use crate::transforms::Transform;
use torsh_core::error::Result;
use torsh_core::{
dtype::{FloatElement, TensorElement},
error::TorshError,
};
use torsh_tensor::Tensor;
use scirs2_core::random::thread_rng;
#[derive(Debug, Clone)]
pub struct RandomHorizontalFlip {
prob: f32,
}
impl RandomHorizontalFlip {
pub fn new(prob: f32) -> Self {
assert!(
(0.0..=1.0).contains(&prob),
"Probability must be between 0 and 1"
);
Self { prob }
}
}
impl<T: FloatElement> Transform<Tensor<T>> for RandomHorizontalFlip {
type Output = Tensor<T>;
fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
let mut rng = thread_rng();
let random_val = rng.random::<f32>();
if random_val < self.prob {
self.horizontal_flip(input)
} else {
Ok(input)
}
}
fn is_deterministic(&self) -> bool {
false
}
}
impl RandomHorizontalFlip {
fn horizontal_flip<T: FloatElement>(&self, input: Tensor<T>) -> Result<Tensor<T>> {
let binding = input.shape();
let shape = binding.dims();
if shape.len() < 2 {
return Err(TorshError::InvalidArgument(
"Input tensor must have at least 2 dimensions for horizontal flip".to_string(),
));
}
Ok(input)
}
}
#[derive(Debug, Clone)]
pub struct RandomCrop {
size: (usize, usize),
padding: Option<usize>,
}
impl RandomCrop {
pub fn new(size: (usize, usize)) -> Self {
Self {
size,
padding: None,
}
}
pub fn with_padding(mut self, padding: usize) -> Self {
self.padding = Some(padding);
self
}
}
impl<T: TensorElement> Transform<Tensor<T>> for RandomCrop {
type Output = Tensor<T>;
fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
let shape = input.shape();
let dims = shape.dims();
if dims.len() < 2 {
return Err(TorshError::InvalidArgument(
"Input tensor must have at least 2 dimensions for random crop".to_string(),
));
}
let (input_height, input_width) = if dims.len() == 2 {
(dims[0], dims[1])
} else {
(dims[1], dims[2])
};
let (crop_height, crop_width) = self.size;
if crop_height > input_height || crop_width > input_width {
if let Some(padding) = self.padding {
let _new_height = input_height.max(crop_height) + 2 * padding;
let _new_width = input_width.max(crop_width) + 2 * padding;
return Ok(input);
} else {
return Err(TorshError::InvalidArgument(
format!("Crop size ({crop_height}, {crop_width}) is larger than input size ({input_height}, {input_width}) and no padding specified"),
));
}
}
let mut rng = thread_rng();
let max_y = input_height - crop_height;
let max_x = input_width - crop_width;
let _start_y = if max_y > 0 {
rng.gen_range(0..=max_y)
} else {
0
};
let _start_x = if max_x > 0 {
rng.gen_range(0..=max_x)
} else {
0
};
Ok(input)
}
fn is_deterministic(&self) -> bool {
false
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum InterpolationMode {
Nearest,
Linear,
Bilinear,
Bicubic,
}
impl Default for InterpolationMode {
fn default() -> Self {
Self::Bilinear
}
}
#[derive(Debug, Clone)]
pub struct Resize {
size: (usize, usize),
interpolation: InterpolationMode,
}
impl Resize {
pub fn new(size: (usize, usize)) -> Self {
Self {
size,
interpolation: InterpolationMode::Bilinear,
}
}
pub fn with_interpolation(mut self, mode: InterpolationMode) -> Self {
self.interpolation = mode;
self
}
}
impl<T: FloatElement> Transform<Tensor<T>> for Resize {
type Output = Tensor<T>;
fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
let shape = input.shape();
let dims = shape.dims();
if dims.len() < 2 {
return Err(TorshError::InvalidArgument(
"Input tensor must have at least 2 dimensions for resize".to_string(),
));
}
let (input_height, input_width) = if dims.len() == 2 {
(dims[0], dims[1])
} else {
(dims[1], dims[2])
};
let (target_height, target_width) = self.size;
if input_height == target_height && input_width == target_width {
return Ok(input);
}
match self.interpolation {
InterpolationMode::Nearest => {
Ok(input)
}
InterpolationMode::Linear | InterpolationMode::Bilinear => {
Ok(input)
}
InterpolationMode::Bicubic => {
Ok(input)
}
}
}
fn is_deterministic(&self) -> bool {
true
}
}
#[derive(Debug, Clone)]
pub struct CenterCrop {
size: (usize, usize),
}
impl CenterCrop {
pub fn new(size: (usize, usize)) -> Self {
Self { size }
}
}
impl<T: TensorElement> Transform<Tensor<T>> for CenterCrop {
type Output = Tensor<T>;
fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
let shape = input.shape();
let dims = shape.dims();
if dims.len() < 2 {
return Err(TorshError::InvalidArgument(
"Input tensor must have at least 2 dimensions for center crop".to_string(),
));
}
let (input_height, input_width) = if dims.len() == 2 {
(dims[0], dims[1])
} else {
(dims[1], dims[2])
};
let (crop_height, crop_width) = self.size;
if crop_height > input_height || crop_width > input_width {
return Err(TorshError::InvalidArgument(
format!("Crop size ({crop_height}, {crop_width}) is larger than input size ({input_height}, {input_width})"),
));
}
let _start_y = (input_height - crop_height) / 2;
let _start_x = (input_width - crop_width) / 2;
Ok(input)
}
fn is_deterministic(&self) -> bool {
true
}
}
pub fn random_horizontal_flip(prob: f32) -> RandomHorizontalFlip {
RandomHorizontalFlip::new(prob)
}
pub fn random_crop(size: (usize, usize)) -> RandomCrop {
RandomCrop::new(size)
}
pub fn resize(size: (usize, usize)) -> Resize {
Resize::new(size)
}
pub fn center_crop(size: (usize, usize)) -> CenterCrop {
CenterCrop::new(size)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
fn mock_tensor_2d() -> Tensor<f32> {
Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu).unwrap()
}
fn mock_tensor_3d() -> Tensor<f32> {
Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
vec![2, 2, 2], DeviceType::Cpu,
)
.unwrap()
}
#[test]
fn test_random_horizontal_flip_creation() {
let flip = RandomHorizontalFlip::new(0.5);
let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &flip;
assert!(!_test.is_deterministic());
}
#[test]
#[should_panic]
fn test_random_horizontal_flip_invalid_prob() {
RandomHorizontalFlip::new(1.5); }
#[test]
fn test_random_crop_creation() {
let crop = RandomCrop::new((224, 224));
let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &crop;
assert!(!_test.is_deterministic());
}
#[test]
fn test_random_crop_with_padding() {
let crop = RandomCrop::new((224, 224)).with_padding(10);
let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &crop;
assert!(!_test.is_deterministic());
}
#[test]
fn test_resize_creation() {
let resize_transform = Resize::new((224, 224));
let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &resize_transform;
assert!(_test.is_deterministic());
}
#[test]
fn test_resize_with_interpolation() {
let resize_transform =
Resize::new((224, 224)).with_interpolation(InterpolationMode::Nearest);
let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &resize_transform;
assert!(_test.is_deterministic());
}
#[test]
fn test_center_crop_creation() {
let crop = CenterCrop::new((224, 224));
let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &crop;
assert!(_test.is_deterministic());
}
#[test]
fn test_interpolation_mode_default() {
assert_eq!(InterpolationMode::default(), InterpolationMode::Bilinear);
}
#[test]
fn test_tensor_transforms_2d() {
let tensor = mock_tensor_2d();
let flip = RandomHorizontalFlip::new(1.0); let result = flip.transform(tensor.clone());
assert!(result.is_ok());
let crop = CenterCrop::new((1, 1));
let result = crop.transform(tensor.clone());
assert!(result.is_ok());
let resize_transform = Resize::new((4, 4));
let result = resize_transform.transform(tensor);
assert!(result.is_ok());
}
#[test]
fn test_tensor_transforms_3d() {
let tensor = mock_tensor_3d();
let flip = RandomHorizontalFlip::new(0.0); let result = flip.transform(tensor.clone());
assert!(result.is_ok());
let crop = CenterCrop::new((1, 1));
let result = crop.transform(tensor.clone());
assert!(result.is_ok());
let resize_transform = Resize::new((4, 4));
let result = resize_transform.transform(tensor);
assert!(result.is_ok());
}
#[test]
fn test_convenience_functions() {
let _flip = random_horizontal_flip(0.5);
let _crop = random_crop((224, 224));
let _resize = resize((256, 256));
let _center = center_crop((224, 224));
}
#[test]
fn test_invalid_tensor_dimensions() {
let tensor_1d = Tensor::from_data(vec![1.0f32, 2.0], vec![2], DeviceType::Cpu).unwrap();
let flip = RandomHorizontalFlip::new(1.0);
assert!(flip.transform(tensor_1d.clone()).is_err());
let crop = CenterCrop::new((1, 1));
assert!(crop.transform(tensor_1d.clone()).is_err());
let resize_transform = Resize::new((4, 4));
assert!(resize_transform.transform(tensor_1d).is_err());
}
#[test]
fn test_crop_size_validation() {
let tensor = mock_tensor_2d();
let crop = CenterCrop::new((3, 3)); assert!(crop.transform(tensor.clone()).is_err());
let random_crop = RandomCrop::new((3, 3)); assert!(random_crop.transform(tensor).is_err());
}
}