use torsh_core::{
dtype::TensorElement,
error::{Result, TorshError},
};
use torsh_tensor::Tensor;
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, vec::Vec};
pub trait Transform<T>: Send + Sync {
type Output;
fn transform(&self, input: T) -> Result<Self::Output>;
fn transform_batch(&self, inputs: Vec<T>) -> Result<Vec<Self::Output>> {
inputs
.into_iter()
.map(|input| self.transform(input))
.collect()
}
fn is_deterministic(&self) -> bool {
true
}
}
pub trait TransformBuilder {
type Transform;
fn build(self) -> Self::Transform;
}
#[macro_export]
macro_rules! simple_transform {
($name:ident, $input:ty, $output:ty, $transform_fn:expr) => {
#[derive(Clone, Debug, Default)]
pub struct $name;
impl $crate::core_framework::Transform<$input> for $name {
type Output = $output;
fn transform(&self, input: $input) -> $crate::core_framework::Result<Self::Output> {
Ok($transform_fn(input))
}
}
};
($name:ident, $input:ty, $output:ty, $transform_fn:expr, deterministic = $det:literal) => {
#[derive(Clone, Debug, Default)]
pub struct $name;
impl $crate::core_framework::Transform<$input> for $name {
type Output = $output;
fn transform(&self, input: $input) -> $crate::core_framework::Result<Self::Output> {
Ok($transform_fn(input))
}
fn is_deterministic(&self) -> bool {
$det
}
}
};
}
pub trait TransformExt<T>: Transform<T> + Sized + 'static {
fn then<U>(self, next: U) -> Chain<Self, U>
where
U: Transform<Self::Output>,
{
Chain::new(self, next)
}
fn when<P>(self, predicate: P) -> Conditional<Self, P>
where
P: Fn(&T) -> bool + Send + Sync,
{
Conditional::new(self, predicate)
}
fn boxed(self) -> Box<dyn Transform<T, Output = Self::Output> + Send + Sync> {
Box::new(self)
}
}
impl<T, U: Transform<T> + 'static> TransformExt<T> for U {}
#[derive(Debug, Clone)]
pub struct Chain<T1, T2> {
first: T1,
second: T2,
}
impl<T1, T2> Chain<T1, T2> {
pub fn new(first: T1, second: T2) -> Self {
Self { first, second }
}
}
impl<T, T1, T2> Transform<T> for Chain<T1, T2>
where
T1: Transform<T>,
T2: Transform<T1::Output>,
{
type Output = T2::Output;
fn transform(&self, input: T) -> Result<Self::Output> {
let intermediate = self.first.transform(input)?;
self.second.transform(intermediate)
}
fn is_deterministic(&self) -> bool {
self.first.is_deterministic() && self.second.is_deterministic()
}
}
#[derive(Debug, Clone)]
pub struct Conditional<T, P> {
transform: T,
predicate: P,
}
impl<T, P> Conditional<T, P> {
pub fn new(transform: T, predicate: P) -> Self {
Self {
transform,
predicate,
}
}
}
impl<T, U, P> Transform<T> for Conditional<U, P>
where
U: Transform<T, Output = T>,
P: Fn(&T) -> bool + Send + Sync,
{
type Output = T;
fn transform(&self, input: T) -> Result<Self::Output> {
if (self.predicate)(&input) {
self.transform.transform(input)
} else {
Ok(input)
}
}
fn is_deterministic(&self) -> bool {
self.transform.is_deterministic()
}
}
pub struct Compose<T> {
transforms: Vec<Box<dyn Transform<T, Output = T> + Send + Sync>>,
}
impl<T> Compose<T> {
pub fn new(transforms: Vec<Box<dyn Transform<T, Output = T> + Send + Sync>>) -> Self {
Self { transforms }
}
pub fn add<U>(&mut self, transform: U)
where
U: Transform<T, Output = T> + Send + Sync + 'static,
{
self.transforms.push(Box::new(transform));
}
pub fn len(&self) -> usize {
self.transforms.len()
}
pub fn is_empty(&self) -> bool {
self.transforms.is_empty()
}
}
impl<T> Transform<T> for Compose<T> {
type Output = T;
fn transform(&self, mut input: T) -> Result<Self::Output> {
for transform in &self.transforms {
input = transform.transform(input)?;
}
Ok(input)
}
fn is_deterministic(&self) -> bool {
self.transforms.iter().all(|t| t.is_deterministic())
}
}
#[derive(Debug, Clone)]
pub struct Normalize<T: TensorElement> {
#[allow(dead_code)] mean: Vec<T>,
#[allow(dead_code)] std: Vec<T>,
}
impl<T: TensorElement> Normalize<T> {
pub fn new(mean: Vec<T>, std: Vec<T>) -> Result<Self> {
if mean.len() != std.len() {
return Err(TorshError::InvalidArgument(
"Mean and std vectors must have the same length".to_string(),
));
}
Ok(Self { mean, std })
}
}
impl<T: TensorElement> Transform<Tensor<T>> for Normalize<T> {
type Output = Tensor<T>;
fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
Ok(input)
}
}
#[derive(Debug, Clone)]
pub struct ToType<From, To> {
_phantom: core::marker::PhantomData<(From, To)>,
}
impl<From, To> Default for ToType<From, To> {
fn default() -> Self {
Self::new()
}
}
impl<From, To> ToType<From, To> {
pub fn new() -> Self {
Self {
_phantom: core::marker::PhantomData,
}
}
}
impl<From: TensorElement, To: TensorElement> Transform<Tensor<From>> for ToType<From, To> {
type Output = Tensor<To>;
fn transform(&self, _input: Tensor<From>) -> Result<Self::Output> {
Err(TorshError::InvalidArgument(
"Type conversion not yet implemented".to_string(),
))
}
}
#[derive(Debug)]
pub struct Lambda<F> {
func: F,
}
impl<F> Lambda<F> {
pub fn new(func: F) -> Self {
Self { func }
}
}
impl<T, O, F> Transform<T> for Lambda<F>
where
F: Fn(T) -> Result<O> + Send + Sync,
{
type Output = O;
fn transform(&self, input: T) -> Result<Self::Output> {
(self.func)(input)
}
fn is_deterministic(&self) -> bool {
true
}
}
pub fn normalize<T: TensorElement>(mean: Vec<T>, std: Vec<T>) -> Result<Normalize<T>> {
Normalize::new(mean, std)
}
pub fn to_type<From: TensorElement, To: TensorElement>() -> ToType<From, To> {
ToType::new()
}
pub fn lambda<F>(func: F) -> Lambda<F> {
Lambda::new(func)
}
pub fn compose<T>(transforms: Vec<Box<dyn Transform<T, Output = T> + Send + Sync>>) -> Compose<T> {
Compose::new(transforms)
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(dead_code)]
fn mock_tensor() -> Tensor<f32> {
Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0],
vec![2, 2],
torsh_core::device::DeviceType::Cpu,
)
.unwrap()
}
#[test]
fn test_chain_transform() {
let lambda1 = lambda(|x: i32| Ok(x * 2));
let lambda2 = lambda(|x: i32| Ok(x + 1));
let chained = lambda1.then(lambda2);
let result = chained.transform(5).unwrap();
assert_eq!(result, 11); }
#[test]
fn test_conditional_transform() {
let double = lambda(|x: i32| Ok(x * 2));
let conditional = double.when(|&x| x > 5);
assert_eq!(conditional.transform(3).unwrap(), 3); assert_eq!(conditional.transform(7).unwrap(), 14); }
#[test]
fn test_compose_transform() {
let lambda1 = lambda(|x: i32| Ok(x + 1));
let lambda2 = lambda(|x: i32| Ok(x * 2));
let mut composition = Compose::new(vec![]);
composition.add(lambda1);
composition.add(lambda2);
let result = composition.transform(5).unwrap();
assert_eq!(result, 12); }
#[test]
fn test_normalize_creation() {
let mean = vec![0.485f32, 0.456, 0.406];
let std = vec![0.229f32, 0.224, 0.225];
let normalize_transform = normalize(mean, std);
assert!(normalize_transform.is_ok());
}
#[test]
fn test_normalize_invalid_dimensions() {
let mean = vec![0.485f32, 0.456];
let std = vec![0.229f32, 0.224, 0.225];
let normalize_transform = normalize(mean, std);
assert!(normalize_transform.is_err());
}
#[test]
fn test_determinism() {
let deterministic = lambda(|x: i32| Ok(x + 1));
assert!(deterministic.is_deterministic());
let chain = deterministic.then(lambda(|x: i32| Ok(x * 2)));
assert!(chain.is_deterministic());
}
}