use runmat_builtins::Tensor as Matrix;
fn transpose_to_column_major(matrix: &Matrix) -> Vec<f64> {
let mut result = vec![0.0; matrix.data.len()];
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
result[j * matrix.rows() + i] = matrix.data[i * matrix.cols() + j];
}
}
result
}
pub fn blas_matrix_mul(a: &Matrix, b: &Matrix) -> Result<Matrix, String> {
if a.cols() != b.rows() {
return Err(format!(
"Inner matrix dimensions must agree: {}x{} * {}x{}",
a.rows(),
a.cols(),
b.rows(),
b.cols()
));
}
let m = a.rows() as i32;
let n = b.cols() as i32;
let k = a.cols() as i32;
let a_col_major = transpose_to_column_major(a);
let b_col_major = transpose_to_column_major(b);
let mut c_col_major = vec![0.0; (m * n) as usize];
unsafe {
blas::dgemm(
b'N', b'N', m, n, k, 1.0, &a_col_major, m, &b_col_major, k, 0.0, &mut c_col_major, m, );
}
Matrix::new_2d(c_col_major, a.rows(), b.cols())
}
pub fn blas_matrix_vector_mul(matrix: &Matrix, vector: &[f64]) -> Result<Vec<f64>, String> {
if matrix.cols() != vector.len() {
return Err(format!(
"Matrix columns {} must match vector length {}",
matrix.cols(),
vector.len()
));
}
let m = matrix.rows() as i32;
let n = matrix.cols() as i32;
let mut result = vec![0.0; matrix.rows()];
let matrix_col_major = transpose_to_column_major(matrix);
unsafe {
blas::dgemv(
b'N', m, n, 1.0, &matrix_col_major, m, vector, 1, 0.0, &mut result, 1, );
}
Ok(result)
}
pub fn blas_dot_product(a: &[f64], b: &[f64]) -> Result<f64, String> {
if a.len() != b.len() {
return Err(format!(
"Vector lengths must match: {} vs {}",
a.len(),
b.len()
));
}
let n = a.len() as i32;
unsafe { Ok(blas::ddot(n, a, 1, b, 1)) }
}
pub fn blas_vector_norm(vector: &[f64]) -> f64 {
let n = vector.len() as i32;
unsafe { blas::dnrm2(n, vector, 1) }
}
pub fn blas_scale_vector(vector: &mut [f64], alpha: f64) {
let n = vector.len() as i32;
unsafe {
blas::dscal(n, alpha, vector, 1);
}
}
pub fn blas_vector_add(alpha: f64, x: &[f64], y: &mut [f64]) -> Result<(), String> {
if x.len() != y.len() {
return Err(format!(
"Vector lengths must match: {} vs {}",
x.len(),
y.len()
));
}
let n = x.len() as i32;
unsafe {
blas::daxpy(n, alpha, x, 1, y, 1);
}
Ok(())
}