use crate::error::{KernelError, Result};
use crate::types::Kernel;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NystromConfig {
pub num_landmarks: usize,
pub sampling: SamplingMethod,
pub regularization: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SamplingMethod {
Uniform,
First,
KMeansPlusPlus,
}
impl NystromConfig {
pub fn new(num_landmarks: usize) -> Result<Self> {
if num_landmarks == 0 {
return Err(KernelError::InvalidParameter {
parameter: "num_landmarks".to_string(),
value: num_landmarks.to_string(),
reason: "must be greater than 0".to_string(),
});
}
Ok(Self {
num_landmarks,
sampling: SamplingMethod::Uniform,
regularization: 1e-6,
})
}
pub fn with_sampling(mut self, sampling: SamplingMethod) -> Self {
self.sampling = sampling;
self
}
pub fn with_regularization(mut self, regularization: f64) -> Result<Self> {
if regularization < 0.0 {
return Err(KernelError::InvalidParameter {
parameter: "regularization".to_string(),
value: regularization.to_string(),
reason: "must be non-negative".to_string(),
});
}
self.regularization = regularization;
Ok(self)
}
}
pub struct NystromApproximation {
c_matrix: Vec<Vec<f64>>,
w_inv: Vec<Vec<f64>>,
landmark_indices: Vec<usize>,
}
impl NystromApproximation {
pub fn fit(data: &[Vec<f64>], kernel: &dyn Kernel, config: NystromConfig) -> Result<Self> {
let n = data.len();
if config.num_landmarks > n {
return Err(KernelError::InvalidParameter {
parameter: "num_landmarks".to_string(),
value: config.num_landmarks.to_string(),
reason: format!("cannot exceed dataset size ({})", n),
});
}
let landmark_indices = Self::select_landmarks(data, kernel, &config)?;
let mut c_matrix = vec![vec![0.0; config.num_landmarks]; n];
for i in 0..n {
for (j, &landmark_idx) in landmark_indices.iter().enumerate() {
c_matrix[i][j] = kernel.compute(&data[i], &data[landmark_idx])?;
}
}
let mut w_matrix = vec![vec![0.0; config.num_landmarks]; config.num_landmarks];
for i in 0..config.num_landmarks {
for j in 0..config.num_landmarks {
w_matrix[i][j] =
kernel.compute(&data[landmark_indices[i]], &data[landmark_indices[j]])?;
}
}
#[allow(clippy::needless_range_loop)]
for i in 0..config.num_landmarks {
w_matrix[i][i] += config.regularization;
}
let w_inv = Self::pseudo_inverse(&w_matrix)?;
Ok(Self {
c_matrix,
w_inv,
landmark_indices,
})
}
fn select_landmarks(
data: &[Vec<f64>],
kernel: &dyn Kernel,
config: &NystromConfig,
) -> Result<Vec<usize>> {
match config.sampling {
SamplingMethod::First => {
Ok((0..config.num_landmarks).collect())
}
SamplingMethod::Uniform => {
let step = data.len() / config.num_landmarks;
Ok((0..config.num_landmarks).map(|i| i * step).collect())
}
SamplingMethod::KMeansPlusPlus => {
Self::kmeans_plusplus_sampling(data, kernel, config.num_landmarks)
}
}
}
fn kmeans_plusplus_sampling(
data: &[Vec<f64>],
kernel: &dyn Kernel,
num_landmarks: usize,
) -> Result<Vec<usize>> {
let n = data.len();
let mut landmarks = Vec::with_capacity(num_landmarks);
let mut min_distances = vec![f64::INFINITY; n];
landmarks.push(0);
for _ in 1..num_landmarks {
let last_landmark = *landmarks
.last()
.expect("landmarks is non-empty after first push");
for i in 0..n {
if landmarks.contains(&i) {
continue;
}
let similarity = kernel.compute(&data[i], &data[last_landmark])?;
let distance = 1.0 - similarity;
min_distances[i] = min_distances[i].min(distance);
}
let mut max_dist = 0.0;
let mut best_idx = 0;
#[allow(clippy::needless_range_loop)]
for i in 0..n {
if !landmarks.contains(&i) && min_distances[i] > max_dist {
max_dist = min_distances[i];
best_idx = i;
}
}
landmarks.push(best_idx);
}
Ok(landmarks)
}
#[allow(clippy::needless_range_loop)]
fn pseudo_inverse(matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let n = matrix.len();
if n == 0 {
return Err(KernelError::ComputationError(
"Cannot invert empty matrix".to_string(),
));
}
let mut augmented = vec![vec![0.0; 2 * n]; n];
for i in 0..n {
for j in 0..n {
augmented[i][j] = matrix[i][j];
}
augmented[i][n + i] = 1.0;
}
for i in 0..n {
let mut max_row = i;
for k in (i + 1)..n {
if augmented[k][i].abs() > augmented[max_row][i].abs() {
max_row = k;
}
}
if max_row != i {
augmented.swap(i, max_row);
}
if augmented[i][i].abs() < 1e-10 {
return Err(KernelError::ComputationError(
"Matrix is singular or nearly singular".to_string(),
));
}
let pivot = augmented[i][i];
for j in 0..(2 * n) {
augmented[i][j] /= pivot;
}
for k in 0..n {
if k != i {
let factor = augmented[k][i];
for j in 0..(2 * n) {
augmented[k][j] -= factor * augmented[i][j];
}
}
}
}
let mut inverse = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
inverse[i][j] = augmented[i][n + j];
}
}
Ok(inverse)
}
pub fn approximate(&self, i: usize, j: usize) -> Result<f64> {
if i >= self.c_matrix.len() || j >= self.c_matrix.len() {
return Err(KernelError::ComputationError(format!(
"Indices out of bounds: i={}, j={}",
i, j
)));
}
let m = self.w_inv.len();
let mut result = 0.0;
for k in 0..m {
for idx in 0..m {
result += self.c_matrix[i][k] * self.w_inv[k][idx] * self.c_matrix[j][idx];
}
}
Ok(result)
}
pub fn get_approximate_matrix(&self) -> Result<Vec<Vec<f64>>> {
let n = self.c_matrix.len();
let mut matrix = vec![vec![0.0; n]; n];
#[allow(clippy::needless_range_loop)]
for i in 0..n {
for j in 0..n {
matrix[i][j] = self.approximate(i, j)?;
}
}
Ok(matrix)
}
pub fn num_samples(&self) -> usize {
self.c_matrix.len()
}
pub fn num_landmarks(&self) -> usize {
self.landmark_indices.len()
}
pub fn landmark_indices(&self) -> &[usize] {
&self.landmark_indices
}
pub fn approximation_error(&self, exact_matrix: &[Vec<f64>]) -> Result<f64> {
let approx_matrix = self.get_approximate_matrix()?;
let n = exact_matrix.len();
if approx_matrix.len() != n || approx_matrix[0].len() != n {
return Err(KernelError::DimensionMismatch {
expected: vec![n, n],
got: vec![approx_matrix.len(), approx_matrix[0].len()],
context: "approximation error computation".to_string(),
});
}
let mut error = 0.0;
for i in 0..n {
for j in 0..n {
let diff = exact_matrix[i][j] - approx_matrix[i][j];
error += diff * diff;
}
}
Ok(error.sqrt())
}
pub fn compression_ratio(&self) -> f64 {
let n = self.num_samples() as f64;
let m = self.num_landmarks() as f64;
(n * n) / (n * m + m * m)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::LinearKernel;
fn generate_test_data(n: usize, dim: usize) -> Vec<Vec<f64>> {
(0..n)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f64).sin()).collect())
.collect()
}
#[test]
fn test_nystrom_config() {
let config = NystromConfig::new(10).expect("unwrap");
assert_eq!(config.num_landmarks, 10);
assert_eq!(config.sampling, SamplingMethod::Uniform);
}
#[test]
fn test_nystrom_config_invalid() {
let result = NystromConfig::new(0);
assert!(result.is_err());
}
#[test]
fn test_nystrom_approximation_basic() {
let data = generate_test_data(50, 10);
let kernel = LinearKernel::new();
let config = NystromConfig::new(10).expect("unwrap");
let approx = NystromApproximation::fit(&data, &kernel, config).expect("unwrap");
assert_eq!(approx.num_samples(), 50);
assert_eq!(approx.num_landmarks(), 10);
}
#[test]
fn test_nystrom_approximation_value() {
let data = generate_test_data(20, 5);
let kernel = LinearKernel::new();
let config = NystromConfig::new(5).expect("unwrap");
let approx = NystromApproximation::fit(&data, &kernel, config).expect("unwrap");
let value = approx.approximate(0, 1).expect("unwrap");
assert!(value.is_finite());
}
#[test]
fn test_nystrom_sampling_methods() {
let data = generate_test_data(30, 5);
let kernel = LinearKernel::new();
let config1 = NystromConfig::new(10)
.expect("unwrap")
.with_sampling(SamplingMethod::First);
let approx1 = NystromApproximation::fit(&data, &kernel, config1).expect("unwrap");
assert_eq!(approx1.landmark_indices()[0], 0);
let config2 = NystromConfig::new(10)
.expect("unwrap")
.with_sampling(SamplingMethod::Uniform);
let approx2 = NystromApproximation::fit(&data, &kernel, config2).expect("unwrap");
assert_eq!(approx2.num_landmarks(), 10);
let config3 = NystromConfig::new(10)
.expect("unwrap")
.with_sampling(SamplingMethod::KMeansPlusPlus);
let approx3 = NystromApproximation::fit(&data, &kernel, config3).expect("unwrap");
assert_eq!(approx3.num_landmarks(), 10);
}
#[test]
fn test_nystrom_compression_ratio() {
let data = generate_test_data(100, 5);
let kernel = LinearKernel::new();
let config = NystromConfig::new(20).expect("unwrap");
let approx = NystromApproximation::fit(&data, &kernel, config).expect("unwrap");
let ratio = approx.compression_ratio();
assert!(ratio > 3.0);
}
#[test]
fn test_nystrom_approximation_error() {
let data = generate_test_data(30, 5);
let kernel = LinearKernel::new();
let exact = kernel.compute_matrix(&data).expect("unwrap");
let config = NystromConfig::new(20).expect("unwrap");
let approx = NystromApproximation::fit(&data, &kernel, config).expect("unwrap");
let error = approx.approximation_error(&exact).expect("unwrap");
assert!(error < 10.0);
}
#[test]
fn test_nystrom_too_many_landmarks() {
let data = generate_test_data(10, 5);
let kernel = LinearKernel::new();
let config = NystromConfig::new(20).expect("unwrap");
let result = NystromApproximation::fit(&data, &kernel, config);
assert!(result.is_err());
}
#[test]
fn test_nystrom_regularization() {
let data = generate_test_data(20, 5);
let kernel = LinearKernel::new();
let config = NystromConfig::new(5)
.expect("unwrap")
.with_regularization(1e-4)
.expect("unwrap");
let approx = NystromApproximation::fit(&data, &kernel, config).expect("unwrap");
assert!(approx.approximate(0, 1).is_ok());
}
#[test]
fn test_nystrom_invalid_regularization() {
let result = NystromConfig::new(10)
.expect("unwrap")
.with_regularization(-0.1);
assert!(result.is_err());
}
}