use scirs2_core::ndarray::{ArrayD, Axis};
use std::collections::HashMap;
use tensorlogic_infer::{ElemOp, ExecutorError, ReduceOp, TlExecutor};
pub type Scirs2Tensor32 = ArrayD<f32>;
pub struct Scirs2Exec32 {
pub tensors: HashMap<String, Scirs2Tensor32>,
}
impl Default for Scirs2Exec32 {
fn default() -> Self {
Self::new()
}
}
impl Scirs2Exec32 {
pub fn new() -> Self {
Scirs2Exec32 {
tensors: HashMap::new(),
}
}
pub fn add_tensor(&mut self, name: impl Into<String>, tensor: Scirs2Tensor32) {
self.tensors.insert(name.into(), tensor);
}
pub fn get_tensor(&self, name: &str) -> Option<&Scirs2Tensor32> {
self.tensors.get(name)
}
}
impl TlExecutor for Scirs2Exec32 {
type Tensor = Scirs2Tensor32;
type Error = ExecutorError;
fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
if inputs.is_empty() {
return Err(ExecutorError::InvalidEinsumSpec(
"No input tensors provided".to_string(),
));
}
let views: Vec<_> = inputs.iter().map(|t| t.view()).collect();
let view_refs: Vec<_> = views.iter().collect();
scirs2_linalg::einsum(spec, &view_refs)
.map_err(|e| ExecutorError::InvalidEinsumSpec(format!("Einsum error: {}", e)))
}
fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
let result = match op {
ElemOp::Relu => x.mapv(|v| v.max(0.0_f32)),
ElemOp::Sigmoid => x.mapv(|v| 1.0_f32 / (1.0_f32 + (-v).exp())),
ElemOp::OneMinus => x.mapv(|v| 1.0_f32 - v),
_ => {
return Err(ExecutorError::UnsupportedOperation(format!(
"Unary operation {:?} not supported",
op
)))
}
};
Ok(result)
}
fn elem_op_binary(
&mut self,
op: ElemOp,
x: &Self::Tensor,
y: &Self::Tensor,
) -> Result<Self::Tensor, Self::Error> {
let x_is_scalar = x.ndim() == 0;
let y_is_scalar = y.ndim() == 0;
let (x_broadcast, y_broadcast);
let (x_ref, y_ref) = if x_is_scalar && !y_is_scalar {
let scalar_value = x
.iter()
.next()
.expect("scalar tensor has at least one element");
x_broadcast = scirs2_core::ndarray::Array::from_elem(y.raw_dim(), *scalar_value);
(&x_broadcast.view(), &y.view())
} else if y_is_scalar && !x_is_scalar {
let scalar_value = y
.iter()
.next()
.expect("scalar tensor has at least one element");
y_broadcast = scirs2_core::ndarray::Array::from_elem(x.raw_dim(), *scalar_value);
(&x.view(), &y_broadcast.view())
} else if x.shape() != y.shape() {
return Err(ExecutorError::ShapeMismatch(format!(
"Shape mismatch: {:?} vs {:?}",
x.shape(),
y.shape()
)));
} else {
(&x.view(), &y.view())
};
let result = match op {
ElemOp::Add => x_ref + y_ref,
ElemOp::Subtract => x_ref - y_ref,
ElemOp::Multiply => x_ref * y_ref,
ElemOp::Divide => x_ref / y_ref,
ElemOp::Min => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| a.min(b)),
ElemOp::Max => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| a.max(b)),
ElemOp::Eq => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| if (a - b).abs() < 1e-7_f32 { 1.0 } else { 0.0 }),
ElemOp::Lt => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| if a < b { 1.0 } else { 0.0 }),
ElemOp::Gt => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| if a > b { 1.0 } else { 0.0 }),
ElemOp::Lte => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| if a <= b { 1.0 } else { 0.0 }),
ElemOp::Gte => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| if a >= b { 1.0 } else { 0.0 }),
ElemOp::OrMax => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| a.max(b)),
ElemOp::OrProbSum => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| a + b - a * b),
ElemOp::Nand => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| 1.0_f32 - a * b),
ElemOp::Nor => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| 1.0_f32 - a.max(b)),
ElemOp::Xor => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| a + b - 2.0_f32 * a * b),
_ => {
return Err(ExecutorError::UnsupportedOperation(format!(
"Binary operation {:?} not supported",
op
)))
}
};
Ok(result)
}
fn reduce(
&mut self,
op: ReduceOp,
x: &Self::Tensor,
axes: &[usize],
) -> Result<Self::Tensor, Self::Error> {
if axes.is_empty() {
return Ok(x.clone());
}
for &axis in axes {
if axis >= x.ndim() {
return Err(ExecutorError::ShapeMismatch(format!(
"Axis {} out of bounds for tensor with {} dimensions",
axis,
x.ndim()
)));
}
}
let mut result = x.clone();
for &axis in axes.iter().rev() {
result = match op {
ReduceOp::Sum => result.sum_axis(Axis(axis)),
ReduceOp::Max => result.fold_axis(Axis(axis), f32::NEG_INFINITY, |&a, &b| a.max(b)),
ReduceOp::Min => result.fold_axis(Axis(axis), f32::INFINITY, |&a, &b| a.min(b)),
ReduceOp::Mean => result
.mean_axis(Axis(axis))
.expect("axis is valid as validated earlier"),
ReduceOp::Product => result.fold_axis(Axis(axis), 1.0_f32, |&a, &b| a * b),
};
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::ArrayD;
fn make_tensor(shape: &[usize], data: Vec<f32>) -> ArrayD<f32> {
ArrayD::from_shape_vec(shape, data).expect("valid shape/data for test tensor")
}
#[test]
fn test_exec32_einsum_matmul() {
let a = make_tensor(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = make_tensor(&[3, 2], vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
let mut exec = Scirs2Exec32::new();
let result = exec.einsum("ij,jk->ik", &[a, b]).expect("einsum matmul");
assert_eq!(result.shape(), &[2, 2]);
let data: Vec<f32> = result.iter().copied().collect();
assert!(
(data[0] - 58.0).abs() < 1e-4,
"expected 58, got {}",
data[0]
);
assert!(
(data[1] - 64.0).abs() < 1e-4,
"expected 64, got {}",
data[1]
);
assert!(
(data[2] - 139.0).abs() < 1e-4,
"expected 139, got {}",
data[2]
);
assert!(
(data[3] - 154.0).abs() < 1e-4,
"expected 154, got {}",
data[3]
);
}
#[test]
fn test_exec32_relu() {
let x = make_tensor(&[4], vec![-1.0, 0.0, 1.0, 2.0]);
let mut exec = Scirs2Exec32::new();
let result = exec.elem_op(ElemOp::Relu, &x).expect("relu");
let data: Vec<f32> = result.iter().copied().collect();
assert_eq!(data[0], 0.0);
assert_eq!(data[1], 0.0);
assert_eq!(data[2], 1.0);
assert_eq!(data[3], 2.0);
}
#[test]
fn test_exec32_sigmoid() {
let x = make_tensor(&[4], vec![-2.0, 0.0, 1.0, 10.0]);
let mut exec = Scirs2Exec32::new();
let result = exec.elem_op(ElemOp::Sigmoid, &x).expect("sigmoid");
for &v in result.iter() {
assert!(v > 0.0 && v < 1.0, "sigmoid output {} not in (0,1)", v);
}
let data: Vec<f32> = result.iter().copied().collect();
assert!((data[1] - 0.5).abs() < 1e-5, "sigmoid(0) should be 0.5");
}
#[test]
fn test_exec32_one_minus() {
let x = make_tensor(&[3], vec![0.0, 0.3, 1.0]);
let mut exec = Scirs2Exec32::new();
let result = exec.elem_op(ElemOp::OneMinus, &x).expect("one_minus");
let data: Vec<f32> = result.iter().copied().collect();
assert!((data[0] - 1.0).abs() < 1e-6);
assert!((data[1] - 0.7).abs() < 1e-5);
assert!((data[2] - 0.0).abs() < 1e-6);
}
#[test]
fn test_exec32_add() {
let x = make_tensor(&[3], vec![1.0, 2.0, 3.0]);
let y = make_tensor(&[3], vec![4.0, 5.0, 6.0]);
let mut exec = Scirs2Exec32::new();
let result = exec.elem_op_binary(ElemOp::Add, &x, &y).expect("add");
let data: Vec<f32> = result.iter().copied().collect();
assert_eq!(data, vec![5.0, 7.0, 9.0]);
}
#[test]
fn test_exec32_sub() {
let x = make_tensor(&[3], vec![10.0, 5.0, 3.0]);
let y = make_tensor(&[3], vec![1.0, 2.0, 3.0]);
let mut exec = Scirs2Exec32::new();
let result = exec.elem_op_binary(ElemOp::Subtract, &x, &y).expect("sub");
let data: Vec<f32> = result.iter().copied().collect();
assert_eq!(data, vec![9.0, 3.0, 0.0]);
}
#[test]
fn test_exec32_mul() {
let x = make_tensor(&[3], vec![2.0, 3.0, 4.0]);
let y = make_tensor(&[3], vec![5.0, 6.0, 7.0]);
let mut exec = Scirs2Exec32::new();
let result = exec.elem_op_binary(ElemOp::Multiply, &x, &y).expect("mul");
let data: Vec<f32> = result.iter().copied().collect();
assert_eq!(data, vec![10.0, 18.0, 28.0]);
}
#[test]
fn test_exec32_div() {
let x = make_tensor(&[3], vec![6.0, 9.0, 12.0]);
let y = make_tensor(&[3], vec![2.0, 3.0, 4.0]);
let mut exec = Scirs2Exec32::new();
let result = exec.elem_op_binary(ElemOp::Divide, &x, &y).expect("div");
let data: Vec<f32> = result.iter().copied().collect();
assert_eq!(data, vec![3.0, 3.0, 3.0]);
}
#[test]
fn test_exec32_reduce_sum() {
let x = make_tensor(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let mut exec = Scirs2Exec32::new();
let result = exec.reduce(ReduceOp::Sum, &x, &[0]).expect("reduce_sum");
assert_eq!(result.shape(), &[3]);
let data: Vec<f32> = result.iter().copied().collect();
assert_eq!(data, vec![5.0, 7.0, 9.0]);
}
#[test]
fn test_exec32_reduce_max() {
let x = make_tensor(&[2, 3], vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0]);
let mut exec = Scirs2Exec32::new();
let result = exec.reduce(ReduceOp::Max, &x, &[0]).expect("reduce_max");
assert_eq!(result.shape(), &[3]);
let data: Vec<f32> = result.iter().copied().collect();
assert_eq!(data, vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_exec32_reduce_mean() {
let x = make_tensor(&[2, 4], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let mut exec = Scirs2Exec32::new();
let result = exec.reduce(ReduceOp::Mean, &x, &[0]).expect("reduce_mean");
assert_eq!(result.shape(), &[4]);
let data: Vec<f32> = result.iter().copied().collect();
for (got, expected) in data.iter().zip([3.0_f32, 4.0, 5.0, 6.0].iter()) {
assert!(
(got - expected).abs() < 1e-5,
"mean mismatch: {} vs {}",
got,
expected
);
}
}
#[test]
fn test_exec32_zeros() {
let zeros: ArrayD<f32> = ArrayD::zeros(vec![2, 3]);
assert_eq!(zeros.shape(), &[2, 3]);
assert!(zeros.iter().all(|&v| v == 0.0_f32));
}
#[test]
fn test_exec32_ones() {
let ones: ArrayD<f32> = ArrayD::ones(vec![2, 3]);
assert_eq!(ones.shape(), &[2, 3]);
assert!(ones.iter().all(|&v| v == 1.0_f32));
}
#[test]
fn test_exec32_from_data() {
let data = vec![1.5_f32, 2.5, 3.5, 4.5];
let tensor = ArrayD::from_shape_vec(vec![2, 2], data.clone())
.expect("valid shape for from_data test");
assert_eq!(tensor.shape(), &[2, 2]);
let roundtrip: Vec<f32> = tensor.iter().copied().collect();
assert_eq!(roundtrip, data);
}
#[test]
fn test_exec32_memory_half_of_f64() {
let f32_tensor: ArrayD<f32> = ArrayD::zeros(vec![4, 4]);
let f64_tensor: ArrayD<f64> = ArrayD::zeros(vec![4, 4]);
let f32_bytes = f32_tensor.len() * std::mem::size_of::<f32>();
let f64_bytes = f64_tensor.len() * std::mem::size_of::<f64>();
assert_eq!(f32_bytes * 2, f64_bytes);
}
}