1use crate::{Backend, TruenoError, Vector};
18
19#[cfg(feature = "tracing")]
20use tracing::instrument;
21
22#[derive(Debug, Clone, PartialEq)]
49pub struct Matrix<T> {
50 rows: usize,
51 cols: usize,
52 data: Vec<T>,
53 backend: Backend,
54}
55
56impl Matrix<f32> {
57 pub fn new(rows: usize, cols: usize) -> Self {
78 let backend = Backend::select_best();
79 Matrix {
80 rows,
81 cols,
82 data: vec![0.0; rows * cols],
83 backend,
84 }
85 }
86
87 pub fn from_vec(rows: usize, cols: usize, data: Vec<f32>) -> Result<Self, TruenoError> {
109 if data.len() != rows * cols {
110 return Err(TruenoError::InvalidInput(format!(
111 "Data length {} does not match matrix dimensions {}x{} (expected {})",
112 data.len(),
113 rows,
114 cols,
115 rows * cols
116 )));
117 }
118
119 let backend = Backend::select_best();
120 Ok(Matrix {
121 rows,
122 cols,
123 data,
124 backend,
125 })
126 }
127
128 pub fn from_slice(rows: usize, cols: usize, data: &[f32]) -> Result<Self, TruenoError> {
154 Self::from_vec(rows, cols, data.to_vec())
155 }
156
157 pub fn zeros(rows: usize, cols: usize) -> Self {
168 Matrix::new(rows, cols)
169 }
170
171 fn zeros_with_backend(rows: usize, cols: usize, backend: Backend) -> Self {
174 Matrix {
175 rows,
176 cols,
177 data: vec![0.0; rows * cols],
178 backend,
179 }
180 }
181
182 pub fn identity(n: usize) -> Self {
195 let mut data = vec![0.0; n * n];
196 for i in 0..n {
197 data[i * n + i] = 1.0;
198 }
199 let backend = Backend::select_best();
200 Matrix {
201 rows: n,
202 cols: n,
203 data,
204 backend,
205 }
206 }
207
208 pub fn rows(&self) -> usize {
210 self.rows
211 }
212
213 pub fn cols(&self) -> usize {
215 self.cols
216 }
217
218 pub fn shape(&self) -> (usize, usize) {
220 (self.rows, self.cols)
221 }
222
223 pub fn get(&self, row: usize, col: usize) -> Option<&f32> {
227 if row >= self.rows || col >= self.cols {
228 None
229 } else {
230 self.data.get(row * self.cols + col)
231 }
232 }
233
234 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut f32> {
238 if row >= self.rows || col >= self.cols {
239 None
240 } else {
241 let idx = row * self.cols + col;
242 self.data.get_mut(idx)
243 }
244 }
245
246 pub fn as_slice(&self) -> &[f32] {
248 &self.data
249 }
250
251 #[cfg_attr(feature = "tracing", instrument(skip(self, other), fields(dims = %format!("{}x{} @ {}x{}", self.rows, self.cols, other.rows, other.cols))))]
285 pub fn matmul(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
286 if self.cols != other.rows {
287 return Err(TruenoError::InvalidInput(format!(
288 "Matrix dimension mismatch for multiplication: {}×{} × {}×{} (inner dimensions {} and {} must match)",
289 self.rows, self.cols, other.rows, other.cols, self.cols, other.rows
290 )));
291 }
292
293 if self.rows == 1 {
297 return self.matmul_vector_matrix(other);
298 }
299
300 let mut result = Matrix::zeros_with_backend(self.rows, other.cols, self.backend);
301
302 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
308 const GPU_THRESHOLD: usize = 500; const SIMD_THRESHOLD: usize = 64;
310
311 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
313 {
314 if self.rows >= GPU_THRESHOLD
315 && self.cols >= GPU_THRESHOLD
316 && other.cols >= GPU_THRESHOLD
317 {
318 if let Ok(gpu_result) = self.matmul_gpu(other) {
319 return Ok(gpu_result);
320 }
321 }
323 }
324
325 if self.rows >= SIMD_THRESHOLD
327 || self.cols >= SIMD_THRESHOLD
328 || other.cols >= SIMD_THRESHOLD
329 {
330 const TILED_THRESHOLD: usize = 512;
333
334 let max_dim = self.rows.max(self.cols).max(other.cols);
335
336 if max_dim < TILED_THRESHOLD {
337 self.matmul_wasm_tiled(other, &mut result)?;
340 } else {
341 #[cfg(target_arch = "wasm32")]
343 {
344 self.matmul_wasm_tiled(other, &mut result)?;
346 }
347 #[cfg(not(target_arch = "wasm32"))]
348 {
349 self.matmul_simd(other, &mut result)?;
351 }
352 }
353 } else {
354 self.matmul_naive(other, &mut result)?;
355 }
356
357 Ok(result)
358 }
359
360 #[cfg_attr(feature = "tracing", instrument(skip(a_data, b_data), fields(batch, m, k, n)))]
380 pub fn batched_matmul(
381 a_data: &[f32],
382 b_data: &[f32],
383 batch: usize,
384 m: usize,
385 k: usize,
386 n: usize,
387 ) -> Result<Vec<f32>, TruenoError> {
388 let a_stride = m * k;
389 let b_stride = k * n;
390 let out_stride = m * n;
391
392 if a_data.len() != batch * a_stride {
394 return Err(TruenoError::InvalidInput(format!(
395 "A data size mismatch: expected {} ({}×{}×{}), got {}",
396 batch * a_stride, batch, m, k, a_data.len()
397 )));
398 }
399 if b_data.len() != batch * b_stride {
400 return Err(TruenoError::InvalidInput(format!(
401 "B data size mismatch: expected {} ({}×{}×{}), got {}",
402 batch * b_stride, batch, k, n, b_data.len()
403 )));
404 }
405
406 let mut output = vec![0.0f32; batch * out_stride];
407
408 for ba in 0..batch {
410 let a_offset = ba * a_stride;
411 let b_offset = ba * b_stride;
412 let out_offset = ba * out_stride;
413
414 let a_slice = &a_data[a_offset..a_offset + a_stride];
416 let b_slice = &b_data[b_offset..b_offset + b_stride];
417
418 let a_mat = Matrix::from_slice(m, k, a_slice)?;
420 let b_mat = Matrix::from_slice(k, n, b_slice)?;
421
422 let result = a_mat.matmul(&b_mat)?;
424
425 output[out_offset..out_offset + out_stride].copy_from_slice(result.as_slice());
427 }
428
429 Ok(output)
430 }
431
432 #[cfg_attr(feature = "tracing", instrument(skip(a_data, b_data), fields(batch, heads, m, k, n)))]
450 pub fn batched_matmul_4d(
451 a_data: &[f32],
452 b_data: &[f32],
453 batch: usize,
454 heads: usize,
455 m: usize,
456 k: usize,
457 n: usize,
458 ) -> Result<Vec<f32>, TruenoError> {
459 let a_head_stride = m * k;
460 let b_head_stride = k * n;
461 let out_head_stride = m * n;
462 let total_heads = batch * heads;
463
464 let expected_a = total_heads * a_head_stride;
466 let expected_b = total_heads * b_head_stride;
467 if a_data.len() != expected_a {
468 return Err(TruenoError::InvalidInput(format!(
469 "A data size mismatch: expected {} ({}×{}×{}×{}), got {}",
470 expected_a, batch, heads, m, k, a_data.len()
471 )));
472 }
473 if b_data.len() != expected_b {
474 return Err(TruenoError::InvalidInput(format!(
475 "B data size mismatch: expected {} ({}×{}×{}×{}), got {}",
476 expected_b, batch, heads, k, n, b_data.len()
477 )));
478 }
479
480 let mut output = vec![0.0f32; total_heads * out_head_stride];
481
482 for bh in 0..total_heads {
484 let a_offset = bh * a_head_stride;
485 let b_offset = bh * b_head_stride;
486 let out_offset = bh * out_head_stride;
487
488 let a_slice = &a_data[a_offset..a_offset + a_head_stride];
490 let b_slice = &b_data[b_offset..b_offset + b_head_stride];
491
492 let a_mat = Matrix::from_slice(m, k, a_slice)?;
493 let b_mat = Matrix::from_slice(k, n, b_slice)?;
494
495 let result = a_mat.matmul(&b_mat)?;
497
498 output[out_offset..out_offset + out_head_stride].copy_from_slice(result.as_slice());
500 }
501
502 Ok(output)
503 }
504
505 #[cfg_attr(feature = "tracing", instrument(skip(self, other), fields(k = self.cols, n = other.cols)))]
516 fn matmul_vector_matrix(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
517 debug_assert_eq!(self.rows, 1);
518
519 let k = self.cols; let n = other.cols; let mut result = Matrix::zeros_with_backend(1, n, self.backend);
524
525 for ki in 0..k {
529 let a_k = self.data[ki];
530 if a_k == 0.0 {
531 continue; }
533
534 let b_row_start = ki * n;
536
537 for j in 0..n {
540 result.data[j] += a_k * other.data[b_row_start + j];
541 }
542 }
543
544 Ok(result)
545 }
546
547 fn matmul_naive(
549 &self,
550 other: &Matrix<f32>,
551 result: &mut Matrix<f32>,
552 ) -> Result<(), TruenoError> {
553 for i in 0..self.rows {
556 for j in 0..other.cols {
557 let mut sum = 0.0;
558 for k in 0..self.cols {
559 sum += self
561 .get(i, k)
562 .expect("matmul_naive: A[i,k] bounds validated by loop")
563 * other
564 .get(k, j)
565 .expect("matmul_naive: B[k,j] bounds validated by loop");
566 }
567 *result
568 .get_mut(i, j)
569 .expect("matmul_naive: C[i,j] bounds validated by loop") = sum;
570 }
571 }
572 Ok(())
573 }
574
575 #[cfg(target_arch = "x86_64")]
589 #[target_feature(enable = "avx2,fma")]
590 #[inline]
591 unsafe fn matmul_microkernel_4x1_avx2(
592 a_rows: [&[f32]; 4],
593 b_col: &[f32],
594 results: &mut [f32; 4],
595 ) {
596 use std::arch::x86_64::*;
597
598 let len = b_col.len();
599 let chunks = len / 8; let mut acc0 = _mm256_setzero_ps();
603 let mut acc1 = _mm256_setzero_ps();
604 let mut acc2 = _mm256_setzero_ps();
605 let mut acc3 = _mm256_setzero_ps();
606
607 for i in 0..chunks {
609 let offset = i * 8;
610
611 let b_vec = _mm256_loadu_ps(b_col.as_ptr().add(offset));
613
614 let a0_vec = _mm256_loadu_ps(a_rows[0].as_ptr().add(offset));
616 acc0 = _mm256_fmadd_ps(a0_vec, b_vec, acc0);
617
618 let a1_vec = _mm256_loadu_ps(a_rows[1].as_ptr().add(offset));
619 acc1 = _mm256_fmadd_ps(a1_vec, b_vec, acc1);
620
621 let a2_vec = _mm256_loadu_ps(a_rows[2].as_ptr().add(offset));
622 acc2 = _mm256_fmadd_ps(a2_vec, b_vec, acc2);
623
624 let a3_vec = _mm256_loadu_ps(a_rows[3].as_ptr().add(offset));
625 acc3 = _mm256_fmadd_ps(a3_vec, b_vec, acc3);
626 }
627
628 results[0] = Self::horizontal_sum_avx2(acc0);
630 results[1] = Self::horizontal_sum_avx2(acc1);
631 results[2] = Self::horizontal_sum_avx2(acc2);
632 results[3] = Self::horizontal_sum_avx2(acc3);
633
634 let remainder_start = chunks * 8;
636 if remainder_start < len {
637 for i in remainder_start..len {
638 results[0] += a_rows[0][i] * b_col[i];
639 results[1] += a_rows[1][i] * b_col[i];
640 results[2] += a_rows[2][i] * b_col[i];
641 results[3] += a_rows[3][i] * b_col[i];
642 }
643 }
644 }
645
646 #[cfg(target_arch = "x86_64")]
648 #[target_feature(enable = "avx2")]
649 #[inline]
650 unsafe fn horizontal_sum_avx2(v: std::arch::x86_64::__m256) -> f32 {
651 use std::arch::x86_64::*;
652
653 let sum128 = _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
655
656 let sum64 = _mm_hadd_ps(sum128, sum128);
658
659 let sum32 = _mm_hadd_ps(sum64, sum64);
661
662 _mm_cvtss_f32(sum32)
664 }
665
666 #[cfg(feature = "parallel")]
686 #[allow(clippy::too_many_arguments)]
687 fn process_l3_row_block_seq(
688 iii: usize,
689 i3_end: usize,
690 a: &Matrix<f32>,
691 b_transposed: &Matrix<f32>,
692 result: &mut Matrix<f32>,
693 l2_block_size: usize,
694 l3_block_size: usize,
695 ) {
696 #[cfg(target_arch = "x86_64")]
697 use crate::backends::{avx2::Avx2Backend, sse2::Sse2Backend};
698 use crate::backends::{scalar::ScalarBackend, VectorBackend};
699
700 for jjj in (0..b_transposed.rows).step_by(l3_block_size) {
702 let j3_end = (jjj + l3_block_size).min(b_transposed.rows);
703
704 for kkk in (0..a.cols).step_by(l3_block_size) {
705 let k3_end = (kkk + l3_block_size).min(a.cols);
706
707 for ii in (iii..i3_end).step_by(l2_block_size) {
709 let i_end = (ii + l2_block_size).min(i3_end);
710
711 for jj in (jjj..j3_end).step_by(l2_block_size) {
712 let j_end = (jj + l2_block_size).min(j3_end);
713
714 for kk in (kkk..k3_end).step_by(l2_block_size) {
715 let k_end = (kk + l2_block_size).min(k3_end);
716 let block_size = k_end - kk;
717
718 #[cfg(target_arch = "x86_64")]
720 let use_microkernel =
721 matches!(a.backend, Backend::AVX2 | Backend::AVX512);
722
723 #[cfg(target_arch = "x86_64")]
724 if use_microkernel {
725 let mut i = ii;
726
727 while i + 4 <= i_end {
729 let row0_start = i * a.cols + kk;
730 let row1_start = (i + 1) * a.cols + kk;
731 let row2_start = (i + 2) * a.cols + kk;
732 let row3_start = (i + 3) * a.cols + kk;
733
734 let a_rows = [
735 &a.data[row0_start..row0_start + block_size],
736 &a.data[row1_start..row1_start + block_size],
737 &a.data[row2_start..row2_start + block_size],
738 &a.data[row3_start..row3_start + block_size],
739 ];
740
741 for j in jj..j_end {
742 let col_start = j * b_transposed.cols + kk;
743 let b_col =
744 &b_transposed.data[col_start..col_start + block_size];
745
746 let mut partial_dots = [0.0f32; 4];
747 unsafe {
748 Matrix::matmul_microkernel_4x1_avx2(
749 a_rows,
750 b_col,
751 &mut partial_dots,
752 );
753 }
754
755 result.data[i * result.cols + j] += partial_dots[0];
756 result.data[(i + 1) * result.cols + j] += partial_dots[1];
757 result.data[(i + 2) * result.cols + j] += partial_dots[2];
758 result.data[(i + 3) * result.cols + j] += partial_dots[3];
759 }
760
761 i += 4;
762 }
763
764 for i in i..i_end {
766 let row_start = i * a.cols + kk;
767 let a_row = &a.data[row_start..row_start + block_size];
768
769 for j in jj..j_end {
770 let col_start = j * b_transposed.cols + kk;
771 let b_col =
772 &b_transposed.data[col_start..col_start + block_size];
773
774 let partial_dot = unsafe { Avx2Backend::dot(a_row, b_col) };
775 result.data[i * result.cols + j] += partial_dot;
776 }
777 }
778 } else {
779 #[allow(unused_variables)]
781 for i in ii..i_end {
782 let row_start = i * a.cols + kk;
783 let a_row = &a.data[row_start..row_start + block_size];
784
785 for j in jj..j_end {
786 let col_start = j * b_transposed.cols + kk;
787 let b_col =
788 &b_transposed.data[col_start..col_start + block_size];
789
790 let partial_dot = unsafe {
791 match a.backend {
792 Backend::Scalar => ScalarBackend::dot(a_row, b_col),
793 #[cfg(target_arch = "x86_64")]
794 Backend::SSE2 | Backend::AVX => {
795 Sse2Backend::dot(a_row, b_col)
796 }
797 #[cfg(not(target_arch = "x86_64"))]
798 Backend::SSE2
799 | Backend::AVX
800 | Backend::AVX2
801 | Backend::AVX512 => {
802 ScalarBackend::dot(a_row, b_col)
803 }
804 #[cfg(any(
805 target_arch = "aarch64",
806 target_arch = "arm"
807 ))]
808 Backend::NEON => {
809 use crate::backends::neon::NeonBackend;
810 NeonBackend::dot(a_row, b_col)
811 }
812 #[cfg(not(any(
813 target_arch = "aarch64",
814 target_arch = "arm"
815 )))]
816 Backend::NEON => ScalarBackend::dot(a_row, b_col),
817 #[cfg(target_arch = "wasm32")]
818 Backend::WasmSIMD => {
819 use crate::backends::wasm::WasmBackend;
820 WasmBackend::dot(a_row, b_col)
821 }
822 #[cfg(not(target_arch = "wasm32"))]
823 Backend::WasmSIMD => {
824 ScalarBackend::dot(a_row, b_col)
825 }
826 _ => ScalarBackend::dot(a_row, b_col),
828 }
829 };
830
831 result.data[i * result.cols + j] += partial_dot;
832 }
833 }
834 }
835
836 #[cfg(not(target_arch = "x86_64"))]
838 {
839 for i in ii..i_end {
840 let row_start = i * a.cols + kk;
841 let a_row = &a.data[row_start..row_start + block_size];
842
843 for j in jj..j_end {
844 let col_start = j * b_transposed.cols + kk;
845 let b_col =
846 &b_transposed.data[col_start..col_start + block_size];
847
848 let partial_dot = unsafe {
849 match a.backend {
850 Backend::Scalar => ScalarBackend::dot(a_row, b_col),
851 #[cfg(any(
852 target_arch = "aarch64",
853 target_arch = "arm"
854 ))]
855 Backend::NEON => {
856 use crate::backends::neon::NeonBackend;
857 NeonBackend::dot(a_row, b_col)
858 }
859 #[cfg(not(any(
860 target_arch = "aarch64",
861 target_arch = "arm"
862 )))]
863 Backend::NEON => ScalarBackend::dot(a_row, b_col),
864 #[cfg(target_arch = "wasm32")]
865 Backend::WasmSIMD => {
866 use crate::backends::wasm::WasmBackend;
867 WasmBackend::dot(a_row, b_col)
868 }
869 #[cfg(not(target_arch = "wasm32"))]
870 Backend::WasmSIMD => {
871 ScalarBackend::dot(a_row, b_col)
872 }
873 _ => ScalarBackend::dot(a_row, b_col),
874 }
875 };
876
877 result.data[i * result.cols + j] += partial_dot;
878 }
879 }
880 }
881 }
882 }
883 }
884 }
885 }
886 }
887
888 fn matmul_simd(
889 &self,
890 other: &Matrix<f32>,
891 result: &mut Matrix<f32>,
892 ) -> Result<(), TruenoError> {
893 const L2_BLOCK_SIZE: usize = 64;
896 const L3_BLOCK_SIZE: usize = 256;
898 const L3_THRESHOLD: usize = 512; if self.rows <= 32 || self.cols <= 32 || other.cols <= 32 {
902 return self.matmul_simd_simple(other, result);
903 }
904
905 #[cfg(target_arch = "x86_64")]
906 use crate::backends::{avx2::Avx2Backend, sse2::Sse2Backend};
907 use crate::backends::{scalar::ScalarBackend, VectorBackend};
908
909 let b_transposed = other.transpose();
911
912 let use_l3_blocking =
914 self.rows >= L3_THRESHOLD && self.cols >= L3_THRESHOLD && other.cols >= L3_THRESHOLD;
915
916 #[cfg(feature = "parallel")]
918 const PARALLEL_THRESHOLD: usize = 1024;
919 #[cfg(feature = "parallel")]
920 let use_parallel = self.rows >= PARALLEL_THRESHOLD
921 && self.cols >= PARALLEL_THRESHOLD
922 && other.cols >= PARALLEL_THRESHOLD;
923 #[cfg(not(feature = "parallel"))]
924 let use_parallel = false;
925
926 if use_l3_blocking {
927 if use_parallel {
939 #[cfg(feature = "parallel")]
941 {
942 use rayon::prelude::*;
943 use std::sync::atomic::{AtomicPtr, Ordering};
944 use std::sync::Arc;
945
946 let result_ptr = Arc::new(AtomicPtr::new(result as *mut Matrix<f32>));
956
957 let num_blocks = self.rows.div_ceil(L3_BLOCK_SIZE);
959
960 (0..num_blocks).into_par_iter().for_each(|block_idx| {
962 let iii = block_idx * L3_BLOCK_SIZE;
963 let i3_end = (iii + L3_BLOCK_SIZE).min(self.rows);
964
965 unsafe {
973 let ptr = result_ptr.load(Ordering::Relaxed);
974 Self::process_l3_row_block_seq(
975 iii,
976 i3_end,
977 self,
978 &b_transposed,
979 &mut *ptr,
980 L2_BLOCK_SIZE,
981 L3_BLOCK_SIZE,
982 );
983 }
984 });
985 }
986
987 return Ok(());
988 }
989
990 for iii in (0..self.rows).step_by(L3_BLOCK_SIZE) {
992 let i3_end = (iii + L3_BLOCK_SIZE).min(self.rows);
993
994 for jjj in (0..other.cols).step_by(L3_BLOCK_SIZE) {
995 let j3_end = (jjj + L3_BLOCK_SIZE).min(other.cols);
996
997 for kkk in (0..self.cols).step_by(L3_BLOCK_SIZE) {
998 let k3_end = (kkk + L3_BLOCK_SIZE).min(self.cols);
999
1000 for ii in (iii..i3_end).step_by(L2_BLOCK_SIZE) {
1002 let i_end = (ii + L2_BLOCK_SIZE).min(i3_end);
1003
1004 for jj in (jjj..j3_end).step_by(L2_BLOCK_SIZE) {
1005 let j_end = (jj + L2_BLOCK_SIZE).min(j3_end);
1006
1007 for kk in (kkk..k3_end).step_by(L2_BLOCK_SIZE) {
1008 let k_end = (kk + L2_BLOCK_SIZE).min(k3_end);
1009 let block_size = k_end - kk;
1010
1011 #[cfg(target_arch = "x86_64")]
1013 let use_microkernel =
1014 matches!(self.backend, Backend::AVX2 | Backend::AVX512);
1015
1016 #[cfg(target_arch = "x86_64")]
1017 if use_microkernel {
1018 let mut i = ii;
1019
1020 while i + 4 <= i_end {
1022 let row0_start = i * self.cols + kk;
1023 let row1_start = (i + 1) * self.cols + kk;
1024 let row2_start = (i + 2) * self.cols + kk;
1025 let row3_start = (i + 3) * self.cols + kk;
1026
1027 let a_rows = [
1028 &self.data[row0_start..row0_start + block_size],
1029 &self.data[row1_start..row1_start + block_size],
1030 &self.data[row2_start..row2_start + block_size],
1031 &self.data[row3_start..row3_start + block_size],
1032 ];
1033
1034 for j in jj..j_end {
1035 let col_start = j * b_transposed.cols + kk;
1036 let b_col = &b_transposed.data
1037 [col_start..col_start + block_size];
1038
1039 let mut partial_dots = [0.0f32; 4];
1040 unsafe {
1041 Self::matmul_microkernel_4x1_avx2(
1042 a_rows,
1043 b_col,
1044 &mut partial_dots,
1045 );
1046 }
1047
1048 result.data[i * result.cols + j] += partial_dots[0];
1049 result.data[(i + 1) * result.cols + j] +=
1050 partial_dots[1];
1051 result.data[(i + 2) * result.cols + j] +=
1052 partial_dots[2];
1053 result.data[(i + 3) * result.cols + j] +=
1054 partial_dots[3];
1055 }
1056
1057 i += 4;
1058 }
1059
1060 for i in i..i_end {
1062 let row_start = i * self.cols + kk;
1063 let a_row =
1064 &self.data[row_start..row_start + block_size];
1065
1066 for j in jj..j_end {
1067 let col_start = j * b_transposed.cols + kk;
1068 let b_col = &b_transposed.data
1069 [col_start..col_start + block_size];
1070
1071 let partial_dot =
1072 unsafe { Avx2Backend::dot(a_row, b_col) };
1073 result.data[i * result.cols + j] += partial_dot;
1074 }
1075 }
1076 } else {
1077 #[allow(unused_variables)]
1079 for i in ii..i_end {
1080 let row_start = i * self.cols + kk;
1081 let a_row =
1082 &self.data[row_start..row_start + block_size];
1083
1084 for j in jj..j_end {
1085 let col_start = j * b_transposed.cols + kk;
1086 let b_col = &b_transposed.data
1087 [col_start..col_start + block_size];
1088
1089 let partial_dot = unsafe {
1090 match self.backend {
1091 Backend::Scalar => {
1092 ScalarBackend::dot(a_row, b_col)
1093 }
1094 #[cfg(target_arch = "x86_64")]
1095 Backend::SSE2 | Backend::AVX => {
1096 Sse2Backend::dot(a_row, b_col)
1097 }
1098 #[cfg(not(target_arch = "x86_64"))]
1099 Backend::SSE2
1100 | Backend::AVX
1101 | Backend::AVX2
1102 | Backend::AVX512 => {
1103 ScalarBackend::dot(a_row, b_col)
1104 }
1105 #[cfg(any(
1106 target_arch = "aarch64",
1107 target_arch = "arm"
1108 ))]
1109 Backend::NEON => {
1110 use crate::backends::neon::NeonBackend;
1111 NeonBackend::dot(a_row, b_col)
1112 }
1113 #[cfg(not(any(
1114 target_arch = "aarch64",
1115 target_arch = "arm"
1116 )))]
1117 Backend::NEON => {
1118 ScalarBackend::dot(a_row, b_col)
1119 }
1120 #[cfg(target_arch = "wasm32")]
1121 Backend::WasmSIMD => {
1122 use crate::backends::wasm::WasmBackend;
1123 WasmBackend::dot(a_row, b_col)
1124 }
1125 #[cfg(not(target_arch = "wasm32"))]
1126 Backend::WasmSIMD => {
1127 ScalarBackend::dot(a_row, b_col)
1128 }
1129 Backend::GPU
1130 | Backend::Auto
1131 | Backend::AVX2
1132 | Backend::AVX512 => {
1133 ScalarBackend::dot(a_row, b_col)
1134 }
1135 }
1136 };
1137
1138 result.data[i * result.cols + j] += partial_dot;
1139 }
1140 }
1141 }
1142
1143 #[cfg(not(target_arch = "x86_64"))]
1145 for i in ii..i_end {
1146 let row_start = i * self.cols + kk;
1147 let a_row = &self.data[row_start..row_start + block_size];
1148
1149 for j in jj..j_end {
1150 let col_start = j * b_transposed.cols + kk;
1151 let b_col = &b_transposed.data
1152 [col_start..col_start + block_size];
1153
1154 let partial_dot = unsafe {
1155 match self.backend {
1156 Backend::Scalar => {
1157 ScalarBackend::dot(a_row, b_col)
1158 }
1159 #[cfg(any(
1160 target_arch = "aarch64",
1161 target_arch = "arm"
1162 ))]
1163 Backend::NEON => {
1164 use crate::backends::neon::NeonBackend;
1165 NeonBackend::dot(a_row, b_col)
1166 }
1167 #[cfg(not(any(
1168 target_arch = "aarch64",
1169 target_arch = "arm"
1170 )))]
1171 Backend::NEON => {
1172 ScalarBackend::dot(a_row, b_col)
1173 }
1174 #[cfg(target_arch = "wasm32")]
1175 Backend::WasmSIMD => {
1176 use crate::backends::wasm::WasmBackend;
1177 WasmBackend::dot(a_row, b_col)
1178 }
1179 #[cfg(not(target_arch = "wasm32"))]
1180 Backend::WasmSIMD => {
1181 ScalarBackend::dot(a_row, b_col)
1182 }
1183 _ => ScalarBackend::dot(a_row, b_col),
1184 }
1185 };
1186
1187 result.data[i * result.cols + j] += partial_dot;
1188 }
1189 }
1190 }
1191 }
1192 }
1193 }
1194 }
1195 }
1196 } else {
1197 for ii in (0..self.rows).step_by(L2_BLOCK_SIZE) {
1204 let i_end = (ii + L2_BLOCK_SIZE).min(self.rows);
1205
1206 for jj in (0..other.cols).step_by(L2_BLOCK_SIZE) {
1207 let j_end = (jj + L2_BLOCK_SIZE).min(other.cols);
1208
1209 for kk in (0..self.cols).step_by(L2_BLOCK_SIZE) {
1210 let k_end = (kk + L2_BLOCK_SIZE).min(self.cols);
1211 let block_size = k_end - kk;
1212
1213 #[cfg(target_arch = "x86_64")]
1215 let use_microkernel =
1216 matches!(self.backend, Backend::AVX2 | Backend::AVX512);
1217
1218 #[cfg(target_arch = "x86_64")]
1219 if use_microkernel {
1220 let mut i = ii;
1222
1223 while i + 4 <= i_end {
1225 let row0_start = i * self.cols + kk;
1227 let row1_start = (i + 1) * self.cols + kk;
1228 let row2_start = (i + 2) * self.cols + kk;
1229 let row3_start = (i + 3) * self.cols + kk;
1230
1231 let a_rows = [
1232 &self.data[row0_start..row0_start + block_size],
1233 &self.data[row1_start..row1_start + block_size],
1234 &self.data[row2_start..row2_start + block_size],
1235 &self.data[row3_start..row3_start + block_size],
1236 ];
1237
1238 for j in jj..j_end {
1240 let col_start = j * b_transposed.cols + kk;
1241 let b_col =
1242 &b_transposed.data[col_start..col_start + block_size];
1243
1244 let mut partial_dots = [0.0f32; 4];
1246 unsafe {
1247 Self::matmul_microkernel_4x1_avx2(
1248 a_rows,
1249 b_col,
1250 &mut partial_dots,
1251 );
1252 }
1253
1254 result.data[i * result.cols + j] += partial_dots[0];
1256 result.data[(i + 1) * result.cols + j] += partial_dots[1];
1257 result.data[(i + 2) * result.cols + j] += partial_dots[2];
1258 result.data[(i + 3) * result.cols + j] += partial_dots[3];
1259 }
1260
1261 i += 4;
1262 }
1263
1264 for i in i..i_end {
1266 let row_start = i * self.cols + kk;
1267 let a_row = &self.data[row_start..row_start + block_size];
1268
1269 for j in jj..j_end {
1270 let col_start = j * b_transposed.cols + kk;
1271 let b_col =
1272 &b_transposed.data[col_start..col_start + block_size];
1273
1274 let partial_dot = unsafe { Avx2Backend::dot(a_row, b_col) };
1275 result.data[i * result.cols + j] += partial_dot;
1276 }
1277 }
1278 } else {
1279 #[allow(unused_variables)]
1281 for i in ii..i_end {
1282 let row_start = i * self.cols + kk;
1283 let a_row = &self.data[row_start..row_start + block_size];
1284
1285 for j in jj..j_end {
1286 let col_start = j * b_transposed.cols + kk;
1287 let b_col =
1288 &b_transposed.data[col_start..col_start + block_size];
1289
1290 let partial_dot = unsafe {
1291 match self.backend {
1292 Backend::Scalar => ScalarBackend::dot(a_row, b_col),
1293 #[cfg(target_arch = "x86_64")]
1294 Backend::SSE2 | Backend::AVX => {
1295 Sse2Backend::dot(a_row, b_col)
1296 }
1297 #[cfg(not(target_arch = "x86_64"))]
1298 Backend::SSE2
1299 | Backend::AVX
1300 | Backend::AVX2
1301 | Backend::AVX512 => ScalarBackend::dot(a_row, b_col),
1302 #[cfg(any(
1303 target_arch = "aarch64",
1304 target_arch = "arm"
1305 ))]
1306 Backend::NEON => {
1307 use crate::backends::neon::NeonBackend;
1308 NeonBackend::dot(a_row, b_col)
1309 }
1310 #[cfg(not(any(
1311 target_arch = "aarch64",
1312 target_arch = "arm"
1313 )))]
1314 Backend::NEON => ScalarBackend::dot(a_row, b_col),
1315 #[cfg(target_arch = "wasm32")]
1316 Backend::WasmSIMD => {
1317 use crate::backends::wasm::WasmBackend;
1318 WasmBackend::dot(a_row, b_col)
1319 }
1320 #[cfg(not(target_arch = "wasm32"))]
1321 Backend::WasmSIMD => ScalarBackend::dot(a_row, b_col),
1322 Backend::GPU
1323 | Backend::Auto
1324 | Backend::AVX2
1325 | Backend::AVX512 => ScalarBackend::dot(a_row, b_col),
1326 }
1327 };
1328
1329 result.data[i * result.cols + j] += partial_dot;
1330 }
1331 }
1332 }
1333
1334 #[cfg(not(target_arch = "x86_64"))]
1336 for i in ii..i_end {
1337 let row_start = i * self.cols + kk;
1338 let a_row = &self.data[row_start..row_start + block_size];
1339
1340 for j in jj..j_end {
1341 let col_start = j * b_transposed.cols + kk;
1342 let b_col = &b_transposed.data[col_start..col_start + block_size];
1343
1344 let partial_dot = unsafe {
1345 match self.backend {
1346 Backend::Scalar => ScalarBackend::dot(a_row, b_col),
1347 #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
1348 Backend::NEON => {
1349 use crate::backends::neon::NeonBackend;
1350 NeonBackend::dot(a_row, b_col)
1351 }
1352 #[cfg(not(any(
1353 target_arch = "aarch64",
1354 target_arch = "arm"
1355 )))]
1356 Backend::NEON => ScalarBackend::dot(a_row, b_col),
1357 #[cfg(target_arch = "wasm32")]
1358 Backend::WasmSIMD => {
1359 use crate::backends::wasm::WasmBackend;
1360 WasmBackend::dot(a_row, b_col)
1361 }
1362 #[cfg(not(target_arch = "wasm32"))]
1363 Backend::WasmSIMD => ScalarBackend::dot(a_row, b_col),
1364 _ => ScalarBackend::dot(a_row, b_col),
1365 }
1366 };
1367
1368 result.data[i * result.cols + j] += partial_dot;
1369 }
1370 }
1371 }
1372 }
1373 }
1374 }
1375
1376 Ok(())
1377 }
1378
1379 fn matmul_simd_simple(
1384 &self,
1385 other: &Matrix<f32>,
1386 result: &mut Matrix<f32>,
1387 ) -> Result<(), TruenoError> {
1388 #[cfg(target_arch = "x86_64")]
1389 use crate::backends::{avx2::Avx2Backend, sse2::Sse2Backend};
1390 use crate::backends::{scalar::ScalarBackend, VectorBackend};
1391
1392 let b_transposed = other.transpose();
1394
1395 for i in 0..self.rows {
1396 let row_start = i * self.cols;
1397 let row_end = row_start + self.cols;
1398 let a_row = &self.data[row_start..row_end];
1399
1400 for j in 0..other.cols {
1401 let col_start = j * b_transposed.cols;
1402 let col_end = col_start + b_transposed.cols;
1403 let b_col = &b_transposed.data[col_start..col_end];
1404
1405 let dot_result = unsafe {
1408 match self.backend {
1409 Backend::Scalar => ScalarBackend::dot(a_row, b_col),
1410 #[cfg(target_arch = "x86_64")]
1411 Backend::SSE2 | Backend::AVX => Sse2Backend::dot(a_row, b_col),
1412 #[cfg(target_arch = "x86_64")]
1413 Backend::AVX2 | Backend::AVX512 => Avx2Backend::dot(a_row, b_col),
1414 #[cfg(not(target_arch = "x86_64"))]
1415 Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
1416 ScalarBackend::dot(a_row, b_col)
1417 }
1418 #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
1419 Backend::NEON => {
1420 use crate::backends::neon::NeonBackend;
1421 NeonBackend::dot(a_row, b_col)
1422 }
1423 #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
1424 Backend::NEON => ScalarBackend::dot(a_row, b_col),
1425 #[cfg(target_arch = "wasm32")]
1426 Backend::WasmSIMD => {
1427 use crate::backends::wasm::WasmBackend;
1428 WasmBackend::dot(a_row, b_col)
1429 }
1430 #[cfg(not(target_arch = "wasm32"))]
1431 Backend::WasmSIMD => ScalarBackend::dot(a_row, b_col),
1432 Backend::GPU | Backend::Auto => ScalarBackend::dot(a_row, b_col),
1433 }
1434 };
1435
1436 result.data[i * result.cols + j] = dot_result;
1437 }
1438 }
1439
1440 Ok(())
1441 }
1442
1443 fn matmul_wasm_tiled(
1453 &self,
1454 other: &Matrix<f32>,
1455 result: &mut Matrix<f32>,
1456 ) -> Result<(), TruenoError> {
1457 let m = self.rows;
1458 let k = self.cols;
1459 let n = other.cols;
1460
1461 for i in 0..m {
1463 let a_row_start = i * k;
1464 let result_row_start = i * n;
1465
1466 let simd_width = 8; let n_simd = (n / simd_width) * simd_width;
1477
1478 #[allow(clippy::needless_range_loop)]
1482 for j0 in (0..n_simd).step_by(simd_width) {
1483 let mut acc = [0.0f32; 8];
1484
1485 for kk in 0..k {
1486 let a_val = self.data[a_row_start + kk];
1487 let b_row_start = kk * n + j0;
1488
1489 for jj in 0..simd_width {
1491 acc[jj] += a_val * other.data[b_row_start + jj];
1492 }
1493 }
1494
1495 for jj in 0..simd_width {
1497 result.data[result_row_start + j0 + jj] = acc[jj];
1498 }
1499 }
1500
1501 for j in n_simd..n {
1503 let mut sum = 0.0f32;
1504 for kk in 0..k {
1505 sum += self.data[a_row_start + kk] * other.data[kk * n + j];
1506 }
1507 result.data[result_row_start + j] = sum;
1508 }
1509 }
1510
1511 Ok(())
1512 }
1513
1514 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
1516 fn matmul_gpu(&self, other: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
1517 use crate::backends::gpu::GpuBackend;
1518
1519 if !GpuBackend::is_available() {
1521 return Err(TruenoError::InvalidInput("GPU not available".to_string()));
1522 }
1523
1524 let mut gpu = GpuBackend::new();
1526
1527 let result_data = gpu
1529 .matmul(&self.data, &other.data, self.rows, self.cols, other.cols)
1530 .map_err(|e| TruenoError::InvalidInput(format!("GPU matmul failed: {}", e)))?;
1531
1532 let mut result = Matrix::zeros(self.rows, other.cols);
1534 result.data = result_data;
1535
1536 Ok(result)
1537 }
1538
1539 #[cfg_attr(feature = "tracing", instrument(skip(self), fields(dims = %format!("{}x{}", self.rows, self.cols))))]
1566 pub fn transpose(&self) -> Matrix<f32> {
1567 let mut result = Matrix::zeros_with_backend(self.cols, self.rows, self.backend);
1568
1569 const BLOCK_SIZE: usize = 64;
1572
1573 for i_block in (0..self.rows).step_by(BLOCK_SIZE) {
1575 for j_block in (0..self.cols).step_by(BLOCK_SIZE) {
1576 let i_end = (i_block + BLOCK_SIZE).min(self.rows);
1578 let j_end = (j_block + BLOCK_SIZE).min(self.cols);
1579
1580 for i in i_block..i_end {
1581 let src_row_start = i * self.cols;
1583 for j in j_block..j_end {
1584 result.data[j * result.cols + i] = self.data[src_row_start + j];
1587 }
1588 }
1589 }
1590 }
1591
1592 result
1593 }
1594
1595 pub fn matvec(&self, v: &Vector<f32>) -> Result<Vector<f32>, TruenoError> {
1634 if v.len() != self.cols {
1635 return Err(TruenoError::InvalidInput(format!(
1636 "Vector length {} does not match matrix columns {} for matrix-vector multiplication",
1637 v.len(),
1638 self.cols
1639 )));
1640 }
1641
1642 #[cfg(target_arch = "x86_64")]
1643 use crate::backends::{avx2::Avx2Backend, sse2::Sse2Backend};
1644 use crate::backends::{scalar::ScalarBackend, VectorBackend};
1645
1646 let v_slice = v.as_slice();
1647
1648 let mut result_data = vec![0.0; self.rows];
1649
1650 #[cfg(feature = "parallel")]
1653 {
1654 const PARALLEL_THRESHOLD: usize = 4096;
1655
1656 if self.rows >= PARALLEL_THRESHOLD {
1657 use rayon::prelude::*;
1658 use std::sync::atomic::{AtomicPtr, Ordering};
1659 use std::sync::Arc;
1660
1661 let result_ptr = Arc::new(AtomicPtr::new(result_data.as_mut_ptr()));
1662
1663 (0..self.rows).into_par_iter().for_each(|i| {
1665 let row_start = i * self.cols;
1666 let row = &self.data[row_start..(row_start + self.cols)];
1667
1668 let dot_result = unsafe {
1669 #[cfg(target_arch = "x86_64")]
1670 {
1671 match self.backend {
1672 Backend::AVX2 | Backend::AVX512 => Avx2Backend::dot(row, v_slice),
1673 Backend::SSE2 | Backend::AVX => Sse2Backend::dot(row, v_slice),
1674 _ => ScalarBackend::dot(row, v_slice),
1675 }
1676 }
1677 #[cfg(not(target_arch = "x86_64"))]
1678 {
1679 ScalarBackend::dot(row, v_slice)
1680 }
1681 };
1682
1683 unsafe {
1685 let ptr = result_ptr.load(Ordering::Relaxed);
1686 *ptr.add(i) = dot_result;
1687 }
1688 });
1689
1690 return Ok(Vector::from_slice(&result_data));
1691 }
1692 }
1693
1694 for (i, result) in result_data.iter_mut().enumerate() {
1696 let row_start = i * self.cols;
1697 let row = &self.data[row_start..(row_start + self.cols)];
1698
1699 *result = unsafe {
1701 #[cfg(target_arch = "x86_64")]
1702 {
1703 match self.backend {
1704 Backend::AVX2 | Backend::AVX512 => Avx2Backend::dot(row, v_slice),
1705 Backend::SSE2 | Backend::AVX => Sse2Backend::dot(row, v_slice),
1706 _ => ScalarBackend::dot(row, v_slice),
1707 }
1708 }
1709 #[cfg(not(target_arch = "x86_64"))]
1710 {
1711 ScalarBackend::dot(row, v_slice)
1712 }
1713 };
1714 }
1715
1716 Ok(Vector::from_slice(&result_data))
1717 }
1718
1719 pub fn vecmat(v: &Vector<f32>, m: &Matrix<f32>) -> Result<Vector<f32>, TruenoError> {
1759 if v.len() != m.rows {
1760 return Err(TruenoError::InvalidInput(format!(
1761 "Vector length {} does not match matrix rows {} for vector-matrix multiplication",
1762 v.len(),
1763 m.rows
1764 )));
1765 }
1766
1767 let mut result = Vector::from_slice(&vec![0.0; m.cols]);
1777 let v_slice = v.as_slice();
1778
1779 for (i, &scalar) in v_slice.iter().enumerate().take(m.rows) {
1781 let row_start = i * m.cols;
1782 let row = &m.data[row_start..(row_start + m.cols)];
1783
1784 let row_vec = Vector::from_slice(row);
1786
1787 let scaled_row = row_vec.scale(scalar)?;
1789 result = result.add(&scaled_row)?;
1790 }
1791
1792 Ok(result)
1793 }
1794
1795 pub fn convolve2d(&self, kernel: &Matrix<f32>) -> Result<Matrix<f32>, TruenoError> {
1845 if kernel.rows > self.rows || kernel.cols > self.cols {
1847 return Err(TruenoError::InvalidInput(format!(
1848 "Kernel size ({}x{}) larger than input ({}x{})",
1849 kernel.rows, kernel.cols, self.rows, self.cols
1850 )));
1851 }
1852
1853 let output_rows = self.rows - kernel.rows + 1;
1855 let output_cols = self.cols - kernel.cols + 1;
1856
1857 let mut result = Matrix::zeros_with_backend(output_rows, output_cols, self.backend);
1859
1860 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
1866 const GPU_THRESHOLD: usize = 10_000;
1867
1868 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
1870 {
1871 if output_rows * output_cols >= GPU_THRESHOLD {
1872 use crate::backends::gpu::GpuBackend;
1873
1874 if GpuBackend::is_available() {
1875 if let Ok(gpu_result) =
1876 self.convolve2d_gpu(kernel, &mut result, output_rows, output_cols)
1877 {
1878 return Ok(gpu_result);
1879 }
1880 }
1882 }
1883 }
1884
1885 for out_row in 0..output_rows {
1888 for out_col in 0..output_cols {
1889 let mut sum = 0.0;
1890
1891 for k_row in 0..kernel.rows {
1893 for k_col in 0..kernel.cols {
1894 let in_row = out_row + k_row;
1895 let in_col = out_col + k_col;
1896
1897 let input_val = self
1899 .get(in_row, in_col)
1900 .expect("convolve2d: input bounds validated by output dimensions");
1901 let kernel_val = kernel
1902 .get(k_row, k_col)
1903 .expect("convolve2d: kernel bounds validated by loop");
1904
1905 sum += input_val * kernel_val;
1906 }
1907 }
1908
1909 *result
1910 .get_mut(out_row, out_col)
1911 .expect("convolve2d: output bounds validated by allocation") = sum;
1912 }
1913 }
1914
1915 Ok(result)
1916 }
1917
1918 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
1920 fn convolve2d_gpu(
1921 &self,
1922 kernel: &Matrix<f32>,
1923 result: &mut Matrix<f32>,
1924 _output_rows: usize,
1925 _output_cols: usize,
1926 ) -> Result<Matrix<f32>, TruenoError> {
1927 use crate::backends::gpu::GpuDevice;
1928
1929 let gpu = GpuDevice::new().map_err(TruenoError::InvalidInput)?;
1930
1931 gpu.convolve2d(
1932 self.as_slice(),
1933 kernel.as_slice(),
1934 result.data.as_mut_slice(),
1935 self.rows,
1936 self.cols,
1937 kernel.rows,
1938 kernel.cols,
1939 )
1940 .map_err(TruenoError::InvalidInput)?;
1941
1942 Ok(result.clone())
1943 }
1944
1945 pub fn embedding_lookup(&self, indices: &[usize]) -> Result<Matrix<f32>, TruenoError> {
1985 for (i, &idx) in indices.iter().enumerate() {
1987 if idx >= self.rows {
1988 return Err(TruenoError::InvalidInput(format!(
1989 "Index {} at position {} is out of bounds for embedding table with {} rows",
1990 idx, i, self.rows
1991 )));
1992 }
1993 }
1994
1995 if indices.is_empty() {
1997 return Ok(Matrix::zeros_with_backend(0, self.cols, self.backend));
1998 }
1999
2000 let seq_len = indices.len();
2002 let embed_dim = self.cols;
2003 let mut result = Matrix::zeros_with_backend(seq_len, embed_dim, self.backend);
2004
2005 for (out_row, &idx) in indices.iter().enumerate() {
2007 let src_start = idx * embed_dim;
2008 let dst_start = out_row * embed_dim;
2009
2010 result.data[dst_start..dst_start + embed_dim]
2012 .copy_from_slice(&self.data[src_start..src_start + embed_dim]);
2013 }
2014
2015 Ok(result)
2016 }
2017
2018 pub fn embedding_lookup_sparse(
2036 &self,
2037 indices: &[usize],
2038 ) -> Result<(Matrix<f32>, Vec<usize>), TruenoError> {
2039 let embeddings = self.embedding_lookup(indices)?;
2040
2041 let mut unique: Vec<usize> = indices.to_vec();
2043 unique.sort_unstable();
2044 unique.dedup();
2045
2046 Ok((embeddings, unique))
2047 }
2048}
2049
2050#[cfg(test)]
2051mod tests {
2052 use super::*;
2053
2054 #[test]
2055 fn test_matrix_new() {
2056 let m = Matrix::new(3, 4);
2057 assert_eq!(m.rows(), 3);
2058 assert_eq!(m.cols(), 4);
2059 assert_eq!(m.shape(), (3, 4));
2060 assert_eq!(m.as_slice().len(), 12);
2061 }
2062
2063 #[test]
2064 fn test_matrix_from_vec() {
2065 let data = vec![1.0, 2.0, 3.0, 4.0];
2066 let m = Matrix::from_vec(2, 2, data).unwrap();
2067 assert_eq!(m.rows(), 2);
2068 assert_eq!(m.cols(), 2);
2069 assert_eq!(m.get(0, 0), Some(&1.0));
2070 assert_eq!(m.get(0, 1), Some(&2.0));
2071 assert_eq!(m.get(1, 0), Some(&3.0));
2072 assert_eq!(m.get(1, 1), Some(&4.0));
2073 }
2074
2075 #[test]
2076 fn test_matrix_from_vec_invalid_size() {
2077 let data = vec![1.0, 2.0, 3.0];
2078 let result = Matrix::from_vec(2, 2, data);
2079 assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
2080 }
2081
2082 #[test]
2083 fn test_matrix_from_slice() {
2084 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
2086 let m = Matrix::from_slice(2, 3, &data).unwrap();
2087 assert_eq!(m.rows(), 2);
2088 assert_eq!(m.cols(), 3);
2089 assert_eq!(m.get(0, 0), Some(&1.0));
2090 assert_eq!(m.get(1, 2), Some(&6.0));
2091 }
2092
2093 #[test]
2094 fn test_matrix_from_slice_invalid() {
2095 let data = [1.0, 2.0, 3.0];
2097 let result = Matrix::from_slice(2, 2, &data);
2098 assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
2099 }
2100
2101 #[test]
2102 fn test_matrix_zeros() {
2103 let m = Matrix::zeros(2, 3);
2104 assert_eq!(m.rows(), 2);
2105 assert_eq!(m.cols(), 3);
2106 for &val in m.as_slice() {
2107 assert_eq!(val, 0.0);
2108 }
2109 }
2110
2111 #[test]
2112 fn test_matrix_identity() {
2113 let m = Matrix::identity(3);
2114 assert_eq!(m.rows(), 3);
2115 assert_eq!(m.cols(), 3);
2116
2117 assert_eq!(m.get(0, 0), Some(&1.0));
2119 assert_eq!(m.get(1, 1), Some(&1.0));
2120 assert_eq!(m.get(2, 2), Some(&1.0));
2121
2122 assert_eq!(m.get(0, 1), Some(&0.0));
2124 assert_eq!(m.get(0, 2), Some(&0.0));
2125 assert_eq!(m.get(1, 0), Some(&0.0));
2126 assert_eq!(m.get(1, 2), Some(&0.0));
2127 assert_eq!(m.get(2, 0), Some(&0.0));
2128 assert_eq!(m.get(2, 1), Some(&0.0));
2129 }
2130
2131 #[test]
2132 fn test_matrix_get_out_of_bounds() {
2133 let m = Matrix::new(2, 2);
2134 assert_eq!(m.get(2, 0), None);
2135 assert_eq!(m.get(0, 2), None);
2136 assert_eq!(m.get(2, 2), None);
2137 }
2138
2139 #[test]
2142 fn test_matmul_basic() {
2143 let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2146 let b = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
2147 let c = a.matmul(&b).unwrap();
2148
2149 assert_eq!(c.rows(), 2);
2150 assert_eq!(c.cols(), 2);
2151 assert_eq!(c.get(0, 0), Some(&19.0));
2152 assert_eq!(c.get(0, 1), Some(&22.0));
2153 assert_eq!(c.get(1, 0), Some(&43.0));
2154 assert_eq!(c.get(1, 1), Some(&50.0));
2155 }
2156
2157 #[test]
2158 fn test_matmul_identity() {
2159 let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2161 let identity = Matrix::identity(2);
2162 let result = a.matmul(&identity).unwrap();
2163
2164 assert_eq!(result.get(0, 0), Some(&1.0));
2165 assert_eq!(result.get(0, 1), Some(&2.0));
2166 assert_eq!(result.get(1, 0), Some(&3.0));
2167 assert_eq!(result.get(1, 1), Some(&4.0));
2168 }
2169
2170 #[test]
2171 fn test_matmul_zeros() {
2172 let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2174 let zeros = Matrix::zeros(2, 2);
2175 let result = a.matmul(&zeros).unwrap();
2176
2177 for &val in result.as_slice() {
2178 assert_eq!(val, 0.0);
2179 }
2180 }
2181
2182 #[test]
2183 fn test_matmul_dimension_mismatch() {
2184 let a = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2186 let b = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2187 let result = a.matmul(&b);
2188
2189 assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
2190 }
2191
2192 #[test]
2193 fn test_matmul_non_square() {
2194 let a = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2199 let b = Matrix::from_vec(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
2200 let c = a.matmul(&b).unwrap();
2201
2202 assert_eq!(c.rows(), 2);
2203 assert_eq!(c.cols(), 2);
2204 assert_eq!(c.get(0, 0), Some(&58.0));
2205 assert_eq!(c.get(0, 1), Some(&64.0));
2206 assert_eq!(c.get(1, 0), Some(&139.0));
2207 assert_eq!(c.get(1, 1), Some(&154.0));
2208 }
2209
2210 #[test]
2211 fn test_matmul_single_element() {
2212 let a = Matrix::from_vec(1, 1, vec![3.0]).unwrap();
2214 let b = Matrix::from_vec(1, 1, vec![4.0]).unwrap();
2215 let c = a.matmul(&b).unwrap();
2216
2217 assert_eq!(c.rows(), 1);
2218 assert_eq!(c.cols(), 1);
2219 assert_eq!(c.get(0, 0), Some(&12.0));
2220 }
2221
2222 #[test]
2223 fn test_matmul_remainder_rows() {
2224 let a = Matrix::from_vec(5, 8, (0..40).map(|i| (i + 1) as f32).collect()).unwrap();
2228 let b = Matrix::from_vec(8, 6, (0..48).map(|i| (i + 1) as f32).collect()).unwrap();
2229 let c = a.matmul(&b).unwrap();
2230
2231 assert_eq!(c.rows(), 5);
2232 assert_eq!(c.cols(), 6);
2233
2234 let expected_00 = (1..=8)
2237 .zip((0..48).step_by(6).map(|i| (i + 1) as f32))
2238 .map(|(a, b)| a as f32 * b)
2239 .sum::<f32>();
2240 assert!((c.get(0, 0).unwrap() - expected_00).abs() < 1.0);
2241 }
2242
2243 #[test]
2244 fn test_matmul_remainder_rows_7() {
2245 let a = Matrix::from_vec(7, 8, (0..56).map(|_| 1.0f32).collect()).unwrap();
2247 let b = Matrix::from_vec(8, 5, (0..40).map(|_| 1.0f32).collect()).unwrap();
2248 let c = a.matmul(&b).unwrap();
2249
2250 assert_eq!(c.rows(), 7);
2251 assert_eq!(c.cols(), 5);
2252 for &val in c.as_slice() {
2254 assert!((val - 8.0).abs() < 1e-5);
2255 }
2256 }
2257
2258 #[test]
2261 fn test_matmul_simd_equivalence_small() {
2262 let a = Matrix::from_vec(8, 8, (0..64).map(|i| i as f32).collect()).unwrap();
2264 let b = Matrix::from_vec(8, 8, (0..64).map(|i| (i * 2) as f32).collect()).unwrap();
2265
2266 let mut result_naive = Matrix::zeros(8, 8);
2267 let mut result_simd = Matrix::zeros(8, 8);
2268
2269 a.matmul_naive(&b, &mut result_naive).unwrap();
2270 a.matmul_simd(&b, &mut result_simd).unwrap();
2271
2272 for i in 0..8 {
2274 for j in 0..8 {
2275 let naive_val = result_naive.get(i, j).unwrap();
2276 let simd_val = result_simd.get(i, j).unwrap();
2277 assert!(
2278 (naive_val - simd_val).abs() < 1e-5,
2279 "Mismatch at ({}, {}): naive={}, simd={}",
2280 i,
2281 j,
2282 naive_val,
2283 simd_val
2284 );
2285 }
2286 }
2287 }
2288
2289 #[test]
2290 fn test_matmul_simd_equivalence_large() {
2291 let size = 128;
2293 let a = Matrix::from_vec(
2294 size,
2295 size,
2296 (0..size * size).map(|i| (i % 100) as f32).collect(),
2297 )
2298 .unwrap();
2299 let b = Matrix::from_vec(
2300 size,
2301 size,
2302 (0..size * size).map(|i| ((i * 2) % 100) as f32).collect(),
2303 )
2304 .unwrap();
2305
2306 let mut result_naive = Matrix::zeros(size, size);
2307 let mut result_simd = Matrix::zeros(size, size);
2308
2309 a.matmul_naive(&b, &mut result_naive).unwrap();
2310 a.matmul_simd(&b, &mut result_simd).unwrap();
2311
2312 for i in 0..size {
2314 for j in 0..size {
2315 let naive_val = result_naive.get(i, j).unwrap();
2316 let simd_val = result_simd.get(i, j).unwrap();
2317 assert!(
2318 (naive_val - simd_val).abs() < 1e-3,
2319 "Mismatch at ({}, {}): naive={}, simd={}",
2320 i,
2321 j,
2322 naive_val,
2323 simd_val
2324 );
2325 }
2326 }
2327 }
2328
2329 #[test]
2330 fn test_matmul_simd_equivalence_rectangular() {
2331 let a = Matrix::from_vec(64, 128, (0..64 * 128).map(|i| i as f32).collect()).unwrap();
2333 let b = Matrix::from_vec(128, 32, (0..128 * 32).map(|i| (i * 3) as f32).collect()).unwrap();
2334
2335 let mut result_naive = Matrix::zeros(64, 32);
2336 let mut result_simd = Matrix::zeros(64, 32);
2337
2338 a.matmul_naive(&b, &mut result_naive).unwrap();
2339 a.matmul_simd(&b, &mut result_simd).unwrap();
2340
2341 for i in 0..64 {
2343 for j in 0..32 {
2344 let naive_val = result_naive.get(i, j).unwrap();
2345 let simd_val = result_simd.get(i, j).unwrap();
2346 let diff = (naive_val - simd_val).abs();
2347 let tolerance = if naive_val.abs() > 1.0 {
2348 naive_val.abs() * 1e-5 } else {
2350 1e-5 };
2352 assert!(
2353 diff < tolerance,
2354 "Mismatch at ({}, {}): naive={}, simd={}, diff={}",
2355 i,
2356 j,
2357 naive_val,
2358 simd_val,
2359 diff
2360 );
2361 }
2362 }
2363 }
2364
2365 #[test]
2368 fn test_matmul_blocking_small_matrices() {
2369 let sizes = vec![8, 16, 32];
2371 for size in sizes {
2372 let a =
2373 Matrix::from_vec(size, size, (0..size * size).map(|i| i as f32).collect()).unwrap();
2374 let b = Matrix::from_vec(
2375 size,
2376 size,
2377 (0..size * size).map(|i| (i * 2) as f32).collect(),
2378 )
2379 .unwrap();
2380
2381 let mut result_naive = Matrix::zeros(size, size);
2382 let mut result_simd = Matrix::zeros(size, size);
2383
2384 a.matmul_naive(&b, &mut result_naive).unwrap();
2385 a.matmul_simd(&b, &mut result_simd).unwrap();
2386
2387 for i in 0..size {
2389 for j in 0..size {
2390 let naive_val = result_naive.get(i, j).unwrap();
2391 let simd_val = result_simd.get(i, j).unwrap();
2392 let diff = (naive_val - simd_val).abs();
2393 let tolerance = if naive_val.abs() > 1.0 {
2394 naive_val.abs() * 1e-4
2395 } else {
2396 1e-4
2397 };
2398 assert!(
2399 diff < tolerance,
2400 "Size {}: Mismatch at ({}, {}): naive={}, simd={}, diff={}",
2401 size,
2402 i,
2403 j,
2404 naive_val,
2405 simd_val,
2406 diff
2407 );
2408 }
2409 }
2410 }
2411 }
2412
2413 #[test]
2414 fn test_matmul_blocking_medium_matrices() {
2415 let sizes = vec![64, 128, 256];
2417 for size in sizes {
2418 let a = Matrix::from_vec(
2419 size,
2420 size,
2421 (0..size * size).map(|i| (i % 100) as f32).collect(),
2422 )
2423 .unwrap();
2424 let b = Matrix::from_vec(
2425 size,
2426 size,
2427 (0..size * size).map(|i| ((i * 3) % 100) as f32).collect(),
2428 )
2429 .unwrap();
2430
2431 let mut result_naive = Matrix::zeros(size, size);
2432 let mut result_simd = Matrix::zeros(size, size);
2433
2434 a.matmul_naive(&b, &mut result_naive).unwrap();
2435 a.matmul_simd(&b, &mut result_simd).unwrap();
2436
2437 for i in 0..size {
2439 for j in 0..size {
2440 let naive_val = result_naive.get(i, j).unwrap();
2441 let simd_val = result_simd.get(i, j).unwrap();
2442 let diff = (naive_val - simd_val).abs();
2443 let tolerance = if naive_val.abs() > 1.0 {
2444 naive_val.abs() * 1e-3 } else {
2446 1e-3
2447 };
2448 assert!(
2449 diff < tolerance,
2450 "Size {}: Mismatch at ({}, {}): naive={}, simd={}, diff={}",
2451 size,
2452 i,
2453 j,
2454 naive_val,
2455 simd_val,
2456 diff
2457 );
2458 }
2459 }
2460 }
2461 }
2462
2463 #[test]
2464 fn test_matmul_blocking_non_aligned_sizes() {
2465 let test_cases = vec![
2467 (33, 33, 33), (65, 65, 65), (100, 100, 100), (127, 127, 127), ];
2472
2473 for (m, k, n) in test_cases {
2474 let a = Matrix::from_vec(m, k, (0..m * k).map(|i| (i % 50) as f32).collect()).unwrap();
2475 let b = Matrix::from_vec(k, n, (0..k * n).map(|i| ((i * 2) % 50) as f32).collect())
2476 .unwrap();
2477
2478 let mut result_naive = Matrix::zeros(m, n);
2479 let mut result_simd = Matrix::zeros(m, n);
2480
2481 a.matmul_naive(&b, &mut result_naive).unwrap();
2482 a.matmul_simd(&b, &mut result_simd).unwrap();
2483
2484 for i in 0..m {
2486 for j in 0..n {
2487 let naive_val = result_naive.get(i, j).unwrap();
2488 let simd_val = result_simd.get(i, j).unwrap();
2489 let diff = (naive_val - simd_val).abs();
2490 let tolerance = if naive_val.abs() > 1.0 {
2491 naive_val.abs() * 1e-3
2492 } else {
2493 1e-3
2494 };
2495 assert!(
2496 diff < tolerance,
2497 "Size {}×{}×{}: Mismatch at ({}, {}): naive={}, simd={}, diff={}",
2498 m,
2499 k,
2500 n,
2501 i,
2502 j,
2503 naive_val,
2504 simd_val,
2505 diff
2506 );
2507 }
2508 }
2509 }
2510 }
2511
2512 #[test]
2513 fn test_matmul_blocking_large_matrices() {
2514 let size = 256;
2517 let a = Matrix::from_vec(
2518 size,
2519 size,
2520 (0..size * size)
2521 .map(|i| ((i % 100) as f32) / 10.0)
2522 .collect(),
2523 )
2524 .unwrap();
2525 let b = Matrix::from_vec(
2526 size,
2527 size,
2528 (0..size * size)
2529 .map(|i| (((i * 7) % 100) as f32) / 10.0)
2530 .collect(),
2531 )
2532 .unwrap();
2533
2534 let mut result_naive = Matrix::zeros(size, size);
2535 let mut result_simd = Matrix::zeros(size, size);
2536
2537 a.matmul_naive(&b, &mut result_naive).unwrap();
2538 a.matmul_simd(&b, &mut result_simd).unwrap();
2539
2540 let mut max_diff = 0.0f32;
2542 let mut mismatches = 0;
2543 for i in 0..size {
2544 for j in 0..size {
2545 let naive_val = result_naive.get(i, j).unwrap();
2546 let simd_val = result_simd.get(i, j).unwrap();
2547 let diff = (naive_val - simd_val).abs();
2548 let tolerance = if naive_val.abs() > 1.0 {
2549 naive_val.abs() * 1e-2 } else {
2551 1e-2
2552 };
2553 if diff >= tolerance {
2554 mismatches += 1;
2555 if mismatches <= 5 {
2556 eprintln!(
2557 "Mismatch at ({}, {}): naive={}, simd={}, diff={}, tolerance={}",
2558 i, j, naive_val, simd_val, diff, tolerance
2559 );
2560 }
2561 }
2562 max_diff = max_diff.max(diff);
2563 }
2564 }
2565 assert_eq!(
2566 mismatches, 0,
2567 "Found {} mismatches in {}×{} matmul, max_diff={}",
2568 mismatches, size, size, max_diff
2569 );
2570 }
2571
2572 #[test]
2573 fn test_matmul_3level_blocking() {
2574 let size = 512; let a = Matrix::from_vec(
2578 size,
2579 size,
2580 (0..size * size)
2581 .map(|i| ((i % 100) as f32) / 10.0)
2582 .collect(),
2583 )
2584 .unwrap();
2585 let b = Matrix::from_vec(
2586 size,
2587 size,
2588 (0..size * size)
2589 .map(|i| (((i * 7) % 100) as f32) / 10.0)
2590 .collect(),
2591 )
2592 .unwrap();
2593
2594 let mut result_naive = Matrix::zeros(size, size);
2595 let mut result_simd = Matrix::zeros(size, size);
2596
2597 a.matmul_naive(&b, &mut result_naive).unwrap();
2598 a.matmul_simd(&b, &mut result_simd).unwrap();
2599
2600 let mut max_diff = 0.0f32;
2602 let mut mismatches = 0;
2603 for i in 0..size {
2604 for j in 0..size {
2605 let naive_val = result_naive.get(i, j).unwrap();
2606 let simd_val = result_simd.get(i, j).unwrap();
2607 let diff = (naive_val - simd_val).abs();
2608 let tolerance = if naive_val.abs() > 1.0 {
2609 naive_val.abs() * 1e-2
2610 } else {
2611 1e-2
2612 };
2613 if diff >= tolerance {
2614 mismatches += 1;
2615 if mismatches <= 5 {
2616 eprintln!(
2617 "Mismatch at ({}, {}): naive={}, simd={}, diff={}, tolerance={}",
2618 i, j, naive_val, simd_val, diff, tolerance
2619 );
2620 }
2621 }
2622 max_diff = max_diff.max(diff);
2623 }
2624 }
2625 assert_eq!(
2626 mismatches, 0,
2627 "Found {} mismatches in {}×{} matmul (3-level blocking), max_diff={}",
2628 mismatches, size, size, max_diff
2629 );
2630 }
2631
2632 #[test]
2633 #[cfg(feature = "parallel")]
2634 fn test_matmul_parallel_1024() {
2635 let size = 1024;
2638 let a = Matrix::from_vec(
2639 size,
2640 size,
2641 (0..size * size)
2642 .map(|i| ((i % 100) as f32) / 10.0)
2643 .collect(),
2644 )
2645 .unwrap();
2646 let b = Matrix::from_vec(
2647 size,
2648 size,
2649 (0..size * size)
2650 .map(|i| (((i * 7) % 100) as f32) / 10.0)
2651 .collect(),
2652 )
2653 .unwrap();
2654
2655 let mut result_naive = Matrix::zeros(size, size);
2656 let mut result_parallel = Matrix::zeros(size, size);
2657
2658 a.matmul_naive(&b, &mut result_naive).unwrap();
2659 a.matmul_simd(&b, &mut result_parallel).unwrap(); let mut max_diff = 0.0f32;
2663 let mut mismatches = 0;
2664 for i in 0..size {
2665 for j in 0..size {
2666 let naive_val = result_naive.get(i, j).unwrap();
2667 let parallel_val = result_parallel.get(i, j).unwrap();
2668 let diff = (naive_val - parallel_val).abs();
2669 let tolerance = if naive_val.abs() > 1.0 {
2670 naive_val.abs() * 1e-2
2671 } else {
2672 1e-2
2673 };
2674 if diff >= tolerance {
2675 mismatches += 1;
2676 if mismatches <= 5 {
2677 eprintln!(
2678 "Mismatch at ({}, {}): naive={}, parallel={}, diff={}, tolerance={}",
2679 i, j, naive_val, parallel_val, diff, tolerance
2680 );
2681 }
2682 }
2683 max_diff = max_diff.max(diff);
2684 }
2685 }
2686 assert_eq!(
2687 mismatches, 0,
2688 "Found {} mismatches in {}×{} parallel matmul, max_diff={}",
2689 mismatches, size, size, max_diff
2690 );
2691 }
2692
2693 #[test]
2694 #[cfg(feature = "parallel")]
2695 fn test_matvec_parallel_4096() {
2696 let rows = 4096;
2699 let cols = 512;
2700
2701 let matrix = Matrix::from_vec(
2702 rows,
2703 cols,
2704 (0..rows * cols)
2705 .map(|i| ((i % 100) as f32) / 10.0)
2706 .collect(),
2707 )
2708 .unwrap();
2709
2710 let vector = Vector::from_slice(
2711 &(0..cols)
2712 .map(|i| ((i % 50) as f32) / 5.0)
2713 .collect::<Vec<f32>>(),
2714 );
2715
2716 let result = matrix.matvec(&vector).unwrap();
2718
2719 assert_eq!(result.len(), rows);
2721
2722 for sample_row in [0, 1024, 2048, 3072, 4095] {
2725 let row_start = sample_row * cols;
2726 let row = &matrix.data[row_start..(row_start + cols)];
2727
2728 let expected: f32 = row
2730 .iter()
2731 .zip(vector.as_slice().iter())
2732 .map(|(a, b)| a * b)
2733 .sum();
2734
2735 let actual = result.as_slice()[sample_row];
2736 let diff = (expected - actual).abs();
2737 let tolerance = if expected.abs() > 1.0 {
2738 expected.abs() * 1e-3
2739 } else {
2740 1e-3
2741 };
2742
2743 assert!(
2744 diff < tolerance,
2745 "Mismatch at row {}: expected={}, actual={}, diff={}",
2746 sample_row,
2747 expected,
2748 actual,
2749 diff
2750 );
2751 }
2752 }
2753
2754 #[test]
2757 #[cfg(target_arch = "x86_64")]
2758 fn test_horizontal_sum_avx2() {
2759 if !is_x86_feature_detected!("avx2") {
2761 println!("Skipping AVX2 horizontal sum test (CPU doesn't support AVX2)");
2762 return;
2763 }
2764
2765 use std::arch::x86_64::*;
2766
2767 unsafe {
2768 let v = _mm256_set1_ps(1.0);
2770 let sum = Matrix::<f32>::horizontal_sum_avx2(v);
2771 assert!((sum - 8.0).abs() < 1e-6, "Expected 8.0, got {}", sum);
2772
2773 let v = _mm256_setr_ps(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
2775 let sum = Matrix::<f32>::horizontal_sum_avx2(v);
2776 assert!((sum - 36.0).abs() < 1e-6, "Expected 36.0, got {}", sum);
2777
2778 let v = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);
2780 let sum = Matrix::<f32>::horizontal_sum_avx2(v);
2781 assert!(sum.abs() < 1e-6, "Expected ~0.0, got {}", sum);
2782
2783 let v = _mm256_setr_ps(100.0, 200.0, 300.0, 400.0, 500.0, 600.0, 700.0, 800.0);
2785 let sum = Matrix::<f32>::horizontal_sum_avx2(v);
2786 assert!((sum - 3600.0).abs() < 1e-3, "Expected 3600.0, got {}", sum);
2787
2788 let v = _mm256_setr_ps(10.5, -5.25, 3.75, -8.0, 12.0, -6.5, 4.25, -2.75);
2790 let expected = 10.5 - 5.25 + 3.75 - 8.0 + 12.0 - 6.5 + 4.25 - 2.75;
2791 let sum = Matrix::<f32>::horizontal_sum_avx2(v);
2792 assert!(
2793 (sum - expected).abs() < 1e-5,
2794 "Expected {}, got {}",
2795 expected,
2796 sum
2797 );
2798 }
2799 }
2800
2801 #[test]
2802 #[cfg(target_arch = "x86_64")]
2803 fn test_matmul_microkernel_4x1_avx2() {
2804 if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
2806 println!("Skipping AVX2 micro-kernel test (CPU doesn't support AVX2/FMA)");
2807 return;
2808 }
2809
2810 {
2815 let row0: Vec<f32> = (1..=16).map(|x| x as f32).collect();
2816 let row1: Vec<f32> = (17..=32).map(|x| x as f32).collect();
2817 let row2: Vec<f32> = (33..=48).map(|x| x as f32).collect();
2818 let row3: Vec<f32> = (49..=64).map(|x| x as f32).collect();
2819 let b_col = vec![1.0f32; 16];
2820
2821 let a_rows = [
2822 row0.as_slice(),
2823 row1.as_slice(),
2824 row2.as_slice(),
2825 row3.as_slice(),
2826 ];
2827 let mut results = [0.0f32; 4];
2828
2829 unsafe {
2830 Matrix::<f32>::matmul_microkernel_4x1_avx2(a_rows, &b_col, &mut results);
2831 }
2832
2833 let expected = [
2835 (1..=16).sum::<i32>() as f32,
2836 (17..=32).sum::<i32>() as f32,
2837 (33..=48).sum::<i32>() as f32,
2838 (49..=64).sum::<i32>() as f32,
2839 ];
2840
2841 for i in 0..4 {
2842 assert!(
2843 (results[i] - expected[i]).abs() < 1e-3,
2844 "Row {}: expected {}, got {}",
2845 i,
2846 expected[i],
2847 results[i]
2848 );
2849 }
2850 }
2851
2852 {
2855 let row0 = vec![
2856 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2857 ];
2858 let row1 = vec![
2859 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2860 ];
2861 let row2 = vec![
2862 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2863 ];
2864 let row3 = vec![
2865 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
2866 ];
2867 let b_col: Vec<f32> = (1..=16).map(|x| x as f32).collect();
2868
2869 let a_rows = [
2870 row0.as_slice(),
2871 row1.as_slice(),
2872 row2.as_slice(),
2873 row3.as_slice(),
2874 ];
2875 let mut results = [0.0f32; 4];
2876
2877 unsafe {
2878 Matrix::<f32>::matmul_microkernel_4x1_avx2(a_rows, &b_col, &mut results);
2879 }
2880
2881 let expected = [1.0, 2.0, 3.0, 4.0];
2883 for i in 0..4 {
2884 assert!(
2885 (results[i] - expected[i]).abs() < 1e-6,
2886 "Row {}: expected {}, got {}",
2887 i,
2888 expected[i],
2889 results[i]
2890 );
2891 }
2892 }
2893
2894 {
2897 let row0: Vec<f32> = (1..=10).map(|x| x as f32).collect();
2898 let row1: Vec<f32> = (11..=20).map(|x| x as f32).collect();
2899 let row2: Vec<f32> = (21..=30).map(|x| x as f32).collect();
2900 let row3: Vec<f32> = (31..=40).map(|x| x as f32).collect();
2901 let b_col = vec![2.0f32; 10];
2902
2903 let a_rows = [
2904 row0.as_slice(),
2905 row1.as_slice(),
2906 row2.as_slice(),
2907 row3.as_slice(),
2908 ];
2909 let mut results = [0.0f32; 4];
2910
2911 unsafe {
2912 Matrix::<f32>::matmul_microkernel_4x1_avx2(a_rows, &b_col, &mut results);
2913 }
2914
2915 let expected = [
2917 2.0 * (1..=10).sum::<i32>() as f32,
2918 2.0 * (11..=20).sum::<i32>() as f32,
2919 2.0 * (21..=30).sum::<i32>() as f32,
2920 2.0 * (31..=40).sum::<i32>() as f32,
2921 ];
2922
2923 for i in 0..4 {
2924 assert!(
2925 (results[i] - expected[i]).abs() < 1e-3,
2926 "Row {}: expected {}, got {}",
2927 i,
2928 expected[i],
2929 results[i]
2930 );
2931 }
2932 }
2933
2934 {
2936 let row0 = vec![
2937 1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0, 11.0, -12.0, 13.0, -14.0,
2938 15.0, -16.0,
2939 ];
2940 let row1 = vec![
2941 2.0, -4.0, 6.0, -8.0, 10.0, -12.0, 14.0, -16.0, 18.0, -20.0, 22.0, -24.0, 26.0,
2942 -28.0, 30.0, -32.0,
2943 ];
2944 let row2 = vec![
2945 0.5, -1.0, 1.5, -2.0, 2.5, -3.0, 3.5, -4.0, 4.5, -5.0, 5.5, -6.0, 6.5, -7.0, 7.5,
2946 -8.0,
2947 ];
2948 let row3 = vec![
2949 10.0, -10.0, 10.0, -10.0, 10.0, -10.0, 10.0, -10.0, 10.0, -10.0, 10.0, -10.0, 10.0,
2950 -10.0, 10.0, -10.0,
2951 ];
2952 let b_col = vec![
2953 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2954 ];
2955
2956 let a_rows = [
2957 row0.as_slice(),
2958 row1.as_slice(),
2959 row2.as_slice(),
2960 row3.as_slice(),
2961 ];
2962 let mut results = [0.0f32; 4];
2963
2964 unsafe {
2965 Matrix::<f32>::matmul_microkernel_4x1_avx2(a_rows, &b_col, &mut results);
2966 }
2967
2968 let expected = [
2970 row0.iter().sum::<f32>(),
2971 row1.iter().sum::<f32>(),
2972 row2.iter().sum::<f32>(),
2973 row3.iter().sum::<f32>(),
2974 ];
2975
2976 for i in 0..4 {
2977 assert!(
2978 (results[i] - expected[i]).abs() < 1e-4,
2979 "Row {}: expected {}, got {}",
2980 i,
2981 expected[i],
2982 results[i]
2983 );
2984 }
2985 }
2986
2987 {
2989 let row0 = vec![0.0f32; 16];
2990 let row1 = vec![0.0f32; 16];
2991 let row2 = vec![0.0f32; 16];
2992 let row3 = vec![0.0f32; 16];
2993 let b_col: Vec<f32> = (1..=16).map(|x| x as f32).collect();
2994
2995 let a_rows = [
2996 row0.as_slice(),
2997 row1.as_slice(),
2998 row2.as_slice(),
2999 row3.as_slice(),
3000 ];
3001 let mut results = [0.0f32; 4];
3002
3003 unsafe {
3004 Matrix::<f32>::matmul_microkernel_4x1_avx2(a_rows, &b_col, &mut results);
3005 }
3006
3007 for (i, &result) in results.iter().enumerate() {
3008 assert!(
3009 result.abs() < 1e-6,
3010 "Row {}: expected 0.0, got {}",
3011 i,
3012 result
3013 );
3014 }
3015 }
3016
3017 {
3020 let row0 = vec![
3021 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
3022 16.0,
3023 ];
3024 let row1 = vec![
3025 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0,
3026 30.0, 32.0,
3027 ];
3028 let row2 = vec![
3029 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0,
3030 ];
3031 let row3 = vec![
3032 3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0, 39.0, 42.0,
3033 45.0, 48.0,
3034 ];
3035 let b_col = vec![
3036 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
3037 ];
3038
3039 let a_rows = [
3040 row0.as_slice(),
3041 row1.as_slice(),
3042 row2.as_slice(),
3043 row3.as_slice(),
3044 ];
3045 let mut results = [0.0f32; 4];
3046
3047 unsafe {
3048 Matrix::<f32>::matmul_microkernel_4x1_avx2(a_rows, &b_col, &mut results);
3049 }
3050
3051 let expected = [
3053 0.5 * row0.iter().sum::<f32>(),
3054 0.5 * row1.iter().sum::<f32>(),
3055 0.5 * row2.iter().sum::<f32>(),
3056 0.5 * row3.iter().sum::<f32>(),
3057 ];
3058
3059 for i in 0..4 {
3060 assert!(
3061 (results[i] - expected[i]).abs() < 1e-3,
3062 "Row {}: expected {}, got {}",
3063 i,
3064 expected[i],
3065 results[i]
3066 );
3067 }
3068 }
3069 }
3070
3071 #[test]
3074 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
3075 fn test_gpu_availability() {
3076 use crate::backends::gpu::GpuBackend;
3077 let _available = GpuBackend::is_available();
3079 }
3081
3082 #[test]
3083 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
3084 #[ignore] fn test_gpu_matmul_basic() {
3086 use crate::backends::gpu::GpuBackend;
3087
3088 if !GpuBackend::is_available() {
3089 eprintln!("GPU not available, skipping test");
3090 return;
3091 }
3092
3093 let a = Matrix::from_vec(
3095 4,
3096 4,
3097 vec![
3098 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
3099 16.0,
3100 ],
3101 )
3102 .unwrap();
3103
3104 let b = Matrix::from_vec(
3105 4,
3106 4,
3107 vec![
3108 16.0, 15.0, 14.0, 13.0, 12.0, 11.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0,
3109 1.0,
3110 ],
3111 )
3112 .unwrap();
3113
3114 let result = a.matmul_gpu(&b);
3116
3117 if let Ok(c) = result {
3118 assert_eq!(c.rows(), 4);
3120 assert_eq!(c.cols(), 4);
3121
3122 assert!((c.get(0, 0).unwrap() - 80.0).abs() < 1e-4);
3125 } else {
3126 eprintln!("GPU matmul failed: {:?}", result);
3127 }
3128 }
3129
3130 #[test]
3133 fn test_transpose_basic() {
3134 let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
3138 let t = m.transpose();
3139
3140 assert_eq!(t.rows(), 3);
3141 assert_eq!(t.cols(), 2);
3142 assert_eq!(t.get(0, 0), Some(&1.0));
3143 assert_eq!(t.get(0, 1), Some(&4.0));
3144 assert_eq!(t.get(1, 0), Some(&2.0));
3145 assert_eq!(t.get(1, 1), Some(&5.0));
3146 assert_eq!(t.get(2, 0), Some(&3.0));
3147 assert_eq!(t.get(2, 1), Some(&6.0));
3148 }
3149
3150 #[test]
3151 fn test_transpose_square() {
3152 let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
3155 let t = m.transpose();
3156
3157 assert_eq!(t.rows(), 2);
3158 assert_eq!(t.cols(), 2);
3159 assert_eq!(t.get(0, 0), Some(&1.0));
3160 assert_eq!(t.get(0, 1), Some(&3.0));
3161 assert_eq!(t.get(1, 0), Some(&2.0));
3162 assert_eq!(t.get(1, 1), Some(&4.0));
3163 }
3164
3165 #[test]
3166 fn test_transpose_single_row() {
3167 let m = Matrix::from_vec(1, 3, vec![1.0, 2.0, 3.0]).unwrap();
3171 let t = m.transpose();
3172
3173 assert_eq!(t.rows(), 3);
3174 assert_eq!(t.cols(), 1);
3175 assert_eq!(t.get(0, 0), Some(&1.0));
3176 assert_eq!(t.get(1, 0), Some(&2.0));
3177 assert_eq!(t.get(2, 0), Some(&3.0));
3178 }
3179
3180 #[test]
3181 fn test_transpose_single_col() {
3182 let m = Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).unwrap();
3186 let t = m.transpose();
3187
3188 assert_eq!(t.rows(), 1);
3189 assert_eq!(t.cols(), 3);
3190 assert_eq!(t.get(0, 0), Some(&1.0));
3191 assert_eq!(t.get(0, 1), Some(&2.0));
3192 assert_eq!(t.get(0, 2), Some(&3.0));
3193 }
3194
3195 #[test]
3196 fn test_transpose_single_element() {
3197 let m = Matrix::from_vec(1, 1, vec![5.0]).unwrap();
3199 let t = m.transpose();
3200
3201 assert_eq!(t.rows(), 1);
3202 assert_eq!(t.cols(), 1);
3203 assert_eq!(t.get(0, 0), Some(&5.0));
3204 }
3205
3206 #[test]
3207 fn test_transpose_identity() {
3208 let identity = Matrix::identity(3);
3210 let t = identity.transpose();
3211
3212 assert_eq!(t.rows(), 3);
3213 assert_eq!(t.cols(), 3);
3214
3215 for i in 0..3 {
3217 for j in 0..3 {
3218 let expected = if i == j { 1.0 } else { 0.0 };
3219 assert_eq!(t.get(i, j), Some(&expected));
3220 }
3221 }
3222 }
3223}
3224
3225#[cfg(test)]
3227mod property_tests {
3228 use super::*;
3229 use proptest::prelude::*;
3230
3231 fn matrix_strategy(rows: usize, cols: usize) -> impl Strategy<Value = Matrix<f32>> {
3233 proptest::collection::vec(-100.0f32..100.0, rows * cols)
3234 .prop_map(move |data| Matrix::from_vec(rows, cols, data).unwrap())
3235 }
3236
3237 proptest! {
3238 #![proptest_config(ProptestConfig::with_cases(100))]
3239
3240 #[test]
3243 fn test_matmul_associative(
3244 a in matrix_strategy(3, 4),
3245 b in matrix_strategy(4, 5),
3246 c in matrix_strategy(5, 3)
3247 ) {
3248 let ab = a.matmul(&b).unwrap();
3249 let ab_c = ab.matmul(&c).unwrap();
3250
3251 let bc = b.matmul(&c).unwrap();
3252 let a_bc = a.matmul(&bc).unwrap();
3253
3254 prop_assert_eq!(ab_c.rows(), a_bc.rows());
3256 prop_assert_eq!(ab_c.cols(), a_bc.cols());
3257
3258 for i in 0..ab_c.rows() {
3261 for j in 0..ab_c.cols() {
3262 let val1 = ab_c.get(i, j).unwrap();
3263 let val2 = a_bc.get(i, j).unwrap();
3264 let diff = (val1 - val2).abs();
3265 let max_val = val1.abs().max(val2.abs());
3266
3267 let tolerance = if max_val < 1.0 {
3276 1e-3 } else {
3278 max_val * 5e-2 };
3286
3287 prop_assert!(
3288 diff < tolerance,
3289 "Associativity failed at ({}, {}): {} != {} (diff: {}, tolerance: {})",
3290 i, j, val1, val2, diff, tolerance
3291 );
3292 }
3293 }
3294 }
3295
3296 #[test]
3299 fn test_matmul_identity_property(
3300 rows in 1usize..10,
3301 cols in 1usize..10,
3302 data in proptest::collection::vec(-100.0f32..100.0, 1..100)
3303 ) {
3304 let size = rows * cols;
3306 if data.len() < size {
3307 return Ok(());
3308 }
3309 let matrix_data = data[0..size].to_vec();
3310
3311 let a = Matrix::from_vec(rows, cols, matrix_data).unwrap();
3312 let identity = Matrix::identity(cols);
3313 let result = a.matmul(&identity).unwrap();
3314
3315 prop_assert_eq!(result.rows(), a.rows());
3317 prop_assert_eq!(result.cols(), a.cols());
3318
3319 for i in 0..rows {
3321 for j in 0..cols {
3322 let original = a.get(i, j).unwrap();
3323 let multiplied = result.get(i, j).unwrap();
3324 let diff = (original - multiplied).abs();
3325 prop_assert!(
3326 diff < 1e-5,
3327 "Identity property failed at ({}, {}): {} != {} (diff: {})",
3328 i, j, original, multiplied, diff
3329 );
3330 }
3331 }
3332 }
3333
3334 #[test]
3337 fn test_matmul_dimension_property(
3338 m in 1usize..10,
3339 n in 1usize..10,
3340 p in 1usize..10
3341 ) {
3342 let a = Matrix::zeros(m, n);
3343 let b = Matrix::zeros(n, p);
3344 let c = a.matmul(&b).unwrap();
3345
3346 prop_assert_eq!(c.rows(), m);
3347 prop_assert_eq!(c.cols(), p);
3348 }
3349
3350 #[test]
3353 fn test_transpose_double_transpose(
3354 a in matrix_strategy(5, 7)
3355 ) {
3356 let t = a.transpose();
3357 let tt = t.transpose();
3358
3359 prop_assert_eq!(tt.rows(), a.rows());
3360 prop_assert_eq!(tt.cols(), a.cols());
3361
3362 for i in 0..a.rows() {
3363 for j in 0..a.cols() {
3364 prop_assert_eq!(tt.get(i, j), a.get(i, j));
3365 }
3366 }
3367 }
3368
3369 #[test]
3372 fn test_transpose_dimension_swap(
3373 m in 1usize..20,
3374 n in 1usize..20
3375 ) {
3376 let a = Matrix::zeros(m, n);
3377 let t = a.transpose();
3378
3379 prop_assert_eq!(t.rows(), n);
3380 prop_assert_eq!(t.cols(), m);
3381 }
3382
3383 #[test]
3386 fn test_transpose_of_product(
3387 a in matrix_strategy(3, 4),
3388 b in matrix_strategy(4, 5)
3389 ) {
3390 let ab = a.matmul(&b).unwrap();
3391 let ab_t = ab.transpose();
3392
3393 let b_t = b.transpose();
3394 let a_t = a.transpose();
3395 let bt_at = b_t.matmul(&a_t).unwrap();
3396
3397 prop_assert_eq!(ab_t.rows(), bt_at.rows());
3398 prop_assert_eq!(ab_t.cols(), bt_at.cols());
3399
3400 for i in 0..ab_t.rows() {
3402 for j in 0..ab_t.cols() {
3403 let val1 = ab_t.get(i, j).unwrap();
3404 let val2 = bt_at.get(i, j).unwrap();
3405 let diff = (val1 - val2).abs();
3406 let max_val = val1.abs().max(val2.abs());
3407
3408 let tolerance = if max_val < 1.0 {
3409 1e-3
3410 } else {
3411 max_val * 1e-3
3412 };
3413
3414 prop_assert!(
3415 diff < tolerance,
3416 "Transpose of product failed at ({}, {}): {} != {} (diff: {}, tolerance: {})",
3417 i, j, val1, val2, diff, tolerance
3418 );
3419 }
3420 }
3421 }
3422
3423 #[test]
3425 fn test_matvec_associativity(
3426 a in matrix_strategy(3, 4),
3427 b in matrix_strategy(4, 5),
3428 v_data in prop::collection::vec(-10.0f32..10.0, 5)
3429 ) {
3430 let v = Vector::from_slice(&v_data);
3431
3432 let ab = a.matmul(&b).unwrap();
3433 let ab_v = ab.matvec(&v).unwrap();
3434
3435 let b_v = b.matvec(&v).unwrap();
3436 let a_bv = a.matvec(&b_v).unwrap();
3437
3438 prop_assert_eq!(ab_v.len(), a_bv.len());
3439
3440 for i in 0..ab_v.len() {
3441 let diff = (ab_v.as_slice()[i] - a_bv.as_slice()[i]).abs();
3442 let max_val = ab_v.as_slice()[i].abs().max(a_bv.as_slice()[i].abs());
3443 let tolerance = if max_val < 1.0 { 1e-2 } else { max_val * 2e-2 };
3445
3446 prop_assert!(
3447 diff < tolerance,
3448 "Associativity failed at index {}: {} != {} (diff: {}, tolerance: {})",
3449 i, ab_v.as_slice()[i], a_bv.as_slice()[i], diff, tolerance
3450 );
3451 }
3452 }
3453
3454 #[test]
3456 fn test_vecmat_associativity(
3457 a in matrix_strategy(3, 4),
3458 b in matrix_strategy(4, 5),
3459 v_data in prop::collection::vec(-10.0f32..10.0, 3)
3460 ) {
3461 let v = Vector::from_slice(&v_data);
3462
3463 let ab = a.matmul(&b).unwrap();
3464 let v_ab = Matrix::vecmat(&v, &ab).unwrap();
3465
3466 let v_a = Matrix::vecmat(&v, &a).unwrap();
3467 let va_b = Matrix::vecmat(&v_a, &b).unwrap();
3468
3469 prop_assert_eq!(v_ab.len(), va_b.len());
3470
3471 for i in 0..v_ab.len() {
3472 let diff = (v_ab.as_slice()[i] - va_b.as_slice()[i]).abs();
3473 let max_val = v_ab.as_slice()[i].abs().max(va_b.as_slice()[i].abs());
3474 let tolerance = if max_val < 1.0 { 1e-2 } else { max_val * 1e-2 };
3475
3476 prop_assert!(
3477 diff < tolerance,
3478 "Associativity failed at index {}: {} != {} (diff: {}, tolerance: {})",
3479 i, v_ab.as_slice()[i], va_b.as_slice()[i], diff, tolerance
3480 );
3481 }
3482 }
3483 }
3484
3485 #[test]
3487 fn test_matvec_basic() {
3488 let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
3489 let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
3490 let result = m.matvec(&v).unwrap();
3491
3492 assert_eq!(result.len(), 2);
3496 assert!((result.as_slice()[0] - 14.0).abs() < 1e-6);
3497 assert!((result.as_slice()[1] - 32.0).abs() < 1e-6);
3498 }
3499
3500 #[test]
3501 fn test_matvec_identity() {
3502 let m = Matrix::identity(3);
3503 let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
3504 let result = m.matvec(&v).unwrap();
3505
3506 assert_eq!(result.as_slice(), v.as_slice());
3508 }
3509
3510 #[test]
3511 fn test_matvec_dimension_mismatch() {
3512 let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
3513 let v = Vector::from_slice(&[1.0, 2.0]); assert!(m.matvec(&v).is_err());
3516 }
3517
3518 #[test]
3519 fn test_vecmat_basic() {
3520 let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
3521 let v = Vector::from_slice(&[1.0, 2.0]);
3522 let result = Matrix::vecmat(&v, &m).unwrap();
3523
3524 assert_eq!(result.len(), 3);
3527 assert!((result.as_slice()[0] - 9.0).abs() < 1e-6);
3528 assert!((result.as_slice()[1] - 12.0).abs() < 1e-6);
3529 assert!((result.as_slice()[2] - 15.0).abs() < 1e-6);
3530 }
3531
3532 #[test]
3533 fn test_vecmat_identity() {
3534 let m = Matrix::identity(3);
3535 let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
3536 let result = Matrix::vecmat(&v, &m).unwrap();
3537
3538 assert_eq!(result.as_slice(), v.as_slice());
3540 }
3541
3542 #[test]
3543 fn test_vecmat_dimension_mismatch() {
3544 let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
3545 let v = Vector::from_slice(&[1.0, 2.0, 3.0]); assert!(Matrix::vecmat(&v, &m).is_err());
3548 }
3549
3550 #[test]
3551 fn test_matvec_zero_vector() {
3552 let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
3553 let v = Vector::from_slice(&[0.0, 0.0, 0.0]);
3554 let result = m.matvec(&v).unwrap();
3555
3556 assert_eq!(result.as_slice(), &[0.0, 0.0]);
3558 }
3559
3560 #[test]
3561 fn test_vecmat_zero_vector() {
3562 let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
3563 let v = Vector::from_slice(&[0.0, 0.0]);
3564 let result = Matrix::vecmat(&v, &m).unwrap();
3565
3566 assert_eq!(result.as_slice(), &[0.0, 0.0, 0.0]);
3568 }
3569
3570 #[test]
3571 fn test_matvec_transpose_equivalence() {
3572 let m = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
3584 let v = Vector::from_slice(&[1.0, 2.0]); let av = m.matvec(&v).unwrap();
3588
3589 let m_t = m.transpose(); let v_mt = Matrix::vecmat(&v, &m_t).unwrap();
3592
3593 assert_eq!(av.as_slice(), v_mt.as_slice());
3595 }
3596
3597 #[test]
3600 fn test_convolve2d_basic_3x3() {
3601 let input =
3603 Matrix::from_vec(3, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap();
3604
3605 let kernel = Matrix::from_vec(1, 1, vec![1.0]).unwrap();
3607
3608 let result = input.convolve2d(&kernel).unwrap();
3609
3610 assert_eq!(result.rows(), 3);
3612 assert_eq!(result.cols(), 3);
3613 assert_eq!(result.as_slice(), input.as_slice());
3614 }
3615
3616 #[test]
3617 fn test_convolve2d_edge_detection() {
3618 let input = Matrix::from_vec(
3620 4,
3621 4,
3622 vec![
3623 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, ],
3628 )
3629 .unwrap();
3630
3631 #[rustfmt::skip]
3633 let kernel = Matrix::from_vec(
3634 3,
3635 3,
3636 vec![
3637 -1.0, -1.0, -1.0,
3638 0.0, 0.0, 0.0,
3639 1.0, 1.0, 1.0,
3640 ],
3641 )
3642 .unwrap();
3643
3644 let result = input.convolve2d(&kernel).unwrap();
3645
3646 assert_eq!(result.rows(), 2);
3648 assert_eq!(result.cols(), 2);
3649 }
3650
3651 #[test]
3652 fn test_convolve2d_averaging_filter() {
3653 let input = Matrix::from_vec(
3655 5,
3656 5,
3657 vec![
3658 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 9.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],
3664 )
3665 .unwrap();
3666
3667 let kernel_val = 1.0 / 9.0;
3669 let kernel = Matrix::from_vec(
3670 3,
3671 3,
3672 vec![
3673 kernel_val, kernel_val, kernel_val, kernel_val, kernel_val, kernel_val, kernel_val, kernel_val, kernel_val, ],
3677 )
3678 .unwrap();
3679
3680 let result = input.convolve2d(&kernel).unwrap();
3681
3682 assert_eq!(result.rows(), 3);
3684 assert_eq!(result.cols(), 3);
3685
3686 assert!((result.get(1, 1).unwrap() - 1.0).abs() < 1e-5);
3688 }
3689
3690 #[test]
3691 fn test_convolve2d_invalid_kernel() {
3692 let input = Matrix::from_vec(3, 3, vec![1.0; 9]).unwrap();
3693
3694 let kernel = Matrix::from_vec(4, 4, vec![1.0; 16]).unwrap();
3696
3697 assert!(input.convolve2d(&kernel).is_err());
3698 }
3699
3700 #[test]
3703 fn test_embedding_lookup_basic() {
3704 let embeddings = Matrix::from_vec(
3706 4,
3707 3,
3708 vec![
3709 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
3714 )
3715 .unwrap();
3716
3717 let result = embeddings.embedding_lookup(&[1, 3, 0]).unwrap();
3719
3720 assert_eq!(result.rows(), 3);
3721 assert_eq!(result.cols(), 3);
3722
3723 assert_eq!(result.get(0, 0), Some(&4.0));
3725 assert_eq!(result.get(0, 1), Some(&5.0));
3726 assert_eq!(result.get(0, 2), Some(&6.0));
3727
3728 assert_eq!(result.get(1, 0), Some(&10.0));
3730 assert_eq!(result.get(1, 1), Some(&11.0));
3731 assert_eq!(result.get(1, 2), Some(&12.0));
3732
3733 assert_eq!(result.get(2, 0), Some(&1.0));
3735 assert_eq!(result.get(2, 1), Some(&2.0));
3736 assert_eq!(result.get(2, 2), Some(&3.0));
3737 }
3738
3739 #[test]
3740 fn test_embedding_lookup_single_index() {
3741 let embeddings = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
3742
3743 let result = embeddings.embedding_lookup(&[1]).unwrap();
3744
3745 assert_eq!(result.rows(), 1);
3746 assert_eq!(result.cols(), 2);
3747 assert_eq!(result.get(0, 0), Some(&3.0));
3748 assert_eq!(result.get(0, 1), Some(&4.0));
3749 }
3750
3751 #[test]
3752 fn test_embedding_lookup_repeated_indices() {
3753 let embeddings = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
3754
3755 let result = embeddings.embedding_lookup(&[0, 0, 1, 0]).unwrap();
3757
3758 assert_eq!(result.rows(), 4);
3759 assert_eq!(result.cols(), 3);
3760
3761 assert_eq!(result.get(0, 0), result.get(1, 0));
3763 assert_eq!(result.get(0, 0), result.get(3, 0));
3764 }
3765
3766 #[test]
3767 fn test_embedding_lookup_empty_indices() {
3768 let embeddings = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
3769
3770 let result = embeddings.embedding_lookup(&[]).unwrap();
3771
3772 assert_eq!(result.rows(), 0);
3773 assert_eq!(result.cols(), 2);
3774 }
3775
3776 #[test]
3777 fn test_embedding_lookup_out_of_bounds() {
3778 let embeddings = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
3779
3780 let result = embeddings.embedding_lookup(&[0, 5, 1]);
3782
3783 assert!(result.is_err());
3784 let err = result.unwrap_err();
3785 assert!(err.to_string().contains("out of bounds"));
3786 }
3787
3788 #[test]
3789 fn test_embedding_lookup_sparse() {
3790 let embeddings =
3791 Matrix::from_vec(4, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
3792
3793 let (result, unique) = embeddings
3795 .embedding_lookup_sparse(&[1, 3, 1, 0, 3])
3796 .unwrap();
3797
3798 assert_eq!(result.rows(), 5);
3799 assert_eq!(result.cols(), 2);
3800
3801 assert_eq!(unique, vec![0, 1, 3]);
3803 }
3804
3805 #[test]
3806 fn test_embedding_lookup_large_embeddings() {
3807 let vocab_size = 1000;
3809 let embed_dim = 256;
3810 let data: Vec<f32> = (0..vocab_size * embed_dim).map(|i| i as f32).collect();
3811 let embeddings = Matrix::from_vec(vocab_size, embed_dim, data).unwrap();
3812
3813 let indices: Vec<usize> = vec![0, 500, 999, 42, 100];
3815 let result = embeddings.embedding_lookup(&indices).unwrap();
3816
3817 assert_eq!(result.rows(), 5);
3818 assert_eq!(result.cols(), embed_dim);
3819
3820 assert_eq!(result.get(0, 0), Some(&0.0)); assert_eq!(result.get(1, 0), Some(&(500.0 * 256.0))); assert_eq!(result.get(2, 0), Some(&(999.0 * 256.0))); }
3825
3826 #[test]
3829 fn test_batched_matmul_basic() {
3830 let batch = 2;
3832 let m = 2;
3833 let k = 3;
3834 let n = 2;
3835
3836 let a_data: Vec<f32> = vec![
3839 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
3842 let b_data: Vec<f32> = vec![
3843 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
3846
3847 let result = Matrix::batched_matmul(&a_data, &b_data, batch, m, k, n).unwrap();
3848
3849 assert_eq!(result.len(), batch * m * n);
3850
3851 assert!((result[0] - 22.0).abs() < 1e-5);
3853 assert!((result[1] - 28.0).abs() < 1e-5);
3854 assert!((result[2] - 49.0).abs() < 1e-5);
3855 assert!((result[3] - 64.0).abs() < 1e-5);
3856
3857 assert!((result[4] - 220.0).abs() < 1e-5);
3863 assert!((result[5] - 244.0).abs() < 1e-5);
3864 assert!((result[6] - 301.0).abs() < 1e-5);
3865 assert!((result[7] - 334.0).abs() < 1e-5);
3866 }
3867
3868 #[test]
3869 fn test_batched_matmul_single_batch() {
3870 let batch = 1;
3871 let m = 2;
3872 let k = 2;
3873 let n = 2;
3874
3875 let a_data = vec![1.0, 0.0, 0.0, 1.0]; let b_data = vec![5.0, 6.0, 7.0, 8.0];
3877
3878 let result = Matrix::batched_matmul(&a_data, &b_data, batch, m, k, n).unwrap();
3879
3880 assert!((result[0] - 5.0).abs() < 1e-5);
3882 assert!((result[1] - 6.0).abs() < 1e-5);
3883 assert!((result[2] - 7.0).abs() < 1e-5);
3884 assert!((result[3] - 8.0).abs() < 1e-5);
3885 }
3886
3887 #[test]
3888 fn test_batched_matmul_a_size_mismatch() {
3889 let batch = 2;
3890 let m = 2;
3891 let k = 3;
3892 let n = 2;
3893
3894 let a_data = vec![1.0; 10]; let b_data = vec![1.0; batch * k * n];
3896
3897 let result = Matrix::batched_matmul(&a_data, &b_data, batch, m, k, n);
3898 assert!(result.is_err());
3899 assert!(result.unwrap_err().to_string().contains("A data size mismatch"));
3900 }
3901
3902 #[test]
3903 fn test_batched_matmul_b_size_mismatch() {
3904 let batch = 2;
3905 let m = 2;
3906 let k = 3;
3907 let n = 2;
3908
3909 let a_data = vec![1.0; batch * m * k];
3910 let b_data = vec![1.0; 10]; let result = Matrix::batched_matmul(&a_data, &b_data, batch, m, k, n);
3913 assert!(result.is_err());
3914 assert!(result.unwrap_err().to_string().contains("B data size mismatch"));
3915 }
3916
3917 #[test]
3918 fn test_batched_matmul_4d_basic() {
3919 let batch = 1;
3921 let heads = 2;
3922 let m = 2;
3923 let k = 2;
3924 let n = 2;
3925
3926 let a_data: Vec<f32> = vec![
3929 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
3932 let b_data: Vec<f32> = vec![
3933 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, ];
3936
3937 let result = Matrix::batched_matmul_4d(&a_data, &b_data, batch, heads, m, k, n).unwrap();
3938
3939 assert_eq!(result.len(), batch * heads * m * n);
3940
3941 assert!((result[0] - 1.0).abs() < 1e-5);
3943 assert!((result[1] - 2.0).abs() < 1e-5);
3944 assert!((result[2] - 3.0).abs() < 1e-5);
3945 assert!((result[3] - 4.0).abs() < 1e-5);
3946
3947 assert!((result[4] - 5.0).abs() < 1e-5);
3949 assert!((result[5] - 6.0).abs() < 1e-5);
3950 assert!((result[6] - 7.0).abs() < 1e-5);
3951 assert!((result[7] - 8.0).abs() < 1e-5);
3952 }
3953
3954 #[test]
3955 fn test_batched_matmul_4d_attention_pattern() {
3956 let batch = 1;
3958 let heads = 2;
3959 let seq_len = 4;
3960 let head_dim = 8;
3961
3962 let q_data: Vec<f32> = (0..batch * heads * seq_len * head_dim)
3963 .map(|i| (i as f32) * 0.01)
3964 .collect();
3965 let kt_data: Vec<f32> = (0..batch * heads * head_dim * seq_len)
3966 .map(|i| (i as f32) * 0.01)
3967 .collect();
3968
3969 let result = Matrix::batched_matmul_4d(
3970 &q_data,
3971 &kt_data,
3972 batch,
3973 heads,
3974 seq_len,
3975 head_dim,
3976 seq_len,
3977 )
3978 .unwrap();
3979
3980 assert_eq!(result.len(), batch * heads * seq_len * seq_len);
3982 }
3983
3984 #[test]
3985 fn test_batched_matmul_4d_a_size_mismatch() {
3986 let batch = 1;
3987 let heads = 2;
3988 let m = 4;
3989 let k = 8;
3990 let n = 4;
3991
3992 let a_data = vec![1.0; 50]; let b_data = vec![1.0; batch * heads * k * n];
3994
3995 let result = Matrix::batched_matmul_4d(&a_data, &b_data, batch, heads, m, k, n);
3996 assert!(result.is_err());
3997 assert!(result.unwrap_err().to_string().contains("A data size mismatch"));
3998 }
3999
4000 #[test]
4001 fn test_batched_matmul_4d_b_size_mismatch() {
4002 let batch = 1;
4003 let heads = 2;
4004 let m = 4;
4005 let k = 8;
4006 let n = 4;
4007
4008 let a_data = vec![1.0; batch * heads * m * k];
4009 let b_data = vec![1.0; 50]; let result = Matrix::batched_matmul_4d(&a_data, &b_data, batch, heads, m, k, n);
4012 assert!(result.is_err());
4013 assert!(result.unwrap_err().to_string().contains("B data size mismatch"));
4014 }
4015
4016 #[cfg(test)]
4019 mod conv_property_tests {
4020 use super::*;
4021
4022 proptest! {
4023 #[test]
4024 fn test_convolve2d_output_size(
4025 input_rows in 3usize..20,
4026 input_cols in 3usize..20,
4027 kernel_rows in 1usize..5,
4028 kernel_cols in 1usize..5,
4029 ) {
4030 if kernel_rows <= input_rows && kernel_cols <= input_cols {
4032 let input = Matrix::from_vec(input_rows, input_cols, vec![1.0; input_rows * input_cols]).unwrap();
4033 let kernel = Matrix::from_vec(kernel_rows, kernel_cols, vec![1.0; kernel_rows * kernel_cols]).unwrap();
4034
4035 let result = input.convolve2d(&kernel).unwrap();
4036
4037 prop_assert_eq!(result.rows(), input_rows - kernel_rows + 1);
4038 prop_assert_eq!(result.cols(), input_cols - kernel_cols + 1);
4039 }
4040 }
4041
4042 #[test]
4043 fn test_convolve2d_identity_kernel(
4044 input_rows in 3usize..10,
4045 input_cols in 3usize..10,
4046 values in prop::collection::vec(-100.0f32..100.0, 9..100)
4047 ) {
4048 if values.len() >= input_rows * input_cols {
4050 let data: Vec<f32> = values.iter().take(input_rows * input_cols).copied().collect();
4051 let input = Matrix::from_vec(input_rows, input_cols, data.clone()).unwrap();
4052 let kernel = Matrix::from_vec(1, 1, vec![1.0]).unwrap();
4053
4054 let result = input.convolve2d(&kernel).unwrap();
4055
4056 prop_assert_eq!(result.rows(), input_rows);
4057 prop_assert_eq!(result.cols(), input_cols);
4058 prop_assert_eq!(result.as_slice(), input.as_slice());
4059 }
4060 }
4061
4062 #[test]
4063 fn test_convolve2d_zero_kernel(
4064 input_rows in 3usize..10,
4065 input_cols in 3usize..10,
4066 kernel_rows in 1usize..4,
4067 kernel_cols in 1usize..4,
4068 ) {
4069 if kernel_rows <= input_rows && kernel_cols <= input_cols {
4071 let input = Matrix::from_vec(input_rows, input_cols, vec![5.0; input_rows * input_cols]).unwrap();
4072 let kernel = Matrix::from_vec(kernel_rows, kernel_cols, vec![0.0; kernel_rows * kernel_cols]).unwrap();
4073
4074 let result = input.convolve2d(&kernel).unwrap();
4075
4076 for &val in result.as_slice() {
4077 prop_assert!((val - 0.0).abs() < 1e-5);
4078 }
4079 }
4080 }
4081
4082 #[test]
4083 fn test_convolve2d_scalar_multiplication(
4084 input_rows in 3usize..10,
4085 input_cols in 3usize..10,
4086 scalar in -10.0f32..10.0,
4087 ) {
4088 let input = Matrix::from_vec(input_rows, input_cols, vec![2.0; input_rows * input_cols]).unwrap();
4090 let kernel = Matrix::from_vec(3, 3, vec![1.0; 9]).unwrap();
4091 let kernel_scaled = Matrix::from_vec(3, 3, vec![scalar; 9]).unwrap();
4092
4093 let result1 = input.convolve2d(&kernel).unwrap();
4094 let result2 = input.convolve2d(&kernel_scaled).unwrap();
4095
4096 for (v1, v2) in result1.as_slice().iter().zip(result2.as_slice().iter()) {
4097 prop_assert!((v1 * scalar - v2).abs() < 1e-3);
4098 }
4099 }
4100 }
4101 }
4102}