use crate::metrics::evaluation::Metric;
use crate::objective::ObjectiveFunction;
use serde::{Deserialize, Serialize};
const PROPENSITY_CLIP_MIN: f64 = 1e-6;
const PROPENSITY_CLIP_MAX: f64 = 1.0 - 1e-6;
const HESSIAN_FLOOR: f64 = 1e-6;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RLearnerObjective {
pub treatment: Vec<f64>,
pub outcome_predicted: Vec<f64>,
pub treatment_predicted: Vec<f64>,
}
impl RLearnerObjective {
pub fn new(treatment: Vec<f64>, outcome_predicted: Vec<f64>, treatment_predicted: Vec<f64>) -> Self {
assert_eq!(treatment.len(), outcome_predicted.len());
assert_eq!(treatment.len(), treatment_predicted.len());
Self {
treatment,
outcome_predicted,
treatment_predicted,
}
}
}
impl ObjectiveFunction for RLearnerObjective {
fn loss(&self, y: &[f64], yhat: &[f64], _sample_weight: Option<&[f64]>, _group: Option<&[u64]>) -> Vec<f32> {
y.iter()
.zip(yhat.iter())
.zip(self.treatment.iter())
.zip(self.outcome_predicted.iter())
.zip(self.treatment_predicted.iter())
.map(|((((y_i, tau_i), w_i), mu_i), p_i)| {
let y_res = y_i - mu_i;
let p_clipped = p_i.clamp(PROPENSITY_CLIP_MIN, PROPENSITY_CLIP_MAX);
let w_res = w_i - p_clipped;
let diff = y_res - tau_i * w_res;
(diff * diff) as f32
})
.collect()
}
fn gradient(
&self,
y: &[f64],
yhat: &[f64],
_sample_weight: Option<&[f64]>,
_group: Option<&[u64]>,
) -> (Vec<f32>, Option<Vec<f32>>) {
let n = y.len();
let mut grad = Vec::with_capacity(n);
let mut hess = Vec::with_capacity(n);
for i in 0..n {
let y_res = y[i] - self.outcome_predicted[i];
let p_clipped = self.treatment_predicted[i].clamp(PROPENSITY_CLIP_MIN, PROPENSITY_CLIP_MAX);
let w_res = self.treatment[i] - p_clipped;
let tau = yhat[i];
let g = -w_res * y_res + tau * w_res * w_res;
let h = (w_res * w_res).max(HESSIAN_FLOOR);
grad.push(g as f32);
hess.push(h as f32);
}
(grad, Some(hess))
}
fn initial_value(&self, y: &[f64], _sample_weight: Option<&[f64]>, _group: Option<&[u64]>) -> f64 {
let mut num = 0.0;
let mut den = 0.0;
for (i, y_i) in y.iter().enumerate() {
let y_res = y_i - self.outcome_predicted[i];
let p_clipped = self.treatment_predicted[i].clamp(PROPENSITY_CLIP_MIN, PROPENSITY_CLIP_MAX);
let w_res = self.treatment[i] - p_clipped;
num += y_res * w_res;
den += w_res * w_res;
}
if den.abs() < HESSIAN_FLOOR { 0.0 } else { num / den }
}
fn default_metric(&self) -> Metric {
Metric::RootMeanSquaredError
}
}