pub mod backward;
pub mod buffer;
pub mod device;
pub mod kernel;
pub mod llm;
pub mod matmul;
pub mod module;
pub mod nn;
pub mod realize;
pub mod reduce;
pub mod safetensors;
pub mod tensor;
pub mod train;
pub use backward::{forward_backward_gpu, fused_forward_backward, FusedKernel};
pub use buffer::GpuBuffer;
pub use device::{GpuDevice, GpuError};
pub use kernel::KernelCache;
pub use llm::{GpuCausalAttention, GpuEmbedding, GpuInterleavedRoPE, GpuKVCache, GpuRMSNorm, GpuRoPE, GpuSwiGLU, GpuTrainTransformerBlock, GpuTrainTransformer, kv_attention_fused, swiglu_fused_pub, interleaved_rope_backward};
pub use module::{GpuAdam, GpuLinear, GpuModule, GpuTrainModule};
pub use nn::{add_tensors, bias_add, gelu, relu, relu_backward, softmax, GpuAttention, GpuLayerNorm, GpuTransformerBlock};
pub use train::{gpu_cross_entropy_loss, gpu_mse_loss, GradScaler, GpuDataLoader, GpuReLULayer, GpuSequential, GpuTanhLayer, GpuTrainer};
pub use realize::{map_elementwise, map_elementwise_multi};
pub use reduce::{reduce_max, reduce_mean, reduce_sum, reduce_sum_all};
pub use safetensors::{load_safetensors, save_safetensors};
pub use tensor::GpuTensor;
#[cfg(test)]
mod tests {
use super::*;
fn get_device() -> GpuDevice {
GpuDevice::new_sync().expect("GPU device required for tests")
}
#[test]
fn device_creation() {
let _device = get_device();
}
#[test]
fn buffer_roundtrip() {
let device = get_device();
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let buf = GpuBuffer::from_slice(&device, &data);
let result = buf.to_vec_sync(&device);
assert_eq!(result, data);
}
#[test]
fn tensor_roundtrip() {
let device = get_device();
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = GpuTensor::from_slice(&device, &data, &[2, 3]);
assert_eq!(t.shape(), &[2, 3]);
assert_eq!(t.numel(), 6);
let result = t.to_vec_sync(&device);
assert_eq!(result, data);
}
#[test]
fn elementwise_add() {
let device = get_device();
let mut cache = KernelCache::new();
let a = GpuTensor::from_slice(&device, &[1.0, 2.0, 3.0], &[3]);
let b = GpuTensor::from_slice(&device, &[4.0, 5.0, 6.0], &[3]);
let c = map_elementwise(&device, &mut cache, &[&a, &b], |args| args[0] + args[1]);
let result = c.to_vec_sync(&device);
assert_eq!(result, vec![5.0, 7.0, 9.0]);
}
#[test]
fn elementwise_fused() {
let device = get_device();
let mut cache = KernelCache::new();
let a = GpuTensor::from_slice(&device, &[1.0, 2.0, 3.0], &[3]);
let b = GpuTensor::from_slice(&device, &[4.0, 5.0, 6.0], &[3]);
let c = map_elementwise(&device, &mut cache, &[&a, &b], |args| {
let sum = args[0] + args[1];
sum * sum
});
let result = c.to_vec_sync(&device);
assert_eq!(result, vec![25.0, 49.0, 81.0]);
}
#[test]
fn elementwise_relu() {
let device = get_device();
let mut cache = KernelCache::new();
let a = GpuTensor::from_slice(&device, &[-1.0, 0.0, 1.0, 2.0], &[4]);
let c = relu(&device, &mut cache, &a);
let result = c.to_vec_sync(&device);
assert_eq!(result, vec![0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn matmul_2x2() {
let device = get_device();
let mut cache = KernelCache::new();
let a = GpuTensor::from_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = GpuTensor::from_slice(&device, &[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = matmul::matmul(&device, &mut cache, &a, &b);
let result = c.to_vec_sync(&device);
assert_eq!(c.shape(), &[2, 2]);
assert!((result[0] - 19.0).abs() < 0.01);
assert!((result[1] - 22.0).abs() < 0.01);
assert!((result[2] - 43.0).abs() < 0.01);
assert!((result[3] - 50.0).abs() < 0.01);
}
#[test]
fn reduce_sum_axis0() {
let device = get_device();
let mut cache = KernelCache::new();
let t = GpuTensor::from_slice(&device, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let s = reduce_sum(&device, &mut cache, &t, 0);
let result = s.to_vec_sync(&device);
assert_eq!(s.shape(), &[3]);
assert_eq!(result, vec![5.0, 7.0, 9.0]);
}
#[test]
fn reduce_sum_axis1() {
let device = get_device();
let mut cache = KernelCache::new();
let t = GpuTensor::from_slice(&device, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let s = reduce_sum(&device, &mut cache, &t, 1);
let result = s.to_vec_sync(&device);
assert_eq!(s.shape(), &[2]);
assert_eq!(result, vec![6.0, 15.0]);
}
#[test]
fn fused_forward_backward_quadratic() {
let device = get_device();
let mut cache = KernelCache::new();
let kernel = fused_forward_backward(3, |vars| {
vars[0] * vars[0] + vars[1] * vars[1] + vars[2] * vars[2]
});
let (loss, grads) = kernel.run(&device, &mut cache, &[1.0, 2.0, 3.0]);
assert!((loss - 14.0).abs() < 0.01, "loss = {loss}");
assert!((grads[0] - 2.0).abs() < 0.01, "grad[0] = {}", grads[0]);
assert!((grads[1] - 4.0).abs() < 0.01, "grad[1] = {}", grads[1]);
assert!((grads[2] - 6.0).abs() < 0.01, "grad[2] = {}", grads[2]);
}
#[test]
fn fused_forward_backward_product() {
let device = get_device();
let mut cache = KernelCache::new();
let kernel = fused_forward_backward(2, |vars| vars[0] * vars[1]);
let (loss, grads) = kernel.run(&device, &mut cache, &[3.0, 5.0]);
assert!((loss - 15.0).abs() < 0.01, "loss = {loss}");
assert!((grads[0] - 5.0).abs() < 0.01, "grad[0] = {}", grads[0]);
assert!((grads[1] - 3.0).abs() < 0.01, "grad[1] = {}", grads[1]);
}
#[test]
fn softmax_sums_to_one() {
let device = get_device();
let mut cache = KernelCache::new();
let input = GpuTensor::from_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[4]);
let output = softmax(&device, &mut cache, &input);
let result = output.to_vec_sync(&device);
let sum: f32 = result.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"softmax sum = {sum}, expected 1.0"
);
for i in 1..result.len() {
assert!(result[i] > result[i - 1]);
}
}
#[test]
fn softmax_2d_rows() {
let device = get_device();
let mut cache = KernelCache::new();
let input = GpuTensor::from_slice(&device, &[1.0, 2.0, 3.0, 10.0, 0.0, 0.0], &[2, 3]);
let output = softmax(&device, &mut cache, &input);
let result = output.to_vec_sync(&device);
let row1_sum: f32 = result[0..3].iter().sum();
let row2_sum: f32 = result[3..6].iter().sum();
assert!((row1_sum - 1.0).abs() < 1e-5, "row1 sum = {row1_sum}");
assert!((row2_sum - 1.0).abs() < 1e-5, "row2 sum = {row2_sum}");
assert!(result[3] > 0.99, "row2[0] should dominate, got {}", result[3]);
}
#[test]
fn layer_norm_output() {
let device = get_device();
let mut cache = KernelCache::new();
let ln = GpuLayerNorm::new(&device, 4, 1e-5);
let input = GpuTensor::from_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[4]);
let output = ln.forward(&device, &mut cache, &input);
let result = output.to_vec_sync(&device);
let mean: f32 = result.iter().sum::<f32>() / result.len() as f32;
assert!(mean.abs() < 1e-5, "layernorm mean = {mean}");
let var: f32 =
result.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / result.len() as f32;
let std = var.sqrt();
assert!(
(std - 1.0).abs() < 0.1,
"layernorm std = {std}, expected ~1.0"
);
}
#[test]
fn linear_forward() {
let device = get_device();
let mut cache = KernelCache::new();
let linear = GpuLinear::new(
&device,
&[1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
&[0.0, 0.0, 0.0],
2,
3,
);
let input = GpuTensor::from_slice(&device, &[3.0, 4.0], &[2]);
let output = linear.forward(&device, &mut cache, &input);
let result = output.to_vec_sync(&device);
assert!((result[0] - 3.0).abs() < 0.01);
assert!((result[1] - 4.0).abs() < 0.01);
assert!((result[2] - 7.0).abs() < 0.01);
}
#[test]
fn safetensors_roundtrip() {
let device = get_device();
let mut tensors = std::collections::HashMap::new();
tensors.insert(
"weight".to_string(),
GpuTensor::from_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[2, 2]),
);
tensors.insert(
"bias".to_string(),
GpuTensor::from_slice(&device, &[0.5, 1.5], &[2]),
);
let path = std::env::temp_dir().join("tang_gpu_test.safetensors");
save_safetensors(&tensors, &device, &path).unwrap();
let loaded = load_safetensors(&device, &path).unwrap();
assert_eq!(loaded.len(), 2);
let w = loaded.get("weight").unwrap();
assert_eq!(w.shape(), &[2, 2]);
let w_data = w.to_vec_sync(&device);
assert_eq!(w_data, vec![1.0, 2.0, 3.0, 4.0]);
let b = loaded.get("bias").unwrap();
assert_eq!(b.shape(), &[2]);
let b_data = b.to_vec_sync(&device);
assert_eq!(b_data, vec![0.5, 1.5]);
std::fs::remove_file(&path).ok();
}
#[test]
fn adam_step_changes_params() {
let device = get_device();
let mut cache = KernelCache::new();
let mut adam = GpuAdam::new(0.01);
let mut param = GpuTensor::from_slice(&device, &[1.0, 2.0, 3.0], &[3]);
let grads = vec![GpuTensor::from_slice(&device, &[0.1, 0.2, 0.3], &[3])];
let before = param.to_vec_sync(&device);
adam.step(&device, &mut cache, &mut [&mut param], &grads);
let after = param.to_vec_sync(&device);
assert_ne!(before, after);
for i in 0..3 {
assert!(after[i] < before[i], "param[{i}] should decrease");
}
}
}