rgwml 2.0.0

Typed, local-first tabular data library with columnar in-memory storage.
Documentation
use std::env;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::{self, Command, ExitCode};
use std::time::{Instant, SystemTime, UNIX_EPOCH};

use csv::WriterBuilder;
use rgwml::{
    read_csv, AggregateExpr, AggregateOp, ColumnSelector, CompareOp, CsvReadOptions, GroupKey,
    JoinKey, JoinOptions, JoinType, Literal, Predicate,
};

type DynError = Box<dyn std::error::Error>;

fn main() -> ExitCode {
    match try_main() {
        Ok(()) => ExitCode::SUCCESS,
        Err(error) => {
            eprintln!("error: {error}");
            ExitCode::FAILURE
        }
    }
}

fn try_main() -> Result<(), DynError> {
    let mut args = env::args().skip(1).collect::<Vec<_>>();
    if args.is_empty() {
        print_usage();
        return Ok(());
    }

    let command = args.remove(0);
    match command.as_str() {
        "run" => run_command(&args),
        "-h" | "--help" | "help" => {
            print_usage();
            Ok(())
        }
        _ => {
            let mut forwarded = Vec::with_capacity(args.len() + 1);
            forwarded.push(command);
            forwarded.extend(args);
            run_command(&forwarded)
        }
    }
}

fn print_usage() {
    eprintln!(
        "usage:
  cargo run --bin bench_ops -- run <rows> [--python PYTHON] [--skip-pandas] [--keep-csv]
  cargo run --bin bench_ops -- <rows> [--python PYTHON] [--skip-pandas] [--keep-csv]"
    );
}

fn run_command(args: &[String]) -> Result<(), DynError> {
    if args.is_empty() {
        return Err("run requires <rows>".into());
    }

    let rows = args[0]
        .parse::<usize>()
        .map_err(|_| format!("invalid row count '{}'", args[0]))?;
    let mut python = String::from("python3");
    let mut skip_pandas = false;
    let mut keep_csv = false;

    let mut index = 1;
    while index < args.len() {
        match args[index].as_str() {
            "--python" => {
                let value = args
                    .get(index + 1)
                    .ok_or("--python requires an interpreter path")?;
                python = value.clone();
                index += 2;
            }
            "--skip-pandas" => {
                skip_pandas = true;
                index += 1;
            }
            "--keep-csv" => {
                keep_csv = true;
                index += 1;
            }
            other => {
                return Err(format!("unknown argument '{other}'").into());
            }
        }
    }

    let base_path = temp_csv_path("base", rows)?;
    let left_path = temp_csv_path("left", rows)?;
    let right_path = temp_csv_path("right", rows)?;

    generate_base_csv(rows, &base_path)?;
    generate_join_csvs(rows, &left_path, &right_path)?;

    let v2 = run_v2_ops(rows, &base_path, &left_path, &right_path)?;
    let pandas = if skip_pandas {
        PandasOpsOutcome::Skipped(String::from("disabled by --skip-pandas"))
    } else {
        run_pandas_ops(&python, &base_path, &left_path, &right_path)
    };

    print_summary(rows, &v2, &pandas);

    if !keep_csv {
        let _ = fs::remove_file(&base_path);
        let _ = fs::remove_file(&left_path);
        let _ = fs::remove_file(&right_path);
    }

    Ok(())
}

#[derive(Clone, Debug)]
struct OpsMeasurement {
    mode: &'static str,
    filter_ms: u128,
    filter_rows: usize,
    group_ms: u128,
    group_rows: usize,
    join_ms: u128,
    join_rows: usize,
}

enum PandasOpsOutcome {
    Measured(OpsMeasurement),
    Skipped(String),
}

fn run_v2_ops(
    rows: usize,
    base_path: &Path,
    left_path: &Path,
    right_path: &Path,
) -> Result<OpsMeasurement, DynError> {
    let _ = rows;
    let base = read_csv(base_path, &CsvReadOptions::default())?;
    let left = read_csv(left_path, &CsvReadOptions::default())?;
    let right = read_csv(right_path, &CsvReadOptions::default())?;

    let filter_start = Instant::now();
    let filtered = base
        .filter(&Predicate::And(vec![
            Predicate::Comparison {
                column: ColumnSelector::from("active"),
                op: CompareOp::Eq,
                value: Some(Literal::Bool(true)),
            },
            Predicate::Comparison {
                column: ColumnSelector::from("revenue"),
                op: CompareOp::Gt,
                value: Some(Literal::F64(5_000.0)),
            },
            Predicate::Comparison {
                column: ColumnSelector::from("segment"),
                op: CompareOp::StartsWith,
                value: Some(Literal::from("segment_0")),
            },
        ]))?
        .materialize()?;
    let filter_ms = filter_start.elapsed().as_millis();

    let group_start = Instant::now();
    let grouped = base.group_by(
        &[GroupKey {
            column: ColumnSelector::from("segment"),
        }],
        &[
            AggregateExpr {
                input: None,
                op: AggregateOp::CountRows,
                alias: "rows".into(),
            },
            AggregateExpr {
                input: Some(ColumnSelector::from("revenue")),
                op: AggregateOp::Sum,
                alias: "revenue_sum".into(),
            },
            AggregateExpr {
                input: Some(ColumnSelector::from("revenue")),
                op: AggregateOp::Mean,
                alias: "revenue_mean".into(),
            },
        ],
    )?;
    let group_ms = group_start.elapsed().as_millis();

    let join_start = Instant::now();
    let joined = left.join(
        &right,
        &JoinOptions {
            join_type: JoinType::Inner,
            keys: vec![JoinKey {
                left: ColumnSelector::from("id"),
                right: ColumnSelector::from("id"),
            }],
            ..JoinOptions::default()
        },
    )?;
    let join_ms = join_start.elapsed().as_millis();

    Ok(OpsMeasurement {
        mode: "v2",
        filter_ms,
        filter_rows: filtered.nrows(),
        group_ms,
        group_rows: grouped.nrows(),
        join_ms,
        join_rows: joined.nrows(),
    })
}

