#[cfg(feature = "dep_arrow")]
use arrow::{
array::Array, array::ArrayRef, datatypes::DataType, datatypes::Schema,
record_batch::RecordBatch,
};
#[cfg(feature = "dep_arrow")]
use arrow_convert::{field::ArrowField, serialize::TryIntoArrow};
use init_data::InitData;
#[cfg(feature = "dep_arrow")]
use rayon::prelude::*;
use crate::cli::get_matching_files;
use crate::details::{PlayerLobbyDetails, PlayerLobbyDetailsFlatRow};
use crate::game_events::VersionedBalanceUnit;
use crate::tracker_events;
use crate::*;
use clap::Subcommand;
use std::path::PathBuf;
pub mod ipc_writer;
use ipc_writer::*;
#[derive(Debug, Subcommand, Clone)]
pub enum ArrowIpcTypes {
UserInitData,
Details,
Stats,
Upgrades,
UnitBorn,
UnitDied,
MessageEvents,
CmdTargetPoint,
CmdTargetUnit,
All,
}
impl ArrowIpcTypes {
pub fn schema(&self) -> Schema {
match self {
Self::UserInitData => {
if let DataType::Struct(fields) = init_data::UserInitDataFlatRow::data_type() {
Schema::new(fields.clone())
} else {
panic!("Invalid schema, expected struct");
}
}
Self::Details => {
if let DataType::Struct(fields) = details::PlayerLobbyDetailsFlatRow::data_type() {
Schema::new(fields.clone())
} else {
panic!("Invalid schema, expected struct");
}
}
Self::Stats => {
if let DataType::Struct(fields) = tracker_events::PlayerStatsFlatRow::data_type() {
Schema::new(fields.clone())
} else {
panic!("Invalid schema, expected struct");
}
}
Self::Upgrades => {
if let DataType::Struct(fields) = tracker_events::UpgradeEventFlatRow::data_type() {
Schema::new(fields.clone())
} else {
panic!("Invalid schema, expected struct");
}
}
Self::UnitBorn => {
if let DataType::Struct(fields) = tracker_events::UnitBornEventFlatRow::data_type()
{
Schema::new(fields.clone())
} else {
panic!("Invalid schema, expected struct");
}
}
Self::UnitDied => {
if let DataType::Struct(fields) = tracker_events::UnitDiedEventFlatRow::data_type()
{
Schema::new(fields.clone())
} else {
panic!("Invalid schema, expected struct");
}
}
Self::CmdTargetPoint => {
if let DataType::Struct(fields) =
game_events::CmdTargetPointEventFlatRow::data_type()
{
Schema::new(fields.clone())
} else {
panic!("Invalid schema, expected struct");
}
}
Self::CmdTargetUnit => {
if let DataType::Struct(fields) =
game_events::CmdTargetUnitEventFlatRow::data_type()
{
Schema::new(fields.clone())
} else {
panic!("Invalid schema, expected struct");
}
}
_ => unimplemented!(),
}
}
#[tracing::instrument(level = "debug")]
pub fn handle_write_snapshot(
sources: Vec<InitData>,
output: PathBuf,
unit_abilities: &HashMap<(u32, String), VersionedBalanceUnit>,
serially: bool,
) -> Result<(), Box<dyn std::error::Error>> {
if !output.is_dir() {
panic!("Output must be a directory for types 'all'");
}
Self::Details.handle_details_ipc_cmd(sources.clone(), output.join("details.ipc"))?;
Self::Stats.handle_tracker_events(
sources.clone(),
output.join("stats.ipc"),
unit_abilities,
serially,
)?;
Self::Upgrades.handle_tracker_events(
sources.clone(),
output.join("upgrades.ipc"),
unit_abilities,
serially,
)?;
Self::UnitBorn.handle_tracker_events(
sources.clone(),
output.join("unit_born.ipc"),
unit_abilities,
serially,
)?;
Self::UnitDied.handle_tracker_events(
sources.clone(),
output.join("unit_died.ipc"),
unit_abilities,
serially,
)?;
Self::CmdTargetPoint.handle_game_events(
sources.clone(),
output.join("cmd_target_point.ipc"),
unit_abilities,
serially,
)?;
Self::CmdTargetUnit.handle_game_events(
sources.clone(),
output.join("cmd_target_unit.ipc"),
unit_abilities,
serially,
)?;
Ok(())
}
#[tracing::instrument(level = "debug")]
pub fn handle_tracker_events(
&self,
sources: Vec<InitData>,
output: PathBuf,
versioned_abilities: &HashMap<(u32, String), VersionedBalanceUnit>,
serially: bool,
) -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Processing TrackerEvents IPC write request: {:?}", self);
let writer = open_arrow_mutex_writer(output, self.schema())?;
let total_records = if serially {
sources
.iter()
.filter_map(|source| {
let event_iterator =
SC2EventIterator::new(source, versioned_abilities.clone()).ok()?;
let (res, batch_len): (ArrayRef, usize) = match self {
Self::Stats => {
let batch = event_iterator.collect_into_player_stats_flat_rows();
(batch.try_into_arrow().ok()?, batch.len())
}
Self::Upgrades => {
let batch = event_iterator.collect_into_upgrades_flat_rows();
(batch.try_into_arrow().ok()?, batch.len())
}
Self::UnitBorn => {
let batch = event_iterator.collect_into_unit_born_flat_rows();
(batch.try_into_arrow().ok()?, batch.len())
}
Self::UnitDied => {
let batch = event_iterator.collect_into_unit_died_flat_rows();
(batch.try_into_arrow().ok()?, batch.len())
}
_ => unimplemented!(),
};
write_to_arrow_mutex_writer(&writer, res, batch_len)
})
.sum::<usize>()
} else {
sources
.par_iter()
.filter_map(|source| {
let event_iterator =
SC2EventIterator::new(source, versioned_abilities.clone()).ok()?;
let (res, batch_len): (ArrayRef, usize) = match self {
Self::Stats => {
let batch = event_iterator.collect_into_player_stats_flat_rows();
(batch.try_into_arrow().ok()?, batch.len())
}
Self::Upgrades => {
let batch = event_iterator.collect_into_upgrades_flat_rows();
(batch.try_into_arrow().ok()?, batch.len())
}
Self::UnitBorn => {
let batch = event_iterator.collect_into_unit_born_flat_rows();
(batch.try_into_arrow().ok()?, batch.len())
}
Self::UnitDied => {
let batch = event_iterator.collect_into_unit_died_flat_rows();
(batch.try_into_arrow().ok()?, batch.len())
}
_ => unimplemented!(),
};
write_to_arrow_mutex_writer(&writer, res, batch_len)
})
.sum::<usize>()
};
tracing::info!("Loaded {} records", total_records);
close_arrow_mutex_writer(writer)
}
#[tracing::instrument(level = "debug")]
pub fn handle_game_events(
&self,
sources: Vec<InitData>,
output: PathBuf,
versioned_abilities: &HashMap<(u32, String), VersionedBalanceUnit>,
serially: bool,
) -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Processing GameEvents IPC write request: {:?}", self);
let writer = open_arrow_mutex_writer(output, self.schema())?;
let total_records = if serially {
sources
.iter()
.filter_map(|source| {
let event_iterator =
SC2EventIterator::new(source, versioned_abilities.clone()).ok()?;
let (res, batch_len): (ArrayRef, usize) = match self {
Self::CmdTargetPoint => {
let batch =
event_iterator.collect_into_game_cmd_target_points_flat_rows();
(batch.try_into_arrow().ok()?, batch.len())
}
Self::CmdTargetUnit => {
let batch =
event_iterator.collect_into_game_cmd_target_units_flat_rows();
(batch.try_into_arrow().ok()?, batch.len())
}
e => unimplemented!("{:?}", e),
};
write_to_arrow_mutex_writer(&writer, res, batch_len)
})
.sum::<usize>()
} else {
sources
.par_iter()
.filter_map(|source| {
let event_iterator =
SC2EventIterator::new(source, versioned_abilities.clone()).ok()?;
let (res, batch_len): (ArrayRef, usize) = match self {
Self::CmdTargetPoint => {
let batch =
event_iterator.collect_into_game_cmd_target_points_flat_rows();
(batch.try_into_arrow().ok()?, batch.len())
}
Self::CmdTargetUnit => {
let batch =
event_iterator.collect_into_game_cmd_target_units_flat_rows();
(batch.try_into_arrow().ok()?, batch.len())
}
e => unimplemented!("{:?}", e),
};
write_to_arrow_mutex_writer(&writer, res, batch_len)
})
.sum::<usize>()
};
tracing::info!("Loaded {} records", total_records);
close_arrow_mutex_writer(writer)
}
#[tracing::instrument(level = "debug")]
pub fn handle_read_once_write_all(
&self,
sources: Vec<InitData>,
output: PathBuf,
) -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Processing Read Once Write All IPC request");
let details_flaw_rows: Vec<PlayerLobbyDetailsFlatRow> = sources
.iter()
.flat_map(|source| {
let res: Vec<PlayerLobbyDetails> = match source.try_into() {
Ok(details) => details,
Err(err) => {
tracing::error!("Error reading details: {:?}", err);
return vec![];
}
};
res.into_iter()
.map(|d| d.into())
.collect::<Vec<PlayerLobbyDetailsFlatRow>>()
})
.collect();
let res: ArrayRef = details_flaw_rows.try_into_arrow()?;
let chunk: RecordBatch = res
.as_any()
.downcast_ref::<arrow::array::StructArray>()
.unwrap()
.into();
write_batches(output, Self::Details.schema(), chunk)?;
Ok(())
}
#[tracing::instrument(level = "debug")]
pub fn handle_details_ipc_cmd(
&self,
sources: Vec<InitData>,
output: PathBuf,
) -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Processing Details IPC write request");
let details_flaw_rows: Vec<PlayerLobbyDetailsFlatRow> = sources
.iter()
.flat_map(|source| {
let res: Vec<PlayerLobbyDetails> = match source.try_into() {
Ok(details) => details,
Err(err) => {
tracing::error!("Error reading details: {:?}", err);
return vec![];
}
};
res.into_iter()
.map(|d| d.into())
.collect::<Vec<PlayerLobbyDetailsFlatRow>>()
})
.collect();
let res: ArrayRef = details_flaw_rows.try_into_arrow()?;
let chunk: RecordBatch = res
.as_any()
.downcast_ref::<arrow::array::StructArray>()
.unwrap()
.into();
write_batches(output, Self::Details.schema(), chunk)?;
Ok(())
}
#[tracing::instrument(level = "debug")]
pub fn handle_arrow_ipc_cmd(
source: PathBuf,
output: PathBuf,
cmd: &WriteArrowIpcProps,
unit_abilities: &HashMap<(u32, String), VersionedBalanceUnit>,
serially: bool,
) -> Result<(), Box<dyn std::error::Error>> {
println!(
"Processing Arrow write request with scan_max_files: {}, traverse_max_depth: {}, process_max_files: {}, min_version: {:?}, max_version: {:?}",
cmd.scan_max_files,
cmd.process_max_files,
cmd.traverse_max_depth,
cmd.min_version,
cmd.max_version
);
let sources = get_matching_files(source, cmd.scan_max_files, cmd.traverse_max_depth)?;
println!("Located {} matching files by extension", sources.len());
let sources: Vec<InitData> = if !serially {
sources
.iter()
.enumerate()
.filter_map(|(idx, source)| {
InitData::try_from((source.clone(), u64::try_from(idx).unwrap())).ok()
})
.collect::<Vec<InitData>>()
} else {
sources
.par_iter()
.enumerate()
.filter_map(|(idx, source)| {
InitData::try_from((source.clone(), u64::try_from(idx).unwrap())).ok()
})
.collect::<Vec<InitData>>()
};
let sources: Vec<InitData> = sources
.into_iter()
.filter(|source| {
if let Some(min_version) = cmd.min_version
&& source.version < min_version
{
return false;
}
if let Some(max_version) = cmd.max_version
&& source.version > max_version
{
return false;
}
true
})
.take(cmd.process_max_files)
.collect();
if sources.is_empty() {
panic!("No files found");
} else {
println!(
"{} files have valid init data, processing...",
sources.len()
);
}
Self::handle_write_snapshot(sources, output, unit_abilities, serially)
}
}