use crate::array::Array;
use crate::array_view::ArrayView;
use crate::backend::dispatch::get_dispatch_table;
use crate::llo::ElementwiseKind;
use crate::llo::reduction::ReductionKind;
use anyhow::{Result, anyhow};
pub fn elementwise_f32(
a: &[f32],
b: &[f32],
out: &mut [f32],
kind: ElementwiseKind,
) -> Result<()> {
if a.len() != b.len() || a.len() != out.len() {
return Err(anyhow!(
"Length mismatch: a={}, b={}, out={}",
a.len(), b.len(), out.len()
));
}
let len = a.len();
let a_arr = Array::new(vec![len], a.to_vec());
let b_arr = Array::new(vec![len], b.to_vec());
let table = get_dispatch_table();
let result = (table.elementwise)(&a_arr, &b_arr, kind)?;
if result.data.len() != len {
return Err(anyhow!("Result length mismatch"));
}
out.copy_from_slice(&result.data);
Ok(())
}
pub fn elementwise_view(
a: &ArrayView,
b: &ArrayView,
out: *mut std::ffi::c_void,
out_len: usize,
kind: ElementwiseKind,
) -> Result<()> {
use crate::array::DType;
if a.dtype() != b.dtype() {
return Err(anyhow!("Type mismatch: a is {:?}, b is {:?}", a.dtype(), b.dtype()));
}
match a.dtype() {
DType::F32 => {
let a_slice = a.as_f32().unwrap();
let b_slice = b.as_f32().unwrap();
let out_slice = unsafe {
std::slice::from_raw_parts_mut(out as *mut f32, out_len)
};
if a_slice.len() != b_slice.len() || a_slice.len() != out_len {
return Err(anyhow!(
"Length mismatch: a={}, b={}, out={}",
a_slice.len(), b_slice.len(), out_len
));
}
let len = a_slice.len();
let a_arr = Array::new(vec![len], a_slice.to_vec());
let b_arr = Array::new(vec![len], b_slice.to_vec());
let table = get_dispatch_table();
let result = (table.elementwise)(&a_arr, &b_arr, kind)?;
out_slice.copy_from_slice(&result.data);
Ok(())
}
DType::F64 => {
let a_slice = a.as_f64().unwrap();
let b_slice = b.as_f64().unwrap();
let out_slice = unsafe {
std::slice::from_raw_parts_mut(out as *mut f64, out_len)
};
if a_slice.len() != b_slice.len() || a_slice.len() != out_len {
return Err(anyhow!(
"Length mismatch: a={}, b={}, out={}",
a_slice.len(), b_slice.len(), out_len
));
}
let len = a_slice.len();
let a_arr = Array::new(vec![len], a_slice.to_vec());
let b_arr = Array::new(vec![len], b_slice.to_vec());
let result = crate::backend::dispatch::dispatch_elementwise_generic(&a_arr, &b_arr, kind)?;
out_slice.copy_from_slice(&result.data);
Ok(())
}
DType::I32 => {
let a_slice = a.as_i32().unwrap();
let b_slice = b.as_i32().unwrap();
let out_slice = unsafe {
std::slice::from_raw_parts_mut(out as *mut i32, out_len)
};
if a_slice.len() != b_slice.len() || a_slice.len() != out_len {
return Err(anyhow!(
"Length mismatch: a={}, b={}, out={}",
a_slice.len(), b_slice.len(), out_len
));
}
let len = a_slice.len();
let a_arr = Array::new(vec![len], a_slice.to_vec());
let b_arr = Array::new(vec![len], b_slice.to_vec());
let result = crate::backend::dispatch::dispatch_elementwise_generic(&a_arr, &b_arr, kind)?;
out_slice.copy_from_slice(&result.data);
Ok(())
}
_ => Err(anyhow!("Unsupported dtype: {:?}", a.dtype())),
}
}
pub fn reduce_f32(
data: &[f32],
kind: ReductionKind,
) -> Result<f32> {
let len = data.len();
if len == 0 {
return Err(anyhow!("Cannot reduce empty array"));
}
let arr = Array::new(vec![len], data.to_vec());
let table = get_dispatch_table();
let result = (table.reduction)(&arr, None, kind)?;
if result.data.is_empty() {
return Err(anyhow!("Reduction returned empty result"));
}
Ok(result.data[0])
}
pub fn matmul_f32(
a: &[f32],
b: &[f32],
out: &mut [f32],
m: usize,
k: usize,
n: usize,
) -> Result<()> {
if a.len() != m * k {
return Err(anyhow!("Matrix A size mismatch: expected {}, got {}", m * k, a.len()));
}
if b.len() != k * n {
return Err(anyhow!("Matrix B size mismatch: expected {}, got {}", k * n, b.len()));
}
if out.len() != m * n {
return Err(anyhow!("Output matrix size mismatch: expected {}, got {}", m * n, out.len()));
}
let a_arr = Array::new(vec![m, k], a.to_vec());
let b_arr = Array::new(vec![k, n], b.to_vec());
let table = get_dispatch_table();
let result = (table.matmul)(&a_arr, &b_arr)?;
if result.data.len() != m * n {
return Err(anyhow!("Result size mismatch"));
}
out.copy_from_slice(&result.data);
Ok(())
}
pub fn dot_f32(
a: &[f32],
b: &[f32],
) -> Result<f32> {
if a.len() != b.len() {
return Err(anyhow!("Vector length mismatch: a={}, b={}", a.len(), b.len()));
}
let len = a.len();
let a_arr = Array::new(vec![len], a.to_vec());
let b_arr = Array::new(vec![len], b.to_vec());
let table = get_dispatch_table();
let result = (table.dot)(&a_arr, &b_arr)?;
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_elementwise_add() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![10.0, 20.0, 30.0, 40.0];
let mut out = vec![0.0; 4];
elementwise_f32(&a, &b, &mut out, ElementwiseKind::Add).unwrap();
assert_eq!(out, vec![11.0, 22.0, 33.0, 44.0]);
}
#[test]
fn test_elementwise_mul() {
let a = vec![2.0, 3.0, 4.0, 5.0];
let b = vec![10.0, 10.0, 10.0, 10.0];
let mut out = vec![0.0; 4];
elementwise_f32(&a, &b, &mut out, ElementwiseKind::Mul).unwrap();
assert_eq!(out, vec![20.0, 30.0, 40.0, 50.0]);
}
#[test]
fn test_reduce_sum() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = reduce_f32(&data, ReductionKind::Sum).unwrap();
assert_eq!(result, 15.0);
}
#[test]
fn test_reduce_mean() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = reduce_f32(&data, ReductionKind::Mean).unwrap();
assert_eq!(result, 3.0);
}
#[test]
fn test_matmul() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let mut out = vec![0.0; 4];
matmul_f32(&a, &b, &mut out, 2, 2, 2).unwrap();
assert_eq!(out, vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_dot() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![10.0, 20.0, 30.0, 40.0];
let result = dot_f32(&a, &b).unwrap();
assert_eq!(result, 300.0);
}
#[test]
fn test_large_arrays() {
let size = 10000;
let a = vec![1.0; size];
let b = vec![2.0; size];
let mut out = vec![0.0; size];
elementwise_f32(&a, &b, &mut out, ElementwiseKind::Add).unwrap();
assert!(out.iter().all(|&x| (x - 3.0).abs() < 1e-6));
}
}