use oaxaca_blinder::{OaxacaBuilder, QuantileDecompositionBuilder, ReferenceCoefficients};
use polars::prelude::*;
fn create_sample_dataframe() -> DataFrame {
df!(
"wage" => &[10.0, 12.0, 11.0, 13.0, 15.0, 20.0, 22.0, 21.0, 23.0, 25.0, 10.0, 12.0, 11.0, 13.0, 15.0, 20.0, 22.0, 21.0, 23.0, 25.0],
"education" => &[12.0, 16.0, 14.0, 16.0, 18.0, 12.0, 16.0, 14.0, 16.0, 18.0, 12.0, 16.0, 14.0, 16.0, 18.0, 12.0, 16.0, 14.0, 16.0, 18.0],
"gender" => &["F", "F", "F", "F", "F", "M", "M", "M", "M", "M", "F", "F", "F", "F", "F", "M", "M", "M", "M", "M"]
).unwrap()
}
fn run_and_check(builder: OaxacaBuilder, expected_gap: f64) {
let results = builder.run().expect("Oaxaca run failed");
assert!((results.total_gap() - expected_gap).abs() < 1e-9);
let explained = results
.two_fold()
.aggregate()
.iter()
.find(|c| c.name() == "explained")
.unwrap()
.estimate();
let unexplained = results
.two_fold()
.aggregate()
.iter()
.find(|c| c.name() == "unexplained")
.unwrap()
.estimate();
let total_gap = results.total_gap();
println!(
"Explained: {}, Unexplained: {}, Sum: {}, Total Gap: {}",
explained,
unexplained,
explained + unexplained,
total_gap
);
assert!(
(explained + unexplained - results.total_gap()).abs() < 1e-9,
"Decomposition does not sum to total gap"
);
assert_eq!(*results.n_a(), 10);
assert_eq!(*results.n_b(), 10);
results.summary();
}
#[test]
#[ignore]
fn test_detailed_components_with_rare_category() {
let df = df!(
"wage" => &[10.0, 12.0, 11.0, 13.0, 15.0, 20.0, 22.0, 21.0, 23.0, 25.0, 10.0, 12.0, 11.0, 13.0, 15.0, 20.0, 22.0, 21.0, 23.0, 25.0],
"education" => &[12.0, 16.0, 14.0, 16.0, 18.0, 12.0, 16.0, 14.0, 16.0, 18.0, 12.0, 16.0, 14.0, 16.0, 18.0, 12.0, 16.0, 14.0, 16.0, 18.0],
"gender" => &["F", "F", "F", "F", "F", "F", "F", "F", "F", "F", "M", "M", "M", "M", "M", "M", "M", "M", "M", "M"],
"sector" => &["A", "A", "A", "A", "A", "A", "A", "A", "A", "B", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A"] ).unwrap();
let mut builder = OaxacaBuilder::new(df, "wage", "gender", "F");
let results = builder
.predictors(&["education"])
.categorical_predictors(&["sector"])
.bootstrap_reps(5)
.run()
.expect("Oaxaca run failed");
let detailed_unexplained = results.two_fold().detailed_unexplained();
let intercept = detailed_unexplained
.iter()
.find(|c| c.name() == "intercept")
.unwrap();
assert!(intercept.ci_lower().is_finite());
assert!(intercept.ci_upper().is_finite());
let education = detailed_unexplained
.iter()
.find(|c| c.name() == "education")
.unwrap();
assert!(education.ci_lower().is_finite());
assert!(education.ci_upper().is_finite());
let sector_b = detailed_unexplained.iter().find(|c| c.name() == "sector_B");
assert!(
sector_b.is_some(),
"Detailed component for rare category 'sector_B' should be present"
);
assert!(sector_b.unwrap().ci_lower().is_finite());
assert!(sector_b.unwrap().ci_upper().is_finite());
results.summary();
}
#[test]
fn test_full_run_group_b_ref() {
let df = create_sample_dataframe();
let mut builder = OaxacaBuilder::new(df, "wage", "gender", "F");
builder.predictors(&["education"]).bootstrap_reps(5); run_and_check(builder, 10.0);
}
#[test]
fn test_full_run_group_a_ref() {
let df = create_sample_dataframe();
let mut builder = OaxacaBuilder::new(df, "wage", "gender", "F");
builder
.predictors(&["education"])
.bootstrap_reps(5)
.reference_coefficients(ReferenceCoefficients::GroupA);
run_and_check(builder, 10.0);
}
#[test]
fn test_full_run_pooled_ref() {
let df = create_sample_dataframe();
let mut builder = OaxacaBuilder::new(df, "wage", "gender", "F");
builder
.predictors(&["education"])
.bootstrap_reps(5)
.reference_coefficients(ReferenceCoefficients::Pooled);
run_and_check(builder, 10.0);
}
#[test]
fn test_full_run_weighted_ref() {
let df = create_sample_dataframe();
let mut builder = OaxacaBuilder::new(df, "wage", "gender", "F");
builder
.predictors(&["education"])
.bootstrap_reps(5)
.reference_coefficients(ReferenceCoefficients::Weighted);
run_and_check(builder, 10.0);
}
#[test]
fn test_with_categorical_variable() {
let df = df!(
"wage" => &[10.0, 12.0, 11.0, 13.0, 15.0, 20.0, 22.0, 21.0, 23.0, 25.0, 10.0, 12.0, 11.0, 13.0, 15.0, 20.0, 22.0, 21.0, 23.0, 25.0],
"education" => &[12.0, 16.0, 14.0, 16.0, 18.0, 12.0, 16.0, 14.0, 16.0, 18.0, 12.0, 16.0, 14.0, 16.0, 18.0, 12.0, 16.0, 14.0, 16.0, 18.0],
"gender" => &["F", "F", "F", "F", "F", "M", "M", "M", "M", "M", "F", "F", "F", "F", "F", "M", "M", "M", "M", "M"],
"union" => &["none", "union", "union_plus", "none", "union", "union_plus", "none", "union", "union_plus", "none", "none", "union", "union_plus", "none", "union", "union_plus", "none", "union", "union_plus", "none"]
).unwrap();
let mut builder = OaxacaBuilder::new(df, "wage", "gender", "F");
builder
.predictors(&["education"])
.categorical_predictors(&["union"])
.normalize(&["union"])
.bootstrap_reps(5);
run_and_check(builder, 10.0);
}
#[test]
fn test_quantile_decomposition() {
let df = df!(
"wage" => &[10.0, 12.0, 11.0, 13.0, 15.0, 20.0, 22.0, 21.0, 23.0, 25.0, 9.0, 18.0],
"education" => &[12.0, 16.0, 14.0, 16.0, 18.0, 12.0, 16.0, 14.0, 16.0, 18.0, 10.0, 20.0],
"gender" => &["F", "F", "F", "F", "F", "F", "M", "M", "M", "M", "M", "M"]
)
.unwrap();
let quantiles_to_test = &[0.25, 0.5, 0.75];
let mut builder = QuantileDecompositionBuilder::new(df, "wage", "gender", "F");
let results = builder
.predictors(&["education"])
.quantiles(quantiles_to_test)
.simulations(10) .bootstrap_reps(2) .run()
.unwrap();
assert!(results.results_by_quantile().contains_key("q25"));
assert!(results.results_by_quantile().contains_key("q50"));
assert!(results.results_by_quantile().contains_key("q75"));
for key in &["q25", "q50", "q75"] {
let detail = results.results_by_quantile().get(*key).unwrap();
let gap = detail.total_gap().estimate();
let chars = detail.characteristics_effect().estimate();
let coeffs = detail.coefficients_effect().estimate();
assert!((chars + coeffs - gap).abs() < 1e-9);
}
results.summary();
}