#![allow(dead_code)]
use crate::unified_transforms::{TransformContext, TransformParameter, UnifiedTransform};
use crate::{Result, VisionError};
use scirs2_core::random::{Random, Rng};
use scirs2_core::RngExt;
use std::collections::HashMap;
use std::sync::Arc;
use torsh_core::device::Device;
use torsh_core::dtype::DType;
use torsh_tensor::Tensor;
#[derive(Debug)]
pub struct UnifiedResize {
size: (usize, usize),
device: Option<Arc<dyn Device>>,
}
impl UnifiedResize {
pub fn new(size: (usize, usize)) -> Self {
Self { size, device: None }
}
pub fn with_device(size: (usize, usize), device: Arc<dyn Device>) -> Self {
Self {
size,
device: Some(device),
}
}
}
impl UnifiedTransform for UnifiedResize {
fn apply(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
crate::ops::resize(input, self.size)
}
fn apply_gpu(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
if let Some(device) = &self.device {
if matches!(
device.device_type(),
torsh_core::device::DeviceType::Cuda(_)
) {
self.apply(input)
} else {
self.apply(input)
}
} else {
self.apply(input)
}
}
fn name(&self) -> &'static str {
"UnifiedResize"
}
fn parameters(&self) -> HashMap<String, TransformParameter> {
let mut params = HashMap::new();
params.insert(
"size".to_string(),
TransformParameter::Tuple2Usize(self.size),
);
if let Some(device) = &self.device {
params.insert(
"device".to_string(),
TransformParameter::String(format!("{:?}", device)),
);
}
params
}
fn supports_gpu(&self) -> bool {
self.device.as_ref().map_or(false, |d| {
matches!(d.device_type(), torsh_core::device::DeviceType::Cuda(_))
})
}
fn output_shape(&self, input_shape: &[usize]) -> Result<Vec<usize>> {
if input_shape.len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W), got {}D",
input_shape.len()
)));
}
Ok(vec![input_shape[0], self.size.1, self.size.0])
}
fn clone_transform(&self) -> Box<dyn UnifiedTransform> {
Box::new(UnifiedResize {
size: self.size,
device: self.device.clone(),
})
}
fn preferred_device(&self) -> Option<&dyn Device> {
self.device.as_ref().map(|d| d.as_ref())
}
}
#[derive(Debug, Clone)]
pub struct UnifiedCenterCrop {
size: (usize, usize),
}
impl UnifiedCenterCrop {
pub fn new(size: (usize, usize)) -> Self {
Self { size }
}
}
impl UnifiedTransform for UnifiedCenterCrop {
fn apply(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
crate::ops::center_crop(input, self.size)
}
fn name(&self) -> &'static str {
"UnifiedCenterCrop"
}
fn parameters(&self) -> HashMap<String, TransformParameter> {
let mut params = HashMap::new();
params.insert(
"size".to_string(),
TransformParameter::Tuple2Usize(self.size),
);
params
}
fn output_shape(&self, input_shape: &[usize]) -> Result<Vec<usize>> {
if input_shape.len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W), got {}D",
input_shape.len()
)));
}
Ok(vec![input_shape[0], self.size.1, self.size.0])
}
fn clone_transform(&self) -> Box<dyn UnifiedTransform> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone)]
pub struct UnifiedRandomHorizontalFlip {
p: f32,
}
impl UnifiedRandomHorizontalFlip {
pub fn new(p: f32) -> Self {
Self { p }
}
}
impl UnifiedTransform for UnifiedRandomHorizontalFlip {
fn apply(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let mut rng = Random::seed(42);
if rng.random::<f32>() < self.p {
crate::ops::horizontal_flip(input)
} else {
Ok(input.clone())
}
}
fn name(&self) -> &'static str {
"UnifiedRandomHorizontalFlip"
}
fn parameters(&self) -> HashMap<String, TransformParameter> {
let mut params = HashMap::new();
params.insert("probability".to_string(), TransformParameter::Float(self.p));
params
}
fn clone_transform(&self) -> Box<dyn UnifiedTransform> {
Box::new(self.clone())
}
}
#[derive(Debug)]
pub struct UnifiedNormalize {
mean: Vec<f32>,
std: Vec<f32>,
device: Option<Arc<dyn Device>>,
}
impl UnifiedNormalize {
pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
Self {
mean,
std,
device: None,
}
}
pub fn with_device(mean: Vec<f32>, std: Vec<f32>, device: Arc<dyn Device>) -> Self {
Self {
mean,
std,
device: Some(device),
}
}
}
impl UnifiedTransform for UnifiedNormalize {
fn apply(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
crate::ops::normalize(
input,
crate::ops::color::NormalizationConfig {
method: crate::ops::color::NormalizationMethod::Custom,
mean: Some(self.mean.clone()),
std: Some(self.std.clone()),
per_channel: true,
eps: 1e-8,
},
)
}
fn apply_gpu(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
if let Some(device) = &self.device {
if matches!(
device.device_type(),
torsh_core::device::DeviceType::Cuda(_)
) {
self.apply(input)
} else {
self.apply(input)
}
} else {
self.apply(input)
}
}
fn name(&self) -> &'static str {
"UnifiedNormalize"
}
fn parameters(&self) -> HashMap<String, TransformParameter> {
let mut params = HashMap::new();
params.insert(
"mean".to_string(),
TransformParameter::FloatVec(self.mean.clone()),
);
params.insert(
"std".to_string(),
TransformParameter::FloatVec(self.std.clone()),
);
if let Some(device) = &self.device {
params.insert(
"device".to_string(),
TransformParameter::String(format!("{:?}", device)),
);
}
params
}
fn supports_gpu(&self) -> bool {
self.device.as_ref().map_or(false, |d| {
matches!(d.device_type(), torsh_core::device::DeviceType::Cuda(_))
})
}
fn supports_mixed_precision(&self) -> bool {
self.supports_gpu()
}
fn clone_transform(&self) -> Box<dyn UnifiedTransform> {
Box::new(UnifiedNormalize {
mean: self.mean.clone(),
std: self.std.clone(),
device: self.device.clone(),
})
}
fn preferred_device(&self) -> Option<&dyn Device> {
self.device.as_ref().map(|d| d.as_ref())
}
}
#[derive(Debug, Clone)]
pub struct UnifiedRandomCrop {
size: (usize, usize),
padding: Option<usize>,
}
impl UnifiedRandomCrop {
pub fn new(size: (usize, usize)) -> Self {
Self {
size,
padding: None,
}
}
pub fn with_padding(size: (usize, usize), padding: usize) -> Self {
Self {
size,
padding: Some(padding),
}
}
}
impl UnifiedTransform for UnifiedRandomCrop {
fn apply(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let input_to_crop = if let Some(padding) = self.padding {
crate::ops::pad(
input,
(padding, padding, padding, padding),
crate::ops::PaddingMode::Zero,
0.0,
)?
} else {
input.clone()
};
crate::ops::random_crop(&input_to_crop, self.size)
}
fn name(&self) -> &'static str {
"UnifiedRandomCrop"
}
fn parameters(&self) -> HashMap<String, TransformParameter> {
let mut params = HashMap::new();
params.insert(
"size".to_string(),
TransformParameter::Tuple2Usize(self.size),
);
if let Some(padding) = self.padding {
params.insert("padding".to_string(), TransformParameter::Usize(padding));
}
params
}
fn output_shape(&self, input_shape: &[usize]) -> Result<Vec<usize>> {
if input_shape.len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W), got {}D",
input_shape.len()
)));
}
Ok(vec![input_shape[0], self.size.1, self.size.0])
}
fn clone_transform(&self) -> Box<dyn UnifiedTransform> {
Box::new(self.clone())
}
}
#[derive(Debug)]
pub struct UnifiedColorJitter {
brightness: Option<f32>,
contrast: Option<f32>,
saturation: Option<f32>,
hue: Option<f32>,
device: Option<Arc<dyn Device>>,
}
impl UnifiedColorJitter {
pub fn new() -> Self {
Self {
brightness: None,
contrast: None,
saturation: None,
hue: None,
device: None,
}
}
pub fn brightness(mut self, brightness: f32) -> Self {
self.brightness = Some(brightness);
self
}
pub fn contrast(mut self, contrast: f32) -> Self {
self.contrast = Some(contrast);
self
}
pub fn saturation(mut self, saturation: f32) -> Self {
self.saturation = Some(saturation);
self
}
pub fn hue(mut self, hue: f32) -> Self {
self.hue = Some(hue);
self
}
pub fn with_device(mut self, device: Arc<dyn Device>) -> Self {
self.device = Some(device);
self
}
}
impl Default for UnifiedColorJitter {
fn default() -> Self {
Self::new()
}
}
impl UnifiedTransform for UnifiedColorJitter {
fn apply(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let mut output = input.clone();
let mut rng = Random::seed(42);
if let Some(brightness) = self.brightness {
let factor = rng.gen_range(1.0 - brightness..=1.0 + brightness);
output = output.mul_scalar(factor)?;
}
if let Some(contrast) = self.contrast {
let factor = rng.gen_range(1.0 - contrast..=1.0 + contrast);
let mean = output.mean(None, false)?;
let mean_val = mean.to_vec()?[0];
output.sub_scalar_(mean_val)?;
output = output.mul_scalar(factor)?;
output = output.add_scalar(mean_val)?;
}
Ok(output)
}
fn apply_gpu(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
if let Some(device) = &self.device {
if matches!(
device.device_type(),
torsh_core::device::DeviceType::Cuda(_)
) {
self.apply(input)
} else {
self.apply(input)
}
} else {
self.apply(input)
}
}
fn name(&self) -> &'static str {
"UnifiedColorJitter"
}
fn parameters(&self) -> HashMap<String, TransformParameter> {
let mut params = HashMap::new();
if let Some(brightness) = self.brightness {
params.insert(
"brightness".to_string(),
TransformParameter::Float(brightness),
);
}
if let Some(contrast) = self.contrast {
params.insert("contrast".to_string(), TransformParameter::Float(contrast));
}
if let Some(saturation) = self.saturation {
params.insert(
"saturation".to_string(),
TransformParameter::Float(saturation),
);
}
if let Some(hue) = self.hue {
params.insert("hue".to_string(), TransformParameter::Float(hue));
}
if let Some(device) = &self.device {
params.insert(
"device".to_string(),
TransformParameter::String(format!("{:?}", device)),
);
}
params
}
fn supports_gpu(&self) -> bool {
self.device.as_ref().map_or(false, |d| {
matches!(d.device_type(), torsh_core::device::DeviceType::Cuda(_))
})
}
fn supports_mixed_precision(&self) -> bool {
self.supports_gpu()
}
fn clone_transform(&self) -> Box<dyn UnifiedTransform> {
Box::new(UnifiedColorJitter {
brightness: self.brightness,
contrast: self.contrast,
saturation: self.saturation,
hue: self.hue,
device: self.device.clone(),
})
}
fn preferred_device(&self) -> Option<&dyn Device> {
self.device.as_ref().map(|d| d.as_ref())
}
}
#[derive(Debug, Clone)]
pub struct UnifiedRandomRotation {
degrees: (f32, f32),
}
impl UnifiedRandomRotation {
pub fn new(degrees: (f32, f32)) -> Self {
Self { degrees }
}
}
impl UnifiedTransform for UnifiedRandomRotation {
fn apply(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let mut rng = Random::seed(42);
let angle = rng.gen_range(self.degrees.0..=self.degrees.1);
crate::ops::rotate(input, angle)
}
fn name(&self) -> &'static str {
"UnifiedRandomRotation"
}
fn parameters(&self) -> HashMap<String, TransformParameter> {
let mut params = HashMap::new();
params.insert(
"degrees".to_string(),
TransformParameter::Tuple2Float(self.degrees),
);
params
}
fn clone_transform(&self) -> Box<dyn UnifiedTransform> {
Box::new(self.clone())
}
}
#[derive(Debug)]
pub struct TransformBridge<T: crate::transforms::Transform> {
inner: T,
}
impl<T: crate::transforms::Transform + Clone + std::fmt::Debug> TransformBridge<T> {
pub fn new(transform: T) -> Self {
Self { inner: transform }
}
}
impl<T: crate::transforms::Transform + Clone + std::fmt::Debug + 'static> UnifiedTransform
for TransformBridge<T>
{
fn apply(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
self.inner.forward(input)
}
fn name(&self) -> &'static str {
self.inner.name()
}
fn parameters(&self) -> HashMap<String, TransformParameter> {
let old_params = self.inner.parameters();
let mut params = HashMap::new();
for (key, value) in old_params {
params.insert(key.to_string(), TransformParameter::String(value));
}
params
}
fn is_inplace(&self) -> bool {
self.inner.is_inplace()
}
fn clone_transform(&self) -> Box<dyn UnifiedTransform> {
Box::new(TransformBridge::new(self.inner.clone()))
}
}
pub struct UnifiedTransformBridge {
inner: Box<dyn UnifiedTransform>,
}
impl UnifiedTransformBridge {
pub fn new(transform: Box<dyn UnifiedTransform>) -> Self {
Self { inner: transform }
}
}
impl crate::transforms::Transform for UnifiedTransformBridge {
fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
self.inner.apply(input)
}
fn name(&self) -> &'static str {
self.inner.name()
}
fn is_inplace(&self) -> bool {
self.inner.is_inplace()
}
fn parameters(&self) -> Vec<(&'static str, String)> {
vec![("bridge", "legacy_transform".to_string())]
}
fn clone_transform(&self) -> Box<dyn crate::transforms::Transform> {
Box::new(UnifiedTransformBridge::new(self.inner.clone_transform()))
}
}
pub mod factory {
use super::*;
pub fn create_resize(
size: (usize, usize),
context: &TransformContext,
) -> Box<dyn UnifiedTransform> {
if matches!(
context.device.device_type(),
torsh_core::device::DeviceType::Cuda(_)
) {
Box::new(UnifiedResize::with_device(size, context.device.clone()))
} else {
Box::new(UnifiedResize::new(size))
}
}
pub fn create_normalize(
mean: Vec<f32>,
std: Vec<f32>,
context: &TransformContext,
) -> Box<dyn UnifiedTransform> {
if matches!(
context.device.device_type(),
torsh_core::device::DeviceType::Cuda(_)
) {
Box::new(UnifiedNormalize::with_device(
mean,
std,
context.device.clone(),
))
} else {
Box::new(UnifiedNormalize::new(mean, std))
}
}
pub fn create_color_jitter(context: &TransformContext) -> UnifiedColorJitter {
if matches!(
context.device.device_type(),
torsh_core::device::DeviceType::Cuda(_)
) {
UnifiedColorJitter::new().with_device(context.device.clone())
} else {
UnifiedColorJitter::new()
}
}
pub fn imagenet_preprocessing(
size: (usize, usize),
context: &TransformContext,
) -> crate::unified_transforms::UnifiedCompose {
use crate::unified_transforms::TransformBuilder;
TransformBuilder::new()
.with_context(context.clone())
.add(UnifiedResize::with_device(size, context.device.clone()))
.add(UnifiedCenterCrop::new(size))
.add(UnifiedNormalize::with_device(
vec![0.485, 0.456, 0.406],
vec![0.229, 0.224, 0.225],
context.device.clone(),
))
.build()
}
pub fn training_augmentation(
size: (usize, usize),
context: &TransformContext,
) -> crate::unified_transforms::UnifiedCompose {
use crate::unified_transforms::TransformBuilder;
TransformBuilder::new()
.with_context(context.clone())
.add(UnifiedResize::with_device(
(size.0 + 32, size.1 + 32),
context.device.clone(),
))
.add(UnifiedRandomCrop::with_padding(size, 4))
.add(UnifiedRandomHorizontalFlip::new(0.5))
.add(
UnifiedColorJitter::new()
.brightness(0.2)
.contrast(0.2)
.with_device(context.device.clone()),
)
.add(UnifiedNormalize::with_device(
vec![0.485, 0.456, 0.406],
vec![0.229, 0.224, 0.225],
context.device.clone(),
))
.build()
}
}
pub mod migration {
use super::*;
pub fn convert_transforms(
_transforms: Vec<Box<dyn crate::transforms::Transform>>,
) -> Vec<Box<dyn UnifiedTransform>> {
Vec::new()
}
pub fn analyze_pipeline(_compose: &crate::transforms::Compose) -> String {
let mut suggestions = Vec::new();
suggestions.push("Consider migrating to UnifiedTransform API for:".to_string());
suggestions.push("- Better GPU acceleration support".to_string());
suggestions.push("- Mixed precision training capabilities".to_string());
suggestions.push("- Improved parameter introspection".to_string());
suggestions.push("- Hardware-aware optimization".to_string());
suggestions.join("\n")
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::zeros;
#[test]
fn test_unified_resize() {
let transform = UnifiedResize::new((224, 224));
let input = zeros(&[3, 256, 256]).expect("zeros should succeed");
let output = transform
.apply(&input)
.expect("apply operation should succeed");
assert_eq!(output.shape().dims(), &[3, 224, 224]);
}
#[test]
fn test_unified_parameters() {
let transform = UnifiedResize::new((224, 224));
let params = transform.parameters();
assert!(params.contains_key("size"));
}
#[test]
fn test_transform_bridge() {
let old_transform = crate::transforms::Resize::new((224, 224));
let bridge = TransformBridge::new(old_transform);
let input = zeros(&[3, 256, 256]).expect("zeros should succeed");
let output = bridge
.apply(&input)
.expect("apply operation should succeed");
assert_eq!(output.shape().dims(), &[3, 224, 224]);
}
#[test]
fn test_factory_functions() {
let context = TransformContext::default();
let transform = factory::create_resize((224, 224), &context);
let input = zeros(&[3, 256, 256]).expect("zeros should succeed");
let output = transform
.apply(&input)
.expect("apply operation should succeed");
assert_eq!(output.shape().dims(), &[3, 224, 224]);
}
#[test]
fn test_gpu_context() {
let context = TransformContext::auto_detect().unwrap_or_default();
let transform = factory::create_normalize(
vec![0.485, 0.456, 0.406],
vec![0.229, 0.224, 0.225],
&context,
);
assert_eq!(transform.name(), "UnifiedNormalize");
assert!(
transform.supports_gpu()
|| !matches!(
context.device.device_type(),
torsh_core::device::DeviceType::Cuda(_)
)
);
}
}