use crate::types::Precision;
use crate::matrix::sparse::CSRStorage;
use alloc::vec::Vec;
use core::mem;
#[cfg(feature = "simd")]
use wide::f64x4;
pub struct OptimizedCSR {
storage: CSRStorage,
temp_pool: Vec<Vec<Precision>>,
matvec_count: usize,
}
impl OptimizedCSR {
pub fn from_triplets(
triplets: Vec<(usize, usize, Precision)>,
rows: usize,
cols: usize,
) -> Result<Self, String> {
let nnz = triplets.len();
let mut storage = CSRStorage::with_capacity(rows, cols, nnz);
let mut sorted_triplets = triplets;
sorted_triplets.sort_by(|a, b| {
a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1))
});
let mut current_row = 0;
for (row, col, val) in sorted_triplets.iter() {
while current_row <= *row {
storage.row_ptr[current_row] = storage.values.len() as u32;
current_row += 1;
}
storage.values.push(*val);
storage.col_indices.push(*col as u32);
}
while current_row <= rows {
storage.row_ptr[current_row] = storage.values.len() as u32;
current_row += 1;
}
let temp_pool = vec![Vec::with_capacity(rows); 4];
Ok(Self {
storage,
temp_pool,
matvec_count: 0,
})
}
pub fn multiply_vector_optimized(&mut self, x: &[Precision], y: &mut [Precision]) {
assert_eq!(x.len(), self.storage.cols);
assert_eq!(y.len(), self.storage.rows);
self.matvec_count += 1;
#[cfg(feature = "simd")]
{
self.multiply_simd_optimized(x, y);
}
#[cfg(not(feature = "simd"))]
{
self.multiply_scalar_optimized(x, y);
}
}
#[cfg(feature = "simd")]
fn multiply_simd_optimized(&self, x: &[Precision], y: &mut [Precision]) {
use std::arch::x86_64::*;
let chunks = y.len() / 4;
let remainder = y.len() % 4;
unsafe {
let zero = _mm256_setzero_pd();
for chunk in 0..chunks {
let idx = chunk * 4;
_mm256_storeu_pd(y.as_mut_ptr().add(idx), zero);
}
for i in (chunks * 4)..y.len() {
y[i] = 0.0;
}
}
for row in 0..self.storage.rows {
let start = self.storage.row_ptr[row] as usize;
let end = self.storage.row_ptr[row + 1] as usize;
let nnz = end - start;
if nnz == 0 {
continue;
}
let values = &self.storage.values[start..end];
let indices = &self.storage.col_indices[start..end];
if nnz >= 8 {
let simd_chunks = nnz / 4;
let mut sum = f64x4::splat(0.0);
for chunk in 0..simd_chunks {
let idx = chunk * 4;
let vals = f64x4::new([
values[idx],
values[idx + 1],
values[idx + 2],
values[idx + 3],
]);
let x_vals = f64x4::new([
x[indices[idx] as usize],
x[indices[idx + 1] as usize],
x[indices[idx + 2] as usize],
x[indices[idx + 3] as usize],
]);
sum = sum + (vals * x_vals);
}
let sum_array = sum.to_array();
y[row] = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
for i in (simd_chunks * 4)..nnz {
y[row] += values[i] * x[indices[i] as usize];
}
} else {
let mut sum = 0.0;
for i in 0..nnz {
sum += values[i] * x[indices[i] as usize];
}
y[row] = sum;
}
}
}
#[cfg(not(feature = "simd"))]
fn multiply_scalar_optimized(&self, x: &[Precision], y: &mut [Precision]) {
y.fill(0.0);
for row in 0..self.storage.rows {
let start = self.storage.row_ptr[row] as usize;
let end = self.storage.row_ptr[row + 1] as usize;
let mut sum = 0.0;
let values = &self.storage.values[start..end];
let indices = &self.storage.col_indices[start..end];
let chunks = (end - start) / 4;
let remainder = (end - start) % 4;
for chunk in 0..chunks {
let idx = chunk * 4;
sum += values[idx] * x[indices[idx] as usize]
+ values[idx + 1] * x[indices[idx + 1] as usize]
+ values[idx + 2] * x[indices[idx + 2] as usize]
+ values[idx + 3] * x[indices[idx + 3] as usize];
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
sum += values[i] * x[indices[i] as usize];
}
y[row] = sum;
}
}
pub fn get_temp_vector(&mut self) -> Vec<Precision> {
if let Some(mut vec) = self.temp_pool.pop() {
vec.clear();
vec.resize(self.storage.rows, 0.0);
vec
} else {
vec![0.0; self.storage.rows]
}
}
pub fn return_temp_vector(&mut self, vec: Vec<Precision>) {
if self.temp_pool.len() < 8 {
self.temp_pool.push(vec);
}
}
pub fn dimensions(&self) -> (usize, usize) {
(self.storage.rows, self.storage.cols)
}
pub fn nnz(&self) -> usize {
self.storage.values.len()
}
pub fn get_stats(&self) -> (usize, usize) {
(self.matvec_count, self.temp_pool.len())
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
#[test]
fn test_optimized_csr() {
let triplets = vec![
(0, 0, 4.0), (0, 1, 1.0),
(1, 0, 2.0), (1, 1, 3.0),
];
let mut matrix = OptimizedCSR::from_triplets(triplets, 2, 2).unwrap();
let x = vec![1.0, 2.0];
let mut y = vec![0.0; 2];
matrix.multiply_vector_optimized(&x, &mut y);
assert_eq!(y, vec![6.0, 8.0]);
}
}