use crate::estimate::{EstimationError, UnifiedFitResult};
use crate::inference::predict::{
InferenceCovarianceMode, PredictInput, PredictPosteriorMeanResult, PredictResult,
PredictUncertaintyOptions, PredictUncertaintyResult, PredictionWithSE,
};
use crate::types::ResponseFamily;
use ndarray::Array1;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct ResponseBounds(Option<(f64, f64)>);
impl ResponseBounds {
pub const UNBOUNDED: Self = Self(None);
pub const UNIT_PROBABILITY: Self = Self(Some((0.0, 1.0)));
pub fn closed(lo: f64, hi: f64) -> Self {
Self(Some((lo, hi)))
}
pub fn for_family(response: &ResponseFamily) -> Self {
Self(response.mean_clamp_bounds())
}
pub fn response_support(response: &ResponseFamily) -> Self {
Self(response.response_support_bounds())
}
#[inline]
pub fn clamp_value(&self, v: f64) -> f64 {
match self.0 {
Some((lo, hi)) => v.clamp(lo, hi),
None => v,
}
}
pub fn clamp_in_place(&self, values: &mut Array1<f64>) {
if let Some((lo, hi)) = self.0 {
values.mapv_inplace(|v| v.clamp(lo, hi));
}
}
}
pub fn central_z(level: f64) -> Result<f64, EstimationError> {
crate::probability::standard_normal_quantile(0.5 + 0.5 * level)
.map_err(EstimationError::InvalidInput)
}
pub fn validated_central_z(level: f64) -> Result<f64, EstimationError> {
if !(level.is_finite() && level > 0.0 && level < 1.0) {
return Err(EstimationError::InvalidInput(format!(
"confidence_level must be in (0,1), got {level}"
)));
}
central_z(level)
}
#[inline]
pub fn symmetric_interval(
center: &Array1<f64>,
se: &Array1<f64>,
z: f64,
) -> (Array1<f64>, Array1<f64>) {
let half_width = se.mapv(|s| z * s);
(center - &half_width, center + &half_width)
}
pub fn transform_eta_interval<F>(
eta_lower: &Array1<f64>,
eta_upper: &Array1<f64>,
bounds: ResponseBounds,
response_map: F,
) -> Result<(Array1<f64>, Array1<f64>), EstimationError>
where
F: Fn(&Array1<f64>) -> Result<Array1<f64>, EstimationError>,
{
let transformed_lower = response_map(eta_lower)?;
let transformed_upper = response_map(eta_upper)?;
let mut mean_lower = Array1::from_iter(
transformed_lower
.iter()
.zip(transformed_upper.iter())
.map(|(&lo, &hi)| lo.min(hi)),
);
let mut mean_upper = Array1::from_iter(
transformed_lower
.iter()
.zip(transformed_upper.iter())
.map(|(&lo, &hi)| lo.max(hi)),
);
bounds.clamp_in_place(&mut mean_lower);
bounds.clamp_in_place(&mut mean_upper);
Ok((mean_lower, mean_upper))
}
pub fn delta_mean_interval(
mean: &Array1<f64>,
mean_se: &Array1<f64>,
z: f64,
bounds: ResponseBounds,
) -> (Array1<f64>, Array1<f64>) {
let (mut mean_lower, mut mean_upper) = symmetric_interval(mean, mean_se, z);
bounds.clamp_in_place(&mut mean_lower);
bounds.clamp_in_place(&mut mean_upper);
(mean_lower, mean_upper)
}
pub enum MeanBoundMethod<'a> {
TransformEta {
bounds: ResponseBounds,
response_map: &'a (dyn Fn(&Array1<f64>) -> Result<Array1<f64>, EstimationError> + 'a),
},
Delta {
mean_se: &'a Array1<f64>,
bounds: ResponseBounds,
},
IdentityEta,
}
pub fn mean_bounds(
eta_lower: &Array1<f64>,
eta_upper: &Array1<f64>,
mean: &Array1<f64>,
z: f64,
method: MeanBoundMethod<'_>,
) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
match method {
MeanBoundMethod::TransformEta {
bounds,
response_map,
} => transform_eta_interval(eta_lower, eta_upper, bounds, response_map),
MeanBoundMethod::Delta { mean_se, bounds } => {
Ok(delta_mean_interval(mean, mean_se, z, bounds))
}
MeanBoundMethod::IdentityEta => Ok((eta_lower.clone(), eta_upper.clone())),
}
}
pub enum EtaInterval {
Symmetric,
Collapsed,
}
impl EtaInterval {
fn endpoints(
&self,
eta: &Array1<f64>,
eta_se: &Array1<f64>,
z: f64,
) -> (Array1<f64>, Array1<f64>) {
match self {
EtaInterval::Symmetric => symmetric_interval(eta, eta_se, z),
EtaInterval::Collapsed => (eta.clone(), eta.clone()),
}
}
}
pub struct ObservationInterval<'a> {
pub noise_sd: &'a Array1<f64>,
pub bounds: ResponseBounds,
}
pub struct UncertaintyProvenance {
pub covariance_mode_requested: InferenceCovarianceMode,
pub covariance_corrected_used: bool,
}
#[allow(clippy::too_many_arguments)]
pub fn assemble_uncertainty_result(
confidence_level: f64,
eta: Array1<f64>,
mean: Array1<f64>,
eta_standard_error: Array1<f64>,
mean_standard_error: Array1<f64>,
eta_interval: EtaInterval,
method: MeanBoundMethod<'_>,
observation: Option<ObservationInterval<'_>>,
provenance: UncertaintyProvenance,
) -> Result<PredictUncertaintyResult, EstimationError> {
let z = validated_central_z(confidence_level)?;
let (eta_lower, eta_upper) = eta_interval.endpoints(&eta, &eta_standard_error, z);
let (mean_lower, mean_upper) = mean_bounds(&eta_lower, &eta_upper, &mean, z, method)?;
let (observation_lower, observation_upper) = match observation {
Some(obs) => {
let half = obs.noise_sd.mapv(|s| z * s);
let mut lower = &mean - ½
let mut upper = &mean + ½
obs.bounds.clamp_in_place(&mut lower);
obs.bounds.clamp_in_place(&mut upper);
(Some(lower), Some(upper))
}
None => (None, None),
};
Ok(PredictUncertaintyResult {
eta,
mean,
eta_standard_error,
mean_standard_error,
eta_lower,
eta_upper,
mean_lower,
mean_upper,
observation_lower,
observation_upper,
covariance_mode_requested: provenance.covariance_mode_requested,
covariance_corrected_used: provenance.covariance_corrected_used,
})
}
pub fn assemble_posterior_mean_bounds(
result: &mut PredictPosteriorMeanResult,
confidence_level: Option<f64>,
eta_interval: EtaInterval,
method: MeanBoundMethod<'_>,
) -> Result<(), EstimationError> {
let Some(level) = confidence_level else {
return Ok(());
};
let z = validated_central_z(level)?;
let (eta_lower, eta_upper) = eta_interval.endpoints(&result.eta, &result.eta_standard_error, z);
let (mean_lower, mean_upper) = mean_bounds(&eta_lower, &eta_upper, &result.mean, z, method)?;
result.mean_lower = Some(mean_lower);
result.mean_upper = Some(mean_upper);
Ok(())
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PredictPass {
FullUncertainty,
PosteriorMean,
}
pub enum ResponseInterval {
TransformEta,
IdentityEta,
CollapsedDelta,
SymmetricDelta,
}
pub struct LinearState {
pub eta: Array1<f64>,
pub mean: Array1<f64>,
pub eta_se: Option<Array1<f64>>,
pub mean_se: Option<Array1<f64>>,
pub covariance_corrected_used: bool,
}
pub trait PredictionTransform {
fn point_state(&self, input: &PredictInput) -> Result<LinearState, EstimationError>;
fn linear_state(
&self,
input: &PredictInput,
fit: &UnifiedFitResult,
pass: PredictPass,
covariance_mode: InferenceCovarianceMode,
) -> Result<LinearState, EstimationError> {
assert!(std::mem::size_of_val(fit) > 0);
assert!(std::mem::size_of_val(&covariance_mode) > 0);
match pass {
PredictPass::FullUncertainty => self.point_state(input),
PredictPass::PosteriorMean => Err(EstimationError::InvalidInput(
"this transform does not implement the posterior-mean pass".to_string(),
)),
}
}
fn response(&self, eta: &Array1<f64>) -> Result<Array1<f64>, EstimationError>;
fn response_jacobian_rows(&self, pass: PredictPass) -> ResponseInterval;
fn bounds(&self) -> ResponseBounds;
fn observation_noise(
&self,
input: &PredictInput,
) -> Result<Option<Array1<f64>>, EstimationError> {
assert!(std::mem::size_of_val(input) > 0);
Ok(None)
}
}
fn mean_bound_method_for<'a, T: PredictionTransform>(
transform: &'a T,
policy: &ResponseInterval,
response_map: &'a (dyn Fn(&Array1<f64>) -> Result<Array1<f64>, EstimationError> + 'a),
mean_se: &'a Array1<f64>,
) -> MeanBoundMethod<'a> {
match policy {
ResponseInterval::TransformEta => MeanBoundMethod::TransformEta {
bounds: transform.bounds(),
response_map,
},
ResponseInterval::IdentityEta => MeanBoundMethod::IdentityEta,
ResponseInterval::CollapsedDelta | ResponseInterval::SymmetricDelta => {
MeanBoundMethod::Delta {
mean_se,
bounds: transform.bounds(),
}
}
}
}
fn eta_interval_for(policy: &ResponseInterval) -> EtaInterval {
match policy {
ResponseInterval::CollapsedDelta => EtaInterval::Collapsed,
ResponseInterval::TransformEta
| ResponseInterval::IdentityEta
| ResponseInterval::SymmetricDelta => EtaInterval::Symmetric,
}
}
pub fn predict_full_uncertainty_generic<T: PredictionTransform>(
transform: &T,
input: &PredictInput,
fit: &UnifiedFitResult,
options: &PredictUncertaintyOptions,
) -> Result<PredictUncertaintyResult, EstimationError> {
let state = transform.linear_state(
input,
fit,
PredictPass::FullUncertainty,
options.covariance_mode,
)?;
let covariance_corrected_used = state.covariance_corrected_used;
let eta_se = state.eta_se.ok_or_else(|| {
EstimationError::InvalidInput(
"full uncertainty requires covariance (eta_se unavailable)".to_string(),
)
})?;
let mean_se = state.mean_se.ok_or_else(|| {
EstimationError::InvalidInput(
"full uncertainty requires covariance (mean_se unavailable)".to_string(),
)
})?;
let policy = transform.response_jacobian_rows(PredictPass::FullUncertainty);
let response_map = move |eta: &Array1<f64>| transform.response(eta);
let observation = if options.includeobservation_interval {
transform.observation_noise(input)?
} else {
None
};
assemble_uncertainty_result(
options.confidence_level,
state.eta,
state.mean,
eta_se,
mean_se.clone(),
eta_interval_for(&policy),
mean_bound_method_for(transform, &policy, &response_map, &mean_se),
observation.as_ref().map(|noise_sd| ObservationInterval {
noise_sd,
bounds: transform.bounds(),
}),
UncertaintyProvenance {
covariance_mode_requested: options.covariance_mode,
covariance_corrected_used,
},
)
}
pub fn predict_posterior_mean_generic<T: PredictionTransform>(
transform: &T,
input: &PredictInput,
fit: &UnifiedFitResult,
confidence_level: Option<f64>,
) -> Result<PredictPosteriorMeanResult, EstimationError> {
let state = transform.linear_state(
input,
fit,
PredictPass::PosteriorMean,
InferenceCovarianceMode::Conditional,
)?;
let eta_se = state
.eta_se
.unwrap_or_else(|| Array1::zeros(state.eta.len()));
let policy = transform.response_jacobian_rows(PredictPass::PosteriorMean);
let mean_se = state.mean_se.clone().unwrap_or_else(|| eta_se.clone());
let response_map = move |eta: &Array1<f64>| transform.response(eta);
let mut result = PredictPosteriorMeanResult {
eta: state.eta,
eta_standard_error: eta_se,
mean: state.mean,
mean_lower: None,
mean_upper: None,
};
assemble_posterior_mean_bounds(
&mut result,
confidence_level,
eta_interval_for(&policy),
mean_bound_method_for(transform, &policy, &response_map, &mean_se),
)?;
Ok(result)
}
pub fn predict_plugin_response_generic<T: PredictionTransform>(
transform: &T,
input: &PredictInput,
) -> Result<PredictResult, EstimationError> {
let state = transform.point_state(input)?;
Ok(PredictResult {
eta: state.eta,
mean: state.mean,
})
}
pub fn predict_with_uncertainty_generic<T: PredictionTransform>(
transform: &T,
input: &PredictInput,
) -> Result<PredictionWithSE, EstimationError> {
let state = transform.point_state(input)?;
Ok(PredictionWithSE {
eta: state.eta,
mean: state.mean,
eta_se: state.eta_se,
mean_se: state.mean_se,
})
}
#[cfg(test)]
mod parity_tests {
use super::*;
use ndarray::array;
const LEVEL: f64 = 0.95;
fn z95() -> f64 {
central_z(LEVEL).expect("0.95 is a valid level")
}
fn assert_close(a: &Array1<f64>, b: &Array1<f64>, tag: &str) {
assert_eq!(a.len(), b.len(), "{tag}: length mismatch");
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
assert!(
(x - y).abs() < 1e-12,
"{tag}: row {i} mismatch: engine={x}, reference={y}"
);
}
}
#[test]
fn delta_symmetric_matches_inline() {
let eta = array![0.2, -0.5, 1.3];
let mean = array![0.55, 0.38, 0.78];
let eta_se = array![0.1, 0.2, 0.15];
let mean_se = array![0.04, 0.06, 0.05];
let z = z95();
let ref_eta_lower = &eta - &eta_se.mapv(|s| z * s);
let ref_eta_upper = &eta + &eta_se.mapv(|s| z * s);
let ref_mean_lower = (&mean - &mean_se.mapv(|s| z * s)).mapv(|v| v.clamp(0.0, 1.0));
let ref_mean_upper = (&mean + &mean_se.mapv(|s| z * s)).mapv(|v| v.clamp(0.0, 1.0));
let out = assemble_uncertainty_result(
LEVEL,
eta.clone(),
mean.clone(),
eta_se.clone(),
mean_se.clone(),
EtaInterval::Symmetric,
MeanBoundMethod::Delta {
mean_se: &mean_se,
bounds: ResponseBounds::UNIT_PROBABILITY,
},
None,
UncertaintyProvenance {
covariance_mode_requested: InferenceCovarianceMode::Conditional,
covariance_corrected_used: false,
},
)
.expect("engine assembly");
assert_close(&out.eta, &eta, "eta point");
assert_close(&out.mean, &mean, "mean point");
assert_close(&out.eta_standard_error, &eta_se, "eta SE");
assert_close(&out.mean_standard_error, &mean_se, "mean SE");
assert_close(&out.eta_lower, &ref_eta_lower, "eta lower");
assert_close(&out.eta_upper, &ref_eta_upper, "eta upper");
assert_close(&out.mean_lower, &ref_mean_lower, "mean lower");
assert_close(&out.mean_upper, &ref_mean_upper, "mean upper");
assert!(out.observation_lower.is_none());
assert!(out.observation_upper.is_none());
assert!(!out.covariance_corrected_used);
}
#[test]
fn transform_eta_symmetric_matches_inline() {
let eta = array![0.2, -0.5, 1.3];
let logistic = |e: &Array1<f64>| -> Result<Array1<f64>, EstimationError> {
Ok(e.mapv(|x| 1.0 / (1.0 + (-x).exp())))
};
let mean = logistic(&eta).unwrap();
let eta_se = array![0.1, 0.2, 0.15];
let mean_se = array![0.02, 0.05, 0.03];
let z = z95();
let ref_eta_lower = &eta - &eta_se.mapv(|s| z * s);
let ref_eta_upper = &eta + &eta_se.mapv(|s| z * s);
let ref_mean_lower = logistic(&ref_eta_lower).unwrap();
let ref_mean_upper = logistic(&ref_eta_upper).unwrap();
let out = assemble_uncertainty_result(
LEVEL,
eta.clone(),
mean.clone(),
eta_se.clone(),
mean_se.clone(),
EtaInterval::Symmetric,
MeanBoundMethod::TransformEta {
bounds: ResponseBounds::UNIT_PROBABILITY,
response_map: &logistic,
},
None,
UncertaintyProvenance {
covariance_mode_requested: InferenceCovarianceMode::Conditional,
covariance_corrected_used: false,
},
)
.expect("engine assembly");
assert_close(&out.eta_lower, &ref_eta_lower, "eta lower");
assert_close(&out.eta_upper, &ref_eta_upper, "eta upper");
assert_close(&out.mean_lower, &ref_mean_lower, "mean lower");
assert_close(&out.mean_upper, &ref_mean_upper, "mean upper");
}
#[test]
fn identity_eta_with_observation_matches_inline() {
let eta = array![1.0, 2.0, -1.0];
let mean = eta.clone();
let eta_se = array![0.3, 0.1, 0.25];
let sigma = array![0.5, 0.4, 0.6];
let z = z95();
let ref_eta_lower = &eta - &eta_se.mapv(|s| z * s);
let ref_eta_upper = &eta + &eta_se.mapv(|s| z * s);
let ref_obs_lower = &mean - &sigma.mapv(|s| z * s);
let ref_obs_upper = &mean + &sigma.mapv(|s| z * s);
let out = assemble_uncertainty_result(
LEVEL,
eta.clone(),
mean.clone(),
eta_se.clone(),
eta_se.clone(),
EtaInterval::Symmetric,
MeanBoundMethod::IdentityEta,
Some(ObservationInterval {
noise_sd: &sigma,
bounds: ResponseBounds::UNBOUNDED,
}),
UncertaintyProvenance {
covariance_mode_requested: InferenceCovarianceMode::Conditional,
covariance_corrected_used: false,
},
)
.expect("engine assembly");
assert_close(&out.mean_standard_error, &eta_se, "mean SE == eta SE");
assert_close(&out.eta_lower, &ref_eta_lower, "eta lower");
assert_close(&out.eta_upper, &ref_eta_upper, "eta upper");
assert_close(&out.mean_lower, &ref_eta_lower, "mean lower == eta lower");
assert_close(&out.mean_upper, &ref_eta_upper, "mean upper == eta upper");
assert_close(
out.observation_lower.as_ref().expect("obs lower"),
&ref_obs_lower,
"observation lower",
);
assert_close(
out.observation_upper.as_ref().expect("obs upper"),
&ref_obs_upper,
"observation upper",
);
}
#[test]
fn delta_collapsed_eta_matches_inline() {
let eta = array![-0.3, 0.7, 0.1];
let mean = array![0.42, 0.66, 0.51];
let mean_se = array![0.05, 0.08, 0.04];
let z = z95();
let ref_mean_lower = (&mean - &mean_se.mapv(|s| z * s)).mapv(|v| v.clamp(0.0, 1.0));
let ref_mean_upper = (&mean + &mean_se.mapv(|s| z * s)).mapv(|v| v.clamp(0.0, 1.0));
let out = assemble_uncertainty_result(
LEVEL,
eta.clone(),
mean.clone(),
mean_se.clone(),
mean_se.clone(),
EtaInterval::Collapsed,
MeanBoundMethod::Delta {
mean_se: &mean_se,
bounds: ResponseBounds::UNIT_PROBABILITY,
},
None,
UncertaintyProvenance {
covariance_mode_requested: InferenceCovarianceMode::Conditional,
covariance_corrected_used: false,
},
)
.expect("engine assembly");
assert_close(&out.eta_lower, &eta, "eta lower == eta");
assert_close(&out.eta_upper, &eta, "eta upper == eta");
assert_close(&out.mean_lower, &ref_mean_lower, "mean lower");
assert_close(&out.mean_upper, &ref_mean_upper, "mean upper");
}
#[test]
fn posterior_mean_bounds_match_inline() {
let eta = array![0.2, -0.4];
let mean = array![0.55, 0.40];
let eta_se = array![0.1, 0.2];
let z = z95();
let logistic = |e: &Array1<f64>| -> Result<Array1<f64>, EstimationError> {
Ok(e.mapv(|x| 1.0 / (1.0 + (-x).exp())))
};
let mut none_result = PredictPosteriorMeanResult {
eta: eta.clone(),
eta_standard_error: eta_se.clone(),
mean: mean.clone(),
mean_lower: None,
mean_upper: None,
};
assemble_posterior_mean_bounds(
&mut none_result,
None,
EtaInterval::Symmetric,
MeanBoundMethod::TransformEta {
bounds: ResponseBounds::UNIT_PROBABILITY,
response_map: &logistic,
},
)
.expect("engine assembly");
assert!(none_result.mean_lower.is_none());
assert!(none_result.mean_upper.is_none());
let ref_eta_lower = &eta - &eta_se.mapv(|s| z * s);
let ref_eta_upper = &eta + &eta_se.mapv(|s| z * s);
let ref_mean_lower = logistic(&ref_eta_lower).unwrap();
let ref_mean_upper = logistic(&ref_eta_upper).unwrap();
let mut some_result = PredictPosteriorMeanResult {
eta: eta.clone(),
eta_standard_error: eta_se.clone(),
mean: mean.clone(),
mean_lower: None,
mean_upper: None,
};
assemble_posterior_mean_bounds(
&mut some_result,
Some(LEVEL),
EtaInterval::Symmetric,
MeanBoundMethod::TransformEta {
bounds: ResponseBounds::UNIT_PROBABILITY,
response_map: &logistic,
},
)
.expect("engine assembly");
assert_close(
some_result.mean_lower.as_ref().expect("mean lower"),
&ref_mean_lower,
"posterior mean lower",
);
assert_close(
some_result.mean_upper.as_ref().expect("mean upper"),
&ref_mean_upper,
"posterior mean upper",
);
}
#[test]
fn transform_eta_non_monotone_orders_bounds() {
let eta = array![0.0, 0.5];
let decreasing = |e: &Array1<f64>| -> Result<Array1<f64>, EstimationError> {
Ok(e.mapv(|x| 1.0 / (1.0 + x.exp())))
};
let mean = decreasing(&eta).unwrap();
let eta_se = array![0.2, 0.3];
let out = assemble_uncertainty_result(
LEVEL,
eta.clone(),
mean.clone(),
eta_se.clone(),
eta_se.clone(),
EtaInterval::Symmetric,
MeanBoundMethod::TransformEta {
bounds: ResponseBounds::UNIT_PROBABILITY,
response_map: &decreasing,
},
None,
UncertaintyProvenance {
covariance_mode_requested: InferenceCovarianceMode::Conditional,
covariance_corrected_used: false,
},
)
.expect("engine assembly");
for (lo, hi) in out.mean_lower.iter().zip(out.mean_upper.iter()) {
assert!(
lo <= hi,
"decreasing map must still return ordered bounds: {lo} > {hi}"
);
assert!((0.0..=1.0).contains(lo) && (0.0..=1.0).contains(hi));
}
}
}