1use std::fmt;
7
8pub mod common;
9pub use common::Transpose;
10
11mod faer;
12use faer::{random_distance_preserving_matrix_impl, sgemm_impl, svd_into_impl};
13use rand::Rng;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum MatrixName {
18 A,
19 B,
20 C,
21}
22
23impl fmt::Display for MatrixName {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 match self {
26 MatrixName::A => write!(f, "a (m * k)"),
27 MatrixName::B => write!(f, "b (k * n)"),
28 MatrixName::C => write!(f, "c (m * n)"),
29 }
30 }
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
35pub enum SgemmError {
36 InvalidMatrixDimensions {
38 matrix_name: MatrixName,
39 expected_rows: usize,
40 expected_cols: usize,
41 actual_len: usize,
42 },
43 DimensionOverflow {
45 matrix_name: MatrixName,
46 rows: usize,
47 cols: usize,
48 },
49}
50
51impl fmt::Display for SgemmError {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 match self {
54 SgemmError::InvalidMatrixDimensions {
55 matrix_name,
56 expected_rows,
57 expected_cols,
58 actual_len,
59 } => write!(
60 f,
61 "expected {}x{} matrix {} to have length {}, instead got {}",
62 expected_rows,
63 expected_cols,
64 matrix_name,
65 expected_rows * expected_cols,
66 actual_len
67 ),
68 SgemmError::DimensionOverflow {
69 matrix_name,
70 rows,
71 cols,
72 } => write!(
73 f,
74 "dimension overflow in matrix {}: {} * {} would overflow usize",
75 matrix_name, rows, cols
76 ),
77 }
78 }
79}
80
81impl std::error::Error for SgemmError {}
82
83#[cfg(test)]
85mod reference;
86
87#[allow(clippy::too_many_arguments)]
145pub fn sgemm(
146 atranspose: Transpose,
147 btranspose: Transpose,
148 m: usize,
149 n: usize,
150 k: usize,
151 alpha: f32,
152 a: &[f32],
153 b: &[f32],
154 beta: Option<f32>,
155 c: &mut [f32],
156) -> Result<(), SgemmError> {
157 let expected_a_len = m.checked_mul(k).ok_or(SgemmError::DimensionOverflow {
159 matrix_name: MatrixName::A,
160 rows: m,
161 cols: k,
162 })?;
163
164 if a.len() != expected_a_len {
165 return Err(SgemmError::InvalidMatrixDimensions {
166 matrix_name: MatrixName::A,
167 expected_rows: m,
168 expected_cols: k,
169 actual_len: a.len(),
170 });
171 }
172
173 let expected_b_len = k.checked_mul(n).ok_or(SgemmError::DimensionOverflow {
174 matrix_name: MatrixName::B,
175 rows: k,
176 cols: n,
177 })?;
178
179 if b.len() != expected_b_len {
180 return Err(SgemmError::InvalidMatrixDimensions {
181 matrix_name: MatrixName::B,
182 expected_rows: k,
183 expected_cols: n,
184 actual_len: b.len(),
185 });
186 }
187
188 let expected_c_len = m.checked_mul(n).ok_or(SgemmError::DimensionOverflow {
189 matrix_name: MatrixName::C,
190 rows: m,
191 cols: n,
192 })?;
193
194 if c.len() != expected_c_len {
195 return Err(SgemmError::InvalidMatrixDimensions {
196 matrix_name: MatrixName::C,
197 expected_rows: m,
198 expected_cols: n,
199 actual_len: c.len(),
200 });
201 }
202
203 sgemm_impl(atranspose, btranspose, m, n, k, alpha, a, b, beta, c);
205 Ok(())
206}
207
208pub fn svd_into(
239 m: usize,
240 n: usize,
241 a: &mut [f32],
242 singular_values: &mut [f32],
243 u: &mut [f32],
244 vt: &mut [f32],
245) -> Result<(), impl std::error::Error + 'static> {
246 assert_eq!(a.len(), m * n);
248 assert_eq!(singular_values.len(), m.min(n));
249 assert_eq!(u.len(), m * m);
250 assert_eq!(vt.len(), n * n);
251
252 svd_into_impl(m, n, a, singular_values, u, vt)
254}
255
256pub fn random_distance_preserving_matrix<T: Rng + ?Sized>(dim: usize, rng: &mut T) -> Vec<f32> {
261 random_distance_preserving_matrix_impl(dim, rng)
262}
263
264#[cfg(test)]
265mod tests {
266 use approx::{assert_abs_diff_eq, assert_relative_eq};
267 use rand::{distr::Distribution, rngs::StdRng, SeedableRng};
268 use rand_distr::StandardNormal;
269 use serde::Deserialize;
270
271 use super::*;
272 use crate::reference;
273
274 #[test]
279 fn test_reference_implementation() {
280 let problems = reference::test_sgemm_problems();
281 for (i, problem) in problems.iter().enumerate() {
282 let result = problem.check(sgemm);
283 if let Err(err) = result {
284 panic!("{} on iteration {}. Problem: {:?}", err, i, problem);
285 }
286 }
287 }
288
289 #[test]
290 fn test_sgemm_invalid_matrix_a_dimensions() {
291 let mut c = [0.0f32; 6];
292 let err = sgemm(
293 Transpose::None,
294 Transpose::None,
295 2,
296 3,
297 4,
298 1.0,
299 &[0.0; 5], &[0.0; 12],
301 None,
302 &mut c,
303 )
304 .unwrap_err();
305
306 assert_eq!(
307 err.to_string(),
308 "expected 2x4 matrix a (m * k) to have length 8, instead got 5"
309 );
310 }
311
312 #[test]
313 fn test_sgemm_invalid_matrix_b_dimensions() {
314 let mut c = [0.0f32; 6];
315 let err = sgemm(
316 Transpose::None,
317 Transpose::None,
318 2,
319 3,
320 4,
321 1.0,
322 &[0.0; 8],
323 &[0.0; 10], None,
325 &mut c,
326 )
327 .unwrap_err();
328
329 assert_eq!(
330 err.to_string(),
331 "expected 4x3 matrix b (k * n) to have length 12, instead got 10"
332 );
333 }
334
335 #[test]
336 fn test_sgemm_invalid_matrix_c_dimensions() {
337 let mut c = [0.0f32; 5]; let err = sgemm(
339 Transpose::None,
340 Transpose::None,
341 2,
342 3,
343 4,
344 1.0,
345 &[0.0; 8],
346 &[0.0; 12],
347 None,
348 &mut c,
349 )
350 .unwrap_err();
351
352 assert_eq!(
353 err.to_string(),
354 "expected 2x3 matrix c (m * n) to have length 6, instead got 5"
355 );
356 }
357
358 #[test]
359 fn test_sgemm_m_times_k_overflow() {
360 let mut c = [0.0f32];
361 let err = sgemm(
362 Transpose::None,
363 Transpose::None,
364 usize::MAX,
365 1,
366 2,
367 1.0,
368 &[],
369 &[0.0],
370 None,
371 &mut c,
372 )
373 .unwrap_err();
374
375 assert_eq!(
376 err.to_string(),
377 format!(
378 "dimension overflow in matrix a (m * k): {} * 2 would overflow usize",
379 usize::MAX
380 )
381 );
382 }
383
384 #[test]
385 fn test_sgemm_k_times_n_overflow() {
386 let mut c = vec![0.0f32; 10];
387 let err = sgemm(
388 Transpose::None,
389 Transpose::None,
390 1,
391 usize::MAX,
392 10,
393 1.0,
394 &[0.0f32; 10],
395 &[],
396 None,
397 &mut c,
398 )
399 .unwrap_err();
400
401 assert_eq!(
402 err.to_string(),
403 format!(
404 "dimension overflow in matrix b (k * n): 10 * {} would overflow usize",
405 usize::MAX
406 )
407 );
408 }
409
410 #[test]
411 fn test_sgemm_m_times_n_overflow() {
412 let mut c = [];
413 let err = sgemm(
414 Transpose::None,
415 Transpose::None,
416 2,
417 usize::MAX,
418 0,
419 1.0,
420 &[],
421 &[],
422 None,
423 &mut c,
424 )
425 .unwrap_err();
426
427 assert_eq!(
428 err.to_string(),
429 format!(
430 "dimension overflow in matrix c (m * n): 2 * {} would overflow usize",
431 usize::MAX
432 )
433 );
434 }
435
436 #[test]
439 fn test_sgemm_result_size() {
440 let mut c = [0.0f32; 6];
441 let result = sgemm(
442 Transpose::None,
443 Transpose::None,
444 2,
445 3,
446 4,
447 1.0,
448 &[0.0; 5],
449 &[0.0; 12],
450 None,
451 &mut c,
452 );
453
454 let result_size = std::mem::size_of_val(&result);
455 const EXPECTED_RESULT_SIZE: usize = 32;
456 assert_eq!(
457 result_size, EXPECTED_RESULT_SIZE,
458 "Result size is {} bytes, does not match the expected size of {} bytes.",
459 result_size, EXPECTED_RESULT_SIZE
460 );
461 }
462
463 fn test_file_path(name: &str) -> String {
468 format!("{}/test_data/{}", env!("CARGO_MANIFEST_DIR"), name)
469 }
470
471 const SVD_INPUT_FILE: &str = "reference_svd_inputs.json";
473
474 #[derive(Deserialize, Debug)]
475 struct SVDTestCase {
476 m: usize,
477 n: usize,
478 matrix: Vec<f32>,
479 singular_values: Vec<f32>,
480 }
481
482 impl SVDTestCase {
483 fn summary(&self) -> String {
484 format!("svd test case with dimension {}x{}", self.m, self.n)
485 }
486 }
487
488 struct SVDTolerance {
489 absolute: f32,
490 relative: f32,
491 }
492
493 impl SVDTolerance {
494 fn check(&self, absolute: f32, relative: f32) -> bool {
495 absolute <= self.absolute || relative <= self.relative
496 }
497 }
498
499 fn materialize_singular_values(singular_values: &[f32], m: usize, n: usize) -> Vec<f32> {
500 assert_eq!(singular_values.len(), m.min(n));
501 let mut output = vec![0.0; m * n];
502
503 for (i, &s) in singular_values.iter().enumerate() {
504 output[n * i + i] = s;
505 }
506 output
507 }
508
509 fn test_svd(
510 case: &SVDTestCase,
511 singular_value_tolerance: &SVDTolerance,
512 reconstructed_tolerance: &SVDTolerance,
513 context: &dyn std::fmt::Display,
514 ) {
515 let mut singular_values = vec![0.0; case.m.min(case.n)];
517 let mut u = vec![0.0; case.m * case.m];
518 let mut vt = vec![0.0; case.n * case.n];
519
520 svd_into(
521 case.m,
522 case.n,
523 &mut case.matrix.clone(),
524 &mut singular_values,
525 &mut u,
526 &mut vt,
527 )
528 .unwrap();
529
530 for (i, (&got, &expected)) in
532 std::iter::zip(singular_values.iter(), case.singular_values.iter()).enumerate()
533 {
534 let diff = (got - expected).abs();
535 let relative = diff / expected;
536 assert!(
537 singular_value_tolerance.check(diff, relative),
538 "got {} but expected {} (diff: {}, relative: {}) at position {}: {}",
539 got,
540 expected,
541 diff,
542 relative,
543 i,
544 context
545 );
546 }
547
548 let full_singular_values = materialize_singular_values(&singular_values, case.m, case.n);
550 let mut temp = vec![0.0; case.m * case.n];
551
552 sgemm(
554 Transpose::None,
555 Transpose::None,
556 case.m,
557 case.n,
558 case.m,
559 1.0,
560 &u,
561 &full_singular_values,
562 None,
563 &mut temp,
564 )
565 .unwrap();
566
567 let mut output = vec![0.0; case.m * case.n];
568 sgemm(
569 Transpose::None,
570 Transpose::None,
571 case.m,
572 case.n,
573 case.n,
574 1.0,
575 &temp,
576 &vt,
577 None,
578 &mut output,
579 )
580 .unwrap();
581
582 for row in 0..case.m {
583 for col in 0..case.n {
584 let got = output[case.n * row + col];
585 let expected = case.matrix[case.n * row + col];
586 let diff = (got - expected).abs();
587 let relative = diff / expected;
588 assert!(
589 reconstructed_tolerance.check(diff, relative),
590 "mismatch in reconstructed matrix at (row, col) = ({}, {}). \
591 Got {}, expected {} (diff: {}, relative: {}). {}",
592 row,
593 col,
594 got,
595 expected,
596 diff,
597 relative,
598 context
599 );
600 }
601 }
602 }
603
604 #[test]
605 fn test_svd_implementation() {
606 let path = test_file_path(SVD_INPUT_FILE);
607 let file = std::fs::File::open(path.clone())
608 .unwrap_or_else(|_| panic!("failed to open file {path}"));
609
610 let reader = std::io::BufReader::new(file);
611 let cases: Vec<SVDTestCase> = serde_json::from_reader(reader).unwrap();
612
613 let singular_values_tolerance = SVDTolerance {
614 absolute: 2.0e-6,
615 relative: 3.0e-6,
616 };
617
618 let reconstructed_tolerance = SVDTolerance {
619 absolute: 5.0e-5,
620 relative: 0.0,
621 };
622
623 for (i, case) in cases.iter().enumerate() {
624 let context = format!(
625 "while processing case {} of {}: {}",
626 i + 1,
627 cases.len(),
628 case.summary()
629 );
630 test_svd(
631 case,
632 &singular_values_tolerance,
633 &reconstructed_tolerance,
634 &context,
635 );
636 }
637 }
638
639 const EPSILON: f32 = 1e-5;
644
645 fn test_distance_preserving_matrix_impl(dim: usize, rng: &mut StdRng) {
646 let q = random_distance_preserving_matrix(dim, rng);
648
649 let qm = ::faer::mat::MatRef::from_row_major_slice(&q, dim, dim);
651 let m = qm * qm.transpose();
652
653 for j in 0..dim {
654 for i in 0..dim {
655 if i == j {
656 assert_abs_diff_eq!(m[(i, j)], 1.0, epsilon = EPSILON);
657 } else {
658 assert_abs_diff_eq!(m[(i, j)], 0.0, epsilon = EPSILON);
659 }
660 }
661 }
662
663 const RANDOM_TRIALS: usize = 100;
666 let mut v = vec![0.0f32; dim];
667 for _ in 0..RANDOM_TRIALS {
668 v.iter_mut()
669 .for_each(|i| *i = StandardNormal {}.sample(rng));
670 let vm = ::faer::mat::MatRef::from_row_major_slice(&v, dim, 1);
671 let v_norm = vm.squared_norm_l2();
672 let t = qm * vm;
673 let t_norm = t.squared_norm_l2();
674
675 assert_relative_eq!(v_norm, t_norm, epsilon = EPSILON, max_relative = EPSILON);
676 assert_ne!(vm, t);
677 }
678 }
679
680 #[test]
681 fn test_rotation_matrix() {
682 let mut rng = StdRng::seed_from_u64(0xc0ff33);
683 let num_trials = 5;
684 for dim in [2, 100, 256] {
685 for _ in 0..num_trials {
686 test_distance_preserving_matrix_impl(dim, &mut rng);
687 }
688 }
689 }
690}