use clap::{CommandFactory, Parser};
use oaxaca_blinder::{OaxacaBuilder, QuantileDecompositionBuilder, ReferenceCoefficients};
use polars::prelude::*;
use std::error::Error;
use std::path::PathBuf;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long)]
data: PathBuf,
#[arg(long)]
outcome: String,
#[arg(long)]
group: String,
#[arg(long)]
reference: String,
#[arg(long, value_delimiter = ',')]
predictors: Vec<String>,
#[arg(long, value_delimiter = ',')]
categorical: Option<Vec<String>>,
#[arg(long, default_value = "mean")]
analysis_type: String,
#[arg(long, default_value = "group_b")]
ref_coeffs: String,
#[arg(long, value_delimiter = ',')]
quantiles: Option<Vec<f64>>,
#[arg(long, default_value_t = 500)]
bootstrap_reps: usize,
#[arg(long, default_value_t = 1000)]
simulations: usize,
#[arg(long)]
formula: Option<String>,
#[arg(long)]
weights: Option<String>,
#[arg(long)]
selection_outcome: Option<String>,
#[arg(long, value_delimiter = ',')]
selection_predictors: Option<Vec<String>>,
#[arg(long)]
output_json: Option<PathBuf>,
#[arg(long)]
output_markdown: Option<PathBuf>,
}
fn run(args: Args) -> Result<(), Box<dyn Error>> {
let df = CsvReader::from_path(&args.data)?
.has_header(true)
.finish()?;
if args.analysis_type == "mean" {
let reference_coeffs = match args.ref_coeffs.as_str() {
"group_a" => ReferenceCoefficients::GroupA,
"group_b" => ReferenceCoefficients::GroupB,
"pooled" => ReferenceCoefficients::Pooled,
"weighted" => ReferenceCoefficients::Weighted,
_ => return Err("Invalid reference coefficient type".into()),
};
let mut builder = if let Some(formula) = &args.formula {
OaxacaBuilder::from_formula(df, formula, &args.group, &args.reference)?
} else {
let predictors: Vec<&str> = args.predictors.iter().map(AsRef::as_ref).collect();
let categorical_predictors: Vec<&str> = args
.categorical
.as_ref()
.map(|v| v.iter().map(AsRef::as_ref).collect())
.unwrap_or_else(Vec::new);
let mut b = OaxacaBuilder::new(df, &args.outcome, &args.group, &args.reference);
b.predictors(&predictors)
.categorical_predictors(&categorical_predictors);
b
};
builder.bootstrap_reps(args.bootstrap_reps)
.reference_coefficients(reference_coeffs);
if let Some(weights) = &args.weights {
builder.weights(weights);
}
if let Some(sel_outcome) = &args.selection_outcome {
if let Some(sel_predictors) = &args.selection_predictors {
let sel_preds_refs: Vec<&str> = sel_predictors.iter().map(AsRef::as_ref).collect();
builder.heckman_selection(sel_outcome, &sel_preds_refs);
} else {
return Err("Selection predictors must be provided if selection outcome is specified".into());
}
}
let results = builder.run()?;
results.summary();
if let Some(path) = args.output_json {
let json = results.to_json().map_err(|e| format!("Failed to serialize to JSON: {}", e))?;
std::fs::write(path, json)?;
}
if let Some(path) = args.output_markdown {
let md = results.to_markdown();
std::fs::write(path, md)?;
}
} else if args.analysis_type == "quantile" {
let predictors: Vec<&str> = args.predictors.iter().map(AsRef::as_ref).collect();
let quantiles = args
.quantiles
.unwrap_or_else(|| vec![0.1, 0.25, 0.5, 0.75, 0.9]);
let categorical_predictors: Vec<&str> = args
.categorical
.as_ref()
.map(|v| v.iter().map(AsRef::as_ref).collect())
.unwrap_or_else(Vec::new);
let mut builder = QuantileDecompositionBuilder::new(df, &args.outcome, &args.group, &args.reference);
builder
.predictors(&predictors)
.categorical_predictors(&categorical_predictors)
.quantiles(&quantiles)
.bootstrap_reps(args.bootstrap_reps)
.simulations(args.simulations);
let results = builder.run()?;
results.summary();
}
Ok(())
}
fn main() {
let args = Args::parse();
if let Err(e) = run(args) {
eprintln!("Error: {}", e);
let _ = Args::command().print_help();
std::process::exit(1);
}
}