#[cfg(feature = "memory_efficient")]
mod tests {
use scirs2_core::error::{CoreError, ErrorContext};
use scirs2_core::memory_efficient::{register_fusion, FusedOp, OpFusion};
use std::any::{Any, TypeId};
use std::sync::Arc;
#[derive(Clone)]
struct SquareOp;
impl FusedOp for SquareOp {
fn name(&self) -> &str {
"SquareOp"
}
fn input_type(&self) -> TypeId {
TypeId::of::<f64>()
}
fn output_type(&self) -> TypeId {
TypeId::of::<f64>()
}
fn can_fuse_with(&self, other: &dyn FusedOp) -> bool {
other.name() == "SqrtOp"
}
fn fuse_with(&self, other: &dyn FusedOp) -> Arc<dyn FusedOp> {
if other.name() == "SqrtOp" {
Arc::new(IdentityOp)
} else {
self.clone_op()
}
}
fn apply(&self, input: &dyn Any) -> Result<Box<dyn Any>, CoreError> {
let x = input
.downcast_ref::<f64>()
.ok_or_else(|| CoreError::InvalidArgument(ErrorContext::new("Expected f64")))?;
Ok(Box::new(x * x))
}
fn clone_op(&self) -> Arc<dyn FusedOp> {
Arc::new(self.clone())
}
}
#[derive(Clone)]
struct SqrtOp;
impl FusedOp for SqrtOp {
fn name(&self) -> &str {
"SqrtOp"
}
fn input_type(&self) -> TypeId {
TypeId::of::<f64>()
}
fn output_type(&self) -> TypeId {
TypeId::of::<f64>()
}
fn can_fuse_with(&self, other: &dyn FusedOp) -> bool {
other.name() == "SquareOp"
}
fn fuse_with(&self, other: &dyn FusedOp) -> Arc<dyn FusedOp> {
if other.name() == "SquareOp" {
Arc::new(IdentityOp)
} else {
self.clone_op()
}
}
fn apply(&self, input: &dyn Any) -> Result<Box<dyn Any>, CoreError> {
let x = input
.downcast_ref::<f64>()
.ok_or_else(|| CoreError::InvalidArgument(ErrorContext::new("Expected f64")))?;
if *x < 0.0 {
return Err(CoreError::InvalidArgument(ErrorContext::new(
"Cannot take sqrt of negative number",
)));
}
Ok(Box::new(x.sqrt()))
}
fn clone_op(&self) -> Arc<dyn FusedOp> {
Arc::new(self.clone())
}
}
#[derive(Clone)]
struct IdentityOp;
impl FusedOp for IdentityOp {
fn name(&self) -> &str {
"IdentityOp"
}
fn input_type(&self) -> TypeId {
TypeId::of::<f64>()
}
fn output_type(&self) -> TypeId {
TypeId::of::<f64>()
}
fn can_fuse_with(&self, other: &dyn FusedOp) -> bool {
true
}
fn fuse_with(&self, other: &dyn FusedOp) -> Arc<dyn FusedOp> {
other.clone_op()
}
fn apply(&self, input: &dyn Any) -> Result<Box<dyn Any>, CoreError> {
let x = input
.downcast_ref::<f64>()
.ok_or_else(|| CoreError::InvalidArgument(ErrorContext::new("Expected f64")))?;
Ok(Box::new(*x))
}
fn clone_op(&self) -> Arc<dyn FusedOp> {
Arc::new(self.clone())
}
}
#[test]
fn test_op_fusion_creation() {
let fusion = OpFusion::new();
assert!(fusion.is_empty());
assert_eq!(fusion.num_ops(), 0);
}
#[test]
fn test_op_fusion_add_op() {
let mut fusion = OpFusion::new();
let square_op = Arc::new(SquareOp);
fusion
.add_op(Arc::new(SquareOp))
.expect("Test: operation failed");
assert!(!fusion.is_empty());
assert_eq!(fusion.num_ops(), 1);
}
#[test]
fn test_op_fusion_type_mismatch() {
let mut fusion = OpFusion::new();
let square_op = Arc::new(SquareOp);
fusion
.add_op(Arc::new(SquareOp))
.expect("Test: operation failed");
struct MismatchOp;
impl FusedOp for MismatchOp {
fn name(&self) -> &str {
"MismatchOp"
}
fn input_type(&self) -> TypeId {
TypeId::of::<i32>()
}
fn output_type(&self) -> TypeId {
TypeId::of::<i32>()
}
fn can_fuse_with(&self, other: &dyn FusedOp) -> bool {
false
}
fn fuse_with(&self, other: &dyn FusedOp) -> Arc<dyn FusedOp> {
Arc::new(MismatchOp)
}
fn apply(&self, input: &dyn Any) -> Result<Box<dyn Any>, CoreError> {
Ok(Box::new(0))
}
fn clone_op(&self) -> Arc<dyn FusedOp> {
Arc::new(MismatchOp)
}
}
let mismatch_op = Arc::new(MismatchOp);
let result = fusion.add_op(Arc::new(MismatchOp));
assert!(result.is_err());
}
#[test]
fn test_op_fusion_optimize() {
let mut fusion = OpFusion::new();
let square_op = Arc::new(SquareOp);
let sqrt_op = Arc::new(SqrtOp);
fusion
.add_op(Arc::new(SquareOp))
.expect("Test: operation failed");
fusion
.add_op(Arc::new(SqrtOp))
.expect("Test: operation failed");
assert_eq!(fusion.num_ops(), 2);
fusion.optimize().expect("Test: operation failed");
assert_eq!(fusion.num_ops(), 1);
}
#[test]
fn test_op_fusion_apply() {
let mut fusion = OpFusion::new();
let square_op = Arc::new(SquareOp);
fusion
.add_op(Arc::new(SquareOp))
.expect("Test: operation failed");
let input = 3.0;
let result = fusion.apply(input).expect("Test: operation failed");
let output = result
.downcast_ref::<f64>()
.expect("Test: operation failed");
assert_eq!(*output, 9.0); }
#[test]
fn test_op_fusion_register() {
let square_op = Arc::new(SquareOp);
register_fusion::<f64>(Arc::new(SquareOp)).expect("Test: operation failed");
}
#[test]
fn test_empty_op_fusion_optimize() {
let mut fusion = OpFusion::new();
fusion.optimize().expect("Test: operation failed");
assert!(fusion.is_empty());
}
#[test]
fn test_single_op_fusion_optimize() {
let mut fusion = OpFusion::new();
let square_op = Arc::new(SquareOp);
fusion
.add_op(Arc::new(SquareOp))
.expect("Test: operation failed");
fusion.optimize().expect("Test: operation failed");
assert_eq!(fusion.num_ops(), 1);
}
#[test]
fn test_op_fusion_apply_type_mismatch() {
let mut fusion = OpFusion::new();
let square_op = Arc::new(SquareOp);
fusion
.add_op(Arc::new(SquareOp))
.expect("Test: operation failed");
let input = 3i32;
let result = fusion.apply(input);
assert!(result.is_err());
}
}