use super::helpers::{
compute_kernel_mean, compute_witness, gaussian_kernel_matrix, greedy_prototype_selection,
median_bandwidth, project_scores, reconstruct_delta_function,
};
use crate::error::FdarError;
use crate::matrix::FdMatrix;
use crate::regression::FpcaResult;
use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct CounterfactualResult {
pub observation: usize,
pub original_scores: Vec<f64>,
pub counterfactual_scores: Vec<f64>,
pub delta_scores: Vec<f64>,
pub delta_function: Vec<f64>,
pub distance: f64,
pub original_prediction: f64,
pub counterfactual_prediction: f64,
pub found: bool,
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn counterfactual_regression(
fit: &FregreLmResult,
data: &FdMatrix,
scalar_covariates: Option<&FdMatrix>,
observation: usize,
target_value: f64,
) -> Result<CounterfactualResult, FdarError> {
let (n, m) = data.shape();
if observation >= n {
return Err(FdarError::InvalidParameter {
parameter: "observation",
message: format!("observation {observation} >= n {n}"),
});
}
if m != fit.fpca.mean.len() {
return Err(FdarError::InvalidDimension {
parameter: "data",
expected: format!("{} columns", fit.fpca.mean.len()),
actual: format!("{m}"),
});
}
let _ = scalar_covariates;
let ncomp = fit.ncomp;
if ncomp == 0 {
return Err(FdarError::InvalidParameter {
parameter: "ncomp",
message: "must be > 0".into(),
});
}
let scores = project_scores(
data,
&fit.fpca.mean,
&fit.fpca.rotation,
ncomp,
&fit.fpca.weights,
);
let original_prediction = fit.fitted_values[observation];
let gap = target_value - original_prediction;
let gamma: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
let gamma_norm_sq: f64 = gamma.iter().map(|g| g * g).sum();
if gamma_norm_sq < 1e-30 {
return Err(FdarError::ComputationFailed {
operation: "counterfactual_regression",
detail: "coefficient norm is near zero; the model has no predictive signal — try increasing ncomp or check data quality".into(),
});
}
let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
let delta_scores: Vec<f64> = gamma.iter().map(|&gk| gap * gk / gamma_norm_sq).collect();
let counterfactual_scores: Vec<f64> = original_scores
.iter()
.zip(&delta_scores)
.map(|(&o, &d)| o + d)
.collect();
let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
let counterfactual_prediction = original_prediction + gap;
Ok(CounterfactualResult {
observation,
original_scores,
counterfactual_scores,
delta_scores,
delta_function,
distance,
original_prediction,
counterfactual_prediction,
found: true,
})
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn counterfactual_logistic(
fit: &FunctionalLogisticResult,
data: &FdMatrix,
scalar_covariates: Option<&FdMatrix>,
observation: usize,
max_iter: usize,
step_size: f64,
) -> Result<CounterfactualResult, FdarError> {
let (n, m) = data.shape();
if observation >= n {
return Err(FdarError::InvalidParameter {
parameter: "observation",
message: format!("observation {observation} >= n {n}"),
});
}
if m != fit.fpca.mean.len() {
return Err(FdarError::InvalidDimension {
parameter: "data",
expected: format!("{} columns", fit.fpca.mean.len()),
actual: format!("{m}"),
});
}
let _ = scalar_covariates;
let ncomp = fit.ncomp;
if ncomp == 0 {
return Err(FdarError::InvalidParameter {
parameter: "ncomp",
message: "must be > 0".into(),
});
}
let scores = project_scores(
data,
&fit.fpca.mean,
&fit.fpca.rotation,
ncomp,
&fit.fpca.weights,
);
let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
let original_prediction = fit.probabilities[observation];
let original_class = usize::from(original_prediction >= 0.5);
let target_class = 1 - original_class;
let (current_scores, current_pred, found) = logistic_counterfactual_descent(
fit.intercept,
&fit.coefficients,
&original_scores,
target_class,
ncomp,
max_iter,
step_size,
);
let delta_scores: Vec<f64> = current_scores
.iter()
.zip(&original_scores)
.map(|(&c, &o)| c - o)
.collect();
let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
Ok(CounterfactualResult {
observation,
original_scores,
counterfactual_scores: current_scores,
delta_scores,
delta_function,
distance,
original_prediction,
counterfactual_prediction: current_pred,
found,
})
}
fn logistic_counterfactual_descent(
intercept: f64,
coefficients: &[f64],
initial_scores: &[f64],
target_class: usize,
ncomp: usize,
max_iter: usize,
step_size: f64,
) -> (Vec<f64>, f64, bool) {
let mut current_scores = initial_scores.to_vec();
let mut current_pred =
logistic_predict_from_scores(intercept, coefficients, ¤t_scores, ncomp);
for _ in 0..max_iter {
current_pred =
logistic_predict_from_scores(intercept, coefficients, ¤t_scores, ncomp);
let current_class = usize::from(current_pred >= 0.5);
if current_class == target_class {
return (current_scores, current_pred, true);
}
for k in 0..ncomp {
let grad = (current_pred - target_class as f64) * coefficients[1 + k];
current_scores[k] -= step_size * grad;
}
}
(current_scores, current_pred, false)
}
fn logistic_predict_from_scores(
intercept: f64,
coefficients: &[f64],
scores: &[f64],
ncomp: usize,
) -> f64 {
let mut eta = intercept;
for k in 0..ncomp {
eta += coefficients[1 + k] * scores[k];
}
sigmoid(eta)
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct PrototypeCriticismResult {
pub prototype_indices: Vec<usize>,
pub prototype_witness: Vec<f64>,
pub criticism_indices: Vec<usize>,
pub criticism_witness: Vec<f64>,
pub bandwidth: f64,
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn prototype_criticism(
fpca: &FpcaResult,
ncomp: usize,
n_prototypes: usize,
n_criticisms: usize,
) -> Result<PrototypeCriticismResult, FdarError> {
let n = fpca.scores.nrows();
let actual_ncomp = ncomp.min(fpca.scores.ncols());
if n == 0 {
return Err(FdarError::InvalidDimension {
parameter: "fpca.scores",
expected: ">0 rows".into(),
actual: "0".into(),
});
}
if actual_ncomp == 0 {
return Err(FdarError::InvalidParameter {
parameter: "ncomp",
message: "must be > 0".into(),
});
}
if n_prototypes == 0 {
return Err(FdarError::InvalidParameter {
parameter: "n_prototypes",
message: "must be > 0".into(),
});
}
if n_prototypes > n {
return Err(FdarError::InvalidParameter {
parameter: "n_prototypes",
message: format!("n_prototypes {n_prototypes} > n {n}"),
});
}
let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
let bandwidth = median_bandwidth(&fpca.scores, n, actual_ncomp);
let kernel = gaussian_kernel_matrix(&fpca.scores, actual_ncomp, bandwidth);
let mu_data = compute_kernel_mean(&kernel, n);
let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
let witness = compute_witness(&kernel, &mu_data, &selected, n);
let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
.filter(|i| !is_selected[*i])
.map(|i| (i, witness[i].abs()))
.collect();
criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let criticism_indices: Vec<usize> = criticism_candidates
.iter()
.take(n_crit)
.map(|&(i, _)| i)
.collect();
let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
Ok(PrototypeCriticismResult {
prototype_indices: selected,
prototype_witness,
criticism_indices,
criticism_witness,
bandwidth,
})
}