use scirs2_core::ndarray::{Array1, Array2};
use std::path::PathBuf;
use super::rbf::{OocRbfKernel, OutOfCoreRbf, OutOfCoreRbfConfig};
use crate::error::InterpolateError;
#[derive(Debug, Clone)]
pub struct OutOfCoreKrigingConfig {
pub chunk_size: usize,
pub cache_size_mb: f64,
pub scratch_dir: PathBuf,
pub length_scale: f64,
pub nugget: f64,
}
impl Default for OutOfCoreKrigingConfig {
fn default() -> Self {
Self {
chunk_size: 1000,
cache_size_mb: 100.0,
scratch_dir: std::env::temp_dir(),
length_scale: 1.0,
nugget: 1e-6,
}
}
}
pub struct OutOfCoreKriging {
inner: OutOfCoreRbf,
config: OutOfCoreKrigingConfig,
}
impl OutOfCoreKriging {
pub fn new(config: OutOfCoreKrigingConfig) -> Self {
let rbf_cfg = OutOfCoreRbfConfig {
chunk_size: config.chunk_size,
cache_size_mb: config.cache_size_mb,
scratch_dir: config.scratch_dir.clone(),
kernel: OocRbfKernel::Gaussian,
epsilon: config.length_scale,
regularization: config.nugget,
};
Self {
inner: OutOfCoreRbf::new(rbf_cfg),
config,
}
}
pub fn fit(
&mut self,
centers: &Array2<f64>,
values: &Array1<f64>,
) -> Result<(), InterpolateError> {
self.inner.fit(centers, values)
}
pub fn predict(&self, query_points: &Array2<f64>) -> Result<Array1<f64>, InterpolateError> {
self.inner.predict(query_points)
}
pub fn cleanup(&self) -> Result<(), InterpolateError> {
self.inner.cleanup()
}
pub fn n_centers(&self) -> usize {
self.inner.n_centers()
}
pub fn dim(&self) -> usize {
self.inner.dim()
}
pub fn config(&self) -> &OutOfCoreKrigingConfig {
&self.config
}
}