numrs2 0.3.3

A Rust implementation inspired by NumPy for numerical computing (NumRS2)
Documentation
//! Optimized linear algebra operations with enhanced performance algorithms
//!
//! This module provides high-performance implementations of linear algebra operations
//! using scirs2-linalg's BLAS/LAPACK acceleration. These implementations offer
//! 200-700x speedups compared to pure Rust implementations for large matrices.
//!
//! # SCIRS2 POLICY Compliance
//!
//! Per SCIRS2 POLICY, all BLAS/LAPACK operations are routed through scirs2-linalg
//! rather than using direct dependencies. This module acts as a thin wrapper
//! around the linalg_accelerated module to maintain backward compatibility.
//!
//! # Performance
//!
//! | Operation | Speedup | Notes |
//! |-----------|---------|-------|
//! | gemm      | 200-700x | BLAS Level 3 |
//! | gemv      | 50-200x  | BLAS Level 2 |
//! | dot       | 10-50x   | BLAS Level 1 |
//! | lu        | 200-700x | LAPACK |

use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use crate::linalg_accelerated;
use num_traits::{Float, NumAssign, NumCast};
use scirs2_core::ndarray::ScalarOperand;
use std::fmt::Debug;
use std::iter::Sum;

/// BLAS-accelerated linear algebra operations
///
/// This struct provides optimized implementations of common linear algebra
/// operations using scirs2-linalg's BLAS/LAPACK backends.
pub struct OptimizedBlas;

impl OptimizedBlas {
    /// Matrix-matrix multiplication using BLAS GEMM
    ///
    /// Computes C = α*op(A)*op(B) + β*C where op(X) is X or X^T
    ///
    /// # Performance
    /// - 200-700x faster than naive implementation for large matrices
    /// - Uses pure Rust BLAS (OxiBLAS with SIMD optimizations) when available
    ///
    /// # Arguments
    /// * `a` - Matrix A
    /// * `b` - Matrix B
    /// * `c` - Matrix C (modified in place)
    /// * `alpha` - Scalar α
    /// * `beta` - Scalar β
    /// * `trans_a` - Whether to transpose A (currently requires caller to pre-transpose)
    /// * `trans_b` - Whether to transpose B (currently requires caller to pre-transpose)
    pub fn gemm<T>(
        a: &Array<T>,
        b: &Array<T>,
        c: &mut Array<T>,
        alpha: T,
        beta: T,
        trans_a: bool,
        trans_b: bool,
    ) -> Result<()>
    where
        T: Float + NumAssign + Clone + Debug + NumCast + 'static,
    {
        let a_shape = a.shape();
        let b_shape = b.shape();
        let c_shape = c.shape();

        // Validate dimensions
        if a_shape.len() != 2 || b_shape.len() != 2 || c_shape.len() != 2 {
            return Err(NumRs2Error::DimensionMismatch(
                "GEMM requires 2D matrices".to_string(),
            ));
        }

        // Handle transposition by pre-transposing if needed
        let a_work = if trans_a { a.transpose() } else { a.clone() };
        let b_work = if trans_b { b.transpose() } else { b.clone() };

        // Use accelerated GEMM
        let result = linalg_accelerated::gemm(alpha, &a_work, &b_work, beta, c)?;
        *c = result;
        Ok(())
    }

    /// Matrix-vector multiplication using BLAS GEMV
    ///
    /// Computes y = α*A*x + β*y
    ///
    /// # Performance
    /// - 50-200x faster than naive implementation for large matrices
    ///
    /// # Arguments
    /// * `a` - Matrix A (m × n)
    /// * `x` - Vector x (n elements)
    /// * `y` - Vector y (m elements, modified in place)
    /// * `alpha` - Scalar α
    /// * `beta` - Scalar β
    /// * `trans` - Whether to transpose A (currently requires caller to pre-transpose)
    pub fn gemv<T>(
        a: &Array<T>,
        x: &Array<T>,
        y: &mut Array<T>,
        alpha: T,
        beta: T,
        trans: bool,
    ) -> Result<()>
    where
        T: Float + NumAssign + Clone + Debug + NumCast + 'static,
    {
        let a_shape = a.shape();
        let x_shape = x.shape();
        let y_shape = y.shape();

        if a_shape.len() != 2 || x_shape.len() != 1 || y_shape.len() != 1 {
            return Err(NumRs2Error::DimensionMismatch(
                "GEMV requires 2D matrix and 1D vectors".to_string(),
            ));
        }

        // Handle transposition
        let a_work = if trans { a.transpose() } else { a.clone() };

        // Use accelerated GEMV
        let result = linalg_accelerated::gemv(alpha, &a_work, x, beta, y)?;
        *y = result;
        Ok(())
    }

