#[cfg(feature = "blas")]
unsafe extern "C" {
fn cblas_sgemm(
order: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: *const f32,
lda: i32,
b: *const f32,
ldb: i32,
beta: f32,
c: *mut f32,
ldc: i32,
);
fn cblas_sgemv(
order: i32,
trans: i32,
m: i32,
n: i32,
alpha: f32,
a: *const f32,
lda: i32,
x: *const f32,
incx: i32,
beta: f32,
y: *mut f32,
incy: i32,
);
fn cblas_sger(
order: i32,
m: i32,
n: i32,
alpha: f32,
x: *const f32,
incx: i32,
y: *const f32,
incy: i32,
a: *mut f32,
lda: i32,
);
fn cblas_sscal(n: i32, alpha: f32, x: *mut f32, incx: i32);
}
#[cfg(not(feature = "blas"))]
#[allow(non_snake_case, clippy::too_many_arguments)]
#[inline]
unsafe fn cblas_sgemm(
_order: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: *const f32,
lda: i32,
b: *const f32,
ldb: i32,
beta: f32,
c: *mut f32,
ldc: i32,
) {
let m = m as usize;
let n = n as usize;
let k = k as usize;
let lda = lda as usize;
let ldb = ldb as usize;
let ldc = ldc as usize;
let trans_a = transa != NO_TRANS;
let trans_b = transb != NO_TRANS;
for i in 0..m {
for j in 0..n {
let mut acc: f32 = 0.0;
for p in 0..k {
let av = if trans_a {
unsafe { *a.add(p * lda + i) }
} else {
unsafe { *a.add(i * lda + p) }
};
let bv = if trans_b {
unsafe { *b.add(j * ldb + p) }
} else {
unsafe { *b.add(p * ldb + j) }
};
acc += av * bv;
}
let cp = unsafe { c.add(i * ldc + j) };
unsafe {
*cp = alpha * acc + beta * *cp;
}
}
}
}
#[cfg(not(feature = "blas"))]
#[allow(non_snake_case, clippy::too_many_arguments)]
#[inline]
unsafe fn cblas_sgemv(
_order: i32,
trans: i32,
m: i32,
n: i32,
alpha: f32,
a: *const f32,
lda: i32,
x: *const f32,
_incx: i32,
beta: f32,
y: *mut f32,
_incy: i32,
) {
let m = m as usize;
let n = n as usize;
let lda = lda as usize;
let trans_a = trans != NO_TRANS;
for i in 0..m {
let mut acc = 0f32;
for j in 0..n {
let av = if trans_a {
unsafe { *a.add(j * lda + i) }
} else {
unsafe { *a.add(i * lda + j) }
};
acc += av * unsafe { *x.add(j) };
}
let yp = unsafe { y.add(i) };
unsafe {
*yp = alpha * acc + beta * *yp;
}
}
}
#[cfg(not(feature = "blas"))]
#[allow(non_snake_case, clippy::too_many_arguments)]
#[inline]
unsafe fn cblas_sger(
_order: i32,
m: i32,
n: i32,
alpha: f32,
x: *const f32,
_incx: i32,
y: *const f32,
_incy: i32,
a: *mut f32,
lda: i32,
) {
let m = m as usize;
let n = n as usize;
let lda = lda as usize;
for i in 0..m {
let xi = unsafe { *x.add(i) };
for j in 0..n {
let yj = unsafe { *y.add(j) };
let ap = unsafe { a.add(i * lda + j) };
unsafe {
*ap += alpha * xi * yj;
}
}
}
}
#[cfg(not(feature = "blas"))]
#[inline]
unsafe fn cblas_sscal(n: i32, alpha: f32, x: *mut f32, _incx: i32) {
for i in 0..n as usize {
let xp = unsafe { x.add(i) };
unsafe {
*xp *= alpha;
}
}
}
#[cfg(feature = "blas")]
unsafe extern "C" {
fn cblas_dgemm(
order: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: f64,
a: *const f64,
lda: i32,
b: *const f64,
ldb: i32,
beta: f64,
c: *mut f64,
ldc: i32,
);
#[link_name = "dgesv_"]
fn lapack_dgesv(
n: *const i32,
nrhs: *const i32,
a: *mut f64,
lda: *const i32,
ipiv: *mut i32,
b: *mut f64,
ldb: *const i32,
info_out: *mut i32,
);
#[link_name = "sgesv_"]
fn lapack_sgesv(
n: *const i32,
nrhs: *const i32,
a: *mut f32,
lda: *const i32,
ipiv: *mut i32,
b: *mut f32,
ldb: *const i32,
info_out: *mut i32,
);
#[link_name = "dpotrf_"]
fn lapack_dpotrf(
uplo: *const i8,
n: *const i32,
a: *mut f64,
lda: *const i32,
info_out: *mut i32,
);
#[link_name = "dsyevd_"]
fn lapack_dsyevd(
jobz: *const i8,
uplo: *const i8,
n: *const i32,
a: *mut f64,
lda: *const i32,
w: *mut f64,
work: *mut f64,
lwork: *const i32,
iwork: *mut i32,
liwork: *const i32,
info_out: *mut i32,
);
#[link_name = "dgeqrf_"]
fn lapack_dgeqrf(
m: *const i32,
n: *const i32,
a: *mut f64,
lda: *const i32,
tau: *mut f64,
work: *mut f64,
lwork: *const i32,
info_out: *mut i32,
);
#[link_name = "dorgqr_"]
fn lapack_dorgqr(
m: *const i32,
n: *const i32,
k: *const i32,
a: *mut f64,
lda: *const i32,
tau: *const f64,
work: *mut f64,
lwork: *const i32,
info_out: *mut i32,
);
#[link_name = "dgesvd_"]
fn lapack_dgesvd(
jobu: *const i8,
jobvt: *const i8,
m: *const i32,
n: *const i32,
a: *mut f64,
lda: *const i32,
s: *mut f64,
u: *mut f64,
ldu: *const i32,
vt: *mut f64,
ldvt: *const i32,
work: *mut f64,
lwork: *const i32,
info_out: *mut i32,
);
fn cblas_dtrsm(
order: i32,
side: i32,
uplo: i32,
transa: i32,
diag: i32,
m: i32,
n: i32,
alpha: f64,
a: *const f64,
lda: i32,
b: *mut f64,
ldb: i32,
);
}
#[cfg(not(feature = "blas"))]
#[allow(non_snake_case, clippy::too_many_arguments)]
#[inline]
unsafe fn cblas_dgemm(
_order: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: f64,
a: *const f64,
lda: i32,
b: *const f64,
ldb: i32,
beta: f64,
c: *mut f64,
ldc: i32,
) {
let m = m as usize;
let n = n as usize;
let k = k as usize;
let lda = lda as usize;
let ldb = ldb as usize;
let ldc = ldc as usize;
let trans_a = transa != NO_TRANS;
let trans_b = transb != NO_TRANS;
for i in 0..m {
for j in 0..n {
let mut acc: f64 = 0.0;
for p in 0..k {
let av = if trans_a {
unsafe { *a.add(p * lda + i) }
} else {
unsafe { *a.add(i * lda + p) }
};
let bv = if trans_b {
unsafe { *b.add(j * ldb + p) }
} else {
unsafe { *b.add(p * ldb + j) }
};
acc += av * bv;
}
let cp = unsafe { c.add(i * ldc + j) };
unsafe {
*cp = alpha * acc + beta * *cp;
}
}
}
}
#[cfg(not(feature = "blas"))]
#[allow(non_snake_case, clippy::too_many_arguments)]
unsafe fn lapack_dgesv(
n: *const i32,
nrhs: *const i32,
a: *mut f64,
lda: *const i32,
ipiv: *mut i32,
b: *mut f64,
ldb: *const i32,
info_out: *mut i32,
) {
let nn = unsafe { *n } as usize;
let nrhs = unsafe { *nrhs } as usize;
let lda = unsafe { *lda } as usize;
let ldb = unsafe { *ldb } as usize;
let aij = |a: *mut f64, i: usize, j: usize| unsafe { a.add(j * lda + i) };
for k in 0..nn {
let mut piv = k;
let mut max_abs = unsafe { *aij(a, k, k) }.abs();
for i in (k + 1)..nn {
let v = unsafe { *aij(a, i, k) }.abs();
if v > max_abs {
max_abs = v;
piv = i;
}
}
unsafe {
*ipiv.add(k) = (piv + 1) as i32;
}
if max_abs == 0.0 {
unsafe {
*info_out = (k + 1) as i32;
}
return;
}
if piv != k {
for j in 0..nn {
let p1 = aij(a, k, j);
let p2 = aij(a, piv, j);
unsafe {
std::ptr::swap(p1, p2);
}
}
for j in 0..nrhs {
let p1 = unsafe { b.add(j * ldb + k) };
let p2 = unsafe { b.add(j * ldb + piv) };
unsafe {
std::ptr::swap(p1, p2);
}
}
}
let akk = unsafe { *aij(a, k, k) };
for i in (k + 1)..nn {
let factor = unsafe { *aij(a, i, k) } / akk;
unsafe {
*aij(a, i, k) = factor;
}
for j in (k + 1)..nn {
let v = unsafe { *aij(a, i, j) } - factor * unsafe { *aij(a, k, j) };
unsafe {
*aij(a, i, j) = v;
}
}
for j in 0..nrhs {
let v = unsafe { *b.add(j * ldb + i) } - factor * unsafe { *b.add(j * ldb + k) };
unsafe {
*b.add(j * ldb + i) = v;
}
}
}
}
for j in 0..nrhs {
for i in (0..nn).rev() {
let mut sum = unsafe { *b.add(j * ldb + i) };
for k in (i + 1)..nn {
sum -= unsafe { *aij(a, i, k) } * unsafe { *b.add(j * ldb + k) };
}
unsafe {
*b.add(j * ldb + i) = sum / *aij(a, i, i);
}
}
}
unsafe {
*info_out = 0;
}
}
#[cfg(not(feature = "blas"))]
#[allow(non_snake_case, clippy::too_many_arguments)]
unsafe fn lapack_sgesv(
n: *const i32,
nrhs: *const i32,
a: *mut f32,
lda: *const i32,
ipiv: *mut i32,
b: *mut f32,
ldb: *const i32,
info_out: *mut i32,
) {
let nn = unsafe { *n } as usize;
let nrhs = unsafe { *nrhs } as usize;
let lda = unsafe { *lda } as usize;
let ldb = unsafe { *ldb } as usize;
let aij = |a: *mut f32, i: usize, j: usize| unsafe { a.add(j * lda + i) };
for k in 0..nn {
let mut piv = k;
let mut max_abs = unsafe { *aij(a, k, k) }.abs();
for i in (k + 1)..nn {
let v = unsafe { *aij(a, i, k) }.abs();
if v > max_abs {
max_abs = v;
piv = i;
}
}
unsafe {
*ipiv.add(k) = (piv + 1) as i32;
}
if max_abs == 0.0 {
unsafe {
*info_out = (k + 1) as i32;
}
return;
}
if piv != k {
for j in 0..nn {
let p1 = aij(a, k, j);
let p2 = aij(a, piv, j);
unsafe {
std::ptr::swap(p1, p2);
}
}
for j in 0..nrhs {
let p1 = unsafe { b.add(j * ldb + k) };
let p2 = unsafe { b.add(j * ldb + piv) };
unsafe {
std::ptr::swap(p1, p2);
}
}
}
let akk = unsafe { *aij(a, k, k) };
for i in (k + 1)..nn {
let factor = unsafe { *aij(a, i, k) } / akk;
unsafe {
*aij(a, i, k) = factor;
}
for j in (k + 1)..nn {
let v = unsafe { *aij(a, i, j) } - factor * unsafe { *aij(a, k, j) };
unsafe {
*aij(a, i, j) = v;
}
}
for j in 0..nrhs {
let v = unsafe { *b.add(j * ldb + i) } - factor * unsafe { *b.add(j * ldb + k) };
unsafe {
*b.add(j * ldb + i) = v;
}
}
}
}
for j in 0..nrhs {
for i in (0..nn).rev() {
let mut sum = unsafe { *b.add(j * ldb + i) };
for k in (i + 1)..nn {
sum -= unsafe { *aij(a, i, k) } * unsafe { *b.add(j * ldb + k) };
}
unsafe {
*b.add(j * ldb + i) = sum / *aij(a, i, i);
}
}
}
unsafe {
*info_out = 0;
}
}
#[inline]
pub fn dgemm(a: &[f64], b: &[f64], c: &mut [f64], m: usize, k: usize, n: usize) {
unsafe {
cblas_dgemm(
ROW_MAJOR,
NO_TRANS,
NO_TRANS,
m as i32,
n as i32,
k as i32,
1.0,
a.as_ptr(),
k as i32,
b.as_ptr(),
n as i32,
0.0,
c.as_mut_ptr(),
n as i32,
);
}
}
pub fn dgesv(a: &mut [f64], b: &mut [f64], n: usize, nrhs: usize) -> i32 {
assert_eq!(a.len(), n * n, "dgesv: A must be n×n");
assert_eq!(b.len(), n * nrhs, "dgesv: B must be n×nrhs");
for i in 0..n {
for j in (i + 1)..n {
a.swap(i * n + j, j * n + i);
}
}
if nrhs > 1 {
let mut tmp = vec![0f64; n * nrhs];
for i in 0..n {
for j in 0..nrhs {
tmp[j * n + i] = b[i * nrhs + j];
}
}
b.copy_from_slice(&tmp);
}
let mut ipiv = vec![0i32; n];
let mut info: i32 = 0;
let nn = n as i32;
let nrhs_i = nrhs as i32;
unsafe {
lapack_dgesv(
&nn,
&nrhs_i,
a.as_mut_ptr(),
&nn,
ipiv.as_mut_ptr(),
b.as_mut_ptr(),
&nn,
&mut info,
);
}
if nrhs > 1 && info == 0 {
let mut tmp = vec![0f64; n * nrhs];
for j in 0..nrhs {
for i in 0..n {
tmp[i * nrhs + j] = b[j * n + i];
}
}
b.copy_from_slice(&tmp);
}
info
}
pub fn sgesv(a: &mut [f32], b: &mut [f32], n: usize, nrhs: usize) -> i32 {
assert_eq!(a.len(), n * n, "sgesv: A must be n×n");
assert_eq!(b.len(), n * nrhs, "sgesv: B must be n×nrhs");
for i in 0..n {
for j in (i + 1)..n {
a.swap(i * n + j, j * n + i);
}
}
if nrhs > 1 {
let mut tmp = vec![0f32; n * nrhs];
for i in 0..n {
for j in 0..nrhs {
tmp[j * n + i] = b[i * nrhs + j];
}
}
b.copy_from_slice(&tmp);
}
let mut ipiv = vec![0i32; n];
let mut info: i32 = 0;
let nn = n as i32;
let nrhs_i = nrhs as i32;
unsafe {
lapack_sgesv(
&nn,
&nrhs_i,
a.as_mut_ptr(),
&nn,
ipiv.as_mut_ptr(),
b.as_mut_ptr(),
&nn,
&mut info,
);
}
if nrhs > 1 && info == 0 {
let mut tmp = vec![0f32; n * nrhs];
for j in 0..nrhs {
for i in 0..n {
tmp[i * nrhs + j] = b[j * n + i];
}
}
b.copy_from_slice(&tmp);
}
info
}
const ROW_MAJOR: i32 = 101;
const NO_TRANS: i32 = 111;
const TRANS: i32 = 112;
const CBLAS_LEFT: i32 = 141;
#[allow(dead_code)]
const CBLAS_RIGHT: i32 = 142;
const CBLAS_UPPER: i32 = 121;
const CBLAS_LOWER: i32 = 122;
const CBLAS_NON_UNIT: i32 = 131;
#[allow(dead_code)]
const CBLAS_UNIT: i32 = 132;
#[cfg(not(feature = "blas"))]
#[allow(non_snake_case, clippy::too_many_arguments)]
unsafe fn lapack_dpotrf(_: *const i8, _: *const i32, _: *mut f64, _: *const i32, info: *mut i32) {
unsafe {
*info = -1;
}
panic!("rlx-cpu: dpotrf requires the `blas` feature (LAPACK)");
}
#[cfg(not(feature = "blas"))]
#[allow(non_snake_case, clippy::too_many_arguments)]
unsafe fn lapack_dsyevd(
_: *const i8,
_: *const i8,
_: *const i32,
_: *mut f64,
_: *const i32,
_: *mut f64,
_: *mut f64,
_: *const i32,
_: *mut i32,
_: *const i32,
info: *mut i32,
) {
unsafe {
*info = -1;
}
panic!("rlx-cpu: dsyevd requires the `blas` feature (LAPACK)");
}
#[cfg(not(feature = "blas"))]
#[allow(non_snake_case, clippy::too_many_arguments)]
unsafe fn lapack_dgeqrf(
_: *const i32,
_: *const i32,
_: *mut f64,
_: *const i32,
_: *mut f64,
_: *mut f64,
_: *const i32,
info: *mut i32,
) {
unsafe {
*info = -1;
}
panic!("rlx-cpu: dgeqrf requires the `blas` feature (LAPACK)");
}
#[cfg(not(feature = "blas"))]
#[allow(non_snake_case, clippy::too_many_arguments)]
unsafe fn lapack_dorgqr(
_: *const i32,
_: *const i32,
_: *const i32,
_: *mut f64,
_: *const i32,
_: *const f64,
_: *mut f64,
_: *const i32,
info: *mut i32,
) {
unsafe {
*info = -1;
}
panic!("rlx-cpu: dorgqr requires the `blas` feature (LAPACK)");
}
#[cfg(not(feature = "blas"))]
#[allow(non_snake_case, clippy::too_many_arguments)]
unsafe fn lapack_dgesvd(
_: *const i8,
_: *const i8,
_: *const i32,
_: *const i32,
_: *mut f64,
_: *const i32,
_: *mut f64,
_: *mut f64,
_: *const i32,
_: *mut f64,
_: *const i32,
_: *mut f64,
_: *const i32,
info: *mut i32,
) {
unsafe {
*info = -1;
}
panic!("rlx-cpu: dgesvd requires the `blas` feature (LAPACK)");
}
#[cfg(not(feature = "blas"))]
#[allow(non_snake_case, clippy::too_many_arguments)]
unsafe fn cblas_dtrsm(
_: i32,
_: i32,
_: i32,
_: i32,
_: i32,
_: i32,
_: i32,
_: f64,
_: *const f64,
_: i32,
_: *mut f64,
_: i32,
) {
panic!("rlx-cpu: cblas_dtrsm requires the `blas` feature");
}
pub fn dpotrf(a: &mut [f64], n: usize, lower: bool) -> i32 {
assert_eq!(a.len(), n * n, "dpotrf: A must be n×n");
let uplo: i8 = if lower { b'U' as i8 } else { b'L' as i8 };
let nn = n as i32;
let mut info: i32 = 0;
unsafe {
lapack_dpotrf(&uplo, &nn, a.as_mut_ptr(), &nn, &mut info);
}
if info != 0 {
return info;
}
if lower {
for i in 0..n {
for j in (i + 1)..n {
a[i * n + j] = 0.0;
}
}
} else {
for i in 1..n {
for j in 0..i {
a[i * n + j] = 0.0;
}
}
}
info
}
pub fn dsyevd(a: &mut [f64], w: &mut [f64], n: usize) -> i32 {
assert_eq!(a.len(), n * n);
assert_eq!(w.len(), n);
let jobz: i8 = b'V' as i8;
let uplo: i8 = b'U' as i8;
let nn = n as i32;
let mut info: i32 = 0;
let lwork = (1 + 6 * n + 2 * n * n) as i32;
let liwork = (3 + 5 * n) as i32;
let mut work = vec![0f64; lwork.max(1) as usize];
let mut iwork = vec![0i32; liwork.max(1) as usize];
unsafe {
lapack_dsyevd(
&jobz,
&uplo,
&nn,
a.as_mut_ptr(),
&nn,
w.as_mut_ptr(),
work.as_mut_ptr(),
&lwork,
iwork.as_mut_ptr(),
&liwork,
&mut info,
);
}
info
}
pub fn dgeqrf_full(a: &mut [f64], m: usize, n: usize, q_out: &mut [f64], r_out: &mut [f64]) -> i32 {
assert_eq!(a.len(), m * n, "dgeqrf: A must be m×n");
let k = m.min(n);
assert_eq!(q_out.len(), m * k, "Q must be m×min(m,n)");
assert_eq!(r_out.len(), k * n, "R must be min(m,n)×n");
let mut a_col = transpose_to_col(a, m, n);
let mut tau = vec![0f64; k];
let mm = m as i32;
let nn = n as i32;
let kk = k as i32;
let lwork = (n.max(1)) as i32;
let mut work = vec![0f64; lwork.max(1) as usize];
let mut info: i32 = 0;
unsafe {
lapack_dgeqrf(
&mm,
&nn,
a_col.as_mut_ptr(),
&mm,
tau.as_mut_ptr(),
work.as_mut_ptr(),
&lwork,
&mut info,
);
}
if info != 0 {
return info;
}
for i in 0..k {
for j in 0..n {
let v = if i <= j { a_col[j * m + i] } else { 0.0 };
r_out[i * n + j] = v;
}
}
let mut work2 = vec![0f64; lwork.max(1) as usize];
let mut info2: i32 = 0;
unsafe {
lapack_dorgqr(
&mm,
&kk,
&kk,
a_col.as_mut_ptr(),
&mm,
tau.as_ptr(),
work2.as_mut_ptr(),
&lwork,
&mut info2,
);
}
if info2 != 0 {
return info2;
}
for i in 0..m {
for j in 0..k {
q_out[i * k + j] = a_col[j * m + i];
}
}
0
}
pub fn dgesvd_thin(
a: &mut [f64],
m: usize,
n: usize,
s: &mut [f64],
u: &mut [f64],
vt: &mut [f64],
) -> i32 {
assert_eq!(a.len(), m * n);
let k = m.min(n);
assert_eq!(s.len(), k);
assert_eq!(u.len(), m * k);
assert_eq!(vt.len(), k * n);
let mut a_col = transpose_to_col(a, m, n);
let mut u_col = vec![0f64; m * k];
let mut vt_col = vec![0f64; k * n];
let jobu = b'S' as i8;
let jobvt = b'S' as i8;
let mm = m as i32;
let nn = n as i32;
let ldu = m as i32;
let ldvt = k as i32;
let lwork = (((3 * k + m.max(n)).max(5 * k)) as i32).max(1);
let mut work = vec![0f64; lwork as usize];
let mut info: i32 = 0;
unsafe {
lapack_dgesvd(
&jobu,
&jobvt,
&mm,
&nn,
a_col.as_mut_ptr(),
&mm,
s.as_mut_ptr(),
u_col.as_mut_ptr(),
&ldu,
vt_col.as_mut_ptr(),
&ldvt,
work.as_mut_ptr(),
&lwork,
&mut info,
);
}
if info != 0 {
return info;
}
for i in 0..m {
for j in 0..k {
u[i * k + j] = u_col[j * m + i];
}
}
for i in 0..k {
for j in 0..n {
vt[i * n + j] = vt_col[j * k + i];
}
}
0
}
pub fn dtrsm_lower_or_upper(
a: &[f64],
b: &mut [f64],
n: usize,
nrhs: usize,
lower: bool,
transpose_a: bool,
) {
assert_eq!(a.len(), n * n);
assert_eq!(b.len(), n * nrhs);
unsafe {
cblas_dtrsm(
ROW_MAJOR,
CBLAS_LEFT,
if lower { CBLAS_LOWER } else { CBLAS_UPPER },
if transpose_a { TRANS } else { NO_TRANS },
CBLAS_NON_UNIT,
n as i32,
nrhs as i32,
1.0,
a.as_ptr(),
n as i32,
b.as_mut_ptr(),
nrhs as i32,
);
}
}
fn transpose_to_col(a_row: &[f64], m: usize, n: usize) -> Vec<f64> {
let mut out = vec![0f64; m * n];
for i in 0..m {
for j in 0..n {
out[j * m + i] = a_row[i * n + j];
}
}
out
}
#[inline]
pub fn sgemm(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
unsafe {
cblas_sgemm(
ROW_MAJOR,
NO_TRANS,
NO_TRANS,
m as i32,
n as i32,
k as i32,
1.0,
a.as_ptr(),
k as i32,
b.as_ptr(),
n as i32,
0.0,
c.as_mut_ptr(),
n as i32,
);
}
}
#[inline]
pub fn sgemm_accumulate(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
unsafe {
cblas_sgemm(
ROW_MAJOR,
NO_TRANS,
NO_TRANS,
m as i32,
n as i32,
k as i32,
1.0,
a.as_ptr(),
k as i32,
b.as_ptr(),
n as i32,
1.0,
c.as_mut_ptr(),
n as i32,
);
}
}
#[inline]
pub fn sgemm_bt(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize, alpha: f32) {
unsafe {
cblas_sgemm(
ROW_MAJOR,
NO_TRANS,
TRANS,
m as i32,
n as i32,
k as i32,
alpha,
a.as_ptr(),
k as i32,
b.as_ptr(),
k as i32, 0.0,
c.as_mut_ptr(),
n as i32,
);
}
}
#[inline]
pub fn sgemv_at(a: &[f32], x: &[f32], y: &mut [f32], n: usize, alpha: f32, beta: f32) {
unsafe {
cblas_sgemv(
ROW_MAJOR,
TRANS,
n as i32,
n as i32,
alpha,
a.as_ptr(),
n as i32,
x.as_ptr(),
1,
beta,
y.as_mut_ptr(),
1,
);
}
}
#[inline]
pub fn sger(a: &mut [f32], x: &[f32], y: &[f32], n: usize, alpha: f32) {
unsafe {
cblas_sger(
ROW_MAJOR,
n as i32,
n as i32,
alpha,
x.as_ptr(),
1,
y.as_ptr(),
1,
a.as_mut_ptr(),
n as i32,
);
}
}
#[inline]
pub fn sscal(x: &mut [f32], alpha: f32) {
if x.is_empty() {
return;
}
unsafe {
cblas_sscal(x.len() as i32, alpha, x.as_mut_ptr(), 1);
}
}
#[inline]
pub fn sgemm_strided(
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
k: usize,
n: usize,
lda: usize,
ldc: usize,
) {
unsafe {
cblas_sgemm(
ROW_MAJOR,
NO_TRANS,
NO_TRANS,
m as i32,
n as i32,
k as i32,
1.0,
a.as_ptr(),
lda as i32,
b.as_ptr(),
n as i32,
0.0,
c.as_mut_ptr(),
ldc as i32,
);
}
}
#[cfg(target_arch = "aarch64")]
pub fn bias_add(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
use std::arch::aarch64::*;
let chunks = n / 4;
unsafe {
for row in 0..m {
let base = row * n;
for c in 0..chunks {
let off = base + c * 4;
let v = vld1q_f32(data.as_ptr().add(off));
let b = vld1q_f32(bias.as_ptr().add(c * 4));
vst1q_f32(data.as_mut_ptr().add(off), vaddq_f32(v, b));
}
for i in (chunks * 4)..n {
data[base + i] += bias[i];
}
}
}
}
#[cfg(not(target_arch = "aarch64"))]
pub fn bias_add(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
for row in 0..m {
let base = row * n;
for i in 0..n {
data[base + i] += bias[i];
}
}
}
#[inline]
pub unsafe fn sgemm_general(
a: *const f32,
b: *const f32,
c: *mut f32,
m: usize,
n: usize,
k: usize,
alpha: f32,
beta: f32,
lda: usize,
ldb: usize,
ldc: usize,
trans_a: bool,
trans_b: bool,
) {
unsafe {
cblas_sgemm(
ROW_MAJOR,
if trans_a { TRANS } else { NO_TRANS },
if trans_b { TRANS } else { NO_TRANS },
m as i32,
n as i32,
k as i32,
alpha,
a,
lda as i32,
b,
ldb as i32,
beta,
c,
ldc as i32,
);
}
}
#[inline]
pub fn sgemm_bias(a: &[f32], b: &[f32], bias: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
if crate::cost::hw_model().prefer_neon_sgemm(m, k, n) {
crate::kernels::neon_sgemm_bias_small(a, b, bias, c, m, k, n);
} else {
sgemm(a, b, c, m, k, n);
bias_add(c, bias, m, n);
}
}
#[inline]
pub fn sgemm_epilogue<E: Fn(f32) -> f32>(
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
k: usize,
n: usize,
epilogue: E,
) {
sgemm(a, b, c, m, k, n);
for v in c.iter_mut() {
*v = epilogue(*v);
}
}
#[inline]
pub fn sgemm_bias_epilogue<E: Fn(f32) -> f32>(
a: &[f32],
b: &[f32],
bias: &[f32],
c: &mut [f32],
m: usize,
k: usize,
n: usize,
activation: E,
) {
sgemm(a, b, c, m, k, n);
for i in 0..m {
let row = &mut c[i * n..(i + 1) * n];
for (j, v) in row.iter_mut().enumerate() {
*v = activation(*v + bias[j]);
}
}
}
#[inline]
pub fn sgemm_auto(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
#[cfg(feature = "parity-gemm")]
{
sgemm_via_gemm_crate(a, b, c, m, k, n);
return;
}
#[cfg(not(feature = "parity-gemm"))]
if m <= 8 && crate::cost::hw_model().prefer_neon_sgemm(m, k, n) {
crate::kernels::neon_sgemm_small(a, b, c, m, k, n);
} else if m < 32 {
par_sgemm(a, b, c, m, k, n);
} else {
sgemm(a, b, c, m, k, n);
}
}
#[cfg(feature = "parity-gemm")]
fn sgemm_via_gemm_crate(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
use gemm::{Parallelism, gemm};
let cfg = crate::config::RuntimeConfig::global();
let workers = cfg.pool_workers + 1;
let par = if workers > 1 {
Parallelism::Rayon(workers)
} else {
Parallelism::None
};
unsafe {
gemm(
m,
n,
k,
c.as_mut_ptr(),
1, n as isize, false, a.as_ptr(),
1, k as isize, b.as_ptr(),
1, n as isize, 0.0, 1.0, false, false, false, par,
);
}
}
pub fn par_sgemm(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
let cfg = crate::config::RuntimeConfig::global();
let workers = cfg.pool_workers + 1;
let total_flops = (m * k * n) as u64;
if m >= 32 || total_flops < 2_000_000 || n < workers * 32 {
sgemm(a, b, c, m, k, n);
return;
}
let chunk = n / workers;
let a_addr = a.as_ptr() as usize;
let b_addr = b.as_ptr() as usize;
let c_addr = c.as_mut_ptr() as usize;
crate::pool::par_for(workers, 1, &|off, cnt| {
for w in off..off + cnt {
let n_start = w * chunk;
let n_end = if w + 1 == workers { n } else { (w + 1) * chunk };
let local_n = n_end - n_start;
if local_n == 0 {
continue;
}
unsafe {
cblas_sgemm(
101,
111,
111,
m as i32,
local_n as i32,
k as i32,
1.0,
a_addr as *const f32,
k as i32,
(b_addr as *const f32).add(n_start),
n as i32,
0.0,
(c_addr as *mut f32).add(n_start),
n as i32,
);
}
}
});
}
pub fn par_sgemm_bias(
a: &[f32],
b: &[f32],
bias: &[f32],
c: &mut [f32],
m: usize,
k: usize,
n: usize,
) {
let cfg = crate::config::RuntimeConfig::global();
let workers = cfg.pool_workers + 1; let total_flops = (m * k * n) as u64;
if m >= 32 || total_flops < 2_000_000 || n < workers * 32 {
sgemm_bias(a, b, bias, c, m, k, n);
return;
}
let chunk = n / workers;
let a_addr = a.as_ptr() as usize;
let b_addr = b.as_ptr() as usize;
let bias_addr = bias.as_ptr() as usize;
let c_addr = c.as_mut_ptr() as usize;
crate::pool::par_for(workers, 1, &|off, cnt| {
for w in off..off + cnt {
let n_start = w * chunk;
let n_end = if w + 1 == workers { n } else { (w + 1) * chunk };
let local_n = n_end - n_start;
if local_n == 0 {
continue;
}
unsafe {
cblas_sgemm(
101,
111,
111, m as i32,
local_n as i32,
k as i32,
1.0,
a_addr as *const f32,
k as i32,
(b_addr as *const f32).add(n_start),
n as i32,
0.0,
(c_addr as *mut f32).add(n_start),
n as i32,
);
let local_bias =
std::slice::from_raw_parts((bias_addr as *const f32).add(n_start), local_n);
let local_c =
std::slice::from_raw_parts_mut((c_addr as *mut f32).add(n_start), m * n);
#[cfg(target_arch = "aarch64")]
{
use std::arch::aarch64::*;
let chunks = local_n / 4;
for row in 0..m {
let base = row * n;
for c in 0..chunks {
let off = base + c * 4;
let v = vld1q_f32(local_c.as_ptr().add(off));
let bv = vld1q_f32(local_bias.as_ptr().add(c * 4));
vst1q_f32(local_c.as_mut_ptr().add(off), vaddq_f32(v, bv));
}
for i in (chunks * 4)..local_n {
local_c[base + i] += local_bias[i];
}
}
}
#[cfg(not(target_arch = "aarch64"))]
for row in 0..m {
let base = row * n;
for i in 0..local_n {
local_c[base + i] += local_bias[i];
}
}
}
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sgemm_identity() {
let a = [1.0, 0.0, 0.0, 1.0f32];
let b = [3.0, 4.0, 5.0, 6.0f32];
let mut c = [0.0f32; 4];
sgemm(&a, &b, &mut c, 2, 2, 2);
assert_eq!(c, [3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn sgemm_rectangular() {
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0f32]; let b = [1.0, 0.0, 0.0, 1.0, 1.0, 0.0f32]; let mut c = [0.0f32; 4];
sgemm(&a, &b, &mut c, 2, 3, 2);
assert_eq!(c, [4.0, 2.0, 10.0, 5.0]);
}
#[test]
fn sgemm_bias_test() {
let a = [1.0, 0.0, 0.0, 1.0f32]; let b = [3.0, 4.0, 5.0, 6.0f32]; let bias = [10.0, 20.0f32]; let mut c = [0.0f32; 4];
sgemm_bias(&a, &b, &bias, &mut c, 2, 2, 2);
assert_eq!(c, [13.0, 24.0, 15.0, 26.0]);
}
#[test]
fn dgemm_identity() {
let a = [1.0, 0.0, 0.0, 1.0f64];
let b = [3.0, 4.0, 5.0, 6.0f64];
let mut c = [0.0f64; 4];
dgemm(&a, &b, &mut c, 2, 2, 2);
assert_eq!(c, [3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn dgesv_2x2_known_solution() {
let mut a = [2.0, 1.0, 1.0, 3.0_f64];
let mut b = [5.0, 10.0_f64];
let info = dgesv(&mut a, &mut b, 2, 1);
assert_eq!(info, 0, "dgesv signaled singular: info={info}");
let want = [1.0, 3.0_f64];
for (i, (g, w)) in b.iter().zip(want.iter()).enumerate() {
assert!((g - w).abs() < 1e-12, "x[{i}] = {g}, expected {w}");
}
}
#[test]
fn dgesv_3x3_general() {
let a_orig = [4.0, -1.0, 0.0, -1.0, 4.0, -1.0, 0.0, -1.0, 4.0_f64];
let mut a = a_orig;
let mut b = [1.0, 0.0, -1.0_f64];
let info = dgesv(&mut a, &mut b, 3, 1);
assert_eq!(info, 0);
let mut residual = [0.0_f64; 3];
for i in 0..3 {
for j in 0..3 {
residual[i] += a_orig[i * 3 + j] * b[j];
}
}
let want_b = [1.0, 0.0, -1.0_f64];
for i in 0..3 {
assert!(
(residual[i] - want_b[i]).abs() < 1e-12,
"residual[{i}] = {} vs {}",
residual[i],
want_b[i]
);
}
}
#[test]
fn sgemm_bt_test() {
let q = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0f32]; let k = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0f32]; let mut scores = [0.0f32; 4];
sgemm_bt(&q, &k, &mut scores, 2, 3, 2, 1.0);
assert_eq!(scores, [1.0, 0.0, 0.0, 1.0]);
}
#[test]
fn sgemm_epilogue_matches_post_pass() {
let a = [1.0f32, -2.0, 3.0, -4.0]; let b = [1.0f32, 0.0, 0.0, 1.0]; let mut c1 = [0f32; 4];
let mut c2 = [0f32; 4];
sgemm(&a, &b, &mut c1, 2, 2, 2);
sgemm_epilogue(&a, &b, &mut c2, 2, 2, 2, |x| x);
assert_eq!(c1, c2);
let mut c3 = [0f32; 4];
sgemm_epilogue(&a, &b, &mut c3, 2, 2, 2, |x| x.max(0.0));
assert_eq!(c3, [1.0, 0.0, 3.0, 0.0]);
}
#[test]
fn sgemm_bias_epilogue_matches_reference() {
let a = [1.0f32, 2.0, 3.0, 4.0]; let b = [1.0f32, 0.0, 0.0, 1.0]; let bias = [10.0f32, 100.0];
let mut c = [0f32; 4];
sgemm_bias_epilogue(&a, &b, &bias, &mut c, 2, 2, 2, |x| x.max(0.0));
assert_eq!(c, [11.0, 102.0, 13.0, 104.0]);
}
}