use crate::autograd::{BackwardOp, Tensor};
use crate::trace::{TraceStep, TRACER};
use ndarray::Array1;
use std::cell::RefCell;
use std::rc::Rc;
#[cfg(all(feature = "realizar", feature = "cuda"))]
use std::sync::atomic::{AtomicBool, Ordering};
#[cfg(all(feature = "realizar", feature = "cuda"))]
use std::sync::{Mutex, OnceLock};
#[cfg(all(feature = "realizar", feature = "cuda"))]
use realizar::cuda::CudaExecutor;
#[cfg(all(feature = "realizar", feature = "cuda"))]
static CUDA_MATMUL_DISABLED: AtomicBool = AtomicBool::new(false);
#[cfg(all(feature = "realizar", feature = "cuda"))]
static CUDA_EXECUTOR: OnceLock<Option<Mutex<CudaExecutor>>> = OnceLock::new();
#[cfg(all(feature = "realizar", feature = "cuda"))]
fn get_cuda_executor() -> Option<&'static Mutex<CudaExecutor>> {
CUDA_EXECUTOR
.get_or_init(|| match CudaExecutor::new(0) {
Ok(executor) => {
TRACER.end(TraceStep::Transfer, "realizar CUDA executor initialized on GPU 0");
Some(Mutex::new(executor))
}
Err(_e) => {
CUDA_MATMUL_DISABLED.store(true, Ordering::Relaxed);
None
}
})
.as_ref()
}
#[inline]
pub fn transpose(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
TRACER.start(TraceStep::Transpose);
let mut transposed = vec![0.0f32; rows * cols];
const BLOCK_SIZE: usize = 32;
if rows >= BLOCK_SIZE && cols >= BLOCK_SIZE {
transpose_blocked(data, &mut transposed, rows, cols, BLOCK_SIZE);
} else {
transpose_simple(data, &mut transposed, rows, cols);
}
TRACER.end(TraceStep::Transpose, format!("{rows}x{cols}"));
transposed
}
pub fn transpose_tracked(tensor: &Tensor, rows: usize, cols: usize) -> Tensor {
let data = tensor.data();
let slice = data.as_slice().expect("transpose_tracked: tensor must be contiguous");
let transposed_data = transpose(slice, rows, cols);
let mut result = Tensor::from_vec(transposed_data, tensor.requires_grad());
if tensor.requires_grad() {
let backward_op = Rc::new(TransposeBackward {
original: tensor.clone(),
rows,
cols,
result_grad: result.grad_cell(),
});
result.set_backward_op(backward_op);
}
result
}
struct TransposeBackward {
original: Tensor,
rows: usize,
cols: usize,
result_grad: Rc<RefCell<Option<Array1<f32>>>>,
}
impl BackwardOp for TransposeBackward {
fn backward(&self) {
if let Some(grad) = self.result_grad.borrow().as_ref() {
let grad_slice = grad.as_slice().expect("gradient must be contiguous");
let grad_original = transpose(grad_slice, self.cols, self.rows);
self.original.accumulate_grad(Array1::from(grad_original));
if let Some(op) = self.original.backward_op() {
op.backward();
}
}
}
}
#[inline]
fn transpose_blocked(src: &[f32], dst: &mut [f32], rows: usize, cols: usize, block: usize) {
for r_block in (0..rows).step_by(block) {
for c_block in (0..cols).step_by(block) {
let r_end = (r_block + block).min(rows);
let c_end = (c_block + block).min(cols);
for r in r_block..r_end {
for c in c_block..c_end {
dst[c * rows + r] = src[r * cols + c];
}
}
}
}
}
#[inline]
fn transpose_simple(src: &[f32], dst: &mut [f32], rows: usize, cols: usize) {
for r in 0..rows {
for c in 0..cols {
dst[c * rows + r] = src[r * cols + c];
}
}
}
#[cfg(all(feature = "realizar", feature = "cuda"))]
pub fn matmul_compute(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
if !CUDA_MATMUL_DISABLED.load(Ordering::Relaxed) {
if let Some(executor_mutex) = get_cuda_executor() {
if let Ok(mut executor) = executor_mutex.lock() {
match cuda_matmul(&mut executor, a, b, m, k, n) {
Ok(result) => return result,
Err(_e) => {
CUDA_MATMUL_DISABLED.store(true, Ordering::Relaxed);
TRACER.end(
TraceStep::Matmul,
"realizar CUDA matmul disabled (JIT failure), using trueno SIMD",
);
}
}
}
}
}
#[cfg(feature = "gpu")]
if !WGPU_BATCH_MODE.load(std::sync::atomic::Ordering::Relaxed) && m * k * n > 32_768 {
if let Some(result) = wgpu_matmul(a, b, m, k, n) {
return result;
}
}
cpu_matmul(a, b, m, k, n)
}
#[cfg(all(feature = "realizar", feature = "cuda"))]
pub fn pre_warm_realizador_gemm(
seq_len: usize,
hidden_size: usize,
kv_hidden_size: usize,
intermediate_size: usize,
lora_rank: usize,
num_classes: usize,
) -> usize {
let executor_mutex = match get_cuda_executor() {
Some(e) => e,
None => return 0,
};
let mut executor = match executor_mutex.lock() {
Ok(e) => e,
Err(_) => return 0,
};
let s = seq_len;
let h = hidden_size;
let kv = kv_hidden_size;
let i = intermediate_size;
let r = lora_rank;
let mut shapes: Vec<(usize, usize, usize)> = vec![
(s, h, h), (s, h, kv), (s, h, i), (s, i, h), (s, h, r), (s, r, h), (s, kv, r), (s, r, kv), (s, kv, h), (s, i, h), (s, h, i), (h, s, h), (h, s, kv), (h, s, i), (i, s, h), (s, r, h), (h, s, r), (s, h, r), (r, s, h), (r, s, kv), (1, h, num_classes),
];
shapes.sort_unstable();
shapes.dedup();
shapes.retain(|&(m, k, n)| m > 0 && k > 0 && n > 0);
let mut warmed = 0usize;
for &(m, k, n) in &shapes {
let a = vec![0.0f32; m * k];
let b = vec![0.0f32; k * n];
match cuda_matmul(&mut executor, &a, &b, m, k, n) {
Ok(_) => warmed += 1,
Err(e) => {
eprintln!("[CUDA] realizador GEMM pre-warm failed for ({m},{k},{n}): {e}");
}
}
}
if warmed == 0 {
CUDA_MATMUL_DISABLED.store(true, Ordering::Relaxed);
}
warmed
}
#[cfg(all(feature = "realizar", feature = "cuda"))]
fn cuda_matmul(
executor: &mut CudaExecutor,
a: &[f32],
b: &[f32],
m: usize,
k: usize,
n: usize,
) -> Result<Vec<f32>, String> {
TRACER.start(TraceStep::Alloc);
let mut c = vec![0.0f32; m * n];
TRACER.end(TraceStep::Alloc, format!("{m}x{n}"));
TRACER.start(TraceStep::Matmul);
executor.gemm(a, b, &mut c, m as u32, n as u32, k as u32).map_err(|e| format!("{e:?}"))?;
TRACER.end(TraceStep::Matmul, format!("{m}x{k}x{n}"));
Ok(c)
}
fn cpu_matmul(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
let mut c = vec![0.0f32; m * n];
if let Err(_e) = trueno::blis::gemm(m, n, k, a, b, &mut c) {
for i in 0..m {
for j in 0..n {
let mut sum = 0.0;
for p in 0..k {
sum += a[i * k + p] * b[p * n + j];
}
c[i * n + j] = sum;
}
}
}
c
}
#[cfg(feature = "gpu")]
static WGPU_BATCH_MODE: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
#[cfg(feature = "gpu")]
pub fn suppress_per_op_wgpu() {
WGPU_BATCH_MODE.store(true, std::sync::atomic::Ordering::Relaxed);
}
#[cfg(feature = "gpu")]
pub fn unsuppress_per_op_wgpu() {
WGPU_BATCH_MODE.store(false, std::sync::atomic::Ordering::Relaxed);
}
#[cfg(not(all(feature = "realizar", feature = "cuda")))]
pub fn matmul_compute(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
#[cfg(feature = "gpu")]
{
if !WGPU_BATCH_MODE.load(std::sync::atomic::Ordering::Relaxed) && m * k * n > 32_768 {
if let Some(result) = wgpu_matmul(a, b, m, k, n) {
return result;
}
}
}
cpu_matmul(a, b, m, k, n)
}
#[cfg(feature = "gpu")]
fn wgpu_matmul(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Option<Vec<f32>> {
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::OnceLock;
static WGPU_DISABLED: AtomicBool = AtomicBool::new(false);
static WGPU_LOGGED: AtomicBool = AtomicBool::new(false);
static WGPU_CALLS: AtomicU64 = AtomicU64::new(0);
static WGPU_DEVICE: OnceLock<Option<trueno::backends::gpu::GpuDevice>> = OnceLock::new();
if WGPU_DISABLED.load(Ordering::Relaxed) {
return None;
}
let device_opt = WGPU_DEVICE.get_or_init(|| {
if !trueno::backends::gpu::GpuBackend::is_available() {
eprintln!("[wgpu] No GPU available, using CPU");
return None;
}
match trueno::backends::gpu::GpuDevice::new() {
Ok(d) => {
eprintln!("[wgpu] GPU device initialized for matmul");
Some(d)
}
Err(e) => {
eprintln!("[wgpu] GPU init failed: {e}, using CPU");
None
}
}
});
let device = match device_opt.as_ref() {
Some(d) => d,
None => {
WGPU_DISABLED.store(true, Ordering::Relaxed);
return None;
}
};
let mut result = vec![0.0f32; m * n];
match device.matmul(a, b, &mut result, m, k, n) {
Ok(()) => {
let calls = WGPU_CALLS.fetch_add(1, Ordering::Relaxed);
if !WGPU_LOGGED.swap(true, Ordering::Relaxed) {
eprintln!("[wgpu] GPU matmul active ({m}x{k}x{n})");
}
if calls > 0 && calls.is_multiple_of(10_000) {
eprintln!("[wgpu] {calls} GPU matmuls completed");
}
Some(result)
}
Err(_e) => {
WGPU_DISABLED.store(true, Ordering::Relaxed);
None
}
}
}
#[provable_contracts_macros::contract("matmul-v1", equation = "matmul")]
pub fn matmul(a: &Tensor, b: &Tensor, m: usize, k: usize, n: usize) -> Tensor {
assert_eq!(a.len(), m * k, "Matrix A size mismatch");
assert_eq!(b.len(), k * n, "Matrix B size mismatch");
let result_data = matmul_compute(
a.data().as_slice().expect("matrix A must be contiguous"),
b.data().as_slice().expect("matrix B must be contiguous"),
m,
k,
n,
);
let requires_grad = a.requires_grad() || b.requires_grad();
let mut result = Tensor::new(Array1::from(result_data), requires_grad);
if requires_grad {
let a_clone = a.clone();
let b_clone = b.clone();
let backward_op = Rc::new(MatmulBackward {
a: a_clone,
b: b_clone,
m,
k,
n,
result_grad: result.grad_cell(),
});
result.set_backward_op(backward_op);
}
result
}
struct MatmulBackward {
a: Tensor,
b: Tensor,
m: usize,
k: usize,
n: usize,
result_grad: Rc<RefCell<Option<Array1<f32>>>>,
}
impl BackwardOp for MatmulBackward {
fn backward(&self) {
if let Some(grad_output) = self.result_grad.borrow().as_ref() {
let grad_c = grad_output.as_slice().expect("gradient output must be contiguous");
let a_data = self.a.data();
let b_data = self.b.data();
let a_slice = a_data.as_slice().expect("matrix A must be contiguous");
let b_slice = b_data.as_slice().expect("matrix B must be contiguous");
if self.a.requires_grad() {
let b_t = transpose(b_slice, self.k, self.n);
let grad_a = matmul_compute(grad_c, &b_t, self.m, self.n, self.k);
self.a.accumulate_grad(Array1::from(grad_a));
}
if self.b.requires_grad() {
let a_t = transpose(a_slice, self.m, self.k);
let grad_b = matmul_compute(&a_t, grad_c, self.k, self.m, self.n);
self.b.accumulate_grad(Array1::from(grad_b));
}
if let Some(op) = self.a.backward_op() {
op.backward();
}
if let Some(op) = self.b.backward_op() {
op.backward();
}
}
}
}
#[provable_contracts_macros::contract("matmul-v1", equation = "matmul_nt")]
pub fn matmul_nt(a: &Tensor, b: &Tensor, m: usize, k: usize, n: usize) -> Tensor {
assert_eq!(
a.len(),
m * k,
"Matrix A size mismatch: expected {}×{} = {}, got {}",
m,
k,
m * k,
a.len()
);
assert_eq!(
b.len(),
n * k,
"Matrix B size mismatch: expected {}×{} = {}, got {}",
n,
k,
n * k,
b.len()
);
let a_slice = a.data();
let b_slice = b.data();
let a_data = a_slice.as_slice().expect("matrix A must be contiguous");
let b_data = b_slice.as_slice().expect("matrix B must be contiguous");
let result_data = matmul_nt_compute(a_data, b_data, m, k, n);
let requires_grad = a.requires_grad() || b.requires_grad();
let mut result = Tensor::new(Array1::from(result_data), requires_grad);
if requires_grad {
let a_clone = a.clone();
let b_clone = b.clone();
let backward_op = Rc::new(MatmulNtBackward {
a: a_clone,
b: b_clone,
m,
k,
n,
result_grad: result.grad_cell(),
});
result.set_backward_op(backward_op);
}
result
}
pub fn matmul_nt_compute(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
let b_t = transpose(b, n, k); cpu_matmul(a, &b_t, m, k, n)
}
struct MatmulNtBackward {
a: Tensor,
b: Tensor,
m: usize,
k: usize,
n: usize,
result_grad: Rc<RefCell<Option<Array1<f32>>>>,
}
impl BackwardOp for MatmulNtBackward {
fn backward(&self) {
if let Some(grad_output) = self.result_grad.borrow().as_ref() {
let grad_c = grad_output.as_slice().expect("gradient output must be contiguous");
if self.a.requires_grad() {
let b_data = self.b.data();
let b_slice = b_data.as_slice().expect("matrix B must be contiguous");
let grad_a = matmul_compute(grad_c, b_slice, self.m, self.n, self.k);
self.a.accumulate_grad(Array1::from(grad_a));
}
if self.b.requires_grad() {
let a_data = self.a.data();
let a_slice = a_data.as_slice().expect("matrix A must be contiguous");
let grad_c_t = transpose(grad_c, self.m, self.n);
let grad_b = matmul_compute(&grad_c_t, a_slice, self.n, self.m, self.k);
self.b.accumulate_grad(Array1::from(grad_b));
}
if let Some(op) = self.a.backward_op() {
op.backward();
}
if let Some(op) = self.b.backward_op() {
op.backward();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transpose_identity() {
let data = vec![5.0];
let result = transpose(&data, 1, 1);
assert_eq!(result, vec![5.0]);
}
#[test]
fn test_transpose_2x3() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let result = transpose(&data, 2, 3);
assert_eq!(result, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_transpose_3x2() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let result = transpose(&data, 3, 2);
assert_eq!(result, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
}
#[test]
fn test_matmul_compute_2x2() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let c = matmul_compute(&a, &b, 2, 2, 2);
assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_matmul_compute_2x3_3x2() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let c = matmul_compute(&a, &b, 2, 3, 2);
assert_eq!(c, vec![58.0, 64.0, 139.0, 154.0]);
}
#[test]
fn test_matmul_no_grad() {
let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), false);
let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), false);
let c = matmul(&a, &b, 2, 2, 2);
assert!(!c.requires_grad());
assert_eq!(
c.data().as_slice().expect("operation should succeed"),
&[19.0, 22.0, 43.0, 50.0]
);
}
#[test]
fn test_matmul_with_grad() {
let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), true);
let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), true);
let c = matmul(&a, &b, 2, 2, 2);
assert!(c.requires_grad());
assert!(c.backward_op().is_some());
}
#[test]
fn test_matmul_backward() {
let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), true);
let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), true);
let c = matmul(&a, &b, 2, 2, 2);
c.set_grad(Array1::from(vec![1.0, 1.0, 1.0, 1.0]));
if let Some(op) = c.backward_op() {
op.backward();
}
assert!(a.grad().is_some());
assert!(b.grad().is_some());
}
#[test]
fn test_matmul_a_requires_grad_only() {
let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), true);
let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), false);
let c = matmul(&a, &b, 2, 2, 2);
assert!(c.requires_grad());
c.set_grad(Array1::from(vec![1.0, 1.0, 1.0, 1.0]));
if let Some(op) = c.backward_op() {
op.backward();
}
assert!(a.grad().is_some());
assert!(b.grad().is_none());
}
#[test]
fn test_matmul_b_requires_grad_only() {
let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), false);
let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), true);
let c = matmul(&a, &b, 2, 2, 2);
assert!(c.requires_grad());
c.set_grad(Array1::from(vec![1.0, 1.0, 1.0, 1.0]));
if let Some(op) = c.backward_op() {
op.backward();
}
assert!(a.grad().is_none());
assert!(b.grad().is_some());
}
#[test]
#[should_panic(expected = "Contract [matmul] Pre-condition violated: a.len() == m * k")]
fn test_matmul_size_mismatch_a() {
let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0]), false);
let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), false);
let _ = matmul(&a, &b, 2, 2, 2);
}
#[test]
#[should_panic(expected = "Contract [matmul] Pre-condition violated: b.len() == k * n")]
fn test_matmul_size_mismatch_b() {
let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), false);
let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0]), false);
let _ = matmul(&a, &b, 2, 2, 2);
}
#[test]
fn test_transpose_double_transpose() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let t1 = transpose(&data, 2, 3);
let t2 = transpose(&t1, 3, 2);
assert_eq!(data, t2);
}
#[test]
fn falsify_mm_001e_shape_correctness() {
for (m, k, n) in [(2, 3, 4), (1, 5, 1), (4, 4, 4), (3, 1, 2)] {
let result = matmul_compute(&vec![1.0; m * k], &vec![1.0; k * n], m, k, n);
assert_eq!(
result.len(),
m * n,
"FALSIFIED MM-001e: output len = {}, expected {} for ({m}x{k}) @ ({k}x{n})",
result.len(),
m * n
);
}
}
#[test]
fn falsify_mm_005e_identity_matrix() {
let m = 3;
let k = 4;
let a: Vec<f32> = (0..m * k).map(|i| (i as f32 + 1.0) * 0.5).collect();
let mut identity = vec![0.0; k * k];
for i in 0..k {
identity[i * k + i] = 1.0;
}
let result = matmul_compute(&a, &identity, m, k, k);
for (i, (&got, &exp)) in result.iter().zip(a.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-5,
"FALSIFIED MM-005e: (A@I)[{i}] = {got}, expected {exp}"
);
}
}
#[test]
fn falsify_mm_002e_numerical_accuracy() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let result = matmul_compute(&a, &b, 2, 3, 2);
let expected = [58.0, 64.0, 139.0, 154.0];
for (i, (&got, &exp)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-4,
"FALSIFIED MM-002e: result[{i}] = {got}, expected {exp}"
);
}
}
#[test]
fn test_matmul_nt_compute_2x2() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let c = matmul_nt_compute(&a, &b, 2, 2, 2);
assert_eq!(c, vec![17.0, 23.0, 39.0, 53.0]);
}
#[test]
fn test_matmul_nt_compute_2x3_4x3() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
let c = matmul_nt_compute(&a, &b, 2, 3, 4);
assert_eq!(c, vec![1.0, 2.0, 3.0, 6.0, 4.0, 5.0, 6.0, 15.0]);
}
#[test]
fn test_matmul_nt_equivalence_to_transpose_matmul() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]; let b_t = transpose(&b, 2, 3);
let c_nt = matmul_nt_compute(&a, &b, 2, 3, 2);
let c_ref = matmul_compute(&a, &b_t, 2, 3, 2);
for (i, (&got, &exp)) in c_nt.iter().zip(c_ref.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-5,
"matmul_nt[{i}] = {got}, matmul(A, B^T)[{i}] = {exp}"
);
}
}
#[test]
fn test_matmul_nt_backward_grad_flows_to_b() {
let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), false); let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), true);
let c = matmul_nt(&a, &b, 2, 2, 2);
assert!(c.requires_grad());
c.set_grad(Array1::from(vec![1.0, 1.0, 1.0, 1.0]));
if let Some(op) = c.backward_op() {
op.backward();
}
let b_grad = b.grad().expect("KAIZEN-011: B must receive gradient from matmul_nt");
let expected_grad_b = vec![4.0, 6.0, 4.0, 6.0];
for (i, (&got, &exp)) in b_grad.iter().zip(expected_grad_b.iter()).enumerate() {
assert!((got - exp).abs() < 1e-4, "KAIZEN-011: grad_B[{i}] = {got}, expected {exp}");
}
}
#[test]
fn test_matmul_nt_backward_grad_flows_to_a() {
let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), true); let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), false);
let c = matmul_nt(&a, &b, 2, 2, 2);
c.set_grad(Array1::from(vec![1.0, 1.0, 1.0, 1.0]));
if let Some(op) = c.backward_op() {
op.backward();
}
let a_grad = a.grad().expect("A must receive gradient");
let expected_grad_a = vec![12.0, 14.0, 12.0, 14.0];
for (i, (&got, &exp)) in a_grad.iter().zip(expected_grad_a.iter()).enumerate() {
assert!((got - exp).abs() < 1e-4, "grad_A[{i}] = {got}, expected {exp}");
}
}
mod mm_proptest_falsify {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn falsify_mm_001e_prop_shape(
m in 1..=8usize,
k in 1..=8usize,
n in 1..=8usize,
) {
let result = matmul_compute(&vec![1.0; m * k], &vec![1.0; k * n], m, k, n);
prop_assert_eq!(result.len(), m * n);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn falsify_mm_005e_prop_identity(
m in 1..=6usize,
k in 1..=6usize,
seed in 0..500u32,
) {
let a: Vec<f32> = (0..m * k)
.map(|i| ((i as f32 + seed as f32) * 0.37).sin())
.collect();
let mut identity = vec![0.0; k * k];
for i in 0..k {
identity[i * k + i] = 1.0;
}
let result = matmul_compute(&a, &identity, m, k, k);
for (i, (&got, &exp)) in result.iter().zip(a.iter()).enumerate() {
prop_assert!(
(got - exp).abs() < 1e-4,
"FALSIFIED MM-005e-prop: (A@I)[{}] = {}, expected {}",
i, got, exp
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn falsify_mm_nt_equivalence(
m in 1..=6usize,
k in 1..=6usize,
n in 1..=6usize,
seed in 0..500u32,
) {
let a: Vec<f32> = (0..m * k)
.map(|i| ((i as f32 + seed as f32) * 0.31).sin())
.collect();
let b: Vec<f32> = (0..n * k)
.map(|i| ((i as f32 + seed as f32 + 100.0) * 0.47).cos())
.collect();
let c_nt = matmul_nt_compute(&a, &b, m, k, n);
let b_t = transpose(&b, n, k);
let c_ref = matmul_compute(&a, &b_t, m, k, n);
for (i, (&got, &exp)) in c_nt.iter().zip(c_ref.iter()).enumerate() {
prop_assert!(
(got - exp).abs() < 1e-3,
"matmul_nt[{}] = {}, expected {}",
i, got, exp
);
}
}
}
}
#[test]
fn test_transpose_tracked_backward_gradient_flow() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], true);
let a_t = transpose_tracked(&a, 2, 3);
assert_eq!(a_t.len(), 6);
let at_data = a_t.data();
let at_slice = at_data.as_slice().expect("contiguous");
assert_eq!(at_slice, &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
a_t.set_grad(Array1::from(vec![10.0, 40.0, 20.0, 50.0, 30.0, 60.0]));
if let Some(op) = a_t.backward_op() {
op.backward();
}
let grad = a.grad().expect("original tensor should have gradient");
let grad_slice = grad.as_slice().expect("contiguous");
assert_eq!(grad_slice, &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]);
}
#[test]
fn test_transpose_tracked_lora_gradient_chain() {
let lora_a = Tensor::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], true);
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
let lora_a_t = transpose_tracked(&lora_a, 2, 3);
let result = matmul(&x, &lora_a_t, 1, 3, 2);
assert_eq!(result.len(), 2);
result.set_grad(Array1::from(vec![1.0, 1.0]));
if let Some(op) = result.backward_op() {
op.backward();
}
let grad = lora_a.grad().expect("LoRA A should receive gradient via transpose_tracked");
assert_eq!(grad.len(), 6);
for (i, &val) in grad.as_slice().expect("contiguous").iter().enumerate() {
assert!(val.is_finite(), "Gradient element {i} is not finite: {val}");
}
let grad_sum: f32 = grad.iter().sum();
assert!(grad_sum.abs() > 1e-6, "Gradient should be non-zero, got sum={grad_sum}");
}
}