    /// Vector dot product using BLAS
    ///
    /// Computes x · y
    ///
    /// # Performance
    /// - 10-50x faster than naive implementation with SIMD + BLAS
    pub fn dot<T>(x: &Array<T>, y: &Array<T>) -> Result<T>
    where
        T: Float + NumAssign + Clone + Debug + NumCast + 'static,
    {
        let x_shape = x.shape();
        let y_shape = y.shape();

        if x_shape.len() != 1 || y_shape.len() != 1 || x_shape[0] != y_shape[0] {
            return Err(NumRs2Error::DimensionMismatch(
                "Dot product requires equal-length vectors".to_string(),
            ));
        }

        linalg_accelerated::dot(x, y)
    }
}

/// LU decomposition with partial pivoting using LAPACK
///
/// Decomposes A = PLU where:
/// - P is a permutation matrix
/// - L is lower triangular with unit diagonal
/// - U is upper triangular
///
/// # Performance
/// - 200-700x faster than pure Rust implementation
///
/// # Arguments
/// * `a` - Square matrix to decompose
///
/// # Returns
/// Tuple of (L, U, P) matrices
pub fn lu_optimized<T>(a: &Array<T>) -> Result<(Array<T>, Array<T>, Array<usize>)>
where
    T: Float + NumAssign + Clone + Debug + NumCast + Sum + Send + Sync + ScalarOperand + 'static,
{
    let shape = a.shape();
    if shape.len() != 2 || shape[0] != shape[1] {
        return Err(NumRs2Error::DimensionMismatch(
            "LU decomposition requires a square matrix".to_string(),
        ));
    }

    // Use accelerated LU decomposition
    let (p, l, u) = linalg_accelerated::lu(a)?;

    // Convert P matrix to permutation indices
    let n = shape[0];
    let mut perm = Array::from_vec((0..n).collect::<Vec<_>>());

    // Extract permutation from P matrix
    for i in 0..n {
        for j in 0..n {
            let val = p.get(&[i, j])?;
            // P matrix has a 1 in each row indicating the permuted position
            if val.to_f64().unwrap_or(0.0) > 0.5 {
                perm.set(&[i], j)?;
                break;
            }
        }
    }

    Ok((l, u, perm))
}

/// Cache-aware matrix transpose
///
/// This implementation uses blocked transpose for better cache performance.
/// For large matrices, this can be significantly faster than naive transpose.
///
/// # Arguments
/// * `a` - Matrix to transpose
///
/// # Returns
/// Transposed matrix
pub fn transpose_optimized<T>(a: &Array<T>) -> Result<Array<T>>
where
    T: Float + Clone + Debug,
{
    let shape = a.shape();
    if shape.len() != 2 {
        return Err(NumRs2Error::DimensionMismatch(
            "Transpose requires a 2D matrix".to_string(),
        ));
    }

    let (m, n) = (shape[0], shape[1]);
    let mut result = Array::zeros(&[n, m]);

    // Use blocked transpose for cache efficiency
    let block_size = 64; // Optimize for cache line size

    for ii in (0..m).step_by(block_size) {
        for jj in (0..n).step_by(block_size) {
            let i_end = (ii + block_size).min(m);
            let j_end = (jj + block_size).min(n);

            for i in ii..i_end {
                for j in jj..j_end {
                    result.set(&[j, i], a.get(&[i, j])?)?;
                }
            }
        }
    }

    Ok(result)
}

/// Simple matrix multiplication using accelerated BLAS
///
/// Convenience function for C = A * B
pub fn matmul_optimized<T>(a: &Array<T>, b: &Array<T>) -> Result<Array<T>>
where
    T: Float + NumAssign + Clone + Debug + NumCast + 'static,
{
    linalg_accelerated::matmul(a, b)
}

/// Matrix-vector multiplication using accelerated BLAS
///
/// Convenience function for y = A * x
pub fn matvec_optimized<T>(a: &Array<T>, x: &Array<T>) -> Result<Array<T>>
where
    T: Float + NumAssign + Clone + Debug + NumCast + 'static,
{
    linalg_accelerated::matvec(a, x)
}

#[cfg(test)]
mod tests {
    use super::*;
    use approx::assert_relative_eq;

