use crate::error::{InterpolateError, InterpolateResult};
use crate::neural_enhanced::tiny_mlp::{Activation, TinyMlp};
use crate::rbf_interpolation::{RbfKernel, ScatteredRbf};
use crate::traits::InterpolationFloat;
use scirs2_core::ndarray::{Array1, Array2};
#[derive(Debug, Clone)]
pub struct ResidualMlpRbfConfig {
pub hidden_sizes: Vec<usize>,
pub activation: Activation,
pub epochs: usize,
pub lr: f32,
pub batch_size: usize,
pub l2: f32,
pub seed: u64,
pub rbf_kernel: RbfKernel,
pub rbf_epsilon: Option<f64>,
pub rbf_nugget: f64,
}
impl Default for ResidualMlpRbfConfig {
fn default() -> Self {
Self {
hidden_sizes: vec![32, 16],
activation: Activation::Tanh,
epochs: 200,
lr: 1e-3,
batch_size: 16,
l2: 1e-4,
seed: 42,
rbf_kernel: RbfKernel::Gaussian,
rbf_epsilon: None,
rbf_nugget: 1e-3,
}
}
}
#[derive(Debug)]
pub struct ResidualMlpRbf {
config: ResidualMlpRbfConfig,
rbf: Option<ScatteredRbf<f64>>,
mlp: Option<TinyMlp>,
x_min: Vec<f64>,
x_max: Vec<f64>,
is_fitted: bool,
}
impl ResidualMlpRbf {
pub fn new(config: ResidualMlpRbfConfig) -> Self {
Self {
config,
rbf: None,
mlp: None,
x_min: Vec::new(),
x_max: Vec::new(),
is_fitted: false,
}
}
pub fn fit(&mut self, points: &Array2<f64>, values: &Array1<f64>) -> InterpolateResult<()> {
let n = points.nrows();
let d = points.ncols();
if n == 0 {
return Err(InterpolateError::empty_data("ResidualMlpRbf::fit"));
}
if values.len() != n {
return Err(InterpolateError::ShapeMismatch {
expected: format!("{n}"),
actual: format!("{}", values.len()),
object: "values".to_string(),
});
}
let rbf = ScatteredRbf::<f64>::new_with_nugget(
points,
values,
self.config.rbf_kernel,
self.config.rbf_epsilon,
self.config.rbf_nugget,
)?;
let mut residuals = Array1::<f64>::zeros(n);
for i in 0..n {
let pt: Vec<f64> = (0..d).map(|k| points[[i, k]]).collect();
let rbf_val = rbf.evaluate(&pt)?;
residuals[i] = values[i] - rbf_val;
}
let mut x_min = vec![f64::MAX; d];
let mut x_max = vec![f64::MIN; d];
for i in 0..n {
for k in 0..d {
let v = points[[i, k]];
if v < x_min[k] {
x_min[k] = v;
}
if v > x_max[k] {
x_max[k] = v;
}
}
}
for k in 0..d {
if (x_max[k] - x_min[k]).abs() < 1e-12 {
x_max[k] = x_min[k] + 1.0;
}
}
let mut layer_sizes = vec![d];
layer_sizes.extend_from_slice(&self.config.hidden_sizes);
layer_sizes.push(1);
let mut mlp = TinyMlp::new(&layer_sizes, self.config.activation, self.config.seed)?;
if self.config.epochs > 0 {
let bs = self.config.batch_size.max(1).min(n);
let lr = self.config.lr;
let l2 = self.config.l2;
for _epoch in 0..self.config.epochs {
let mut start = 0;
while start < n {
let end = (start + bs).min(n);
for i in start..end {
let x_norm = self.normalise_point(points, i, &x_min, &x_max, d);
let target = residuals[i] as f32;
mlp.train_step(&x_norm, target, lr, l2)?;
}
start = end;
}
}
}
self.rbf = Some(rbf);
self.mlp = Some(mlp);
self.x_min = x_min;
self.x_max = x_max;
self.is_fitted = true;
Ok(())
}
pub fn predict(&self, x: &Array1<f64>) -> InterpolateResult<f64> {
if !self.is_fitted {
return Err(InterpolateError::InvalidState(
"ResidualMlpRbf is not fitted; call fit() first".to_string(),
));
}
let rbf = self
.rbf
.as_ref()
.ok_or_else(|| InterpolateError::InvalidState("RBF not fitted".to_string()))?;
let mlp = self
.mlp
.as_ref()
.ok_or_else(|| InterpolateError::InvalidState("MLP not fitted".to_string()))?;
let d = x.len();
let pt: Vec<f64> = x.to_vec();
let rbf_val = rbf.evaluate(&pt)?;
let x_norm: Array1<f32> = Array1::from_iter((0..d).map(|k| {
let range = self.x_max[k] - self.x_min[k];
((x[k] - self.x_min[k]) / range * 2.0 - 1.0) as f32
}));
let mlp_out = mlp.forward(&x_norm)?;
let correction = mlp_out[0] as f64;
Ok(rbf_val + correction)
}
fn normalise_point(
&self,
points: &Array2<f64>,
i: usize,
x_min: &[f64],
x_max: &[f64],
d: usize,
) -> Array1<f32> {
Array1::from_iter((0..d).map(|k| {
let range = x_max[k] - x_min[k];
((points[[i, k]] - x_min[k]) / range * 2.0 - 1.0) as f32
}))
}
pub fn is_fitted(&self) -> bool {
self.is_fitted
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
fn sin_data(n: usize) -> (Array2<f64>, Array1<f64>) {
let mut pts = Array2::<f64>::zeros((n, 1));
let mut vals = Array1::<f64>::zeros(n);
for i in 0..n {
let x = i as f64 / (n - 1) as f64 * std::f64::consts::PI * 2.0;
pts[[i, 0]] = x;
vals[i] = x.sin();
}
(pts, vals)
}
#[test]
fn residual_rbf_is_fitted_after_fit() {
let (pts, vals) = sin_data(10);
let mut model = ResidualMlpRbf::new(ResidualMlpRbfConfig::default());
assert!(!model.is_fitted());
model.fit(&pts, &vals).expect("fit");
assert!(model.is_fitted());
}
#[test]
fn residual_rbf_predict_before_fit_returns_error() {
let model = ResidualMlpRbf::new(ResidualMlpRbfConfig::default());
let x = Array1::from(vec![1.0f64]);
assert!(model.predict(&x).is_err());
}
}