use crate::types::Precision;
#[derive(Debug, Clone)]
pub struct FastCSRMatrix {
pub values: Vec<Precision>,
pub col_indices: Vec<u32>,
pub row_ptr: Vec<u32>,
pub rows: usize,
pub cols: usize,
}
impl FastCSRMatrix {
pub fn from_triplets(
triplets: Vec<(usize, usize, Precision)>,
rows: usize,
cols: usize,
) -> Self {
let mut sorted_triplets = triplets;
sorted_triplets.sort_unstable_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
let nnz = sorted_triplets.len();
let mut values = Vec::with_capacity(nnz);
let mut col_indices = Vec::with_capacity(nnz);
let mut row_ptr = vec![0u32; rows + 1];
let mut current_row = 0;
for (row, col, val) in sorted_triplets {
while current_row <= row {
row_ptr[current_row] = values.len() as u32;
current_row += 1;
}
values.push(val);
col_indices.push(col as u32);
}
while current_row <= rows {
row_ptr[current_row] = values.len() as u32;
current_row += 1;
}
Self {
values,
col_indices,
row_ptr,
rows,
cols,
}
}
pub fn multiply_vector_fast(&self, x: &[Precision], y: &mut [Precision]) {
y.fill(0.0);
for row in 0..self.rows {
let start = self.row_ptr[row] as usize;
let end = self.row_ptr[row + 1] as usize;
if start >= end {
continue;
}
let nnz = end - start;
let values = &self.values[start..end];
let indices = &self.col_indices[start..end];
if nnz <= 4 {
let mut sum = 0.0;
for i in 0..nnz {
sum += values[i] * x[indices[i] as usize];
}
y[row] = sum;
} else {
let chunks = nnz / 4;
let remainder = nnz % 4;
let mut sum = 0.0;
for chunk in 0..chunks {
let base = chunk * 4;
sum += values[base] * x[indices[base] as usize]
+ values[base + 1] * x[indices[base + 1] as usize]
+ values[base + 2] * x[indices[base + 2] as usize]
+ values[base + 3] * x[indices[base + 3] as usize];
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
sum += values[i] * x[indices[i] as usize];
}
y[row] = sum;
}
}
}
}
pub struct FastConjugateGradient {
pub max_iterations: usize,
pub tolerance: Precision,
}
impl FastConjugateGradient {
pub fn new(max_iterations: usize, tolerance: Precision) -> Self {
Self {
max_iterations,
tolerance,
}
}
pub fn solve(&self, matrix: &FastCSRMatrix, b: &[Precision]) -> Vec<Precision> {
let n = matrix.rows;
assert_eq!(b.len(), n);
assert_eq!(matrix.rows, matrix.cols);
let mut x = vec![0.0; n];
let mut r = vec![0.0; n];
let mut p = vec![0.0; n];
let mut ap = vec![0.0; n];
r.copy_from_slice(b);
p.copy_from_slice(b);
let mut rsold = Self::dot_product_fast(&r, &r);
let tolerance_sq = self.tolerance * self.tolerance;
for _iteration in 0..self.max_iterations {
if rsold <= tolerance_sq {
break;
}
matrix.multiply_vector_fast(&p, &mut ap);
let pap = Self::dot_product_fast(&p, &ap);
if pap.abs() < 1e-16 {
break;
}
let alpha = rsold / pap;
Self::axpy_fast(alpha, &p, &mut x);
Self::axpy_fast(-alpha, &ap, &mut r);
let rsnew = Self::dot_product_fast(&r, &r);
let beta = rsnew / rsold;
for i in 0..n {
p[i] = r[i] + beta * p[i];
}
rsold = rsnew;
}
x
}
#[inline]
fn dot_product_fast(x: &[Precision], y: &[Precision]) -> Precision {
let n = x.len();
let chunks = n / 4;
let remainder = n % 4;
let mut sum = 0.0;
for chunk in 0..chunks {
let base = chunk * 4;
sum += x[base] * y[base]
+ x[base + 1] * y[base + 1]
+ x[base + 2] * y[base + 2]
+ x[base + 3] * y[base + 3];
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
sum += x[i] * y[i];
}
sum
}
#[inline]
fn axpy_fast(alpha: Precision, x: &[Precision], y: &mut [Precision]) {
let n = x.len();
let chunks = n / 4;
let remainder = n % 4;
for chunk in 0..chunks {
let base = chunk * 4;
y[base] += alpha * x[base];
y[base + 1] += alpha * x[base + 1];
y[base + 2] += alpha * x[base + 2];
y[base + 3] += alpha * x[base + 3];
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
y[i] += alpha * x[i];
}
}
}
pub struct VectorPool {
buffers: Vec<Vec<Precision>>,
size: usize,
}
impl VectorPool {
pub fn new(size: usize, capacity: usize) -> Self {
let buffers = (0..capacity)
.map(|_| vec![0.0; size])
.collect();
Self { buffers, size }
}
pub fn get_buffer(&mut self) -> Vec<Precision> {
self.buffers.pop().unwrap_or_else(|| vec![0.0; self.size])
}
pub fn return_buffer(&mut self, mut buffer: Vec<Precision>) {
if buffer.len() == self.size && self.buffers.len() < 8 {
buffer.fill(0.0);
self.buffers.push(buffer);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fast_csr_matrix() {
let triplets = vec![
(0, 0, 4.0), (0, 1, 1.0),
(1, 0, 2.0), (1, 1, 3.0),
];
let matrix = FastCSRMatrix::from_triplets(triplets, 2, 2);
let x = vec![1.0, 2.0];
let mut y = vec![0.0; 2];
matrix.multiply_vector_fast(&x, &mut y);
assert_eq!(y, vec![6.0, 8.0]); }
#[test]
fn test_fast_conjugate_gradient() {
let triplets = vec![
(0, 0, 4.0), (0, 1, 1.0),
(1, 0, 1.0), (1, 1, 3.0),
];
let matrix = FastCSRMatrix::from_triplets(triplets, 2, 2);
let b = vec![1.0, 2.0];
let solver = FastConjugateGradient::new(1000, 1e-10);
let solution = solver.solve(&matrix, &b);
let mut result = vec![0.0; 2];
matrix.multiply_vector_fast(&solution, &mut result);
let error = ((result[0] - b[0]).powi(2) + (result[1] - b[1]).powi(2)).sqrt();
assert!(error < 1e-8);
}
}