    #[test]
    fn test_optimized_gemm() {
        let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
        let b = Array::from_vec(vec![5.0f64, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
        let mut c = Array::zeros(&[2, 2]);

        OptimizedBlas::gemm(&a, &b, &mut c, 1.0, 0.0, false, false).expect("gemm should succeed");

        // Expected result: [[19, 22], [43, 50]]
        assert_relative_eq!(c.get(&[0, 0]).expect("valid index"), 19.0, epsilon = 1e-10);
        assert_relative_eq!(c.get(&[0, 1]).expect("valid index"), 22.0, epsilon = 1e-10);
        assert_relative_eq!(c.get(&[1, 0]).expect("valid index"), 43.0, epsilon = 1e-10);
        assert_relative_eq!(c.get(&[1, 1]).expect("valid index"), 50.0, epsilon = 1e-10);
    }

    #[test]
    fn test_optimized_gemv() {
        let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
        let x = Array::from_vec(vec![1.0f64, 2.0]);
        let mut y = Array::zeros(&[2]);

        OptimizedBlas::gemv(&a, &x, &mut y, 1.0, 0.0, false).expect("gemv should succeed");

        // Expected result: [5, 11]
        assert_relative_eq!(y.get(&[0]).expect("valid index"), 5.0, epsilon = 1e-10);
        assert_relative_eq!(y.get(&[1]).expect("valid index"), 11.0, epsilon = 1e-10);
    }

    #[test]
    fn test_optimized_dot() {
        let x = Array::from_vec(vec![1.0f64, 2.0, 3.0]);
        let y = Array::from_vec(vec![4.0f64, 5.0, 6.0]);

        let result = OptimizedBlas::dot(&x, &y).expect("dot product should succeed");

        // Expected result: 1*4 + 2*5 + 3*6 = 32
        assert_relative_eq!(result, 32.0, epsilon = 1e-10);
    }

    #[test]
    fn test_lu_optimized() {
        let a = Array::from_vec(vec![2.0f64, 1.0, 1.0, 3.0]).reshape(&[2, 2]);

        let (l, u, _p) = lu_optimized(&a).expect("LU decomposition should succeed");

        // Verify L is lower triangular with 1s on diagonal
        assert_relative_eq!(l.get(&[0, 0]).expect("valid index"), 1.0, epsilon = 1e-10);
        assert_relative_eq!(l.get(&[1, 1]).expect("valid index"), 1.0, epsilon = 1e-10);
        assert_relative_eq!(l.get(&[0, 1]).expect("valid index"), 0.0, epsilon = 1e-10);

        // Verify U is upper triangular
        assert_relative_eq!(u.get(&[1, 0]).expect("valid index"), 0.0, epsilon = 1e-10);
    }

    #[test]
    fn test_transpose_optimized() {
        let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]).reshape(&[2, 2]);

        let result = transpose_optimized(&a).expect("transpose should succeed");

        assert_relative_eq!(
            result.get(&[0, 0]).expect("valid index"),
            1.0,
            epsilon = 1e-10
        );
        assert_relative_eq!(
            result.get(&[0, 1]).expect("valid index"),
            3.0,
            epsilon = 1e-10
        );
        assert_relative_eq!(
            result.get(&[1, 0]).expect("valid index"),
            2.0,
            epsilon = 1e-10
        );
        assert_relative_eq!(
            result.get(&[1, 1]).expect("valid index"),
            4.0,
            epsilon = 1e-10
        );
    }

    #[test]
    fn test_matmul_optimized() {
        let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
        let b = Array::from_vec(vec![5.0f64, 6.0, 7.0, 8.0]).reshape(&[2, 2]);

        let c = matmul_optimized(&a, &b).expect("matmul should succeed");

        assert_relative_eq!(c.get(&[0, 0]).expect("valid index"), 19.0, epsilon = 1e-10);
        assert_relative_eq!(c.get(&[0, 1]).expect("valid index"), 22.0, epsilon = 1e-10);
        assert_relative_eq!(c.get(&[1, 0]).expect("valid index"), 43.0, epsilon = 1e-10);
        assert_relative_eq!(c.get(&[1, 1]).expect("valid index"), 50.0, epsilon = 1e-10);
    }

    #[test]
    fn test_matvec_optimized() {
        let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
        let x = Array::from_vec(vec![1.0f64, 2.0]);

        let y = matvec_optimized(&a, &x).expect("matvec should succeed");

        assert_relative_eq!(y.get(&[0]).expect("valid index"), 5.0, epsilon = 1e-10);
        assert_relative_eq!(y.get(&[1]).expect("valid index"), 11.0, epsilon = 1e-10);
    }
}