#[repr(i32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CblasOrder {
RowMajor = 101,
ColMajor = 102,
}
#[repr(i32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CblasTranspose {
NoTrans = 111,
Trans = 112,
ConjTrans = 113,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MatrixLayout {
#[default]
RowMajor,
ColMajor,
}
impl From<MatrixLayout> for CblasOrder {
fn from(layout: MatrixLayout) -> Self {
match layout {
MatrixLayout::RowMajor => CblasOrder::RowMajor,
MatrixLayout::ColMajor => CblasOrder::ColMajor,
}
}
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[link(name = "Accelerate", kind = "framework")]
extern "C" {
fn cblas_sgemv(
order: i32,
trans: i32,
m: i32,
n: i32,
alpha: f32,
a: *const f32,
lda: i32,
x: *const f32,
incx: i32,
beta: f32,
y: *mut f32,
incy: i32,
);
fn cblas_sgemm(
order: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: *const f32,
lda: i32,
b: *const f32,
ldb: i32,
beta: f32,
c: *mut f32,
ldc: i32,
);
fn cblas_sdot(n: i32, x: *const f32, incx: i32, y: *const f32, incy: i32) -> f32;
fn cblas_sscal(n: i32, alpha: f32, x: *mut f32, incx: i32);
fn cblas_saxpy(n: i32, alpha: f32, x: *const f32, incx: i32, y: *mut f32, incy: i32);
}
const ACCELERATE_MIN_DIM: usize = 256;
const ACCELERATE_MIN_OPS: usize = 65536;
#[inline(always)]
pub fn is_accelerate_available() -> bool {
#[cfg(all(target_os = "macos", feature = "accelerate"))]
{
true
}
#[cfg(not(all(target_os = "macos", feature = "accelerate")))]
{
false
}
}
#[inline(always)]
pub fn should_use_accelerate(m: usize, n: usize) -> bool {
is_accelerate_available()
&& m >= ACCELERATE_MIN_DIM
&& n >= ACCELERATE_MIN_DIM
&& m * n >= ACCELERATE_MIN_OPS
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
pub fn gemv_accelerate(
a: &[f32],
x: &[f32],
y: &mut [f32],
m: usize,
n: usize,
layout: MatrixLayout,
) {
debug_assert_eq!(
a.len(),
m * n,
"Matrix A size mismatch: expected {}, got {}",
m * n,
a.len()
);
debug_assert_eq!(
x.len(),
n,
"Vector x size mismatch: expected {}, got {}",
n,
x.len()
);
debug_assert_eq!(
y.len(),
m,
"Vector y size mismatch: expected {}, got {}",
m,
y.len()
);
assert!(
m <= i32::MAX as usize,
"Matrix dimension m={} exceeds i32::MAX for BLAS",
m
);
assert!(
n <= i32::MAX as usize,
"Matrix dimension n={} exceeds i32::MAX for BLAS",
n
);
unsafe {
gemv_accelerate_unchecked(a, x, y, m, n, layout);
}
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[inline(always)]
pub unsafe fn gemv_accelerate_unchecked(
a: &[f32],
x: &[f32],
y: &mut [f32],
m: usize,
n: usize,
layout: MatrixLayout,
) {
let order = CblasOrder::from(layout) as i32;
let trans = CblasTranspose::NoTrans as i32;
let lda = match layout {
MatrixLayout::RowMajor => n as i32,
MatrixLayout::ColMajor => m as i32,
};
cblas_sgemv(
order,
trans,
m as i32,
n as i32,
1.0, a.as_ptr(),
lda,
x.as_ptr(),
1, 0.0, y.as_mut_ptr(),
1, );
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
pub fn gemv_transpose_accelerate(
a: &[f32],
x: &[f32],
y: &mut [f32],
m: usize,
n: usize,
layout: MatrixLayout,
) {
debug_assert_eq!(a.len(), m * n);
debug_assert_eq!(x.len(), m); debug_assert_eq!(y.len(), n);
assert!(
m <= i32::MAX as usize,
"Matrix dimension m={} exceeds i32::MAX for BLAS",
m
);
assert!(
n <= i32::MAX as usize,
"Matrix dimension n={} exceeds i32::MAX for BLAS",
n
);
unsafe {
let order = CblasOrder::from(layout) as i32;
let trans = CblasTranspose::Trans as i32;
let lda = match layout {
MatrixLayout::RowMajor => n as i32,
MatrixLayout::ColMajor => m as i32,
};
cblas_sgemv(
order,
trans,
m as i32,
n as i32,
1.0,
a.as_ptr(),
lda,
x.as_ptr(),
1,
0.0,
y.as_mut_ptr(),
1,
);
}
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
pub fn gemv_scaled_accelerate(
a: &[f32],
x: &[f32],
y: &mut [f32],
m: usize,
n: usize,
alpha: f32,
beta: f32,
layout: MatrixLayout,
) {
debug_assert_eq!(a.len(), m * n);
debug_assert_eq!(x.len(), n);
debug_assert_eq!(y.len(), m);
assert!(
m <= i32::MAX as usize,
"Matrix dimension m={} exceeds i32::MAX for BLAS",
m
);
assert!(
n <= i32::MAX as usize,
"Matrix dimension n={} exceeds i32::MAX for BLAS",
n
);
unsafe {
let order = CblasOrder::from(layout) as i32;
let trans = CblasTranspose::NoTrans as i32;
let lda = match layout {
MatrixLayout::RowMajor => n as i32,
MatrixLayout::ColMajor => m as i32,
};
cblas_sgemv(
order,
trans,
m as i32,
n as i32,
alpha,
a.as_ptr(),
lda,
x.as_ptr(),
1,
beta,
y.as_mut_ptr(),
1,
);
}
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
pub fn gemm_accelerate(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
debug_assert_eq!(c.len(), m * n);
assert!(
m <= i32::MAX as usize,
"Matrix dimension m={} exceeds i32::MAX for BLAS",
m
);
assert!(
k <= i32::MAX as usize,
"Matrix dimension k={} exceeds i32::MAX for BLAS",
k
);
assert!(
n <= i32::MAX as usize,
"Matrix dimension n={} exceeds i32::MAX for BLAS",
n
);
unsafe {
cblas_sgemm(
CblasOrder::RowMajor as i32,
CblasTranspose::NoTrans as i32,
CblasTranspose::NoTrans as i32,
m as i32,
n as i32,
k as i32,
1.0, a.as_ptr(),
k as i32, b.as_ptr(),
n as i32, 0.0, c.as_mut_ptr(),
n as i32, );
}
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[inline]
pub fn dot_accelerate(x: &[f32], y: &[f32]) -> f32 {
debug_assert_eq!(x.len(), y.len());
unsafe { cblas_sdot(x.len() as i32, x.as_ptr(), 1, y.as_ptr(), 1) }
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[inline]
pub fn scal_accelerate(x: &mut [f32], alpha: f32) {
unsafe { cblas_sscal(x.len() as i32, alpha, x.as_mut_ptr(), 1) }
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[inline]
pub fn axpy_accelerate(x: &[f32], y: &mut [f32], alpha: f32) {
debug_assert_eq!(x.len(), y.len());
unsafe { cblas_saxpy(x.len() as i32, alpha, x.as_ptr(), 1, y.as_mut_ptr(), 1) }
}
#[cfg(not(all(target_os = "macos", feature = "accelerate")))]
pub fn gemv_accelerate(
_a: &[f32],
_x: &[f32],
_y: &mut [f32],
_m: usize,
_n: usize,
_layout: MatrixLayout,
) {
panic!("Accelerate framework is only available on macOS with 'accelerate' feature enabled");
}
#[cfg(not(all(target_os = "macos", feature = "accelerate")))]
pub unsafe fn gemv_accelerate_unchecked(
_a: &[f32],
_x: &[f32],
_y: &mut [f32],
_m: usize,
_n: usize,
_layout: MatrixLayout,
) {
panic!("Accelerate framework is only available on macOS with 'accelerate' feature enabled");
}
#[cfg(not(all(target_os = "macos", feature = "accelerate")))]
pub fn gemv_transpose_accelerate(
_a: &[f32],
_x: &[f32],
_y: &mut [f32],
_m: usize,
_n: usize,
_layout: MatrixLayout,
) {
panic!("Accelerate framework is only available on macOS with 'accelerate' feature enabled");
}
#[cfg(not(all(target_os = "macos", feature = "accelerate")))]
pub fn gemv_scaled_accelerate(
_a: &[f32],
_x: &[f32],
_y: &mut [f32],
_m: usize,
_n: usize,
_alpha: f32,
_beta: f32,
_layout: MatrixLayout,
) {
panic!("Accelerate framework is only available on macOS with 'accelerate' feature enabled");
}
#[cfg(not(all(target_os = "macos", feature = "accelerate")))]
pub fn gemm_accelerate(_a: &[f32], _b: &[f32], _c: &mut [f32], _m: usize, _k: usize, _n: usize) {
panic!("Accelerate framework is only available on macOS with 'accelerate' feature enabled");
}
#[cfg(not(all(target_os = "macos", feature = "accelerate")))]
pub fn dot_accelerate(_x: &[f32], _y: &[f32]) -> f32 {
panic!("Accelerate framework is only available on macOS with 'accelerate' feature enabled");
}
#[cfg(not(all(target_os = "macos", feature = "accelerate")))]
pub fn scal_accelerate(_x: &mut [f32], _alpha: f32) {
panic!("Accelerate framework is only available on macOS with 'accelerate' feature enabled");
}
#[cfg(not(all(target_os = "macos", feature = "accelerate")))]
pub fn axpy_accelerate(_x: &[f32], _y: &mut [f32], _alpha: f32) {
panic!("Accelerate framework is only available on macOS with 'accelerate' feature enabled");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_accelerate_availability() {
#[cfg(all(target_os = "macos", feature = "accelerate"))]
assert!(is_accelerate_available());
#[cfg(not(all(target_os = "macos", feature = "accelerate")))]
assert!(!is_accelerate_available());
}
#[test]
fn test_should_use_accelerate_thresholds() {
assert!(!should_use_accelerate(128, 128));
assert!(!should_use_accelerate(255, 256));
#[cfg(all(target_os = "macos", feature = "accelerate"))]
{
assert!(should_use_accelerate(256, 256));
assert!(should_use_accelerate(4096, 4096));
}
#[cfg(not(all(target_os = "macos", feature = "accelerate")))]
{
assert!(!should_use_accelerate(256, 256));
assert!(!should_use_accelerate(4096, 4096));
}
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[test]
fn test_gemv_accelerate_correctness() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let x = vec![1.0, 1.0, 1.0];
let mut y = vec![0.0, 0.0];
gemv_accelerate(&a, &x, &mut y, 2, 3, MatrixLayout::RowMajor);
assert!((y[0] - 6.0).abs() < 1e-5);
assert!((y[1] - 15.0).abs() < 1e-5);
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[test]
fn test_gemv_transpose_correctness() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let x = vec![1.0, 1.0];
let mut y = vec![0.0, 0.0, 0.0];
gemv_transpose_accelerate(&a, &x, &mut y, 2, 3, MatrixLayout::RowMajor);
assert!((y[0] - 5.0).abs() < 1e-5);
assert!((y[1] - 7.0).abs() < 1e-5);
assert!((y[2] - 9.0).abs() < 1e-5);
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[test]
fn test_gemv_scaled_correctness() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let x = vec![1.0, 1.0, 1.0];
let mut y = vec![1.0, 2.0];
gemv_scaled_accelerate(&a, &x, &mut y, 2, 3, 2.0, 3.0, MatrixLayout::RowMajor);
assert!((y[0] - 15.0).abs() < 1e-5);
assert!((y[1] - 36.0).abs() < 1e-5);
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[test]
fn test_gemm_accelerate_correctness() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let mut c = vec![0.0; 4];
gemm_accelerate(&a, &b, &mut c, 2, 2, 2);
assert!((c[0] - 19.0).abs() < 1e-5);
assert!((c[1] - 22.0).abs() < 1e-5);
assert!((c[2] - 43.0).abs() < 1e-5);
assert!((c[3] - 50.0).abs() < 1e-5);
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[test]
fn test_dot_accelerate_correctness() {
let x = vec![1.0, 2.0, 3.0];
let y = vec![4.0, 5.0, 6.0];
let result = dot_accelerate(&x, &y);
assert!((result - 32.0).abs() < 1e-5);
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[test]
fn test_scal_accelerate_correctness() {
let mut x = vec![1.0, 2.0, 3.0];
scal_accelerate(&mut x, 2.0);
assert!((x[0] - 2.0).abs() < 1e-5);
assert!((x[1] - 4.0).abs() < 1e-5);
assert!((x[2] - 6.0).abs() < 1e-5);
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[test]
fn test_axpy_accelerate_correctness() {
let x = vec![1.0, 2.0, 3.0];
let mut y = vec![4.0, 5.0, 6.0];
axpy_accelerate(&x, &mut y, 2.0);
assert!((y[0] - 6.0).abs() < 1e-5);
assert!((y[1] - 9.0).abs() < 1e-5);
assert!((y[2] - 12.0).abs() < 1e-5);
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[test]
fn test_gemv_large_matrix() {
let m = 512;
let n = 512;
let a: Vec<f32> = (0..m * n).map(|i| (i % 10) as f32 * 0.1).collect();
let x: Vec<f32> = vec![1.0; n];
let mut y = vec![0.0; m];
gemv_accelerate(&a, &x, &mut y, m, n, MatrixLayout::RowMajor);
assert!(y.iter().any(|&v| v != 0.0));
}
#[cfg(all(target_os = "macos", feature = "accelerate"))]
#[test]
fn test_col_major_layout() {
let a = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]; let x = vec![1.0, 1.0, 1.0];
let mut y = vec![0.0, 0.0];
gemv_accelerate(&a, &x, &mut y, 2, 3, MatrixLayout::ColMajor);
assert!((y[0] - 6.0).abs() < 1e-5);
assert!((y[1] - 15.0).abs() < 1e-5);
}
}