fn run_pandas_ops(
    python: &str,
    base_path: &Path,
    left_path: &Path,
    right_path: &Path,
) -> PandasOpsOutcome {
    let script = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("scripts/bench_pandas_ops.py");
    let output = match Command::new(python)
        .arg(&script)
        .arg(base_path)
        .arg(left_path)
        .arg(right_path)
        .output()
    {
        Ok(output) => output,
        Err(error) => return PandasOpsOutcome::Skipped(error.to_string()),
    };

    if !output.status.success() {
        let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
        let reason = if stderr.is_empty() {
            format!("python exited with status {}", output.status)
        } else {
            stderr
        };
        return PandasOpsOutcome::Skipped(reason);
    }

    match parse_pandas_measurement(&output.stdout) {
        Ok(measurement) => PandasOpsOutcome::Measured(measurement),
        Err(error) => PandasOpsOutcome::Skipped(error.to_string()),
    }
}

fn parse_pandas_measurement(output: &[u8]) -> Result<OpsMeasurement, DynError> {
    let text = String::from_utf8(output.to_vec())?;
    let mut values = std::collections::HashMap::new();
    for line in text.lines() {
        if let Some((key, value)) = line.split_once('=') {
            values.insert(key.trim().to_string(), value.trim().to_string());
        }
    }

    Ok(OpsMeasurement {
        mode: "pandas",
        filter_ms: required_ops_value(&values, "filter_ms")?.parse::<u128>()?,
        filter_rows: required_ops_value(&values, "filter_rows")?.parse::<usize>()?,
        group_ms: required_ops_value(&values, "group_ms")?.parse::<u128>()?,
        group_rows: required_ops_value(&values, "group_rows")?.parse::<usize>()?,
        join_ms: required_ops_value(&values, "join_ms")?.parse::<u128>()?,
        join_rows: required_ops_value(&values, "join_rows")?.parse::<usize>()?,
    })
}

fn required_ops_value<'a>(
    values: &'a std::collections::HashMap<String, String>,
    key: &str,
) -> Result<&'a str, DynError> {
    values
        .get(key)
        .map(String::as_str)
        .ok_or_else(|| format!("missing benchmark field '{key}'").into())
}

fn print_summary(rows: usize, v2: &OpsMeasurement, pandas: &PandasOpsOutcome) {
    println!("rows={rows}");
    println!();
    println!(
        "{:<10} {:>12} {:>12} {:>12}",
        "mode", "filter_ms", "group_ms", "join_ms"
    );
    println!(
        "{:<10} {:>12} {:>12} {:>12}",
        v2.mode, v2.filter_ms, v2.group_ms, v2.join_ms
    );
    match pandas {
        PandasOpsOutcome::Measured(measurement) => {
            println!(
                "{:<10} {:>12} {:>12} {:>12}",
                measurement.mode, measurement.filter_ms, measurement.group_ms, measurement.join_ms
            );
        }
        PandasOpsOutcome::Skipped(reason) => println!("pandas=skipped ({reason})"),
    }
    println!("filter_rows={}", v2.filter_rows);
    println!("group_rows={}", v2.group_rows);
    println!("join_rows={}", v2.join_rows);
}

fn generate_base_csv(rows: usize, path: &Path) -> Result<(), DynError> {
    if let Some(parent) = path.parent() {
        fs::create_dir_all(parent)?;
    }

    let mut writer = WriterBuilder::new().from_path(path)?;
    writer.write_record(["id", "active", "revenue", "segment", "owner"])?;
    for row in 0..rows {
        writer.write_record([
            row.to_string(),
            if row % 2 == 0 {
                String::from("true")
            } else {
                String::from("false")
            },
            format!("{:.2}", (row as f64 * 1.35) + ((row % 17) as f64)),
            format!("segment_{:02}", row % 16),
            format!("owner_{:02}", row % 32),
        ])?;
    }
    writer.flush()?;
    Ok(())
}

fn generate_join_csvs(rows: usize, left_path: &Path, right_path: &Path) -> Result<(), DynError> {
    if let Some(parent) = left_path.parent() {
        fs::create_dir_all(parent)?;
    }

    let mut left_writer = WriterBuilder::new().from_path(left_path)?;
    left_writer.write_record(["id", "segment", "revenue"])?;
    for row in 0..rows {
        left_writer.write_record([
            row.to_string(),
            format!("segment_{:02}", row % 16),
            format!("{:.2}", (row as f64 * 0.75) + ((row % 11) as f64)),
        ])?;
    }
    left_writer.flush()?;

    let mut right_writer = WriterBuilder::new().from_path(right_path)?;
    right_writer.write_record(["id", "owner", "tier"])?;
    for row in 0..rows {
        if row % 2 == 0 {
            right_writer.write_record([
                row.to_string(),
                format!("owner_{:02}", row % 32),
                format!("tier_{:02}", row % 4),
            ])?;
        }
    }
    right_writer.flush()?;

    Ok(())
}

fn temp_csv_path(prefix: &str, rows: usize) -> Result<PathBuf, DynError> {
    let millis = SystemTime::now().duration_since(UNIX_EPOCH)?.as_millis();
    Ok(env::temp_dir().join(format!(
        "rgwml_ops_{}_{}_{}_{}.csv",
        prefix,
        rows,
        process::id(),
        millis
    )))
}