torsh-vision 0.1.2

Computer vision utilities for ToRSh deep learning framework
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
/*!
# ToRSh Vision Transforms

A comprehensive collection of image transformation and data augmentation utilities for computer vision tasks.
This module provides PyTorch-compatible transforms with enhanced functionality and performance optimizations.

## Organization

The transforms module is organized into several sub-modules:

- [`core`] - Core Transform trait and composition utilities
- [`basic`] - Fundamental transforms (Resize, Crop, Normalize, etc.)
- [`random`] - Random/probabilistic transforms for data augmentation
- [`augmentation`] - Advanced augmentation techniques (ColorJitter, RandomErasing, etc.)
- [`mixing`] - Data mixing techniques (MixUp, CutMix)
- [`automated`] - Automated augmentation strategies (AutoAugment, RandAugment)
- [`sophisticated`] - State-of-the-art augmentation methods (AugMix, GridMask, Mosaic)
- [`mod@registry`] - Transform registration and builder patterns
- [`presets`] - Common transform configurations for popular datasets

## Quick Start

### Basic Usage

```rust
use torsh_vision::transforms::{Compose, Resize, Normalize, RandomHorizontalFlip};

// Create a simple pipeline
let transforms = vec![
    Box::new(Resize::new((224, 224))) as Box<dyn Transform>,
    Box::new(RandomHorizontalFlip::new(0.5)),
    Box::new(Normalize::imagenet()),
];
let pipeline = Compose::new(transforms);
```

### Using the Builder Pattern

```rust
use torsh_vision::transforms::{TransformBuilder, presets};

// ImageNet training pipeline
let train_transforms = TransformBuilder::new()
    .resize((256, 256))
    .random_horizontal_flip(0.5)
    .center_crop((224, 224))
    .imagenet_normalize()
    .build();

// Or use presets
let train_transforms = presets::imagenet_train(224);
```

### Advanced Augmentation

```rust
use torsh_vision::transforms::{RandAugment, ColorJitter, RandomErasing};

// RandAugment for automated augmentation
let rand_aug = RandAugment::new(2, 5.0);

// Manual augmentation pipeline
let strong_aug = TransformBuilder::new()
    .resize((256, 256))
    .add(ColorJitter::new().brightness(0.4).contrast(0.4))
    .add(RandomErasing::new(0.25))
    .random_horizontal_flip(0.5)
    .center_crop((224, 224))
    .imagenet_normalize()
    .build();
```

## Key Features

- **SciRS2 Integration**: Full compliance with SciRS2 random number generation and array operations
- **Comprehensive Transform Library**: Over 20 different transforms covering basic to advanced techniques
- **Builder Pattern**: Fluent API for creating transform pipelines
- **Preset Configurations**: Ready-to-use configurations for common datasets and tasks
- **Advanced Techniques**: Support for state-of-the-art methods like AugMix, GridMask, and data mixing
- **Type Safety**: Strong typing with comprehensive error handling
- **Performance Optimized**: Efficient implementations with minimal memory allocation
- **Extensive Testing**: Comprehensive test coverage for all transforms
*/

//
// Module declarations
//

pub mod augmentation;
pub mod automated;
pub mod basic;
pub mod core;
pub mod mixing;
pub mod presets;
pub mod random;
pub mod registry;
pub mod sophisticated;
pub mod unified;

//
// Core exports
//

pub use core::{Compose, Transform};

//
// Basic transforms
//

pub use basic::{CenterCrop, Normalize, Pad, Resize, ToTensor};

//
// Random transforms
//

pub use random::{
    RandomCrop, RandomHorizontalFlip, RandomResizedCrop, RandomRotation, RandomVerticalFlip,
    Rotation,
};

//
// Augmentation transforms
//

pub use augmentation::{ColorJitter, Cutout, GaussianBlur, RandomErasing};

//
// Mixing techniques
//

pub use mixing::{CutMix, MixUp};

//
// Automated augmentation
//

pub use automated::{AutoAugment, RandAugment};

//
// Sophisticated techniques
//

pub use sophisticated::{AugMix, GridMask, Mosaic};

//
// Registry and builder patterns
//

pub use registry::{TransformBuilder, TransformIntrospection, TransformRegistry, TransformStats};

//
// Preset configurations
//

pub use presets::*;

//
// Utility functions and convenience constructors
//

/// Create a simple ImageNet training pipeline
///
/// # Arguments
///
/// * `size` - Target image size (will be square)
///
/// # Returns
///
/// A Compose transform ready for ImageNet training
pub fn imagenet_train(size: usize) -> Compose {
    presets::presets::imagenet_train(size)
}

/// Create a simple ImageNet validation pipeline
///
/// # Arguments
///
/// * `size` - Target image size (will be square)
///
/// # Returns
///
/// A Compose transform ready for ImageNet validation/inference
pub fn imagenet_val(size: usize) -> Compose {
    presets::presets::imagenet_val(size)
}

/// Create a CIFAR training pipeline
///
/// # Returns
///
/// A Compose transform ready for CIFAR-10/100 training
pub fn cifar_train() -> Compose {
    presets::presets::cifar_train()
}

/// Create a CIFAR validation pipeline
///
/// # Returns
///
/// A Compose transform ready for CIFAR-10/100 validation/inference
pub fn cifar_val() -> Compose {
    presets::presets::cifar_val()
}

/// Create a strong augmentation pipeline
///
/// # Arguments
///
/// * `size` - Target image size
///
/// # Returns
///
/// A Compose transform with heavy augmentation for robust training
pub fn strong_augment(size: usize) -> Compose {
    presets::presets::strong_augment(size)
}

