pub(crate) mod helpers;
mod statistics;
mod tensor;
#[cfg(test)]
mod tests {
use crate::ops::{
ActivationOps, BinaryOps, IndexingOps, MatmulOps, NormalizationOps, ReduceOps,
};
use crate::runtime::Runtime;
use crate::runtime::cuda::{CudaDevice, CudaRuntime, is_cuda_available};
use crate::tensor::Tensor;
#[test]
fn test_cuda_tensor_add() {
if !is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let a = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device);
let b = Tensor::<CudaRuntime>::from_slice(&[5.0f32, 6.0, 7.0, 8.0], &[2, 2], &device);
let c = client.add(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2]);
let result: Vec<f32> = c.to_vec();
assert_eq!(result, [6.0, 8.0, 10.0, 12.0]);
}
#[test]
fn test_cuda_tensor_matmul_2x2() {
if !is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let a = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device);
let b = Tensor::<CudaRuntime>::from_slice(&[5.0f32, 6.0, 7.0, 8.0], &[2, 2], &device);
let c = client.matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2]);
let result: Vec<f32> = c.to_vec();
assert_eq!(result, [19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_cuda_tensor_matmul_3x2_2x4() {
if !is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let a =
Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], &device);
let b = Tensor::<CudaRuntime>::from_slice(
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&[2, 4],
&device,
);
let c = client.matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[3, 4]);
let result: Vec<f32> = c.to_vec();
assert_eq!(
result,
[
11.0, 14.0, 17.0, 20.0, 23.0, 30.0, 37.0, 44.0, 35.0, 46.0, 57.0, 68.0
]
);
}
#[test]
fn test_cuda_tensor_relu() {
if !is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let a = Tensor::<CudaRuntime>::from_slice(&[-1.0f32, 0.0, 1.0, -2.0], &[4], &device);
let b = client.relu(&a).unwrap();
let result: Vec<f32> = b.to_vec();
assert_eq!(result, [0.0, 0.0, 1.0, 0.0]);
}
#[test]
fn test_cuda_tensor_sum() {
if !is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let a =
Tensor::<CudaRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &device);
let b = client.sum(&a, &[1], false).unwrap();
assert_eq!(b.shape(), &[2]);
let result: Vec<f32> = b.to_vec();
assert_eq!(result, [6.0, 15.0]);
}
#[test]
fn test_cuda_tensor_silu() {
if !is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let a = Tensor::<CudaRuntime>::from_slice(&[-2.0f32, -1.0, 0.0, 1.0, 2.0], &[5], &device);
let b = client.silu(&a).unwrap();
let result: Vec<f32> = b.to_vec();
assert!((result[2] - 0.0).abs() < 1e-5); assert!((result[3] - 0.7310586).abs() < 1e-4); assert!((result[1] - (-0.2689414)).abs() < 1e-4); }
#[test]
fn test_cuda_tensor_gelu() {
if !is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let a = Tensor::<CudaRuntime>::from_slice(&[-2.0f32, -1.0, 0.0, 1.0, 2.0], &[5], &device);
let b = client.gelu(&a).unwrap();
let result: Vec<f32> = b.to_vec();
assert!((result[2] - 0.0).abs() < 1e-5); assert!((result[3] - 0.8413).abs() < 0.01); assert!((result[4] - 1.9545).abs() < 0.01); }
#[test]
fn test_cuda_tensor_rms_norm() {
if !is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let input = Tensor::<CudaRuntime>::from_slice(
&[1.0f32, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0],
&[2, 4],
&device,
);
let weight = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device);
let out = client.rms_norm(&input, &weight, 1e-5).unwrap();
let result: Vec<f32> = out.to_vec();
let rms1 = (30.0f32 / 4.0 + 1e-5).sqrt();
assert!((result[0] - 1.0 / rms1).abs() < 1e-3); assert!((result[1] - 2.0 / rms1).abs() < 1e-3);
assert!((result[2] - 3.0 / rms1).abs() < 1e-3);
assert!((result[3] - 4.0 / rms1).abs() < 1e-3);
let rms2 = (120.0f32 / 4.0 + 1e-5).sqrt();
assert!((result[4] - 2.0 / rms2).abs() < 1e-3);
}
#[test]
fn test_cuda_tensor_layer_norm() {
if !is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let input = Tensor::<CudaRuntime>::from_slice(
&[1.0f32, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0],
&[2, 4],
&device,
);
let weight = Tensor::<CudaRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device);
let bias = Tensor::<CudaRuntime>::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[4], &device);
let out = client.layer_norm(&input, &weight, &bias, 1e-5).unwrap();
let result: Vec<f32> = out.to_vec();
let mean1 = 2.5f32;
let var1 = ((1.0 - mean1).powi(2)
+ (2.0 - mean1).powi(2)
+ (3.0 - mean1).powi(2)
+ (4.0 - mean1).powi(2))
/ 4.0;
let std1 = (var1 + 1e-5).sqrt();
assert!((result[0] - (1.0 - mean1) / std1).abs() < 1e-3); assert!((result[1] - (2.0 - mean1) / std1).abs() < 1e-3);
assert!((result[2] - (3.0 - mean1) / std1).abs() < 1e-3);
assert!((result[3] - (4.0 - mean1) / std1).abs() < 1e-3);
let row1_sum: f32 = result[0..4].iter().sum();
assert!(row1_sum.abs() < 1e-3);
}
#[test]
fn test_cuda_tensor_argmax() {
if !is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let a =
Tensor::<CudaRuntime>::from_slice(&[1.0f32, 5.0, 3.0, 4.0, 2.0, 6.0], &[2, 3], &device);
let out = client.argmax(&a, 1, false).unwrap();
let result: Vec<i64> = out.to_vec();
assert_eq!(out.shape(), &[2]);
assert_eq!(result, [1, 2]);
let out = client.argmax(&a, 0, false).unwrap();
let result: Vec<i64> = out.to_vec();
assert_eq!(out.shape(), &[3]);
assert_eq!(result, [1, 0, 1]);
let out = client.argmax(&a, 1, true).unwrap();
let result: Vec<i64> = out.to_vec();
assert_eq!(out.shape(), &[2, 1]);
assert_eq!(result, [1, 2]);
}
#[test]
fn test_cuda_tensor_argmin() {
if !is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let a =
Tensor::<CudaRuntime>::from_slice(&[1.0f32, 5.0, 3.0, 4.0, 2.0, 6.0], &[2, 3], &device);
let out = client.argmin(&a, 1, false).unwrap();
let result: Vec<i64> = out.to_vec();
assert_eq!(out.shape(), &[2]);
assert_eq!(result, [0, 1]);
let out = client.argmin(&a, 0, false).unwrap();
let result: Vec<i64> = out.to_vec();
assert_eq!(out.shape(), &[3]);
assert_eq!(result, [0, 1, 0]);
let out = client.argmin(&a, 1, true).unwrap();
let result: Vec<i64> = out.to_vec();
assert_eq!(out.shape(), &[2, 1]);
assert_eq!(result, [0, 1]);
}
}