1use crate::TruenoError;
21
22#[cfg(feature = "tracing")]
23use tracing::instrument;
24
25use super::super::Matrix;
26
27impl Matrix<f32> {
28 #[cfg_attr(feature = "tracing", instrument(skip(self, other), fields(dims = %format!("{}x{} @ {}x{}", self.rows, self.cols, other.rows, other.cols))))]
68 pub fn matmul(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
69 if self.cols != other.rows {
70 return Err(TruenoError::InvalidInput(format!(
71 "Matrix dimension mismatch for multiplication: {}×{} × {}×{} (inner dimensions {} and {} must match)",
72 self.rows, self.cols, other.rows, other.cols, self.cols, other.rows
73 )));
74 }
75
76 if self.rows == 1 {
78 return self.matmul_vector_matrix(other);
79 }
80
81 let mut result = Matrix::zeros_with_backend(self.rows, other.cols, self.backend);
83
84 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
85 const GPU_THRESHOLD: usize = 500;
86 const SIMD_THRESHOLD: usize = 64;
87
88 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
90 {
91 if self.rows >= GPU_THRESHOLD
92 && self.cols >= GPU_THRESHOLD
93 && other.cols >= GPU_THRESHOLD
94 {
95 if let Ok(gpu_result) = self.matmul_gpu(other) {
96 return Ok(gpu_result);
97 }
98 }
99 }
100
101 if self.rows >= SIMD_THRESHOLD
103 || self.cols >= SIMD_THRESHOLD
104 || other.cols >= SIMD_THRESHOLD
105 {
106 #[cfg(target_arch = "wasm32")]
107 {
108 self.matmul_wasm_tiled(other, &mut result)?;
109 }
110 #[cfg(not(target_arch = "wasm32"))]
111 {
112 crate::blis::parallel::gemm_blis_parallel(
113 self.rows,
114 other.cols,
115 self.cols,
116 &self.data,
117 &other.data,
118 &mut result.data,
119 )?;
120 }
121 } else {
122 self.matmul_naive(other, &mut result)?;
123 }
124
125 Ok(result)
126 }
127
128 #[cfg_attr(feature = "tracing", instrument(skip(a_data, b_data), fields(batch, m, k, n)))]
132 pub fn batched_matmul(
133 a_data: &[f32],
134 b_data: &[f32],
135 batch: usize,
136 m: usize,
137 k: usize,
138 n: usize,
139 ) -> Result<Vec<f32>, TruenoError> {
140 let a_stride = m * k;
141 let b_stride = k * n;
142 let out_stride = m * n;
143
144 if a_data.len() != batch * a_stride {
145 return Err(TruenoError::InvalidInput(format!(
146 "A data size mismatch: expected {} ({}×{}×{}), got {}",
147 batch * a_stride,
148 batch,
149 m,
150 k,
151 a_data.len()
152 )));
153 }
154 if b_data.len() != batch * b_stride {
155 return Err(TruenoError::InvalidInput(format!(
156 "B data size mismatch: expected {} ({}×{}×{}), got {}",
157 batch * b_stride,
158 batch,
159 k,
160 n,
161 b_data.len()
162 )));
163 }
164
165 let mut output = vec![0.0f32; batch * out_stride];
167
168 for ba in 0..batch {
172 let a_offset = ba * a_stride;
173 let b_offset = ba * b_stride;
174 let out_offset = ba * out_stride;
175
176 let a_slice = &a_data[a_offset..a_offset + a_stride];
177 let b_slice = &b_data[b_offset..b_offset + b_stride];
178 let c_slice = &mut output[out_offset..out_offset + out_stride];
179
180 #[cfg(not(target_arch = "wasm32"))]
181 {
182 crate::blis::gemm_blis(m, n, k, a_slice, b_slice, c_slice, None)?;
183 }
184 #[cfg(target_arch = "wasm32")]
185 {
186 let a_mat = Matrix::from_slice(m, k, a_slice)?;
187 let b_mat = Matrix::from_slice(k, n, b_slice)?;
188 let result = a_mat.matmul(&b_mat)?;
189 c_slice.copy_from_slice(result.as_slice());
190 }
191 }
192
193 Ok(output)
194 }
195
196 #[cfg_attr(
200 feature = "tracing",
201 instrument(skip(a_data, b_data), fields(batch, heads, m, k, n))
202 )]
203 pub fn batched_matmul_4d(
204 a_data: &[f32],
205 b_data: &[f32],
206 batch: usize,
207 heads: usize,
208 m: usize,
209 k: usize,
210 n: usize,
211 ) -> Result<Vec<f32>, TruenoError> {
212 let a_head_stride = m * k;
213 let b_head_stride = k * n;
214 let out_head_stride = m * n;
215 let total_heads = batch * heads;
216
217 let expected_a = total_heads * a_head_stride;
218 let expected_b = total_heads * b_head_stride;
219 if a_data.len() != expected_a {
220 return Err(TruenoError::InvalidInput(format!(
221 "A data size mismatch: expected {} ({}×{}×{}×{}), got {}",
222 expected_a,
223 batch,
224 heads,
225 m,
226 k,
227 a_data.len()
228 )));
229 }
230 if b_data.len() != expected_b {
231 return Err(TruenoError::InvalidInput(format!(
232 "B data size mismatch: expected {} ({}×{}×{}×{}), got {}",
233 expected_b,
234 batch,
235 heads,
236 k,
237 n,
238 b_data.len()
239 )));
240 }
241
242 let mut output = vec![0.0f32; total_heads * out_head_stride];
244
245 for bh in 0..total_heads {
248 let a_offset = bh * a_head_stride;
249 let b_offset = bh * b_head_stride;
250 let out_offset = bh * out_head_stride;
251
252 let a_slice = &a_data[a_offset..a_offset + a_head_stride];
253 let b_slice = &b_data[b_offset..b_offset + b_head_stride];
254 let c_slice = &mut output[out_offset..out_offset + out_head_stride];
255
256 #[cfg(not(target_arch = "wasm32"))]
257 {
258 crate::blis::gemm_blis(m, n, k, a_slice, b_slice, c_slice, None)?;
259 }
260 #[cfg(target_arch = "wasm32")]
261 {
262 let a_mat = Matrix::from_slice(m, k, a_slice)?;
263 let b_mat = Matrix::from_slice(k, n, b_slice)?;
264 let result = a_mat.matmul(&b_mat)?;
265 c_slice.copy_from_slice(result.as_slice());
266 }
267 }
268
269 Ok(output)
270 }
271
272 #[cfg_attr(feature = "tracing", instrument(skip(self, other), fields(k = self.cols, n = other.cols)))]
280 fn matmul_vector_matrix(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
281 debug_assert_eq!(self.rows, 1);
282
283 let k = self.cols;
284 let n = other.cols;
285 let mut c = vec![0.0f32; n];
287
288 crate::blis::gemv::gemv(k, n, &self.data, &other.data, &mut c);
289
290 Matrix::from_vec(1, n, c)
291 }
292
293 fn matmul_naive(
295 &self,
296 other: &Matrix<f32>,
297 result: &mut Matrix<f32>,
298 ) -> Result<(), TruenoError> {
299 let m = self.rows;
300 let k = self.cols;
301 let n = other.cols;
302 let a = &self.data;
305 let b = &other.data;
306 let c = &mut result.data;
307
308 for i in 0..m {
309 let a_row = i * k;
310 let c_row = i * n;
311 for j in 0..n {
312 let mut sum = 0.0f32;
313 for kk in 0..k {
314 sum += a[a_row + kk] * b[kk * n + j];
316 }
317 c[c_row + j] = sum;
318 }
319 }
320 Ok(())
321 }
322
323 #[allow(dead_code)]
325 fn matmul_wasm_tiled(
326 &self,
327 other: &Matrix<f32>,
328 result: &mut Matrix<f32>,
329 ) -> Result<(), TruenoError> {
330 let m = self.rows;
331 let k = self.cols;
332 let n = other.cols;
333
334 for i in 0..m {
335 let a_row_start = i * k;
336 let result_row_start = i * n;
337
338 let simd_width = 8;
339 let n_simd = (n / simd_width) * simd_width;
340
341 #[allow(clippy::needless_range_loop)]
342 for j0 in (0..n_simd).step_by(simd_width) {
343 let mut acc = [0.0f32; 8];
344
345 for kk in 0..k {
346 let a_val = self.data[a_row_start + kk];
347 let b_row_start = kk * n + j0;
348
349 for jj in 0..simd_width {
350 acc[jj] += a_val * other.data[b_row_start + jj];
351 }
352 }
353
354 for jj in 0..simd_width {
355 result.data[result_row_start + j0 + jj] = acc[jj];
356 }
357 }
358
359 for j in n_simd..n {
360 let mut sum = 0.0f32;
361 for kk in 0..k {
362 sum += self.data[a_row_start + kk] * other.data[kk * n + j];
363 }
364 result.data[result_row_start + j] = sum;
365 }
366 }
367
368 Ok(())
369 }
370
371 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
373 fn matmul_gpu(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
374 #[cfg(feature = "cuda")]
377 {
378 if let Ok(result) = self.matmul_cublas(other) {
379 return Ok(result);
380 }
381 }
383
384 use crate::backends::gpu::GpuBackend;
385
386 if !GpuBackend::is_available() {
387 return Err(TruenoError::InvalidInput("GPU not available".to_string()));
388 }
389
390 let mut gpu = GpuBackend::new();
391 let result_data = gpu
392 .matmul(&self.data, &other.data, self.rows, self.cols, other.cols)
393 .map_err(|e| TruenoError::InvalidInput(format!("GPU matmul failed: {}", e)))?;
394
395 let mut result = Matrix::zeros(self.rows, other.cols);
396 result.data = result_data;
397
398 Ok(result)
399 }
400
401 #[cfg(feature = "cuda")]
404 fn matmul_cublas(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
405 use trueno_gpu::driver::{CublasHandle, CudaContext, CudaStream, GemmOp, GpuBuffer};
406
407 let m = self.rows;
408 let k = self.cols;
409 let n = other.cols;
410
411 let ctx = CudaContext::new(0)
412 .map_err(|e| TruenoError::InvalidInput(format!("CUDA init: {e}")))?;
413 let stream = CudaStream::new(&ctx)
414 .map_err(|e| TruenoError::InvalidInput(format!("CUDA stream: {e}")))?;
415 let handle = CublasHandle::new(&ctx)
416 .map_err(|e| TruenoError::InvalidInput(format!("cuBLAS init: {e}")))?;
417 handle
418 .set_stream(&stream)
419 .map_err(|e| TruenoError::InvalidInput(format!("cuBLAS stream: {e}")))?;
420
421 let a_buf = GpuBuffer::from_host(&ctx, &self.data)
422 .map_err(|e| TruenoError::InvalidInput(format!("GPU alloc A: {e}")))?;
423 let b_buf = GpuBuffer::from_host(&ctx, &other.data)
424 .map_err(|e| TruenoError::InvalidInput(format!("GPU alloc B: {e}")))?;
425 let c_data = vec![0.0f32; m * n];
426 let c_buf = GpuBuffer::from_host(&ctx, &c_data)
427 .map_err(|e| TruenoError::InvalidInput(format!("GPU alloc C: {e}")))?;
428
429 handle
430 .gemm_f32_row_major(
431 m as i32,
432 n as i32,
433 k as i32,
434 1.0,
435 a_buf.as_ptr(),
436 b_buf.as_ptr(),
437 0.0,
438 c_buf.as_ptr(),
439 )
440 .map_err(|e| TruenoError::InvalidInput(format!("cuBLAS GEMM: {e}")))?;
441
442 stream.synchronize().map_err(|e| TruenoError::InvalidInput(format!("CUDA sync: {e}")))?;
443
444 let mut result_data = vec![0.0f32; m * n];
445 c_buf
446 .copy_to_host(&mut result_data)
447 .map_err(|e| TruenoError::InvalidInput(format!("GPU readback: {e}")))?;
448
449 Ok(Matrix { rows: m, cols: n, data: result_data, backend: self.backend })
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[test]
458 fn test_matmul_basic() {
459 let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
460 let b = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
461 let c = a.matmul(&b).unwrap();
462
463 assert_eq!(c.get(0, 0), Some(&19.0));
464 assert_eq!(c.get(0, 1), Some(&22.0));
465 assert_eq!(c.get(1, 0), Some(&43.0));
466 assert_eq!(c.get(1, 1), Some(&50.0));
467 }
468
469 #[test]
470 fn test_matmul_dimension_mismatch() {
471 let a = Matrix::from_vec(2, 3, vec![1.0; 6]).unwrap();
472 let b = Matrix::from_vec(2, 2, vec![1.0; 4]).unwrap();
473 assert!(a.matmul(&b).is_err());
474 }
475
476 #[test]
477 fn test_matmul_identity() {
478 let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
479 let i = Matrix::identity(2);
480 let result = a.matmul(&i).unwrap();
481 assert_eq!(result.as_slice(), a.as_slice());
482 }
483
484 #[test]
485 fn test_batched_matmul() {
486 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; let b = vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0]; let result = Matrix::batched_matmul(&a, &b, 2, 2, 2, 2).unwrap();
489 assert_eq!(result, a); }
491
492 #[test]
493 fn test_batched_matmul_a_size_mismatch() {
494 let a = vec![1.0, 2.0, 3.0]; let b = vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
496 let result = Matrix::batched_matmul(&a, &b, 2, 2, 2, 2);
497 assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
498 }
499
500 #[test]
501 fn test_batched_matmul_b_size_mismatch() {
502 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
503 let b = vec![1.0, 0.0]; let result = Matrix::batched_matmul(&a, &b, 2, 2, 2, 2);
505 assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
506 }
507
508 #[test]
509 fn test_batched_matmul_single_batch() {
510 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let b = vec![1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]; let result = Matrix::batched_matmul(&a, &b, 1, 3, 2, 4).unwrap();
514 assert_eq!(result.len(), 12); }
516
517 #[test]
518 fn test_batched_matmul_4d_basic() {
519 let a = vec![1.0, 2.0, 3.0, 4.0]; let b = vec![1.0, 0.0, 0.0, 1.0]; let result = Matrix::batched_matmul_4d(&a, &b, 1, 1, 2, 2, 2).unwrap();
523 assert_eq!(result, a);
524 }
525
526 #[test]
527 fn test_batched_matmul_4d_a_size_mismatch() {
528 let a = vec![1.0]; let b: Vec<f32> = (0..80).map(|x| x as f32 * 0.1).collect();
530 let result = Matrix::batched_matmul_4d(&a, &b, 2, 2, 3, 4, 5);
531 assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
532 }
533
534 #[test]
535 fn test_batched_matmul_4d_b_size_mismatch() {
536 let a: Vec<f32> = (0..48).map(|x| x as f32 * 0.1).collect();
537 let b = vec![1.0]; let result = Matrix::batched_matmul_4d(&a, &b, 2, 2, 3, 4, 5);
539 assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
540 }
541
542 #[test]
543 fn test_batched_matmul_4d_multi_head() {
544 let total = 4 * 2 * 2; let a: Vec<f32> = (0..total).map(|_| 1.0).collect();
547 let b: Vec<f32> = (0..total).map(|_| 1.0).collect();
548 let result = Matrix::batched_matmul_4d(&a, &b, 1, 4, 2, 2, 2).unwrap();
549 assert_eq!(result.len(), total);
550 for val in &result {
552 assert!((*val - 2.0).abs() < 1e-5);
553 }
554 }
555
556 #[test]
557 fn test_matmul_vector_matrix_path() {
558 let a = Matrix::from_vec(1, 4, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
560 let b = Matrix::from_vec(
561 4,
562 3,
563 vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
564 )
565 .unwrap();
566 let result = a.matmul(&b).unwrap();
567 assert_eq!(result.rows(), 1);
568 assert_eq!(result.cols(), 3);
569 assert!((result.get(0, 0).unwrap() - 5.0).abs() < 1e-5);
571 assert!((result.get(0, 1).unwrap() - 6.0).abs() < 1e-5);
572 assert!((result.get(0, 2).unwrap() - 7.0).abs() < 1e-5);
573 }
574
575 #[test]
576 fn test_matmul_vector_matrix_with_zeros() {
577 let a = Matrix::from_vec(1, 3, vec![0.0, 2.0, 0.0]).unwrap();
579 let b = Matrix::from_vec(3, 2, vec![100.0, 200.0, 3.0, 4.0, 500.0, 600.0]).unwrap();
580 let result = a.matmul(&b).unwrap();
581 assert!((result.get(0, 0).unwrap() - 6.0).abs() < 1e-5);
583 assert!((result.get(0, 1).unwrap() - 8.0).abs() < 1e-5);
584 }
585
586 #[test]
593 fn test_matmul_wasm_tiled_small_no_simd() {
594 let a = Matrix::from_vec(2, 4, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
597 let b = Matrix::from_vec(
598 4,
599 3,
600 vec![1.0, 0.0, 2.0, 0.0, 1.0, 0.0, 2.0, 0.0, 1.0, 0.0, 2.0, 0.0],
601 )
602 .unwrap();
603 let mut result = Matrix::zeros(2, 3);
604 a.matmul_wasm_tiled(&b, &mut result).unwrap();
605
606 assert!((result.get(0, 0).unwrap() - 7.0).abs() < 1e-5);
608 assert!((result.get(0, 1).unwrap() - 10.0).abs() < 1e-5);
609 assert!((result.get(0, 2).unwrap() - 5.0).abs() < 1e-5);
610
611 assert!((result.get(1, 0).unwrap() - 19.0).abs() < 1e-5);
613 assert!((result.get(1, 1).unwrap() - 22.0).abs() < 1e-5);
614 assert!((result.get(1, 2).unwrap() - 17.0).abs() < 1e-5);
615 }
616
617 #[test]
618 fn test_matmul_wasm_tiled_exact_simd_width() {
619 let a = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
623 let b_data: Vec<f32> = (1..=24).map(|x| x as f32).collect(); let b = Matrix::from_vec(3, 8, b_data).unwrap();
625 let mut result = Matrix::zeros(2, 8);
626 a.matmul_wasm_tiled(&b, &mut result).unwrap();
627
628 let mut expected = Matrix::zeros(2, 8);
630 a.matmul_naive(&b, &mut expected).unwrap();
631 for i in 0..2 {
632 for j in 0..8 {
633 assert!(
634 (result.get(i, j).unwrap() - expected.get(i, j).unwrap()).abs() < 1e-4,
635 "Mismatch at ({}, {}): wasm_tiled={}, naive={}",
636 i,
637 j,
638 result.get(i, j).unwrap(),
639 expected.get(i, j).unwrap()
640 );
641 }
642 }
643 }
644
645 #[test]
646 fn test_matmul_wasm_tiled_simd_plus_remainder() {
647 let a_data: Vec<f32> = (1..=12).map(|x| x as f32).collect();
651 let a = Matrix::from_vec(3, 4, a_data).unwrap();
652 let b_data: Vec<f32> = (1..=44).map(|x| x as f32 * 0.1).collect();
653 let b = Matrix::from_vec(4, 11, b_data).unwrap();
654 let mut result = Matrix::zeros(3, 11);
655 a.matmul_wasm_tiled(&b, &mut result).unwrap();
656
657 let mut expected = Matrix::zeros(3, 11);
659 a.matmul_naive(&b, &mut expected).unwrap();
660 for i in 0..3 {
661 for j in 0..11 {
662 assert!(
663 (result.get(i, j).unwrap() - expected.get(i, j).unwrap()).abs() < 1e-3,
664 "Mismatch at ({}, {}): wasm_tiled={}, naive={}",
665 i,
666 j,
667 result.get(i, j).unwrap(),
668 expected.get(i, j).unwrap()
669 );
670 }
671 }
672 }
673
674 #[test]
675 fn test_matmul_wasm_tiled_multiple_simd_blocks() {
676 let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
679 let b_data: Vec<f32> = (1..=32).map(|x| x as f32).collect();
680 let b = Matrix::from_vec(2, 16, b_data).unwrap();
681 let mut result = Matrix::zeros(2, 16);
682 a.matmul_wasm_tiled(&b, &mut result).unwrap();
683
684 let mut expected = Matrix::zeros(2, 16);
685 a.matmul_naive(&b, &mut expected).unwrap();
686 for i in 0..2 {
687 for j in 0..16 {
688 assert!(
689 (result.get(i, j).unwrap() - expected.get(i, j).unwrap()).abs() < 1e-3,
690 "Mismatch at ({}, {})",
691 i,
692 j,
693 );
694 }
695 }
696 }
697
698 #[test]
699 fn test_matmul_wasm_tiled_single_row() {
700 let a = Matrix::from_vec(1, 5, vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
703 let b_data: Vec<f32> = (1..=50).map(|x| x as f32 * 0.1).collect();
704 let b = Matrix::from_vec(5, 10, b_data).unwrap();
705 let mut result = Matrix::zeros(1, 10);
706 a.matmul_wasm_tiled(&b, &mut result).unwrap();
707
708 let mut expected = Matrix::zeros(1, 10);
709 a.matmul_naive(&b, &mut expected).unwrap();
710 for j in 0..10 {
711 assert!(
712 (result.get(0, j).unwrap() - expected.get(0, j).unwrap()).abs() < 1e-3,
713 "Mismatch at col {}: wasm_tiled={}, naive={}",
714 j,
715 result.get(0, j).unwrap(),
716 expected.get(0, j).unwrap()
717 );
718 }
719 }
720
721 #[test]
722 fn test_matmul_wasm_tiled_identity() {
723 let a = Matrix::from_vec(
726 4,
727 4,
728 vec![
729 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
730 16.0,
731 ],
732 )
733 .unwrap();
734 let identity = Matrix::identity(4);
735 let mut result = Matrix::zeros(4, 4);
736 a.matmul_wasm_tiled(&identity, &mut result).unwrap();
737
738 assert_eq!(result.as_slice(), a.as_slice());
739 }
740
741 #[test]
742 fn test_matmul_wasm_tiled_large_mixed() {
743 let a_data: Vec<f32> = (0..50).map(|x| (x as f32) * 0.1).collect();
747 let a = Matrix::from_vec(5, 10, a_data).unwrap();
748 let b_data: Vec<f32> = (0..190).map(|x| (x as f32) * 0.01).collect();
749 let b = Matrix::from_vec(10, 19, b_data).unwrap();
750 let mut result = Matrix::zeros(5, 19);
751 a.matmul_wasm_tiled(&b, &mut result).unwrap();
752
753 let mut expected = Matrix::zeros(5, 19);
754 a.matmul_naive(&b, &mut expected).unwrap();
755 for i in 0..5 {
756 for j in 0..19 {
757 assert!(
758 (result.get(i, j).unwrap() - expected.get(i, j).unwrap()).abs() < 1e-2,
759 "Mismatch at ({}, {}): wasm_tiled={}, naive={}",
760 i,
761 j,
762 result.get(i, j).unwrap(),
763 expected.get(i, j).unwrap()
764 );
765 }
766 }
767 }
768
769 #[test]
785 fn falsify_mm_001_shape_correctness() {
786 for &(m, p, n) in &[(1, 1, 1), (2, 3, 4), (16, 32, 8), (1, 100, 1), (64, 1, 64)] {
787 let a = Matrix::from_vec(m, p, vec![1.0; m * p]).unwrap();
788 let b = Matrix::from_vec(p, n, vec![1.0; p * n]).unwrap();
789 let c = a.matmul(&b).unwrap();
790 assert_eq!(
791 (c.rows(), c.cols()),
792 (m, n),
793 "FALSIFIED MM-001: matmul([{m},{p}], [{p},{n}]) shape = [{},{}], expected [{m},{n}]",
794 c.rows(),
795 c.cols()
796 );
797 }
798 }
799
800 #[test]
802 fn falsify_mm_005_identity_matrix() {
803 let a = Matrix::from_vec(3, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap();
804 let eye =
805 Matrix::from_vec(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).unwrap();
806
807 let ai = a.matmul(&eye).unwrap();
808 let ia = eye.matmul(&a).unwrap();
809
810 for i in 0..3 {
811 for j in 0..3 {
812 let expected = a.get(i, j).unwrap();
813 assert!(
814 (*ai.get(i, j).unwrap() - expected).abs() < 1e-6,
815 "FALSIFIED MM-005: (A*I)[{i},{j}] = {}, expected {expected}",
816 ai.get(i, j).unwrap()
817 );
818 assert!(
819 (*ia.get(i, j).unwrap() - expected).abs() < 1e-6,
820 "FALSIFIED MM-005: (I*A)[{i},{j}] = {}, expected {expected}",
821 ia.get(i, j).unwrap()
822 );
823 }
824 }
825 }
826
827 #[test]
829 fn falsify_mm_002_numerical_accuracy() {
830 let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
831 let b = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
832 let c = a.matmul(&b).unwrap();
833
834 let expected = [19.0, 22.0, 43.0, 50.0];
835 for (i, &exp) in expected.iter().enumerate() {
836 let row = i / 2;
837 let col = i % 2;
838 let val = *c.get(row, col).unwrap();
839 assert!(
840 (val - exp).abs() < 1e-5,
841 "FALSIFIED MM-002: C[{row},{col}] = {val}, expected {exp}"
842 );
843 }
844 }
845
846 #[test]
848 fn falsify_mm_002b_zero_annihilation() {
849 let zero = Matrix::from_vec(3, 4, vec![0.0; 12]).unwrap();
850 let b = Matrix::from_vec(4, 2, vec![1.0; 8]).unwrap();
851 let c = zero.matmul(&b).unwrap();
852
853 for i in 0..3 {
854 for j in 0..2 {
855 let val = *c.get(i, j).unwrap();
856 assert!(
857 val.abs() < 1e-10,
858 "FALSIFIED MM-002b: zeros*B [{i},{j}] = {val}, expected 0"
859 );
860 }
861 }
862 }
863}
864
865#[cfg(test)]
866#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
867mod gpu_tests {
868 use super::*;
869
870 #[test]
874 fn test_matmul_gpu_identity() {
875 use crate::backends::gpu::GpuBackend;
876
877 if !GpuBackend::is_available() {
878 eprintln!("GPU not available, skipping test_matmul_gpu_identity");
879 return;
880 }
881
882 let n = 500; let a_data: Vec<f32> = (0..n * n).map(|i| (i % 100) as f32 * 0.01).collect();
886
887 let mut i_data = vec![0.0f32; n * n];
889 for i in 0..n {
890 i_data[i * n + i] = 1.0;
891 }
892
893 let a = Matrix::from_vec(n, n, a_data.clone()).expect("valid matrix A");
894 let identity = Matrix::from_vec(n, n, i_data).expect("valid identity matrix");
895
896 let result = a.matmul(&identity).expect("matmul should succeed");
897
898 assert_eq!(result.rows(), n);
899 assert_eq!(result.cols(), n);
900
901 let check_indices = [(0, 0), (0, n - 1), (n - 1, 0), (n - 1, n - 1), (n / 2, n / 2)];
903 for &(r, c) in &check_indices {
904 let expected = a_data[r * n + c];
905 let actual = *result.get(r, c).unwrap();
906 assert!(
907 (actual - expected).abs() < 1e-2,
908 "A*I mismatch at ({},{}): gpu={}, expected={}",
909 r,
910 c,
911 actual,
912 expected
913 );
914 }
915 }
916
917 #[test]
919 fn test_matmul_gpu_ones() {
920 use crate::backends::gpu::GpuBackend;
921
922 if !GpuBackend::is_available() {
923 eprintln!("GPU not available, skipping test_matmul_gpu_ones");
924 return;
925 }
926
927 let m = 500;
928 let k = 500;
929 let n = 500;
930
931 let a = Matrix::from_vec(m, k, vec![1.0f32; m * k]).expect("valid matrix A");
932 let b = Matrix::from_vec(k, n, vec![1.0f32; k * n]).expect("valid matrix B");
933
934 let result = a.matmul(&b).expect("matmul should succeed");
935
936 assert_eq!(result.rows(), m);
937 assert_eq!(result.cols(), n);
938
939 let expected = k as f32;
941 for i in 0..10 {
942 for j in 0..10 {
943 assert!(
944 (result.get(i, j).unwrap() - expected).abs() < 1.0,
945 "C[{},{}] = {}, expected {}",
946 i,
947 j,
948 result.get(i, j).unwrap(),
949 expected
950 );
951 }
952 }
953 }
954
955 #[test]
957 fn test_matmul_gpu_direct() {
958 use crate::backends::gpu::GpuBackend;
959
960 if !GpuBackend::is_available() {
961 eprintln!("GPU not available, skipping test_matmul_gpu_direct");
962 return;
963 }
964
965 let a = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("valid A");
967 let b = Matrix::from_vec(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).expect("valid B");
968
969 let result = a.matmul_gpu(&b).expect("matmul_gpu should succeed");
970
971 assert_eq!(result.rows(), 2);
972 assert_eq!(result.cols(), 2);
973
974 assert!(
980 (result.get(0, 0).unwrap() - 58.0).abs() < 1e-2,
981 "Expected 58.0, got {}",
982 result.get(0, 0).unwrap()
983 );
984 assert!(
985 (result.get(0, 1).unwrap() - 64.0).abs() < 1e-2,
986 "Expected 64.0, got {}",
987 result.get(0, 1).unwrap()
988 );
989 assert!(
990 (result.get(1, 0).unwrap() - 139.0).abs() < 1e-2,
991 "Expected 139.0, got {}",
992 result.get(1, 0).unwrap()
993 );
994 assert!(
995 (result.get(1, 1).unwrap() - 154.0).abs() < 1e-2,
996 "Expected 154.0, got {}",
997 result.get(1, 1).unwrap()
998 );
999 }
1000
1001 #[test]
1003 fn test_matmul_gpu_not_available_path() {
1004 use crate::backends::gpu::GpuBackend;
1005
1006 if !GpuBackend::is_available() {
1009 let a = Matrix::from_vec(2, 2, vec![1.0; 4]).unwrap();
1011 let b = Matrix::from_vec(2, 2, vec![1.0; 4]).unwrap();
1012 let result = a.matmul_gpu(&b);
1013 assert!(result.is_err(), "matmul_gpu should fail without GPU");
1014 }
1015 }
1016}