use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign};
use crate::error::{LinalgError, LinalgResult};
#[allow(dead_code)]
pub fn dot<F>(x: &ArrayView1<F>, y: &ArrayView1<F>) -> LinalgResult<F>
where
F: Float + NumAssign + 'static,
{
if x.len() != y.len() {
return Err(LinalgError::ShapeError(format!(
"Vectors must have the same length for dot product, got {} and {}",
x.len(),
y.len()
)));
}
Ok(x.dot(y))
}
#[allow(dead_code)]
pub fn norm<F>(x: &ArrayView1<F>) -> LinalgResult<F>
where
F: Float + NumAssign + 'static,
{
if x.is_empty() {
return Err(LinalgError::InvalidInputError(
"Cannot compute norm of an empty vector".to_string(),
));
}
let mut sum = F::zero();
for &val in x.iter() {
sum += val * val;
}
Ok(Float::sqrt(sum))
}
#[allow(dead_code)]
pub fn gemv<F>(
alpha: F,
a: &ArrayView2<F>,
x: &ArrayView1<F>,
beta: F,
y: &ArrayView1<F>,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + 'static,
{
if a.ncols() != x.len() {
return Err(LinalgError::ShapeError(format!(
"Matrix columns ({}) must match vector length ({}) for gemv",
a.ncols(),
x.len()
)));
}
if a.nrows() != y.len() {
return Err(LinalgError::ShapeError(format!(
"Matrix rows ({}) must match result vector length ({}) for gemv",
a.nrows(),
y.len()
)));
}
let mut result = y.to_owned();
if beta != F::one() {
result.map_inplace(|v| *v *= beta);
}
let ax = a.dot(x);
result.zip_mut_with(&ax, |y_i, &ax_i| *y_i += alpha * ax_i);
Ok(result)
}
#[allow(dead_code)]
pub fn gemm<F>(
alpha: F,
a: &ArrayView2<F>,
b: &ArrayView2<F>,
beta: F,
c: &ArrayView2<F>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + 'static,
{
if a.ncols() != b.nrows() {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions not compatible for multiplication: a.ncols ({}) != b.nrows ({})",
a.ncols(),
b.nrows()
)));
}
if a.nrows() != c.nrows() || b.ncols() != c.ncols() {
return Err(LinalgError::ShapeError(format!(
"Output matrix dimensions ({},{}) don't match expected ({},{})",
c.nrows(),
c.ncols(),
a.nrows(),
b.ncols()
)));
}
let mut result = c.to_owned();
if beta != F::one() {
result.map_inplace(|v| *v *= beta);
}
let ab = a.dot(b);
result.zip_mut_with(&ab, |c_ij, &ab_ij| *c_ij += alpha * ab_ij);
Ok(result)
}
#[allow(dead_code)]
pub fn matmul<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + 'static,
{
if a.ncols() != b.nrows() {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions not compatible for multiplication: a.ncols ({}) != b.nrows ({})",
a.ncols(),
b.nrows()
)));
}
Ok(a.dot(b))
}
#[allow(dead_code)]
pub fn solve<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + 'static,
{
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Matrix must be square for solve, got shape {:?}",
a.shape()
)));
}
if a.nrows() != b.len() {
return Err(LinalgError::ShapeError(format!(
"Matrix rows ({}) must match vector length ({}) for solve",
a.nrows(),
b.len()
)));
}
let n = a.nrows();
let mut aug = Array2::<F>::zeros((n, n + 1));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = a[[i, j]];
}
aug[[i, n]] = b[i];
}
for i in 0..n {
let mut max_row = i;
let mut max_val = Float::abs(aug[[i, i]]);
for j in (i + 1)..n {
let val = Float::abs(aug[[j, i]]);
if val > max_val {
max_row = j;
max_val = val;
}
}
if max_val < F::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Matrix is singular or nearly singular".to_string(),
));
}
if max_row != i {
for j in 0..(n + 1) {
let temp = aug[[i, j]];
aug[[i, j]] = aug[[max_row, j]];
aug[[max_row, j]] = temp;
}
}
for j in (i + 1)..n {
let factor = aug[[j, i]] / aug[[i, i]];
aug[[j, i]] = F::zero();
for k in (i + 1)..(n + 1) {
aug[[j, k]] = aug[[j, k]] - factor * aug[[i, k]];
}
}
}
let mut x = Array1::<F>::zeros(n);
for i in (0..n).rev() {
let mut sum = aug[[i, n]];
for j in (i + 1)..n {
sum -= aug[[i, j]] * x[j];
}
x[i] = sum / aug[[i, i]];
}
Ok(x)
}
#[allow(dead_code)]
pub fn inv<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + 'static,
{
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Matrix must be square for inverse, got shape {:?}",
a.shape()
)));
}
let n = a.nrows();
let mut aug = Array2::<F>::zeros((n, 2 * n));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = a[[i, j]];
}
aug[[i, i + n]] = F::one(); }
for i in 0..n {
let mut max_row = i;
let mut max_val = Float::abs(aug[[i, i]]);
for j in (i + 1)..n {
let val = Float::abs(aug[[j, i]]);
if val > max_val {
max_row = j;
max_val = val;
}
}
if max_val < F::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Matrix is singular or nearly singular".to_string(),
));
}
if max_row != i {
for j in 0..(2 * n) {
let temp = aug[[i, j]];
aug[[i, j]] = aug[[max_row, j]];
aug[[max_row, j]] = temp;
}
}
let pivot = aug[[i, i]];
for j in 0..(2 * n) {
aug[[i, j]] /= pivot;
}
for j in 0..n {
if j != i {
let factor = aug[[j, i]];
for k in 0..(2 * n) {
aug[[j, k]] = aug[[j, k]] - factor * aug[[i, k]];
}
}
}
}
let mut a_inv = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
a_inv[[i, j]] = aug[[i, j + n]];
}
}
Ok(a_inv)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::{array, Array1, Array2};
#[test]
fn test_dot() {
let x = array![1.0, 2.0, 3.0];
let y = array![4.0, 5.0, 6.0];
let result = dot(&x.view(), &y.view()).expect("Operation failed");
assert_relative_eq!(result, 32.0, epsilon = 1e-10); }
#[test]
fn test_norm() {
let x = array![3.0, 4.0];
let result = norm(&x.view()).expect("Operation failed");
assert_relative_eq!(result, 5.0, epsilon = 1e-10); }
#[test]
fn test_gemv() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let x = array![2.0, 3.0];
let y = Array1::<f64>::zeros(2);
let result = gemv(1.0, &a.view(), &x.view(), 0.0, &y.view()).expect("Operation failed");
assert_relative_eq!(result[0], 8.0, epsilon = 1e-10); assert_relative_eq!(result[1], 18.0, epsilon = 1e-10); }
#[test]
fn test_gemm() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0]];
let c = Array2::<f64>::zeros((2, 2));
let result = gemm(1.0, &a.view(), &b.view(), 0.0, &c.view()).expect("Operation failed");
assert_relative_eq!(result[[0, 0]], 19.0, epsilon = 1e-10); assert_relative_eq!(result[[0, 1]], 22.0, epsilon = 1e-10); assert_relative_eq!(result[[1, 0]], 43.0, epsilon = 1e-10); assert_relative_eq!(result[[1, 1]], 50.0, epsilon = 1e-10); }
#[test]
fn test_matmul() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0]];
let result = matmul(&a.view(), &b.view()).expect("Operation failed");
assert_relative_eq!(result[[0, 0]], 19.0, epsilon = 1e-10); assert_relative_eq!(result[[0, 1]], 22.0, epsilon = 1e-10); assert_relative_eq!(result[[1, 0]], 43.0, epsilon = 1e-10); assert_relative_eq!(result[[1, 1]], 50.0, epsilon = 1e-10); }
#[test]
fn test_solve() {
let a = array![[3.0, 1.0], [1.0, 2.0]];
let b = array![9.0, 8.0];
let x = solve(&a.view(), &b.view()).expect("Operation failed");
assert_relative_eq!(x[0], 2.0, epsilon = 1e-10);
assert_relative_eq!(x[1], 3.0, epsilon = 1e-10);
let b_check = a.dot(&x);
assert_relative_eq!(b_check[0], b[0], epsilon = 1e-10);
assert_relative_eq!(b_check[1], b[1], epsilon = 1e-10);
}
#[test]
fn test_inv() {
let a = array![[4.0, 7.0], [2.0, 6.0]];
let a_inv = inv(&a.view()).expect("Operation failed");
assert_relative_eq!(a_inv[[0, 0]], 0.6, epsilon = 1e-10);
assert_relative_eq!(a_inv[[0, 1]], -0.7, epsilon = 1e-10);
assert_relative_eq!(a_inv[[1, 0]], -0.2, epsilon = 1e-10);
assert_relative_eq!(a_inv[[1, 1]], 0.4, epsilon = 1e-10);
let identity = a.dot(&a_inv);
assert_relative_eq!(identity[[0, 0]], 1.0, epsilon = 1e-10);
assert_relative_eq!(identity[[0, 1]], 0.0, epsilon = 1e-10);
assert_relative_eq!(identity[[1, 0]], 0.0, epsilon = 1e-10);
assert_relative_eq!(identity[[1, 1]], 1.0, epsilon = 1e-10);
}
}