use axonml_tensor::Tensor;
use rand::Rng;
pub trait Transform: Send + Sync {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
}
pub struct Compose {
transforms: Vec<Box<dyn Transform>>,
}
impl Compose {
#[must_use]
pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
Self { transforms }
}
#[must_use]
pub fn empty() -> Self {
Self {
transforms: Vec::new(),
}
}
pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
self.transforms.push(Box::new(transform));
self
}
}
impl Transform for Compose {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let mut result = input.clone();
for transform in &self.transforms {
result = transform.apply(&result);
}
result
}
}
pub struct ToTensor;
impl ToTensor {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Default for ToTensor {
fn default() -> Self {
Self::new()
}
}
impl Transform for ToTensor {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
input.clone()
}
}
pub struct Normalize {
mean: f32,
std: f32,
}
impl Normalize {
#[must_use]
pub fn new(mean: f32, std: f32) -> Self {
Self { mean, std }
}
#[must_use]
pub fn standard() -> Self {
Self::new(0.0, 1.0)
}
#[must_use]
pub fn zero_centered() -> Self {
Self::new(0.5, 0.5)
}
}
impl Transform for Normalize {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let data = input.to_vec();
let normalized: Vec<f32> = data.iter().map(|&x| (x - self.mean) / self.std).collect();
Tensor::from_vec(normalized, input.shape()).unwrap()
}
}
pub struct RandomNoise {
std: f32,
}
impl RandomNoise {
#[must_use]
pub fn new(std: f32) -> Self {
Self { std }
}
}
impl Transform for RandomNoise {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
if self.std == 0.0 {
return input.clone();
}
let mut rng = rand::thread_rng();
let data = input.to_vec();
let noisy: Vec<f32> = data
.iter()
.map(|&x| {
let u1: f32 = rng.r#gen();
let u2: f32 = rng.r#gen();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
x + z * self.std
})
.collect();
Tensor::from_vec(noisy, input.shape()).unwrap()
}
}
pub struct RandomCrop {
size: Vec<usize>,
}
impl RandomCrop {
#[must_use]
pub fn new(size: Vec<usize>) -> Self {
Self { size }
}
#[must_use]
pub fn new_2d(height: usize, width: usize) -> Self {
Self::new(vec![height, width])
}
}
impl Transform for RandomCrop {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let shape = input.shape();
if shape.len() < self.size.len() {
return input.clone();
}
let spatial_start = shape.len() - self.size.len();
let mut rng = rand::thread_rng();
let mut offsets = Vec::with_capacity(self.size.len());
for (i, &target_dim) in self.size.iter().enumerate() {
let input_dim = shape[spatial_start + i];
if input_dim <= target_dim {
offsets.push(0);
} else {
offsets.push(rng.gen_range(0..=input_dim - target_dim));
}
}
let crop_sizes: Vec<usize> = self
.size
.iter()
.enumerate()
.map(|(i, &s)| s.min(shape[spatial_start + i]))
.collect();
let data = input.to_vec();
if shape.len() == 1 && self.size.len() == 1 {
let start = offsets[0];
let end = start + crop_sizes[0];
let cropped = data[start..end].to_vec();
let len = cropped.len();
return Tensor::from_vec(cropped, &[len]).unwrap();
}
if shape.len() == 2 && self.size.len() == 2 {
let (_h, w) = (shape[0], shape[1]);
let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
let (off_h, off_w) = (offsets[0], offsets[1]);
let mut cropped = Vec::with_capacity(crop_h * crop_w);
for row in off_h..off_h + crop_h {
for col in off_w..off_w + crop_w {
cropped.push(data[row * w + col]);
}
}
return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
}
if shape.len() == 3 && self.size.len() == 2 {
let (c, h, w) = (shape[0], shape[1], shape[2]);
let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
let (off_h, off_w) = (offsets[0], offsets[1]);
let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
for channel in 0..c {
for row in off_h..off_h + crop_h {
for col in off_w..off_w + crop_w {
cropped.push(data[channel * h * w + row * w + col]);
}
}
}
return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
}
if shape.len() == 4 && self.size.len() == 2 {
let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
let (off_h, off_w) = (offsets[0], offsets[1]);
let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
for batch in 0..n {
for channel in 0..c {
for row in off_h..off_h + crop_h {
for col in off_w..off_w + crop_w {
let idx = batch * c * h * w + channel * h * w + row * w + col;
cropped.push(data[idx]);
}
}
}
}
return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
}
input.clone()
}
}
pub struct RandomFlip {
dim: usize,
probability: f32,
}
impl RandomFlip {
#[must_use]
pub fn new(dim: usize, probability: f32) -> Self {
Self {
dim,
probability: probability.clamp(0.0, 1.0),
}
}
#[must_use]
pub fn horizontal() -> Self {
Self::new(1, 0.5)
}
#[must_use]
pub fn vertical() -> Self {
Self::new(0, 0.5)
}
}
impl Transform for RandomFlip {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let mut rng = rand::thread_rng();
if rng.r#gen::<f32>() > self.probability {
return input.clone();
}
let shape = input.shape();
if self.dim >= shape.len() {
return input.clone();
}
if shape.len() == 1 {
let mut data = input.to_vec();
data.reverse();
return Tensor::from_vec(data, shape).unwrap();
}
if shape.len() == 2 {
let data = input.to_vec();
let (rows, cols) = (shape[0], shape[1]);
let mut flipped = vec![0.0; data.len()];
if self.dim == 0 {
for r in 0..rows {
for c in 0..cols {
flipped[r * cols + c] = data[(rows - 1 - r) * cols + c];
}
}
} else {
for r in 0..rows {
for c in 0..cols {
flipped[r * cols + c] = data[r * cols + (cols - 1 - c)];
}
}
}
return Tensor::from_vec(flipped, shape).unwrap();
}
input.clone()
}
}
pub struct Scale {
factor: f32,
}
impl Scale {
#[must_use]
pub fn new(factor: f32) -> Self {
Self { factor }
}
}
impl Transform for Scale {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
input.mul_scalar(self.factor)
}
}
pub struct Clamp {
min: f32,
max: f32,
}
impl Clamp {
#[must_use]
pub fn new(min: f32, max: f32) -> Self {
Self { min, max }
}
#[must_use]
pub fn zero_one() -> Self {
Self::new(0.0, 1.0)
}
#[must_use]
pub fn symmetric() -> Self {
Self::new(-1.0, 1.0)
}
}
impl Transform for Clamp {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let data = input.to_vec();
let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
Tensor::from_vec(clamped, input.shape()).unwrap()
}
}
pub struct Flatten;
impl Flatten {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Default for Flatten {
fn default() -> Self {
Self::new()
}
}
impl Transform for Flatten {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let data = input.to_vec();
Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
}
}
pub struct Reshape {
shape: Vec<usize>,
}
impl Reshape {
#[must_use]
pub fn new(shape: Vec<usize>) -> Self {
Self { shape }
}
}
impl Transform for Reshape {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let data = input.to_vec();
let expected_size: usize = self.shape.iter().product();
if data.len() != expected_size {
return input.clone();
}
Tensor::from_vec(data, &self.shape).unwrap()
}
}
pub struct DropoutTransform {
probability: f32,
}
impl DropoutTransform {
#[must_use]
pub fn new(probability: f32) -> Self {
Self {
probability: probability.clamp(0.0, 1.0),
}
}
}
impl Transform for DropoutTransform {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
if self.probability == 0.0 {
return input.clone();
}
let mut rng = rand::thread_rng();
let scale = 1.0 / (1.0 - self.probability);
let data = input.to_vec();
let dropped: Vec<f32> = data
.iter()
.map(|&x| {
if rng.r#gen::<f32>() < self.probability {
0.0
} else {
x * scale
}
})
.collect();
Tensor::from_vec(dropped, input.shape()).unwrap()
}
}
pub struct Lambda<F>
where
F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
{
func: F,
}
impl<F> Lambda<F>
where
F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
{
pub fn new(func: F) -> Self {
Self { func }
}
}
impl<F> Transform for Lambda<F>
where
F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
{
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
(self.func)(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
let normalize = Normalize::new(2.5, 0.5);
let output = normalize.apply(&input);
let expected = [-3.0, -1.0, 1.0, 3.0];
let result = output.to_vec();
for (a, b) in result.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn test_scale() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let scale = Scale::new(2.0);
let output = scale.apply(&input);
assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_clamp() {
let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
let clamp = Clamp::zero_one();
let output = clamp.apply(&input);
assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
}
#[test]
fn test_flatten() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let flatten = Flatten::new();
let output = flatten.apply(&input);
assert_eq!(output.shape(), &[4]);
assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_reshape() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
let reshape = Reshape::new(vec![2, 3]);
let output = reshape.apply(&input);
assert_eq!(output.shape(), &[2, 3]);
}
#[test]
fn test_compose() {
let normalize = Normalize::new(0.0, 1.0);
let scale = Scale::new(2.0);
let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let output = compose.apply(&input);
assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_compose_builder() {
let compose = Compose::empty()
.add(Normalize::new(0.0, 1.0))
.add(Scale::new(2.0));
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let output = compose.apply(&input);
assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_random_noise() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let noise = RandomNoise::new(0.0);
let output = noise.apply(&input);
assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_random_flip_1d() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
let flip = RandomFlip::new(0, 1.0);
let output = flip.apply(&input);
assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
}
#[test]
fn test_random_flip_2d_horizontal() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let flip = RandomFlip::new(1, 1.0);
let output = flip.apply(&input);
assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
}
#[test]
fn test_random_flip_2d_vertical() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let flip = RandomFlip::new(0, 1.0);
let output = flip.apply(&input);
assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
}
#[test]
fn test_dropout_transform() {
let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
let dropout = DropoutTransform::new(0.5);
let output = dropout.apply(&input);
let output_vec = output.to_vec();
let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
assert!(
zeros > 300 && zeros < 700,
"Expected ~500 zeros, got {zeros}"
);
let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
for x in nonzeros {
assert!((x - 2.0).abs() < 1e-6);
}
}
#[test]
fn test_lambda() {
let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let output = lambda.apply(&input);
assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
}
#[test]
fn test_to_tensor() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let to_tensor = ToTensor::new();
let output = to_tensor.apply(&input);
assert_eq!(output.to_vec(), input.to_vec());
}
#[test]
fn test_normalize_variants() {
let standard = Normalize::standard();
assert_eq!(standard.mean, 0.0);
assert_eq!(standard.std, 1.0);
let zero_centered = Normalize::zero_centered();
assert_eq!(zero_centered.mean, 0.5);
assert_eq!(zero_centered.std, 0.5);
}
#[test]
fn test_random_crop_1d() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
let crop = RandomCrop::new(vec![3]);
let output = crop.apply(&input);
assert_eq!(output.shape(), &[3]);
}
#[test]
fn test_random_crop_2d() {
let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
let crop = RandomCrop::new_2d(2, 2);
let output = crop.apply(&input);
assert_eq!(output.shape(), &[2, 2]);
let vals = output.to_vec();
assert_eq!(vals.len(), 4);
}
#[test]
fn test_random_crop_3d() {
let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
let crop = RandomCrop::new_2d(2, 2);
let output = crop.apply(&input);
assert_eq!(output.shape(), &[2, 2, 2]); }
}