use std::env;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::{self, Command, ExitCode};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use csv::WriterBuilder;
use rgwml::{
read_csv, AggregateExpr, AggregateOp, ColumnSelector, CompareOp, CsvReadOptions, GroupKey,
JoinKey, JoinOptions, JoinType, Literal, Predicate,
};
type DynError = Box<dyn std::error::Error>;
fn main() -> ExitCode {
match try_main() {
Ok(()) => ExitCode::SUCCESS,
Err(error) => {
eprintln!("error: {error}");
ExitCode::FAILURE
}
}
}
fn try_main() -> Result<(), DynError> {
let mut args = env::args().skip(1).collect::<Vec<_>>();
if args.is_empty() {
print_usage();
return Ok(());
}
let command = args.remove(0);
match command.as_str() {
"run" => run_command(&args),
"-h" | "--help" | "help" => {
print_usage();
Ok(())
}
_ => {
let mut forwarded = Vec::with_capacity(args.len() + 1);
forwarded.push(command);
forwarded.extend(args);
run_command(&forwarded)
}
}
}
fn print_usage() {
eprintln!(
"usage:
cargo run --bin bench_ops -- run <rows> [--python PYTHON] [--skip-pandas] [--keep-csv]
cargo run --bin bench_ops -- <rows> [--python PYTHON] [--skip-pandas] [--keep-csv]"
);
}
fn run_command(args: &[String]) -> Result<(), DynError> {
if args.is_empty() {
return Err("run requires <rows>".into());
}
let rows = args[0]
.parse::<usize>()
.map_err(|_| format!("invalid row count '{}'", args[0]))?;
let mut python = String::from("python3");
let mut skip_pandas = false;
let mut keep_csv = false;
let mut index = 1;
while index < args.len() {
match args[index].as_str() {
"--python" => {
let value = args
.get(index + 1)
.ok_or("--python requires an interpreter path")?;
python = value.clone();
index += 2;
}
"--skip-pandas" => {
skip_pandas = true;
index += 1;
}
"--keep-csv" => {
keep_csv = true;
index += 1;
}
other => {
return Err(format!("unknown argument '{other}'").into());
}
}
}
let base_path = temp_csv_path("base", rows)?;
let left_path = temp_csv_path("left", rows)?;
let right_path = temp_csv_path("right", rows)?;
generate_base_csv(rows, &base_path)?;
generate_join_csvs(rows, &left_path, &right_path)?;
let v2 = run_v2_ops(rows, &base_path, &left_path, &right_path)?;
let pandas = if skip_pandas {
PandasOpsOutcome::Skipped(String::from("disabled by --skip-pandas"))
} else {
run_pandas_ops(&python, &base_path, &left_path, &right_path)
};
print_summary(rows, &v2, &pandas);
if !keep_csv {
let _ = fs::remove_file(&base_path);
let _ = fs::remove_file(&left_path);
let _ = fs::remove_file(&right_path);
}
Ok(())
}
#[derive(Clone, Debug)]
struct OpsMeasurement {
mode: &'static str,
filter_ms: u128,
filter_rows: usize,
group_ms: u128,
group_rows: usize,
join_ms: u128,
join_rows: usize,
}
enum PandasOpsOutcome {
Measured(OpsMeasurement),
Skipped(String),
}
fn run_v2_ops(
rows: usize,
base_path: &Path,
left_path: &Path,
right_path: &Path,
) -> Result<OpsMeasurement, DynError> {
let _ = rows;
let base = read_csv(base_path, &CsvReadOptions::default())?;
let left = read_csv(left_path, &CsvReadOptions::default())?;
let right = read_csv(right_path, &CsvReadOptions::default())?;
let filter_start = Instant::now();
let filtered = base
.filter(&Predicate::And(vec![
Predicate::Comparison {
column: ColumnSelector::from("active"),
op: CompareOp::Eq,
value: Some(Literal::Bool(true)),
},
Predicate::Comparison {
column: ColumnSelector::from("revenue"),
op: CompareOp::Gt,
value: Some(Literal::F64(5_000.0)),
},
Predicate::Comparison {
column: ColumnSelector::from("segment"),
op: CompareOp::StartsWith,
value: Some(Literal::from("segment_0")),
},
]))?
.materialize()?;
let filter_ms = filter_start.elapsed().as_millis();
let group_start = Instant::now();
let grouped = base.group_by(
&[GroupKey {
column: ColumnSelector::from("segment"),
}],
&[
AggregateExpr {
input: None,
op: AggregateOp::CountRows,
alias: "rows".into(),
},
AggregateExpr {
input: Some(ColumnSelector::from("revenue")),
op: AggregateOp::Sum,
alias: "revenue_sum".into(),
},
AggregateExpr {
input: Some(ColumnSelector::from("revenue")),
op: AggregateOp::Mean,
alias: "revenue_mean".into(),
},
],
)?;
let group_ms = group_start.elapsed().as_millis();
let join_start = Instant::now();
let joined = left.join(
&right,
&JoinOptions {
join_type: JoinType::Inner,
keys: vec![JoinKey {
left: ColumnSelector::from("id"),
right: ColumnSelector::from("id"),
}],
..JoinOptions::default()
},
)?;
let join_ms = join_start.elapsed().as_millis();
Ok(OpsMeasurement {
mode: "v2",
filter_ms,
filter_rows: filtered.nrows(),
group_ms,
group_rows: grouped.nrows(),
join_ms,
join_rows: joined.nrows(),
})
}
fn run_pandas_ops(
python: &str,
base_path: &Path,
left_path: &Path,
right_path: &Path,
) -> PandasOpsOutcome {
let script = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("scripts/bench_pandas_ops.py");
let output = match Command::new(python)
.arg(&script)
.arg(base_path)
.arg(left_path)
.arg(right_path)
.output()
{
Ok(output) => output,
Err(error) => return PandasOpsOutcome::Skipped(error.to_string()),
};
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
let reason = if stderr.is_empty() {
format!("python exited with status {}", output.status)
} else {
stderr
};
return PandasOpsOutcome::Skipped(reason);
}
match parse_pandas_measurement(&output.stdout) {
Ok(measurement) => PandasOpsOutcome::Measured(measurement),
Err(error) => PandasOpsOutcome::Skipped(error.to_string()),
}
}
fn parse_pandas_measurement(output: &[u8]) -> Result<OpsMeasurement, DynError> {
let text = String::from_utf8(output.to_vec())?;
let mut values = std::collections::HashMap::new();
for line in text.lines() {
if let Some((key, value)) = line.split_once('=') {
values.insert(key.trim().to_string(), value.trim().to_string());
}
}
Ok(OpsMeasurement {
mode: "pandas",
filter_ms: required_ops_value(&values, "filter_ms")?.parse::<u128>()?,
filter_rows: required_ops_value(&values, "filter_rows")?.parse::<usize>()?,
group_ms: required_ops_value(&values, "group_ms")?.parse::<u128>()?,
group_rows: required_ops_value(&values, "group_rows")?.parse::<usize>()?,
join_ms: required_ops_value(&values, "join_ms")?.parse::<u128>()?,
join_rows: required_ops_value(&values, "join_rows")?.parse::<usize>()?,
})
}
fn required_ops_value<'a>(
values: &'a std::collections::HashMap<String, String>,
key: &str,
) -> Result<&'a str, DynError> {
values
.get(key)
.map(String::as_str)
.ok_or_else(|| format!("missing benchmark field '{key}'").into())
}
fn print_summary(rows: usize, v2: &OpsMeasurement, pandas: &PandasOpsOutcome) {
println!("rows={rows}");
println!();
println!(
"{:<10} {:>12} {:>12} {:>12}",
"mode", "filter_ms", "group_ms", "join_ms"
);
println!(
"{:<10} {:>12} {:>12} {:>12}",
v2.mode, v2.filter_ms, v2.group_ms, v2.join_ms
);
match pandas {
PandasOpsOutcome::Measured(measurement) => {
println!(
"{:<10} {:>12} {:>12} {:>12}",
measurement.mode, measurement.filter_ms, measurement.group_ms, measurement.join_ms
);
}
PandasOpsOutcome::Skipped(reason) => println!("pandas=skipped ({reason})"),
}
println!("filter_rows={}", v2.filter_rows);
println!("group_rows={}", v2.group_rows);
println!("join_rows={}", v2.join_rows);
}
fn generate_base_csv(rows: usize, path: &Path) -> Result<(), DynError> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let mut writer = WriterBuilder::new().from_path(path)?;
writer.write_record(["id", "active", "revenue", "segment", "owner"])?;
for row in 0..rows {
writer.write_record([
row.to_string(),
if row % 2 == 0 {
String::from("true")
} else {
String::from("false")
},
format!("{:.2}", (row as f64 * 1.35) + ((row % 17) as f64)),
format!("segment_{:02}", row % 16),
format!("owner_{:02}", row % 32),
])?;
}
writer.flush()?;
Ok(())
}
fn generate_join_csvs(rows: usize, left_path: &Path, right_path: &Path) -> Result<(), DynError> {
if let Some(parent) = left_path.parent() {
fs::create_dir_all(parent)?;
}
let mut left_writer = WriterBuilder::new().from_path(left_path)?;
left_writer.write_record(["id", "segment", "revenue"])?;
for row in 0..rows {
left_writer.write_record([
row.to_string(),
format!("segment_{:02}", row % 16),
format!("{:.2}", (row as f64 * 0.75) + ((row % 11) as f64)),
])?;
}
left_writer.flush()?;
let mut right_writer = WriterBuilder::new().from_path(right_path)?;
right_writer.write_record(["id", "owner", "tier"])?;
for row in 0..rows {
if row % 2 == 0 {
right_writer.write_record([
row.to_string(),
format!("owner_{:02}", row % 32),
format!("tier_{:02}", row % 4),
])?;
}
}
right_writer.flush()?;
Ok(())
}
fn temp_csv_path(prefix: &str, rows: usize) -> Result<PathBuf, DynError> {
let millis = SystemTime::now().duration_since(UNIX_EPOCH)?.as_millis();
Ok(env::temp_dir().join(format!(
"rgwml_ops_{}_{}_{}_{}.csv",
prefix,
rows,
process::id(),
millis
)))
}