use axonml_tensor::Tensor;
use rand::Rng;
pub trait Transform: Send + Sync {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
}
pub struct Compose {
transforms: Vec<Box<dyn Transform>>,
}
impl Compose {
#[must_use]
pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
Self { transforms }
}
#[must_use]
pub fn empty() -> Self {
Self {
transforms: Vec::new(),
}
}
pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
self.transforms.push(Box::new(transform));
self
}
}
impl Transform for Compose {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let mut result = input.clone();
for transform in &self.transforms {
result = transform.apply(&result);
}
result
}
}
pub struct ToTensor;
impl ToTensor {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Default for ToTensor {
fn default() -> Self {
Self::new()
}
}
impl Transform for ToTensor {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
input.clone()
}
}
pub struct Normalize {
mean: Vec<f32>,
std: Vec<f32>,
}
impl Normalize {
#[must_use]
pub fn new(mean: f32, std: f32) -> Self {
Self {
mean: vec![mean],
std: vec![std],
}
}
#[must_use]
pub fn per_channel(mean: Vec<f32>, std: Vec<f32>) -> Self {
assert_eq!(mean.len(), std.len(), "mean and std must have same length");
Self { mean, std }
}
#[must_use]
pub fn standard() -> Self {
Self::new(0.0, 1.0)
}
#[must_use]
pub fn zero_centered() -> Self {
Self::new(0.5, 0.5)
}
#[must_use]
pub fn imagenet() -> Self {
Self::per_channel(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
}
}
impl Transform for Normalize {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let shape = input.shape();
let mut data = input.to_vec();
if self.mean.len() == 1 {
let m = self.mean[0];
let s = self.std[0];
for x in &mut data {
*x = (*x - m) / s;
}
} else {
let num_channels = self.mean.len();
if shape.len() == 3 && shape[0] == num_channels {
let spatial = shape[1] * shape[2];
for c in 0..num_channels {
let offset = c * spatial;
let m = self.mean[c];
let s = self.std[c];
for i in 0..spatial {
data[offset + i] = (data[offset + i] - m) / s;
}
}
} else if shape.len() == 4 && shape[1] == num_channels {
let spatial = shape[2] * shape[3];
let sample_size = num_channels * spatial;
for n in 0..shape[0] {
for c in 0..num_channels {
let offset = n * sample_size + c * spatial;
let m = self.mean[c];
let s = self.std[c];
for i in 0..spatial {
data[offset + i] = (data[offset + i] - m) / s;
}
}
}
} else {
let m = self.mean[0];
let s = self.std[0];
for x in &mut data {
*x = (*x - m) / s;
}
}
}
Tensor::from_vec(data, shape).unwrap()
}
}
pub struct RandomNoise {
std: f32,
}
impl RandomNoise {
#[must_use]
pub fn new(std: f32) -> Self {
Self { std }
}
}
impl Transform for RandomNoise {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
if self.std == 0.0 {
return input.clone();
}
let mut rng = rand::thread_rng();
let data = input.to_vec();
let noisy: Vec<f32> = data
.iter()
.map(|&x| {
let u1: f32 = rng.r#gen();
let u2: f32 = rng.r#gen();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
x + z * self.std
})
.collect();
Tensor::from_vec(noisy, input.shape()).unwrap()
}
}
pub struct RandomCrop {
size: Vec<usize>,
}
impl RandomCrop {
#[must_use]
pub fn new(size: Vec<usize>) -> Self {
Self { size }
}
#[must_use]
pub fn new_2d(height: usize, width: usize) -> Self {
Self::new(vec![height, width])
}
}
impl Transform for RandomCrop {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let shape = input.shape();
if shape.len() < self.size.len() {
return input.clone();
}
let spatial_start = shape.len() - self.size.len();
let mut rng = rand::thread_rng();
let mut offsets = Vec::with_capacity(self.size.len());
for (i, &target_dim) in self.size.iter().enumerate() {
let input_dim = shape[spatial_start + i];
if input_dim <= target_dim {
offsets.push(0);
} else {
offsets.push(rng.gen_range(0..=input_dim - target_dim));
}
}
let crop_sizes: Vec<usize> = self
.size
.iter()
.enumerate()
.map(|(i, &s)| s.min(shape[spatial_start + i]))
.collect();
let data = input.to_vec();
if shape.len() == 1 && self.size.len() == 1 {
let start = offsets[0];
let end = start + crop_sizes[0];
let cropped = data[start..end].to_vec();
let len = cropped.len();
return Tensor::from_vec(cropped, &[len]).unwrap();
}
if shape.len() == 2 && self.size.len() == 2 {
let (_h, w) = (shape[0], shape[1]);
let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
let (off_h, off_w) = (offsets[0], offsets[1]);
let mut cropped = Vec::with_capacity(crop_h * crop_w);
for row in off_h..off_h + crop_h {
for col in off_w..off_w + crop_w {
cropped.push(data[row * w + col]);
}
}
return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
}
if shape.len() == 3 && self.size.len() == 2 {
let (c, h, w) = (shape[0], shape[1], shape[2]);
let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
let (off_h, off_w) = (offsets[0], offsets[1]);
let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
for channel in 0..c {
for row in off_h..off_h + crop_h {
for col in off_w..off_w + crop_w {
cropped.push(data[channel * h * w + row * w + col]);
}
}
}
return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
}
if shape.len() == 4 && self.size.len() == 2 {
let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
let (off_h, off_w) = (offsets[0], offsets[1]);
let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
for batch in 0..n {
for channel in 0..c {
for row in off_h..off_h + crop_h {
for col in off_w..off_w + crop_w {
let idx = batch * c * h * w + channel * h * w + row * w + col;
cropped.push(data[idx]);
}
}
}
}
return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
}
input.clone()
}
}
pub struct RandomFlip {
dim: usize,
probability: f32,
}
impl RandomFlip {
#[must_use]
pub fn new(dim: usize, probability: f32) -> Self {
Self {
dim,
probability: probability.clamp(0.0, 1.0),
}
}
#[must_use]
pub fn horizontal() -> Self {
Self::new(1, 0.5)
}
#[must_use]
pub fn vertical() -> Self {
Self::new(0, 0.5)
}
}
impl Transform for RandomFlip {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let mut rng = rand::thread_rng();
if rng.r#gen::<f32>() > self.probability {
return input.clone();
}
let shape = input.shape();
if self.dim >= shape.len() {
return input.clone();
}
let data = input.to_vec();
let ndim = shape.len();
let total = data.len();
let mut flipped = vec![0.0f32; total];
let mut strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let dim = self.dim;
let dim_size = shape[dim];
let dim_stride = strides[dim];
for i in 0..total {
let coord_in_dim = (i / dim_stride) % dim_size;
let flipped_coord = dim_size - 1 - coord_in_dim;
let diff = flipped_coord as isize - coord_in_dim as isize;
let src = (i as isize + diff * dim_stride as isize) as usize;
flipped[i] = data[src];
}
Tensor::from_vec(flipped, shape).unwrap()
}
}
pub struct Scale {
factor: f32,
}
impl Scale {
#[must_use]
pub fn new(factor: f32) -> Self {
Self { factor }
}
}
impl Transform for Scale {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
input.mul_scalar(self.factor)
}
}
pub struct Clamp {
min: f32,
max: f32,
}
impl Clamp {
#[must_use]
pub fn new(min: f32, max: f32) -> Self {
Self { min, max }
}
#[must_use]
pub fn zero_one() -> Self {
Self::new(0.0, 1.0)
}
#[must_use]
pub fn symmetric() -> Self {
Self::new(-1.0, 1.0)
}
}
impl Transform for Clamp {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let data = input.to_vec();
let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
Tensor::from_vec(clamped, input.shape()).unwrap()
}
}
pub struct Flatten;
impl Flatten {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Default for Flatten {
fn default() -> Self {
Self::new()
}
}
impl Transform for Flatten {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let data = input.to_vec();
Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
}
}
pub struct Reshape {
shape: Vec<usize>,
}
impl Reshape {
#[must_use]
pub fn new(shape: Vec<usize>) -> Self {
Self { shape }
}
}
impl Transform for Reshape {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
let data = input.to_vec();
let expected_size: usize = self.shape.iter().product();
if data.len() != expected_size {
return input.clone();
}
Tensor::from_vec(data, &self.shape).unwrap()
}
}
pub struct DropoutTransform {
probability: f32,
training: std::sync::atomic::AtomicBool,
}
impl DropoutTransform {
#[must_use]
pub fn new(probability: f32) -> Self {
Self {
probability: probability.clamp(0.0, 1.0),
training: std::sync::atomic::AtomicBool::new(true),
}
}
pub fn set_training(&self, training: bool) {
self.training
.store(training, std::sync::atomic::Ordering::Relaxed);
}
pub fn is_training(&self) -> bool {
self.training.load(std::sync::atomic::Ordering::Relaxed)
}
}
impl Transform for DropoutTransform {
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
if !self.is_training() || self.probability == 0.0 {
return input.clone();
}
let mut rng = rand::thread_rng();
let scale = 1.0 / (1.0 - self.probability);
let data = input.to_vec();
let dropped: Vec<f32> = data
.iter()
.map(|&x| {
if rng.r#gen::<f32>() < self.probability {
0.0
} else {
x * scale
}
})
.collect();
Tensor::from_vec(dropped, input.shape()).unwrap()
}
}
pub struct Lambda<F>
where
F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
{
func: F,
}
impl<F> Lambda<F>
where
F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
{
pub fn new(func: F) -> Self {
Self { func }
}
}
impl<F> Transform for Lambda<F>
where
F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
{
fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
(self.func)(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
let normalize = Normalize::new(2.5, 0.5);
let output = normalize.apply(&input);
let expected = [-3.0, -1.0, 1.0, 3.0];
let result = output.to_vec();
for (a, b) in result.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn test_normalize_per_channel() {
let input =
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0], &[2, 2, 2]).unwrap();
let normalize = Normalize::per_channel(vec![0.0, 10.0], vec![1.0, 10.0]);
let output = normalize.apply(&input);
let result = output.to_vec();
assert!((result[0] - 1.0).abs() < 1e-6);
assert!((result[3] - 4.0).abs() < 1e-6);
assert!((result[4] - 0.0).abs() < 1e-6); assert!((result[5] - 1.0).abs() < 1e-6); }
#[test]
fn test_scale() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let scale = Scale::new(2.0);
let output = scale.apply(&input);
assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_clamp() {
let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
let clamp = Clamp::zero_one();
let output = clamp.apply(&input);
assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
}
#[test]
fn test_flatten() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let flatten = Flatten::new();
let output = flatten.apply(&input);
assert_eq!(output.shape(), &[4]);
assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_reshape() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
let reshape = Reshape::new(vec![2, 3]);
let output = reshape.apply(&input);
assert_eq!(output.shape(), &[2, 3]);
}
#[test]
fn test_compose() {
let normalize = Normalize::new(0.0, 1.0);
let scale = Scale::new(2.0);
let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let output = compose.apply(&input);
assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_compose_builder() {
let compose = Compose::empty()
.add(Normalize::new(0.0, 1.0))
.add(Scale::new(2.0));
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let output = compose.apply(&input);
assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_random_noise() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let noise = RandomNoise::new(0.0);
let output = noise.apply(&input);
assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_random_flip_1d() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
let flip = RandomFlip::new(0, 1.0);
let output = flip.apply(&input);
assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
}
#[test]
fn test_random_flip_2d_horizontal() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let flip = RandomFlip::new(1, 1.0);
let output = flip.apply(&input);
assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
}
#[test]
fn test_random_flip_2d_vertical() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let flip = RandomFlip::new(0, 1.0);
let output = flip.apply(&input);
assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
}
#[test]
fn test_random_flip_3d() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 2]).unwrap();
let flip = RandomFlip::new(2, 1.0);
let output = flip.apply(&input);
assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
assert_eq!(output.shape(), &[1, 2, 2]);
}
#[test]
fn test_random_flip_4d() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap();
let flip = RandomFlip::new(2, 1.0);
let output = flip.apply(&input);
assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
assert_eq!(output.shape(), &[1, 1, 2, 2]);
}
#[test]
fn test_dropout_eval_mode() {
let input = Tensor::from_vec(vec![1.0; 100], &[100]).unwrap();
let dropout = DropoutTransform::new(0.5);
let output_train = dropout.apply(&input);
let zeros_train = output_train.to_vec().iter().filter(|&&x| x == 0.0).count();
assert!(zeros_train > 0, "Training mode should drop elements");
dropout.set_training(false);
let output_eval = dropout.apply(&input);
assert_eq!(output_eval.to_vec(), vec![1.0; 100]);
}
#[test]
fn test_dropout_transform() {
let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
let dropout = DropoutTransform::new(0.5);
let output = dropout.apply(&input);
let output_vec = output.to_vec();
let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
assert!(
zeros > 300 && zeros < 700,
"Expected ~500 zeros, got {zeros}"
);
let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
for x in nonzeros {
assert!((x - 2.0).abs() < 1e-6);
}
}
#[test]
fn test_lambda() {
let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let output = lambda.apply(&input);
assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
}
#[test]
fn test_to_tensor() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let to_tensor = ToTensor::new();
let output = to_tensor.apply(&input);
assert_eq!(output.to_vec(), input.to_vec());
}
#[test]
fn test_normalize_variants() {
let standard = Normalize::standard();
assert_eq!(standard.mean, vec![0.0]);
assert_eq!(standard.std, vec![1.0]);
let zero_centered = Normalize::zero_centered();
assert_eq!(zero_centered.mean, vec![0.5]);
assert_eq!(zero_centered.std, vec![0.5]);
}
#[test]
fn test_random_crop_1d() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
let crop = RandomCrop::new(vec![3]);
let output = crop.apply(&input);
assert_eq!(output.shape(), &[3]);
}
#[test]
fn test_random_crop_2d() {
let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
let crop = RandomCrop::new_2d(2, 2);
let output = crop.apply(&input);
assert_eq!(output.shape(), &[2, 2]);
let vals = output.to_vec();
assert_eq!(vals.len(), 4);
}
#[test]
fn test_random_crop_3d() {
let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
let crop = RandomCrop::new_2d(2, 2);
let output = crop.apply(&input);
assert_eq!(output.shape(), &[2, 2, 2]); }
}