use scirs2_core::ndarray::Axis;
use std::collections::HashMap;
use tensorlogic_infer::{ElemOp, ExecutorError, ReduceOp, TlExecutor};
use crate::autodiff::ForwardTape;
use crate::memory_pool::TensorPool;
use crate::Scirs2Tensor;
pub struct Scirs2Exec {
pub tensors: HashMap<String, Scirs2Tensor>,
pub(crate) tape: Option<ForwardTape>,
pub(crate) pool: Option<TensorPool>,
}
impl Default for Scirs2Exec {
fn default() -> Self {
Self::new()
}
}
impl Scirs2Exec {
pub fn new() -> Self {
Scirs2Exec {
tensors: HashMap::new(),
tape: None,
pool: None,
}
}
pub fn with_memory_pool() -> Self {
Scirs2Exec {
tensors: HashMap::new(),
tape: None,
pool: Some(TensorPool::new()),
}
}
pub fn enable_pooling(&mut self) {
if self.pool.is_none() {
self.pool = Some(TensorPool::new());
}
}
pub fn disable_pooling(&mut self) {
self.pool = None;
}
pub fn pool_stats(&self) -> Option<crate::memory_pool::PoolStats> {
self.pool.as_ref().map(|p| p.stats())
}
pub fn add_tensor(&mut self, name: impl Into<String>, tensor: Scirs2Tensor) {
self.tensors.insert(name.into(), tensor);
}
pub fn get_tensor(&self, name: &str) -> Option<&Scirs2Tensor> {
self.tensors.get(name)
}
}
impl TlExecutor for Scirs2Exec {
type Tensor = Scirs2Tensor;
type Error = ExecutorError;
fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
if inputs.is_empty() {
return Err(ExecutorError::InvalidEinsumSpec(
"No input tensors provided".to_string(),
));
}
let views: Vec<_> = inputs.iter().map(|t| t.view()).collect();
let view_refs: Vec<_> = views.iter().collect();
scirs2_linalg::einsum(spec, &view_refs)
.map_err(|e| ExecutorError::InvalidEinsumSpec(format!("Einsum error: {}", e)))
}
fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
let result = match op {
ElemOp::Relu => x.mapv(|v| v.max(0.0)),
ElemOp::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
ElemOp::OneMinus => x.mapv(|v| 1.0 - v),
_ => {
return Err(ExecutorError::UnsupportedOperation(format!(
"Unary operation {:?} not supported",
op
)))
}
};
Ok(result)
}
fn elem_op_binary(
&mut self,
op: ElemOp,
x: &Self::Tensor,
y: &Self::Tensor,
) -> Result<Self::Tensor, Self::Error> {
let x_is_scalar = x.ndim() == 0;
let y_is_scalar = y.ndim() == 0;
let (x_broadcast, y_broadcast);
let (x_ref, y_ref) = if x_is_scalar && !y_is_scalar {
let scalar_value = x
.iter()
.next()
.expect("scalar tensor has at least one element");
x_broadcast = scirs2_core::ndarray::Array::from_elem(y.raw_dim(), *scalar_value);
(&x_broadcast.view(), &y.view())
} else if y_is_scalar && !x_is_scalar {
let scalar_value = y
.iter()
.next()
.expect("scalar tensor has at least one element");
y_broadcast = scirs2_core::ndarray::Array::from_elem(x.raw_dim(), *scalar_value);
(&x.view(), &y_broadcast.view())
} else if x.shape() != y.shape() {
return Err(ExecutorError::ShapeMismatch(format!(
"Shape mismatch: {:?} vs {:?}",
x.shape(),
y.shape()
)));
} else {
(&x.view(), &y.view())
};
let result = match op {
ElemOp::Add => x_ref + y_ref,
ElemOp::Subtract => x_ref - y_ref,
ElemOp::Multiply => x_ref * y_ref,
ElemOp::Divide => x_ref / y_ref,
ElemOp::Min => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| a.min(b)),
ElemOp::Max => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| a.max(b)),
ElemOp::Eq => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| if (a - b).abs() < 1e-10 { 1.0 } else { 0.0 }),
ElemOp::Lt => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| if a < b { 1.0 } else { 0.0 }),
ElemOp::Gt => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| if a > b { 1.0 } else { 0.0 }),
ElemOp::Lte => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| if a <= b { 1.0 } else { 0.0 }),
ElemOp::Gte => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| if a >= b { 1.0 } else { 0.0 }),
ElemOp::OrMax => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| a.max(b)),
ElemOp::OrProbSum => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| a + b - a * b), ElemOp::Nand => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| 1.0 - a * b),
ElemOp::Nor => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| 1.0 - a.max(b)),
ElemOp::Xor => scirs2_core::ndarray::Zip::from(x_ref)
.and(y_ref)
.map_collect(|&a, &b| a + b - 2.0 * a * b),
_ => {
return Err(ExecutorError::UnsupportedOperation(format!(
"Binary operation {:?} not supported",
op
)))
}
};
Ok(result)
}
fn reduce(
&mut self,
op: ReduceOp,
x: &Self::Tensor,
axes: &[usize],
) -> Result<Self::Tensor, Self::Error> {
if axes.is_empty() {
return Ok(x.clone());
}
for &axis in axes {
if axis >= x.ndim() {
return Err(ExecutorError::ShapeMismatch(format!(
"Axis {} out of bounds for tensor with {} dimensions",
axis,
x.ndim()
)));
}
}
let mut result = x.clone();
for &axis in axes.iter().rev() {
result = match op {
ReduceOp::Sum => result.sum_axis(Axis(axis)),
ReduceOp::Max => result.fold_axis(Axis(axis), f64::NEG_INFINITY, |&a, &b| a.max(b)),
ReduceOp::Min => result.fold_axis(Axis(axis), f64::INFINITY, |&a, &b| a.min(b)),
ReduceOp::Mean => result
.mean_axis(Axis(axis))
.expect("axis is valid as validated earlier"),
ReduceOp::Product => result.fold_axis(Axis(axis), 1.0, |&a, &b| a * b),
};
}
Ok(result)
}
}