#![allow(dead_code)]
use crate::{Result, VisionError};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use torsh_core::device::{CpuDevice, Device, DeviceType};
use torsh_core::dtype::DType;
use torsh_tensor::Tensor;
pub trait UnifiedTransform: Send + Sync + fmt::Debug {
fn apply(&self, input: &Tensor<f32>) -> Result<Tensor<f32>>;
fn apply_gpu(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
self.apply(input)
}
fn apply_gpu_f16(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let input_f32 = input.to_dtype(DType::F32)?;
let output_f32 = self.apply_gpu(&input_f32)?;
Ok(output_f32.to_dtype(DType::F16)?)
}
fn name(&self) -> &'static str;
fn parameters(&self) -> HashMap<String, TransformParameter>;
fn supports_gpu(&self) -> bool {
false
}
fn supports_mixed_precision(&self) -> bool {
false
}
fn is_inplace(&self) -> bool {
false
}
fn output_shape(&self, input_shape: &[usize]) -> Result<Vec<usize>> {
Ok(input_shape.to_vec())
}
fn clone_transform(&self) -> Box<dyn UnifiedTransform>;
fn preferred_device(&self) -> Option<&dyn Device> {
None
}
}
#[derive(Debug, Clone)]
pub enum TransformParameter {
Float(f32),
Int(i32),
Usize(usize),
Bool(bool),
String(String),
FloatVec(Vec<f32>),
IntVec(Vec<i32>),
UsizeVec(Vec<usize>),
Tuple2Usize((usize, usize)),
Tuple2Float((f32, f32)),
}
impl fmt::Display for TransformParameter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TransformParameter::Float(v) => write!(f, "{:.3}", v),
TransformParameter::Int(v) => write!(f, "{}", v),
TransformParameter::Usize(v) => write!(f, "{}", v),
TransformParameter::Bool(v) => write!(f, "{}", v),
TransformParameter::String(v) => write!(f, "{}", v),
TransformParameter::FloatVec(v) => write!(f, "{:?}", v),
TransformParameter::IntVec(v) => write!(f, "{:?}", v),
TransformParameter::UsizeVec(v) => write!(f, "{:?}", v),
TransformParameter::Tuple2Usize((a, b)) => write!(f, "({}, {})", a, b),
TransformParameter::Tuple2Float((a, b)) => write!(f, "({:.3}, {:.3})", a, b),
}
}
}
#[derive(Clone)]
pub struct TransformContext {
pub device: Arc<dyn Device>,
pub use_mixed_precision: bool,
pub batch_size: Option<usize>,
pub enable_optimizations: bool,
}
impl std::fmt::Debug for TransformContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TransformContext")
.field("device", &"Arc<dyn Device>")
.field("use_mixed_precision", &self.use_mixed_precision)
.field("batch_size", &self.batch_size)
.field("enable_optimizations", &self.enable_optimizations)
.finish()
}
}
impl Default for TransformContext {
fn default() -> Self {
Self {
device: Arc::new(CpuDevice::new()),
use_mixed_precision: false,
batch_size: None,
enable_optimizations: true,
}
}
}
impl TransformContext {
pub fn new(device: Arc<dyn Device>) -> Self {
Self {
device,
..Default::default()
}
}
pub fn with_mixed_precision(mut self, enabled: bool) -> Self {
self.use_mixed_precision = enabled;
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
pub fn auto_detect() -> Result<Self> {
let _device_type = DeviceType::Cpu;
let device = Arc::new(CpuDevice::new()) as Arc<dyn Device>;
let use_mixed_precision = false;
Ok(Self {
device,
use_mixed_precision,
batch_size: None,
enable_optimizations: true,
})
}
}
#[derive(Debug)]
pub struct UnifiedCompose {
transforms: Vec<Box<dyn UnifiedTransform>>,
context: TransformContext,
}
impl UnifiedCompose {
pub fn new(transforms: Vec<Box<dyn UnifiedTransform>>) -> Self {
Self {
transforms,
context: TransformContext::default(),
}
}
pub fn with_context(mut self, context: TransformContext) -> Self {
self.context = context;
self
}
pub fn add_transform(&mut self, transform: Box<dyn UnifiedTransform>) {
self.transforms.push(transform);
}
pub fn len(&self) -> usize {
self.transforms.len()
}
pub fn is_empty(&self) -> bool {
self.transforms.is_empty()
}
}
impl UnifiedTransform for UnifiedCompose {
fn apply(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let mut output = input.clone();
for transform in &self.transforms {
output = if self.context.use_mixed_precision && transform.supports_mixed_precision() {
let output_f16 = output.to_dtype(DType::F16)?;
let result_f16 = transform.apply_gpu_f16(&output_f16)?;
result_f16.to_dtype(DType::F32)?
} else if matches!(self.context.device.device_type(), DeviceType::Cuda(_))
&& transform.supports_gpu()
{
transform.apply_gpu(&output)?
} else {
transform.apply(&output)?
};
}
Ok(output)
}
fn apply_gpu(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let mut output = input.clone();
for transform in &self.transforms {
output = transform.apply_gpu(&output)?;
}
Ok(output)
}
fn apply_gpu_f16(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let mut output = input.clone();
for transform in &self.transforms {
output = transform.apply_gpu_f16(&output)?;
}
Ok(output)
}
fn name(&self) -> &'static str {
"UnifiedCompose"
}
fn parameters(&self) -> HashMap<String, TransformParameter> {
let mut params = HashMap::new();
params.insert(
"num_transforms".to_string(),
TransformParameter::Usize(self.transforms.len()),
);
for (i, transform) in self.transforms.iter().enumerate() {
params.insert(
format!("transform_{}", i),
TransformParameter::String(transform.name().to_string()),
);
}
params
}
fn supports_gpu(&self) -> bool {
self.transforms.iter().any(|t| t.supports_gpu())
}
fn supports_mixed_precision(&self) -> bool {
self.transforms.iter().any(|t| t.supports_mixed_precision())
}
fn output_shape(&self, input_shape: &[usize]) -> Result<Vec<usize>> {
let mut shape = input_shape.to_vec();
for transform in &self.transforms {
shape = transform.output_shape(&shape)?;
}
Ok(shape)
}
fn clone_transform(&self) -> Box<dyn UnifiedTransform> {
let cloned_transforms = self
.transforms
.iter()
.map(|t| t.clone_transform())
.collect();
Box::new(UnifiedCompose::new(cloned_transforms).with_context(self.context.clone()))
}
}
pub struct TransformBuilder {
transforms: Vec<Box<dyn UnifiedTransform>>,
context: TransformContext,
}
impl TransformBuilder {
pub fn new() -> Self {
Self {
transforms: Vec::new(),
context: TransformContext::default(),
}
}
pub fn with_context(mut self, context: TransformContext) -> Self {
self.context = context;
self
}
pub fn add<T: UnifiedTransform + 'static>(mut self, transform: T) -> Self {
self.transforms.push(Box::new(transform));
self
}
pub fn resize(self, size: (usize, usize)) -> Self {
self.add(crate::transforms::unified::UnifiedResize::new(size))
}
pub fn center_crop(self, size: (usize, usize)) -> Self {
self.add(crate::transforms::unified::UnifiedCenterCrop::new(size))
}
pub fn random_horizontal_flip(self, p: f32) -> Self {
self.add(crate::transforms::unified::UnifiedRandomHorizontalFlip::new(p))
}
pub fn normalize(self, mean: Vec<f32>, std: Vec<f32>) -> Self {
self.add(crate::transforms::unified::UnifiedNormalize::new(mean, std))
}
pub fn build(self) -> UnifiedCompose {
UnifiedCompose::new(self.transforms).with_context(self.context)
}
}
impl Default for TransformBuilder {
fn default() -> Self {
Self::new()
}
}
#[macro_export]
macro_rules! compose_transforms {
($($transform:expr),+ $(,)?) => {
$crate::unified_transforms::UnifiedCompose::new(vec![
$(Box::new($transform),)+
])
};
}
pub mod presets {
use super::*;
pub fn imagenet_preprocessing(size: (usize, usize)) -> UnifiedCompose {
TransformBuilder::new()
.resize(size)
.center_crop(size)
.normalize(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
.build()
}
pub fn training_augmentation(size: (usize, usize)) -> UnifiedCompose {
TransformBuilder::new()
.resize((size.0 + 32, size.1 + 32))
.random_horizontal_flip(0.5)
.center_crop(size)
.normalize(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
.build()
}
pub fn gpu_optimized_transforms(size: (usize, usize)) -> Result<UnifiedCompose> {
let context = TransformContext::auto_detect()?;
Ok(TransformBuilder::new()
.with_context(context)
.resize(size)
.normalize(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
.build())
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::zeros;
#[test]
fn test_transform_builder() {
let transform = TransformBuilder::new()
.resize((224, 224))
.center_crop((224, 224))
.normalize(vec![0.5], vec![0.5])
.build();
assert_eq!(transform.len(), 3);
assert_eq!(transform.name(), "UnifiedCompose");
}
#[test]
fn test_compose_macro() {
let resize = crate::transforms::unified::UnifiedResize::new((224, 224));
let normalize = crate::transforms::unified::UnifiedNormalize::new(vec![0.5], vec![0.5]);
let _transform = compose_transforms![resize, normalize];
}
#[test]
fn test_transform_parameters() {
let param = TransformParameter::Tuple2Usize((224, 224));
assert_eq!(param.to_string(), "(224, 224)");
let param = TransformParameter::FloatVec(vec![0.485, 0.456, 0.406]);
assert!(param.to_string().contains("0.485"));
}
#[test]
fn test_context_auto_detect() {
let context = TransformContext::auto_detect();
assert!(context.is_ok());
}
}