use scirs2_core::ndarray::{
Array, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, Axis, ScalarOperand,
};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
pub mod attention;
pub use attention::{
batch_flash_attention, batch_multi_head_attention, batch_multi_query_attention,
};
pub mod operations;
pub use operations::{
batch_cholesky, batch_det, batch_frobenius_norm, batch_inverse, batch_matmul_pairwise,
batch_solve, batch_solve_matrix, batch_trace, batch_transpose,
};
const BATCH_MATMUL_SIMD_THRESHOLD: usize = 64;
#[allow(dead_code)]
pub fn batch_matmul<F>(batch_a: &ArrayView3<F>, b: &ArrayView2<F>) -> LinalgResult<Array3<F>>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let (batchsize, m, k1) = batch_a.dim();
let (k2, n) = b.dim();
if k1 != k2 {
return Err(LinalgError::ShapeError(format!(
"Inner dimensions mismatch for batch_matmul: {k1} vs {k2}"
)));
}
let mut result = Array::zeros((batchsize, m, n));
let work_size = m * k1 * n;
if work_size >= BATCH_MATMUL_SIMD_THRESHOLD {
for batch_idx in 0..batchsize {
let a_slice = batch_a.slice(scirs2_core::ndarray::s![batch_idx, .., ..]);
let batch_result = crate::blas_accelerated::matmul(&a_slice, b)?;
result
.slice_mut(scirs2_core::ndarray::s![batch_idx, .., ..])
.assign(&batch_result);
}
} else {
for batch_idx in 0..batchsize {
for i in 0..m {
for j in 0..n {
let mut sum = F::zero();
for k in 0..k1 {
sum += batch_a[[batch_idx, i, k]] * b[[k, j]];
}
result[[batch_idx, i, j]] = sum;
}
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn batch_matvec<F>(batch_a: &ArrayView3<F>, x: &ArrayView1<F>) -> LinalgResult<Array2<F>>
where
F: Float
+ NumAssign
+ Sum
+ Send
+ Sync
+ ScalarOperand
+ scirs2_core::simd_ops::SimdUnifiedOps
+ 'static,
{
let (batchsize, m, n) = batch_a.dim();
let x_len = x.len();
if n != x_len {
return Err(LinalgError::ShapeError(format!(
"Dimension mismatch for batch_matvec: matrix width {n} does not match vector length {x_len}"
)));
}
let mut result = Array::zeros((batchsize, m));
if n >= 8 {
for batch_idx in 0..batchsize {
for i in 0..m {
let row = batch_a.slice(scirs2_core::ndarray::s![batch_idx, i, ..]);
result[[batch_idx, i]] = F::simd_dot(&row, x);
}
}
} else {
for batch_idx in 0..batchsize {
for i in 0..m {
let mut sum = F::zero();
for j in 0..n {
sum += batch_a[[batch_idx, i, j]] * x[j];
}
result[[batch_idx, i]] = sum;
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn batch_add<F>(
batch_a: &ArrayView3<F>,
v: &ArrayView1<F>,
axis: usize,
) -> LinalgResult<Array3<F>>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let (batchsize, m, n) = batch_a.dim();
match axis {
0 => {
if v.len() != m {
return Err(LinalgError::ShapeError(format!(
"Column-wise batch_add requires vector length {} to match number of rows {}, got {}",
m, m, v.len()
)));
}
}
1 => {
if v.len() != n {
return Err(LinalgError::ShapeError(format!(
"Row-wise batch_add requires vector length {} to match number of columns {}, got {}",
n, n, v.len()
)));
}
}
_ => {
return Err(LinalgError::InvalidInputError(format!(
"Invalid axis {axis}: must be 0 (column-wise) or 1 (row-wise)"
)));
}
}
let mut result = batch_a.to_owned();
match axis {
0 => {
for batch_idx in 0..batchsize {
for i in 0..m {
for j in 0..n {
result[[batch_idx, i, j]] += v[i];
}
}
}
}
1 => {
for batch_idx in 0..batchsize {
for i in 0..m {
for j in 0..n {
result[[batch_idx, i, j]] += v[j];
}
}
}
}
_ => unreachable!(), }
Ok(result)
}
#[allow(dead_code)]
pub fn batch_sum<F>(batch_a: &ArrayView3<F>) -> Array2<F>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
batch_a.sum_axis(Axis(0))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::{array, Array3};
#[test]
fn test_batch_matmul() {
let batch_a = Array3::from_shape_vec(
(2, 2, 2),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ],
)
.expect("Operation failed");
let b = array![[1.0, 2.0], [3.0, 4.0]];
let result = batch_matmul(&batch_a.view(), &b.view()).expect("Operation failed");
assert_eq!(result.shape(), &[2, 2, 2]);
assert_relative_eq!(result[[0, 0, 0]], 7.0);
assert_relative_eq!(result[[0, 0, 1]], 10.0);
assert_relative_eq!(result[[0, 1, 0]], 15.0);
assert_relative_eq!(result[[0, 1, 1]], 22.0);
assert_relative_eq!(result[[1, 0, 0]], 23.0);
assert_relative_eq!(result[[1, 0, 1]], 34.0);
assert_relative_eq!(result[[1, 1, 0]], 31.0);
assert_relative_eq!(result[[1, 1, 1]], 46.0);
}
#[test]
fn test_batch_matvec() {
let batch_a = Array3::from_shape_vec(
(2, 2, 3),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
)
.expect("Operation failed");
let x = array![1.0, 2.0, 3.0];
let result = batch_matvec(&batch_a.view(), &x.view()).expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result[[0, 0]], 14.0);
assert_relative_eq!(result[[0, 1]], 32.0);
assert_relative_eq!(result[[1, 0]], 50.0);
assert_relative_eq!(result[[1, 1]], 68.0);
}
#[test]
fn test_batch_add_row_wise() {
let batch_a = Array3::from_shape_vec(
(2, 2, 2),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ],
)
.expect("Operation failed");
let v = array![10.0, 20.0];
let result = batch_add(&batch_a.view(), &v.view(), 1).expect("Operation failed");
assert_eq!(result.shape(), &[2, 2, 2]);
assert_relative_eq!(result[[0, 0, 0]], 11.0);
assert_relative_eq!(result[[0, 0, 1]], 22.0);
assert_relative_eq!(result[[0, 1, 0]], 13.0);
assert_relative_eq!(result[[0, 1, 1]], 24.0);
assert_relative_eq!(result[[1, 0, 0]], 15.0);
assert_relative_eq!(result[[1, 0, 1]], 26.0);
assert_relative_eq!(result[[1, 1, 0]], 17.0);
assert_relative_eq!(result[[1, 1, 1]], 28.0);
}
#[test]
fn test_batch_add_column_wise() {
let batch_a = Array3::from_shape_vec(
(2, 2, 2),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ],
)
.expect("Operation failed");
let v = array![10.0, 20.0];
let result = batch_add(&batch_a.view(), &v.view(), 0).expect("Operation failed");
assert_eq!(result.shape(), &[2, 2, 2]);
assert_relative_eq!(result[[0, 0, 0]], 11.0);
assert_relative_eq!(result[[0, 0, 1]], 12.0);
assert_relative_eq!(result[[0, 1, 0]], 23.0);
assert_relative_eq!(result[[0, 1, 1]], 24.0);
assert_relative_eq!(result[[1, 0, 0]], 15.0);
assert_relative_eq!(result[[1, 0, 1]], 16.0);
assert_relative_eq!(result[[1, 1, 0]], 27.0);
assert_relative_eq!(result[[1, 1, 1]], 28.0);
}
#[test]
fn test_batch_sum() {
let batch_a = Array3::from_shape_vec(
(2, 2, 2),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ],
)
.expect("Operation failed");
let result = batch_sum(&batch_a.view());
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result[[0, 0]], 6.0);
assert_relative_eq!(result[[0, 1]], 8.0);
assert_relative_eq!(result[[1, 0]], 10.0);
assert_relative_eq!(result[[1, 1]], 12.0);
}
#[test]
fn test_batch_matmul_dimension_error() {
let batch_a = Array3::from_shape_vec(
(2, 2, 3),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
)
.expect("Operation failed");
let b = array![[1.0, 2.0], [3.0, 4.0]];
let result = batch_matmul(&batch_a.view(), &b.view());
assert!(result.is_err());
}
#[test]
fn test_batch_matvec_dimension_error() {
let batch_a = Array3::from_shape_vec(
(2, 2, 3),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
)
.expect("Operation failed");
let x = array![1.0, 2.0];
let result = batch_matvec(&batch_a.view(), &x.view());
assert!(result.is_err());
}
#[test]
fn test_batch_add_dimension_error() {
let batch_a = Array3::from_shape_vec(
(2, 2, 2),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ],
)
.expect("Operation failed");
let v = array![10.0, 20.0, 30.0];
let result = batch_add(&batch_a.view(), &v.view(), 1);
assert!(result.is_err());
}
#[test]
fn test_batch_add_invalid_axis() {
let batch_a = Array3::from_shape_vec(
(2, 2, 2),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ],
)
.expect("Operation failed");
let v = array![10.0, 20.0];
let result = batch_add(&batch_a.view(), &v.view(), 2);
assert!(result.is_err());
}
}