use crate::{Result, VisionError};
use torsh_tensor::Tensor;
pub trait Transform: Send + Sync {
fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>>;
fn name(&self) -> &'static str {
std::any::type_name::<Self>()
}
fn is_inplace(&self) -> bool {
false
}
fn parameters(&self) -> Vec<(&'static str, String)> {
Vec::new()
}
fn clone_transform(&self) -> Box<dyn Transform>;
}
pub struct Compose {
pub(crate) transforms: Vec<Box<dyn Transform>>,
}
impl Compose {
pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
Self { transforms }
}
pub fn len(&self) -> usize {
self.transforms.len()
}
pub fn is_empty(&self) -> bool {
self.transforms.is_empty()
}
pub fn add<T: Transform + 'static>(&mut self, transform: T) {
self.transforms.push(Box::new(transform));
}
pub fn transforms(&self) -> &[Box<dyn Transform>] {
&self.transforms
}
}
impl Transform for Compose {
fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let mut output = input.clone();
for transform in &self.transforms {
output = transform.forward(&output)?;
}
Ok(output)
}
fn name(&self) -> &'static str {
"Compose"
}
fn parameters(&self) -> Vec<(&'static str, String)> {
vec![("num_transforms", format!("{}", self.transforms.len()))]
}
fn clone_transform(&self) -> Box<dyn Transform> {
Box::new(Compose::new(
self.transforms
.iter()
.map(|t| t.clone_transform())
.collect(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation;
struct MockTransform {
name: &'static str,
multiplier: f32,
}
impl MockTransform {
fn new(name: &'static str, multiplier: f32) -> Self {
Self { name, multiplier }
}
}
impl Transform for MockTransform {
fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
input
.mul_scalar(self.multiplier)
.map_err(|e| VisionError::TensorError(e))
}
fn name(&self) -> &'static str {
self.name
}
fn parameters(&self) -> Vec<(&'static str, String)> {
vec![("multiplier", format!("{:.2}", self.multiplier))]
}
fn clone_transform(&self) -> Box<dyn Transform> {
Box::new(MockTransform::new(self.name, self.multiplier))
}
}
#[test]
fn test_transform_trait() {
let transform = MockTransform::new("TestTransform", 2.0);
let input = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let result = transform
.forward(&input)
.expect("forward pass should succeed");
assert_eq!(
result
.get(&[0, 0, 0])
.expect("element retrieval should succeed for valid index"),
2.0
);
assert_eq!(transform.name(), "TestTransform");
assert_eq!(transform.is_inplace(), false);
let params = transform.parameters();
assert_eq!(params.len(), 1);
assert_eq!(params[0].0, "multiplier");
assert_eq!(params[0].1, "2.00");
}
#[test]
fn test_compose_empty() {
let compose = Compose::new(vec![]);
assert!(compose.is_empty());
assert_eq!(compose.len(), 0);
let input = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let result = compose
.forward(&input)
.expect("forward pass should succeed");
assert_eq!(
result
.get(&[0, 0, 0])
.expect("element retrieval should succeed for valid index"),
1.0
);
}
#[test]
fn test_compose_single_transform() {
let transforms = vec![Box::new(MockTransform::new("Double", 2.0)) as Box<dyn Transform>];
let compose = Compose::new(transforms);
assert_eq!(compose.len(), 1);
assert!(!compose.is_empty());
let input = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let result = compose
.forward(&input)
.expect("forward pass should succeed");
assert_eq!(
result
.get(&[0, 0, 0])
.expect("element retrieval should succeed for valid index"),
2.0
);
}
#[test]
fn test_compose_multiple_transforms() {
let transforms = vec![
Box::new(MockTransform::new("Double", 2.0)) as Box<dyn Transform>,
Box::new(MockTransform::new("Triple", 3.0)) as Box<dyn Transform>,
];
let compose = Compose::new(transforms);
assert_eq!(compose.len(), 2);
let input = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let result = compose
.forward(&input)
.expect("forward pass should succeed");
assert_eq!(
result
.get(&[0, 0, 0])
.expect("element retrieval should succeed for valid index"),
6.0
);
}
#[test]
fn test_compose_add_transform() {
let mut compose = Compose::new(vec![]);
assert_eq!(compose.len(), 0);
compose.add(MockTransform::new("Double", 2.0));
assert_eq!(compose.len(), 1);
compose.add(MockTransform::new("Triple", 3.0));
assert_eq!(compose.len(), 2);
}
#[test]
fn test_compose_clone() {
let transforms = vec![
Box::new(MockTransform::new("Double", 2.0)) as Box<dyn Transform>,
Box::new(MockTransform::new("Triple", 3.0)) as Box<dyn Transform>,
];
let compose = Compose::new(transforms);
let cloned = compose.clone_transform();
let input = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let original_result = compose
.forward(&input)
.expect("forward pass should succeed");
let cloned_result = cloned.forward(&input).expect("forward pass should succeed");
assert_eq!(
original_result
.get(&[0, 0, 0])
.expect("element retrieval should succeed for valid index"),
cloned_result
.get(&[0, 0, 0])
.expect("element retrieval should succeed for valid index")
);
}
#[test]
fn test_compose_parameters() {
let transforms = vec![
Box::new(MockTransform::new("Double", 2.0)) as Box<dyn Transform>,
Box::new(MockTransform::new("Triple", 3.0)) as Box<dyn Transform>,
];
let compose = Compose::new(transforms);
let params = compose.parameters();
assert_eq!(params.len(), 1);
assert_eq!(params[0].0, "num_transforms");
assert_eq!(params[0].1, "2");
}
}