use crate::booster::config::MissingNodeTreatment;
use crate::booster::core::PerpetualBooster;
use crate::constraints::ConstraintMap;
use crate::data::Matrix;
use crate::errors::PerpetualError;
use crate::objective::Objective;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Serialize, Deserialize)]
pub struct IVBooster {
pub treatment_model: PerpetualBooster,
pub outcome_model: PerpetualBooster,
pub stage1_budget: f32,
pub stage2_budget: f32,
}
impl IVBooster {
#[allow(clippy::too_many_arguments)]
pub fn new(
treatment_objective: Objective,
outcome_objective: Objective,
stage1_budget: f32,
stage2_budget: f32,
max_bin: u16,
num_threads: Option<usize>,
monotone_constraints: Option<ConstraintMap>,
interaction_constraints: Option<Vec<Vec<usize>>>,
force_children_to_bound_parent: bool,
missing: f64,
allow_missing_splits: bool,
create_missing_branch: bool,
terminate_missing_features: HashSet<usize>,
missing_node_treatment: MissingNodeTreatment,
log_iterations: usize,
seed: u64,
reset: Option<bool>,
categorical_features: Option<std::collections::HashSet<usize>>,
timeout: Option<f32>,
iteration_limit: Option<usize>,
memory_limit: Option<f32>,
stopping_rounds: Option<usize>,
) -> Result<Self, PerpetualError> {
let treatment_model = PerpetualBooster::new(
treatment_objective,
stage1_budget,
f64::NAN, max_bin,
num_threads,
monotone_constraints.clone(),
interaction_constraints.clone(),
force_children_to_bound_parent,
missing,
allow_missing_splits,
create_missing_branch,
terminate_missing_features.clone(),
missing_node_treatment,
log_iterations,
seed,
reset,
categorical_features.clone(),
timeout,
iteration_limit,
memory_limit,
stopping_rounds,
false,
)?;
let outcome_model = PerpetualBooster::new(
outcome_objective,
stage2_budget,
f64::NAN,
max_bin,
num_threads,
monotone_constraints,
interaction_constraints,
force_children_to_bound_parent,
missing,
allow_missing_splits,
create_missing_branch,
terminate_missing_features,
missing_node_treatment,
log_iterations,
seed,
reset,
categorical_features,
timeout,
iteration_limit,
memory_limit,
stopping_rounds,
false,
)?;
Ok(IVBooster {
treatment_model,
outcome_model,
stage1_budget,
stage2_budget,
})
}
pub fn fit(&mut self, x: &Matrix<f64>, z: &Matrix<f64>, y: &[f64], w: &[f64]) -> Result<(), PerpetualError> {
let rows = x.rows;
let x_cols = x.cols;
let z_cols = z.cols;
let total_cols_s1 = x_cols + z_cols;
let mut stage1_data = Vec::with_capacity(x.data.len() + z.data.len());
stage1_data.extend_from_slice(x.data);
stage1_data.extend_from_slice(z.data);
let matrix_stage1 = Matrix::new(&stage1_data, rows, total_cols_s1);
self.treatment_model.fit(&matrix_stage1, w, None, None)?;
let w_hat = self.treatment_model.predict(&matrix_stage1, true);
let v_res: Vec<f64> = w.iter().zip(w_hat.iter()).map(|(wi, what)| wi - what).collect();
let mut stage2_data = Vec::with_capacity(x.data.len() + w_hat.len() + v_res.len());
stage2_data.extend_from_slice(x.data);
stage2_data.extend_from_slice(&w_hat);
stage2_data.extend_from_slice(&v_res);
let matrix_stage2 = Matrix::new(&stage2_data, rows, x_cols + 2);
self.outcome_model.fit(&matrix_stage2, y, None, None)?;
Ok(())
}
pub fn predict(&self, x: &Matrix<f64>, w_counterfactual: &[f64]) -> Vec<f64> {
let rows = x.rows;
let x_cols = x.cols;
if w_counterfactual.len() != 1 && w_counterfactual.len() != rows {
panic!("w_counterfactual must satisfy len == 1 or len == x.rows");
}
let mut stage2_data = Vec::with_capacity(x.data.len() + rows * 2);
stage2_data.extend_from_slice(x.data);
if w_counterfactual.len() == 1 {
stage2_data.resize(stage2_data.len() + rows, w_counterfactual[0]);
} else {
stage2_data.extend_from_slice(w_counterfactual);
}
stage2_data.resize(stage2_data.len() + rows, 0.0);
let matrix_stage2 = Matrix::new(&stage2_data, rows, x_cols + 2);
self.outcome_model.predict(&matrix_stage2, true)
}
}