use crate::{Backend, TruenoError, Vector};
#[cfg(feature = "tracing")]
use tracing::instrument;
macro_rules! dispatch_dot {
($backend:expr, $a:expr, $b:expr) => {{
#[cfg(target_arch = "x86_64")]
use crate::backends::{avx2::Avx2Backend, sse2::Sse2Backend};
use crate::backends::{scalar::ScalarBackend, VectorBackend};
unsafe {
match $backend {
Backend::Scalar => ScalarBackend::dot($a, $b),
#[cfg(target_arch = "x86_64")]
Backend::SSE2 | Backend::AVX => Sse2Backend::dot($a, $b),
#[cfg(target_arch = "x86_64")]
Backend::AVX2 | Backend::AVX512 => Avx2Backend::dot($a, $b),
#[cfg(not(target_arch = "x86_64"))]
Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
ScalarBackend::dot($a, $b)
}
#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
Backend::NEON => {
use crate::backends::neon::NeonBackend;
NeonBackend::dot($a, $b)
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
Backend::NEON => ScalarBackend::dot($a, $b),
#[cfg(target_arch = "wasm32")]
Backend::WasmSIMD => {
use crate::backends::wasm::WasmBackend;
WasmBackend::dot($a, $b)
}
#[cfg(not(target_arch = "wasm32"))]
Backend::WasmSIMD => ScalarBackend::dot($a, $b),
Backend::GPU | Backend::Auto => ScalarBackend::dot($a, $b),
}
}
}};
}
use super::super::Matrix;
impl Matrix<f32> {
#[cfg_attr(feature = "tracing", instrument(skip(self), fields(dims = %format!("{}x{}", self.rows, self.cols))))]
pub fn transpose(&self) -> Matrix<f32> {
let mut result = Matrix::zeros_with_backend(self.cols, self.rows, self.backend);
if let Err(e) =
crate::blis::transpose::transpose(self.rows, self.cols, &self.data, &mut result.data)
{
debug_assert!(false, "BLIS transpose dimension mismatch: {e}");
for i in 0..self.rows {
for j in 0..self.cols {
result.data[j * self.rows + i] = self.data[i * self.cols + j];
}
}
}
result
}
pub fn matvec(&self, v: &Vector<f32>) -> Result<Vector<f32>, TruenoError> {
if v.len() != self.cols {
return Err(TruenoError::InvalidInput(format!(
"Vector length {} does not match matrix columns {} for matrix-vector multiplication",
v.len(),
self.cols
)));
}
let v_slice = v.as_slice();
let mut result_data = vec![0.0; self.rows];
#[cfg(feature = "parallel")]
{
const PARALLEL_THRESHOLD: usize = 4096;
if self.rows >= PARALLEL_THRESHOLD {
use rayon::prelude::*;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::sync::Arc;
let result_ptr = Arc::new(AtomicPtr::new(result_data.as_mut_ptr()));
(0..self.rows).into_par_iter().for_each(|i| {
let row_start = i * self.cols;
let row = &self.data[row_start..(row_start + self.cols)];
let dot_result = dispatch_dot!(self.backend, row, v_slice);
unsafe {
let ptr = result_ptr.load(Ordering::Relaxed);
*ptr.add(i) = dot_result;
}
});
return Ok(Vector::from_slice(&result_data));
}
}
for (i, result) in result_data.iter_mut().enumerate() {
let row_start = i * self.cols;
let row = &self.data[row_start..(row_start + self.cols)];
*result = dispatch_dot!(self.backend, row, v_slice);
}
Ok(Vector::from_slice(&result_data))
}
pub fn vecmat(v: &Vector<f32>, m: &Matrix<f32>) -> Result<Vector<f32>, TruenoError> {
if v.len() != m.rows {
return Err(TruenoError::InvalidInput(format!(
"Vector length {} does not match matrix rows {} for vector-matrix multiplication",
v.len(),
m.rows
)));
}
let mut result_data = vec![0.0f32; m.cols];
crate::blis::gemv::gemv(m.rows, m.cols, v.as_slice(), &m.data, &mut result_data);
Ok(Vector::from_slice(&result_data))
}
}
#[cfg(test)]
mod tests;