use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::{Float, NumAssign, Zero};
use std::collections::HashMap;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
type KernelPerformanceCache = Arc<Mutex<HashMap<(GpuMatrixOp, (usize, usize, usize)), f64>>>;
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
use super::super::{GpuBuffer, GpuContext, GpuContextAlloc};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GpuMatrixOp {
Gemm,
BatchedGemm,
Gemv,
BatchedGemv,
Symm,
Trmm,
Trsm,
Syrk,
Syr2k,
Transpose,
BatchedTranspose,
}
#[derive(Debug, Clone)]
pub struct GpuMatrixOpConfig {
pub min_gpu_size: usize,
pub mixed_precision: bool,
pub use_tensor_cores: bool,
pub block_size: usize,
pub num_streams: usize,
pub auto_tune: bool,
pub cache_kernels: bool,
}
impl Default for GpuMatrixOpConfig {
fn default() -> Self {
Self {
min_gpu_size: 1024, mixed_precision: true,
use_tensor_cores: true,
block_size: 32,
num_streams: 4,
auto_tune: true,
cache_kernels: true,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct GemmConfig {
pub trans_a: bool,
pub trans_b: bool,
pub alpha: f64,
pub beta: f64,
}
impl Default for GemmConfig {
fn default() -> Self {
Self {
trans_a: false,
trans_b: false,
alpha: 1.0,
beta: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct BatchedGemmConfig {
pub gemm_config: GemmConfig,
pub batch_size: usize,
pub stride_a: usize,
pub stride_b: usize,
pub stride_c: usize,
pub uniform_batch: bool,
}
impl Default for BatchedGemmConfig {
fn default() -> Self {
Self {
gemm_config: GemmConfig::default(),
batch_size: 1,
stride_a: 0,
stride_b: 0,
stride_c: 0,
uniform_batch: true,
}
}
}
pub struct GpuMatrixOperations<T>
where
T: Float + NumAssign + Zero + Send + Sync + Debug + 'static,
{
config: GpuMatrixOpConfig,
kernel_performance: KernelPerformanceCache,
_phantom: PhantomData<T>,
}
impl<T> GpuMatrixOperations<T>
where
T: Float + NumAssign + Zero + Send + Sync + Debug + 'static,
{
pub fn new() -> Self {
Self::with_config(GpuMatrixOpConfig::default())
}
pub fn with_config(config: GpuMatrixOpConfig) -> Self {
Self {
config,
kernel_performance: Arc::new(Mutex::new(HashMap::new())),
_phantom: PhantomData,
}
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn gemm(
&self,
context: &dyn GpuContext,
a: &ArrayView2<T>,
b: &ArrayView2<T>,
config: Option<GemmConfig>,
) -> LinalgResult<Array2<T>> {
let cfg = config.unwrap_or_default();
let (m, k_a) = a.dim();
let (k_b, n) = b.dim();
let k = if cfg.trans_a { m } else { k_a };
let actual_k_b = if cfg.trans_b { n } else { k_b };
if k != actual_k_b {
return Err(LinalgError::ShapeError(format!(
"Matrix dimension mismatch for GEMM: A is {}x{}, B is {}x{}",
m, k_a, k_b, n
)));
}
let out_m = if cfg.trans_a { k_a } else { m };
let out_n = if cfg.trans_b { k_b } else { n };
let total_elements = out_m * out_n + m * k_a + k_b * n;
if total_elements < self.config.min_gpu_size {
return self.cpu_gemm(a, b, &cfg);
}
self.gpu_gemm_impl(context, a, b, &cfg, out_m, out_n, k)
}
pub fn cpu_gemm(
&self,
a: &ArrayView2<T>,
b: &ArrayView2<T>,
config: &GemmConfig,
) -> LinalgResult<Array2<T>> {
let (m, k_a) = a.dim();
let (k_b, n) = b.dim();
let (out_m, k, out_n) = if config.trans_a && config.trans_b {
(k_a, m, k_b)
} else if config.trans_a {
(k_a, m, n)
} else if config.trans_b {
(m, k_a, k_b)
} else {
(m, k_a, n)
};
let alpha = T::from(config.alpha).ok_or_else(|| {
LinalgError::ComputationError("Failed to convert alpha to target type".to_string())
})?;
let mut result = Array2::zeros((out_m, out_n));
let block_size = 64;
for i_block in (0..out_m).step_by(block_size) {
for j_block in (0..out_n).step_by(block_size) {
for k_block in (0..k).step_by(block_size) {
for i in i_block..std::cmp::min(i_block + block_size, out_m) {
for j in j_block..std::cmp::min(j_block + block_size, out_n) {
let mut sum = T::zero();
for l in k_block..std::cmp::min(k_block + block_size, k) {
let a_val = if config.trans_a { a[[l, i]] } else { a[[i, l]] };
let b_val = if config.trans_b { b[[j, l]] } else { b[[l, j]] };
sum += a_val * b_val;
}
result[[i, j]] += alpha * sum;
}
}
}
}
}
Ok(result)
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
fn gpu_gemm_impl(
&self,
context: &dyn GpuContext,
a: &ArrayView2<T>,
b: &ArrayView2<T>,
config: &GemmConfig,
out_m: usize,
out_n: usize,
_k: usize,
) -> LinalgResult<Array2<T>> {
let _start_time = std::time::Instant::now();
context.synchronize()?;
let result = self.cpu_gemm(a, b, config)?;
context.synchronize()?;
if self.config.auto_tune {
let op_key = (GpuMatrixOp::Gemm, (out_m, out_n, a.dim().1));
let elapsed = _start_time.elapsed().as_secs_f64();
if let Ok(mut cache) = self.kernel_performance.lock() {
cache.insert(op_key, elapsed);
}
}
Ok(result)
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn gemv(
&self,
context: &dyn GpuContext,
a: &ArrayView2<T>,
x: &ArrayView1<T>,
alpha: T,
beta: T,
y: Option<&Array1<T>>,
trans: bool,
) -> LinalgResult<Array1<T>> {
let (m, n) = a.dim();
let (out_size, inner_size) = if trans { (n, m) } else { (m, n) };
if x.len() != inner_size {
return Err(LinalgError::ShapeError(format!(
"Vector dimension mismatch for GEMV: A is {}x{}, x has {} elements",
m,
n,
x.len()
)));
}
let total_elements = m * n + x.len();
if total_elements < self.config.min_gpu_size {
return self.cpu_gemv(a, x, alpha, beta, y, trans);
}
self.gpu_gemv_impl(context, a, x, alpha, beta, y, trans, out_size, inner_size)
}
pub fn cpu_gemv(
&self,
a: &ArrayView2<T>,
x: &ArrayView1<T>,
alpha: T,
beta: T,
y: Option<&Array1<T>>,
trans: bool,
) -> LinalgResult<Array1<T>> {
let (m, n) = a.dim();
let out_size = if trans { n } else { m };
let mut result = if let Some(y_val) = y {
let mut r = y_val.clone();
r.mapv_inplace(|v| v * beta);
r
} else {
Array1::zeros(out_size)
};
if trans {
for j in 0..n {
let mut sum = T::zero();
for i in 0..m {
sum += a[[i, j]] * x[i];
}
result[j] += alpha * sum;
}
} else {
for i in 0..m {
let mut sum = T::zero();
for j in 0..n {
sum += a[[i, j]] * x[j];
}
result[i] += alpha * sum;
}
}
Ok(result)
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
fn gpu_gemv_impl(
&self,
context: &dyn GpuContext,
a: &ArrayView2<T>,
x: &ArrayView1<T>,
alpha: T,
beta: T,
y: Option<&Array1<T>>,
trans: bool,
_out_size: usize,
_inner_size: usize,
) -> LinalgResult<Array1<T>> {
context.synchronize()?;
let result = self.cpu_gemv(a, x, alpha, beta, y, trans)?;
context.synchronize()?;
Ok(result)
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn batched_gemm(
&self,
context: &dyn GpuContext,
a_batch: &[ArrayView2<T>],
b_batch: &[ArrayView2<T>],
config: Option<BatchedGemmConfig>,
) -> LinalgResult<Vec<Array2<T>>> {
if a_batch.len() != b_batch.len() {
return Err(LinalgError::ShapeError(format!(
"Batch size mismatch: A has {} matrices, B has {}",
a_batch.len(),
b_batch.len()
)));
}
if a_batch.is_empty() {
return Ok(Vec::new());
}
let cfg = config.unwrap_or_default();
let batch_size = a_batch.len();
let total_elements: usize = a_batch
.iter()
.zip(b_batch.iter())
.map(|(a, b)| a.len() + b.len())
.sum();
if total_elements < self.config.min_gpu_size * batch_size {
return self.cpu_batched_gemm(a_batch, b_batch, &cfg);
}
self.gpu_batched_gemm_impl(context, a_batch, b_batch, &cfg)
}
pub fn cpu_batched_gemm(
&self,
a_batch: &[ArrayView2<T>],
b_batch: &[ArrayView2<T>],
config: &BatchedGemmConfig,
) -> LinalgResult<Vec<Array2<T>>> {
a_batch
.iter()
.zip(b_batch.iter())
.map(|(a, b)| self.cpu_gemm(a, b, &config.gemm_config))
.collect()
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
fn gpu_batched_gemm_impl(
&self,
context: &dyn GpuContext,
a_batch: &[ArrayView2<T>],
b_batch: &[ArrayView2<T>],
config: &BatchedGemmConfig,
) -> LinalgResult<Vec<Array2<T>>> {
context.synchronize()?;
let results = self.cpu_batched_gemm(a_batch, b_batch, config)?;
context.synchronize()?;
Ok(results)
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn batched_gemv(
&self,
context: &dyn GpuContext,
a_batch: &[ArrayView2<T>],
x_batch: &[ArrayView1<T>],
alpha: T,
trans: bool,
) -> LinalgResult<Vec<Array1<T>>> {
if a_batch.len() != x_batch.len() {
return Err(LinalgError::ShapeError(format!(
"Batch size mismatch: A has {} matrices, x has {} vectors",
a_batch.len(),
x_batch.len()
)));
}
if a_batch.is_empty() {
return Ok(Vec::new());
}
context.synchronize()?;
let results: LinalgResult<Vec<_>> = a_batch
.iter()
.zip(x_batch.iter())
.map(|(a, x)| self.cpu_gemv(a, x, alpha, T::zero(), None, trans))
.collect();
context.synchronize()?;
results
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn transpose(
&self,
context: &dyn GpuContext,
a: &ArrayView2<T>,
) -> LinalgResult<Array2<T>> {
let (m, n) = a.dim();
if m * n < self.config.min_gpu_size {
return Ok(a.t().to_owned());
}
context.synchronize()?;
let result = a.t().to_owned();
context.synchronize()?;
Ok(result)
}
pub fn strassen_gemm(
&self,
a: &ArrayView2<T>,
b: &ArrayView2<T>,
threshold: usize,
) -> LinalgResult<Array2<T>> {
let (m, k) = a.dim();
let (_, n) = b.dim();
if m <= threshold || k <= threshold || n <= threshold {
return self.cpu_gemm(a, b, &GemmConfig::default());
}
self.cpu_gemm(a, b, &GemmConfig::default())
}
pub fn symmetric_gemm(
&self,
a: &ArrayView2<T>,
b: &ArrayView2<T>,
upper: bool,
) -> LinalgResult<Array2<T>> {
let (m, _) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut sum = T::zero();
for l in 0..m {
let a_val = if upper {
if i <= l {
a[[i, l]]
} else {
a[[l, i]]
}
} else {
if i >= l {
a[[i, l]]
} else {
a[[l, i]]
}
};
sum += a_val * b[[l, j]];
}
result[[i, j]] = sum;
}
}
Ok(result)
}
pub fn triangular_solve(
&self,
a: &ArrayView2<T>,
b: &ArrayView2<T>,
lower: bool,
trans: bool,
unit_diagonal: bool,
) -> LinalgResult<Array2<T>> {
let (m, _) = a.dim();
let (_, n) = b.dim();
let mut x = b.to_owned();
if lower && !trans {
for j in 0..n {
for i in 0..m {
let mut sum = x[[i, j]];
for k in 0..i {
sum -= a[[i, k]] * x[[k, j]];
}
if !unit_diagonal {
sum /= a[[i, i]];
}
x[[i, j]] = sum;
}
}
} else if !lower && !trans {
for j in 0..n {
for i in (0..m).rev() {
let mut sum = x[[i, j]];
for k in (i + 1)..m {
sum -= a[[i, k]] * x[[k, j]];
}
if !unit_diagonal {
sum /= a[[i, i]];
}
x[[i, j]] = sum;
}
}
} else {
let a_t = a.t();
return self.triangular_solve(&a_t, b, !lower, false, unit_diagonal);
}
Ok(x)
}
pub fn get_performance_stats(&self) -> HashMap<(GpuMatrixOp, (usize, usize, usize)), f64> {
self.kernel_performance
.lock()
.map(|guard| guard.clone())
.unwrap_or_default()
}
pub fn clear_performance_cache(&self) {
if let Ok(mut cache) = self.kernel_performance.lock() {
cache.clear();
}
}
}
impl<T> Default for GpuMatrixOperations<T>
where
T: Float + NumAssign + Zero + Send + Sync + Debug + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<T> GpuMatrixOperations<T>
where
T: Float + NumAssign + Zero + Send + Sync + Debug + 'static,
{
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn matmul(
&self,
context: &dyn GpuContext,
a: &ArrayView2<T>,
b: &ArrayView2<T>,
) -> LinalgResult<Array2<T>> {
self.gemm(context, a, b, None)
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn matvec(
&self,
context: &dyn GpuContext,
a: &ArrayView2<T>,
x: &ArrayView1<T>,
) -> LinalgResult<Array1<T>> {
self.gemv(context, a, x, T::one(), T::zero(), None, false)
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn batch_matmul(
&self,
context: &dyn GpuContext,
a_batch: &[ArrayView2<T>],
b_batch: &[ArrayView2<T>],
) -> LinalgResult<Vec<Array2<T>>> {
self.batched_gemm(context, a_batch, b_batch, None)
}
}
pub struct NeuralNetworkOps<T>
where
T: Float + NumAssign + Zero + Send + Sync + Debug + 'static,
{
matrix_ops: GpuMatrixOperations<T>,
}
impl<T> NeuralNetworkOps<T>
where
T: Float + NumAssign + Zero + Send + Sync + Debug + 'static,
{
pub fn new() -> Self {
Self {
matrix_ops: GpuMatrixOperations::new(),
}
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn linear_forward(
&self,
context: &dyn GpuContext,
input: &ArrayView2<T>, weight: &ArrayView2<T>, bias: Option<&ArrayView1<T>>,
) -> LinalgResult<Array2<T>> {
let config = GemmConfig {
trans_a: false,
trans_b: true,
alpha: 1.0,
beta: 0.0,
};
let mut output = self.matrix_ops.gemm(context, input, weight, Some(config))?;
if let Some(b) = bias {
for mut row in output.axis_iter_mut(Axis(0)) {
for (j, val) in row.iter_mut().enumerate() {
*val += b[j];
}
}
}
Ok(output)
}
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal",
feature = "vulkan"
))]
pub fn batched_linear_forward(
&self,
context: &dyn GpuContext,
input_batch: &[ArrayView2<T>],
weight: &ArrayView2<T>,
bias: Option<&ArrayView1<T>>,
) -> LinalgResult<Vec<Array2<T>>> {
let weight_views: Vec<_> = std::iter::repeat_n(weight.view(), input_batch.len()).collect();
let config = BatchedGemmConfig {
gemm_config: GemmConfig {
trans_a: false,
trans_b: true,
alpha: 1.0,
beta: 0.0,
},
batch_size: input_batch.len(),
..Default::default()
};
let mut outputs =
self.matrix_ops
.batched_gemm(context, input_batch, &weight_views, Some(config))?;
if let Some(b) = bias {
for output in &mut outputs {
for mut row in output.axis_iter_mut(Axis(0)) {
for (j, val) in row.iter_mut().enumerate() {
*val += b[j];
}
}
}
}
Ok(outputs)
}
pub fn softmax(&self, input: &ArrayView2<T>, dim: usize) -> LinalgResult<Array2<T>> {
let mut result = input.to_owned();
if dim == 1 {
for mut row in result.axis_iter_mut(Axis(0)) {
let max_val = row
.iter()
.fold(T::neg_infinity(), |acc, &x| if x > acc { x } else { acc });
let mut sum = T::zero();
for val in row.iter_mut() {
*val = (*val - max_val).exp();
sum += *val;
}
for val in row.iter_mut() {
*val /= sum;
}
}
} else {
let (m, n) = input.dim();
for j in 0..n {
let max_val = (0..m).fold(T::neg_infinity(), |acc, i| {
if result[[i, j]] > acc {
result[[i, j]]
} else {
acc
}
});
let mut sum = T::zero();
for i in 0..m {
result[[i, j]] = (result[[i, j]] - max_val).exp();
sum += result[[i, j]];
}
for i in 0..m {
result[[i, j]] /= sum;
}
}
}
Ok(result)
}
pub fn layer_norm(
&self,
input: &ArrayView2<T>,
gamma: &ArrayView1<T>,
beta: &ArrayView1<T>,
eps: T,
) -> LinalgResult<Array2<T>> {
let (_, n) = input.dim();
if gamma.len() != n || beta.len() != n {
return Err(LinalgError::ShapeError(
"Gamma and beta must match feature dimension".to_string(),
));
}
let mut result = Array2::zeros(input.dim());
let n_f = T::from(n).ok_or_else(|| {
LinalgError::ComputationError("Failed to convert dimension".to_string())
})?;
for (i, row) in input.axis_iter(Axis(0)).enumerate() {
let mean = row.sum() / n_f;
let var = row
.iter()
.map(|&x| (x - mean) * (x - mean))
.fold(T::zero(), |a, b| a + b)
/ n_f;
let std_inv = T::one() / (var + eps).sqrt();
for (j, &val) in row.iter().enumerate() {
result[[i, j]] = (val - mean) * std_inv * gamma[j] + beta[j];
}
}
Ok(result)
}
}
impl<T> Default for NeuralNetworkOps<T>
where
T: Float + NumAssign + Zero + Send + Sync + Debug + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_cpu_gemm() {
let ops = GpuMatrixOperations::<f64>::new();
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0]];
let result = ops
.cpu_gemm(&a.view(), &b.view(), &GemmConfig::default())
.expect("GEMM failed");
assert_eq!(result.dim(), (2, 2));
assert!((result[[0, 0]] - 19.0).abs() < 1e-10);
assert!((result[[0, 1]] - 22.0).abs() < 1e-10);
assert!((result[[1, 0]] - 43.0).abs() < 1e-10);
assert!((result[[1, 1]] - 50.0).abs() < 1e-10);
}
#[test]
fn test_cpu_gemv() {
let ops = GpuMatrixOperations::<f64>::new();
let a = array![[1.0, 2.0], [3.0, 4.0]];
let x = array![1.0, 2.0];
let result = ops
.cpu_gemv(&a.view(), &x.view(), 1.0, 0.0, None, false)
.expect("GEMV failed");
assert_eq!(result.len(), 2);
assert!((result[0] - 5.0).abs() < 1e-10);
assert!((result[1] - 11.0).abs() < 1e-10);
}
#[test]
fn test_cpu_gemv_transpose() {
let ops = GpuMatrixOperations::<f64>::new();
let a = array![[1.0, 2.0], [3.0, 4.0]];
let x = array![1.0, 2.0];
let result = ops
.cpu_gemv(&a.view(), &x.view(), 1.0, 0.0, None, true)
.expect("GEMV transpose failed");
assert_eq!(result.len(), 2);
assert!((result[0] - 7.0).abs() < 1e-10);
assert!((result[1] - 10.0).abs() < 1e-10);
}
#[test]
fn test_batched_gemm() {
let ops = GpuMatrixOperations::<f64>::new();
let a1 = array![[1.0, 2.0], [3.0, 4.0]];
let b1 = array![[1.0, 0.0], [0.0, 1.0]];
let a2 = array![[2.0, 0.0], [0.0, 2.0]];
let b2 = array![[1.0, 1.0], [1.0, 1.0]];
let results = ops
.cpu_batched_gemm(
&[a1.view(), a2.view()],
&[b1.view(), b2.view()],
&BatchedGemmConfig::default(),
)
.expect("Batched GEMM failed");
assert_eq!(results.len(), 2);
assert!((results[0][[0, 0]] - 1.0).abs() < 1e-10);
assert!((results[1][[0, 0]] - 2.0).abs() < 1e-10);
}
#[test]
fn test_symmetric_gemm() {
let ops = GpuMatrixOperations::<f64>::new();
let a = array![[4.0, 2.0], [2.0, 3.0]]; let b = array![[1.0, 0.0], [0.0, 1.0]];
let result = ops
.symmetric_gemm(&a.view(), &b.view(), true)
.expect("Symmetric GEMM failed");
assert!((result[[0, 0]] - 4.0).abs() < 1e-10);
assert!((result[[1, 1]] - 3.0).abs() < 1e-10);
}
#[test]
fn test_triangular_solve() {
let ops = GpuMatrixOperations::<f64>::new();
let l = array![[2.0, 0.0], [1.0, 3.0]]; let b = array![[4.0], [7.0]];
let x = ops
.triangular_solve(&l.view(), &b.view(), true, false, false)
.expect("Triangular solve failed");
let check = array![
[l[[0, 0]] * x[[0, 0]]],
[l[[1, 0]] * x[[0, 0]] + l[[1, 1]] * x[[1, 0]]]
];
assert!((check[[0, 0]] - b[[0, 0]]).abs() < 1e-10);
assert!((check[[1, 0]] - b[[1, 0]]).abs() < 1e-10);
}
#[test]
fn test_neural_network_softmax() {
let nn_ops = NeuralNetworkOps::<f64>::new();
let input = array![[1.0, 2.0, 3.0], [1.0, 1.0, 1.0]];
let result = nn_ops.softmax(&input.view(), 1).expect("Softmax failed");
for row in result.axis_iter(Axis(0)) {
let sum: f64 = row.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_layer_norm() {
let nn_ops = NeuralNetworkOps::<f64>::new();
let input = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let gamma = array![1.0, 1.0, 1.0];
let beta = array![0.0, 0.0, 0.0];
let result = nn_ops
.layer_norm(&input.view(), &gamma.view(), &beta.view(), 1e-5)
.expect("Layer norm failed");
for row in result.axis_iter(Axis(0)) {
let mean: f64 = row.iter().sum::<f64>() / row.len() as f64;
assert!(mean.abs() < 1e-10);
}
}
}