use super::*;
pub struct TransformationNormalPredictor {
pub covariance: Option<Array2<f64>>,
}
impl PredictionTransform for TransformationNormalPredictor {
fn point_state(&self, input: &PredictInput) -> Result<LinearState, EstimationError> {
let h = input.offset.clone();
let zeros = Array1::zeros(h.len());
Ok(LinearState {
eta: h.clone(),
mean: h,
eta_se: Some(zeros.clone()),
mean_se: Some(zeros),
covariance_corrected_used: false,
})
}
fn linear_state(
&self,
input: &PredictInput,
fit: &UnifiedFitResult,
pass: PredictPass,
covariance_mode: InferenceCovarianceMode,
) -> Result<LinearState, EstimationError> {
let mut state = self.point_state(input)?;
if matches!(pass, PredictPass::FullUncertainty) {
let corrected_requested =
!matches!(covariance_mode, InferenceCovarianceMode::Conditional);
state.covariance_corrected_used =
corrected_requested && fit.covariance_corrected.is_some();
}
Ok(state)
}
fn response(&self, eta: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
Ok(eta.clone())
}
fn response_jacobian_rows(&self, _: PredictPass) -> ResponseInterval {
ResponseInterval::IdentityEta
}
fn bounds(&self) -> ResponseBounds {
ResponseBounds::UNBOUNDED
}
fn response_family(&self) -> ResponseFamily {
ResponseFamily::Gaussian
}
}
impl PredictableModel for TransformationNormalPredictor {
fn predict_plugin_response(
&self,
input: &PredictInput,
) -> Result<PredictResult, EstimationError> {
predict_plugin_response_generic(self, input)
}
fn predict_with_uncertainty(
&self,
input: &PredictInput,
) -> Result<PredictionWithSE, EstimationError> {
let h = input.offset.clone();
Ok(PredictionWithSE {
eta: h.clone(),
mean: h,
eta_se: None,
mean_se: None,
})
}
fn predict_full_uncertainty(
&self,
input: &PredictInput,
fit: &UnifiedFitResult,
options: &PredictUncertaintyOptions,
) -> Result<PredictUncertaintyResult, EstimationError> {
predict_full_uncertainty_generic(self, input, fit, options)
}
fn predict_posterior_mean(
&self,
input: &PredictInput,
fit: &UnifiedFitResult,
options: &PosteriorMeanOptions,
) -> Result<PredictPosteriorMeanResult, EstimationError> {
let has_fit_covariance =
fit.covariance_corrected.is_some() || fit.covariance_conditional.is_some();
let bound_level = has_fit_covariance
.then_some(options.confidence_level)
.flatten();
let bounded_options = PosteriorMeanOptions {
confidence_level: bound_level,
..*options
};
predict_posterior_mean_generic(self, input, fit, &bounded_options)
}
fn n_blocks(&self) -> usize {
1
}
fn block_roles(&self) -> Vec<BlockRole> {
vec![BlockRole::Mean]
}
}