extern crate intel_mkl_src;
use super::common::Transpose;
use thiserror::Error;
#[allow(clippy::too_many_arguments)]
pub(super) fn sgemm_impl(
atranspose: Transpose,
btranspose: Transpose,
m: usize,
n: usize,
k: usize,
alpha: f32,
a: &[f32],
b: &[f32],
beta: Option<f32>,
c: &mut [f32],
) {
assert_eq!(a.len(), m * k);
assert_eq!(b.len(), k * n);
assert_eq!(c.len(), m * n);
let m: i32 = m.try_into().unwrap();
let n: i32 = n.try_into().unwrap();
let k: i32 = k.try_into().unwrap();
unsafe {
cblas::sgemm(
cblas::Layout::RowMajor,
atranspose.forward(cblas::Transpose::None, cblas::Transpose::Ordinary),
btranspose.forward(cblas::Transpose::None, cblas::Transpose::Ordinary),
m,
n,
k,
alpha,
a,
atranspose.forward(k, m),
b,
btranspose.forward(n, k),
beta.unwrap_or(0.0),
c,
n,
)
}
}
#[derive(Debug, Error)]
#[error("lapacke::sgessd failed with return code {error_code}")]
struct SVDError {
error_code: i32,
}
pub(super) fn svd_into_impl(
m: usize,
n: usize,
a: &mut [f32],
singular_values: &mut [f32],
u: &mut [f32],
vt: &mut [f32],
) -> Result<(), impl std::error::Error + 'static> {
debug_assert_eq!(a.len(), m * n);
debug_assert_eq!(singular_values.len(), m.min(n));
debug_assert_eq!(u.len(), m * m);
debug_assert_eq!(vt.len(), n * n);
let m: i32 = m.try_into().unwrap();
let n: i32 = n.try_into().unwrap();
let error_code = unsafe {
lapacke::sgesdd(
lapacke::Layout::RowMajor,
b'A',
m,
n,
a,
n,
singular_values,
u,
m,
vt,
n,
)
};
match error_code {
0 => Ok(()),
error_code => Err(SVDError { error_code }),
}
}