use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
use scirs2_linalg::solve;
use std::path::PathBuf;
use super::storage::DiskStorage;
use crate::error::InterpolateError;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum OocRbfKernel {
Gaussian,
ThinPlate,
Multiquadric,
InverseMultiquadric,
}
#[derive(Debug, Clone)]
pub struct OutOfCoreRbfConfig {
pub chunk_size: usize,
pub cache_size_mb: f64,
pub scratch_dir: PathBuf,
pub kernel: OocRbfKernel,
pub epsilon: f64,
pub regularization: f64,
}
impl Default for OutOfCoreRbfConfig {
fn default() -> Self {
Self {
chunk_size: 1000,
cache_size_mb: 100.0,
scratch_dir: std::env::temp_dir(),
kernel: OocRbfKernel::ThinPlate,
epsilon: 1.0,
regularization: 1e-10,
}
}
}
pub struct OutOfCoreRbf {
config: OutOfCoreRbfConfig,
centers: Array2<f64>,
coeff_storage: Option<DiskStorage>,
coeff_path: PathBuf,
d_in: usize,
n: usize,
}
impl OutOfCoreRbf {
pub fn new(config: OutOfCoreRbfConfig) -> Self {
let coeff_path = config.scratch_dir.join("outofcore_rbf_coeffs.bin");
Self {
config,
centers: Array2::zeros((0, 0)),
coeff_storage: None,
coeff_path,
d_in: 0,
n: 0,
}
}
fn eval_kernel(&self, r: f64) -> f64 {
let eps = self.config.epsilon;
match self.config.kernel {
OocRbfKernel::Gaussian => (-eps * r * r).exp(),
OocRbfKernel::ThinPlate => {
if r == 0.0 {
0.0
} else {
r * r * r.ln()
}
}
OocRbfKernel::Multiquadric => (r * r + eps * eps).sqrt(),
OocRbfKernel::InverseMultiquadric => 1.0 / (r * r + eps * eps).sqrt(),
}
}
fn kernel_row(&self, xi: ArrayView1<f64>, centers: &Array2<f64>) -> Array1<f64> {
let n = centers.nrows();
let d = centers.ncols();
let mut row = Array1::zeros(n);
for j in 0..n {
let mut sq = 0.0_f64;
for k in 0..d {
let diff = xi[k] - centers[[j, k]];
sq += diff * diff;
}
row[j] = self.eval_kernel(sq.sqrt());
}
row
}
fn solve_system(a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>, InterpolateError> {
solve(&a.view(), &b.view(), None)
.map_err(|e| InterpolateError::LinalgError(format!("OOC RBF solve: {e}")))
}
fn fit_direct(
&self,
centers: &Array2<f64>,
values: &Array1<f64>,
) -> Result<Array1<f64>, InterpolateError> {
let n = centers.nrows();
let reg = self.config.regularization;
let mut k = Array2::<f64>::zeros((n, n));
for i in 0..n {
let row = self.kernel_row(centers.row(i), centers);
for j in 0..n {
k[[i, j]] = row[j];
}
k[[i, i]] += reg;
}
Self::solve_system(&k, values)
}
fn fit_landmark(
&self,
centers: &Array2<f64>,
values: &Array1<f64>,
) -> Result<Array1<f64>, InterpolateError> {
let n = centers.nrows();
let m = self.config.chunk_size.min(n);
let step = if m > 0 { n / m } else { 1 };
let landmark_idx: Vec<usize> = (0..m).map(|k| (k * step).min(n - 1)).collect();
let landmarks = centers.select(Axis(0), &landmark_idx);
let reg = self.config.regularization;
let mut k_mm = Array2::<f64>::zeros((m, m));
for i in 0..m {
let row = self.kernel_row(landmarks.row(i), &landmarks);
for j in 0..m {
k_mm[[i, j]] = row[j];
}
k_mm[[i, i]] += reg;
}
let mut k_mn_y = Array1::<f64>::zeros(m);
for i in 0..m {
let row = self.kernel_row(landmarks.row(i), centers);
k_mn_y[i] = row.dot(values);
}
let alpha_m = Self::solve_system(&k_mm, &k_mn_y)?;
let mut full_alpha = Array1::<f64>::zeros(n);
for (li, &idx) in landmark_idx.iter().enumerate() {
full_alpha[idx] = alpha_m[li];
}
Ok(full_alpha)
}
pub fn fit(
&mut self,
centers: &Array2<f64>,
values: &Array1<f64>,
) -> Result<(), InterpolateError> {
let n = centers.nrows();
let d = centers.ncols();
if n == 0 {
return Err(InterpolateError::InvalidInput {
message: "OutOfCoreRbf::fit: empty center set".into(),
});
}
if values.len() != n {
return Err(InterpolateError::ShapeMismatch {
expected: format!("{n} values"),
actual: format!("{} values", values.len()),
object: "OutOfCoreRbf::fit values vs centers".into(),
});
}
self.n = n;
self.d_in = d;
self.centers = centers.to_owned();
let use_direct = n <= self.config.chunk_size.saturating_mul(10);
let coefficients = if use_direct {
self.fit_direct(centers, values)?
} else {
self.fit_landmark(centers, values)?
};
std::fs::create_dir_all(&self.config.scratch_dir).map_err(|e| {
InterpolateError::IoError(format!(
"OutOfCoreRbf: cannot create scratch_dir '{}': {e}",
self.config.scratch_dir.display()
))
})?;
let storage = DiskStorage::create(&self.coeff_path, n, 1)?;
let coeff_slice = coefficients.as_slice().ok_or_else(|| {
InterpolateError::ComputationError(
"OutOfCoreRbf: coefficient array is non-contiguous".into(),
)
})?;
storage.write_rows(0, coeff_slice)?;
self.coeff_storage = Some(storage);
Ok(())
}
pub fn predict(&self, query_points: &Array2<f64>) -> Result<Array1<f64>, InterpolateError> {
let storage = self.coeff_storage.as_ref().ok_or_else(|| {
InterpolateError::ComputationError("OutOfCoreRbf: model not fitted".into())
})?;
if query_points.ncols() != self.d_in {
return Err(InterpolateError::DimensionMismatch(format!(
"query dim {} != training dim {}",
query_points.ncols(),
self.d_in
)));
}
let coeff_raw = storage.read_rows(0, self.n)?;
let coefficients = Array1::from_vec(coeff_raw);
let n_query = query_points.nrows();
let mut predictions = Array1::<f64>::zeros(n_query);
for qi in 0..n_query {
let k_row = self.kernel_row(query_points.row(qi), &self.centers);
predictions[qi] = k_row.dot(&coefficients);
}
Ok(predictions)
}
pub fn cleanup(&self) -> Result<(), InterpolateError> {
if self.coeff_path.exists() {
std::fs::remove_file(&self.coeff_path)
.map_err(|e| InterpolateError::IoError(format!("OutOfCoreRbf::cleanup: {e}")))?;
}
Ok(())
}
pub fn n_centers(&self) -> usize {
self.n
}
pub fn dim(&self) -> usize {
self.d_in
}
pub fn config(&self) -> &OutOfCoreRbfConfig {
&self.config
}
}