use crate::error::{KernelError, Result};
use crate::types::Kernel;
use std::collections::VecDeque;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct OnlineConfig {
pub initial_capacity: usize,
pub compute_full_matrix: bool,
}
impl Default for OnlineConfig {
fn default() -> Self {
Self {
initial_capacity: 64,
compute_full_matrix: true,
}
}
}
impl OnlineConfig {
pub fn with_capacity(capacity: usize) -> Self {
Self {
initial_capacity: capacity,
..Default::default()
}
}
}
#[derive(Debug, Clone, Default)]
pub struct OnlineStats {
pub samples_added: usize,
pub samples_removed: usize,
pub kernel_computations: usize,
pub resizes: usize,
}
pub struct OnlineKernelMatrix {
kernel: Box<dyn Kernel>,
samples: Vec<Vec<f64>>,
matrix: Vec<Vec<f64>>,
config: OnlineConfig,
stats: OnlineStats,
}
impl OnlineKernelMatrix {
pub fn new(kernel: Box<dyn Kernel>) -> Self {
Self::with_config(kernel, OnlineConfig::default())
}
pub fn with_config(kernel: Box<dyn Kernel>, config: OnlineConfig) -> Self {
Self {
kernel,
samples: Vec::with_capacity(config.initial_capacity),
matrix: Vec::with_capacity(config.initial_capacity),
config,
stats: OnlineStats::default(),
}
}
pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<()> {
if let Some(first) = self.samples.first() {
if sample.len() != first.len() {
return Err(KernelError::DimensionMismatch {
expected: vec![first.len()],
got: vec![sample.len()],
context: "online kernel matrix".to_string(),
});
}
}
let n = self.samples.len();
let mut new_row = Vec::with_capacity(n + 1);
for existing in &self.samples {
let k = self.kernel.compute(&sample, existing)?;
new_row.push(k);
self.stats.kernel_computations += 1;
}
let k_self = self.kernel.compute(&sample, &sample)?;
new_row.push(k_self);
self.stats.kernel_computations += 1;
for (i, row) in self.matrix.iter_mut().enumerate() {
row.push(new_row[i]);
}
self.matrix.push(new_row);
self.samples.push(sample);
self.stats.samples_added += 1;
Ok(())
}
pub fn add_samples(&mut self, samples: Vec<Vec<f64>>) -> Result<()> {
for sample in samples {
self.add_sample(sample)?;
}
Ok(())
}
pub fn remove_sample(&mut self, index: usize) -> Result<Vec<f64>> {
if index >= self.samples.len() {
return Err(KernelError::ComputationError(format!(
"Index {} out of bounds for {} samples",
index,
self.samples.len()
)));
}
let removed = self.samples.remove(index);
self.matrix.remove(index);
for row in &mut self.matrix {
row.remove(index);
}
self.stats.samples_removed += 1;
Ok(removed)
}
pub fn get_matrix(&self) -> &Vec<Vec<f64>> {
&self.matrix
}
pub fn get_samples(&self) -> &Vec<Vec<f64>> {
&self.samples
}
pub fn get(&self, i: usize, j: usize) -> Option<f64> {
self.matrix.get(i).and_then(|row| row.get(j).copied())
}
pub fn len(&self) -> usize {
self.samples.len()
}
pub fn is_empty(&self) -> bool {
self.samples.is_empty()
}
pub fn stats(&self) -> &OnlineStats {
&self.stats
}
pub fn clear(&mut self) {
self.samples.clear();
self.matrix.clear();
self.stats = OnlineStats::default();
}
pub fn kernel(&self) -> &dyn Kernel {
self.kernel.as_ref()
}
pub fn config(&self) -> &OnlineConfig {
&self.config
}
pub fn compute_with_sample(&self, query: &[f64], sample_idx: usize) -> Result<f64> {
let sample = self.samples.get(sample_idx).ok_or_else(|| {
KernelError::ComputationError(format!("Sample index {} not found", sample_idx))
})?;
self.kernel.compute(query, sample)
}
pub fn compute_with_all(&self, query: &[f64]) -> Result<Vec<f64>> {
let mut result = Vec::with_capacity(self.samples.len());
for sample in &self.samples {
let k = self.kernel.compute(query, sample)?;
result.push(k);
}
Ok(result)
}
pub fn to_matrix(&self) -> Vec<Vec<f64>> {
self.matrix.clone()
}
}
pub struct WindowedKernelMatrix {
kernel: Box<dyn Kernel>,
window_size: usize,
samples: VecDeque<Vec<f64>>,
matrix: Vec<Vec<f64>>,
stats: OnlineStats,
}
impl WindowedKernelMatrix {
pub fn new(kernel: Box<dyn Kernel>, window_size: usize) -> Self {
assert!(window_size > 0, "Window size must be positive");
Self {
kernel,
window_size,
samples: VecDeque::with_capacity(window_size),
matrix: Vec::with_capacity(window_size),
stats: OnlineStats::default(),
}
}
pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<Option<Vec<f64>>> {
if let Some(first) = self.samples.front() {
if sample.len() != first.len() {
return Err(KernelError::DimensionMismatch {
expected: vec![first.len()],
got: vec![sample.len()],
context: "windowed kernel matrix".to_string(),
});
}
}
let evicted = if self.samples.len() >= self.window_size {
let removed = self.samples.pop_front();
self.matrix.remove(0);
for row in &mut self.matrix {
row.remove(0);
}
self.stats.samples_removed += 1;
removed
} else {
None
};
let n = self.samples.len();
let mut new_row = Vec::with_capacity(n + 1);
for existing in &self.samples {
let k = self.kernel.compute(&sample, existing)?;
new_row.push(k);
self.stats.kernel_computations += 1;
}
let k_self = self.kernel.compute(&sample, &sample)?;
new_row.push(k_self);
self.stats.kernel_computations += 1;
for (i, row) in self.matrix.iter_mut().enumerate() {
row.push(new_row[i]);
}
self.matrix.push(new_row);
self.samples.push_back(sample);
self.stats.samples_added += 1;
Ok(evicted)
}
pub fn get_matrix(&self) -> &Vec<Vec<f64>> {
&self.matrix
}
pub fn get_samples(&self) -> &VecDeque<Vec<f64>> {
&self.samples
}
pub fn window_size(&self) -> usize {
self.window_size
}
pub fn len(&self) -> usize {
self.samples.len()
}
pub fn is_empty(&self) -> bool {
self.samples.is_empty()
}
pub fn is_full(&self) -> bool {
self.samples.len() >= self.window_size
}
pub fn stats(&self) -> &OnlineStats {
&self.stats
}
pub fn clear(&mut self) {
self.samples.clear();
self.matrix.clear();
self.stats = OnlineStats::default();
}
pub fn compute_with_all(&self, query: &[f64]) -> Result<Vec<f64>> {
let mut result = Vec::with_capacity(self.samples.len());
for sample in &self.samples {
let k = self.kernel.compute(query, sample)?;
result.push(k);
}
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct ForgetfulConfig {
pub lambda: f64,
pub removal_threshold: Option<f64>,
pub max_samples: Option<usize>,
}
impl Default for ForgetfulConfig {
fn default() -> Self {
Self {
lambda: 0.99,
removal_threshold: Some(0.01),
max_samples: None,
}
}
}
impl ForgetfulConfig {
pub fn with_lambda(lambda: f64) -> Result<Self> {
if lambda <= 0.0 || lambda > 1.0 {
return Err(KernelError::InvalidParameter {
parameter: "lambda".to_string(),
value: lambda.to_string(),
reason: "lambda must be in (0, 1]".to_string(),
});
}
Ok(Self {
lambda,
..Default::default()
})
}
pub fn with_max_samples(mut self, max: usize) -> Self {
self.max_samples = Some(max);
self
}
pub fn with_threshold(mut self, threshold: f64) -> Self {
self.removal_threshold = Some(threshold);
self
}
}
pub struct ForgetfulKernelMatrix {
kernel: Box<dyn Kernel>,
config: ForgetfulConfig,
samples: Vec<Vec<f64>>,
weights: Vec<f64>,
matrix: Vec<Vec<f64>>,
stats: OnlineStats,
}
impl ForgetfulKernelMatrix {
pub fn new(kernel: Box<dyn Kernel>, config: ForgetfulConfig) -> Self {
Self {
kernel,
config,
samples: Vec::new(),
weights: Vec::new(),
matrix: Vec::new(),
stats: OnlineStats::default(),
}
}
pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<()> {
if let Some(first) = self.samples.first() {
if sample.len() != first.len() {
return Err(KernelError::DimensionMismatch {
expected: vec![first.len()],
got: vec![sample.len()],
context: "forgetful kernel matrix".to_string(),
});
}
}
for weight in &mut self.weights {
*weight *= self.config.lambda;
}
if let Some(threshold) = self.config.removal_threshold {
let mut i = 0;
while i < self.weights.len() {
if self.weights[i] < threshold {
self.remove_at(i);
} else {
i += 1;
}
}
}
if let Some(max) = self.config.max_samples {
while self.samples.len() >= max && !self.samples.is_empty() {
if let Some((min_idx, _)) =
self.weights.iter().enumerate().min_by(|(_, a), (_, b)| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
{
self.remove_at(min_idx);
}
}
}
let n = self.samples.len();
let mut new_row = Vec::with_capacity(n + 1);
for existing in &self.samples {
let k = self.kernel.compute(&sample, existing)?;
new_row.push(k);
self.stats.kernel_computations += 1;
}
let k_self = self.kernel.compute(&sample, &sample)?;
new_row.push(k_self);
self.stats.kernel_computations += 1;
for (i, row) in self.matrix.iter_mut().enumerate() {
row.push(new_row[i]);
}
self.matrix.push(new_row);
self.samples.push(sample);
self.weights.push(1.0); self.stats.samples_added += 1;
Ok(())
}
fn remove_at(&mut self, index: usize) {
self.samples.remove(index);
self.weights.remove(index);
self.matrix.remove(index);
for row in &mut self.matrix {
row.remove(index);
}
self.stats.samples_removed += 1;
}
pub fn get_matrix(&self) -> &Vec<Vec<f64>> {
&self.matrix
}
pub fn get_weighted_matrix(&self) -> Vec<Vec<f64>> {
let n = self.matrix.len();
let mut weighted = vec![vec![0.0; n]; n];
for (i, (row, &weight_i)) in self.matrix.iter().zip(&self.weights).enumerate() {
let sqrt_wi = weight_i.sqrt();
for (j, (&k_val, &weight_j)) in row.iter().zip(&self.weights).enumerate() {
let sqrt_wj = weight_j.sqrt();
weighted[i][j] = k_val * sqrt_wi * sqrt_wj;
}
}
weighted
}
pub fn get_weights(&self) -> &Vec<f64> {
&self.weights
}
pub fn get_samples(&self) -> &Vec<Vec<f64>> {
&self.samples
}
pub fn len(&self) -> usize {
self.samples.len()
}
pub fn is_empty(&self) -> bool {
self.samples.is_empty()
}
pub fn stats(&self) -> &OnlineStats {
&self.stats
}
pub fn lambda(&self) -> f64 {
self.config.lambda
}
pub fn clear(&mut self) {
self.samples.clear();
self.weights.clear();
self.matrix.clear();
self.stats = OnlineStats::default();
}
pub fn compute_weighted(&self, query: &[f64]) -> Result<Vec<f64>> {
let mut result = Vec::with_capacity(self.samples.len());
for (sample, weight) in self.samples.iter().zip(&self.weights) {
let k = self.kernel.compute(query, sample)?;
result.push(k * weight.sqrt());
}
Ok(result)
}
pub fn effective_size(&self) -> f64 {
self.weights.iter().sum()
}
}
pub struct AdaptiveKernelMatrix {
kernel: Arc<dyn Fn(f64) -> Box<dyn Kernel + Send + Sync> + Send + Sync>,
current_bandwidth: f64,
distance_sum: f64,
distance_count: usize,
inner: OnlineKernelMatrix,
adaptation_rate: f64,
}
impl AdaptiveKernelMatrix {
pub fn new<F>(kernel_factory: F, initial_bandwidth: f64, adaptation_rate: f64) -> Self
where
F: Fn(f64) -> Box<dyn Kernel + Send + Sync> + Send + Sync + 'static,
{
let factory = Arc::new(kernel_factory);
let kernel = factory(initial_bandwidth);
Self {
kernel: factory,
current_bandwidth: initial_bandwidth,
distance_sum: 0.0,
distance_count: 0,
inner: OnlineKernelMatrix::new(kernel),
adaptation_rate,
}
}
pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<()> {
for existing in self.inner.get_samples() {
let dist_sq: f64 = sample
.iter()
.zip(existing.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
let dist = dist_sq.sqrt();
self.distance_sum += dist;
self.distance_count += 1;
}
if self.distance_count > 0 {
let mean_dist = self.distance_sum / self.distance_count as f64;
let new_bandwidth = mean_dist / 2.0_f64.sqrt();
self.current_bandwidth = (1.0 - self.adaptation_rate) * self.current_bandwidth
+ self.adaptation_rate * new_bandwidth;
let new_kernel = (self.kernel)(self.current_bandwidth);
let samples: Vec<Vec<f64>> = self.inner.get_samples().clone();
self.inner = OnlineKernelMatrix::new(new_kernel);
for s in samples {
self.inner.add_sample(s)?;
}
}
self.inner.add_sample(sample)
}
pub fn bandwidth(&self) -> f64 {
self.current_bandwidth
}
pub fn get_matrix(&self) -> &Vec<Vec<f64>> {
self.inner.get_matrix()
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
#[cfg(test)]
#[allow(clippy::needless_range_loop)]
mod tests {
use super::*;
use crate::{LinearKernel, RbfKernel, RbfKernelConfig};
#[test]
fn test_online_kernel_matrix_basic() {
let kernel = LinearKernel::new();
let mut online = OnlineKernelMatrix::new(Box::new(kernel));
assert!(online.is_empty());
online.add_sample(vec![1.0, 2.0]).expect("unwrap");
assert_eq!(online.len(), 1);
online.add_sample(vec![3.0, 4.0]).expect("unwrap");
assert_eq!(online.len(), 2);
let matrix = online.get_matrix();
assert_eq!(matrix.len(), 2);
assert_eq!(matrix[0].len(), 2);
}
#[test]
fn test_online_kernel_matrix_values() {
let kernel = LinearKernel::new();
let mut online = OnlineKernelMatrix::new(Box::new(kernel));
online.add_sample(vec![1.0, 0.0]).expect("unwrap");
online.add_sample(vec![0.0, 1.0]).expect("unwrap");
let matrix = online.get_matrix();
assert!((matrix[0][0] - 1.0).abs() < 1e-10);
assert!((matrix[1][1] - 1.0).abs() < 1e-10);
assert!((matrix[0][1]).abs() < 1e-10);
assert!((matrix[1][0]).abs() < 1e-10);
}
#[test]
fn test_online_kernel_matrix_symmetry() {
let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("unwrap");
let mut online = OnlineKernelMatrix::new(Box::new(kernel));
online.add_sample(vec![1.0, 2.0, 3.0]).expect("unwrap");
online.add_sample(vec![4.0, 5.0, 6.0]).expect("unwrap");
online.add_sample(vec![7.0, 8.0, 9.0]).expect("unwrap");
let matrix = online.get_matrix();
for i in 0..3 {
for j in 0..3 {
assert!(
(matrix[i][j] - matrix[j][i]).abs() < 1e-10,
"Matrix not symmetric at ({}, {})",
i,
j
);
}
}
}
#[test]
fn test_online_kernel_matrix_remove() {
let kernel = LinearKernel::new();
let mut online = OnlineKernelMatrix::new(Box::new(kernel));
online.add_sample(vec![1.0]).expect("unwrap");
online.add_sample(vec![2.0]).expect("unwrap");
online.add_sample(vec![3.0]).expect("unwrap");
let removed = online.remove_sample(1).expect("unwrap");
assert_eq!(removed, vec![2.0]);
assert_eq!(online.len(), 2);
let matrix = online.get_matrix();
assert_eq!(matrix.len(), 2);
assert_eq!(matrix[0].len(), 2);
}
#[test]
fn test_online_kernel_matrix_dimension_mismatch() {
let kernel = LinearKernel::new();
let mut online = OnlineKernelMatrix::new(Box::new(kernel));
online.add_sample(vec![1.0, 2.0]).expect("unwrap");
let result = online.add_sample(vec![1.0, 2.0, 3.0]);
assert!(result.is_err());
}
#[test]
fn test_online_kernel_matrix_compute_with_all() {
let kernel = LinearKernel::new();
let mut online = OnlineKernelMatrix::new(Box::new(kernel));
online.add_sample(vec![1.0, 0.0]).expect("unwrap");
online.add_sample(vec![0.0, 1.0]).expect("unwrap");
let query = vec![1.0, 1.0];
let result = online.compute_with_all(&query).expect("unwrap");
assert!((result[0] - 1.0).abs() < 1e-10);
assert!((result[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_online_kernel_matrix_stats() {
let kernel = LinearKernel::new();
let mut online = OnlineKernelMatrix::new(Box::new(kernel));
online.add_sample(vec![1.0]).expect("unwrap");
online.add_sample(vec![2.0]).expect("unwrap");
online.add_sample(vec![3.0]).expect("unwrap");
let stats = online.stats();
assert_eq!(stats.samples_added, 3);
assert_eq!(stats.kernel_computations, 6);
}
#[test]
fn test_windowed_kernel_matrix_basic() {
let kernel = LinearKernel::new();
let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 3);
assert_eq!(windowed.window_size(), 3);
assert!(!windowed.is_full());
windowed.add_sample(vec![1.0]).expect("unwrap");
windowed.add_sample(vec![2.0]).expect("unwrap");
windowed.add_sample(vec![3.0]).expect("unwrap");
assert!(windowed.is_full());
assert_eq!(windowed.len(), 3);
}
#[test]
fn test_windowed_kernel_matrix_eviction() {
let kernel = LinearKernel::new();
let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 2);
windowed.add_sample(vec![1.0]).expect("unwrap");
windowed.add_sample(vec![2.0]).expect("unwrap");
let evicted = windowed.add_sample(vec![3.0]).expect("unwrap");
assert_eq!(evicted, Some(vec![1.0]));
assert_eq!(windowed.len(), 2);
let samples: Vec<_> = windowed.get_samples().iter().cloned().collect();
assert_eq!(samples, vec![vec![2.0], vec![3.0]]);
}
#[test]
fn test_windowed_kernel_matrix_values() {
let kernel = LinearKernel::new();
let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 2);
windowed.add_sample(vec![1.0, 0.0]).expect("unwrap");
windowed.add_sample(vec![0.0, 1.0]).expect("unwrap");
let matrix = windowed.get_matrix();
assert!((matrix[0][0] - 1.0).abs() < 1e-10);
assert!((matrix[1][1] - 1.0).abs() < 1e-10);
assert!((matrix[0][1]).abs() < 1e-10);
windowed.add_sample(vec![1.0, 1.0]).expect("unwrap");
let matrix = windowed.get_matrix();
assert!((matrix[0][0] - 1.0).abs() < 1e-10);
assert!((matrix[1][1] - 2.0).abs() < 1e-10);
assert!((matrix[0][1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_windowed_kernel_matrix_dimension_mismatch() {
let kernel = LinearKernel::new();
let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 3);
windowed.add_sample(vec![1.0, 2.0]).expect("unwrap");
let result = windowed.add_sample(vec![1.0]);
assert!(result.is_err());
}
#[test]
fn test_forgetful_kernel_matrix_basic() {
let kernel = LinearKernel::new();
let config = ForgetfulConfig::with_lambda(0.9).expect("unwrap");
let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
forgetful.add_sample(vec![1.0]).expect("unwrap");
forgetful.add_sample(vec![2.0]).expect("unwrap");
assert_eq!(forgetful.len(), 2);
assert!((forgetful.lambda() - 0.9).abs() < 1e-10);
}
#[test]
fn test_forgetful_kernel_matrix_weights() {
let kernel = LinearKernel::new();
let config = ForgetfulConfig {
lambda: 0.8,
removal_threshold: None,
max_samples: None,
};
let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
forgetful.add_sample(vec![1.0]).expect("unwrap");
forgetful.add_sample(vec![2.0]).expect("unwrap");
forgetful.add_sample(vec![3.0]).expect("unwrap");
let weights = forgetful.get_weights();
assert!((weights[2] - 1.0).abs() < 1e-10);
assert!((weights[1] - 0.8).abs() < 1e-10);
assert!((weights[0] - 0.64).abs() < 1e-10);
}
#[test]
fn test_forgetful_kernel_matrix_weighted_matrix() {
let kernel = LinearKernel::new();
let config = ForgetfulConfig {
lambda: 0.5,
removal_threshold: None,
max_samples: None,
};
let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
forgetful.add_sample(vec![1.0]).expect("unwrap");
forgetful.add_sample(vec![1.0]).expect("unwrap");
let weighted = forgetful.get_weighted_matrix();
assert!((weighted[0][0] - 0.5).abs() < 1e-10);
assert!((weighted[1][1] - 1.0).abs() < 1e-10);
assert!((weighted[0][1] - 0.5_f64.sqrt()).abs() < 1e-10);
}
#[test]
fn test_forgetful_kernel_matrix_removal_threshold() {
let kernel = LinearKernel::new();
let config = ForgetfulConfig {
lambda: 0.5,
removal_threshold: Some(0.3),
max_samples: None,
};
let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
forgetful.add_sample(vec![1.0]).expect("unwrap");
forgetful.add_sample(vec![2.0]).expect("unwrap");
forgetful.add_sample(vec![3.0]).expect("unwrap");
assert_eq!(forgetful.len(), 2);
}
#[test]
fn test_forgetful_kernel_matrix_max_samples() {
let kernel = LinearKernel::new();
let config = ForgetfulConfig {
lambda: 1.0, removal_threshold: None,
max_samples: Some(2),
};
let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
forgetful.add_sample(vec![1.0]).expect("unwrap");
forgetful.add_sample(vec![2.0]).expect("unwrap");
forgetful.add_sample(vec![3.0]).expect("unwrap");
assert_eq!(forgetful.len(), 2);
}
#[test]
fn test_forgetful_kernel_matrix_effective_size() {
let kernel = LinearKernel::new();
let config = ForgetfulConfig {
lambda: 0.9,
removal_threshold: None,
max_samples: None,
};
let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
forgetful.add_sample(vec![1.0]).expect("unwrap");
forgetful.add_sample(vec![2.0]).expect("unwrap");
forgetful.add_sample(vec![3.0]).expect("unwrap");
let eff_size = forgetful.effective_size();
assert!((eff_size - 2.71).abs() < 1e-10);
}
#[test]
fn test_forgetful_kernel_matrix_invalid_lambda() {
let result = ForgetfulConfig::with_lambda(0.0);
assert!(result.is_err());
let result = ForgetfulConfig::with_lambda(1.5);
assert!(result.is_err());
}
#[test]
fn test_forgetful_kernel_matrix_dimension_mismatch() {
let kernel = LinearKernel::new();
let config = ForgetfulConfig::with_lambda(0.9).expect("unwrap");
let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
forgetful.add_sample(vec![1.0, 2.0]).expect("unwrap");
let result = forgetful.add_sample(vec![1.0]);
assert!(result.is_err());
}
#[test]
fn test_adaptive_kernel_matrix_basic() {
let mut adaptive = AdaptiveKernelMatrix::new(
|gamma| Box::new(RbfKernel::new(RbfKernelConfig::new(gamma)).expect("unwrap")),
1.0,
0.1,
);
adaptive.add_sample(vec![1.0, 2.0]).expect("unwrap");
adaptive.add_sample(vec![3.0, 4.0]).expect("unwrap");
adaptive.add_sample(vec![5.0, 6.0]).expect("unwrap");
assert_eq!(adaptive.len(), 3);
assert!(adaptive.bandwidth() > 0.0);
}
#[test]
fn test_adaptive_kernel_matrix_bandwidth_update() {
let mut adaptive = AdaptiveKernelMatrix::new(
|gamma| Box::new(RbfKernel::new(RbfKernelConfig::new(gamma)).expect("unwrap")),
1.0,
0.5, );
let initial = adaptive.bandwidth();
adaptive.add_sample(vec![0.0]).expect("unwrap");
adaptive.add_sample(vec![10.0]).expect("unwrap");
let after = adaptive.bandwidth();
assert_ne!(initial, after);
}
#[test]
fn test_online_empty_operations() {
let kernel = LinearKernel::new();
let online = OnlineKernelMatrix::new(Box::new(kernel));
assert!(online.is_empty());
assert!(online.get_matrix().is_empty());
assert!(online.get_samples().is_empty());
}
#[test]
fn test_windowed_clear() {
let kernel = LinearKernel::new();
let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 3);
windowed.add_sample(vec![1.0]).expect("unwrap");
windowed.add_sample(vec![2.0]).expect("unwrap");
windowed.clear();
assert!(windowed.is_empty());
assert_eq!(windowed.len(), 0);
}
#[test]
fn test_forgetful_clear() {
let kernel = LinearKernel::new();
let config = ForgetfulConfig::with_lambda(0.9).expect("unwrap");
let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
forgetful.add_sample(vec![1.0]).expect("unwrap");
forgetful.add_sample(vec![2.0]).expect("unwrap");
forgetful.clear();
assert!(forgetful.is_empty());
assert_eq!(forgetful.len(), 0);
}
}