/// Create a transform builder
///
/// # Returns
///
/// A new TransformBuilder for creating custom pipelines
pub fn builder() -> TransformBuilder {
    TransformBuilder::new()
}

/// Create a transform registry
///
/// # Returns
///
/// A new TransformRegistry with default transforms registered
pub fn registry() -> TransformRegistry {
    TransformRegistry::new()
}

//
// Type aliases for convenience
//

/// Convenience type alias for a boxed transform
pub type BoxedTransform = Box<dyn Transform>;

/// Convenience type alias for a vector of boxed transforms
pub type TransformVec = Vec<BoxedTransform>;

//
// Trait implementations for common conversions
//

impl From<Vec<BoxedTransform>> for Compose {
    fn from(transforms: Vec<BoxedTransform>) -> Self {
        Compose::new(transforms)
    }
}

//
// Re-exports from sub-modules for convenience
//

pub use core::*;

#[cfg(test)]
mod tests {
    use super::*;
    use torsh_tensor::creation;

    #[test]
    fn test_module_exports() {
        // Test that all major exports are accessible
        let _resize = Resize::new((224, 224));
        let _normalize = Normalize::imagenet();
        let _flip = RandomHorizontalFlip::new(0.5);
        let _jitter = ColorJitter::new();
        let _mixup = MixUp::new(1.0);
        let _autoaug = AutoAugment::new();
        let _augmix = AugMix::new();
        let _builder = TransformBuilder::new();
        let _registry = TransformRegistry::new();
    }

    #[test]
    fn test_convenience_functions() {
        // Test convenience constructors
        let train = imagenet_train(224);
        let val = imagenet_val(224);
        let cifar_tr = cifar_train();
        let cifar_v = cifar_val();
        let strong = strong_augment(224);

        assert!(!train.is_empty());
        assert!(!val.is_empty());
        assert!(!cifar_tr.is_empty());
        assert!(!cifar_v.is_empty());
        assert!(!strong.is_empty());

        let _builder = builder();
        let _reg = registry();
    }

    #[test]
    fn test_type_aliases() {
        let transform: BoxedTransform = Box::new(Resize::new((224, 224)));
        let transforms: TransformVec = vec![
            Box::new(Resize::new((224, 224))),
            Box::new(Normalize::imagenet()),
        ];

        assert_eq!(transform.name(), "Resize");
        assert_eq!(transforms.len(), 2);
    }

    #[test]
    fn test_compose_from_vec() {
        let transforms: TransformVec = vec![
            Box::new(Resize::new((224, 224))),
            Box::new(RandomHorizontalFlip::new(0.5)),
            Box::new(Normalize::imagenet()),
        ];

        let compose: Compose = transforms.into();
        assert_eq!(compose.len(), 3);
    }

    #[test]
    fn test_full_pipeline() {
        // Test a complete pipeline using the public API
        let input = creation::ones(&[3, 256, 256]).expect("creation should succeed");

        let pipeline = builder()
            .resize((224, 224))
            .random_horizontal_flip(0.5)
            .add(ColorJitter::new().brightness(0.1))
            .imagenet_normalize()
            .build();

        let result = pipeline.forward(&input);
        assert!(result.is_ok());

        let output = result.expect("operation should succeed");
        assert_eq!(output.shape().dims(), &[3, 224, 224]);
    }

    #[test]
    fn test_preset_pipelines() {
        let input = creation::ones(&[3, 256, 256]).expect("creation should succeed");

        // Test all preset pipelines
        let presets = vec![imagenet_train(224), imagenet_val(224), strong_augment(224)];

        for preset in presets {
            let result = preset.forward(&input);
            assert!(result.is_ok());
        }
    }

    #[test]
    #[ignore] // Fails in parallel execution due to shared RNG state
    fn test_advanced_transforms() {
        let input = creation::ones(&[3, 224, 224]).expect("creation should succeed");

        // Test advanced transforms work
        let rand_aug = RandAugment::new(2, 5.0);
        let result = rand_aug.forward(&input);
        assert!(result.is_ok());

        let augmix = AugMix::new();
        let result = augmix.forward(&input);
        assert!(result.is_ok());

        let gridmask = GridMask::new();
        let result = gridmask.forward(&input);
        assert!(result.is_ok());
    }

    #[test]
    #[ignore] // Fails in parallel execution due to shared RNG state
    fn test_mixing_transforms() {
        let input1 = creation::ones(&[3, 32, 32]).expect("creation should succeed");
        let input2 = creation::zeros(&[3, 32, 32]).expect("creation should succeed");

        let mixup = MixUp::new(1.0);
        let result = mixup.apply_pair(&input1, &input2, 0, 1, 10);
        assert!(result.is_ok());

        let cutmix = CutMix::new(1.0);
        let result = cutmix.apply_pair(&input1, &input2, 0, 1, 10);
        assert!(result.is_ok());
    }

    #[test]
    fn test_introspection() {
        let pipeline = builder()
            .resize((224, 224))
            .random_horizontal_flip(0.5)
            .imagenet_normalize()
            .build();

        let description = pipeline.describe();
        assert!(description.contains("Resize"));
        assert!(description.contains("RandomHorizontalFlip"));
        assert!(description.contains("Normalize"));

        let stats = pipeline.statistics();
        assert_eq!(stats.total_transforms, 3);

        let validation = pipeline.validate();
        assert!(validation.is_ok());
    }
}