use crate::vision::pipeline::{predicates, ExecutionMode, Pipeline, PipelineBuilder};
use crate::vision::transforms::{
CenterCrop, InterpolationMode, Normalize, RandomCrop, RandomHorizontalFlip, Resize, ToTensor,
};
use num_traits::Float;
pub struct ImageNetPreprocessing;
impl ImageNetPreprocessing {
pub fn training<T: Float + From<f32> + Clone + 'static + std::fmt::Debug>() -> Pipeline<T> {
PipelineBuilder::new("imagenet_training".to_string())
.transform(Box::new(Resize::new((256, 256))))
.transform(Box::new(RandomCrop::new((224, 224)).with_padding((4, 4))))
.transform(Box::new(RandomHorizontalFlip::new(0.5)))
.transform(Box::new(ToTensor::new()))
.transform(Box::new(Normalize::imagenet()))
.cache(500)
.execution_mode(ExecutionMode::Sequential)
.build()
}
pub fn validation<T: Float + From<f32> + Clone + 'static + std::fmt::Debug>() -> Pipeline<T> {
PipelineBuilder::new("imagenet_validation".to_string())
.transform(Box::new(Resize::new((256, 256))))
.transform(Box::new(CenterCrop::new((224, 224))))
.transform(Box::new(ToTensor::new()))
.transform(Box::new(Normalize::imagenet()))
.cache(200)
.execution_mode(ExecutionMode::Sequential)
.build()
}
}
pub struct CIFARPreprocessing;
impl CIFARPreprocessing {
pub fn training<T: Float + From<f32> + Clone + 'static + std::fmt::Debug>() -> Pipeline<T> {
PipelineBuilder::new("cifar_training".to_string())
.transform(Box::new(RandomCrop::new((32, 32)).with_padding((4, 4))))
.transform(Box::new(RandomHorizontalFlip::new(0.5)))
.transform(Box::new(ToTensor::new()))
.transform(Box::new(
Normalize::new(
vec![
<T as From<f32>>::from(0.4914),
<T as From<f32>>::from(0.4822),
<T as From<f32>>::from(0.4465),
],
vec![
<T as From<f32>>::from(0.2023),
<T as From<f32>>::from(0.1994),
<T as From<f32>>::from(0.2010),
],
)
.unwrap(),
))
.cache(1000)
.execution_mode(ExecutionMode::Sequential)
.build()
}
pub fn validation<T: Float + From<f32> + Clone + 'static + std::fmt::Debug>() -> Pipeline<T> {
PipelineBuilder::new("cifar_validation".to_string())
.transform(Box::new(ToTensor::new()))
.transform(Box::new(
Normalize::new(
vec![
<T as From<f32>>::from(0.4914),
<T as From<f32>>::from(0.4822),
<T as From<f32>>::from(0.4465),
],
vec![
<T as From<f32>>::from(0.2023),
<T as From<f32>>::from(0.1994),
<T as From<f32>>::from(0.2010),
],
)
.unwrap(),
))
.cache(500)
.execution_mode(ExecutionMode::Sequential)
.build()
}
}
pub struct ObjectDetectionPreprocessing;
impl ObjectDetectionPreprocessing {
pub fn coco_training<T: Float + From<f32> + Clone + 'static + std::fmt::Debug>() -> Pipeline<T>
{
PipelineBuilder::new("coco_detection_training".to_string())
.conditional_transform(
Box::new(Resize::new((800, 800))),
predicates::min_size(800, 800),
"resize_large_images".to_string(),
)
.transform(Box::new(RandomHorizontalFlip::new(0.5)))
.transform(Box::new(ToTensor::new()))
.transform(Box::new(Normalize::imagenet()))
.cache(200)
.execution_mode(ExecutionMode::Sequential)
.build()
}
}
pub struct SegmentationPreprocessing;
impl SegmentationPreprocessing {
pub fn semantic_training<T: Float + From<f32> + Clone + 'static + std::fmt::Debug>(
) -> Pipeline<T> {
PipelineBuilder::new("segmentation_training".to_string())
.transform(Box::new(Resize::new((512, 512))))
.transform(Box::new(RandomHorizontalFlip::new(0.5)))
.transform(Box::new(ToTensor::new()))
.transform(Box::new(Normalize::imagenet()))
.cache(100)
.execution_mode(ExecutionMode::Sequential)
.build()
}
}
pub struct MedicalImagingPreprocessing;
impl MedicalImagingPreprocessing {
pub fn xray_preprocessing<T: Float + From<f32> + Clone + 'static + std::fmt::Debug>(
) -> Pipeline<T> {
PipelineBuilder::new("medical_xray".to_string())
.transform(Box::new(Resize::new((512, 512))))
.transform(Box::new(CenterCrop::new((448, 448))))
.conditional_transform(
Box::new(ToTensor::new()),
predicates::channels_eq(1),
"grayscale_tensor".to_string(),
)
.transform(Box::new(
Normalize::new(
vec![<T as From<f32>>::from(0.449)],
vec![<T as From<f32>>::from(0.226)],
)
.unwrap(),
))
.cache(50)
.execution_mode(ExecutionMode::Sequential)
.build()
}
}
pub struct MobileOptimizedPreprocessing;
impl MobileOptimizedPreprocessing {
pub fn mobile_inference<T: Float + From<f32> + Clone + 'static + std::fmt::Debug>(
) -> Pipeline<T> {
PipelineBuilder::new("mobile_inference".to_string())
.transform(Box::new(Resize::new((224, 224))))
.transform(Box::new(CenterCrop::new((224, 224))))
.transform(Box::new(ToTensor::new()))
.transform(Box::new(Normalize::imagenet()))
.cache(10) .execution_mode(ExecutionMode::Batch)
.build()
}
}
pub struct CustomPipelineFactory;
impl CustomPipelineFactory {
pub fn high_resolution<T: Float + From<f32> + Clone + 'static + std::fmt::Debug>(
target_size: (usize, usize),
enable_augmentation: bool,
) -> Pipeline<T> {
let mut builder = PipelineBuilder::new("high_resolution_custom".to_string())
.transform(Box::new(
Resize::new(target_size).with_interpolation(InterpolationMode::Bicubic),
))
.transform(Box::new(CenterCrop::new(target_size)));
if enable_augmentation {
builder = builder.transform(Box::new(RandomHorizontalFlip::new(0.3)));
}
builder
.transform(Box::new(ToTensor::new()))
.transform(Box::new(Normalize::imagenet()))
.cache(50)
.execution_mode(ExecutionMode::Sequential)
.build()
}
pub fn probabilistic_augmentation<T: Float + From<f32> + Clone + 'static + std::fmt::Debug>(
base_size: (usize, usize),
augment_probability: f64,
) -> Pipeline<T> {
PipelineBuilder::new("probabilistic_augmentation".to_string())
.transform(Box::new(Resize::new(base_size)))
.conditional_transform(
Box::new(RandomCrop::new(base_size).with_padding((8, 8))),
predicates::probability(augment_probability),
"random_crop_probabilistic".to_string(),
)
.conditional_transform(
Box::new(RandomHorizontalFlip::new(1.0)), predicates::probability(augment_probability),
"random_flip_probabilistic".to_string(),
)
.transform(Box::new(ToTensor::new()))
.transform(Box::new(Normalize::imagenet()))
.cache(200)
.execution_mode(ExecutionMode::Sequential)
.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_imagenet_preprocessing() {
let training_pipeline = ImageNetPreprocessing::training::<f32>();
let validation_pipeline = ImageNetPreprocessing::validation::<f32>();
assert_eq!(training_pipeline.name(), "imagenet_training");
assert_eq!(validation_pipeline.name(), "imagenet_validation");
assert_eq!(training_pipeline.len(), 5);
assert_eq!(validation_pipeline.len(), 4);
}
#[test]
fn test_cifar_preprocessing() {
let training_pipeline = CIFARPreprocessing::training::<f32>();
let validation_pipeline = CIFARPreprocessing::validation::<f32>();
assert_eq!(training_pipeline.name(), "cifar_training");
assert_eq!(validation_pipeline.name(), "cifar_validation");
}
#[test]
fn test_object_detection_preprocessing() {
let coco_pipeline = ObjectDetectionPreprocessing::coco_training::<f32>();
assert_eq!(coco_pipeline.name(), "coco_detection_training");
}
#[test]
fn test_custom_pipeline_factory() {
let high_res_pipeline = CustomPipelineFactory::high_resolution::<f32>((512, 512), true);
let prob_aug_pipeline =
CustomPipelineFactory::probabilistic_augmentation::<f32>((224, 224), 0.5);
assert_eq!(high_res_pipeline.name(), "high_resolution_custom");
assert_eq!(prob_aug_pipeline.name(), "probabilistic_augmentation");
}
#[test]
fn test_medical_imaging_preprocessing() {
let xray_pipeline = MedicalImagingPreprocessing::xray_preprocessing::<f32>();
assert_eq!(xray_pipeline.name(), "medical_xray");
}
#[test]
fn test_mobile_optimized_preprocessing() {
let mobile_pipeline = MobileOptimizedPreprocessing::mobile_inference::<f32>();
assert_eq!(mobile_pipeline.name(), "mobile_inference");
let (_, max_cache) = mobile_pipeline.cache_info();
assert_eq!(max_cache, 10);
}
}