use crate::OaxacaError;
use nalgebra::{DMatrix, DVector};
use statrs::distribution::{Continuous, ContinuousCDF, Normal};
#[derive(Debug)]
#[allow(dead_code)]
pub struct ProbitResult {
pub coefficients: DVector<f64>,
pub vcov: DMatrix<f64>,
pub converged: bool,
pub iterations: usize,
}
pub fn probit(
y: &DVector<f64>,
x: &DMatrix<f64>,
max_iter: usize,
tol: f64,
) -> Result<ProbitResult, OaxacaError> {
let n = x.nrows();
let k = x.ncols();
let mut beta = DVector::zeros(k);
let normal = Normal::new(0.0, 1.0).unwrap();
let mut converged = false;
let mut iterations = 0;
let mut h = DMatrix::zeros(k, k);
for iter in 0..max_iter {
iterations = iter + 1;
let z = x * β
let mut gradient = DVector::zeros(k);
h = DMatrix::zeros(k, k);
for i in 0..n {
let xi = x.row(i).transpose();
let yi = y[i];
let zi = z[i];
let phi = normal.pdf(zi);
let big_phi = normal.cdf(zi);
let big_phi = big_phi.clamp(1e-10, 1.0 - 1e-10);
let lambda = if yi > 0.5 {
phi / big_phi
} else {
-phi / (1.0 - big_phi)
};
gradient += &xi * lambda;
let weight = (phi * phi) / (big_phi * (1.0 - big_phi));
h -= &xi * xi.transpose() * weight;
}
for i in 0..k {
h[(i, i)] -= 1e-9;
}
let h_inv = match h.clone().try_inverse() {
Some(inv) => inv,
None => {
return Err(OaxacaError::NalgebraError(
"Failed to invert Hessian in Probit".to_string(),
))
}
};
let _delta = &h_inv * (-&gradient);
let step = -&h_inv * &gradient;
beta += &step;
if step.norm() < tol {
converged = true;
break;
}
}
let vcov = match h.try_inverse() {
Some(inv) => -inv,
None => {
return Err(OaxacaError::NalgebraError(
"Failed to invert Hessian for VCOV".to_string(),
))
}
};
Ok(ProbitResult {
coefficients: beta,
vcov,
converged,
iterations,
})
}