use crate::types::{Precision, DimensionType, IndexType};
use crate::error::{SolverError, Result};
use crate::matrix::sparse::{CSRStorage, CSCStorage, COOStorage};
use alloc::{vec::Vec, collections::VecDeque, boxed::Box};
use core::sync::atomic::{AtomicUsize, Ordering};
#[cfg(feature = "std")]
use std::sync::Mutex;
#[cfg(feature = "simd")]
use wide::f64x4;
pub struct BufferPool {
small_buffers: VecDeque<Vec<Precision>>,
medium_buffers: VecDeque<Vec<Precision>>,
large_buffers: VecDeque<Vec<Precision>>,
allocations: AtomicUsize,
deallocations: AtomicUsize,
cache_hits: AtomicUsize,
cache_misses: AtomicUsize,
}
const SMALL_BUFFER_THRESHOLD: usize = 128; const MEDIUM_BUFFER_THRESHOLD: usize = 8192;
impl BufferPool {
pub fn new() -> Self {
Self {
small_buffers: VecDeque::with_capacity(16),
medium_buffers: VecDeque::with_capacity(8),
large_buffers: VecDeque::with_capacity(4),
allocations: AtomicUsize::new(0),
deallocations: AtomicUsize::new(0),
cache_hits: AtomicUsize::new(0),
cache_misses: AtomicUsize::new(0),
}
}
pub fn get_buffer(&mut self, min_size: usize) -> Vec<Precision> {
self.allocations.fetch_add(1, Ordering::Relaxed);
let buffer_queue = if min_size <= SMALL_BUFFER_THRESHOLD {
&mut self.small_buffers
} else if min_size <= MEDIUM_BUFFER_THRESHOLD {
&mut self.medium_buffers
} else {
&mut self.large_buffers
};
for _ in 0..buffer_queue.len() {
if let Some(mut buffer) = buffer_queue.pop_front() {
if buffer.capacity() >= min_size {
buffer.clear();
buffer.resize(min_size, 0.0);
self.cache_hits.fetch_add(1, Ordering::Relaxed);
return buffer;
} else {
buffer_queue.push_back(buffer);
}
}
}
self.cache_misses.fetch_add(1, Ordering::Relaxed);
vec![0.0; min_size]
}
pub fn return_buffer(&mut self, buffer: Vec<Precision>) {
self.deallocations.fetch_add(1, Ordering::Relaxed);
let capacity = buffer.capacity();
let buffer_queue = if capacity <= SMALL_BUFFER_THRESHOLD {
&mut self.small_buffers
} else if capacity <= MEDIUM_BUFFER_THRESHOLD {
&mut self.medium_buffers
} else {
&mut self.large_buffers
};
if buffer_queue.len() < 32 && capacity < 1_000_000 {
buffer_queue.push_back(buffer);
}
}
pub fn stats(&self) -> BufferPoolStats {
BufferPoolStats {
allocations: self.allocations.load(Ordering::Relaxed),
deallocations: self.deallocations.load(Ordering::Relaxed),
cache_hits: self.cache_hits.load(Ordering::Relaxed),
cache_misses: self.cache_misses.load(Ordering::Relaxed),
small_buffers_pooled: self.small_buffers.len(),
medium_buffers_pooled: self.medium_buffers.len(),
large_buffers_pooled: self.large_buffers.len(),
}
}
pub fn clear(&mut self) {
self.small_buffers.clear();
self.medium_buffers.clear();
self.large_buffers.clear();
}
}
impl Default for BufferPool {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub struct BufferPoolStats {
pub allocations: usize,
pub deallocations: usize,
pub cache_hits: usize,
pub cache_misses: usize,
pub small_buffers_pooled: usize,
pub medium_buffers_pooled: usize,
pub large_buffers_pooled: usize,
}
impl BufferPoolStats {
pub fn hit_rate(&self) -> f64 {
if self.allocations == 0 {
0.0
} else {
(self.cache_hits as f64 / self.allocations as f64) * 100.0
}
}
}
#[cfg(all(feature = "std", feature = "lazy_static"))]
lazy_static::lazy_static! {
static ref GLOBAL_BUFFER_POOL: Mutex<BufferPool> = Mutex::new(BufferPool::new());
}
#[cfg(all(feature = "std", feature = "lazy_static"))]
pub fn get_global_buffer(min_size: usize) -> Vec<Precision> {
GLOBAL_BUFFER_POOL.lock().unwrap().get_buffer(min_size)
}
#[cfg(all(feature = "std", feature = "lazy_static"))]
pub fn return_global_buffer(buffer: Vec<Precision>) {
GLOBAL_BUFFER_POOL.lock().unwrap().return_buffer(buffer);
}
pub struct OptimizedCSRStorage {
storage: CSRStorage,
buffer_pool: BufferPool,
workspace: Vec<Precision>,
matvec_count: AtomicUsize,
bytes_processed: AtomicUsize,
}
impl OptimizedCSRStorage {
pub fn from_coo(coo: &COOStorage, rows: DimensionType, cols: DimensionType) -> Result<Self> {
let storage = CSRStorage::from_coo(coo, rows, cols)?;
let workspace_size = rows.max(cols);
Ok(Self {
storage,
buffer_pool: BufferPool::new(),
workspace: vec![0.0; workspace_size],
matvec_count: AtomicUsize::new(0),
bytes_processed: AtomicUsize::new(0),
})
}
#[cfg(feature = "simd")]
pub fn multiply_vector_simd(&self, x: &[Precision], result: &mut [Precision]) {
result.fill(0.0);
self.matvec_count.fetch_add(1, Ordering::Relaxed);
let bytes = (self.storage.values.len() * 8) + (x.len() * 8) + (result.len() * 8);
self.bytes_processed.fetch_add(bytes, Ordering::Relaxed);
for (row, row_result) in result.iter_mut().enumerate() {
let start = self.storage.row_ptr[row] as usize;
let end = self.storage.row_ptr[row + 1] as usize;
if end <= start {
continue;
}
let row_values = &self.storage.values[start..end];
let row_indices = &self.storage.col_indices[start..end];
let simd_chunks = row_values.len() / 4;
let mut sum = f64x4::splat(0.0);
for chunk in 0..simd_chunks {
let val_idx = chunk * 4;
let values = f64x4::new([
row_values[val_idx],
row_values[val_idx + 1],
row_values[val_idx + 2],
row_values[val_idx + 3],
]);
let x_vals = f64x4::new([
x[row_indices[val_idx] as usize],
x[row_indices[val_idx + 1] as usize],
x[row_indices[val_idx + 2] as usize],
x[row_indices[val_idx + 3] as usize],
]);
sum = sum + (values * x_vals);
}
let sum_array = sum.to_array();
*row_result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
for i in (simd_chunks * 4)..row_values.len() {
let col = row_indices[i] as usize;
*row_result += row_values[i] * x[col];
}
}
}
#[cfg(not(feature = "simd"))]
pub fn multiply_vector_simd(&self, x: &[Precision], result: &mut [Precision]) {
self.multiply_vector_optimized(x, result);
}
pub fn multiply_vector_optimized(&self, x: &[Precision], result: &mut [Precision]) {
result.fill(0.0);
self.matvec_count.fetch_add(1, Ordering::Relaxed);
const BLOCK_SIZE: usize = 64;
for row_block in (0..result.len()).step_by(BLOCK_SIZE) {
let row_end = (row_block + BLOCK_SIZE).min(result.len());
for row in row_block..row_end {
let start = self.storage.row_ptr[row] as usize;
let end = self.storage.row_ptr[row + 1] as usize;
let mut sum = 0.0;
for i in start..end {
let col = self.storage.col_indices[i] as usize;
sum += self.storage.values[i] * x[col];
}
result[row] = sum;
}
}
}
pub fn multiply_vector_streaming<F>(
&self,
x: &[Precision],
mut callback: F,
chunk_size: usize
) -> Result<()>
where
F: FnMut(usize, &[Precision]),
{
let mut result_chunk = vec![0.0; chunk_size];
for chunk_start in (0..self.storage.row_ptr.len() - 1).step_by(chunk_size) {
let chunk_end = (chunk_start + chunk_size).min(self.storage.row_ptr.len() - 1);
let actual_chunk_size = chunk_end - chunk_start;
result_chunk.resize(actual_chunk_size, 0.0);
result_chunk.fill(0.0);
for (local_row, global_row) in (chunk_start..chunk_end).enumerate() {
let start = self.storage.row_ptr[global_row] as usize;
let end = self.storage.row_ptr[global_row + 1] as usize;
let mut sum = 0.0;
for i in start..end {
let col = self.storage.col_indices[i] as usize;
sum += self.storage.values[i] * x[col];
}
result_chunk[local_row] = sum;
}
callback(chunk_start, &result_chunk[..actual_chunk_size]);
}
Ok(())
}
pub fn performance_stats(&self) -> OptimizedMatrixStats {
OptimizedMatrixStats {
matvec_count: self.matvec_count.load(Ordering::Relaxed),
bytes_processed: self.bytes_processed.load(Ordering::Relaxed),
buffer_pool_stats: self.buffer_pool.stats(),
matrix_nnz: self.storage.nnz(),
matrix_rows: self.storage.row_ptr.len() - 1,
workspace_size: self.workspace.len(),
}
}
pub fn reset_stats(&self) {
self.matvec_count.store(0, Ordering::Relaxed);
self.bytes_processed.store(0, Ordering::Relaxed);
}
pub fn get_temp_buffer(&mut self, size: usize) -> Vec<Precision> {
self.buffer_pool.get_buffer(size)
}
pub fn return_temp_buffer(&mut self, buffer: Vec<Precision>) {
self.buffer_pool.return_buffer(buffer);
}
pub fn storage(&self) -> &CSRStorage {
&self.storage
}
}
#[derive(Debug, Clone)]
pub struct OptimizedMatrixStats {
pub matvec_count: usize,
pub bytes_processed: usize,
pub buffer_pool_stats: BufferPoolStats,
pub matrix_nnz: usize,
pub matrix_rows: usize,
pub workspace_size: usize,
}
impl OptimizedMatrixStats {
pub fn bandwidth_gbs(&self, total_time_ms: f64) -> f64 {
if total_time_ms <= 0.0 {
0.0
} else {
let total_gb = self.bytes_processed as f64 / 1_073_741_824.0; let total_seconds = total_time_ms / 1000.0;
total_gb / total_seconds
}
}
pub fn ops_per_second(&self, total_time_ms: f64) -> f64 {
if total_time_ms <= 0.0 {
0.0
} else {
let total_ops = self.matvec_count as f64;
let total_seconds = total_time_ms / 1000.0;
total_ops / total_seconds
}
}
}
#[cfg(feature = "std")]
pub struct ParallelCSRStorage {
storage: OptimizedCSRStorage,
num_threads: usize,
}
#[cfg(feature = "std")]
impl ParallelCSRStorage {
pub fn new(storage: OptimizedCSRStorage, num_threads: Option<usize>) -> Self {
let num_threads = num_threads.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(1)
});
Self {
storage,
num_threads,
}
}
#[cfg(feature = "rayon")]
pub fn multiply_vector_parallel(&self, x: &[Precision], result: &mut [Precision]) {
use rayon::prelude::*;
result.fill(0.0);
let rows = result.len();
let chunk_size = (rows + self.num_threads - 1) / self.num_threads;
result.par_chunks_mut(chunk_size)
.enumerate()
.for_each(|(chunk_idx, result_chunk)| {
let start_row = chunk_idx * chunk_size;
let end_row = (start_row + result_chunk.len()).min(rows);
for (local_idx, global_row) in (start_row..end_row).enumerate() {
let start = self.storage.storage.row_ptr[global_row] as usize;
let end = self.storage.storage.row_ptr[global_row + 1] as usize;
let mut sum = 0.0;
for i in start..end {
let col = self.storage.storage.col_indices[i] as usize;
sum += self.storage.storage.values[i] * x[col];
}
result_chunk[local_idx] = sum;
}
});
}
}
pub struct StreamingMatrix {
chunks: Vec<OptimizedCSRStorage>,
chunk_size: usize,
total_rows: usize,
total_cols: usize,
memory_limit: usize,
}
impl StreamingMatrix {
pub fn from_triplets(
triplets: Vec<(usize, usize, Precision)>,
rows: usize,
cols: usize,
memory_limit_mb: usize,
) -> Result<Self> {
let memory_limit = memory_limit_mb * 1_048_576;
let nnz = triplets.len();
let avg_nnz_per_row = if rows > 0 { nnz / rows } else { 0 };
let bytes_per_row = avg_nnz_per_row * (8 + 4) + 4;
let target_chunk_size = if bytes_per_row > 0 {
(memory_limit / (bytes_per_row * 2)).max(1) } else {
1000
};
let chunk_size = target_chunk_size.min(rows);
let mut sorted_triplets = triplets;
sorted_triplets.sort_by_key(|(row, _, _)| *row);
let mut chunks = Vec::new();
let num_chunks = (rows + chunk_size - 1) / chunk_size;
for chunk_idx in 0..num_chunks {
let chunk_start_row = chunk_idx * chunk_size;
let chunk_end_row = ((chunk_idx + 1) * chunk_size).min(rows);
let chunk_rows = chunk_end_row - chunk_start_row;
let chunk_triplets: Vec<(usize, usize, Precision)> = sorted_triplets
.iter()
.filter(|(row, _, _)| *row >= chunk_start_row && *row < chunk_end_row)
.map(|(row, col, val)| (row - chunk_start_row, *col, *val))
.collect();
if !chunk_triplets.is_empty() {
let coo = COOStorage::from_triplets(chunk_triplets)?;
let chunk_storage = OptimizedCSRStorage::from_coo(&coo, chunk_rows, cols)?;
chunks.push(chunk_storage);
} else {
let empty_coo = COOStorage::from_triplets(vec![])?;
let empty_storage = OptimizedCSRStorage::from_coo(&empty_coo, chunk_rows, cols)?;
chunks.push(empty_storage);
}
}
Ok(Self {
chunks,
chunk_size,
total_rows: rows,
total_cols: cols,
memory_limit,
})
}
pub fn multiply_vector_streaming<F>(
&self,
x: &[Precision],
mut callback: F,
) -> Result<()>
where
F: FnMut(usize, &[Precision]),
{
for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
let start_row = chunk_idx * self.chunk_size;
let end_row = (start_row + self.chunk_size).min(self.total_rows);
let chunk_rows = end_row - start_row;
let mut result = vec![0.0; chunk_rows];
chunk.multiply_vector_optimized(x, &mut result);
callback(start_row, &result);
}
Ok(())
}
pub fn memory_usage(&self) -> usize {
self.chunks.iter()
.map(|chunk| {
let stats = chunk.performance_stats();
stats.matrix_nnz * 12 + stats.matrix_rows * 4 })
.sum()
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
#[test]
fn test_buffer_pool() {
let mut pool = BufferPool::new();
let buffer1 = pool.get_buffer(100);
assert_eq!(buffer1.len(), 100);
pool.return_buffer(buffer1);
let buffer2 = pool.get_buffer(50);
assert_eq!(buffer2.len(), 50);
let stats = pool.stats();
assert_eq!(stats.allocations, 2);
assert_eq!(stats.deallocations, 1);
}
#[test]
fn test_optimized_csr_performance() {
let triplets = vec![
(0, 0, 2.0), (0, 1, 1.0),
(1, 0, 1.0), (1, 1, 3.0),
];
let coo = COOStorage::from_triplets(triplets).unwrap();
let optimized = OptimizedCSRStorage::from_coo(&coo, 2, 2).unwrap();
let x = vec![1.0, 2.0];
let mut result = vec![0.0; 2];
optimized.multiply_vector_optimized(&x, &mut result);
assert_eq!(result, vec![4.0, 7.0]);
let stats = optimized.performance_stats();
assert_eq!(stats.matvec_count, 1);
}
#[test]
fn test_streaming_matrix() {
let triplets = vec![
(0, 0, 1.0), (0, 1, 2.0),
(1, 0, 3.0), (1, 1, 4.0),
(2, 0, 5.0), (2, 1, 6.0),
];
let streaming = StreamingMatrix::from_triplets(triplets, 3, 2, 1).unwrap();
let x = vec![1.0, 1.0];
let mut results = Vec::new();
streaming.multiply_vector_streaming(&x, |start_row, chunk_result| {
results.extend_from_slice(chunk_result);
}).unwrap();
assert!(results.len() >= 3);
}
}