use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::error::{KernelError, Result};
use crate::types::Kernel;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SparseKernelMatrix {
size: usize,
row_ptr: Vec<usize>,
col_idx: Vec<usize>,
values: Vec<f64>,
#[serde(skip)]
temp_map: HashMap<(usize, usize), f64>,
}
impl SparseKernelMatrix {
pub fn new(size: usize) -> Self {
Self {
size,
row_ptr: vec![0; size + 1],
col_idx: Vec::new(),
values: Vec::new(),
temp_map: HashMap::new(),
}
}
pub fn set(&mut self, row: usize, col: usize, value: f64) {
if row >= self.size || col >= self.size {
return;
}
if value.abs() < 1e-10 {
self.temp_map.remove(&(row, col));
} else {
self.temp_map.insert((row, col), value);
}
}
pub fn get(&self, row: usize, col: usize) -> Option<f64> {
if row >= self.size || col >= self.size {
return None;
}
if let Some(&value) = self.temp_map.get(&(row, col)) {
return Some(value);
}
let start = self.row_ptr[row];
let end = self.row_ptr[row + 1];
for i in start..end {
if self.col_idx[i] == col {
return Some(self.values[i]);
}
}
None
}
pub fn finalize(&mut self) {
if self.temp_map.is_empty() {
return;
}
self.col_idx.clear();
self.values.clear();
self.row_ptr = vec![0; self.size + 1];
let mut entries: Vec<_> = self.temp_map.iter().collect();
entries.sort_by_key(|&((row, col), _)| (*row, *col));
let mut current_row = 0;
for (&(row, col), &value) in &entries {
while current_row < row {
current_row += 1;
self.row_ptr[current_row] = self.col_idx.len();
}
self.col_idx.push(col);
self.values.push(value);
}
while current_row < self.size {
current_row += 1;
self.row_ptr[current_row] = self.col_idx.len();
}
self.temp_map.clear();
}
pub fn nnz(&self) -> usize {
self.values.len() + self.temp_map.len()
}
pub fn size(&self) -> usize {
self.size
}
pub fn density(&self) -> f64 {
let total = self.size * self.size;
if total == 0 {
0.0
} else {
self.nnz() as f64 / total as f64
}
}
#[allow(clippy::needless_range_loop)]
pub fn to_dense(&mut self) -> Vec<Vec<f64>> {
self.finalize();
let mut dense = vec![vec![0.0; self.size]; self.size];
for row in 0..self.size {
let start = self.row_ptr[row];
let end = self.row_ptr[row + 1];
for i in start..end {
let col = self.col_idx[i];
let value = self.values[i];
dense[row][col] = value;
}
}
dense
}
pub fn from_kernel_with_threshold(
data: &[Vec<f64>],
kernel: &dyn Kernel,
threshold: f64,
) -> Result<Self> {
let n = data.len();
let mut matrix = Self::new(n);
for i in 0..n {
for j in 0..n {
let value = kernel.compute(&data[i], &data[j])?;
if value.abs() >= threshold {
matrix.set(i, j, value);
}
}
}
matrix.finalize();
Ok(matrix)
}
pub fn row(&mut self, row_idx: usize) -> Option<Vec<(usize, f64)>> {
if row_idx >= self.size {
return None;
}
self.finalize();
let start = self.row_ptr[row_idx];
let end = self.row_ptr[row_idx + 1];
let mut row_data = Vec::new();
for i in start..end {
row_data.push((self.col_idx[i], self.values[i]));
}
Some(row_data)
}
}
pub struct SparseKernelMatrixBuilder {
threshold: f64,
max_entries_per_row: Option<usize>,
}
impl SparseKernelMatrixBuilder {
pub fn new() -> Self {
Self {
threshold: 1e-10,
max_entries_per_row: None,
}
}
pub fn with_threshold(mut self, threshold: f64) -> Result<Self> {
if threshold < 0.0 {
return Err(KernelError::InvalidParameter {
parameter: "threshold".to_string(),
value: threshold.to_string(),
reason: "must be non-negative".to_string(),
});
}
self.threshold = threshold;
Ok(self)
}
pub fn with_max_entries_per_row(mut self, max_entries: usize) -> Result<Self> {
if max_entries == 0 {
return Err(KernelError::InvalidParameter {
parameter: "max_entries_per_row".to_string(),
value: max_entries.to_string(),
reason: "must be positive".to_string(),
});
}
self.max_entries_per_row = Some(max_entries);
Ok(self)
}
pub fn build(&self, data: &[Vec<f64>], kernel: &dyn Kernel) -> Result<SparseKernelMatrix> {
let n = data.len();
let mut matrix = SparseKernelMatrix::new(n);
for i in 0..n {
let mut row_entries = Vec::new();
for j in 0..n {
let value = kernel.compute(&data[i], &data[j])?;
if value.abs() >= self.threshold {
row_entries.push((j, value));
}
}
if let Some(max_entries) = self.max_entries_per_row {
if row_entries.len() > max_entries {
row_entries.sort_by(|(_, a), (_, b)| {
b.abs()
.partial_cmp(&a.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
row_entries.truncate(max_entries);
}
}
for (j, value) in row_entries {
matrix.set(i, j, value);
}
}
matrix.finalize();
Ok(matrix)
}
}
impl Default for SparseKernelMatrixBuilder {
fn default() -> Self {
Self::new()
}
}
impl SparseKernelMatrix {
pub fn spmv(&mut self, x: &[f64]) -> Result<Vec<f64>> {
if x.len() != self.size {
return Err(KernelError::InvalidParameter {
parameter: "x".to_string(),
value: x.len().to_string(),
reason: format!("vector length must match matrix size {}", self.size),
});
}
self.finalize();
let mut y = vec![0.0; self.size];
for (row, y_elem) in y.iter_mut().enumerate() {
let start = self.row_ptr[row];
let end = self.row_ptr[row + 1];
let mut sum = 0.0;
for i in start..end {
let col = self.col_idx[i];
let value = self.values[i];
sum += value * x[col];
}
*y_elem = sum;
}
Ok(y)
}
pub fn transpose(&self) -> Result<Self> {
let mut transposed = Self::new(self.size);
for row in 0..self.size {
let start = self.row_ptr[row];
let end = self.row_ptr[row + 1];
for i in start..end {
let col = self.col_idx[i];
let value = self.values[i];
transposed.set(col, row, value);
}
}
transposed.finalize();
Ok(transposed)
}
pub fn add(&mut self, other: &Self) -> Result<Self> {
if self.size != other.size {
return Err(KernelError::InvalidParameter {
parameter: "other".to_string(),
value: other.size.to_string(),
reason: format!("matrix sizes must match: {} vs {}", self.size, other.size),
});
}
self.finalize();
let mut other_finalized = other.clone();
other_finalized.finalize();
let mut result = Self::new(self.size);
for row in 0..self.size {
let start = self.row_ptr[row];
let end = self.row_ptr[row + 1];
for i in start..end {
let col = self.col_idx[i];
let value = self.values[i];
result.set(row, col, value);
}
}
for row in 0..other_finalized.size {
let start = other_finalized.row_ptr[row];
let end = other_finalized.row_ptr[row + 1];
for i in start..end {
let col = other_finalized.col_idx[i];
let value = other_finalized.values[i];
let existing = result.get(row, col).unwrap_or(0.0);
result.set(row, col, existing + value);
}
}
result.finalize();
Ok(result)
}
pub fn frobenius_norm(&self) -> f64 {
let mut sum_squares = 0.0;
for row in 0..self.size {
let start = self.row_ptr[row];
let end = self.row_ptr[row + 1];
for i in start..end {
let value = self.values[i];
sum_squares += value * value;
}
}
sum_squares.sqrt()
}
pub fn iter_nonzeros(&mut self) -> SparseMatrixIterator<'_> {
self.finalize();
SparseMatrixIterator {
matrix: self,
current_row: 0,
current_idx: 0,
}
}
pub fn scale(&mut self, scalar: f64) {
for value in &mut self.values {
*value *= scalar;
}
for value in self.temp_map.values_mut() {
*value *= scalar;
}
}
}
pub struct SparseMatrixIterator<'a> {
matrix: &'a SparseKernelMatrix,
current_row: usize,
current_idx: usize,
}
impl<'a> Iterator for SparseMatrixIterator<'a> {
type Item = (usize, usize, f64);
fn next(&mut self) -> Option<Self::Item> {
while self.current_row < self.matrix.size {
let row_end = self.matrix.row_ptr[self.current_row + 1];
if self.current_idx < row_end {
let col = self.matrix.col_idx[self.current_idx];
let value = self.matrix.values[self.current_idx];
self.current_idx += 1;
return Some((self.current_row, col, value));
}
self.current_row += 1;
self.current_idx = self
.matrix
.row_ptr
.get(self.current_row)
.copied()
.unwrap_or(0);
}
None
}
}
impl SparseKernelMatrixBuilder {
pub fn build_parallel(
&self,
data: &[Vec<f64>],
kernel: &dyn Kernel,
) -> Result<SparseKernelMatrix> {
use rayon::prelude::*;
let n = data.len();
let mut matrix = SparseKernelMatrix::new(n);
let row_data: Vec<Vec<(usize, f64)>> = (0..n)
.into_par_iter()
.map(|i| {
let mut row_entries = Vec::new();
for j in 0..n {
match kernel.compute(&data[i], &data[j]) {
Ok(value) => {
if value.abs() >= self.threshold {
row_entries.push((j, value));
}
}
Err(_) => continue,
}
}
if let Some(max_entries) = self.max_entries_per_row {
if row_entries.len() > max_entries {
row_entries.sort_by(|(_, a), (_, b)| {
b.abs()
.partial_cmp(&a.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
row_entries.truncate(max_entries);
}
}
row_entries
})
.collect();
for (i, row_entries) in row_data.into_iter().enumerate() {
for (j, value) in row_entries {
matrix.set(i, j, value);
}
}
matrix.finalize();
Ok(matrix)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_kernels::LinearKernel;
#[test]
fn test_sparse_matrix_creation() {
let matrix = SparseKernelMatrix::new(3);
assert_eq!(matrix.size(), 3);
assert_eq!(matrix.nnz(), 0);
}
#[test]
fn test_sparse_matrix_set_get() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 1, 0.8);
matrix.set(1, 2, 0.6);
assert_eq!(matrix.get(0, 1), Some(0.8));
assert_eq!(matrix.get(1, 2), Some(0.6));
assert_eq!(matrix.get(0, 2), None);
}
#[test]
fn test_sparse_matrix_finalize() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 1, 0.8);
matrix.set(1, 2, 0.6);
matrix.set(2, 0, 0.4);
matrix.finalize();
assert_eq!(matrix.get(0, 1), Some(0.8));
assert_eq!(matrix.get(1, 2), Some(0.6));
assert_eq!(matrix.get(2, 0), Some(0.4));
}
#[test]
fn test_sparse_matrix_nnz() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 1, 0.8);
matrix.set(1, 2, 0.6);
assert_eq!(matrix.nnz(), 2);
}
#[test]
fn test_sparse_matrix_density() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 1, 0.8);
matrix.set(1, 2, 0.6);
let density = matrix.density();
assert!((density - 2.0 / 9.0).abs() < 1e-10);
}
#[test]
fn test_sparse_matrix_to_dense() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 1, 0.8);
matrix.set(1, 2, 0.6);
let dense = matrix.to_dense();
assert_eq!(dense.len(), 3);
assert!((dense[0][1] - 0.8).abs() < 1e-10);
assert!((dense[1][2] - 0.6).abs() < 1e-10);
assert!(dense[0][0].abs() < 1e-10);
}
#[test]
fn test_sparse_matrix_from_kernel() {
let kernel = LinearKernel::new();
let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
let mut matrix =
SparseKernelMatrix::from_kernel_with_threshold(&data, &kernel, 0.1).expect("unwrap");
assert!(matrix.nnz() > 0);
let dense = matrix.to_dense();
assert_eq!(dense.len(), 3);
}
#[test]
fn test_sparse_matrix_row() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 1, 0.8);
matrix.set(0, 2, 0.6);
let row = matrix.row(0).expect("unwrap");
assert_eq!(row.len(), 2);
assert!(row.contains(&(1, 0.8)));
assert!(row.contains(&(2, 0.6)));
}
#[test]
fn test_sparse_matrix_builder() {
let builder = SparseKernelMatrixBuilder::new();
let kernel = LinearKernel::new();
let data = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let matrix = builder.build(&data, &kernel).expect("unwrap");
assert!(matrix.nnz() > 0);
}
#[test]
fn test_sparse_matrix_builder_with_threshold() {
let builder = SparseKernelMatrixBuilder::new()
.with_threshold(0.5)
.expect("unwrap");
let kernel = LinearKernel::new();
let data = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let matrix = builder.build(&data, &kernel).expect("unwrap");
assert!(matrix.nnz() > 0);
}
#[test]
fn test_sparse_matrix_builder_invalid_threshold() {
let result = SparseKernelMatrixBuilder::new().with_threshold(-0.1);
assert!(result.is_err());
}
#[test]
fn test_sparse_matrix_builder_max_entries() {
let builder = SparseKernelMatrixBuilder::new()
.with_max_entries_per_row(2)
.expect("unwrap");
let kernel = LinearKernel::new();
let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
let matrix = builder.build(&data, &kernel).expect("unwrap");
for i in 0..matrix.size() {
let mut temp_matrix = matrix.clone();
let row = temp_matrix.row(i).expect("unwrap");
assert!(row.len() <= 2);
}
}
#[test]
fn test_sparse_matrix_builder_invalid_max_entries() {
let result = SparseKernelMatrixBuilder::new().with_max_entries_per_row(0);
assert!(result.is_err());
}
#[test]
fn test_sparse_matrix_zero_threshold() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 1, 1e-11); matrix.finalize();
assert_eq!(matrix.nnz(), 0);
}
#[test]
fn test_sparse_matrix_spmv() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 0, 2.0);
matrix.set(0, 2, 1.0);
matrix.set(1, 1, 3.0);
matrix.set(2, 0, 1.0);
matrix.set(2, 2, 2.0);
let x = vec![1.0, 2.0, 3.0];
let y = matrix.spmv(&x).expect("unwrap");
assert_eq!(y.len(), 3);
assert!((y[0] - 5.0).abs() < 1e-10); assert!((y[1] - 6.0).abs() < 1e-10); assert!((y[2] - 7.0).abs() < 1e-10); }
#[test]
fn test_sparse_matrix_spmv_invalid_size() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 0, 1.0);
let x = vec![1.0, 2.0]; let result = matrix.spmv(&x);
assert!(result.is_err());
}
#[test]
fn test_sparse_matrix_transpose() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 1, 0.8);
matrix.set(1, 2, 0.6);
matrix.set(2, 0, 0.4);
matrix.finalize();
let transposed = matrix.transpose().expect("unwrap");
assert_eq!(transposed.get(1, 0), Some(0.8));
assert_eq!(transposed.get(2, 1), Some(0.6));
assert_eq!(transposed.get(0, 2), Some(0.4));
}
#[test]
fn test_sparse_matrix_add() {
let mut matrix1 = SparseKernelMatrix::new(3);
matrix1.set(0, 0, 1.0);
matrix1.set(0, 1, 2.0);
matrix1.set(1, 1, 3.0);
let mut matrix2 = SparseKernelMatrix::new(3);
matrix2.set(0, 1, 1.0);
matrix2.set(1, 2, 4.0);
matrix2.set(2, 2, 5.0);
let result = matrix1.add(&matrix2).expect("unwrap");
assert_eq!(result.get(0, 0), Some(1.0));
assert_eq!(result.get(0, 1), Some(3.0)); assert_eq!(result.get(1, 1), Some(3.0));
assert_eq!(result.get(1, 2), Some(4.0));
assert_eq!(result.get(2, 2), Some(5.0));
}
#[test]
fn test_sparse_matrix_add_invalid_size() {
let mut matrix1 = SparseKernelMatrix::new(3);
matrix1.set(0, 0, 1.0);
let matrix2 = SparseKernelMatrix::new(2);
let result = matrix1.add(&matrix2);
assert!(result.is_err());
}
#[test]
fn test_sparse_matrix_frobenius_norm() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 0, 3.0);
matrix.set(1, 1, 4.0);
matrix.finalize();
let norm = matrix.frobenius_norm();
assert!((norm - 5.0).abs() < 1e-10); }
#[test]
fn test_sparse_matrix_iterator() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 1, 0.8);
matrix.set(1, 2, 0.6);
matrix.set(2, 0, 0.4);
let entries: Vec<_> = matrix.iter_nonzeros().collect();
assert_eq!(entries.len(), 3);
assert!(entries.contains(&(0, 1, 0.8)));
assert!(entries.contains(&(1, 2, 0.6)));
assert!(entries.contains(&(2, 0, 0.4)));
}
#[test]
fn test_sparse_matrix_scale() {
let mut matrix = SparseKernelMatrix::new(3);
matrix.set(0, 0, 2.0);
matrix.set(1, 1, 4.0);
matrix.finalize();
matrix.scale(0.5);
assert_eq!(matrix.get(0, 0), Some(1.0));
assert_eq!(matrix.get(1, 1), Some(2.0));
}
#[test]
fn test_sparse_matrix_builder_parallel() {
let builder = SparseKernelMatrixBuilder::new();
let kernel = LinearKernel::new();
let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
let matrix = builder.build_parallel(&data, &kernel).expect("unwrap");
assert!(matrix.nnz() > 0);
let matrix_seq = builder.build(&data, &kernel).expect("unwrap");
assert_eq!(matrix.nnz(), matrix_seq.nnz());
}
#[test]
fn test_sparse_matrix_parallel_with_threshold() {
let builder = SparseKernelMatrixBuilder::new()
.with_threshold(0.5)
.expect("unwrap");
let kernel = LinearKernel::new();
let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
let matrix = builder.build_parallel(&data, &kernel).expect("unwrap");
assert!(matrix.nnz() > 0);
}
}