Skip to main content

trueno/matrix/ops/
arithmetic.rs

1//! Matrix arithmetic operations
2//!
3//! This module provides matrix multiplication and related operations:
4//! - `matmul()` - Standard matrix multiplication with SIMD optimization
5//! - `batched_matmul()` - Batched 3D tensor multiplication
6//! - `batched_matmul_4d()` - 4D tensor multiplication for attention
7//!
8//! ## Domain Separation (PMAT-018)
9//!
10//! Arithmetic operations (multiplication, addition) are separate from storage
11//! operations (allocation, indexing). This allows optimizing compute kernels
12//! independently of memory layout decisions.
13//!
14//! ## Performance Hierarchy
15//!
16//! 1. GPU for large matrices (≥500×500) - 2-10x speedup
17//! 2. BLIS/SIMD for medium-large matrices (>64×64) - 2-8x speedup
18//! 3. Naive for small matrices - lowest overhead
19
20use crate::TruenoError;
21
22#[cfg(feature = "tracing")]
23use tracing::instrument;
24
25use super::super::Matrix;
26
27impl Matrix<f32> {
28    /// Matrix multiplication (matmul)
29    ///
30    /// Computes `C = A × B` where A is `m×n`, B is `n×p`, and C is `m×p`.
31    ///
32    /// # Arguments
33    ///
34    /// * `other` - The matrix to multiply with (right operand)
35    ///
36    /// # Returns
37    ///
38    /// A new matrix containing the result of matrix multiplication
39    ///
40    /// # Errors
41    ///
42    /// Returns `InvalidInput` if matrix dimensions are incompatible
43    /// (i.e., `self.cols != other.rows`)
44    ///
45    /// # Example
46    ///
47    /// ```
48    /// use trueno::Matrix;
49    ///
50    /// let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
51    /// let b = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
52    /// let c = a.matmul(&b).unwrap();
53    ///
54    /// // [[1, 2],   [[5, 6],   [[19, 22],
55    /// //  [3, 4]] ×  [7, 8]] =  [43, 50]]
56    /// assert_eq!(c.get(0, 0), Some(&19.0));
57    /// assert_eq!(c.get(0, 1), Some(&22.0));
58    /// assert_eq!(c.get(1, 0), Some(&43.0));
59    /// assert_eq!(c.get(1, 1), Some(&50.0));
60    /// ```
61    // =========================================================================
62    // HOT PATH - PERFORMANCE CRITICAL
63    // =========================================================================
64    // Core matrix operation used in neural network forward passes.
65    // Changes to inner loops REQUIRE benchmark verification: make bench-check
66    // =========================================================================
67    #[cfg_attr(feature = "tracing", instrument(skip(self, other), fields(dims = %format!("{}x{} @ {}x{}", self.rows, self.cols, other.rows, other.cols))))]
68    pub fn matmul(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
69        if self.cols != other.rows {
70            return Err(TruenoError::InvalidInput(format!(
71                "Matrix dimension mismatch for multiplication: {}×{} × {}×{} (inner dimensions {} and {} must match)",
72                self.rows, self.cols, other.rows, other.cols, self.cols, other.rows
73            )));
74        }
75
76        // Fast path for vector-matrix multiply (rows=1)
77        if self.rows == 1 {
78            return self.matmul_vector_matrix(other);
79        }
80
81        // NOTE: zeros required — BLIS GEMM accumulates (c += A*B) via load_c_tile.
82        let mut result = Matrix::zeros_with_backend(self.rows, other.cols, self.backend);
83
84        #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
85        const GPU_THRESHOLD: usize = 500;
86        const SIMD_THRESHOLD: usize = 64;
87
88        // Try GPU first for very large matrices
89        #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
90        {
91            if self.rows >= GPU_THRESHOLD
92                && self.cols >= GPU_THRESHOLD
93                && other.cols >= GPU_THRESHOLD
94            {
95                if let Ok(gpu_result) = self.matmul_gpu(other) {
96                    return Ok(gpu_result);
97                }
98            }
99        }
100
101        // Use SIMD for medium-large matrices
102        if self.rows >= SIMD_THRESHOLD
103            || self.cols >= SIMD_THRESHOLD
104            || other.cols >= SIMD_THRESHOLD
105        {
106            #[cfg(target_arch = "wasm32")]
107            {
108                self.matmul_wasm_tiled(other, &mut result)?;
109            }
110            #[cfg(not(target_arch = "wasm32"))]
111            {
112                crate::blis::parallel::gemm_blis_parallel(
113                    self.rows,
114                    other.cols,
115                    self.cols,
116                    &self.data,
117                    &other.data,
118                    &mut result.data,
119                )?;
120            }
121        } else {
122            self.matmul_naive(other, &mut result)?;
123        }
124
125        Ok(result)
126    }
127
128    /// Batched matrix multiplication for 3D tensors.
129    ///
130    /// Computes `[batch, m, k] @ [batch, k, n] -> [batch, m, n]` using SIMD for each batch.
131    #[cfg_attr(feature = "tracing", instrument(skip(a_data, b_data), fields(batch, m, k, n)))]
132    pub fn batched_matmul(
133        a_data: &[f32],
134        b_data: &[f32],
135        batch: usize,
136        m: usize,
137        k: usize,
138        n: usize,
139    ) -> Result<Vec<f32>, TruenoError> {
140        let a_stride = m * k;
141        let b_stride = k * n;
142        let out_stride = m * n;
143
144        if a_data.len() != batch * a_stride {
145            return Err(TruenoError::InvalidInput(format!(
146                "A data size mismatch: expected {} ({}×{}×{}), got {}",
147                batch * a_stride,
148                batch,
149                m,
150                k,
151                a_data.len()
152            )));
153        }
154        if b_data.len() != batch * b_stride {
155            return Err(TruenoError::InvalidInput(format!(
156                "B data size mismatch: expected {} ({}×{}×{}), got {}",
157                batch * b_stride,
158                batch,
159                k,
160                n,
161                b_data.len()
162            )));
163        }
164
165        // NOTE: zeros required — gemm_blis accumulates (c += A*B) via load_c_tile.
166        let mut output = vec![0.0f32; batch * out_stride];
167
168        // KAIZEN-039: Call gemm_blis directly on sub-slices instead of
169        // Matrix::from_slice (which copies data). Eliminates 2 × batch Vec
170        // allocations per call (e.g., 64 copies for 32-head attention).
171        for ba in 0..batch {
172            let a_offset = ba * a_stride;
173            let b_offset = ba * b_stride;
174            let out_offset = ba * out_stride;
175
176            let a_slice = &a_data[a_offset..a_offset + a_stride];
177            let b_slice = &b_data[b_offset..b_offset + b_stride];
178            let c_slice = &mut output[out_offset..out_offset + out_stride];
179
180            #[cfg(not(target_arch = "wasm32"))]
181            {
182                crate::blis::gemm_blis(m, n, k, a_slice, b_slice, c_slice, None)?;
183            }
184            #[cfg(target_arch = "wasm32")]
185            {
186                let a_mat = Matrix::from_slice(m, k, a_slice)?;
187                let b_mat = Matrix::from_slice(k, n, b_slice)?;
188                let result = a_mat.matmul(&b_mat)?;
189                c_slice.copy_from_slice(result.as_slice());
190            }
191        }
192
193        Ok(output)
194    }
195
196    /// Batched matrix multiplication for 4D tensors (attention pattern).
197    ///
198    /// Computes `[batch, heads, m, k] @ [batch, heads, k, n] -> [batch, heads, m, n]`
199    #[cfg_attr(
200        feature = "tracing",
201        instrument(skip(a_data, b_data), fields(batch, heads, m, k, n))
202    )]
203    pub fn batched_matmul_4d(
204        a_data: &[f32],
205        b_data: &[f32],
206        batch: usize,
207        heads: usize,
208        m: usize,
209        k: usize,
210        n: usize,
211    ) -> Result<Vec<f32>, TruenoError> {
212        let a_head_stride = m * k;
213        let b_head_stride = k * n;
214        let out_head_stride = m * n;
215        let total_heads = batch * heads;
216
217        let expected_a = total_heads * a_head_stride;
218        let expected_b = total_heads * b_head_stride;
219        if a_data.len() != expected_a {
220            return Err(TruenoError::InvalidInput(format!(
221                "A data size mismatch: expected {} ({}×{}×{}×{}), got {}",
222                expected_a,
223                batch,
224                heads,
225                m,
226                k,
227                a_data.len()
228            )));
229        }
230        if b_data.len() != expected_b {
231            return Err(TruenoError::InvalidInput(format!(
232                "B data size mismatch: expected {} ({}×{}×{}×{}), got {}",
233                expected_b,
234                batch,
235                heads,
236                k,
237                n,
238                b_data.len()
239            )));
240        }
241
242        // NOTE: zeros required — gemm_blis accumulates (c += A*B) via load_c_tile.
243        let mut output = vec![0.0f32; total_heads * out_head_stride];
244
245        // KAIZEN-039: Call gemm_blis directly — eliminates 2 × total_heads
246        // Vec copies per call (e.g., 64 copies for batch=1, heads=32).
247        for bh in 0..total_heads {
248            let a_offset = bh * a_head_stride;
249            let b_offset = bh * b_head_stride;
250            let out_offset = bh * out_head_stride;
251
252            let a_slice = &a_data[a_offset..a_offset + a_head_stride];
253            let b_slice = &b_data[b_offset..b_offset + b_head_stride];
254            let c_slice = &mut output[out_offset..out_offset + out_head_stride];
255
256            #[cfg(not(target_arch = "wasm32"))]
257            {
258                crate::blis::gemm_blis(m, n, k, a_slice, b_slice, c_slice, None)?;
259            }
260            #[cfg(target_arch = "wasm32")]
261            {
262                let a_mat = Matrix::from_slice(m, k, a_slice)?;
263                let b_mat = Matrix::from_slice(k, n, b_slice)?;
264                let result = a_mat.matmul(&b_mat)?;
265                c_slice.copy_from_slice(result.as_slice());
266            }
267        }
268
269        Ok(output)
270    }
271
272    /// Fast path for vector-matrix multiplication (1×K @ K×N → 1×N)
273    ///
274    /// Dispatches to AVX2 SIMD GEMV kernel when available (explicit VFMADD
275    /// with 4-way K-unrolling), falls back to scalar 4-way axpy.
276    /// Bypasses BLIS packing which dominates for M=1.
277    ///
278    /// Contract: matvec-kernel-v1, equation "matvec"
279    #[cfg_attr(feature = "tracing", instrument(skip(self, other), fields(k = self.cols, n = other.cols)))]
280    fn matmul_vector_matrix(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
281        debug_assert_eq!(self.rows, 1);
282
283        let k = self.cols;
284        let n = other.cols;
285        // NOTE: zeros required — gemv accumulates (c[j] += a[k]*b[k*n+j]).
286        let mut c = vec![0.0f32; n];
287
288        crate::blis::gemv::gemv(k, n, &self.data, &other.data, &mut c);
289
290        Matrix::from_vec(1, n, c)
291    }
292
293    /// Naive O(n³) matrix multiplication (baseline for small matrices < 64)
294    fn matmul_naive(
295        &self,
296        other: &Matrix<f32>,
297        result: &mut Matrix<f32>,
298    ) -> Result<(), TruenoError> {
299        let m = self.rows;
300        let k = self.cols;
301        let n = other.cols;
302        // Direct slice access — eliminates bounds-check + Option::expect
303        // per element in the innermost loop (~30% overhead for small matrices).
304        let a = &self.data;
305        let b = &other.data;
306        let c = &mut result.data;
307
308        for i in 0..m {
309            let a_row = i * k;
310            let c_row = i * n;
311            for j in 0..n {
312                let mut sum = 0.0f32;
313                for kk in 0..k {
314                    // a[i,kk] * b[kk,j] — row-major layout
315                    sum += a[a_row + kk] * b[kk * n + j];
316                }
317                c[c_row + j] = sum;
318            }
319        }
320        Ok(())
321    }
322
323    /// WASM-optimized tiled matrix multiplication
324    #[allow(dead_code)]
325    fn matmul_wasm_tiled(
326        &self,
327        other: &Matrix<f32>,
328        result: &mut Matrix<f32>,
329    ) -> Result<(), TruenoError> {
330        let m = self.rows;
331        let k = self.cols;
332        let n = other.cols;
333
334        for i in 0..m {
335            let a_row_start = i * k;
336            let result_row_start = i * n;
337
338            let simd_width = 8;
339            let n_simd = (n / simd_width) * simd_width;
340
341            #[allow(clippy::needless_range_loop)]
342            for j0 in (0..n_simd).step_by(simd_width) {
343                let mut acc = [0.0f32; 8];
344
345                for kk in 0..k {
346                    let a_val = self.data[a_row_start + kk];
347                    let b_row_start = kk * n + j0;
348
349                    for jj in 0..simd_width {
350                        acc[jj] += a_val * other.data[b_row_start + jj];
351                    }
352                }
353
354                for jj in 0..simd_width {
355                    result.data[result_row_start + j0 + jj] = acc[jj];
356                }
357            }
358
359            for j in n_simd..n {
360                let mut sum = 0.0f32;
361                for kk in 0..k {
362                    sum += self.data[a_row_start + kk] * other.data[kk * n + j];
363                }
364                result.data[result_row_start + j] = sum;
365            }
366        }
367
368        Ok(())
369    }
370
371    /// GPU-accelerated matrix multiplication
372    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
373    fn matmul_gpu(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
374        // Track 1 (CGP-DBUF): Try cuBLAS first when cuda feature is enabled.
375        // cuBLAS at 105-150 TFLOP/s vs wgpu shader at ~5 TFLOP/s — 20-30× faster.
376        #[cfg(feature = "cuda")]
377        {
378            if let Ok(result) = self.matmul_cublas(other) {
379                return Ok(result);
380            }
381            // cuBLAS unavailable (no GPU, driver error) — fall through to wgpu
382        }
383
384        use crate::backends::gpu::GpuBackend;
385
386        if !GpuBackend::is_available() {
387            return Err(TruenoError::InvalidInput("GPU not available".to_string()));
388        }
389
390        let mut gpu = GpuBackend::new();
391        let result_data = gpu
392            .matmul(&self.data, &other.data, self.rows, self.cols, other.cols)
393            .map_err(|e| TruenoError::InvalidInput(format!("GPU matmul failed: {}", e)))?;
394
395        let mut result = Matrix::zeros(self.rows, other.cols);
396        result.data = result_data;
397
398        Ok(result)
399    }
400
401    /// cuBLAS FP32 GEMM via trueno-gpu's own FFI bindings (OWN THE STACK).
402    /// Uses cublasGemmEx with CUBLAS_COMPUTE_32F for numerical safety.
403    #[cfg(feature = "cuda")]
404    fn matmul_cublas(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
405        use trueno_gpu::driver::{CublasHandle, CudaContext, CudaStream, GemmOp, GpuBuffer};
406
407        let m = self.rows;
408        let k = self.cols;
409        let n = other.cols;
410
411        let ctx = CudaContext::new(0)
412            .map_err(|e| TruenoError::InvalidInput(format!("CUDA init: {e}")))?;
413        let stream = CudaStream::new(&ctx)
414            .map_err(|e| TruenoError::InvalidInput(format!("CUDA stream: {e}")))?;
415        let handle = CublasHandle::new(&ctx)
416            .map_err(|e| TruenoError::InvalidInput(format!("cuBLAS init: {e}")))?;
417        handle
418            .set_stream(&stream)
419            .map_err(|e| TruenoError::InvalidInput(format!("cuBLAS stream: {e}")))?;
420
421        let a_buf = GpuBuffer::from_host(&ctx, &self.data)
422            .map_err(|e| TruenoError::InvalidInput(format!("GPU alloc A: {e}")))?;
423        let b_buf = GpuBuffer::from_host(&ctx, &other.data)
424            .map_err(|e| TruenoError::InvalidInput(format!("GPU alloc B: {e}")))?;
425        let c_data = vec![0.0f32; m * n];
426        let c_buf = GpuBuffer::from_host(&ctx, &c_data)
427            .map_err(|e| TruenoError::InvalidInput(format!("GPU alloc C: {e}")))?;
428
429        handle
430            .gemm_f32_row_major(
431                m as i32,
432                n as i32,
433                k as i32,
434                1.0,
435                a_buf.as_ptr(),
436                b_buf.as_ptr(),
437                0.0,
438                c_buf.as_ptr(),
439            )
440            .map_err(|e| TruenoError::InvalidInput(format!("cuBLAS GEMM: {e}")))?;
441
442        stream.synchronize().map_err(|e| TruenoError::InvalidInput(format!("CUDA sync: {e}")))?;
443
444        let mut result_data = vec![0.0f32; m * n];
445        c_buf
446            .copy_to_host(&mut result_data)
447            .map_err(|e| TruenoError::InvalidInput(format!("GPU readback: {e}")))?;
448
449        Ok(Matrix { rows: m, cols: n, data: result_data, backend: self.backend })
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[test]
458    fn test_matmul_basic() {
459        let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
460        let b = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
461        let c = a.matmul(&b).unwrap();
462
463        assert_eq!(c.get(0, 0), Some(&19.0));
464        assert_eq!(c.get(0, 1), Some(&22.0));
465        assert_eq!(c.get(1, 0), Some(&43.0));
466        assert_eq!(c.get(1, 1), Some(&50.0));
467    }
468
469    #[test]
470    fn test_matmul_dimension_mismatch() {
471        let a = Matrix::from_vec(2, 3, vec![1.0; 6]).unwrap();
472        let b = Matrix::from_vec(2, 2, vec![1.0; 4]).unwrap();
473        assert!(a.matmul(&b).is_err());
474    }
475
476    #[test]
477    fn test_matmul_identity() {
478        let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
479        let i = Matrix::identity(2);
480        let result = a.matmul(&i).unwrap();
481        assert_eq!(result.as_slice(), a.as_slice());
482    }
483
484    #[test]
485    fn test_batched_matmul() {
486        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; // 2 batches of 2×2
487        let b = vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0]; // 2 identity matrices
488        let result = Matrix::batched_matmul(&a, &b, 2, 2, 2, 2).unwrap();
489        assert_eq!(result, a); // A × I = A
490    }
491
492    #[test]
493    fn test_batched_matmul_a_size_mismatch() {
494        let a = vec![1.0, 2.0, 3.0]; // Wrong size: should be 2*2*2=8
495        let b = vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
496        let result = Matrix::batched_matmul(&a, &b, 2, 2, 2, 2);
497        assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
498    }
499
500    #[test]
501    fn test_batched_matmul_b_size_mismatch() {
502        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
503        let b = vec![1.0, 0.0]; // Wrong size: should be 2*2*2=8
504        let result = Matrix::batched_matmul(&a, &b, 2, 2, 2, 2);
505        assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
506    }
507
508    #[test]
509    fn test_batched_matmul_single_batch() {
510        // Single batch 3x2 @ 2x4 = 3x4
511        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2
512        let b = vec![1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]; // 2x4
513        let result = Matrix::batched_matmul(&a, &b, 1, 3, 2, 4).unwrap();
514        assert_eq!(result.len(), 12); // 3x4
515    }
516
517    #[test]
518    fn test_batched_matmul_4d_basic() {
519        // batch=1, heads=1, m=2, k=2, n=2
520        let a = vec![1.0, 2.0, 3.0, 4.0]; // 2x2
521        let b = vec![1.0, 0.0, 0.0, 1.0]; // identity
522        let result = Matrix::batched_matmul_4d(&a, &b, 1, 1, 2, 2, 2).unwrap();
523        assert_eq!(result, a);
524    }
525
526    #[test]
527    fn test_batched_matmul_4d_a_size_mismatch() {
528        let a = vec![1.0]; // Wrong: should be 2*2*3*4=48
529        let b: Vec<f32> = (0..80).map(|x| x as f32 * 0.1).collect();
530        let result = Matrix::batched_matmul_4d(&a, &b, 2, 2, 3, 4, 5);
531        assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
532    }
533
534    #[test]
535    fn test_batched_matmul_4d_b_size_mismatch() {
536        let a: Vec<f32> = (0..48).map(|x| x as f32 * 0.1).collect();
537        let b = vec![1.0]; // Wrong: should be 2*2*4*5=80
538        let result = Matrix::batched_matmul_4d(&a, &b, 2, 2, 3, 4, 5);
539        assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
540    }
541
542    #[test]
543    fn test_batched_matmul_4d_multi_head() {
544        // batch=1, heads=4, m=2, k=2, n=2 (like attention heads)
545        let total = 4 * 2 * 2; // 16 elements for A
546        let a: Vec<f32> = (0..total).map(|_| 1.0).collect();
547        let b: Vec<f32> = (0..total).map(|_| 1.0).collect();
548        let result = Matrix::batched_matmul_4d(&a, &b, 1, 4, 2, 2, 2).unwrap();
549        assert_eq!(result.len(), total);
550        // Each element should be 2.0 (dot product of two 1.0 vectors of length 2)
551        for val in &result {
552            assert!((*val - 2.0).abs() < 1e-5);
553        }
554    }
555
556    #[test]
557    fn test_matmul_vector_matrix_path() {
558        // 1×K @ K×N triggers the vector-matrix fast path
559        let a = Matrix::from_vec(1, 4, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
560        let b = Matrix::from_vec(
561            4,
562            3,
563            vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
564        )
565        .unwrap();
566        let result = a.matmul(&b).unwrap();
567        assert_eq!(result.rows(), 1);
568        assert_eq!(result.cols(), 3);
569        // [1*1+2*0+3*0+4*1, 1*0+2*1+3*0+4*1, 1*0+2*0+3*1+4*1] = [5, 6, 7]
570        assert!((result.get(0, 0).unwrap() - 5.0).abs() < 1e-5);
571        assert!((result.get(0, 1).unwrap() - 6.0).abs() < 1e-5);
572        assert!((result.get(0, 2).unwrap() - 7.0).abs() < 1e-5);
573    }
574
575    #[test]
576    fn test_matmul_vector_matrix_with_zeros() {
577        // Test that zero elements in the vector skip computation
578        let a = Matrix::from_vec(1, 3, vec![0.0, 2.0, 0.0]).unwrap();
579        let b = Matrix::from_vec(3, 2, vec![100.0, 200.0, 3.0, 4.0, 500.0, 600.0]).unwrap();
580        let result = a.matmul(&b).unwrap();
581        // Only the second row of B contributes: [2*3, 2*4] = [6, 8]
582        assert!((result.get(0, 0).unwrap() - 6.0).abs() < 1e-5);
583        assert!((result.get(0, 1).unwrap() - 8.0).abs() < 1e-5);
584    }
585
586    // =========================================================================
587    // matmul_wasm_tiled tests
588    // =========================================================================
589    // These tests call the private WASM-tiled matmul directly (not behind
590    // #[cfg(target_arch = "wasm32")]) to achieve coverage on non-WASM hosts.
591
592    #[test]
593    fn test_matmul_wasm_tiled_small_no_simd() {
594        // n=3 < simd_width(8), so only the remainder path executes.
595        // 2x4 @ 4x3 = 2x3
596        let a = Matrix::from_vec(2, 4, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
597        let b = Matrix::from_vec(
598            4,
599            3,
600            vec![1.0, 0.0, 2.0, 0.0, 1.0, 0.0, 2.0, 0.0, 1.0, 0.0, 2.0, 0.0],
601        )
602        .unwrap();
603        let mut result = Matrix::zeros(2, 3);
604        a.matmul_wasm_tiled(&b, &mut result).unwrap();
605
606        // Row 0: [1*1+2*0+3*2+4*0, 1*0+2*1+3*0+4*2, 1*2+2*0+3*1+4*0] = [7, 10, 5]
607        assert!((result.get(0, 0).unwrap() - 7.0).abs() < 1e-5);
608        assert!((result.get(0, 1).unwrap() - 10.0).abs() < 1e-5);
609        assert!((result.get(0, 2).unwrap() - 5.0).abs() < 1e-5);
610
611        // Row 1: [5*1+6*0+7*2+8*0, 5*0+6*1+7*0+8*2, 5*2+6*0+7*1+8*0] = [19, 22, 17]
612        assert!((result.get(1, 0).unwrap() - 19.0).abs() < 1e-5);
613        assert!((result.get(1, 1).unwrap() - 22.0).abs() < 1e-5);
614        assert!((result.get(1, 2).unwrap() - 17.0).abs() < 1e-5);
615    }
616
617    #[test]
618    fn test_matmul_wasm_tiled_exact_simd_width() {
619        // n=8 exactly equals simd_width, so the SIMD path handles all columns
620        // and the remainder path has zero iterations.
621        // 2x3 @ 3x8 = 2x8
622        let a = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
623        let b_data: Vec<f32> = (1..=24).map(|x| x as f32).collect(); // 3x8
624        let b = Matrix::from_vec(3, 8, b_data).unwrap();
625        let mut result = Matrix::zeros(2, 8);
626        a.matmul_wasm_tiled(&b, &mut result).unwrap();
627
628        // Verify against naive matmul
629        let mut expected = Matrix::zeros(2, 8);
630        a.matmul_naive(&b, &mut expected).unwrap();
631        for i in 0..2 {
632            for j in 0..8 {
633                assert!(
634                    (result.get(i, j).unwrap() - expected.get(i, j).unwrap()).abs() < 1e-4,
635                    "Mismatch at ({}, {}): wasm_tiled={}, naive={}",
636                    i,
637                    j,
638                    result.get(i, j).unwrap(),
639                    expected.get(i, j).unwrap()
640                );
641            }
642        }
643    }
644
645    #[test]
646    fn test_matmul_wasm_tiled_simd_plus_remainder() {
647        // n=11 => n_simd=8 (SIMD path handles columns 0..8),
648        // remainder path handles columns 8..11. Both paths exercise.
649        // 3x4 @ 4x11 = 3x11
650        let a_data: Vec<f32> = (1..=12).map(|x| x as f32).collect();
651        let a = Matrix::from_vec(3, 4, a_data).unwrap();
652        let b_data: Vec<f32> = (1..=44).map(|x| x as f32 * 0.1).collect();
653        let b = Matrix::from_vec(4, 11, b_data).unwrap();
654        let mut result = Matrix::zeros(3, 11);
655        a.matmul_wasm_tiled(&b, &mut result).unwrap();
656
657        // Verify against naive
658        let mut expected = Matrix::zeros(3, 11);
659        a.matmul_naive(&b, &mut expected).unwrap();
660        for i in 0..3 {
661            for j in 0..11 {
662                assert!(
663                    (result.get(i, j).unwrap() - expected.get(i, j).unwrap()).abs() < 1e-3,
664                    "Mismatch at ({}, {}): wasm_tiled={}, naive={}",
665                    i,
666                    j,
667                    result.get(i, j).unwrap(),
668                    expected.get(i, j).unwrap()
669                );
670            }
671        }
672    }
673
674    #[test]
675    fn test_matmul_wasm_tiled_multiple_simd_blocks() {
676        // n=16 => two full SIMD blocks (0..8 and 8..16), no remainder.
677        // 2x2 @ 2x16 = 2x16
678        let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
679        let b_data: Vec<f32> = (1..=32).map(|x| x as f32).collect();
680        let b = Matrix::from_vec(2, 16, b_data).unwrap();
681        let mut result = Matrix::zeros(2, 16);
682        a.matmul_wasm_tiled(&b, &mut result).unwrap();
683
684        let mut expected = Matrix::zeros(2, 16);
685        a.matmul_naive(&b, &mut expected).unwrap();
686        for i in 0..2 {
687            for j in 0..16 {
688                assert!(
689                    (result.get(i, j).unwrap() - expected.get(i, j).unwrap()).abs() < 1e-3,
690                    "Mismatch at ({}, {})",
691                    i,
692                    j,
693                );
694            }
695        }
696    }
697
698    #[test]
699    fn test_matmul_wasm_tiled_single_row() {
700        // m=1, n=10 => SIMD block 0..8 + remainder 8..10
701        // 1x5 @ 5x10 = 1x10
702        let a = Matrix::from_vec(1, 5, vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
703        let b_data: Vec<f32> = (1..=50).map(|x| x as f32 * 0.1).collect();
704        let b = Matrix::from_vec(5, 10, b_data).unwrap();
705        let mut result = Matrix::zeros(1, 10);
706        a.matmul_wasm_tiled(&b, &mut result).unwrap();
707
708        let mut expected = Matrix::zeros(1, 10);
709        a.matmul_naive(&b, &mut expected).unwrap();
710        for j in 0..10 {
711            assert!(
712                (result.get(0, j).unwrap() - expected.get(0, j).unwrap()).abs() < 1e-3,
713                "Mismatch at col {}: wasm_tiled={}, naive={}",
714                j,
715                result.get(0, j).unwrap(),
716                expected.get(0, j).unwrap()
717            );
718        }
719    }
720
721    #[test]
722    fn test_matmul_wasm_tiled_identity() {
723        // Multiplying by identity should return the original matrix.
724        // 4x4 identity, n=4 < 8 so only remainder path.
725        let a = Matrix::from_vec(
726            4,
727            4,
728            vec![
729                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
730                16.0,
731            ],
732        )
733        .unwrap();
734        let identity = Matrix::identity(4);
735        let mut result = Matrix::zeros(4, 4);
736        a.matmul_wasm_tiled(&identity, &mut result).unwrap();
737
738        assert_eq!(result.as_slice(), a.as_slice());
739    }
740
741    #[test]
742    fn test_matmul_wasm_tiled_large_mixed() {
743        // Larger test: 5x10 @ 10x19 = 5x19
744        // n=19 => n_simd=16, remainder 16..19
745        // Exercises multiple SIMD blocks (0..8, 8..16) plus remainder (16..19).
746        let a_data: Vec<f32> = (0..50).map(|x| (x as f32) * 0.1).collect();
747        let a = Matrix::from_vec(5, 10, a_data).unwrap();
748        let b_data: Vec<f32> = (0..190).map(|x| (x as f32) * 0.01).collect();
749        let b = Matrix::from_vec(10, 19, b_data).unwrap();
750        let mut result = Matrix::zeros(5, 19);
751        a.matmul_wasm_tiled(&b, &mut result).unwrap();
752
753        let mut expected = Matrix::zeros(5, 19);
754        a.matmul_naive(&b, &mut expected).unwrap();
755        for i in 0..5 {
756            for j in 0..19 {
757                assert!(
758                    (result.get(i, j).unwrap() - expected.get(i, j).unwrap()).abs() < 1e-2,
759                    "Mismatch at ({}, {}): wasm_tiled={}, naive={}",
760                    i,
761                    j,
762                    result.get(i, j).unwrap(),
763                    expected.get(i, j).unwrap()
764                );
765            }
766        }
767    }
768
769    // =========================================================================
770    // FALSIFY-MM: matmul-kernel-v1.yaml contract (trueno Matrix::matmul)
771    //
772    // Five-Whys (PMAT-354):
773    //   Why 1: trueno had 10+ matmul tests but zero FALSIFY-MM-* tests
774    //   Why 2: unit tests verify known products, not mathematical invariants
775    //   Why 3: no mapping from matmul-kernel-v1.yaml to trueno test names
776    //   Why 4: trueno predates the provable-contracts YAML convention
777    //   Why 5: matmul was "obviously correct" (standard GEMM)
778    //
779    // References:
780    //   - provable-contracts/contracts/matmul-kernel-v1.yaml
781    // =========================================================================
782
783    /// FALSIFY-MM-001: Shape correctness — matmul(A[m,p], B[p,n]) = [m,n]
784    #[test]
785    fn falsify_mm_001_shape_correctness() {
786        for &(m, p, n) in &[(1, 1, 1), (2, 3, 4), (16, 32, 8), (1, 100, 1), (64, 1, 64)] {
787            let a = Matrix::from_vec(m, p, vec![1.0; m * p]).unwrap();
788            let b = Matrix::from_vec(p, n, vec![1.0; p * n]).unwrap();
789            let c = a.matmul(&b).unwrap();
790            assert_eq!(
791                (c.rows(), c.cols()),
792                (m, n),
793                "FALSIFIED MM-001: matmul([{m},{p}], [{p},{n}]) shape = [{},{}], expected [{m},{n}]",
794                c.rows(),
795                c.cols()
796            );
797        }
798    }
799
800    /// FALSIFY-MM-005: Identity matrix — matmul(A, I) = A
801    #[test]
802    fn falsify_mm_005_identity_matrix() {
803        let a = Matrix::from_vec(3, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap();
804        let eye =
805            Matrix::from_vec(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).unwrap();
806
807        let ai = a.matmul(&eye).unwrap();
808        let ia = eye.matmul(&a).unwrap();
809
810        for i in 0..3 {
811            for j in 0..3 {
812                let expected = a.get(i, j).unwrap();
813                assert!(
814                    (*ai.get(i, j).unwrap() - expected).abs() < 1e-6,
815                    "FALSIFIED MM-005: (A*I)[{i},{j}] = {}, expected {expected}",
816                    ai.get(i, j).unwrap()
817                );
818                assert!(
819                    (*ia.get(i, j).unwrap() - expected).abs() < 1e-6,
820                    "FALSIFIED MM-005: (I*A)[{i},{j}] = {}, expected {expected}",
821                    ia.get(i, j).unwrap()
822                );
823            }
824        }
825    }
826
827    /// FALSIFY-MM-002: Numerical accuracy — known product verified
828    #[test]
829    fn falsify_mm_002_numerical_accuracy() {
830        let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
831        let b = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
832        let c = a.matmul(&b).unwrap();
833
834        let expected = [19.0, 22.0, 43.0, 50.0];
835        for (i, &exp) in expected.iter().enumerate() {
836            let row = i / 2;
837            let col = i % 2;
838            let val = *c.get(row, col).unwrap();
839            assert!(
840                (val - exp).abs() < 1e-5,
841                "FALSIFIED MM-002: C[{row},{col}] = {val}, expected {exp}"
842            );
843        }
844    }
845
846    /// FALSIFY-MM-002b: matmul(zeros, B) = zeros
847    #[test]
848    fn falsify_mm_002b_zero_annihilation() {
849        let zero = Matrix::from_vec(3, 4, vec![0.0; 12]).unwrap();
850        let b = Matrix::from_vec(4, 2, vec![1.0; 8]).unwrap();
851        let c = zero.matmul(&b).unwrap();
852
853        for i in 0..3 {
854            for j in 0..2 {
855                let val = *c.get(i, j).unwrap();
856                assert!(
857                    val.abs() < 1e-10,
858                    "FALSIFIED MM-002b: zeros*B [{i},{j}] = {val}, expected 0"
859                );
860            }
861        }
862    }
863}
864
865#[cfg(test)]
866#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
867mod gpu_tests {
868    use super::*;
869
870    /// Test matmul_gpu via public API with matrices large enough to exceed
871    /// GPU_THRESHOLD (all dimensions >= 500).
872    /// Uses identity multiplication: A * I = A.
873    #[test]
874    fn test_matmul_gpu_identity() {
875        use crate::backends::gpu::GpuBackend;
876
877        if !GpuBackend::is_available() {
878            eprintln!("GPU not available, skipping test_matmul_gpu_identity");
879            return;
880        }
881
882        let n = 500; // Meets GPU_THRESHOLD for all three dimensions
883
884        // Create a simple test matrix: A[i,j] = (i*n + j) mod 100 * 0.01
885        let a_data: Vec<f32> = (0..n * n).map(|i| (i % 100) as f32 * 0.01).collect();
886
887        // Identity matrix
888        let mut i_data = vec![0.0f32; n * n];
889        for i in 0..n {
890            i_data[i * n + i] = 1.0;
891        }
892
893        let a = Matrix::from_vec(n, n, a_data.clone()).expect("valid matrix A");
894        let identity = Matrix::from_vec(n, n, i_data).expect("valid identity matrix");
895
896        let result = a.matmul(&identity).expect("matmul should succeed");
897
898        assert_eq!(result.rows(), n);
899        assert_eq!(result.cols(), n);
900
901        // A * I = A: sample verification (check corners and center)
902        let check_indices = [(0, 0), (0, n - 1), (n - 1, 0), (n - 1, n - 1), (n / 2, n / 2)];
903        for &(r, c) in &check_indices {
904            let expected = a_data[r * n + c];
905            let actual = *result.get(r, c).unwrap();
906            assert!(
907                (actual - expected).abs() < 1e-2,
908                "A*I mismatch at ({},{}): gpu={}, expected={}",
909                r,
910                c,
911                actual,
912                expected
913            );
914        }
915    }
916
917    /// Test matmul_gpu with all-ones matrices: result should be all-K.
918    #[test]
919    fn test_matmul_gpu_ones() {
920        use crate::backends::gpu::GpuBackend;
921
922        if !GpuBackend::is_available() {
923            eprintln!("GPU not available, skipping test_matmul_gpu_ones");
924            return;
925        }
926
927        let m = 500;
928        let k = 500;
929        let n = 500;
930
931        let a = Matrix::from_vec(m, k, vec![1.0f32; m * k]).expect("valid matrix A");
932        let b = Matrix::from_vec(k, n, vec![1.0f32; k * n]).expect("valid matrix B");
933
934        let result = a.matmul(&b).expect("matmul should succeed");
935
936        assert_eq!(result.rows(), m);
937        assert_eq!(result.cols(), n);
938
939        // Each element of C should be K (dot product of K ones with K ones)
940        let expected = k as f32;
941        for i in 0..10 {
942            for j in 0..10 {
943                assert!(
944                    (result.get(i, j).unwrap() - expected).abs() < 1.0,
945                    "C[{},{}] = {}, expected {}",
946                    i,
947                    j,
948                    result.get(i, j).unwrap(),
949                    expected
950                );
951            }
952        }
953    }
954
955    /// Test matmul_gpu directly via the private helper method.
956    #[test]
957    fn test_matmul_gpu_direct() {
958        use crate::backends::gpu::GpuBackend;
959
960        if !GpuBackend::is_available() {
961            eprintln!("GPU not available, skipping test_matmul_gpu_direct");
962            return;
963        }
964
965        // Small matrix for direct private method test
966        let a = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("valid A");
967        let b = Matrix::from_vec(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).expect("valid B");
968
969        let result = a.matmul_gpu(&b).expect("matmul_gpu should succeed");
970
971        assert_eq!(result.rows(), 2);
972        assert_eq!(result.cols(), 2);
973
974        // C = A * B
975        // C[0,0] = 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58
976        // C[0,1] = 1*8 + 2*10 + 3*12 = 8 + 20 + 36 = 64
977        // C[1,0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139
978        // C[1,1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154
979        assert!(
980            (result.get(0, 0).unwrap() - 58.0).abs() < 1e-2,
981            "Expected 58.0, got {}",
982            result.get(0, 0).unwrap()
983        );
984        assert!(
985            (result.get(0, 1).unwrap() - 64.0).abs() < 1e-2,
986            "Expected 64.0, got {}",
987            result.get(0, 1).unwrap()
988        );
989        assert!(
990            (result.get(1, 0).unwrap() - 139.0).abs() < 1e-2,
991            "Expected 139.0, got {}",
992            result.get(1, 0).unwrap()
993        );
994        assert!(
995            (result.get(1, 1).unwrap() - 154.0).abs() < 1e-2,
996            "Expected 154.0, got {}",
997            result.get(1, 1).unwrap()
998        );
999    }
1000
1001    /// Test matmul_gpu returns error when GPU is unavailable.
1002    #[test]
1003    fn test_matmul_gpu_not_available_path() {
1004        use crate::backends::gpu::GpuBackend;
1005
1006        // This test verifies the GpuBackend::is_available() check in matmul_gpu
1007        // If GPU IS available, the function should succeed; test the full path
1008        if !GpuBackend::is_available() {
1009            // If GPU is not available, matmul_gpu should return an error
1010            let a = Matrix::from_vec(2, 2, vec![1.0; 4]).unwrap();
1011            let b = Matrix::from_vec(2, 2, vec![1.0; 4]).unwrap();
1012            let result = a.matmul_gpu(&b);
1013            assert!(result.is_err(), "matmul_gpu should fail without GPU");
1014        }
1015    }
1016}