use super::residual_model::{PinnError, PinnResult, ResidualModel};
#[derive(Debug, Clone)]
pub struct PinnCorrector {
model: Option<ResidualModel>,
}
impl Default for PinnCorrector {
fn default() -> Self {
Self::empty()
}
}
impl PinnCorrector {
pub fn empty() -> Self {
Self { model: None }
}
pub fn new(model: ResidualModel) -> Self {
#[cfg(feature = "pinn_correction")]
{
Self { model: Some(model) }
}
#[cfg(not(feature = "pinn_correction"))]
{
let _ = model;
Self { model: None }
}
}
pub fn is_active(&self) -> bool {
cfg!(feature = "pinn_correction") && self.model.is_some()
}
pub fn parameter_count(&self) -> Option<usize> {
self.model.as_ref().map(|m| m.parameter_count())
}
pub fn apply(&self, previous_state: &[f64], solver_step: &[f64]) -> PinnResult<Vec<f64>> {
if !cfg!(feature = "pinn_correction") {
return Err(PinnError::FeatureDisabled);
}
if previous_state.len() != solver_step.len() {
return Err(PinnError::InvalidModel(format!(
"previous_state and solver_step have different lengths: {} vs {}",
previous_state.len(),
solver_step.len()
)));
}
let model = match self.model.as_ref() {
Some(m) => m,
None => return Err(PinnError::FeatureDisabled),
};
let raw = model.forward_raw(previous_state)?;
Ok(solver_step
.iter()
.zip(raw.iter())
.map(|(s, r)| *s + model.residual_scale * *r)
.collect())
}
pub fn model(&self) -> Option<&ResidualModel> {
self.model.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pinn::residual_model::{Activation, DenseLayer, ResidualModelConfig};
fn identity_residual_model() -> ResidualModel {
let cfg = ResidualModelConfig {
input_dim: 2,
output_dim: 2,
hidden_widths: vec![2],
hidden_activation: Activation::Relu,
output_activation: Activation::Identity,
description: "identity".to_string(),
};
let h = DenseLayer::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
Activation::Identity,
)
.expect("layer ok");
let o = DenseLayer::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
Activation::Identity,
)
.expect("layer ok");
ResidualModel::from_layers(cfg, vec![h, o]).expect("model ok")
}
#[test]
fn empty_corrector_is_inactive() {
let c = PinnCorrector::empty();
assert!(!c.is_active());
assert_eq!(c.parameter_count(), None);
}
#[test]
fn apply_without_feature_returns_disabled() {
let model = identity_residual_model();
let c = PinnCorrector::new(model);
let prev = vec![1.0, 2.0];
let step = vec![3.0, 4.0];
let result = c.apply(&prev, &step);
#[cfg(feature = "pinn_correction")]
{
let v = result.expect("apply ok");
assert_eq!(v, vec![3.0 + 1.0, 4.0 + 2.0]);
assert!(c.is_active());
}
#[cfg(not(feature = "pinn_correction"))]
{
assert!(matches!(result, Err(PinnError::FeatureDisabled)));
assert!(!c.is_active());
}
}
#[test]
fn shape_mismatch_between_previous_and_step() {
let model = identity_residual_model();
let c = PinnCorrector::new(model);
let result = c.apply(&[1.0, 2.0], &[1.0, 2.0, 3.0]);
#[cfg(feature = "pinn_correction")]
assert!(matches!(result, Err(PinnError::InvalidModel(_))));
#[cfg(not(feature = "pinn_correction"))]
assert!(matches!(result, Err(PinnError::FeatureDisabled)));
}
}