use crate::{Scirs2Tensor, TlBackendError, TlBackendResult};
use scirs2_core::ndarray::Zip;
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct InplaceExecutor {
aliased_tensors: HashSet<usize>,
pub stats: InplaceStats,
}
#[derive(Debug, Clone, Default)]
pub struct InplaceStats {
pub inplace_ops: usize,
pub non_inplace_ops: usize,
pub memory_saved_bytes: usize,
}
impl InplaceExecutor {
pub fn new() -> Self {
Self {
aliased_tensors: HashSet::new(),
stats: InplaceStats::default(),
}
}
pub fn mark_aliased(&mut self, tensor_id: usize) {
self.aliased_tensors.insert(tensor_id);
}
pub fn can_execute_inplace(&self, tensor_id: usize) -> bool {
!self.aliased_tensors.contains(&tensor_id)
}
pub fn execute_inplace_unary(
&mut self,
op: &str,
tensor: &mut Scirs2Tensor,
) -> TlBackendResult<()> {
let element_count = tensor.len();
let bytes_saved = element_count * std::mem::size_of::<f64>();
match op {
"relu" => {
Zip::from(tensor).for_each(|x| {
*x = x.max(0.0);
});
}
"sigmoid" => {
Zip::from(tensor).for_each(|x| {
*x = 1.0 / (1.0 + (-*x).exp());
});
}
"oneminus" => {
Zip::from(tensor).for_each(|x| {
*x = 1.0 - *x;
});
}
"tanh" => {
Zip::from(tensor).for_each(|x| {
*x = x.tanh();
});
}
"abs" => {
Zip::from(tensor).for_each(|x| {
*x = x.abs();
});
}
"neg" => {
Zip::from(tensor).for_each(|x| {
*x = -*x;
});
}
"exp" => {
Zip::from(tensor).for_each(|x| {
*x = x.exp();
});
}
"log" => {
Zip::from(tensor).for_each(|x| {
*x = x.ln();
});
}
"sqrt" => {
Zip::from(tensor).for_each(|x| {
*x = x.sqrt();
});
}
"square" => {
Zip::from(tensor).for_each(|x| {
*x = *x * *x;
});
}
"clip" => {
Zip::from(tensor).for_each(|x| {
*x = x.clamp(0.0, 1.0);
});
}
_ => {
self.stats.non_inplace_ops += 1;
return Err(TlBackendError::unsupported(format!(
"Unsupported in-place unary operation: {}",
op
)));
}
}
self.stats.inplace_ops += 1;
self.stats.memory_saved_bytes += bytes_saved;
Ok(())
}
pub fn execute_inplace_binary(
&mut self,
op: &str,
lhs: &mut Scirs2Tensor,
rhs: &Scirs2Tensor,
) -> TlBackendResult<()> {
if lhs.shape() != rhs.shape() {
self.stats.non_inplace_ops += 1;
return Err(TlBackendError::shape_mismatch(
op,
vec![lhs.shape().to_vec()],
vec![rhs.shape().to_vec()],
));
}
let element_count = lhs.len();
let bytes_saved = element_count * std::mem::size_of::<f64>();
match op {
"add" => {
Zip::from(lhs).and(rhs).for_each(|x, &y| {
*x += y;
});
}
"subtract" | "sub" => {
Zip::from(lhs).and(rhs).for_each(|x, &y| {
*x -= y;
});
}
"multiply" | "mul" => {
Zip::from(lhs).and(rhs).for_each(|x, &y| {
*x *= y;
});
}
"divide" | "div" => {
Zip::from(lhs).and(rhs).for_each(|x, &y| {
*x /= y;
});
}
"min" => {
Zip::from(lhs).and(rhs).for_each(|x, &y| {
*x = x.min(y);
});
}
"max" => {
Zip::from(lhs).and(rhs).for_each(|x, &y| {
*x = x.max(y);
});
}
_ => {
self.stats.non_inplace_ops += 1;
return Err(TlBackendError::unsupported(format!(
"Unsupported in-place binary operation: {}",
op
)));
}
}
self.stats.inplace_ops += 1;
self.stats.memory_saved_bytes += bytes_saved;
Ok(())
}
pub fn execute_inplace_scalar(
&mut self,
op: &str,
tensor: &mut Scirs2Tensor,
scalar: f64,
) -> TlBackendResult<()> {
let element_count = tensor.len();
let bytes_saved = element_count * std::mem::size_of::<f64>();
match op {
"add" | "add_scalar" => {
Zip::from(tensor).for_each(|x| {
*x += scalar;
});
}
"sub" | "sub_scalar" => {
Zip::from(tensor).for_each(|x| {
*x -= scalar;
});
}
"mul" | "mul_scalar" => {
Zip::from(tensor).for_each(|x| {
*x *= scalar;
});
}
"div" | "div_scalar" => {
Zip::from(tensor).for_each(|x| {
*x /= scalar;
});
}
"pow" => {
Zip::from(tensor).for_each(|x| {
*x = x.powf(scalar);
});
}
"clamp_min" => {
Zip::from(tensor).for_each(|x| {
*x = x.max(scalar);
});
}
"clamp_max" => {
Zip::from(tensor).for_each(|x| {
*x = x.min(scalar);
});
}
_ => {
self.stats.non_inplace_ops += 1;
return Err(TlBackendError::unsupported(format!(
"Unsupported in-place scalar operation: {}",
op
)));
}
}
self.stats.inplace_ops += 1;
self.stats.memory_saved_bytes += bytes_saved;
Ok(())
}
pub fn statistics(&self) -> &InplaceStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = InplaceStats::default();
}
pub fn clear_aliasing(&mut self) {
self.aliased_tensors.clear();
}
}
impl Default for InplaceExecutor {
fn default() -> Self {
Self::new()
}
}
impl InplaceStats {
pub fn inplace_percentage(&self) -> f64 {
let total = self.inplace_ops + self.non_inplace_ops;
if total == 0 {
0.0
} else {
(self.inplace_ops as f64 / total as f64) * 100.0
}
}
pub fn format_memory_saved(&self) -> String {
let bytes = self.memory_saved_bytes;
if bytes < 1024 {
format!("{} bytes", bytes)
} else if bytes < 1024 * 1024 {
format!("{:.2} KB", bytes as f64 / 1024.0)
} else if bytes < 1024 * 1024 * 1024 {
format!("{:.2} MB", bytes as f64 / (1024.0 * 1024.0))
} else {
format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
}
}
}
pub fn can_execute_inplace(op: &str) -> bool {
matches!(
op,
"relu"
| "sigmoid"
| "oneminus"
| "tanh"
| "abs"
| "neg"
| "exp"
| "log"
| "sqrt"
| "square"
| "clip"
| "add"
| "subtract"
| "sub"
| "multiply"
| "mul"
| "divide"
| "div"
| "min"
| "max"
)
}
pub fn is_shape_preserving(op: &str) -> bool {
!matches!(op, "sum" | "mean" | "max_reduce" | "min_reduce" | "product")
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::ArrayD;
#[test]
fn test_inplace_executor_new() {
let executor = InplaceExecutor::new();
assert_eq!(executor.stats.inplace_ops, 0);
assert_eq!(executor.stats.non_inplace_ops, 0);
assert_eq!(executor.stats.memory_saved_bytes, 0);
}
#[test]
fn test_can_execute_inplace() {
let mut executor = InplaceExecutor::new();
assert!(executor.can_execute_inplace(0));
executor.mark_aliased(0);
assert!(!executor.can_execute_inplace(0));
assert!(executor.can_execute_inplace(1));
}
#[test]
fn test_inplace_unary_relu() {
let mut executor = InplaceExecutor::new();
let mut tensor = ArrayD::from_shape_vec(vec![3], vec![-1.0, 0.0, 1.0]).expect("unwrap");
executor
.execute_inplace_unary("relu", &mut tensor)
.expect("unwrap");
assert_eq!(tensor[[0]], 0.0);
assert_eq!(tensor[[1]], 0.0);
assert_eq!(tensor[[2]], 1.0);
assert_eq!(executor.stats.inplace_ops, 1);
}
#[test]
fn test_inplace_unary_sigmoid() {
let mut executor = InplaceExecutor::new();
let mut tensor = ArrayD::from_shape_vec(vec![2], vec![0.0, 1.0]).expect("unwrap");
executor
.execute_inplace_unary("sigmoid", &mut tensor)
.expect("unwrap");
assert!((tensor[[0]] - 0.5).abs() < 1e-6);
assert!((tensor[[1]] - 0.731).abs() < 0.01);
assert_eq!(executor.stats.inplace_ops, 1);
}
#[test]
fn test_inplace_unary_oneminus() {
let mut executor = InplaceExecutor::new();
let mut tensor = ArrayD::from_shape_vec(vec![3], vec![0.0, 0.5, 1.0]).expect("unwrap");
executor
.execute_inplace_unary("oneminus", &mut tensor)
.expect("unwrap");
assert_eq!(tensor[[0]], 1.0);
assert_eq!(tensor[[1]], 0.5);
assert_eq!(tensor[[2]], 0.0);
}
#[test]
fn test_inplace_binary_add() {
let mut executor = InplaceExecutor::new();
let mut lhs = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("unwrap");
let rhs = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).expect("unwrap");
executor
.execute_inplace_binary("add", &mut lhs, &rhs)
.expect("unwrap");
assert_eq!(lhs[[0]], 5.0);
assert_eq!(lhs[[1]], 7.0);
assert_eq!(lhs[[2]], 9.0);
assert_eq!(executor.stats.inplace_ops, 1);
}
#[test]
fn test_inplace_binary_multiply() {
let mut executor = InplaceExecutor::new();
let mut lhs = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("unwrap");
let rhs = ArrayD::from_shape_vec(vec![3], vec![2.0, 3.0, 4.0]).expect("unwrap");
executor
.execute_inplace_binary("multiply", &mut lhs, &rhs)
.expect("unwrap");
assert_eq!(lhs[[0]], 2.0);
assert_eq!(lhs[[1]], 6.0);
assert_eq!(lhs[[2]], 12.0);
}
#[test]
fn test_inplace_binary_shape_mismatch() {
let mut executor = InplaceExecutor::new();
let mut lhs = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("unwrap");
let rhs = ArrayD::from_shape_vec(vec![2], vec![4.0, 5.0]).expect("unwrap");
let result = executor.execute_inplace_binary("add", &mut lhs, &rhs);
assert!(result.is_err());
assert_eq!(executor.stats.non_inplace_ops, 1);
}
#[test]
fn test_inplace_scalar_add() {
let mut executor = InplaceExecutor::new();
let mut tensor = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("unwrap");
executor
.execute_inplace_scalar("add", &mut tensor, 10.0)
.expect("unwrap");
assert_eq!(tensor[[0]], 11.0);
assert_eq!(tensor[[1]], 12.0);
assert_eq!(tensor[[2]], 13.0);
}
#[test]
fn test_inplace_scalar_multiply() {
let mut executor = InplaceExecutor::new();
let mut tensor = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("unwrap");
executor
.execute_inplace_scalar("mul", &mut tensor, 2.0)
.expect("unwrap");
assert_eq!(tensor[[0]], 2.0);
assert_eq!(tensor[[1]], 4.0);
assert_eq!(tensor[[2]], 6.0);
}
#[test]
fn test_inplace_stats() {
let mut executor = InplaceExecutor::new();
let mut tensor = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("unwrap");
executor
.execute_inplace_unary("relu", &mut tensor)
.expect("unwrap");
executor
.execute_inplace_scalar("add", &mut tensor, 1.0)
.expect("unwrap");
assert_eq!(executor.stats.inplace_ops, 2);
assert!(executor.stats.memory_saved_bytes > 0);
assert_eq!(executor.stats.inplace_percentage(), 100.0);
}
#[test]
fn test_can_execute_inplace_func() {
assert!(can_execute_inplace("relu"));
assert!(can_execute_inplace("sigmoid"));
assert!(can_execute_inplace("add"));
assert!(!can_execute_inplace("unknown_op"));
}
#[test]
fn test_is_shape_preserving() {
assert!(is_shape_preserving("relu"));
assert!(is_shape_preserving("add"));
assert!(!is_shape_preserving("sum"));
assert!(!is_shape_preserving("mean"));
}
#[test]
fn test_format_memory_saved() {
let stats = InplaceStats {
memory_saved_bytes: 512,
..Default::default()
};
assert_eq!(stats.format_memory_saved(), "512 bytes");
let stats = InplaceStats {
memory_saved_bytes: 2048,
..Default::default()
};
assert_eq!(stats.format_memory_saved(), "2.00 KB");
let stats = InplaceStats {
memory_saved_bytes: 2 * 1024 * 1024,
..Default::default()
};
assert_eq!(stats.format_memory_saved(), "2.00 MB");
let stats = InplaceStats {
memory_saved_bytes: 3 * 1024 * 1024 * 1024,
..Default::default()
};
assert_eq!(stats.format_memory_saved(), "3.00 GB");
}
#[test]
fn test_reset_stats() {
let mut executor = InplaceExecutor::new();
let mut tensor = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("unwrap");
executor
.execute_inplace_unary("relu", &mut tensor)
.expect("unwrap");
assert_eq!(executor.stats.inplace_ops, 1);
executor.reset_stats();
assert_eq!(executor.stats.inplace_ops, 0);
assert_eq!(executor.stats.memory_saved_bytes, 0);
}
#[test]
fn test_clear_aliasing() {
let mut executor = InplaceExecutor::new();
executor.mark_aliased(0);
executor.mark_aliased(1);
assert!(!executor.can_execute_inplace(0));
assert!(!executor.can_execute_inplace(1));
executor.clear_aliasing();
assert!(executor.can_execute_inplace(0));
assert!(executor.can_execute_inplace(1));
}
}