Skip to main content

entrenar/autograd/ops/
matmul.rs

1//! Matrix multiplication autograd operations
2//!
3//! Uses realizar's CUDA executor for GPU acceleration, falls back to trueno SIMD GEMM on CPU.
4//! Both forward AND backward passes use CUDA GEMM for full GPU acceleration.
5//! Instrumented with TRACER for empirical overhead analysis.
6
7use crate::autograd::{BackwardOp, Tensor};
8use crate::trace::{TraceStep, TRACER};
9use ndarray::Array1;
10use std::cell::RefCell;
11use std::rc::Rc;
12
13#[cfg(all(feature = "realizar", feature = "cuda"))]
14use std::sync::atomic::{AtomicBool, Ordering};
15#[cfg(all(feature = "realizar", feature = "cuda"))]
16use std::sync::{Mutex, OnceLock};
17
18#[cfg(all(feature = "realizar", feature = "cuda"))]
19use realizar::cuda::CudaExecutor;
20
21/// Once a realizador CUDA matmul fails (typically JIT OOM after GPU VRAM is filled
22/// by NF4 block upload), disable all further attempts. Without this flag, every
23/// matmul call re-attempts CUDA, fails, and falls back to CPU — producing thousands
24/// of log lines per training step and adding ~100ms overhead per call.
25#[cfg(all(feature = "realizar", feature = "cuda"))]
26static CUDA_MATMUL_DISABLED: AtomicBool = AtomicBool::new(false);
27
28/// Global CUDA executor (singleton, initialized once)
29#[cfg(all(feature = "realizar", feature = "cuda"))]
30static CUDA_EXECUTOR: OnceLock<Option<Mutex<CudaExecutor>>> = OnceLock::new();
31
32/// Get or initialize CUDA executor
33#[cfg(all(feature = "realizar", feature = "cuda"))]
34fn get_cuda_executor() -> Option<&'static Mutex<CudaExecutor>> {
35    CUDA_EXECUTOR
36        .get_or_init(|| match CudaExecutor::new(0) {
37            Ok(executor) => {
38                TRACER.end(TraceStep::Transfer, "realizar CUDA executor initialized on GPU 0");
39                Some(Mutex::new(executor))
40            }
41            Err(_e) => {
42                CUDA_MATMUL_DISABLED.store(true, Ordering::Relaxed);
43                None
44            }
45        })
46        .as_ref()
47}
48
49/// Transpose a row-major matrix (rows x cols) to (cols x rows)
50/// Uses cache-efficient blocked transpose for large matrices
51#[inline]
52pub fn transpose(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
53    contract_pre_transpose!(data);
54    TRACER.start(TraceStep::Transpose);
55    let mut transposed = vec![0.0f32; rows * cols];
56
57    const BLOCK_SIZE: usize = 32;
58    if rows >= BLOCK_SIZE && cols >= BLOCK_SIZE {
59        transpose_blocked(data, &mut transposed, rows, cols, BLOCK_SIZE);
60    } else {
61        transpose_simple(data, &mut transposed, rows, cols);
62    }
63
64    TRACER.end(TraceStep::Transpose, format!("{rows}x{cols}"));
65    transposed
66}
67
68/// Autograd-aware transpose that preserves the backward chain (KAIZEN-018).
69///
70/// Creates a new tensor with transposed data AND a backward op that
71/// accumulates the inverse-transposed gradient on the original tensor.
72/// This ensures gradient flow through LoRA weight transposes.
73///
74/// # Contract (C-LORA-GRAD-001)
75///
76/// - **Precondition**: `tensor` has shape (rows, cols) in row-major layout
77/// - **Postcondition**: Returns tensor with shape (cols, rows), backward chain connected
78/// - **Invariant**: `original.grad()` receives the correctly transposed gradient
79pub fn transpose_tracked(tensor: &Tensor, rows: usize, cols: usize) -> Tensor {
80    contract_pre_transpose_tracked!();
81    let data = tensor.data();
82    let slice = data.as_slice().expect("transpose_tracked: tensor must be contiguous");
83    let transposed_data = transpose(slice, rows, cols);
84    let mut result = Tensor::from_vec(transposed_data, tensor.requires_grad());
85
86    if tensor.requires_grad() {
87        let backward_op = Rc::new(TransposeBackward {
88            original: tensor.clone(),
89            rows,
90            cols,
91            result_grad: result.grad_cell(),
92        });
93        result.set_backward_op(backward_op);
94    }
95
96    result
97}
98
99/// Backward op for autograd-aware transpose (KAIZEN-018).
100///
101/// Given forward: result = transpose(original, rows, cols)
102/// Backward: grad_original = transpose(grad_result, cols, rows)
103/// (The inverse of an (r,c) transpose is a (c,r) transpose.)
104struct TransposeBackward {
105    original: Tensor,
106    rows: usize,
107    cols: usize,
108    result_grad: Rc<RefCell<Option<Array1<f32>>>>,
109}
110
111impl BackwardOp for TransposeBackward {
112    fn backward(&self) {
113        if let Some(grad) = self.result_grad.borrow().as_ref() {
114            let grad_slice = grad.as_slice().expect("gradient must be contiguous");
115            // Inverse transpose: (cols, rows) → (rows, cols)
116            let grad_original = transpose(grad_slice, self.cols, self.rows);
117            self.original.accumulate_grad(Array1::from(grad_original));
118            if let Some(op) = self.original.backward_op() {
119                op.backward();
120            }
121        }
122    }
123}
124
125/// Blocked transpose for cache efficiency on large matrices.
126#[inline]
127fn transpose_blocked(src: &[f32], dst: &mut [f32], rows: usize, cols: usize, block: usize) {
128    for r_block in (0..rows).step_by(block) {
129        for c_block in (0..cols).step_by(block) {
130            let r_end = (r_block + block).min(rows);
131            let c_end = (c_block + block).min(cols);
132            for r in r_block..r_end {
133                for c in c_block..c_end {
134                    dst[c * rows + r] = src[r * cols + c];
135                }
136            }
137        }
138    }
139}
140
141/// Simple transpose for small matrices.
142#[inline]
143fn transpose_simple(src: &[f32], dst: &mut [f32], rows: usize, cols: usize) {
144    for r in 0..rows {
145        for c in 0..cols {
146            dst[c * rows + r] = src[r * cols + c];
147        }
148    }
149}
150
151/// Compute matrix multiplication using realizar CUDA if available, else SIMD CPU.
152///
153/// After the first CUDA failure (typically JIT OOM when VRAM is occupied by NF4
154/// block uploads), all subsequent calls skip CUDA entirely and use trueno SIMD.
155#[cfg(all(feature = "realizar", feature = "cuda"))]
156pub fn matmul_compute(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
157    contract_pre_matmul!(a);
158    // Fast path: skip CUDA entirely once disabled (common during QLoRA training
159    // where NF4 blocks fill VRAM before realizador can JIT-compile gemm_tiled)
160    if !CUDA_MATMUL_DISABLED.load(Ordering::Relaxed) {
161        if let Some(executor_mutex) = get_cuda_executor() {
162            if let Ok(mut executor) = executor_mutex.lock() {
163                match cuda_matmul(&mut executor, a, b, m, k, n) {
164                    Ok(result) => return result,
165                    Err(_e) => {
166                        // First failure: disable all future CUDA matmul attempts
167                        CUDA_MATMUL_DISABLED.store(true, Ordering::Relaxed);
168                        TRACER.end(
169                            TraceStep::Matmul,
170                            "realizar CUDA matmul disabled (JIT failure), using trueno SIMD",
171                        );
172                    }
173                }
174            }
175        }
176    }
177
178    // wgpu GPU fallback (AMD/Intel/Apple GPUs via Vulkan/Metal/DX12)
179    // KAIZEN-004: Skip per-op wgpu when batched forward pass is active
180    #[cfg(feature = "gpu")]
181    if !WGPU_BATCH_MODE.load(std::sync::atomic::Ordering::Relaxed) && m * k * n > 32_768 {
182        if let Some(result) = wgpu_matmul(a, b, m, k, n) {
183            return result;
184        }
185    }
186
187    // trueno SIMD fallback (rayon-parallel if trueno/parallel enabled)
188    cpu_matmul(a, b, m, k, n)
189}
190
191/// Pre-warm realizador's CUDA GEMM kernels for all training shapes.
192///
193/// Realizador JIT-compiles `gemm_tiled` per unique (M,K,N) shape. If compilation
194/// happens after transformer block upload fills VRAM, JIT fails with
195/// CUDA_ERROR_ILLEGAL_ADDRESS and `CUDA_MATMUL_DISABLED` gets set, forcing ALL
196/// matmul to CPU SIMD (~100x slower).
197///
198/// This function pre-warms with every (M,K,N) triplet used during training:
199/// - Forward: linear projections (Q,K,V,O), FFN (gate,up,down)
200/// - Backward: transposed shapes for grad_A and grad_B
201/// - LoRA: A and B projection shapes
202/// - Classifier head
203///
204/// Call this BEFORE uploading transformer blocks (C-PREWARM-001).
205#[cfg(all(feature = "realizar", feature = "cuda"))]
206pub fn pre_warm_realizador_gemm(
207    seq_len: usize,
208    hidden_size: usize,
209    kv_hidden_size: usize,
210    intermediate_size: usize,
211    lora_rank: usize,
212    num_classes: usize,
213) -> usize {
214    let executor_mutex = match get_cuda_executor() {
215        Some(e) => e,
216        None => return 0,
217    };
218    let mut executor = match executor_mutex.lock() {
219        Ok(e) => e,
220        Err(_) => return 0,
221    };
222
223    // Collect all unique (M, K, N) shapes used during training
224    let s = seq_len;
225    let h = hidden_size;
226    let kv = kv_hidden_size;
227    let i = intermediate_size;
228    let r = lora_rank;
229
230    let mut shapes: Vec<(usize, usize, usize)> = vec![
231        // Forward linear projections
232        (s, h, h),  // Q, O projections
233        (s, h, kv), // K, V projections
234        (s, h, i),  // FFN gate, up
235        (s, i, h),  // FFN down
236        // LoRA forward
237        (s, h, r),  // LoRA A (Q/O/gate/up)
238        (s, r, h),  // LoRA B (Q/O)
239        (s, kv, r), // LoRA A (K/V) — if kv != h
240        (s, r, kv), // LoRA B (K/V)
241        // Backward: grad_A = grad_C @ B^T → (M, N_fwd, K_fwd)
242        // For (s,h,h): grad_A is (s,h,h) — same
243        (s, kv, h), // K/V backward grad_A: (s, kv) @ (kv, h)
244        (s, i, h),  // Gate/Up backward grad_A — same as FFN down forward
245        (s, h, i),  // Down backward grad_A — same as FFN gate forward
246        // Backward: grad_B = A^T @ grad_C → (K_fwd, M, N_fwd)
247        (h, s, h),  // Q/O backward grad_B: (h, s) @ (s, h)
248        (h, s, kv), // K/V backward grad_B: (h, s) @ (s, kv)
249        (h, s, i),  // Gate/Up backward grad_B: (h, s) @ (s, i)
250        (i, s, h),  // Down backward grad_B: (i, s) @ (s, h)
251        // LoRA backward
252        (s, r, h),  // LoRA A backward grad_A — same as LoRA B forward
253        (h, s, r),  // LoRA A backward grad_B
254        (s, h, r),  // LoRA B backward grad_A — same as LoRA A forward
255        (r, s, h),  // LoRA B backward grad_B
256        (r, s, kv), // LoRA B (K/V) backward grad_B
257        // Classifier head
258        (1, h, num_classes),
259    ];
260
261    // Deduplicate
262    shapes.sort_unstable();
263    shapes.dedup();
264    // Remove zero-dimension shapes
265    shapes.retain(|&(m, k, n)| m > 0 && k > 0 && n > 0);
266
267    let mut warmed = 0usize;
268    for &(m, k, n) in &shapes {
269        let a = vec![0.0f32; m * k];
270        let b = vec![0.0f32; k * n];
271        match cuda_matmul(&mut executor, &a, &b, m, k, n) {
272            Ok(_) => warmed += 1,
273            Err(e) => {
274                eprintln!("[CUDA] realizador GEMM pre-warm failed for ({m},{k},{n}): {e}");
275            }
276        }
277    }
278
279    if warmed == 0 {
280        CUDA_MATMUL_DISABLED.store(true, Ordering::Relaxed);
281    }
282
283    warmed
284}
285
286/// CUDA matrix multiplication via realizar's CudaExecutor
287#[cfg(all(feature = "realizar", feature = "cuda"))]
288fn cuda_matmul(
289    executor: &mut CudaExecutor,
290    a: &[f32],
291    b: &[f32],
292    m: usize,
293    k: usize,
294    n: usize,
295) -> Result<Vec<f32>, String> {
296    TRACER.start(TraceStep::Alloc);
297    let mut c = vec![0.0f32; m * n];
298    TRACER.end(TraceStep::Alloc, format!("{m}x{n}"));
299
300    TRACER.start(TraceStep::Matmul);
301    executor.gemm(a, b, &mut c, m as u32, n as u32, k as u32).map_err(|e| format!("{e:?}"))?;
302    TRACER.end(TraceStep::Matmul, format!("{m}x{k}x{n}"));
303    Ok(c)
304}
305
306/// CPU fallback using trueno SIMD GEMM
307fn cpu_matmul(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
308    let mut c = vec![0.0f32; m * n];
309
310    if let Err(_e) = trueno::blis::gemm(m, n, k, a, b, &mut c) {
311        // Naive triple-loop fallback (trueno BLIS should never fail in practice)
312        for i in 0..m {
313            for j in 0..n {
314                let mut sum = 0.0;
315                for p in 0..k {
316                    sum += a[i * k + p] * b[p * n + j];
317                }
318                c[i * n + j] = sum;
319            }
320        }
321    }
322
323    c
324}
325
326/// KAIZEN-004: When WgpuForwardPass is handling the forward pass in batch mode,
327/// suppress per-op wgpu matmul. Attention matmuls go to CPU SIMD instead,
328/// avoiding buffer upload/download overhead and GPU contention with the batched FFN path.
329#[cfg(feature = "gpu")]
330static WGPU_BATCH_MODE: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
331
332/// Suppress per-op wgpu matmul (use CPU SIMD instead).
333///
334/// Call this before running attention on CPU while WgpuForwardPass handles FFN.
335/// Per-op wgpu adds ~3-5ms overhead per matmul (buffer upload/compute/download).
336/// For 144 attention matmuls per sample, that's 430-720ms of pure overhead.
337/// CPU SIMD is equally fast and doesn't compete for GPU bandwidth.
338#[cfg(feature = "gpu")]
339pub fn suppress_per_op_wgpu() {
340    WGPU_BATCH_MODE.store(true, std::sync::atomic::Ordering::Relaxed);
341}
342
343/// Re-enable per-op wgpu matmul.
344#[cfg(feature = "gpu")]
345pub fn unsuppress_per_op_wgpu() {
346    WGPU_BATCH_MODE.store(false, std::sync::atomic::Ordering::Relaxed);
347}
348
349/// CPU/wgpu path (no CUDA feature)
350///
351/// Tries wgpu GPU matmul first (Vulkan/Metal/DX12), falls back to rayon-parallel
352/// trueno BLIS GEMM on CPU. The wgpu path uses trueno's GpuDevice for cross-platform
353/// GPU compute on AMD, Intel, and Apple GPUs.
354#[cfg(not(all(feature = "realizar", feature = "cuda")))]
355pub fn matmul_compute(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
356    #[cfg(feature = "gpu")]
357    {
358        // KAIZEN-004: Skip per-op wgpu when batched forward pass is active.
359        // Attention matmuls use CPU SIMD instead — equally fast, no buffer overhead.
360        if !WGPU_BATCH_MODE.load(std::sync::atomic::Ordering::Relaxed) && m * k * n > 32_768 {
361            if let Some(result) = wgpu_matmul(a, b, m, k, n) {
362                return result;
363            }
364        }
365    }
366    cpu_matmul(a, b, m, k, n)
367}
368
369/// wgpu GPU matmul via trueno GpuDevice (Vulkan/Metal/DX12)
370///
371/// Uses a singleton GpuDevice to avoid per-call device creation overhead.
372/// Returns None if GPU is unavailable or matmul fails (auto-fallback to CPU).
373#[cfg(feature = "gpu")]
374fn wgpu_matmul(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Option<Vec<f32>> {
375    use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
376    use std::sync::OnceLock;
377    static WGPU_DISABLED: AtomicBool = AtomicBool::new(false);
378    static WGPU_LOGGED: AtomicBool = AtomicBool::new(false);
379    static WGPU_CALLS: AtomicU64 = AtomicU64::new(0);
380    static WGPU_DEVICE: OnceLock<Option<trueno::backends::gpu::GpuDevice>> = OnceLock::new();
381
382    if WGPU_DISABLED.load(Ordering::Relaxed) {
383        return None;
384    }
385
386    let device_opt = WGPU_DEVICE.get_or_init(|| {
387        if !trueno::backends::gpu::GpuBackend::is_available() {
388            eprintln!("[wgpu] No GPU available, using CPU");
389            return None;
390        }
391        match trueno::backends::gpu::GpuDevice::new() {
392            Ok(d) => {
393                eprintln!("[wgpu] GPU device initialized for matmul");
394                Some(d)
395            }
396            Err(e) => {
397                eprintln!("[wgpu] GPU init failed: {e}, using CPU");
398                None
399            }
400        }
401    });
402
403    let device = match device_opt.as_ref() {
404        Some(d) => d,
405        None => {
406            WGPU_DISABLED.store(true, Ordering::Relaxed);
407            return None;
408        }
409    };
410
411    let mut result = vec![0.0f32; m * n];
412    match device.matmul(a, b, &mut result, m, k, n) {
413        Ok(()) => {
414            let calls = WGPU_CALLS.fetch_add(1, Ordering::Relaxed);
415            if !WGPU_LOGGED.swap(true, Ordering::Relaxed) {
416                eprintln!("[wgpu] GPU matmul active ({m}x{k}x{n})");
417            }
418            // KAIZEN-003: Demote to 10k intervals; previous 1k floods logs
419            if calls > 0 && calls.is_multiple_of(10_000) {
420                eprintln!("[wgpu] {calls} GPU matmuls completed");
421            }
422            Some(result)
423        }
424        Err(_e) => {
425            WGPU_DISABLED.store(true, Ordering::Relaxed);
426            None
427        }
428    }
429}
430
431/// Matrix multiplication
432///
433/// Computes C = A @ B where:
434/// - A is m×k (flattened to length m*k)
435/// - B is k×n (flattened to length k*n)
436/// - C is m×n (flattened to length m*n)
437///
438/// Uses GPU acceleration when available (requires `gpu` feature).
439///
440/// # Arguments
441/// * `a` - Left matrix (m×k flattened)
442/// * `b` - Right matrix (k×n flattened)
443/// * `m` - Number of rows in A
444/// * `k` - Number of columns in A (= rows in B)
445/// * `n` - Number of columns in B
446#[provable_contracts_macros::contract("matmul-v1", equation = "matmul")]
447pub fn matmul(a: &Tensor, b: &Tensor, m: usize, k: usize, n: usize) -> Tensor {
448    assert_eq!(a.len(), m * k, "Matrix A size mismatch");
449    assert_eq!(b.len(), k * n, "Matrix B size mismatch");
450
451    // Compute C = A @ B using GPU if available
452    let result_data = matmul_compute(
453        a.data().as_slice().expect("matrix A must be contiguous"),
454        b.data().as_slice().expect("matrix B must be contiguous"),
455        m,
456        k,
457        n,
458    );
459
460    let requires_grad = a.requires_grad() || b.requires_grad();
461    let mut result = Tensor::new(Array1::from(result_data), requires_grad);
462
463    if requires_grad {
464        let a_clone = a.clone();
465        let b_clone = b.clone();
466        let backward_op = Rc::new(MatmulBackward {
467            a: a_clone,
468            b: b_clone,
469            m,
470            k,
471            n,
472            result_grad: result.grad_cell(),
473        });
474        result.set_backward_op(backward_op);
475    }
476
477    contract_post_matmul!(result.data().as_slice().unwrap_or(&[]));
478    result
479}
480
481struct MatmulBackward {
482    a: Tensor,
483    b: Tensor,
484    m: usize,
485    k: usize,
486    n: usize,
487    result_grad: Rc<RefCell<Option<Array1<f32>>>>,
488}
489
490impl BackwardOp for MatmulBackward {
491    fn backward(&self) {
492        if let Some(grad_output) = self.result_grad.borrow().as_ref() {
493            // ∂L/∂A = ∂L/∂C @ B^T  (m×n) @ (n×k) = (m×k)
494            // ∂L/∂B = A^T @ ∂L/∂C  (k×m) @ (m×n) = (k×n)
495
496            let grad_c = grad_output.as_slice().expect("gradient output must be contiguous");
497            let a_data = self.a.data();
498            let b_data = self.b.data();
499            let a_slice = a_data.as_slice().expect("matrix A must be contiguous");
500            let b_slice = b_data.as_slice().expect("matrix B must be contiguous");
501
502            if self.a.requires_grad() {
503                // grad_A = grad_C @ B^T
504                // grad_C is (m, n), B is (k, n), B^T is (n, k)
505                // Result: (m, n) @ (n, k) = (m, k)
506                let b_t = transpose(b_slice, self.k, self.n);
507                let grad_a = matmul_compute(grad_c, &b_t, self.m, self.n, self.k);
508                self.a.accumulate_grad(Array1::from(grad_a));
509            }
510
511            if self.b.requires_grad() {
512                // grad_B = A^T @ grad_C
513                // A is (m, k), A^T is (k, m), grad_C is (m, n)
514                // Result: (k, m) @ (m, n) = (k, n)
515                let a_t = transpose(a_slice, self.m, self.k);
516                let grad_b = matmul_compute(&a_t, grad_c, self.k, self.m, self.n);
517                self.b.accumulate_grad(Array1::from(grad_b));
518            }
519
520            // Recursively call backward on inputs
521            if let Some(op) = self.a.backward_op() {
522                op.backward();
523            }
524            if let Some(op) = self.b.backward_op() {
525                op.backward();
526            }
527        }
528    }
529}
530
531/// Matrix multiply with B transposed: C = A @ B^T (KAIZEN-011)
532///
533/// # Contract (C-MATMUL-NT-001)
534///
535/// - **Precondition**: A is (M, K), B is (N, K) — B's second dim matches A's second dim
536/// - **Postcondition**: Output is (M, N) where C(i,j) = Σ_k A(i,k) * B(j,k)
537/// - **Invariant**: Gradients flow to BOTH A and B (not transposed copies)
538///
539/// This is essential for LoRA where A_lora is stored as (rank, d_in) and we need
540/// `x @ A_lora^T` without creating a transposed copy that breaks gradient flow.
541///
542/// # Backward
543///
544/// - `∂L/∂A = ∂L/∂C @ B`  — (M,N) @ (N,K) = (M,K)
545/// - `∂L/∂B = ∂L/∂C^T @ A` — (N,M) @ (M,K) = (N,K)
546#[provable_contracts_macros::contract("matmul-v1", equation = "matmul_nt")]
547pub fn matmul_nt(a: &Tensor, b: &Tensor, m: usize, k: usize, n: usize) -> Tensor {
548    assert_eq!(
549        a.len(),
550        m * k,
551        "Matrix A size mismatch: expected {}×{} = {}, got {}",
552        m,
553        k,
554        m * k,
555        a.len()
556    );
557    assert_eq!(
558        b.len(),
559        n * k,
560        "Matrix B size mismatch: expected {}×{} = {}, got {}",
561        n,
562        k,
563        n * k,
564        b.len()
565    );
566
567    let a_slice = a.data();
568    let b_slice = b.data();
569    let a_data = a_slice.as_slice().expect("matrix A must be contiguous");
570    let b_data = b_slice.as_slice().expect("matrix B must be contiguous");
571
572    // C = A @ B^T: C(i,j) = Σ_k A(i,k) * B(j,k)
573    let result_data = matmul_nt_compute(a_data, b_data, m, k, n);
574
575    let requires_grad = a.requires_grad() || b.requires_grad();
576    let mut result = Tensor::new(Array1::from(result_data), requires_grad);
577
578    if requires_grad {
579        let a_clone = a.clone();
580        let b_clone = b.clone();
581        let backward_op = Rc::new(MatmulNtBackward {
582            a: a_clone,
583            b: b_clone,
584            m,
585            k,
586            n,
587            result_grad: result.grad_cell(),
588        });
589        result.set_backward_op(backward_op);
590    }
591
592    contract_post_matmul!(result.data().as_slice().unwrap_or(&[]));
593    result
594}
595
596/// Raw compute for C = A @ B^T using trueno SIMD GEMM
597///
598/// A is (M, K), B is (N, K), output is (M, N)
599pub fn matmul_nt_compute(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
600    // Transpose B to (K, N) then use standard matmul
601    let b_t = transpose(b, n, k); // (N,K) → (K,N)
602    cpu_matmul(a, &b_t, m, k, n)
603}
604
605struct MatmulNtBackward {
606    a: Tensor,
607    b: Tensor,
608    m: usize,
609    k: usize,
610    n: usize,
611    result_grad: Rc<RefCell<Option<Array1<f32>>>>,
612}
613
614impl BackwardOp for MatmulNtBackward {
615    fn backward(&self) {
616        if let Some(grad_output) = self.result_grad.borrow().as_ref() {
617            // C = A @ B^T where A is (M,K), B is (N,K), C is (M,N)
618            //
619            // ∂L/∂A = ∂L/∂C @ B     (M,N) @ (N,K) = (M,K)
620            // ∂L/∂B = ∂L/∂C^T @ A   (N,M) @ (M,K) = (N,K)
621
622            let grad_c = grad_output.as_slice().expect("gradient output must be contiguous");
623
624            if self.a.requires_grad() {
625                // grad_A = grad_C @ B  (standard matmul)
626                let b_data = self.b.data();
627                let b_slice = b_data.as_slice().expect("matrix B must be contiguous");
628                let grad_a = matmul_compute(grad_c, b_slice, self.m, self.n, self.k);
629                self.a.accumulate_grad(Array1::from(grad_a));
630            }
631
632            if self.b.requires_grad() {
633                // grad_B = grad_C^T @ A
634                let a_data = self.a.data();
635                let a_slice = a_data.as_slice().expect("matrix A must be contiguous");
636                let grad_c_t = transpose(grad_c, self.m, self.n);
637                let grad_b = matmul_compute(&grad_c_t, a_slice, self.n, self.m, self.k);
638                self.b.accumulate_grad(Array1::from(grad_b));
639            }
640
641            // Recursively propagate
642            if let Some(op) = self.a.backward_op() {
643                op.backward();
644            }
645            if let Some(op) = self.b.backward_op() {
646                op.backward();
647            }
648        }
649    }
650}
651
652#[cfg(test)]
653mod tests {
654    use super::*;
655
656    #[test]
657    fn test_transpose_identity() {
658        // 1x1 matrix
659        let data = vec![5.0];
660        let result = transpose(&data, 1, 1);
661        assert_eq!(result, vec![5.0]);
662    }
663
664    #[test]
665    fn test_transpose_2x3() {
666        // 2x3 matrix
667        // [1, 2, 3]
668        // [4, 5, 6]
669        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
670        let result = transpose(&data, 2, 3);
671        // Expected 3x2:
672        // [1, 4]
673        // [2, 5]
674        // [3, 6]
675        assert_eq!(result, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
676    }
677
678    #[test]
679    fn test_transpose_3x2() {
680        // 3x2 matrix
681        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
682        let result = transpose(&data, 3, 2);
683        // Expected 2x3:
684        assert_eq!(result, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
685    }
686
687    #[test]
688    fn test_matmul_compute_2x2() {
689        // A = [[1, 2], [3, 4]] (2x2)
690        // B = [[5, 6], [7, 8]] (2x2)
691        // C = A @ B = [[19, 22], [43, 50]]
692        let a = vec![1.0, 2.0, 3.0, 4.0];
693        let b = vec![5.0, 6.0, 7.0, 8.0];
694        let c = matmul_compute(&a, &b, 2, 2, 2);
695        assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]);
696    }
697
698    #[test]
699    fn test_matmul_compute_2x3_3x2() {
700        // A = [[1, 2, 3], [4, 5, 6]] (2x3)
701        // B = [[7, 8], [9, 10], [11, 12]] (3x2)
702        // C = A @ B (2x2)
703        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
704        let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
705        let c = matmul_compute(&a, &b, 2, 3, 2);
706        // [1*7+2*9+3*11, 1*8+2*10+3*12] = [58, 64]
707        // [4*7+5*9+6*11, 4*8+5*10+6*12] = [139, 154]
708        assert_eq!(c, vec![58.0, 64.0, 139.0, 154.0]);
709    }
710
711    #[test]
712    fn test_matmul_no_grad() {
713        let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), false);
714        let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), false);
715        let c = matmul(&a, &b, 2, 2, 2);
716        assert!(!c.requires_grad());
717        assert_eq!(
718            c.data().as_slice().expect("operation should succeed"),
719            &[19.0, 22.0, 43.0, 50.0]
720        );
721    }
722
723    #[test]
724    fn test_matmul_with_grad() {
725        let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), true);
726        let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), true);
727        let c = matmul(&a, &b, 2, 2, 2);
728        assert!(c.requires_grad());
729        assert!(c.backward_op().is_some());
730    }
731
732    #[test]
733    fn test_matmul_backward() {
734        let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), true);
735        let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), true);
736        let c = matmul(&a, &b, 2, 2, 2);
737
738        // Set gradient of output
739        c.set_grad(Array1::from(vec![1.0, 1.0, 1.0, 1.0]));
740
741        // Trigger backward
742        if let Some(op) = c.backward_op() {
743            op.backward();
744        }
745
746        // Check gradients are accumulated
747        assert!(a.grad().is_some());
748        assert!(b.grad().is_some());
749    }
750
751    #[test]
752    fn test_matmul_a_requires_grad_only() {
753        let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), true);
754        let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), false);
755        let c = matmul(&a, &b, 2, 2, 2);
756        assert!(c.requires_grad());
757
758        c.set_grad(Array1::from(vec![1.0, 1.0, 1.0, 1.0]));
759        if let Some(op) = c.backward_op() {
760            op.backward();
761        }
762
763        assert!(a.grad().is_some());
764        assert!(b.grad().is_none());
765    }
766
767    #[test]
768    fn test_matmul_b_requires_grad_only() {
769        let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), false);
770        let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), true);
771        let c = matmul(&a, &b, 2, 2, 2);
772        assert!(c.requires_grad());
773
774        c.set_grad(Array1::from(vec![1.0, 1.0, 1.0, 1.0]));
775        if let Some(op) = c.backward_op() {
776            op.backward();
777        }
778
779        assert!(a.grad().is_none());
780        assert!(b.grad().is_some());
781    }
782
783    #[test]
784    #[should_panic(expected = "Pre-condition violated")]
785    fn test_matmul_size_mismatch_a() {
786        let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0]), false);
787        let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), false);
788        let _ = matmul(&a, &b, 2, 2, 2);
789    }
790
791    #[test]
792    #[should_panic(expected = "Pre-condition violated")]
793    fn test_matmul_size_mismatch_b() {
794        let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), false);
795        let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0]), false);
796        let _ = matmul(&a, &b, 2, 2, 2);
797    }
798
799    #[test]
800    fn test_transpose_double_transpose() {
801        // Transpose twice should give original
802        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
803        let t1 = transpose(&data, 2, 3);
804        let t2 = transpose(&t1, 3, 2);
805        assert_eq!(data, t2);
806    }
807
808    // =========================================================================
809    // FALSIFY-MM: matmul-kernel-v1.yaml contract (entrenar autograd matmul)
810    //
811    // Five-Whys (PMAT-354):
812    //   Why 1: entrenar had 10 matmul tests but zero FALSIFY-MM-* tests
813    //   Why 2: unit tests verify 2x2 cases and backward, not invariants
814    //   Why 3: no mapping from matmul-kernel-v1.yaml to entrenar test names
815    //   Why 4: entrenar predates the provable-contracts YAML convention
816    //   Why 5: matmul was "obviously correct" (textbook GEMM + autograd)
817    //
818    // References:
819    //   - provable-contracts/contracts/matmul-kernel-v1.yaml
820    // =========================================================================
821
822    /// FALSIFY-MM-001e: Shape correctness — output is [m, n]
823    #[test]
824    fn falsify_mm_001e_shape_correctness() {
825        for (m, k, n) in [(2, 3, 4), (1, 5, 1), (4, 4, 4), (3, 1, 2)] {
826            let result = matmul_compute(&vec![1.0; m * k], &vec![1.0; k * n], m, k, n);
827            assert_eq!(
828                result.len(),
829                m * n,
830                "FALSIFIED MM-001e: output len = {}, expected {} for ({m}x{k}) @ ({k}x{n})",
831                result.len(),
832                m * n
833            );
834        }
835    }
836
837    /// FALSIFY-MM-005e: Identity matrix — A @ I = A
838    #[test]
839    fn falsify_mm_005e_identity_matrix() {
840        let m = 3;
841        let k = 4;
842        let a: Vec<f32> = (0..m * k).map(|i| (i as f32 + 1.0) * 0.5).collect();
843        let mut identity = vec![0.0; k * k];
844        for i in 0..k {
845            identity[i * k + i] = 1.0;
846        }
847        let result = matmul_compute(&a, &identity, m, k, k);
848        for (i, (&got, &exp)) in result.iter().zip(a.iter()).enumerate() {
849            assert!(
850                (got - exp).abs() < 1e-5,
851                "FALSIFIED MM-005e: (A@I)[{i}] = {got}, expected {exp}"
852            );
853        }
854    }
855
856    /// FALSIFY-MM-002e: Numerical accuracy against reference
857    #[test]
858    fn falsify_mm_002e_numerical_accuracy() {
859        // 2x3 @ 3x2 known result
860        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
861        let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
862        let result = matmul_compute(&a, &b, 2, 3, 2);
863        let expected = [58.0, 64.0, 139.0, 154.0];
864        for (i, (&got, &exp)) in result.iter().zip(expected.iter()).enumerate() {
865            assert!(
866                (got - exp).abs() < 1e-4,
867                "FALSIFIED MM-002e: result[{i}] = {got}, expected {exp}"
868            );
869        }
870    }
871
872    // =========================================================================
873    // matmul_nt tests (KAIZEN-011)
874    // =========================================================================
875
876    #[test]
877    fn test_matmul_nt_compute_2x2() {
878        // A = [[1, 2], [3, 4]] (2x2)
879        // B = [[5, 6], [7, 8]] (2x2)
880        // C = A @ B^T
881        // B^T = [[5, 7], [6, 8]]
882        // C = [[1*5+2*6, 1*7+2*8], [3*5+4*6, 3*7+4*8]]
883        //   = [[17, 23], [39, 53]]
884        let a = vec![1.0, 2.0, 3.0, 4.0];
885        let b = vec![5.0, 6.0, 7.0, 8.0];
886        let c = matmul_nt_compute(&a, &b, 2, 2, 2);
887        assert_eq!(c, vec![17.0, 23.0, 39.0, 53.0]);
888    }
889
890    #[test]
891    fn test_matmul_nt_compute_2x3_4x3() {
892        // A = [[1,2,3],[4,5,6]] (2x3), B = [[1,0,0],[0,1,0],[0,0,1],[1,1,1]] (4x3)
893        // C = A @ B^T (2x4)
894        // B^T cols = rows of B
895        // C(0,0) = 1*1+2*0+3*0 = 1
896        // C(0,1) = 1*0+2*1+3*0 = 2
897        // C(0,2) = 1*0+2*0+3*1 = 3
898        // C(0,3) = 1*1+2*1+3*1 = 6
899        // C(1,0) = 4*1+5*0+6*0 = 4
900        // C(1,1) = 4*0+5*1+6*0 = 5
901        // C(1,2) = 4*0+5*0+6*1 = 6
902        // C(1,3) = 4*1+5*1+6*1 = 15
903        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
904        let b = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
905        let c = matmul_nt_compute(&a, &b, 2, 3, 4);
906        assert_eq!(c, vec![1.0, 2.0, 3.0, 6.0, 4.0, 5.0, 6.0, 15.0]);
907    }
908
909    #[test]
910    fn test_matmul_nt_equivalence_to_transpose_matmul() {
911        // Verify: matmul_nt(A, B, m, k, n) == matmul(A, B^T, m, k, n)
912        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
913        let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]; // 2x3
914        let b_t = transpose(&b, 2, 3); // 3x2
915
916        let c_nt = matmul_nt_compute(&a, &b, 2, 3, 2);
917        let c_ref = matmul_compute(&a, &b_t, 2, 3, 2);
918
919        for (i, (&got, &exp)) in c_nt.iter().zip(c_ref.iter()).enumerate() {
920            assert!(
921                (got - exp).abs() < 1e-5,
922                "matmul_nt[{i}] = {got}, matmul(A, B^T)[{i}] = {exp}"
923            );
924        }
925    }
926
927    #[test]
928    fn test_matmul_nt_backward_grad_flows_to_b() {
929        // KAIZEN-011: Verify gradients flow to the ORIGINAL B tensor (not a copy)
930        let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), false); // 2x2
931        let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), true); // 2x2, requires_grad
932
933        let c = matmul_nt(&a, &b, 2, 2, 2);
934        assert!(c.requires_grad());
935
936        c.set_grad(Array1::from(vec![1.0, 1.0, 1.0, 1.0]));
937        if let Some(op) = c.backward_op() {
938            op.backward();
939        }
940
941        // B must have received gradient
942        let b_grad = b.grad().expect("KAIZEN-011: B must receive gradient from matmul_nt");
943
944        // grad_B = grad_C^T @ A = [[1,1],[1,1]]^T @ [[1,2],[3,4]]
945        // = [[1,1],[1,1]] @ [[1,2],[3,4]] = [[4,6],[4,6]]
946        let expected_grad_b = vec![4.0, 6.0, 4.0, 6.0];
947        for (i, (&got, &exp)) in b_grad.iter().zip(expected_grad_b.iter()).enumerate() {
948            assert!((got - exp).abs() < 1e-4, "KAIZEN-011: grad_B[{i}] = {got}, expected {exp}");
949        }
950    }
951
952    #[test]
953    fn test_matmul_nt_backward_grad_flows_to_a() {
954        let a = Tensor::new(Array1::from(vec![1.0, 2.0, 3.0, 4.0]), true); // 2x2
955        let b = Tensor::new(Array1::from(vec![5.0, 6.0, 7.0, 8.0]), false); // 2x2
956
957        let c = matmul_nt(&a, &b, 2, 2, 2);
958        c.set_grad(Array1::from(vec![1.0, 1.0, 1.0, 1.0]));
959        if let Some(op) = c.backward_op() {
960            op.backward();
961        }
962
963        let a_grad = a.grad().expect("A must receive gradient");
964
965        // grad_A = grad_C @ B = [[1,1],[1,1]] @ [[5,6],[7,8]] = [[12,14],[12,14]]
966        let expected_grad_a = vec![12.0, 14.0, 12.0, 14.0];
967        for (i, (&got, &exp)) in a_grad.iter().zip(expected_grad_a.iter()).enumerate() {
968            assert!((got - exp).abs() < 1e-4, "grad_A[{i}] = {got}, expected {exp}");
969        }
970    }
971
972    mod mm_proptest_falsify {
973        use super::*;
974        use proptest::prelude::*;
975
976        // FALSIFY-MM-001e-prop: Shape correctness for random dimensions
977        proptest! {
978            #![proptest_config(ProptestConfig::with_cases(100))]
979
980            #[test]
981            fn falsify_mm_001e_prop_shape(
982                m in 1..=8usize,
983                k in 1..=8usize,
984                n in 1..=8usize,
985            ) {
986                let result = matmul_compute(&vec![1.0; m * k], &vec![1.0; k * n], m, k, n);
987                prop_assert_eq!(result.len(), m * n);
988            }
989        }
990
991        // FALSIFY-MM-005e-prop: Identity matrix for random dimensions
992        proptest! {
993            #![proptest_config(ProptestConfig::with_cases(50))]
994
995            #[test]
996            fn falsify_mm_005e_prop_identity(
997                m in 1..=6usize,
998                k in 1..=6usize,
999                seed in 0..500u32,
1000            ) {
1001                let a: Vec<f32> = (0..m * k)
1002                    .map(|i| ((i as f32 + seed as f32) * 0.37).sin())
1003                    .collect();
1004                let mut identity = vec![0.0; k * k];
1005                for i in 0..k {
1006                    identity[i * k + i] = 1.0;
1007                }
1008                let result = matmul_compute(&a, &identity, m, k, k);
1009                for (i, (&got, &exp)) in result.iter().zip(a.iter()).enumerate() {
1010                    prop_assert!(
1011                        (got - exp).abs() < 1e-4,
1012                        "FALSIFIED MM-005e-prop: (A@I)[{}] = {}, expected {}",
1013                        i, got, exp
1014                    );
1015                }
1016            }
1017        }
1018
1019        // FALSIFY-MM-NT-001: matmul_nt equivalence to manual transpose
1020        proptest! {
1021            #![proptest_config(ProptestConfig::with_cases(50))]
1022
1023            #[test]
1024            fn falsify_mm_nt_equivalence(
1025                m in 1..=6usize,
1026                k in 1..=6usize,
1027                n in 1..=6usize,
1028                seed in 0..500u32,
1029            ) {
1030                let a: Vec<f32> = (0..m * k)
1031                    .map(|i| ((i as f32 + seed as f32) * 0.31).sin())
1032                    .collect();
1033                let b: Vec<f32> = (0..n * k)
1034                    .map(|i| ((i as f32 + seed as f32 + 100.0) * 0.47).cos())
1035                    .collect();
1036
1037                let c_nt = matmul_nt_compute(&a, &b, m, k, n);
1038                let b_t = transpose(&b, n, k);
1039                let c_ref = matmul_compute(&a, &b_t, m, k, n);
1040
1041                for (i, (&got, &exp)) in c_nt.iter().zip(c_ref.iter()).enumerate() {
1042                    prop_assert!(
1043                        (got - exp).abs() < 1e-3,
1044                        "matmul_nt[{}] = {}, expected {}",
1045                        i, got, exp
1046                    );
1047                }
1048            }
1049        }
1050    }
1051
1052    /// KAIZEN-018: Verify transpose_tracked backward propagates gradient
1053    /// to the original tensor through the inverse transpose.
1054    #[test]
1055    fn test_transpose_tracked_backward_gradient_flow() {
1056        // Original tensor A: 2×3 matrix [1,2,3,4,5,6], requires_grad=true
1057        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], true);
1058
1059        // Tracked transpose: A^T is 3×2
1060        let a_t = transpose_tracked(&a, 2, 3);
1061        assert_eq!(a_t.len(), 6);
1062
1063        // Verify transposed data is correct
1064        let at_data = a_t.data();
1065        let at_slice = at_data.as_slice().expect("contiguous");
1066        assert_eq!(at_slice, &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1067
1068        // Set gradient on transposed tensor (as if backward computed it)
1069        // Gradient shape matches A^T: 3×2
1070        a_t.set_grad(Array1::from(vec![10.0, 40.0, 20.0, 50.0, 30.0, 60.0]));
1071
1072        // Trigger backward: should transpose grad back (3×2 → 2×3) and accumulate on a
1073        if let Some(op) = a_t.backward_op() {
1074            op.backward();
1075        }
1076
1077        // Check that the original tensor has the correctly transposed gradient
1078        let grad = a.grad().expect("original tensor should have gradient");
1079        let grad_slice = grad.as_slice().expect("contiguous");
1080        // Transpose of 3×2 [10,40,20,50,30,60] = 2×3 [10,20,30,40,50,60]
1081        assert_eq!(grad_slice, &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]);
1082    }
1083
1084    /// KAIZEN-018: Verify that transpose_tracked + matmul backward flows
1085    /// gradient to the original (non-transposed) LoRA parameter.
1086    #[test]
1087    fn test_transpose_tracked_lora_gradient_chain() {
1088        // Simulate LoRA forward: y = x @ A^T where A is (rank=2, d_in=3)
1089        let lora_a = Tensor::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], true);
1090        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0], true); // 1×3 input
1091
1092        // Tracked transpose: A^T is (d_in=3, rank=2)
1093        let lora_a_t = transpose_tracked(&lora_a, 2, 3);
1094
1095        // Matmul: (1, 3) @ (3, 2) = (1, 2)
1096        let result = matmul(&x, &lora_a_t, 1, 3, 2);
1097        assert_eq!(result.len(), 2);
1098
1099        // Set gradient on result (as if loss backward computed it)
1100        result.set_grad(Array1::from(vec![1.0, 1.0]));
1101
1102        // Trigger backward chain: result → matmul backward → lora_a_t → transpose backward → lora_a
1103        if let Some(op) = result.backward_op() {
1104            op.backward();
1105        }
1106
1107        // The original lora_a should now have a gradient
1108        let grad = lora_a.grad().expect("LoRA A should receive gradient via transpose_tracked");
1109        assert_eq!(grad.len(), 6);
1110
1111        // Verify gradient is finite and non-zero
1112        for (i, &val) in grad.as_slice().expect("contiguous").iter().enumerate() {
1113            assert!(val.is_finite(), "Gradient element {i} is not finite: {val}");
1114        }
1115        let grad_sum: f32 = grad.iter().sum();
1116        assert!(grad_sum.abs() > 1e-6, "Gradient should be non-zero, got sum={grad_sum}");
1117    }
1118}