use crate::metrics::evaluation::Metric;
use crate::objective::ObjectiveFunction;
use serde::{Deserialize, Serialize};
const HESSIAN_FLOOR: f64 = 1e-6;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DMLObjective {
pub y_residual: Vec<f64>,
pub w_residual: Vec<f64>,
}
impl DMLObjective {
pub fn new(y_residual: Vec<f64>, w_residual: Vec<f64>) -> Self {
assert_eq!(
y_residual.len(),
w_residual.len(),
"y_residual and w_residual must have the same length"
);
Self { y_residual, w_residual }
}
}
impl ObjectiveFunction for DMLObjective {
fn loss(&self, _y: &[f64], yhat: &[f64], _sample_weight: Option<&[f64]>, _group: Option<&[u64]>) -> Vec<f32> {
yhat.iter()
.zip(self.y_residual.iter())
.zip(self.w_residual.iter())
.map(|((theta, yr), wr)| {
let diff = yr - theta * wr;
(diff * diff) as f32
})
.collect()
}
fn gradient(
&self,
_y: &[f64],
yhat: &[f64],
_sample_weight: Option<&[f64]>,
_group: Option<&[u64]>,
) -> (Vec<f32>, Option<Vec<f32>>) {
let n = yhat.len();
let mut grad = Vec::with_capacity(n);
let mut hess = Vec::with_capacity(n);
for (i, &theta) in yhat.iter().enumerate().take(n) {
let yr = self.y_residual[i];
let wr = self.w_residual[i];
let g = -wr * yr + theta * wr * wr;
let h = (wr * wr).max(HESSIAN_FLOOR);
grad.push(g as f32);
hess.push(h as f32);
}
(grad, Some(hess))
}
fn initial_value(&self, _y: &[f64], _sample_weight: Option<&[f64]>, _group: Option<&[u64]>) -> f64 {
let num: f64 = self
.y_residual
.iter()
.zip(self.w_residual.iter())
.map(|(yr, wr)| yr * wr)
.sum();
let den: f64 = self.w_residual.iter().map(|wr| wr * wr).sum();
if den.abs() < HESSIAN_FLOOR { 0.0 } else { num / den }
}
fn default_metric(&self) -> Metric {
Metric::RootMeanSquaredError
}
}