use crate::lower::{LoweredOp, OxiOp};
use oxiblas_core::simd::{SimdLevel, SimdRegister, detect_simd_level};
#[cfg(target_arch = "aarch64")]
use oxiblas_core::simd::aarch64::{F64x2, F64x4 as F64x4Aarch};
#[cfg(target_arch = "x86_64")]
use oxiblas_core::simd::x86_64::{F64x2Sse, F64x4};
#[cfg(feature = "parallel")]
const PARALLEL_SIMD_THRESHOLD: usize = 512;
pub fn eval_batch_simd(ops: &[OxiOp], data: &[Vec<f64>]) -> Vec<f64> {
if data.is_empty() {
return Vec::new();
}
#[cfg(target_arch = "aarch64")]
{
match detect_simd_level() {
SimdLevel::Simd256 | SimdLevel::Simd512 => {
eval_chunks_dispatch::<F64x4Aarch>(ops, data)
}
SimdLevel::Simd128 => eval_chunks_dispatch::<F64x2>(ops, data),
SimdLevel::Scalar => LoweredOp::eval_batch_scalar_from_ops(ops, data),
}
}
#[cfg(target_arch = "x86_64")]
{
match detect_simd_level() {
SimdLevel::Simd256 | SimdLevel::Simd512 => eval_chunks_dispatch::<F64x4>(ops, data),
SimdLevel::Simd128 => eval_chunks_dispatch::<F64x2Sse>(ops, data),
SimdLevel::Scalar => LoweredOp::eval_batch_scalar_from_ops(ops, data),
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
LoweredOp::eval_batch_scalar_from_ops(ops, data)
}
}
fn eval_chunks_dispatch<V>(ops: &[OxiOp], data: &[Vec<f64>]) -> Vec<f64>
where
V: SimdRegister<Scalar = f64>,
{
#[cfg(feature = "parallel")]
{
if data.len() >= PARALLEL_SIMD_THRESHOLD {
return eval_chunks_parallel::<V>(ops, data);
}
}
eval_chunks::<V>(ops, data)
}
fn eval_chunks<V>(ops: &[OxiOp], data: &[Vec<f64>]) -> Vec<f64>
where
V: SimdRegister<Scalar = f64>,
{
let lanes = V::LANES;
let n = data.len();
let full_chunks = n / lanes;
let mut results = Vec::with_capacity(n);
for chunk_idx in 0..full_chunks {
let base = chunk_idx * lanes;
let reg = eval_simd_chunk::<V>(ops, data, base);
for i in 0..lanes {
results.push(reg.extract(i));
}
}
let remainder_start = full_chunks * lanes;
for row in data.iter().take(n).skip(remainder_start) {
results.push(LoweredOp::eval_ops(ops, row));
}
results
}
fn eval_simd_chunk<V>(ops: &[OxiOp], data: &[Vec<f64>], base: usize) -> V
where
V: SimdRegister<Scalar = f64>,
{
let lanes = V::LANES;
let mut stack: Vec<V> = Vec::with_capacity(ops.len());
for op in ops {
match op {
OxiOp::Const(c) => stack.push(V::splat(*c)),
OxiOp::Var(i) => {
let mut reg = V::zero();
for lane in 0..lanes {
let val = data[base + lane].get(*i).copied().unwrap_or(f64::NAN);
reg = reg.insert(lane, val);
}
stack.push(reg);
}
OxiOp::Add => {
let b = stack.pop().unwrap_or_else(V::zero);
let a = stack.pop().unwrap_or_else(V::zero);
stack.push(a.add(b));
}
OxiOp::Sub => {
let b = stack.pop().unwrap_or_else(V::zero);
let a = stack.pop().unwrap_or_else(V::zero);
stack.push(a.sub(b));
}
OxiOp::Mul => {
let b = stack.pop().unwrap_or_else(V::zero);
let a = stack.pop().unwrap_or_else(V::zero);
stack.push(a.mul(b));
}
OxiOp::Div => {
let b = stack.pop().unwrap_or_else(V::zero);
let a = stack.pop().unwrap_or_else(V::zero);
stack.push(a.div(b));
}
OxiOp::Neg => {
let a = stack.pop().unwrap_or_else(V::zero);
stack.push(V::zero().sub(a));
}
OxiOp::Exp => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).exp());
}
stack.push(reg);
}
OxiOp::Ln => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).ln());
}
stack.push(reg);
}
OxiOp::Sin => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).sin());
}
stack.push(reg);
}
OxiOp::Cos => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).cos());
}
stack.push(reg);
}
OxiOp::Pow => {
let b = stack.pop().unwrap_or_else(V::zero);
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).powf(b.extract(lane)));
}
stack.push(reg);
}
OxiOp::Tan => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).tan());
}
stack.push(reg);
}
OxiOp::Sinh => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).sinh());
}
stack.push(reg);
}
OxiOp::Cosh => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).cosh());
}
stack.push(reg);
}
OxiOp::Tanh => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).tanh());
}
stack.push(reg);
}
OxiOp::Arcsin => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).asin());
}
stack.push(reg);
}
OxiOp::Arccos => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).acos());
}
stack.push(reg);
}
OxiOp::Arctan => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).atan());
}
stack.push(reg);
}
OxiOp::Arcsinh => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).asinh());
}
stack.push(reg);
}
OxiOp::Arccosh => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).acosh());
}
stack.push(reg);
}
OxiOp::Arctanh => {
let a = stack.pop().unwrap_or_else(V::zero);
let mut reg = V::zero();
for lane in 0..lanes {
reg = reg.insert(lane, a.extract(lane).atanh());
}
stack.push(reg);
}
}
}
stack.pop().unwrap_or_else(V::zero)
}
#[cfg(feature = "parallel")]
fn eval_chunks_parallel<V>(ops: &[OxiOp], data: &[Vec<f64>]) -> Vec<f64>
where
V: SimdRegister<Scalar = f64>,
{
use rayon::prelude::*;
let num_threads = rayon::current_num_threads().max(1);
let chunk_size = data.len().div_ceil(num_threads).max(V::LANES);
data.par_chunks(chunk_size)
.flat_map_iter(|chunk| eval_chunks::<V>(ops, chunk))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Canonical;
fn exp_lowered() -> crate::lower::LoweredOp {
let x = crate::tree::EmlTree::var(0);
Canonical::exp(&x).lower()
}
#[test]
fn test_eval_batch_simd_matches_scalar() {
let lowered = exp_lowered();
let ops = lowered.to_oxiblas_ops();
let data: Vec<Vec<f64>> = (0..256).map(|i| vec![i as f64 * 0.01]).collect();
let simd_results = eval_batch_simd(&ops, &data);
let scalar_results = LoweredOp::eval_batch_scalar_from_ops(&ops, &data);
assert_eq!(simd_results.len(), scalar_results.len());
for (i, (s, r)) in simd_results.iter().zip(scalar_results.iter()).enumerate() {
assert!(
(s - r).abs() < 1e-12,
"row {i}: SIMD={s} scalar={r} diff={}",
(s - r).abs()
);
}
}
#[test]
fn test_eval_batch_simd_transcendentals() {
use crate::lower::LoweredOp as L;
let lowered = L::Add(
Box::new(L::Sin(Box::new(L::Var(0)))),
Box::new(L::Cos(Box::new(L::Var(0)))),
);
let ops = lowered.to_oxiblas_ops();
let data: Vec<Vec<f64>> = (0..128).map(|i| vec![i as f64 * 0.05]).collect();
let simd_results = eval_batch_simd(&ops, &data);
let scalar_results = LoweredOp::eval_batch_scalar_from_ops(&ops, &data);
assert_eq!(simd_results.len(), 128);
for (i, (s, r)) in simd_results.iter().zip(scalar_results.iter()).enumerate() {
assert!(
(s - r).abs() < 1e-12,
"sin+cos row {i}: SIMD={s} scalar={r}"
);
}
}
#[test]
fn test_eval_batch_simd_remainder() {
let lowered = exp_lowered();
let ops = lowered.to_oxiblas_ops();
let data: Vec<Vec<f64>> = (0..7).map(|i| vec![i as f64 * 0.3]).collect();
let simd_results = eval_batch_simd(&ops, &data);
let scalar_results = LoweredOp::eval_batch_scalar_from_ops(&ops, &data);
assert_eq!(simd_results.len(), 7);
for (i, (s, r)) in simd_results.iter().zip(scalar_results.iter()).enumerate() {
assert!((s - r).abs() < 1e-12, "remainder row {i}: {s} vs {r}");
}
}
#[test]
fn test_eval_batch_simd_empty() {
let lowered = exp_lowered();
let ops = lowered.to_oxiblas_ops();
let results = eval_batch_simd(&ops, &[]);
assert!(results.is_empty());
}
#[cfg(feature = "parallel")]
#[test]
fn test_eval_batch_simd_parallel_matches_scalar() {
let lowered = exp_lowered();
let ops = lowered.to_oxiblas_ops();
let data: Vec<Vec<f64>> = (0..1000).map(|i| vec![i as f64 * 0.001]).collect();
let simd_results = eval_batch_simd(&ops, &data);
let scalar_results = LoweredOp::eval_batch_scalar_from_ops(&ops, &data);
assert_eq!(simd_results.len(), 1000);
for (i, (s, r)) in simd_results.iter().zip(scalar_results.iter()).enumerate() {
assert!((s - r).abs() < 1e-12, "parallel row {i}: {s} vs {r}");
}
}
}