use crate::math::kde::{kde, silverman_bandwidth};
use crate::math::logit::logit;
use crate::OaxacaError;
use nalgebra::{DMatrix, DVector};
use polars::prelude::*;
use serde::Serialize;
#[derive(Debug, Serialize)]
pub struct DflResult {
pub grid: Vec<f64>,
pub density_a: Vec<f64>,
pub density_b: Vec<f64>,
pub density_b_counterfactual: Vec<f64>,
}
pub fn run_dfl(
df: &DataFrame,
outcome: &str,
group: &str,
reference_group: &str,
predictors: &[String],
) -> Result<DflResult, OaxacaError> {
let unique_groups = df.column(group)?.unique()?.sort(SortOptions {
descending: false,
nulls_last: false,
..Default::default()
})?;
let group_b_name = reference_group;
let group_a_name_temp = unique_groups.str()?.get(0).unwrap_or(reference_group);
let group_a_name = if group_a_name_temp == group_b_name {
unique_groups.str()?.get(1).unwrap_or("")
} else {
group_a_name_temp
};
let group_series = df.column(group)?;
let target_vec: Vec<f64> = group_series
.str()?
.into_iter()
.map(|opt_s| {
if opt_s.unwrap_or("") == group_a_name {
1.0
} else {
0.0
}
})
.collect();
let y = DVector::from_vec(target_vec);
let mut x_cols = Vec::new();
let intercept = Series::new("intercept".into(), vec![1.0; df.height()]);
x_cols.push(Column::Series(intercept));
for pred in predictors {
let col = df.column(pred)?.cast(&DataType::Float64)?;
x_cols.push(col);
}
let x_df = DataFrame::new(x_cols).map_err(OaxacaError::from)?;
let x_matrix = x_df.to_ndarray::<Float64Type>(IndexOrder::Fortran)?;
let x_vec: Vec<f64> = x_matrix.iter().copied().collect();
let x = DMatrix::from_row_slice(x_df.height(), x_df.width(), &x_vec);
let logit_res = logit(&y, &x, 100, 1e-6)?;
let probs = logit_res.predicted_probs;
let n = df.height() as f64;
let n_a = df
.filter(
&df.column(group)?
.as_materialized_series()
.equal(group_a_name)?,
)?
.height() as f64;
let n_b = df
.filter(
&df.column(group)?
.as_materialized_series()
.equal(group_b_name)?,
)?
.height() as f64;
let p_a_marginal = n_a / n;
let p_b_marginal = n_b / n;
let ratio_marginal = p_b_marginal / p_a_marginal;
let mut weights_counterfactual = Vec::new();
let outcome_series = df.column(outcome)?.f64()?;
let mut outcome_b = Vec::new();
let mut outcome_a = Vec::new();
for i in 0..df.height() {
let is_group_b = y[i] == 0.0;
let val = outcome_series.get(i).unwrap_or(0.0);
if is_group_b {
let p_x = probs[i];
let p_x = p_x.min(0.9999).max(0.0001);
let weight = (p_x / (1.0 - p_x)) * ratio_marginal;
weights_counterfactual.push(weight);
outcome_b.push(val);
} else {
outcome_a.push(val);
}
}
let all_outcomes: Vec<f64> = outcome_series.into_no_null_iter().collect();
let min_val = all_outcomes.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max_val = all_outcomes
.iter()
.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let range = max_val - min_val;
let grid_size = 100;
let step = range / (grid_size as f64);
let grid: Vec<f64> = (0..grid_size).map(|i| min_val + i as f64 * step).collect();
let bandwidth_a = silverman_bandwidth(&outcome_a);
let bandwidth_b = silverman_bandwidth(&outcome_b);
let density_a = kde(&outcome_a, None, &grid, bandwidth_a);
let density_b = kde(&outcome_b, None, &grid, bandwidth_b);
let density_b_counterfactual = kde(
&outcome_b,
Some(&weights_counterfactual),
&grid,
bandwidth_b,
);
Ok(DflResult {
grid,
density_a,
density_b,
density_b_counterfactual,
})
}