use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use crate::vision::{Image, ImageFormat};
use num_traits::Float;
use rand::Rng;
pub trait Transform<T: Float>: std::fmt::Debug {
fn apply(&self, image: &Image<T>) -> RusTorchResult<Image<T>>;
}
#[derive(Debug, Clone)]
pub struct Resize {
pub size: (usize, usize),
pub interpolation: InterpolationMode,
}
#[derive(Debug, Clone, Copy)]
pub enum InterpolationMode {
Nearest,
Bilinear,
Bicubic,
}
impl Resize {
pub fn new(size: (usize, usize)) -> Self {
Self {
size,
interpolation: InterpolationMode::Bilinear,
}
}
pub fn with_interpolation(mut self, mode: InterpolationMode) -> Self {
self.interpolation = mode;
self
}
}
impl<T: Float + From<f32> + 'static + std::fmt::Debug> Transform<T> for Resize {
fn apply(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
let (target_height, target_width) = self.size;
let new_shape = match image.format {
ImageFormat::CHW => vec![image.channels, target_height, target_width],
ImageFormat::HWC => vec![target_height, target_width, image.channels],
};
let resized_data = Tensor::zeros(&new_shape);
Image::new(resized_data, image.format).map_err(|e| e.into())
}
}
#[derive(Debug, Clone)]
pub struct CenterCrop {
pub size: (usize, usize),
}
impl CenterCrop {
pub fn new(size: (usize, usize)) -> Self {
Self { size }
}
}
impl<T: Float + From<f32> + 'static + std::fmt::Debug> Transform<T> for CenterCrop {
fn apply(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
let (crop_height, crop_width) = self.size;
if crop_height > image.height || crop_width > image.width {
return Err(RusTorchError::InvalidTransformParams(format!(
"Crop size ({}, {}) larger than image size ({}, {})",
crop_height, crop_width, image.height, image.width
))
.into());
}
let _start_y = (image.height - crop_height) / 2;
let _start_x = (image.width - crop_width) / 2;
let new_shape = match image.format {
ImageFormat::CHW => vec![image.channels, crop_height, crop_width],
ImageFormat::HWC => vec![crop_height, crop_width, image.channels],
};
let cropped_data = Tensor::zeros(&new_shape);
Image::new(cropped_data, image.format).map_err(|e| e.into())
}
}
#[derive(Debug, Clone)]
pub struct RandomCrop {
pub size: (usize, usize),
pub padding: Option<(usize, usize)>,
}
impl RandomCrop {
pub fn new(size: (usize, usize)) -> Self {
Self {
size,
padding: None,
}
}
pub fn with_padding(mut self, padding: (usize, usize)) -> Self {
self.padding = Some(padding);
self
}
}
impl<T: Float + From<f32> + 'static + std::fmt::Debug> Transform<T> for RandomCrop {
fn apply(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
let (crop_height, crop_width) = self.size;
let mut rng = rand::thread_rng();
let working_image = if let Some((_pad_h, _pad_w)) = self.padding {
image.clone()
} else {
image.clone()
};
if crop_height > working_image.height || crop_width > working_image.width {
return Err(RusTorchError::InvalidTransformParams(format!(
"Crop size ({}, {}) larger than image size ({}, {})",
crop_height, crop_width, working_image.height, working_image.width
))
.into());
}
let max_y = working_image.height - crop_height;
let max_x = working_image.width - crop_width;
let _start_y = rng.gen_range(0..=max_y);
let _start_x = rng.gen_range(0..=max_x);
let new_shape = match image.format {
ImageFormat::CHW => vec![image.channels, crop_height, crop_width],
ImageFormat::HWC => vec![crop_height, crop_width, image.channels],
};
let cropped_data = Tensor::zeros(&new_shape);
Image::new(cropped_data, image.format).map_err(|e| e.into())
}
}
#[derive(Debug, Clone)]
pub struct RandomHorizontalFlip {
pub probability: f64,
}
impl RandomHorizontalFlip {
pub fn new(probability: f64) -> Self {
Self { probability }
}
}
impl<T: Float + From<f32> + 'static + std::fmt::Debug> Transform<T> for RandomHorizontalFlip {
fn apply(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
let mut rng = rand::thread_rng();
if rng.gen::<f64>() < self.probability {
Ok(image.clone())
} else {
Ok(image.clone())
}
}
}
#[derive(Debug, Clone)]
pub struct RandomRotation {
pub degrees: (f64, f64),
pub fill: Option<f64>,
}
impl RandomRotation {
pub fn new(degrees: (f64, f64)) -> Self {
Self {
degrees,
fill: None,
}
}
pub fn with_fill(mut self, fill: f64) -> Self {
self.fill = Some(fill);
self
}
}
impl<T: Float + From<f32> + From<f64> + 'static + std::fmt::Debug> Transform<T> for RandomRotation {
fn apply(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
let mut rng = rand::thread_rng();
let _angle = rng.gen_range(self.degrees.0..=self.degrees.1);
Ok(image.clone())
}
}
#[derive(Debug, Clone)]
pub struct Normalize<T: Float> {
pub mean: Vec<T>,
pub std: Vec<T>,
}
impl<T: Float + From<f32> + Copy> Normalize<T> {
pub fn new(mean: Vec<T>, std: Vec<T>) -> RusTorchResult<Self> {
if mean.len() != std.len() {
return Err(RusTorchError::InvalidTransformParams(
"Mean and std must have same length".to_string(),
)
.into());
}
Ok(Self { mean, std })
}
pub fn imagenet() -> Self {
Self {
mean: vec![
<T as From<f32>>::from(0.485),
<T as From<f32>>::from(0.456),
<T as From<f32>>::from(0.406),
],
std: vec![
<T as From<f32>>::from(0.229),
<T as From<f32>>::from(0.224),
<T as From<f32>>::from(0.225),
],
}
}
}
impl<T: Float + From<f32> + Copy + 'static + std::fmt::Debug> Transform<T> for Normalize<T> {
fn apply(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
if self.mean.len() != image.channels {
return Err(RusTorchError::InvalidTransformParams(format!(
"Mean length {} doesn't match image channels {}",
self.mean.len(),
image.channels
))
.into());
}
Ok(image.clone())
}
}
#[derive(Debug, Clone)]
pub struct ToTensor {
pub format: ImageFormat,
}
impl ToTensor {
pub fn new() -> Self {
Self {
format: ImageFormat::CHW,
}
}
pub fn with_format(mut self, format: ImageFormat) -> Self {
self.format = format;
self
}
}
impl Default for ToTensor {
fn default() -> Self {
Self::new()
}
}
impl<T: Float + From<f32> + 'static + std::fmt::Debug> Transform<T> for ToTensor {
fn apply(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
image.to_format(self.format).map_err(|e| e.into())
}
}
#[derive(Debug)]
pub struct Compose<T: Float> {
pub transforms: Vec<Box<dyn Transform<T>>>,
}
impl<T: Float> Compose<T> {
pub fn new(transforms: Vec<Box<dyn Transform<T>>>) -> Self {
Self { transforms }
}
}
impl<T: Float + 'static + std::fmt::Debug> Transform<T> for Compose<T> {
fn apply(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
let mut result = image.clone();
for transform in &self.transforms {
result = transform.apply(&result)?;
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resize_creation() {
let resize = Resize::new((224, 224));
assert_eq!(resize.size, (224, 224));
}
#[test]
fn test_center_crop_creation() {
let crop = CenterCrop::new((224, 224));
assert_eq!(crop.size, (224, 224));
}
#[test]
fn test_random_crop_creation() {
let crop = RandomCrop::new((224, 224)).with_padding((4, 4));
assert_eq!(crop.size, (224, 224));
assert_eq!(crop.padding, Some((4, 4)));
}
#[test]
fn test_normalize_creation() {
let normalize = Normalize::new(vec![0.5f32], vec![0.5f32]).unwrap();
assert_eq!(normalize.mean, vec![0.5f32]);
assert_eq!(normalize.std, vec![0.5f32]);
}
#[test]
fn test_normalize_imagenet() {
let normalize: Normalize<f32> = Normalize::imagenet();
assert_eq!(normalize.mean.len(), 3);
assert_eq!(normalize.std.len(), 3);
}
#[test]
fn test_to_tensor_creation() {
let to_tensor = ToTensor::new();
assert_eq!(to_tensor.format, ImageFormat::CHW);
}
}