use crate::metrics::evaluation::Metric;
use crate::objective::ObjectiveFunction;
use serde::{Deserialize, Serialize};
const PROPENSITY_CLIP: f64 = 1e-3;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PolicyMode {
IPW,
AIPW {
mu_hat_1: Vec<f64>,
mu_hat_0: Vec<f64>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyObjective {
pub treatment: Vec<u8>,
pub propensity: Vec<f64>,
pub mode: PolicyMode,
}
impl PolicyObjective {
pub fn new(treatment: Vec<u8>, propensity: Vec<f64>) -> Self {
Self {
treatment,
propensity,
mode: PolicyMode::IPW,
}
}
pub fn new_aipw(treatment: Vec<u8>, propensity: Vec<f64>, mu_hat_1: Vec<f64>, mu_hat_0: Vec<f64>) -> Self {
Self {
treatment,
propensity,
mode: PolicyMode::AIPW { mu_hat_1, mu_hat_0 },
}
}
fn pseudo_outcome(&self, y: &[f64]) -> Vec<f64> {
let n = y.len();
let mut gamma = Vec::with_capacity(n);
for i in 0..n {
let w = self.treatment[i] as f64;
let p = self.propensity[i].clamp(PROPENSITY_CLIP, 1.0 - PROPENSITY_CLIP);
let yi = y[i];
let g = match &self.mode {
PolicyMode::IPW => {
yi * w / p - yi * (1.0 - w) / (1.0 - p)
}
PolicyMode::AIPW { mu_hat_1, mu_hat_0 } => {
let m1 = mu_hat_1[i];
let m0 = mu_hat_0[i];
let term1 = m1 - m0;
let term2 = w * (yi - m1) / p;
let term3 = (1.0 - w) * (yi - m0) / (1.0 - p);
term1 + term2 - term3
}
};
gamma.push(g);
}
gamma
}
}
impl ObjectiveFunction for PolicyObjective {
fn loss(&self, y: &[f64], yhat: &[f64], _sample_weight: Option<&[f64]>, _group: Option<&[u64]>) -> Vec<f32> {
let gamma = self.pseudo_outcome(y);
gamma
.iter()
.zip(yhat.iter())
.map(|(g, score)| {
let sigma = 1.0 / (1.0 + (-score).exp());
let target = if *g >= 0.0 { 1.0 } else { 0.0 };
let w = g.abs();
let loss = -(target * sigma.max(1e-15).ln() + (1.0 - target) * (1.0 - sigma).max(1e-15).ln());
(w * loss) as f32
})
.collect()
}
fn gradient(
&self,
y: &[f64],
yhat: &[f64],
_sample_weight: Option<&[f64]>,
_group: Option<&[u64]>,
) -> (Vec<f32>, Option<Vec<f32>>) {
let n = y.len();
let gamma = self.pseudo_outcome(y);
let mut grad = Vec::with_capacity(n);
let mut hess = Vec::with_capacity(n);
for i in 0..n {
let score = yhat[i];
let sigma = 1.0 / (1.0 + (-score).exp());
let target = if gamma[i] >= 0.0 { 1.0 } else { 0.0 };
let weight = gamma[i].abs();
let g = weight * (sigma - target);
let h = weight * sigma * (1.0 - sigma);
grad.push(g as f32);
hess.push(h as f32);
}
(grad, Some(hess))
}
fn initial_value(&self, _y: &[f64], _sample_weight: Option<&[f64]>, _group: Option<&[u64]>) -> f64 {
0.0
}
fn default_metric(&self) -> Metric {
Metric::LogLoss
}
}