use crate::matching::distance::{DistanceMetric, MahalanobisDistance};
use crate::matching::logistic::LogisticRegression;
use crate::OaxacaError;
use kdtree::distance::squared_euclidean;
use kdtree::KdTree;
use nalgebra::{DMatrix, DVector};
use polars::prelude::*;
pub struct MatchingEngine {
dataframe: DataFrame,
treatment_col: String,
outcome_col: String,
covariates: Vec<String>,
}
impl MatchingEngine {
pub fn new(
dataframe: DataFrame,
treatment_col: &str,
outcome_col: &str,
covariates: &[&str],
) -> Self {
Self {
dataframe,
treatment_col: treatment_col.to_string(),
outcome_col: outcome_col.to_string(),
covariates: covariates.iter().map(|s| s.to_string()).collect(),
}
}
fn prepare_data(
&self,
) -> Result<(DMatrix<f64>, DMatrix<f64>, DVector<f64>, DVector<f64>), OaxacaError> {
let df = &self.dataframe;
let treated_mask = df
.column(&self.treatment_col)?
.as_materialized_series()
.equal(1)?;
let control_mask = df
.column(&self.treatment_col)?
.as_materialized_series()
.equal(0)?;
let treated_df = df.filter(&treated_mask)?;
let control_df = df.filter(&control_mask)?;
if treated_df.height() == 0 || control_df.height() == 0 {
return Err(OaxacaError::InvalidGroupVariable(
"One group is empty".to_string(),
));
}
let get_matrix = |d: &DataFrame| -> Result<DMatrix<f64>, OaxacaError> {
let selected = d.select(&self.covariates)?;
let mat = selected.to_ndarray::<Float64Type>(IndexOrder::Fortran)?;
let vec: Vec<f64> = mat.iter().copied().collect();
Ok(DMatrix::from_row_slice(
selected.height(),
selected.width(),
&vec,
))
};
let x_treated = get_matrix(&treated_df)?;
let x_control = get_matrix(&control_df)?;
let get_vec = |d: &DataFrame| -> Result<DVector<f64>, OaxacaError> {
let s = d.column(&self.outcome_col)?.f64()?;
let v: Vec<f64> = s.into_iter().map(|opt| opt.unwrap_or(0.0)).collect();
Ok(DVector::from_vec(v))
};
let y_treated = get_vec(&treated_df)?;
let y_control = get_vec(&control_df)?;
Ok((x_treated, x_control, y_treated, y_control))
}
pub fn match_nearest_neighbor<M: DistanceMetric>(
&self,
k: usize,
metric: &M,
) -> Result<DVector<f64>, OaxacaError> {
let use_mahalanobis = metric.is_mahalanobis();
let weights = self.run_matching(k, use_mahalanobis)?;
Ok(DVector::from_vec(weights))
}
pub fn run_matching(&self, k: usize, use_mahalanobis: bool) -> Result<Vec<f64>, OaxacaError> {
let mut df = self.dataframe.clone();
let indices: Vec<u32> = (0..df.height() as u32).collect();
let df = df
.with_column(Series::new("orig_index".into(), indices))?
.clone();
let treated_mask = df
.column(&self.treatment_col)?
.as_materialized_series()
.equal(1)?;
let control_mask = df
.column(&self.treatment_col)?
.as_materialized_series()
.equal(0)?;
let treated_df = df.filter(&treated_mask)?;
let control_df = df.filter(&control_mask)?;
let get_matrix = |d: &DataFrame| -> Result<DMatrix<f64>, OaxacaError> {
let selected = d.select(&self.covariates)?;
let mat = selected.to_ndarray::<Float64Type>(IndexOrder::Fortran)?;
let vec: Vec<f64> = mat.iter().copied().collect();
Ok(DMatrix::from_row_slice(
selected.height(),
selected.width(),
&vec,
))
};
let mut x_treated = get_matrix(&treated_df)?;
let mut x_control = get_matrix(&control_df)?;
if use_mahalanobis {
let metric =
MahalanobisDistance::new(&x_control).map_err(OaxacaError::DiagnosticError)?;
let inv_cov = metric.inv_covariance;
let cholesky = inv_cov.cholesky().ok_or(OaxacaError::DiagnosticError(
"Cholesky decomposition failed".to_string(),
))?;
let l = cholesky.l();
x_treated = x_treated * &l;
x_control = x_control * &l;
}
let n_features = x_control.ncols();
let mut kdtree = KdTree::new(n_features);
for i in 0..x_control.nrows() {
let row = x_control.row(i);
let point: Vec<f64> = row.iter().copied().collect();
kdtree
.add(point, i)
.map_err(|e| OaxacaError::DiagnosticError(format!("K-D Tree error: {}", e)))?;
}
let mut control_counts = vec![0.0; x_control.nrows()];
for i in 0..x_treated.nrows() {
let row = x_treated.row(i);
let point: Vec<f64> = row.iter().copied().collect();
let nearest = kdtree.nearest(&point, k, &squared_euclidean).map_err(|e| {
OaxacaError::DiagnosticError(format!("K-D Tree search error: {}", e))
})?;
for (_dist, &index) in nearest {
control_counts[index] += 1.0 / (k as f64);
}
}
let mut final_weights = vec![0.0; df.height()];
let treated_indices = treated_df.column("orig_index")?.u32()?;
for idx in treated_indices {
if let Some(i) = idx {
final_weights[i as usize] = 1.0;
}
}
let control_indices = control_df.column("orig_index")?.u32()?;
for (local_idx, orig_idx) in control_indices.into_iter().enumerate() {
if let Some(i) = orig_idx {
final_weights[i as usize] = control_counts[local_idx];
}
}
Ok(final_weights)
}
pub fn match_psm(&self, k: usize) -> Result<Vec<f64>, OaxacaError> {
let (_x_treated, _x_control, _, _) = self.prepare_data()?;
let df = &self.dataframe;
let get_matrix = |d: &DataFrame| -> Result<DMatrix<f64>, OaxacaError> {
let selected = d.select(&self.covariates)?;
let mat = selected.to_ndarray::<Float64Type>(IndexOrder::Fortran)?;
let vec: Vec<f64> = mat.iter().copied().collect();
Ok(DMatrix::from_row_slice(
selected.height(),
selected.width(),
&vec,
))
};
let x_full = get_matrix(df)?;
let mut x_full_intercept = x_full.clone();
x_full_intercept = x_full_intercept.insert_column(0, 1.0);
let y_treatment = df.column(&self.treatment_col)?.f64()?;
let y_vec: Vec<f64> = y_treatment
.into_iter()
.map(|opt| opt.unwrap_or(0.0))
.collect();
let y_full = DVector::from_vec(y_vec);
let mut logit = LogisticRegression::new();
logit
.fit(&x_full_intercept, &y_full, 100, 1e-6)
.map_err(OaxacaError::DiagnosticError)?;
let scores = logit.predict_proba(&x_full_intercept);
let mut df_with_score = df.clone();
let score_series = Series::new("propensity_score".into(), scores.as_slice());
df_with_score.with_column(score_series)?;
let engine = MatchingEngine::new(
df_with_score,
&self.treatment_col,
&self.outcome_col,
&["propensity_score"],
);
engine.run_matching(k, false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matching::distance::EuclideanDistance;
fn create_dummy_df() -> DataFrame {
let s0 = Series::new("treatment".into(), &[1, 1, 0, 0, 0]);
let s1 = Series::new("outcome".into(), &[10.0, 12.0, 8.0, 9.0, 8.5]);
let s2 = Series::new("age".into(), &[30.0, 40.0, 31.0, 41.0, 35.0]);
let s3 = Series::new("education".into(), &[12.0, 16.0, 12.0, 16.0, 14.0]);
DataFrame::new(vec![
Column::Series(s0),
Column::Series(s1),
Column::Series(s2),
Column::Series(s3),
])
.unwrap()
}
#[test]
fn test_match_nearest_neighbor_euclidean() {
let df = create_dummy_df();
let engine = MatchingEngine::new(df, "treatment", "outcome", &["age", "education"]);
let weights = engine
.match_nearest_neighbor(1, &EuclideanDistance)
.unwrap();
assert_eq!(weights.len(), 5);
assert_eq!(weights[0], 1.0);
assert_eq!(weights[1], 1.0);
}
}