use crate::matrix::sparse::{COOStorage, CSRStorage};
#[cfg(feature = "simd")]
use crate::simd_ops::{axpy_simd, dot_product_simd, matrix_vector_multiply_simd};
use crate::types::Precision;
use alloc::vec::Vec;
use core::sync::atomic::{AtomicUsize, Ordering};
#[cfg(feature = "std")]
use std::time::Instant;
pub struct OptimizedSparseMatrix {
storage: CSRStorage,
dimensions: (usize, usize),
performance_stats: PerformanceStats,
}
#[derive(Debug, Default)]
pub struct PerformanceStats {
pub matvec_count: AtomicUsize,
pub bytes_processed: AtomicUsize,
}
impl Clone for PerformanceStats {
fn clone(&self) -> Self {
Self {
matvec_count: AtomicUsize::new(self.matvec_count.load(Ordering::Relaxed)),
bytes_processed: AtomicUsize::new(self.bytes_processed.load(Ordering::Relaxed)),
}
}
}
impl OptimizedSparseMatrix {
pub fn from_triplets(
triplets: Vec<(usize, usize, Precision)>,
rows: usize,
cols: usize,
) -> Result<Self, String> {
let coo = COOStorage::from_triplets(triplets)
.map_err(|e| format!("Failed to create COO storage: {:?}", e))?;
let storage = CSRStorage::from_coo(&coo, rows, cols)
.map_err(|e| format!("Failed to create CSR storage: {:?}", e))?;
Ok(Self {
storage,
dimensions: (rows, cols),
performance_stats: PerformanceStats::default(),
})
}
pub fn dimensions(&self) -> (usize, usize) {
self.dimensions
}
pub fn nnz(&self) -> usize {
self.storage.nnz()
}
pub fn multiply_vector(&self, x: &[Precision], y: &mut [Precision]) {
assert_eq!(x.len(), self.dimensions.1);
assert_eq!(y.len(), self.dimensions.0);
self.performance_stats
.matvec_count
.fetch_add(1, Ordering::Relaxed);
let bytes = (self.storage.values.len() * 8) + (x.len() * 8) + (y.len() * 8);
self.performance_stats
.bytes_processed
.fetch_add(bytes, Ordering::Relaxed);
#[cfg(feature = "simd")]
{
matrix_vector_multiply_simd(
&self.storage.values,
&self.storage.col_indices,
&self.storage.row_ptr,
x,
y,
);
}
#[cfg(not(feature = "simd"))]
{
self.storage.multiply_vector(x, y);
}
}
pub fn get_performance_stats(&self) -> (usize, usize) {
(
self.performance_stats.matvec_count.load(Ordering::Relaxed),
self.performance_stats
.bytes_processed
.load(Ordering::Relaxed),
)
}
pub fn reset_stats(&self) {
self.performance_stats
.matvec_count
.store(0, Ordering::Relaxed);
self.performance_stats
.bytes_processed
.store(0, Ordering::Relaxed);
}
}
#[derive(Debug, Clone)]
pub struct OptimizedSolverConfig {
pub max_iterations: usize,
pub tolerance: Precision,
pub enable_profiling: bool,
}
impl Default for OptimizedSolverConfig {
fn default() -> Self {
Self {
max_iterations: 1000,
tolerance: 1e-6,
enable_profiling: false,
}
}
}
#[derive(Debug, Clone)]
pub struct OptimizedSolverResult {
pub solution: Vec<Precision>,
pub residual_norm: Precision,
pub iterations: usize,
pub converged: bool,
#[cfg(feature = "std")]
pub computation_time_ms: f64,
#[cfg(not(feature = "std"))]
pub computation_time_ms: u64,
pub performance_stats: OptimizedSolverStats,
}
#[derive(Debug, Clone, Default)]
pub struct OptimizedSolverStats {
pub matvec_count: usize,
pub dot_product_count: usize,
pub axpy_count: usize,
pub total_flops: usize,
pub average_bandwidth_gbs: f64,
pub average_gflops: f64,
}
pub struct OptimizedConjugateGradientSolver {
config: OptimizedSolverConfig,
stats: OptimizedSolverStats,
}
impl OptimizedConjugateGradientSolver {
pub fn new(config: OptimizedSolverConfig) -> Self {
Self {
config,
stats: OptimizedSolverStats::default(),
}
}
pub fn solve(
&mut self,
matrix: &OptimizedSparseMatrix,
b: &[Precision],
) -> Result<OptimizedSolverResult, String> {
let (rows, cols) = matrix.dimensions();
if rows != cols {
return Err("Matrix must be square".to_string());
}
if b.len() != rows {
return Err("Right-hand side vector length must match matrix size".to_string());
}
#[cfg(feature = "std")]
let start_time = Instant::now();
self.stats = OptimizedSolverStats::default();
let mut x = vec![0.0; rows];
let mut r = vec![0.0; rows];
let mut p = vec![0.0; rows];
let mut ap = vec![0.0; rows];
r.copy_from_slice(b);
let mut iteration = 0;
let tolerance_sq = self.config.tolerance * self.config.tolerance;
let mut converged = false;
let mut rsold = self.dot_product(&r, &r);
p.copy_from_slice(&r);
while iteration < self.config.max_iterations {
if rsold <= tolerance_sq {
converged = true;
break;
}
matrix.multiply_vector(&p, &mut ap);
self.stats.matvec_count += 1;
let pap = self.dot_product(&p, &ap);
if pap.abs() < 1e-16 {
break; }
let alpha = rsold / pap;
self.axpy(alpha, &p, &mut x);
self.axpy(-alpha, &ap, &mut r);
let rsnew = self.dot_product(&r, &r);
let beta = rsnew / rsold;
for (pi, &ri) in p.iter_mut().zip(r.iter()) {
*pi = ri + beta * *pi;
}
rsold = rsnew;
iteration += 1;
}
#[cfg(feature = "std")]
let computation_time_ms = start_time.elapsed().as_millis() as f64;
#[cfg(not(feature = "std"))]
let computation_time_ms = 0.0;
let final_residual_norm = rsold.sqrt();
self.stats.total_flops = self.stats.matvec_count * matrix.nnz() * 2 + iteration * rows * 6;
if computation_time_ms > 0.0 {
let total_gb = (self.stats.total_flops * 8) as f64 / 1e9;
self.stats.average_bandwidth_gbs = total_gb / (computation_time_ms / 1000.0);
self.stats.average_gflops =
(self.stats.total_flops as f64) / (computation_time_ms * 1e6);
}
Ok(OptimizedSolverResult {
solution: x,
residual_norm: final_residual_norm,
iterations: iteration,
converged,
computation_time_ms,
performance_stats: self.stats.clone(),
})
}
fn dot_product(&mut self, x: &[Precision], y: &[Precision]) -> Precision {
self.stats.dot_product_count += 1;
#[cfg(feature = "simd")]
{
dot_product_simd(x, y)
}
#[cfg(not(feature = "simd"))]
{
x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum()
}
}
fn axpy(&mut self, alpha: Precision, x: &[Precision], y: &mut [Precision]) {
self.stats.axpy_count += 1;
#[cfg(feature = "simd")]
{
axpy_simd(alpha, x, y);
}
#[cfg(not(feature = "simd"))]
{
for (yi, &xi) in y.iter_mut().zip(x.iter()) {
*yi += alpha * xi;
}
}
}
fn l2_norm(&self, x: &[Precision]) -> Precision {
x.iter().map(|&xi| xi * xi).sum::<Precision>().sqrt()
}
pub fn get_last_iteration_count(&self) -> usize {
self.stats.matvec_count
}
pub fn solve_with_callback<F>(
&mut self,
matrix: &OptimizedSparseMatrix,
b: &[Precision],
_chunk_size: usize,
mut _callback: F,
) -> Result<OptimizedSolverResult, String>
where
F: FnMut(&OptimizedSolverStats),
{
self.solve(matrix, b)
}
}
impl OptimizedSolverResult {
pub fn data(&self) -> &[Precision] {
&self.solution
}
}
#[derive(Debug, Clone, Default)]
pub struct OptimizedSolverOptions {
pub track_performance: bool,
pub track_memory: bool,
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
fn create_test_matrix() -> OptimizedSparseMatrix {
let triplets = vec![(0, 0, 4.0), (0, 1, 1.0), (1, 0, 1.0), (1, 1, 3.0)];
OptimizedSparseMatrix::from_triplets(triplets, 2, 2).unwrap()
}
#[test]
fn test_optimized_matrix_creation() {
let matrix = create_test_matrix();
assert_eq!(matrix.dimensions(), (2, 2));
assert_eq!(matrix.nnz(), 4);
}
#[test]
fn test_optimized_matrix_vector_multiply() {
let matrix = create_test_matrix();
let x = vec![1.0, 2.0];
let mut y = vec![0.0; 2];
matrix.multiply_vector(&x, &mut y);
assert_eq!(y, vec![6.0, 7.0]); }
#[test]
fn test_optimized_conjugate_gradient() {
let matrix = create_test_matrix();
let b = vec![1.0, 2.0];
let config = OptimizedSolverConfig::default();
let mut solver = OptimizedConjugateGradientSolver::new(config);
let result = solver.solve(&matrix, &b).unwrap();
assert!(result.converged);
assert!(result.residual_norm < 1e-6);
assert!(result.iterations > 0);
let mut ax = vec![0.0; 2];
matrix.multiply_vector(&result.solution, &mut ax);
let error = ((ax[0] - b[0]).powi(2) + (ax[1] - b[1]).powi(2)).sqrt();
assert!(error < 1e-10);
}
#[test]
fn test_solver_performance_stats() {
let matrix = create_test_matrix();
let b = vec![1.0, 2.0];
let config = OptimizedSolverConfig::default();
let mut solver = OptimizedConjugateGradientSolver::new(config);
let result = solver.solve(&matrix, &b).unwrap();
assert!(result.performance_stats.matvec_count > 0);
assert!(result.performance_stats.dot_product_count > 0);
assert!(result.performance_stats.total_flops > 0);
}
}