use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign};
#[allow(dead_code)]
pub fn dot<F>(x: &ArrayView1<F>, y: &ArrayView1<F>) -> F
where
F: Float + NumAssign,
{
if x.len() != y.len() {
panic!("Vectors must have the same length for dot product");
}
let mut result = F::zero();
for i in 0..x.len() {
result += x[i] * y[i];
}
result
}
#[allow(dead_code)]
pub fn nrm2<F>(x: &ArrayView1<F>) -> F
where
F: Float + NumAssign,
{
let mut result = F::zero();
for i in 0..x.len() {
result += x[i] * x[i];
}
result.sqrt()
}
#[allow(dead_code)]
pub fn asum<F>(x: &ArrayView1<F>) -> F
where
F: Float + NumAssign,
{
let mut result = F::zero();
for i in 0..x.len() {
result += x[i].abs();
}
result
}
#[allow(dead_code)]
pub fn iamax<F>(x: &ArrayView1<F>) -> usize
where
F: Float + NumAssign,
{
if x.is_empty() {
panic!("Cannot find maximum of an empty vector");
}
let mut max_idx = 0;
let mut max_val = x[0].abs();
for i in 1..x.len() {
let abs_val = x[i].abs();
if abs_val > max_val {
max_val = abs_val;
max_idx = i;
}
}
max_idx
}
#[allow(dead_code)]
pub fn axpy<F>(alpha: F, x: &ArrayView1<F>, y: &mut Array1<F>)
where
F: Float + NumAssign,
{
if x.len() != y.len() {
panic!("Vectors must have the same length for axpy operation");
}
for i in 0..x.len() {
y[i] += alpha * x[i];
}
}
#[allow(dead_code)]
pub fn gemv<F>(alpha: F, a: &ArrayView2<F>, x: &ArrayView1<F>, beta: F, y: &mut Array1<F>)
where
F: Float + NumAssign,
{
if a.ncols() != x.len() || a.nrows() != y.len() {
panic!("Incompatible dimensions for matrix-vector multiplication");
}
if beta != F::one() {
for i in 0..y.len() {
y[i] *= beta;
}
}
for i in 0..a.nrows() {
let row = a.slice(scirs2_core::ndarray::s![i, ..]);
let mut sum = F::zero();
for j in 0..x.len() {
sum += row[j] * x[j];
}
y[i] += alpha * sum;
}
}
#[allow(dead_code)]
pub fn gemm<F>(alpha: F, a: &ArrayView2<F>, b: &ArrayView2<F>, beta: F, c: &mut Array2<F>)
where
F: Float + NumAssign,
{
if a.ncols() != b.nrows() || a.nrows() != c.nrows() || b.ncols() != c.ncols() {
panic!("Incompatible dimensions for matrix-matrix multiplication");
}
if beta != F::one() {
for i in 0..c.nrows() {
for j in 0..c.ncols() {
c[[i, j]] *= beta;
}
}
}
for i in 0..a.nrows() {
for j in 0..b.ncols() {
let mut sum = F::zero();
for k in 0..a.ncols() {
sum += a[[i, k]] * b[[k, j]];
}
c[[i, j]] += alpha * sum;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::{array, Array1, Array2};
#[test]
fn test_dot() {
let x = array![1.0, 2.0, 3.0];
let y = array![4.0, 5.0, 6.0];
let result = dot(&x.view(), &y.view());
assert_relative_eq!(result, 32.0); }
#[test]
fn test_nrm2() {
let x = array![3.0, 4.0];
let result = nrm2(&x.view());
assert_relative_eq!(result, 5.0); }
#[test]
fn test_asum() {
let x = array![1.0, -2.0, 3.0];
let result = asum(&x.view());
assert_relative_eq!(result, 6.0); }
#[test]
fn test_iamax() {
let x = array![1.0, -5.0, 3.0];
let result = iamax(&x.view());
assert_eq!(result, 1); }
#[test]
fn test_axpy() {
let x = array![1.0, 2.0, 3.0];
let mut y = array![4.0, 5.0, 6.0];
axpy(2.0, &x.view(), &mut y);
assert_relative_eq!(y[0], 6.0); assert_relative_eq!(y[1], 9.0); assert_relative_eq!(y[2], 12.0); }
#[test]
fn test_gemv() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let x = array![2.0, 3.0];
let mut y = Array1::zeros(2);
gemv(1.0, &a.view(), &x.view(), 0.0, &mut y);
assert_relative_eq!(y[0], 8.0); assert_relative_eq!(y[1], 18.0); }
#[test]
fn test_gemm() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0]];
let mut c = Array2::zeros((2, 2));
gemm(1.0, &a.view(), &b.view(), 0.0, &mut c);
assert_relative_eq!(c[[0, 0]], 19.0); assert_relative_eq!(c[[0, 1]], 22.0); assert_relative_eq!(c[[1, 0]], 43.0); assert_relative_eq!(c[[1, 1]], 50.0); }
}