use crate::error::{CoreError, Result};
use crate::tensor::Tensor;
use crate::{Float, Scalar};
pub fn dot<T: Scalar>(x: &Tensor<T>, y: &Tensor<T>) -> Result<T> {
check_vectors(x, y, "dot")?;
Ok(dot_slice(x.as_slice(), y.as_slice()))
}
fn dot_slice<T: Scalar>(a: &[T], b: &[T]) -> T {
#[cfg(feature = "simd")]
{
use crate::simd;
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
let result =
unsafe { simd::f64_ops::dot_f64(simd::slice_as_f64(a), simd::slice_as_f64(b)) };
return unsafe { simd::f64_to_t(result) };
}
if TypeId::of::<T>() == TypeId::of::<f32>() {
let result =
unsafe { simd::f32_ops::dot_f32(simd::slice_as_f32(a), simd::slice_as_f32(b)) };
return unsafe { simd::f32_to_t(result) };
}
}
a.iter()
.zip(b.iter())
.fold(T::zero(), |acc, (&x, &y)| acc + x * y)
}
pub fn axpy<T: Scalar>(alpha: T, x: &Tensor<T>, y: &mut Tensor<T>) -> Result<()> {
check_vectors(x, y, "axpy")?;
axpy_slice(alpha, x.as_slice(), y.as_mut_slice());
Ok(())
}
fn axpy_slice<T: Scalar>(alpha: T, x: &[T], y: &mut [T]) {
#[cfg(feature = "simd")]
{
use crate::simd;
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
unsafe {
simd::f64_ops::axpy_f64(
simd::t_to_f64(alpha),
simd::slice_as_f64(x),
simd::slice_as_f64_mut(y),
);
}
return;
}
if TypeId::of::<T>() == TypeId::of::<f32>() {
unsafe {
simd::f32_ops::axpy_f32(
simd::t_to_f32(alpha),
simd::slice_as_f32(x),
simd::slice_as_f32_mut(y),
);
}
return;
}
}
for (yi, &xi) in y.iter_mut().zip(x.iter()) {
*yi += alpha * xi;
}
}
pub fn nrm2<T: Float>(x: &Tensor<T>) -> Result<T> {
check_vector(x, "nrm2")?;
#[cfg(feature = "simd")]
{
use crate::simd;
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
let result =
unsafe { simd::f64_ops::sum_sq_f64(simd::slice_as_f64(x.as_slice())).sqrt() };
return Ok(unsafe { simd::f64_to_t(result) });
}
if TypeId::of::<T>() == TypeId::of::<f32>() {
let result =
unsafe { simd::f32_ops::sum_sq_f32(simd::slice_as_f32(x.as_slice())).sqrt() };
return Ok(unsafe { simd::f32_to_t(result) });
}
}
let sum_sq = x.as_slice().iter().fold(T::zero(), |acc, &v| acc + v * v);
Ok(sum_sq.sqrt())
}
pub fn asum<T: Float>(x: &Tensor<T>) -> Result<T> {
check_vector(x, "asum")?;
#[cfg(feature = "simd")]
{
use crate::simd;
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
let result = unsafe { simd::f64_ops::asum_f64(simd::slice_as_f64(x.as_slice())) };
return Ok(unsafe { simd::f64_to_t(result) });
}
if TypeId::of::<T>() == TypeId::of::<f32>() {
let result = unsafe { simd::f32_ops::asum_f32(simd::slice_as_f32(x.as_slice())) };
return Ok(unsafe { simd::f32_to_t(result) });
}
}
let result = x.as_slice().iter().fold(T::zero(), |acc, &v| acc + v.abs());
Ok(result)
}
pub fn scal<T: Scalar>(alpha: T, x: &mut Tensor<T>) -> Result<()> {
check_vector(x, "scal")?;
#[cfg(feature = "simd")]
{
use crate::simd;
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
unsafe {
simd::f64_ops::scal_f64(
simd::t_to_f64(alpha),
simd::slice_as_f64_mut(x.as_mut_slice()),
);
}
return Ok(());
}
if TypeId::of::<T>() == TypeId::of::<f32>() {
unsafe {
simd::f32_ops::scal_f32(
simd::t_to_f32(alpha),
simd::slice_as_f32_mut(x.as_mut_slice()),
);
}
return Ok(());
}
}
for v in x.as_mut_slice() {
*v *= alpha;
}
Ok(())
}
pub fn iamax<T: Float>(x: &Tensor<T>) -> Result<Option<usize>> {
check_vector(x, "iamax")?;
if x.is_empty() {
return Ok(None);
}
let mut max_idx = 0;
let mut max_val = x.as_slice()[0].abs();
for (i, &v) in x.as_slice().iter().enumerate().skip(1) {
let av = v.abs();
if av > max_val {
max_val = av;
max_idx = i;
}
}
Ok(Some(max_idx))
}
#[allow(clippy::many_single_char_names)]
pub fn gemv<T: Scalar>(
alpha: T,
a: &Tensor<T>,
x: &Tensor<T>,
beta: T,
y: &mut Tensor<T>,
) -> Result<()> {
if a.ndim() != 2 {
return Err(CoreError::InvalidArgument {
reason: "gemv: `a` must be a 2-D tensor (matrix)",
});
}
if x.ndim() != 1 {
return Err(CoreError::InvalidArgument {
reason: "gemv: `x` must be a 1-D tensor (vector)",
});
}
if y.ndim() != 1 {
return Err(CoreError::InvalidArgument {
reason: "gemv: `y` must be a 1-D tensor (vector)",
});
}
let m = a.shape()[0];
let n = a.shape()[1];
if x.numel() != n {
return Err(CoreError::DimensionMismatch {
expected: vec![n],
got: x.shape().to_vec(),
});
}
if y.numel() != m {
return Err(CoreError::DimensionMismatch {
expected: vec![m],
got: y.shape().to_vec(),
});
}
let a_data = a.as_slice();
let x_data = x.as_slice();
let y_data = y.as_mut_slice();
for (i, yi) in y_data.iter_mut().enumerate().take(m) {
let row_offset = i * n;
let row = &a_data[row_offset..row_offset + n];
let sum = dot_slice(row, x_data);
*yi = alpha * sum + beta * *yi;
}
Ok(())
}
#[allow(clippy::many_single_char_names, clippy::too_many_lines)]
pub fn gemm<T: Scalar>(
alpha: T,
a: &Tensor<T>,
b: &Tensor<T>,
beta: T,
c: &mut Tensor<T>,
) -> Result<()> {
const MC: usize = 64; const KC: usize = 256; const NC: usize = 256;
if a.ndim() != 2 || b.ndim() != 2 || c.ndim() != 2 {
return Err(CoreError::InvalidArgument {
reason: "gemm: all arguments must be 2-D tensors (matrices)",
});
}
let m = a.shape()[0];
let k = a.shape()[1];
let n = b.shape()[1];
if b.shape()[0] != k {
return Err(CoreError::DimensionMismatch {
expected: vec![k, n],
got: b.shape().to_vec(),
});
}
if c.shape()[0] != m || c.shape()[1] != n {
return Err(CoreError::DimensionMismatch {
expected: vec![m, n],
got: c.shape().to_vec(),
});
}
let a_data = a.as_slice();
let b_data = b.as_slice();
let c_data = c.as_mut_slice();
if beta == T::zero() {
for v in c_data.iter_mut() {
*v = T::zero();
}
} else if beta != T::one() {
for v in c_data.iter_mut() {
*v *= beta;
}
}
for pk in (0..k).step_by(KC) {
let kb = KC.min(k - pk);
for pi in (0..m).step_by(MC) {
let mb = MC.min(m - pi);
for pj in (0..n).step_by(NC) {
let nb = NC.min(n - pj);
#[cfg(all(target_arch = "aarch64", feature = "simd"))]
{
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
unsafe {
let a_f64 = a_data.as_ptr().cast::<f64>();
let b_f64 = b_data.as_ptr().cast::<f64>();
let c_f64 = c_data.as_mut_ptr().cast::<f64>();
let alpha_f64 = crate::simd::t_to_f64(alpha);
let j4 = nb / 4 * 4;
let i8 = mb / 8 * 8;
for i in (0..i8).step_by(8) {
for j in (0..j4).step_by(4) {
let a_off = (pi + i) * k + pk;
let b_off = pk * n + (pj + j);
let c_off = (pi + i) * n + (pj + j);
crate::simd::neon_f64_ops::gemm_8x4_f64_neon(
a_f64.add(a_off),
b_f64.add(b_off),
c_f64.add(c_off),
alpha_f64,
kb, k, n, n,
);
}
if j4 < nb {
for ii in 0..8 {
let row_a = (pi + i + ii) * k + pk;
let row_c = (pi + i + ii) * n + pj + j4;
for p in 0..kb {
let scale_f64 = alpha_f64 * *a_f64.add(row_a + p);
for jj in 0..(nb - j4) {
let b_idx = (pk + p) * n + pj + j4 + jj;
*c_f64.add(row_c + jj) += scale_f64 * *b_f64.add(b_idx);
}
}
}
}
}
let i4_start = i8;
let i4_end = i4_start + (mb - i8) / 4 * 4;
for i in (i4_start..i4_end).step_by(4) {
for j in (0..j4).step_by(4) {
let a_off = (pi + i) * k + pk;
let b_off = pk * n + (pj + j);
let c_off = (pi + i) * n + (pj + j);
crate::simd::neon_f64_ops::gemm_4x4_f64_neon(
a_f64.add(a_off),
b_f64.add(b_off),
c_f64.add(c_off),
alpha_f64,
kb, k, n, n,
);
}
if j4 < nb {
for ii in 0..4 {
let row_a = (pi + i + ii) * k + pk;
let row_c = (pi + i + ii) * n + pj + j4;
for p in 0..kb {
let scale_f64 = alpha_f64 * *a_f64.add(row_a + p);
for jj in 0..(nb - j4) {
let b_idx = (pk + p) * n + pj + j4 + jj;
*c_f64.add(row_c + jj) += scale_f64 * *b_f64.add(b_idx);
}
}
}
}
}
for i in i4_end..mb {
let row_a = (pi + i) * k + pk;
let row_c = (pi + i) * n + pj;
for p in 0..kb {
let scale = alpha * a_data[row_a + p];
let b_off2 = (pk + p) * n + pj;
let b_row = &b_data[b_off2..b_off2 + nb];
let c_slice = &mut c_data[row_c..row_c + nb];
axpy_slice(scale, b_row, c_slice);
}
}
}
continue;
}
}
for i in 0..mb {
let row_a = (pi + i) * k + pk;
let row_c = (pi + i) * n + pj;
for p in 0..kb {
let scale = alpha * a_data[row_a + p];
let b_off = (pk + p) * n + pj;
let b_row = &b_data[b_off..b_off + nb];
let c_slice = &mut c_data[row_c..row_c + nb];
axpy_slice(scale, b_row, c_slice);
}
}
}
}
}
Ok(())
}
impl<T: Scalar> Tensor<T> {
pub fn matvec(&self, x: &Tensor<T>) -> Result<Tensor<T>> {
let m = self.shape().first().copied().unwrap_or(0);
let mut y = Tensor::zeros(vec![m]);
gemv(T::one(), self, x, T::zero(), &mut y)?;
Ok(y)
}
pub fn matmul(&self, other: &Tensor<T>) -> Result<Tensor<T>> {
let m = self.shape().first().copied().unwrap_or(0);
let n = other.shape().get(1).copied().unwrap_or(0);
let mut c = Tensor::zeros(vec![m, n]);
gemm(T::one(), self, other, T::zero(), &mut c)?;
Ok(c)
}
pub fn dot(&self, other: &Tensor<T>) -> Result<T> {
dot(self, other)
}
}
impl<T: Float> Tensor<T> {
pub fn norm(&self) -> Result<T> {
nrm2(self)
}
pub fn solve(&self, b: &Tensor<T>) -> Result<Tensor<T>> {
crate::linalg::solve(self, b)
}
pub fn inv(&self) -> Result<Tensor<T>> {
crate::linalg::inv(self)
}
pub fn det(&self) -> Result<T> {
crate::linalg::det(self)
}
pub fn lstsq(&self, b: &Tensor<T>) -> Result<Tensor<T>> {
crate::linalg::lstsq(self, b)
}
}
fn check_vector<T: Scalar>(x: &Tensor<T>, name: &'static str) -> Result<()> {
if x.ndim() != 1 {
return Err(CoreError::InvalidArgument {
reason: match name {
"nrm2" => "nrm2: expected a 1-D tensor",
"asum" => "asum: expected a 1-D tensor",
"scal" => "scal: expected a 1-D tensor",
"iamax" => "iamax: expected a 1-D tensor",
_ => "expected a 1-D tensor",
},
});
}
Ok(())
}
fn check_vectors<T: Scalar>(x: &Tensor<T>, y: &Tensor<T>, name: &'static str) -> Result<()> {
if x.ndim() != 1 || y.ndim() != 1 {
return Err(CoreError::InvalidArgument {
reason: match name {
"dot" => "dot: both arguments must be 1-D tensors",
"axpy" => "axpy: both arguments must be 1-D tensors",
_ => "both arguments must be 1-D tensors",
},
});
}
if x.numel() != y.numel() {
return Err(CoreError::DimensionMismatch {
expected: x.shape().to_vec(),
got: y.shape().to_vec(),
});
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
fn vec_f64(data: &[f64]) -> Tensor<f64> {
Tensor::from_vec(data.to_vec(), vec![data.len()]).unwrap()
}
fn mat_f64(data: &[f64], rows: usize, cols: usize) -> Tensor<f64> {
Tensor::from_vec(data.to_vec(), vec![rows, cols]).unwrap()
}
#[test]
fn test_dot_basic() {
let x = vec_f64(&[1.0, 2.0, 3.0]);
let y = vec_f64(&[4.0, 5.0, 6.0]);
assert_eq!(dot(&x, &y).unwrap(), 32.0);
}
#[test]
fn test_dot_single() {
let x = vec_f64(&[3.0]);
let y = vec_f64(&[7.0]);
assert_eq!(dot(&x, &y).unwrap(), 21.0);
}
#[test]
fn test_dot_length_mismatch() {
let x = vec_f64(&[1.0, 2.0]);
let y = vec_f64(&[1.0, 2.0, 3.0]);
assert!(dot(&x, &y).is_err());
}
#[test]
fn test_dot_not_1d() {
let x = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let y = vec_f64(&[1.0, 2.0]);
assert!(dot(&x, &y).is_err());
}
#[test]
fn test_axpy() {
let x = vec_f64(&[1.0, 2.0, 3.0]);
let mut y = vec_f64(&[10.0, 20.0, 30.0]);
axpy(2.0, &x, &mut y).unwrap();
assert_eq!(y.as_slice(), &[12.0, 24.0, 36.0]);
}
#[test]
fn test_axpy_zero_alpha() {
let x = vec_f64(&[1.0, 2.0, 3.0]);
let mut y = vec_f64(&[10.0, 20.0, 30.0]);
axpy(0.0, &x, &mut y).unwrap();
assert_eq!(y.as_slice(), &[10.0, 20.0, 30.0]);
}
#[test]
fn test_nrm2() {
let x = vec_f64(&[3.0, 4.0]);
assert!((nrm2(&x).unwrap() - 5.0).abs() < 1e-10);
}
#[test]
fn test_nrm2_single() {
let x = vec_f64(&[-7.0]);
assert!((nrm2(&x).unwrap() - 7.0).abs() < 1e-10);
}
#[test]
fn test_asum() {
let x = vec_f64(&[-1.0, 2.0, -3.0, 4.0]);
assert!((asum(&x).unwrap() - 10.0).abs() < 1e-10);
}
#[test]
fn test_scal() {
let mut x = vec_f64(&[1.0, 2.0, 3.0]);
scal(10.0, &mut x).unwrap();
assert_eq!(x.as_slice(), &[10.0, 20.0, 30.0]);
}
#[test]
fn test_scal_zero() {
let mut x = vec_f64(&[1.0, 2.0, 3.0]);
scal(0.0, &mut x).unwrap();
assert_eq!(x.as_slice(), &[0.0, 0.0, 0.0]);
}
#[test]
fn test_iamax() {
let x = vec_f64(&[1.0, -5.0, 3.0, -2.0]);
assert_eq!(iamax(&x).unwrap(), Some(1));
}
#[test]
fn test_iamax_first_is_max() {
let x = vec_f64(&[100.0, 1.0, 2.0]);
assert_eq!(iamax(&x).unwrap(), Some(0));
}
#[test]
fn test_iamax_empty() {
let x = Tensor::<f64>::zeros(vec![0]);
assert_eq!(iamax(&x).unwrap(), None);
}
#[test]
fn test_gemv_basic() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let x = vec_f64(&[5.0, 6.0]);
let mut y = Tensor::<f64>::zeros(vec![2]);
gemv(1.0, &a, &x, 0.0, &mut y).unwrap();
assert_eq!(y.as_slice(), &[17.0, 39.0]);
}
#[test]
fn test_gemv_with_alpha_beta() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let x = vec_f64(&[1.0, 1.0]);
let mut y = vec_f64(&[10.0, 10.0]);
gemv(2.0, &a, &x, 3.0, &mut y).unwrap();
assert_eq!(y.as_slice(), &[36.0, 44.0]);
}
#[test]
fn test_gemv_rectangular() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
let x = vec_f64(&[1.0, 0.0, 1.0]);
let mut y = Tensor::<f64>::zeros(vec![2]);
gemv(1.0, &a, &x, 0.0, &mut y).unwrap();
assert_eq!(y.as_slice(), &[4.0, 10.0]);
}
#[test]
fn test_gemv_dimension_mismatch() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let x = vec_f64(&[1.0, 2.0, 3.0]);
let mut y = Tensor::<f64>::zeros(vec![2]);
assert!(gemv(1.0, &a, &x, 0.0, &mut y).is_err());
}
#[test]
fn test_gemv_y_dimension_mismatch() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let x = vec_f64(&[1.0, 2.0]);
let mut y = Tensor::<f64>::zeros(vec![3]);
assert!(gemv(1.0, &a, &x, 0.0, &mut y).is_err());
}
#[test]
fn test_gemm_square() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let b = mat_f64(&[5.0, 6.0, 7.0, 8.0], 2, 2);
let mut c = Tensor::<f64>::zeros(vec![2, 2]);
gemm(1.0, &a, &b, 0.0, &mut c).unwrap();
assert_eq!(c.as_slice(), &[19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_gemm_rectangular() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
let b = mat_f64(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], 3, 2);
let mut c = Tensor::<f64>::zeros(vec![2, 2]);
gemm(1.0, &a, &b, 0.0, &mut c).unwrap();
assert_eq!(c.as_slice(), &[58.0, 64.0, 139.0, 154.0]);
}
#[test]
fn test_gemm_with_alpha_beta() {
let a = mat_f64(&[1.0, 0.0, 0.0, 1.0], 2, 2); let b = mat_f64(&[5.0, 6.0, 7.0, 8.0], 2, 2);
let mut c = mat_f64(&[1.0, 1.0, 1.0, 1.0], 2, 2);
gemm(2.0, &a, &b, 3.0, &mut c).unwrap();
assert_eq!(c.as_slice(), &[13.0, 15.0, 17.0, 19.0]);
}
#[test]
fn test_gemm_identity() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 3, 3);
let eye = Tensor::<f64>::eye(3);
let mut c = Tensor::<f64>::zeros(vec![3, 3]);
gemm(1.0, &a, &eye, 0.0, &mut c).unwrap();
assert_eq!(c.as_slice(), a.as_slice());
}
#[test]
fn test_gemm_single_element() {
let a = mat_f64(&[3.0], 1, 1);
let b = mat_f64(&[7.0], 1, 1);
let mut c = Tensor::<f64>::zeros(vec![1, 1]);
gemm(1.0, &a, &b, 0.0, &mut c).unwrap();
assert_eq!(c.as_slice(), &[21.0]);
}
#[test]
fn test_gemm_dimension_mismatch() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let b = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
let mut c = Tensor::<f64>::zeros(vec![2, 2]);
assert!(gemm(1.0, &a, &b, 0.0, &mut c).is_err());
}
#[test]
fn test_gemm_c_shape_mismatch() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let b = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let mut c = Tensor::<f64>::zeros(vec![3, 3]);
assert!(gemm(1.0, &a, &b, 0.0, &mut c).is_err());
}
#[test]
fn test_matvec() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let x = vec_f64(&[5.0, 6.0]);
let y = a.matvec(&x).unwrap();
assert_eq!(y.as_slice(), &[17.0, 39.0]);
}
#[test]
fn test_matmul() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let b = mat_f64(&[5.0, 6.0, 7.0, 8.0], 2, 2);
let c = a.matmul(&b).unwrap();
assert_eq!(c.as_slice(), &[19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_tensor_dot() {
let x = vec_f64(&[1.0, 2.0, 3.0]);
let y = vec_f64(&[4.0, 5.0, 6.0]);
assert_eq!(x.dot(&y).unwrap(), 32.0);
}
#[test]
fn test_tensor_norm() {
let x = vec_f64(&[3.0, 4.0]);
assert!((x.norm().unwrap() - 5.0).abs() < 1e-10);
}
#[test]
fn test_gemm_numpy_reference() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
let b = mat_f64(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], 3, 2);
let c = a.matmul(&b).unwrap();
assert_eq!(c.as_slice(), &[58.0, 64.0, 139.0, 154.0]);
}
#[test]
fn test_gemv_numpy_reference() {
let a = mat_f64(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
let x = vec_f64(&[1.0, 1.0, 1.0]);
let y = a.matvec(&x).unwrap();
assert_eq!(y.as_slice(), &[6.0, 15.0]);
}
#[test]
fn test_dot_numpy_reference() {
let x = vec_f64(&[1.0, 2.0, 3.0, 4.0, 5.0]);
let y = vec_f64(&[5.0, 4.0, 3.0, 2.0, 1.0]);
assert_eq!(dot(&x, &y).unwrap(), 35.0);
}
#[test]
fn test_nrm2_numpy_reference() {
let x = vec_f64(&[1.0, 2.0, 3.0, 4.0, 5.0]);
let n = nrm2(&x).unwrap();
assert!((n - 7.416_198_487_095_663).abs() < 1e-12);
}
}