use crate::bridge;
use crate::error::{Error, Result};
pub mod blas_order {
pub const ROW_MAJOR: i32 = 101;
pub const COL_MAJOR: i32 = 102;
}
pub mod blas_transpose {
pub const NO_TRANS: i32 = 111;
pub const TRANS: i32 = 112;
pub const CONJ_TRANS: i32 = 113;
}
fn i32_len(value: usize) -> Result<i32> {
i32::try_from(value).map_err(|_| Error::OperationFailed("dimension exceeds i32"))
}
pub fn sdot(x: &[f32], y: &[f32]) -> Result<f32> {
if x.len() != y.len() {
return Err(Error::InvalidLength {
expected: x.len(),
actual: y.len(),
});
}
let n = i32_len(x.len())?;
Ok(unsafe { bridge::acc_blas_sdot(n, x.as_ptr(), y.as_ptr()) })
}
pub fn sgemv_row_major(
rows: usize,
columns: usize,
alpha: f32,
matrix: &[f32],
x: &[f32],
beta: f32,
y: &mut [f32],
) -> Result<()> {
let expected_matrix = rows
.checked_mul(columns)
.ok_or(Error::OperationFailed("matrix dimensions overflowed"))?;
if matrix.len() != expected_matrix {
return Err(Error::InvalidLength {
expected: expected_matrix,
actual: matrix.len(),
});
}
if x.len() != columns {
return Err(Error::InvalidLength {
expected: columns,
actual: x.len(),
});
}
if y.len() != rows {
return Err(Error::InvalidLength {
expected: rows,
actual: y.len(),
});
}
let rows_i32 = i32_len(rows)?;
let columns_i32 = i32_len(columns)?;
let ok = unsafe {
bridge::acc_blas_sgemv_row_major(
rows_i32,
columns_i32,
alpha,
matrix.as_ptr(),
x.as_ptr(),
beta,
y.as_mut_ptr(),
)
};
if ok {
Ok(())
} else {
Err(Error::OperationFailed("CBLAS sgemv failed"))
}
}
#[allow(clippy::too_many_arguments)]
pub fn sgemm_row_major(
rows: usize,
columns: usize,
inner_dimension: usize,
alpha: f32,
lhs: &[f32],
rhs: &[f32],
beta: f32,
output: &mut [f32],
) -> Result<()> {
let expected_lhs = rows
.checked_mul(inner_dimension)
.ok_or(Error::OperationFailed("matrix dimensions overflowed"))?;
let expected_rhs = inner_dimension
.checked_mul(columns)
.ok_or(Error::OperationFailed("matrix dimensions overflowed"))?;
let expected_output = rows
.checked_mul(columns)
.ok_or(Error::OperationFailed("matrix dimensions overflowed"))?;
if lhs.len() != expected_lhs {
return Err(Error::InvalidLength {
expected: expected_lhs,
actual: lhs.len(),
});
}
if rhs.len() != expected_rhs {
return Err(Error::InvalidLength {
expected: expected_rhs,
actual: rhs.len(),
});
}
if output.len() != expected_output {
return Err(Error::InvalidLength {
expected: expected_output,
actual: output.len(),
});
}
let rows_i32 = i32_len(rows)?;
let columns_i32 = i32_len(columns)?;
let inner_dimension_i32 = i32_len(inner_dimension)?;
let ok = unsafe {
bridge::acc_blas_sgemm_row_major(
rows_i32,
columns_i32,
inner_dimension_i32,
alpha,
lhs.as_ptr(),
rhs.as_ptr(),
beta,
output.as_mut_ptr(),
)
};
if ok {
Ok(())
} else {
Err(Error::OperationFailed("CBLAS sgemm failed"))
}
}