use crate::tolerances::ZERO_TOL;
#[derive(Debug, Clone)]
pub struct SparseVec {
pub indices: Vec<usize>,
pub values: Vec<f64>,
pub len: usize, }
impl SparseVec {
pub fn new(len: usize) -> Self {
Self {
indices: Vec::new(),
values: Vec::new(),
len,
}
}
pub fn from_dense(dense: &[f64]) -> Self {
let mut indices = Vec::new();
let mut values = Vec::new();
for (i, &v) in dense.iter().enumerate() {
if v.abs() > ZERO_TOL {
indices.push(i);
values.push(v);
}
}
Self {
indices,
values,
len: dense.len(),
}
}
pub fn to_dense(&self) -> Vec<f64> {
let mut dense = vec![0.0; self.len];
for (k, &idx) in self.indices.iter().enumerate() {
dense[idx] = self.values[k];
}
dense
}
pub fn to_dense_into(&self, buf: &mut [f64]) {
for v in buf.iter_mut() {
*v = 0.0;
}
for (k, &idx) in self.indices.iter().enumerate() {
buf[idx] = self.values[k];
}
}
pub fn get(&self, idx: usize) -> f64 {
match self.indices.binary_search(&idx) {
Ok(pos) => self.values[pos],
Err(_) => 0.0,
}
}
pub fn set(&mut self, idx: usize, val: f64) {
match self.indices.binary_search(&idx) {
Ok(pos) => {
if val.abs() <= ZERO_TOL {
self.indices.remove(pos);
self.values.remove(pos);
} else {
self.values[pos] = val;
}
}
Err(pos) => {
if val.abs() > ZERO_TOL {
self.indices.insert(pos, idx);
self.values.insert(pos, val);
}
}
}
}
pub fn axpy(&mut self, alpha: f64, other: &SparseVec) {
let mut new_indices = Vec::new();
let mut new_values = Vec::new();
let (mut i, mut j) = (0, 0);
while i < self.indices.len() && j < other.indices.len() {
if self.indices[i] == other.indices[j] {
let val = self.values[i] + alpha * other.values[j];
if val.abs() > ZERO_TOL {
new_indices.push(self.indices[i]);
new_values.push(val);
}
i += 1;
j += 1;
} else if self.indices[i] < other.indices[j] {
new_indices.push(self.indices[i]);
new_values.push(self.values[i]);
i += 1;
} else {
let val = alpha * other.values[j];
if val.abs() > ZERO_TOL {
new_indices.push(other.indices[j]);
new_values.push(val);
}
j += 1;
}
}
while i < self.indices.len() {
new_indices.push(self.indices[i]);
new_values.push(self.values[i]);
i += 1;
}
while j < other.indices.len() {
let val = alpha * other.values[j];
if val.abs() > ZERO_TOL {
new_indices.push(other.indices[j]);
new_values.push(val);
}
j += 1;
}
self.indices = new_indices;
self.values = new_values;
}
pub fn dot(&self, other: &SparseVec) -> f64 {
let mut result = 0.0;
let (mut i, mut j) = (0, 0);
while i < self.indices.len() && j < other.indices.len() {
if self.indices[i] == other.indices[j] {
result += self.values[i] * other.values[j];
i += 1;
j += 1;
} else if self.indices[i] < other.indices[j] {
i += 1;
} else {
j += 1;
}
}
result
}
pub fn dot_dense(&self, dense: &[f64]) -> f64 {
let mut result = 0.0;
for (k, &idx) in self.indices.iter().enumerate() {
if idx < dense.len() {
result += self.values[k] * dense[idx];
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_vec_from_dense_to_dense() {
let dense = vec![1.0, 0.0, 0.0, 3.5, 0.0, -2.0];
let sv = SparseVec::from_dense(&dense);
assert_eq!(sv.len, 6);
assert_eq!(sv.indices, vec![0, 3, 5]);
assert_eq!(sv.values, vec![1.0, 3.5, -2.0]);
let back = sv.to_dense();
assert_eq!(back, dense);
}
#[test]
fn test_sparse_vec_get_set() {
let mut sv = SparseVec::new(5);
assert_eq!(sv.get(0), 0.0);
assert_eq!(sv.get(4), 0.0);
sv.set(2, 7.0);
sv.set(4, -1.0);
assert_eq!(sv.get(2), 7.0);
assert_eq!(sv.get(4), -1.0);
assert_eq!(sv.get(3), 0.0);
sv.set(2, 3.0);
assert_eq!(sv.get(2), 3.0);
sv.set(2, 0.0);
assert_eq!(sv.get(2), 0.0);
assert_eq!(sv.indices, vec![4]);
}
#[test]
fn test_sparse_vec_dot() {
let a = SparseVec::from_dense(&[1.0, 0.0, 3.0, 0.0]);
let b = SparseVec::from_dense(&[2.0, 5.0, 4.0, 0.0]);
assert!((a.dot(&b) - 14.0).abs() < 1e-10);
let dense = vec![2.0, 5.0, 4.0, 0.0];
assert!((a.dot_dense(&dense) - 14.0).abs() < 1e-10);
}
#[test]
fn test_sparse_vec_axpy() {
let mut a = SparseVec::from_dense(&[1.0, 0.0, 3.0]);
let b = SparseVec::from_dense(&[0.0, 2.0, 1.0]);
a.axpy(2.0, &b);
let dense = a.to_dense();
assert!((dense[0] - 1.0).abs() < 1e-10);
assert!((dense[1] - 4.0).abs() < 1e-10);
assert!((dense[2] - 5.0).abs() < 1e-10);
}
#[test]
fn test_dot_different_len() {
let a = SparseVec { indices: vec![0, 2], values: vec![1.0, 2.0], len: 3 };
let b = SparseVec {
indices: vec![0, 1, 2, 3, 4],
values: vec![3.0, 4.0, 5.0, 6.0, 7.0],
len: 5,
};
assert!((a.dot(&b) - 13.0).abs() < 1e-10);
let empty_a = SparseVec::new(3);
let empty_b = SparseVec::new(5);
assert_eq!(empty_a.dot(&empty_b), 0.0);
}
#[test]
fn test_axpy_different_len() {
let mut a = SparseVec { indices: vec![0], values: vec![1.0], len: 3 };
let b = SparseVec { indices: vec![0, 2], values: vec![2.0, 3.0], len: 5 };
a.axpy(1.0, &b);
assert!((a.get(0) - 3.0).abs() < 1e-10, "index 0: expected 3.0, got {}", a.get(0));
assert!((a.get(2) - 3.0).abs() < 1e-10, "index 2: expected 3.0, got {}", a.get(2));
assert_eq!(a.get(1), 0.0, "index 1 should remain 0");
let mut empty = SparseVec::new(3);
let src = SparseVec { indices: vec![1, 2], values: vec![4.0, 5.0], len: 3 };
empty.axpy(1.0, &src);
assert!((empty.get(1) - 4.0).abs() < 1e-10, "index 1: expected 4.0, got {}", empty.get(1));
assert!((empty.get(2) - 5.0).abs() < 1e-10, "index 2: expected 5.0, got {}", empty.get(2));
assert_eq!(empty.get(0), 0.0, "index 0 should be 0");
}
}