#[cfg(feature = "display")]
use comfy_table::{Cell, Table};
use getset::Getters;
use nalgebra::{DMatrix, DVector};
use polars::prelude::*;
use serde::Serialize;
use std::collections::HashMap;
use std::fmt;
mod decomposition;
mod inference;
mod math;
pub use crate::decomposition::BudgetAdjustment;
pub use crate::decomposition::ReferenceCoefficients;
use crate::decomposition::{
detailed_decomposition, three_fold_decomposition, two_fold_decomposition, DetailedComponent,
ThreeFoldDecomposition, TwoFoldDecomposition,
};
use crate::inference::bootstrap_stats;
use crate::math::normalization::normalize_categorical_coefficients;
use crate::math::ols::ols;
pub mod quantile_decomposition;
use crate::math::rif::calculate_rif;
pub use crate::quantile_decomposition::QuantileDecompositionBuilder;
pub mod jmp;
pub use crate::jmp::decompose_changes;
pub mod dfl;
pub use crate::dfl::run_dfl;
pub mod formula;
pub mod heckman;
use crate::formula::Formula;
pub use crate::heckman::heckman_two_step;
pub mod akm;
pub use crate::akm::{AkmBuilder, AkmResult};
pub mod matching;
pub use crate::matching::engine::MatchingEngine;
#[cfg(feature = "python")]
#[cfg(feature = "python")]
pub mod python;
#[derive(Debug)]
pub enum OaxacaError {
PolarsError(PolarsError),
ColumnNotFound(String),
InvalidGroupVariable(String),
NalgebraError(String),
DiagnosticError(String),
}
impl From<PolarsError> for OaxacaError {
fn from(err: PolarsError) -> Self {
OaxacaError::PolarsError(err)
}
}
impl fmt::Display for OaxacaError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
OaxacaError::PolarsError(e) => write!(f, "Polars error: {}", e),
OaxacaError::ColumnNotFound(s) => write!(f, "Column not found: {}", s),
OaxacaError::InvalidGroupVariable(s) => write!(f, "Invalid group variable: {}", s),
OaxacaError::NalgebraError(s) => write!(f, "Nalgebra error: {}", s),
OaxacaError::DiagnosticError(s) => write!(f, "Diagnostic error: {}", s),
}
}
}
impl std::error::Error for OaxacaError {}
#[derive(Debug, Clone)]
pub struct OaxacaBuilder {
dataframe: DataFrame,
outcome: String,
predictors: Vec<String>,
categorical_predictors: Vec<String>,
group: String,
reference_group: String,
bootstrap_reps: usize,
reference_coeffs: ReferenceCoefficients,
normalization_vars: Vec<String>,
weights_col: Option<String>,
selection_outcome: Option<String>,
selection_predictors: Vec<String>,
}
#[derive(Clone)]
#[allow(dead_code)]
struct SinglePassResult {
three_fold: ThreeFoldDecomposition,
two_fold: TwoFoldDecomposition,
detailed_explained: Vec<DetailedComponent>,
detailed_unexplained: Vec<DetailedComponent>,
total_gap: f64,
residuals_a: DVector<f64>,
residuals_b: DVector<f64>,
xa_mean: DVector<f64>,
xb_mean: DVector<f64>,
beta_star: DVector<f64>,
detailed_selection: Vec<DetailedComponent>,
}
struct EstimationResult {
beta_a: DVector<f64>,
beta_b: DVector<f64>,
xa_mean: DVector<f64>,
xb_mean: DVector<f64>,
predictor_names: Vec<String>,
residuals_a: DVector<f64>,
residuals_b: DVector<f64>,
base_coeffs_a: HashMap<String, f64>,
base_coeffs_b: HashMap<String, f64>,
selection_coeffs_a: Option<DVector<f64>>,
selection_coeffs_b: Option<DVector<f64>>,
selection_means_a: Option<DVector<f64>>,
selection_means_b: Option<DVector<f64>>,
selection_names: Option<Vec<String>>,
imr_delta_a: Option<f64>,
imr_delta_b: Option<f64>,
}
struct EstimationContext<'a> {
df_a: &'a DataFrame,
df_b: &'a DataFrame,
x_a: &'a DMatrix<f64>,
y_a: &'a DVector<f64>,
w_a: &'a Option<DVector<f64>>,
x_b: &'a DMatrix<f64>,
y_b: &'a DVector<f64>,
w_b: &'a Option<DVector<f64>>,
predictor_names: &'a [String],
category_counts: &'a HashMap<String, usize>,
}
trait Estimator {
fn estimate(&self, ctx: &EstimationContext) -> Result<EstimationResult, OaxacaError>;
}
struct OlsEstimator {
normalization_vars: Vec<String>,
}
impl Estimator for OlsEstimator {
fn estimate(&self, ctx: &EstimationContext) -> Result<EstimationResult, OaxacaError> {
let mut ols_a = ols(ctx.y_a, ctx.x_a, ctx.w_a.as_ref())?;
let mut ols_b = ols(ctx.y_b, ctx.x_b, ctx.w_b.as_ref())?;
let calculate_mean = |x: &DMatrix<f64>, w: &Option<DVector<f64>>| -> DVector<f64> {
if let Some(weights) = w {
let total_weight = weights.sum();
let mut means = DVector::zeros(x.ncols());
for j in 0..x.ncols() {
let col = x.column(j);
means[j] = col.dot(weights) / total_weight;
}
means
} else {
x.row_mean().transpose()
}
};
let xa_mean = calculate_mean(ctx.x_a, ctx.w_a);
let xb_mean = calculate_mean(ctx.x_b, ctx.w_b);
let mut base_coeffs_a = HashMap::new();
let mut base_coeffs_b = HashMap::new();
if !self.normalization_vars.is_empty() {
base_coeffs_a = normalize_categorical_coefficients(
&mut ols_a,
ctx.predictor_names,
&self.normalization_vars,
&xa_mean,
ctx.category_counts,
);
base_coeffs_b = normalize_categorical_coefficients(
&mut ols_b,
ctx.predictor_names,
&self.normalization_vars,
&xb_mean,
ctx.category_counts,
);
}
Ok(EstimationResult {
beta_a: ols_a.coefficients,
beta_b: ols_b.coefficients,
xa_mean,
xb_mean,
predictor_names: ctx.predictor_names.to_vec(),
residuals_a: ols_a.residuals,
residuals_b: ols_b.residuals,
base_coeffs_a,
base_coeffs_b,
selection_coeffs_a: None,
selection_coeffs_b: None,
selection_means_a: None,
selection_means_b: None,
selection_names: None,
imr_delta_a: None,
imr_delta_b: None,
})
}
}
struct HeckmanEstimator {
selection_outcome: String,
selection_predictors: Vec<String>,
}
impl Estimator for HeckmanEstimator {
fn estimate(&self, ctx: &EstimationContext) -> Result<EstimationResult, OaxacaError> {
let (x_sel_a, y_sel_a, x_sel_sub_a) = self.prepare_selection_data(ctx.df_a)?;
let (x_sel_b, y_sel_b, x_sel_sub_b) = self.prepare_selection_data(ctx.df_b)?;
let (x_a_filt, y_a_filt) = self.filter_outcome_rows(ctx.x_a, ctx.y_a, ctx.df_a)?;
let (x_b_filt, y_b_filt) = self.filter_outcome_rows(ctx.x_b, ctx.y_b, ctx.df_b)?;
let res_a = heckman_two_step(&y_sel_a, &x_sel_a, &y_a_filt, &x_a_filt, &x_sel_sub_a)?;
let res_b = heckman_two_step(&y_sel_b, &x_sel_b, &y_b_filt, &x_b_filt, &x_sel_sub_b)?;
let mut beta_a = res_a.outcome_coeffs;
let mut beta_b = res_b.outcome_coeffs;
let k = beta_a.len();
beta_a = beta_a.insert_row(k, res_a.imr_coeff);
beta_b = beta_b.insert_row(k, res_b.imr_coeff);
let mut xa_mean = x_a_filt.row_mean().transpose();
let mut xb_mean = x_b_filt.row_mean().transpose();
let imr_mean_a = res_a.imr.mean();
let imr_mean_b = res_b.imr.mean();
let k_x = xa_mean.len();
xa_mean = xa_mean.insert_row(k_x, imr_mean_a);
xb_mean = xb_mean.insert_row(k_x, imr_mean_b);
let mut final_names = ctx.predictor_names.to_vec();
final_names.push("IMR".to_string());
let residuals_a = DVector::zeros(y_a_filt.len());
let residuals_b = DVector::zeros(y_b_filt.len());
Ok(EstimationResult {
beta_a,
beta_b,
xa_mean,
xb_mean,
predictor_names: final_names,
residuals_a,
residuals_b,
base_coeffs_a: HashMap::new(),
base_coeffs_b: HashMap::new(),
selection_coeffs_a: Some(res_a.selection_coeffs),
selection_coeffs_b: Some(res_b.selection_coeffs),
selection_means_a: Some(vec_to_dvec(&x_sel_a.row_mean().transpose())),
selection_means_b: Some(vec_to_dvec(&x_sel_b.row_mean().transpose())),
selection_names: Some(ctx.predictor_names.to_vec()),
imr_delta_a: Some(res_a.imr_delta),
imr_delta_b: Some(res_b.imr_delta),
})
}
}
fn vec_to_dvec(v: &DVector<f64>) -> DVector<f64> {
v.clone()
}
impl HeckmanEstimator {
fn prepare_selection_data(
&self,
df_group: &DataFrame,
) -> Result<(DMatrix<f64>, DVector<f64>, DMatrix<f64>), OaxacaError> {
let y_sel_series = df_group.column(&self.selection_outcome)?.f64()?;
let y_sel_vec: Vec<f64> = y_sel_series
.into_iter()
.map(|opt| opt.expect("Selection outcome should be clean"))
.collect();
let y_sel = DVector::from_vec(y_sel_vec);
let mut x_sel_df = df_group.select(&self.selection_predictors)?;
let intercept = Series::new("intercept".into(), vec![1.0; df_group.height()]);
x_sel_df.with_column(intercept)?;
let mut cols = vec!["intercept".to_string()];
cols.extend(self.selection_predictors.clone());
let x_sel_df = x_sel_df.select(&cols)?;
let x_sel_mat = x_sel_df.to_ndarray::<Float64Type>(IndexOrder::Fortran)?;
let x_sel_vec: Vec<f64> = x_sel_mat.iter().copied().collect();
let x_sel = DMatrix::from_row_slice(x_sel_df.height(), x_sel_df.width(), &x_sel_vec);
let mask = df_group
.column(&self.selection_outcome)?
.as_materialized_series()
.equal(1)?;
let df_subset = df_group.filter(&mask)?;
let mut x_sel_sub_df = df_subset.select(&self.selection_predictors)?;
x_sel_sub_df.with_column(Series::new(
"intercept".into(),
vec![1.0; df_subset.height()],
))?;
let x_sel_sub_df = x_sel_sub_df.select(&cols)?;
let x_sel_sub_mat = x_sel_sub_df.to_ndarray::<Float64Type>(IndexOrder::Fortran)?;
let x_sel_sub_vec: Vec<f64> = x_sel_sub_mat.iter().copied().collect();
let x_sel_sub =
DMatrix::from_row_slice(x_sel_sub_df.height(), x_sel_sub_df.width(), &x_sel_sub_vec);
Ok((x_sel, y_sel, x_sel_sub))
}
fn filter_outcome_rows(
&self,
x: &DMatrix<f64>,
y: &DVector<f64>,
df: &DataFrame,
) -> Result<(DMatrix<f64>, DVector<f64>), OaxacaError> {
let mask = df
.column(&self.selection_outcome)?
.as_materialized_series()
.equal(1)?;
let mut rows = Vec::new();
let mut y_vals = Vec::new();
for i in 0..df.height() {
if mask.get(i) == Some(true) {
rows.push(x.row(i).into_owned());
y_vals.push(y[i]);
}
}
if rows.is_empty() {
return Err(OaxacaError::InvalidGroupVariable(
"No observed outcomes in group".to_string(),
));
}
let x_filtered = DMatrix::from_rows(&rows);
let y_filtered = DVector::from_vec(y_vals);
Ok((x_filtered, y_filtered))
}
}
impl OaxacaBuilder {
pub fn new(dataframe: DataFrame, outcome: &str, group: &str, reference_group: &str) -> Self {
Self {
dataframe,
outcome: outcome.to_string(),
predictors: Vec::new(),
categorical_predictors: Vec::new(),
group: group.to_string(),
reference_group: reference_group.to_string(),
bootstrap_reps: 100,
reference_coeffs: ReferenceCoefficients::GroupA,
normalization_vars: Vec::new(),
weights_col: None,
selection_outcome: None,
selection_predictors: Vec::new(),
}
}
pub fn from_formula(
dataframe: DataFrame,
formula: &str,
group: &str,
reference_group: &str,
) -> Result<Self, OaxacaError> {
let parsed_formula = Formula::parse(formula)?;
Ok(Self {
dataframe,
outcome: parsed_formula.outcome,
predictors: parsed_formula.predictors,
categorical_predictors: parsed_formula.categorical_predictors,
group: group.to_string(),
reference_group: reference_group.to_string(),
bootstrap_reps: 100,
reference_coeffs: ReferenceCoefficients::GroupA,
normalization_vars: Vec::new(),
weights_col: None,
selection_outcome: None,
selection_predictors: Vec::new(),
})
}
pub fn reference_coefficients(&mut self, reference: ReferenceCoefficients) -> &mut Self {
self.reference_coeffs = reference;
self
}
pub fn predictors(&mut self, predictors: &[&str]) -> &mut Self {
self.predictors = predictors.iter().map(|s| s.to_string()).collect();
self
}
pub fn categorical_predictors(&mut self, predictors: &[&str]) -> &mut Self {
self.categorical_predictors = predictors.iter().map(|s| s.to_string()).collect();
self
}
pub fn bootstrap_reps(&mut self, reps: usize) -> &mut Self {
self.bootstrap_reps = reps;
self
}
pub fn normalize(&mut self, vars: &[&str]) -> &mut Self {
self.normalization_vars = vars.iter().map(|s| s.to_string()).collect();
self
}
pub fn weights(&mut self, weights: &str) -> &mut Self {
self.weights_col = Some(weights.to_string());
self
}
pub fn heckman_selection(&mut self, outcome: &str, predictors: &[&str]) -> &mut Self {
self.selection_outcome = Some(outcome.to_string());
self.selection_predictors = predictors.iter().map(|s| s.to_string()).collect();
self
}
pub fn get_data_matrices(
&self,
) -> Result<
(
DMatrix<f64>,
DVector<f64>,
DMatrix<f64>,
DVector<f64>,
Vec<String>,
),
OaxacaError,
> {
let df_dirty = self.dataframe.clone();
let mut df = self.clean_dataframe(&df_dirty)?;
let mut all_dummy_names = Vec::new();
if !self.categorical_predictors.is_empty() {
for cat_pred in &self.categorical_predictors {
let series = df.column(cat_pred)?;
let (dummies, _, _) =
self.create_dummies_manual(series.as_materialized_series())?;
for s in dummies.get_columns() {
all_dummy_names.push(s.name().to_string());
}
df = df.hstack(dummies.get_columns())?;
}
}
let unique_groups = self
.dataframe
.column(&self.group)?
.unique()?
.sort(SortOptions {
descending: false,
nulls_last: false,
..Default::default()
})?;
if unique_groups.len() < 2 {
return Err(OaxacaError::InvalidGroupVariable(
"Not enough groups for comparison".to_string(),
));
}
let group_b_name = self.reference_group.as_str();
let group_a_name_temp = unique_groups
.str()?
.get(0)
.unwrap_or(self.reference_group.as_str());
let group_a_name = if group_a_name_temp == group_b_name {
unique_groups.str()?.get(1).unwrap_or("")
} else {
group_a_name_temp
};
let df_a = df.filter(
&df.column(&self.group)?
.as_materialized_series()
.equal(group_a_name)?,
)?;
let df_b = df.filter(
&df.column(&self.group)?
.as_materialized_series()
.equal(group_b_name)?,
)?;
let (x_a, y_a, _, predictor_names) = self.prepare_data(&df_a, &all_dummy_names, &[])?;
let (x_b, y_b, _, _) = self.prepare_data(&df_b, &all_dummy_names, &[])?;
Ok((x_a, y_a, x_b, y_b, predictor_names))
}
fn prepare_data(
&self,
df: &DataFrame,
all_dummy_names: &[String],
extra_predictors: &[String],
) -> Result<
(
DMatrix<f64>,
DVector<f64>,
Option<DVector<f64>>,
Vec<String>,
),
OaxacaError,
> {
let y_series = df.column(&self.outcome)?.f64()?;
let y_vec: Vec<f64> = y_series
.into_iter()
.map(|opt| {
opt.ok_or_else(|| {
OaxacaError::PolarsError(PolarsError::ComputeError(
"Null values found in outcome after cleaning".into(),
))
})
})
.collect::<Result<Vec<f64>, _>>()?;
let y = DVector::from_vec(y_vec);
let mut current_predictors = self.predictors.clone();
current_predictors.extend_from_slice(extra_predictors);
let mut final_predictors: Vec<String> = vec!["intercept".to_string()];
final_predictors.extend_from_slice(¤t_predictors);
final_predictors.extend_from_slice(all_dummy_names);
let mut x_df = df.select(¤t_predictors)?;
let intercept_series = Series::new("intercept".into(), vec![1.0; df.height()]);
x_df.with_column(intercept_series)?;
for name in all_dummy_names {
if df
.get_column_names()
.iter()
.any(|s| s.as_str() == name.as_str())
{
x_df.with_column(df.column(name)?.clone())?;
} else {
let zero_series = Series::new(name.into(), vec![0.0; df.height()]);
x_df.with_column(zero_series)?;
}
}
let x_df_selected = x_df.select(&final_predictors)?;
let x_matrix = x_df_selected.to_ndarray::<Float64Type>(IndexOrder::Fortran)?;
let x_vec: Vec<f64> = x_matrix.iter().copied().collect();
let final_names = x_df_selected
.get_column_names()
.iter()
.map(|s| s.to_string())
.collect();
let weights = if let Some(w_col) = &self.weights_col {
let w_series = df.column(w_col)?.f64()?;
let w_vec: Vec<f64> = w_series
.into_iter()
.map(|opt| {
opt.ok_or_else(|| {
OaxacaError::PolarsError(PolarsError::ComputeError(
"Null weights found after cleaning".into(),
))
})
})
.collect::<Result<Vec<f64>, _>>()?;
Some(DVector::from_vec(w_vec))
} else {
None
};
Ok((
DMatrix::from_row_slice(x_df_selected.height(), x_df_selected.width(), &x_vec),
y,
weights,
final_names,
))
}
fn create_dummies_manual(
&self,
series: &Series,
) -> Result<(DataFrame, usize, String), OaxacaError> {
let unique_vals = series.unique()?.sort(SortOptions {
descending: false,
nulls_last: false,
..Default::default()
})?;
let m = unique_vals.len();
let mut dummy_vars: Vec<Series> = Vec::new();
let reference_val = if let Some(s) = unique_vals.str()?.get(0) {
s
} else {
return Err(OaxacaError::InvalidGroupVariable(format!(
"Could not get reference category for {}",
series.name()
)));
};
let reference_name = format!("{}_{}", series.name(), reference_val);
for val in unique_vals.str()?.into_iter().flatten().skip(1) {
let dummy_name = format!("{}_{}", series.name(), val);
let ca = series.equal(val)?;
let mut dummy_series = ca.into_series();
dummy_series = dummy_series.cast(&DataType::Float64)?;
dummy_series.rename(dummy_name.as_str().into());
dummy_vars.push(dummy_series);
}
Ok((
DataFrame::new(dummy_vars.into_iter().map(Column::Series).collect())
.map_err(OaxacaError::from)?,
m,
reference_name,
))
}
fn run_single_pass(
&self,
df: &DataFrame,
all_dummy_names: &[String],
category_counts: &std::collections::HashMap<String, usize>,
base_categories: &std::collections::HashMap<String, String>,
) -> Result<SinglePassResult, OaxacaError> {
let unique_groups = df.column(&self.group)?.unique()?.sort(SortOptions {
descending: false,
nulls_last: false,
..Default::default()
})?;
if unique_groups.len() < 2 {
return Err(OaxacaError::InvalidGroupVariable(
"Not enough groups for comparison".to_string(),
));
}
let group_b_name = self.reference_group.as_str();
let group_a_name = unique_groups
.str()?
.get(0)
.unwrap_or(self.reference_group.as_str());
let group_a_name = if group_a_name == group_b_name {
unique_groups.str()?.get(1).unwrap_or("")
} else {
group_a_name
};
let df_a = df.filter(
&df.column(&self.group)?
.as_materialized_series()
.equal(group_a_name)?,
)?;
let df_b = df.filter(
&df.column(&self.group)?
.as_materialized_series()
.equal(group_b_name)?,
)?;
if df_a.height() == 0 || df_b.height() == 0 {
return Err(OaxacaError::InvalidGroupVariable(
"One group has no data".to_string(),
));
}
let (x_a, y_a, w_a, predictor_names) = self.prepare_data(&df_a, all_dummy_names, &[])?;
let (x_b, y_b, w_b, _) = self.prepare_data(&df_b, all_dummy_names, &[])?;
let ctx = EstimationContext {
df_a: &df_a,
df_b: &df_b,
x_a: &x_a,
y_a: &y_a,
w_a: &w_a,
x_b: &x_b,
y_b: &y_b,
w_b: &w_b,
predictor_names: &predictor_names,
category_counts,
};
let estimator: Box<dyn Estimator> = if let Some(sel_outcome) = &self.selection_outcome {
Box::new(HeckmanEstimator {
selection_outcome: sel_outcome.clone(),
selection_predictors: self.selection_predictors.clone(),
})
} else {
Box::new(OlsEstimator {
normalization_vars: self.normalization_vars.clone(),
})
};
let result = estimator.estimate(&ctx)?;
let beta_a = &result.beta_a;
let beta_b = &result.beta_b;
let xa_mean = result.xa_mean;
let xb_mean = result.xb_mean;
let final_predictor_names = result.predictor_names;
let residuals_a = result.residuals_a;
let residuals_b = result.residuals_b;
let base_coeffs_a = result.base_coeffs_a;
let base_coeffs_b = result.base_coeffs_b;
let mut detailed_selection_components = Vec::new();
if let (
Some(gamma_a),
Some(gamma_b),
Some(z_mean_a),
Some(z_mean_b),
Some(delta_a),
Some(delta_b),
Some(_sel_names),
) = (
&result.selection_coeffs_a,
&result.selection_coeffs_b,
&result.selection_means_a,
&result.selection_means_b,
result.imr_delta_a,
result.imr_delta_b,
&result.selection_names,
) {
let (theta_ref, delta_ref, gamma_ref) = match self.reference_coeffs {
ReferenceCoefficients::GroupA => (beta_a[beta_a.len() - 1], delta_a, gamma_a),
ReferenceCoefficients::GroupB => (beta_b[beta_b.len() - 1], delta_b, gamma_b),
_ => (beta_b[beta_b.len() - 1], delta_b, gamma_b), };
let mut full_sel_names = vec!["intercept".to_string()];
full_sel_names.extend(self.selection_predictors.clone());
if gamma_ref.len() == full_sel_names.len() && z_mean_a.len() == full_sel_names.len() {
for (i, name) in full_sel_names.iter().enumerate() {
let diff_z = z_mean_a[i] - z_mean_b[i];
let contribution = theta_ref * delta_ref * gamma_ref[i] * diff_z;
detailed_selection_components.push(DetailedComponent {
variable_name: name.clone(),
contribution,
});
}
}
}
let mut base_coeffs_star = std::collections::HashMap::new();
let beta_star_owned: DVector<f64>;
let beta_star: &DVector<f64> = match self.reference_coeffs {
ReferenceCoefficients::GroupA => {
base_coeffs_star = base_coeffs_a.clone();
beta_a
}
ReferenceCoefficients::GroupB => {
base_coeffs_star = base_coeffs_b.clone();
beta_b
}
ReferenceCoefficients::Pooled | ReferenceCoefficients::Neumark => {
let mut df_pooled = df_a.vstack(&df_b)?;
let group_indicator = Series::new(
"group_indicator".into(),
df_pooled
.column(&self.group)?
.as_materialized_series()
.equal(group_a_name)?
.into_series()
.cast(&DataType::Float64)?,
);
df_pooled.with_column(group_indicator)?;
let (x_pooled, y_pooled, w_pooled, pooled_predictor_names) = self.prepare_data(
&df_pooled,
all_dummy_names,
&["group_indicator".to_string()],
)?;
let mut ols_pooled = ols(&y_pooled, &x_pooled, w_pooled.as_ref())?;
if !self.normalization_vars.is_empty() {
let n_a = df_a.height() as f64;
let n_b = df_b.height() as f64;
let x_pool_mean = (xa_mean.clone() * n_a + xb_mean.clone() * n_b) / (n_a + n_b);
base_coeffs_star = normalize_categorical_coefficients(
&mut ols_pooled,
&pooled_predictor_names,
&self.normalization_vars,
&x_pool_mean,
category_counts,
);
}
let group_indicator_idx = pooled_predictor_names
.iter()
.position(|r| r == "group_indicator")
.ok_or_else(|| {
OaxacaError::NalgebraError(
"group_indicator not found in pooled model predictors".to_string(),
)
})?;
beta_star_owned = ols_pooled.coefficients.remove_row(group_indicator_idx);
&beta_star_owned
}
ReferenceCoefficients::Weighted | ReferenceCoefficients::Cotton => {
let n_a = if let Some(w) = &w_a {
w.sum()
} else {
df_a.height() as f64
};
let n_b = if let Some(w) = &w_b {
w.sum()
} else {
df_b.height() as f64
};
let total_n = n_a + n_b;
if total_n == 0.0 {
return Err(OaxacaError::InvalidGroupVariable(
"No data in groups for weighted coefficients.".to_string(),
));
}
let weight_a = n_a / total_n;
let weight_b = 1.0 - weight_a;
if !self.normalization_vars.is_empty() {
for var in &self.normalization_vars {
let coeff_a = base_coeffs_a.get(var).unwrap_or(&0.0);
let coeff_b = base_coeffs_b.get(var).unwrap_or(&0.0);
base_coeffs_star
.insert(var.clone(), coeff_a * weight_a + coeff_b * weight_b);
}
}
beta_star_owned = beta_a * weight_a + beta_b * weight_b;
&beta_star_owned
}
};
let three_fold = three_fold_decomposition(&xa_mean, &xb_mean, beta_a, beta_b);
let mut two_fold = two_fold_decomposition(&xa_mean, &xb_mean, beta_a, beta_b, beta_star);
let (mut detailed_explained, mut detailed_unexplained) = detailed_decomposition(
&xa_mean,
&xb_mean,
beta_a,
beta_b,
beta_star,
&final_predictor_names,
);
if !self.normalization_vars.is_empty() && self.selection_outcome.is_none() {
for var in &self.normalization_vars {
let base_dummy_name = if let Some(name) = base_categories.get(var) {
name
} else {
continue;
};
let dummy_indices: Vec<usize> = final_predictor_names
.iter()
.enumerate()
.filter(|(_, name)| name.starts_with(&format!("{}_", var)))
.map(|(i, _)| i)
.collect();
let xa_mean_base = 1.0 - dummy_indices.iter().map(|&i| xa_mean[i]).sum::<f64>();
let xb_mean_base = 1.0 - dummy_indices.iter().map(|&i| xb_mean[i]).sum::<f64>();
let beta_a_base = base_coeffs_a.get(var).cloned().unwrap_or(0.0);
let beta_b_base = base_coeffs_b.get(var).cloned().unwrap_or(0.0);
let beta_star_base = base_coeffs_star.get(var).cloned().unwrap_or(0.0);
let contribution_unexplained = xa_mean_base * (beta_a_base - beta_star_base)
+ xb_mean_base * (beta_star_base - beta_b_base);
let contribution_explained = (xa_mean_base - xb_mean_base) * beta_star_base;
detailed_unexplained.push(DetailedComponent {
variable_name: base_dummy_name.clone(),
contribution: contribution_unexplained,
});
detailed_explained.push(DetailedComponent {
variable_name: base_dummy_name.clone(),
contribution: contribution_explained,
});
two_fold.explained += contribution_explained;
two_fold.unexplained += contribution_unexplained;
}
}
let total_gap = if let Some(w) = &w_a {
y_a.dot(w) / w.sum()
} else {
y_a.mean()
} - if let Some(w) = &w_b {
y_b.dot(w) / w.sum()
} else {
y_b.mean()
};
Ok(SinglePassResult {
three_fold,
two_fold,
detailed_explained,
detailed_unexplained,
total_gap,
residuals_a,
residuals_b,
xa_mean: xa_mean.clone(),
xb_mean: xb_mean.clone(),
beta_star: beta_star.clone(),
detailed_selection: detailed_selection_components,
})
}
pub fn decompose_quantile(&self, quantile: f64) -> Result<OaxacaResults, OaxacaError> {
let df_dirty = self.dataframe.clone();
let df = self.clean_dataframe(&df_dirty)?;
let unique_groups = df.column(&self.group)?.unique()?.sort(SortOptions {
descending: false,
nulls_last: false,
..Default::default()
})?;
if unique_groups.len() < 2 {
return Err(OaxacaError::InvalidGroupVariable(
"Not enough groups".to_string(),
));
}
let group_b_name = self.reference_group.as_str();
let group_a_name_temp = unique_groups
.str()?
.get(0)
.unwrap_or(self.reference_group.as_str());
let group_a_name = if group_a_name_temp == group_b_name {
unique_groups.str()?.get(1).unwrap_or("")
} else {
group_a_name_temp
};
let df_a = df.filter(
&df.column(&self.group)?
.as_materialized_series()
.equal(group_a_name)?,
)?;
let df_b = df.filter(
&df.column(&self.group)?
.as_materialized_series()
.equal(group_b_name)?,
)?;
let rif_a = calculate_rif(
df_a.column(&self.outcome)?.as_materialized_series(),
quantile,
)
.map_err(OaxacaError::PolarsError)?;
let rif_b = calculate_rif(
df_b.column(&self.outcome)?.as_materialized_series(),
quantile,
)
.map_err(OaxacaError::PolarsError)?;
let mut df_a_mod = df_a.clone();
df_a_mod.with_column(rif_a)?;
let mut df_b_mod = df_b.clone();
df_b_mod.with_column(rif_b)?;
let df_mod = df_a_mod.vstack(&df_b_mod)?;
let mut builder =
OaxacaBuilder::new(df_mod, &self.outcome, &self.group, &self.reference_group);
builder
.predictors(
&self
.predictors
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
)
.categorical_predictors(
&self
.categorical_predictors
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
)
.bootstrap_reps(self.bootstrap_reps)
.reference_coefficients(self.reference_coeffs.clone())
.normalize(
&self
.normalization_vars
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
);
if let Some(w) = &self.weights_col {
builder.weights(w);
}
builder.run()
}
fn clean_dataframe(&self, df: &DataFrame) -> Result<DataFrame, OaxacaError> {
let mut cols = vec![self.outcome.clone(), self.group.clone()];
cols.extend(self.predictors.clone());
cols.extend(self.categorical_predictors.clone());
if let Some(w) = &self.weights_col {
cols.push(w.to_string());
}
if let Some(sel_out) = &self.selection_outcome {
cols.push(sel_out.to_string());
}
cols.extend(self.selection_predictors.clone());
for c in &cols {
if df.column(c).is_err() {
return Err(OaxacaError::ColumnNotFound(c.clone()));
}
}
let clean_df = df
.drop_nulls(Some(&cols))
.map_err(OaxacaError::PolarsError)?;
Ok(clean_df)
}
pub fn run(&self) -> Result<OaxacaResults, OaxacaError> {
let df_dirty = self.dataframe.clone();
let mut df = self.clean_dataframe(&df_dirty)?;
let mut all_dummy_names = Vec::new();
let mut category_counts = std::collections::HashMap::new();
let mut base_categories = std::collections::HashMap::new();
if !self.categorical_predictors.is_empty() {
for cat_pred in &self.categorical_predictors {
let series = df.column(cat_pred)?;
let (dummies, m, base_name) =
self.create_dummies_manual(series.as_materialized_series())?;
category_counts.insert(cat_pred.clone(), m);
base_categories.insert(cat_pred.clone(), base_name);
for s in dummies.get_columns() {
all_dummy_names.push(s.name().to_string());
}
df = df.hstack(dummies.get_columns())?;
}
}
let unique_groups = self
.dataframe
.column(&self.group)?
.unique()?
.sort(SortOptions {
descending: false,
nulls_last: false,
..Default::default()
})?;
let group_b_name = self.reference_group.as_str();
let group_a_name_temp = unique_groups
.str()?
.get(0)
.unwrap_or(self.reference_group.as_str());
let group_a_name = if group_a_name_temp == group_b_name {
unique_groups.str()?.get(1).unwrap_or("")
} else {
group_a_name_temp
};
use rayon::prelude::*;
let point_estimates =
self.run_single_pass(&df, &all_dummy_names, &category_counts, &base_categories)?;
let group_a_name_owned = group_a_name.to_string();
let group_b_name_owned = group_b_name.to_string();
let bootstrap_results: Vec<SinglePassResult> = (0..self.bootstrap_reps)
.into_par_iter()
.filter_map(|_| {
let df_a = df
.filter(
&df.column(&self.group)
.ok()?
.as_materialized_series()
.equal(group_a_name_owned.as_str())
.ok()?,
)
.ok()?;
let df_b = df
.filter(
&df.column(&self.group)
.ok()?
.as_materialized_series()
.equal(group_b_name_owned.as_str())
.ok()?,
)
.ok()?;
let sample_a = df_a
.sample_n_literal(df_a.height(), true, false, None)
.ok()?;
let sample_b = df_b
.sample_n_literal(df_b.height(), true, false, None)
.ok()?;
let sample_df = sample_a.vstack(&sample_b).ok()?;
self.run_single_pass(
&sample_df,
&all_dummy_names,
&category_counts,
&base_categories,
)
.ok()
})
.collect();
let successful_bootstraps = bootstrap_results.len();
if successful_bootstraps < self.bootstrap_reps {
eprintln!(
"Warning: {} out of {} bootstrap replications failed and were discarded. The analysis is based on {} successful replications.",
self.bootstrap_reps - successful_bootstraps, self.bootstrap_reps, successful_bootstraps
);
}
let process_component = |name: &str, point: f64, estimates: Vec<f64>| {
let (std_err, p_value, (ci_lower, ci_upper)) = bootstrap_stats(&estimates, point);
let t_stat = if std_err.abs() > 1e-9 {
point / std_err
} else {
0.0
};
ComponentResult {
name: name.to_string(),
estimate: point,
std_err,
t_stat,
p_value,
ci_lower,
ci_upper,
}
};
let two_fold_agg = vec![
process_component(
"explained",
point_estimates.two_fold.explained,
bootstrap_results
.iter()
.map(|r| r.two_fold.explained)
.collect(),
),
process_component(
"unexplained",
point_estimates.two_fold.unexplained,
bootstrap_results
.iter()
.map(|r| r.two_fold.unexplained)
.collect(),
),
];
let three_fold_agg = vec![
process_component(
"endowments",
point_estimates.three_fold.endowments,
bootstrap_results
.iter()
.map(|r| r.three_fold.endowments)
.collect(),
),
process_component(
"coefficients",
point_estimates.three_fold.coefficients,
bootstrap_results
.iter()
.map(|r| r.three_fold.coefficients)
.collect(),
),
process_component(
"interaction",
point_estimates.three_fold.interaction,
bootstrap_results
.iter()
.map(|r| r.three_fold.interaction)
.collect(),
),
];
let detailed_explained = self.process_detailed_components(
&point_estimates.detailed_explained,
&bootstrap_results,
|r| &r.detailed_explained,
&process_component,
);
let detailed_unexplained = self.process_detailed_components(
&point_estimates.detailed_unexplained,
&bootstrap_results,
|r| &r.detailed_unexplained,
&process_component,
);
let detailed_selection = self.process_detailed_components(
&point_estimates.detailed_selection,
&bootstrap_results,
|r| &r.detailed_selection,
&process_component,
);
Ok(OaxacaResults {
total_gap: point_estimates.total_gap,
two_fold: TwoFoldResults {
aggregate: two_fold_agg,
detailed_explained,
detailed_unexplained,
detailed_selection,
},
three_fold: DecompositionDetail {
aggregate: three_fold_agg,
detailed: Vec::new(),
},
n_a: df
.filter(
&df.column(&self.group)?
.as_materialized_series()
.equal(group_a_name)?,
)?
.height(),
n_b: df
.filter(
&df.column(&self.group)?
.as_materialized_series()
.equal(group_b_name)?,
)?
.height(),
residuals: point_estimates.residuals_b.iter().copied().collect(),
xa_mean: point_estimates.xa_mean,
xb_mean: point_estimates.xb_mean,
beta_star: point_estimates.beta_star,
})
}
fn process_detailed_components<'a, F>(
&self,
point_components: &[DetailedComponent],
bootstrap_results: &'a [SinglePassResult],
extract_fn: F,
process_component: &dyn Fn(&str, f64, Vec<f64>) -> ComponentResult,
) -> Vec<ComponentResult>
where
F: Fn(&'a SinglePassResult) -> &'a Vec<DetailedComponent> + Sync,
{
let mut bootstrap_map: HashMap<String, Vec<f64>> = HashMap::new();
for r in bootstrap_results.iter() {
for comp in extract_fn(r) {
bootstrap_map
.entry(comp.variable_name.clone())
.or_default()
.push(comp.contribution);
}
}
point_components
.iter()
.map(|comp| {
let estimates = bootstrap_map
.get(&comp.variable_name)
.cloned()
.unwrap_or_else(Vec::new);
process_component(&comp.variable_name, comp.contribution, estimates)
})
.collect()
}
}
#[derive(Debug, Getters, Serialize)]
#[getset(get = "pub")]
pub struct TwoFoldResults {
pub aggregate: Vec<ComponentResult>,
pub detailed_explained: Vec<ComponentResult>,
pub detailed_unexplained: Vec<ComponentResult>,
pub detailed_selection: Vec<ComponentResult>,
}
#[derive(Debug, Getters, Serialize)]
#[getset(get = "pub")]
pub struct OaxacaResults {
pub total_gap: f64,
pub two_fold: TwoFoldResults,
pub three_fold: DecompositionDetail,
pub n_a: usize,
pub n_b: usize,
pub residuals: Vec<f64>,
#[serde(skip)]
pub xa_mean: DVector<f64>,
#[serde(skip)]
pub xb_mean: DVector<f64>,
#[serde(skip)]
pub beta_star: DVector<f64>,
}
impl OaxacaResults {
pub fn explained(&self) -> &ComponentResult {
self.two_fold
.aggregate()
.iter()
.find(|c| c.name == "explained")
.expect("Explained component not found")
}
pub fn unexplained(&self) -> &ComponentResult {
self.two_fold
.aggregate()
.iter()
.find(|c| c.name == "unexplained")
.expect("Unexplained component not found")
}
pub fn get_summary_table(&self) -> Vec<(&String, &ComponentResult)> {
self.two_fold
.aggregate()
.iter()
.map(|c| (&c.name, c))
.collect()
}
pub fn get_detailed_table(&self) -> Vec<(String, f64, f64)> {
let mut map = std::collections::HashMap::new();
for comp in self.two_fold.detailed_explained() {
map.entry(comp.name().clone()).or_insert((0.0, 0.0)).0 = *comp.estimate();
}
for comp in self.two_fold.detailed_unexplained() {
map.entry(comp.name().clone()).or_insert((0.0, 0.0)).1 = *comp.estimate();
}
map.into_iter().map(|(k, (v1, v2))| (k, v1, v2)).collect()
}
#[cfg(feature = "display")]
pub fn summary(&self) {
println!("Oaxaca-Blinder Decomposition Results");
println!("========================================");
println!("Group A (Advantaged): {} observations", self.n_a);
println!("Group B (Reference): {} observations", self.n_b);
println!("Total Gap: {:.4}", self.total_gap);
println!();
let mut two_fold_table = Table::new();
two_fold_table.set_header(vec![
"Component",
"Estimate",
"Std. Err.",
"p-value",
"95% CI",
]);
for component in self.two_fold.aggregate() {
let ci = format!("[{:.3}, {:.3}]", component.ci_lower(), component.ci_upper());
two_fold_table.add_row(vec![
Cell::new(component.name()),
Cell::new(format!("{:.4}", component.estimate())),
Cell::new(format!("{:.4}", component.std_err())),
Cell::new(format!("{:.4}", component.p_value())),
Cell::new(ci),
]);
}
println!("Two-Fold Decomposition");
println!("{}", two_fold_table);
let mut explained_table = Table::new();
explained_table.set_header(vec![
"Variable",
"Contribution",
"Std. Err.",
"p-value",
"95% CI",
]);
for component in self.two_fold.detailed_explained() {
let ci = format!("[{:.3}, {:.3}]", component.ci_lower(), component.ci_upper());
explained_table.add_row(vec![
Cell::new(component.name()),
Cell::new(format!("{:.4}", component.estimate())),
Cell::new(format!("{:.4}", component.std_err())),
Cell::new(format!("{:.4}", component.p_value())),
Cell::new(ci),
]);
}
println!("\nDetailed Decomposition (Explained)");
println!("{}", explained_table);
let mut unexplained_table = Table::new();
unexplained_table.set_header(vec![
"Variable",
"Contribution",
"Std. Err.",
"p-value",
"95% CI",
]);
for component in self.two_fold.detailed_unexplained() {
let ci = format!("[{:.3}, {:.3}]", component.ci_lower(), component.ci_upper());
unexplained_table.add_row(vec![
Cell::new(component.name()),
Cell::new(format!("{:.4}", component.estimate())),
Cell::new(format!("{:.4}", component.std_err())),
Cell::new(format!("{:.4}", component.p_value())),
Cell::new(ci),
]);
}
println!("\nDetailed Decomposition (Unexplained)");
println!("{}", unexplained_table);
}
pub fn to_latex(&self) -> String {
let mut latex = String::new();
latex.push_str("\\begin{table}[ht]\n");
latex.push_str("\\centering\n");
latex.push_str("\\begin{tabular}{lcccc}\n");
latex.push_str("\\hline\n");
latex.push_str("Component & Estimate & Std. Err. & p-value & 95\\% CI \\\\\n");
latex.push_str("\\hline\n");
latex.push_str("\\multicolumn{5}{l}{\\textit{Two-Fold Decomposition}} \\\\\n");
for component in self.two_fold.aggregate() {
latex.push_str(&format!(
"{} & {:.4} & {:.4} & {:.4} & [{:.3}, {:.3}] \\\\\n",
component.name(),
component.estimate(),
component.std_err(),
component.p_value(),
component.ci_lower(),
component.ci_upper()
));
}
latex.push_str("\\hline\n");
latex.push_str("\\end{tabular}\n");
latex.push_str("\\caption{Oaxaca-Blinder Decomposition Results}\n");
latex.push_str("\\label{tab:oaxaca_results}\n");
latex.push_str("\\end{table}\n");
latex
}
pub fn to_markdown(&self) -> String {
let mut md = String::new();
md.push_str("### Oaxaca-Blinder Decomposition Results\n\n");
md.push_str("| Component | Estimate | Std. Err. | p-value | 95% CI |\n");
md.push_str("|---|---|---|---|---|\n");
for component in self.two_fold.aggregate() {
md.push_str(&format!(
"| {} | {:.4} | {:.4} | {:.4} | [{:.3}, {:.3}] |\n",
component.name(),
component.estimate(),
component.std_err(),
component.p_value(),
component.ci_lower(),
component.ci_upper()
));
}
md
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(self)
}
pub fn optimize_budget(&self, budget: f64, target_gap: f64) -> Vec<BudgetAdjustment> {
let current_gap = self.total_gap;
if current_gap <= target_gap {
return Vec::new();
}
let required_reduction = current_gap - target_gap;
let total_needed = required_reduction * self.n_b as f64;
let effective_budget = budget.min(total_needed);
let mut candidates: Vec<(usize, f64)> = self
.residuals
.iter()
.enumerate()
.filter(|(_, &r)| r < 0.0)
.map(|(i, &r)| (i, r))
.collect();
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut adjustments = Vec::new();
let mut spent = 0.0;
for (index, residual) in candidates {
if spent >= effective_budget {
break;
}
let max_raise = -residual;
let remaining_budget = effective_budget - spent;
let raise = if max_raise <= remaining_budget {
max_raise
} else {
remaining_budget
};
if raise > 1e-9 {
adjustments.push(BudgetAdjustment {
index,
original_residual: residual,
adjustment: raise,
});
spent += raise;
}
}
adjustments
}
}
#[derive(Debug, Getters, Serialize)]
#[getset(get = "pub")]
pub struct DecompositionDetail {
pub aggregate: Vec<ComponentResult>,
pub detailed: Vec<ComponentResult>,
}
#[derive(Debug, Getters, Clone, Serialize)]
#[getset(get = "pub")]
pub struct ComponentResult {
pub name: String,
pub estimate: f64,
pub std_err: f64,
pub t_stat: f64,
pub p_value: f64,
pub ci_lower: f64,
pub ci_upper: f64,
}
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
let result = 2 + 2;
assert_eq!(result, 4);
}
}