use std::fmt;
pub mod common;
pub use common::Transpose;
mod faer;
use faer::{random_distance_preserving_matrix_impl, sgemm_impl, svd_into_impl};
use rand::Rng;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MatrixName {
A,
B,
C,
}
impl fmt::Display for MatrixName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MatrixName::A => write!(f, "a (m * k)"),
MatrixName::B => write!(f, "b (k * n)"),
MatrixName::C => write!(f, "c (m * n)"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SgemmError {
InvalidMatrixDimensions {
matrix_name: MatrixName,
expected_rows: usize,
expected_cols: usize,
actual_len: usize,
},
DimensionOverflow {
matrix_name: MatrixName,
rows: usize,
cols: usize,
},
}
impl fmt::Display for SgemmError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SgemmError::InvalidMatrixDimensions {
matrix_name,
expected_rows,
expected_cols,
actual_len,
} => write!(
f,
"expected {}x{} matrix {} to have length {}, instead got {}",
expected_rows,
expected_cols,
matrix_name,
expected_rows * expected_cols,
actual_len
),
SgemmError::DimensionOverflow {
matrix_name,
rows,
cols,
} => write!(
f,
"dimension overflow in matrix {}: {} * {} would overflow usize",
matrix_name, rows, cols
),
}
}
}
impl std::error::Error for SgemmError {}
#[cfg(test)]
mod reference;
#[allow(clippy::too_many_arguments)]
pub fn sgemm(
atranspose: Transpose,
btranspose: Transpose,
m: usize,
n: usize,
k: usize,
alpha: f32,
a: &[f32],
b: &[f32],
beta: Option<f32>,
c: &mut [f32],
) -> Result<(), SgemmError> {
let expected_a_len = m.checked_mul(k).ok_or(SgemmError::DimensionOverflow {
matrix_name: MatrixName::A,
rows: m,
cols: k,
})?;
if a.len() != expected_a_len {
return Err(SgemmError::InvalidMatrixDimensions {
matrix_name: MatrixName::A,
expected_rows: m,
expected_cols: k,
actual_len: a.len(),
});
}
let expected_b_len = k.checked_mul(n).ok_or(SgemmError::DimensionOverflow {
matrix_name: MatrixName::B,
rows: k,
cols: n,
})?;
if b.len() != expected_b_len {
return Err(SgemmError::InvalidMatrixDimensions {
matrix_name: MatrixName::B,
expected_rows: k,
expected_cols: n,
actual_len: b.len(),
});
}
let expected_c_len = m.checked_mul(n).ok_or(SgemmError::DimensionOverflow {
matrix_name: MatrixName::C,
rows: m,
cols: n,
})?;
if c.len() != expected_c_len {
return Err(SgemmError::InvalidMatrixDimensions {
matrix_name: MatrixName::C,
expected_rows: m,
expected_cols: n,
actual_len: c.len(),
});
}
sgemm_impl(atranspose, btranspose, m, n, k, alpha, a, b, beta, c);
Ok(())
}
pub fn svd_into(
m: usize,
n: usize,
a: &mut [f32],
singular_values: &mut [f32],
u: &mut [f32],
vt: &mut [f32],
) -> Result<(), impl std::error::Error + 'static> {
assert_eq!(a.len(), m * n);
assert_eq!(singular_values.len(), m.min(n));
assert_eq!(u.len(), m * m);
assert_eq!(vt.len(), n * n);
svd_into_impl(m, n, a, singular_values, u, vt)
}
pub fn random_distance_preserving_matrix<T: Rng + ?Sized>(dim: usize, rng: &mut T) -> Vec<f32> {
random_distance_preserving_matrix_impl(dim, rng)
}
#[cfg(test)]
mod tests {
use approx::{assert_abs_diff_eq, assert_relative_eq};
use rand::{distr::Distribution, rngs::StdRng, SeedableRng};
use rand_distr::StandardNormal;
use serde::Deserialize;
use super::*;
use crate::reference;
#[test]
fn test_reference_implementation() {
let problems = reference::test_sgemm_problems();
for (i, problem) in problems.iter().enumerate() {
let result = problem.check(sgemm);
if let Err(err) = result {
panic!("{} on iteration {}. Problem: {:?}", err, i, problem);
}
}
}
#[test]
fn test_sgemm_invalid_matrix_a_dimensions() {
let mut c = [0.0f32; 6];
let err = sgemm(
Transpose::None,
Transpose::None,
2,
3,
4,
1.0,
&[0.0; 5], &[0.0; 12],
None,
&mut c,
)
.unwrap_err();
assert_eq!(
err.to_string(),
"expected 2x4 matrix a (m * k) to have length 8, instead got 5"
);
}
#[test]
fn test_sgemm_invalid_matrix_b_dimensions() {
let mut c = [0.0f32; 6];
let err = sgemm(
Transpose::None,
Transpose::None,
2,
3,
4,
1.0,
&[0.0; 8],
&[0.0; 10], None,
&mut c,
)
.unwrap_err();
assert_eq!(
err.to_string(),
"expected 4x3 matrix b (k * n) to have length 12, instead got 10"
);
}
#[test]
fn test_sgemm_invalid_matrix_c_dimensions() {
let mut c = [0.0f32; 5]; let err = sgemm(
Transpose::None,
Transpose::None,
2,
3,
4,
1.0,
&[0.0; 8],
&[0.0; 12],
None,
&mut c,
)
.unwrap_err();
assert_eq!(
err.to_string(),
"expected 2x3 matrix c (m * n) to have length 6, instead got 5"
);
}
#[test]
fn test_sgemm_m_times_k_overflow() {
let mut c = [0.0f32];
let err = sgemm(
Transpose::None,
Transpose::None,
usize::MAX,
1,
2,
1.0,
&[],
&[0.0],
None,
&mut c,
)
.unwrap_err();
assert_eq!(
err.to_string(),
format!(
"dimension overflow in matrix a (m * k): {} * 2 would overflow usize",
usize::MAX
)
);
}
#[test]
fn test_sgemm_k_times_n_overflow() {
let mut c = vec![0.0f32; 10];
let err = sgemm(
Transpose::None,
Transpose::None,
1,
usize::MAX,
10,
1.0,
&[0.0f32; 10],
&[],
None,
&mut c,
)
.unwrap_err();
assert_eq!(
err.to_string(),
format!(
"dimension overflow in matrix b (k * n): 10 * {} would overflow usize",
usize::MAX
)
);
}
#[test]
fn test_sgemm_m_times_n_overflow() {
let mut c = [];
let err = sgemm(
Transpose::None,
Transpose::None,
2,
usize::MAX,
0,
1.0,
&[],
&[],
None,
&mut c,
)
.unwrap_err();
assert_eq!(
err.to_string(),
format!(
"dimension overflow in matrix c (m * n): 2 * {} would overflow usize",
usize::MAX
)
);
}
#[test]
fn test_sgemm_result_size() {
let mut c = [0.0f32; 6];
let result = sgemm(
Transpose::None,
Transpose::None,
2,
3,
4,
1.0,
&[0.0; 5],
&[0.0; 12],
None,
&mut c,
);
let result_size = std::mem::size_of_val(&result);
const EXPECTED_RESULT_SIZE: usize = 32;
assert_eq!(
result_size, EXPECTED_RESULT_SIZE,
"Result size is {} bytes, does not match the expected size of {} bytes.",
result_size, EXPECTED_RESULT_SIZE
);
}
fn test_file_path(name: &str) -> String {
format!("{}/test_data/{}", env!("CARGO_MANIFEST_DIR"), name)
}
const SVD_INPUT_FILE: &str = "reference_svd_inputs.json";
#[derive(Deserialize, Debug)]
struct SVDTestCase {
m: usize,
n: usize,
matrix: Vec<f32>,
singular_values: Vec<f32>,
}
impl SVDTestCase {
fn summary(&self) -> String {
format!("svd test case with dimension {}x{}", self.m, self.n)
}
}
struct SVDTolerance {
absolute: f32,
relative: f32,
}
impl SVDTolerance {
fn check(&self, absolute: f32, relative: f32) -> bool {
absolute <= self.absolute || relative <= self.relative
}
}
fn materialize_singular_values(singular_values: &[f32], m: usize, n: usize) -> Vec<f32> {
assert_eq!(singular_values.len(), m.min(n));
let mut output = vec![0.0; m * n];
for (i, &s) in singular_values.iter().enumerate() {
output[n * i + i] = s;
}
output
}
fn test_svd(
case: &SVDTestCase,
singular_value_tolerance: &SVDTolerance,
reconstructed_tolerance: &SVDTolerance,
context: &dyn std::fmt::Display,
) {
let mut singular_values = vec![0.0; case.m.min(case.n)];
let mut u = vec![0.0; case.m * case.m];
let mut vt = vec![0.0; case.n * case.n];
svd_into(
case.m,
case.n,
&mut case.matrix.clone(),
&mut singular_values,
&mut u,
&mut vt,
)
.unwrap();
for (i, (&got, &expected)) in
std::iter::zip(singular_values.iter(), case.singular_values.iter()).enumerate()
{
let diff = (got - expected).abs();
let relative = diff / expected;
assert!(
singular_value_tolerance.check(diff, relative),
"got {} but expected {} (diff: {}, relative: {}) at position {}: {}",
got,
expected,
diff,
relative,
i,
context
);
}
let full_singular_values = materialize_singular_values(&singular_values, case.m, case.n);
let mut temp = vec![0.0; case.m * case.n];
sgemm(
Transpose::None,
Transpose::None,
case.m,
case.n,
case.m,
1.0,
&u,
&full_singular_values,
None,
&mut temp,
)
.unwrap();
let mut output = vec![0.0; case.m * case.n];
sgemm(
Transpose::None,
Transpose::None,
case.m,
case.n,
case.n,
1.0,
&temp,
&vt,
None,
&mut output,
)
.unwrap();
for row in 0..case.m {
for col in 0..case.n {
let got = output[case.n * row + col];
let expected = case.matrix[case.n * row + col];
let diff = (got - expected).abs();
let relative = diff / expected;
assert!(
reconstructed_tolerance.check(diff, relative),
"mismatch in reconstructed matrix at (row, col) = ({}, {}). \
Got {}, expected {} (diff: {}, relative: {}). {}",
row,
col,
got,
expected,
diff,
relative,
context
);
}
}
}
#[test]
fn test_svd_implementation() {
let path = test_file_path(SVD_INPUT_FILE);
let file = std::fs::File::open(path.clone())
.unwrap_or_else(|_| panic!("failed to open file {path}"));
let reader = std::io::BufReader::new(file);
let cases: Vec<SVDTestCase> = serde_json::from_reader(reader).unwrap();
let singular_values_tolerance = SVDTolerance {
absolute: 2.0e-6,
relative: 3.0e-6,
};
let reconstructed_tolerance = SVDTolerance {
absolute: 5.0e-5,
relative: 0.0,
};
for (i, case) in cases.iter().enumerate() {
let context = format!(
"while processing case {} of {}: {}",
i + 1,
cases.len(),
case.summary()
);
test_svd(
case,
&singular_values_tolerance,
&reconstructed_tolerance,
&context,
);
}
}
const EPSILON: f32 = 1e-5;
fn test_distance_preserving_matrix_impl(dim: usize, rng: &mut StdRng) {
let q = random_distance_preserving_matrix(dim, rng);
let qm = ::faer::mat::MatRef::from_row_major_slice(&q, dim, dim);
let m = qm * qm.transpose();
for j in 0..dim {
for i in 0..dim {
if i == j {
assert_abs_diff_eq!(m[(i, j)], 1.0, epsilon = EPSILON);
} else {
assert_abs_diff_eq!(m[(i, j)], 0.0, epsilon = EPSILON);
}
}
}
const RANDOM_TRIALS: usize = 100;
let mut v = vec![0.0f32; dim];
for _ in 0..RANDOM_TRIALS {
v.iter_mut()
.for_each(|i| *i = StandardNormal {}.sample(rng));
let vm = ::faer::mat::MatRef::from_row_major_slice(&v, dim, 1);
let v_norm = vm.squared_norm_l2();
let t = qm * vm;
let t_norm = t.squared_norm_l2();
assert_relative_eq!(v_norm, t_norm, epsilon = EPSILON, max_relative = EPSILON);
assert_ne!(vm, t);
}
}
#[test]
fn test_rotation_matrix() {
let mut rng = StdRng::seed_from_u64(0xc0ff33);
let num_trials = 5;
for dim in [2, 100, 256] {
for _ in 0..num_trials {
test_distance_preserving_matrix_impl(dim, &mut rng);
}
}
}
}