pub mod augmentation;
pub mod automated;
pub mod basic;
pub mod core;
pub mod mixing;
pub mod presets;
pub mod random;
pub mod registry;
pub mod sophisticated;
pub mod unified;
pub use core::{Compose, Transform};
pub use basic::{CenterCrop, Normalize, Pad, Resize, ToTensor};
pub use random::{
RandomCrop, RandomHorizontalFlip, RandomResizedCrop, RandomRotation, RandomVerticalFlip,
Rotation,
};
pub use augmentation::{ColorJitter, Cutout, GaussianBlur, RandomErasing};
pub use mixing::{CutMix, MixUp};
pub use automated::{AutoAugment, RandAugment};
pub use sophisticated::{AugMix, GridMask, Mosaic};
pub use registry::{TransformBuilder, TransformIntrospection, TransformRegistry, TransformStats};
pub use presets::*;
pub fn imagenet_train(size: usize) -> Compose {
presets::presets::imagenet_train(size)
}
pub fn imagenet_val(size: usize) -> Compose {
presets::presets::imagenet_val(size)
}
pub fn cifar_train() -> Compose {
presets::presets::cifar_train()
}
pub fn cifar_val() -> Compose {
presets::presets::cifar_val()
}
pub fn strong_augment(size: usize) -> Compose {
presets::presets::strong_augment(size)
}
pub fn builder() -> TransformBuilder {
TransformBuilder::new()
}
pub fn registry() -> TransformRegistry {
TransformRegistry::new()
}
pub type BoxedTransform = Box<dyn Transform>;
pub type TransformVec = Vec<BoxedTransform>;
impl From<Vec<BoxedTransform>> for Compose {
fn from(transforms: Vec<BoxedTransform>) -> Self {
Compose::new(transforms)
}
}
pub use core::*;
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation;
#[test]
fn test_module_exports() {
let _resize = Resize::new((224, 224));
let _normalize = Normalize::imagenet();
let _flip = RandomHorizontalFlip::new(0.5);
let _jitter = ColorJitter::new();
let _mixup = MixUp::new(1.0);
let _autoaug = AutoAugment::new();
let _augmix = AugMix::new();
let _builder = TransformBuilder::new();
let _registry = TransformRegistry::new();
}
#[test]
fn test_convenience_functions() {
let train = imagenet_train(224);
let val = imagenet_val(224);
let cifar_tr = cifar_train();
let cifar_v = cifar_val();
let strong = strong_augment(224);
assert!(!train.is_empty());
assert!(!val.is_empty());
assert!(!cifar_tr.is_empty());
assert!(!cifar_v.is_empty());
assert!(!strong.is_empty());
let _builder = builder();
let _reg = registry();
}
#[test]
fn test_type_aliases() {
let transform: BoxedTransform = Box::new(Resize::new((224, 224)));
let transforms: TransformVec = vec![
Box::new(Resize::new((224, 224))),
Box::new(Normalize::imagenet()),
];
assert_eq!(transform.name(), "Resize");
assert_eq!(transforms.len(), 2);
}
#[test]
fn test_compose_from_vec() {
let transforms: TransformVec = vec![
Box::new(Resize::new((224, 224))),
Box::new(RandomHorizontalFlip::new(0.5)),
Box::new(Normalize::imagenet()),
];
let compose: Compose = transforms.into();
assert_eq!(compose.len(), 3);
}
#[test]
fn test_full_pipeline() {
let input = creation::ones(&[3, 256, 256]).expect("creation should succeed");
let pipeline = builder()
.resize((224, 224))
.random_horizontal_flip(0.5)
.add(ColorJitter::new().brightness(0.1))
.imagenet_normalize()
.build();
let result = pipeline.forward(&input);
assert!(result.is_ok());
let output = result.expect("operation should succeed");
assert_eq!(output.shape().dims(), &[3, 224, 224]);
}
#[test]
fn test_preset_pipelines() {
let input = creation::ones(&[3, 256, 256]).expect("creation should succeed");
let presets = vec![imagenet_train(224), imagenet_val(224), strong_augment(224)];
for preset in presets {
let result = preset.forward(&input);
assert!(result.is_ok());
}
}
#[test]
#[ignore] fn test_advanced_transforms() {
let input = creation::ones(&[3, 224, 224]).expect("creation should succeed");
let rand_aug = RandAugment::new(2, 5.0);
let result = rand_aug.forward(&input);
assert!(result.is_ok());
let augmix = AugMix::new();
let result = augmix.forward(&input);
assert!(result.is_ok());
let gridmask = GridMask::new();
let result = gridmask.forward(&input);
assert!(result.is_ok());
}
#[test]
#[ignore] fn test_mixing_transforms() {
let input1 = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let input2 = creation::zeros(&[3, 32, 32]).expect("creation should succeed");
let mixup = MixUp::new(1.0);
let result = mixup.apply_pair(&input1, &input2, 0, 1, 10);
assert!(result.is_ok());
let cutmix = CutMix::new(1.0);
let result = cutmix.apply_pair(&input1, &input2, 0, 1, 10);
assert!(result.is_ok());
}
#[test]
fn test_introspection() {
let pipeline = builder()
.resize((224, 224))
.random_horizontal_flip(0.5)
.imagenet_normalize()
.build();
let description = pipeline.describe();
assert!(description.contains("Resize"));
assert!(description.contains("RandomHorizontalFlip"));
assert!(description.contains("Normalize"));
let stats = pipeline.statistics();
assert_eq!(stats.total_transforms, 3);
let validation = pipeline.validate();
assert!(validation.is_ok());
}
}