use crate::error::RusTorchResult;
use crate::vision::transforms::Transform;
use crate::vision::{Image, ImageFormat};
use num_traits::Float;
use std::collections::HashMap;
use std::fmt;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone)]
pub struct PipelineStats {
pub total_processed: usize,
pub avg_processing_time_us: f64,
pub cache_hits: usize,
pub cache_misses: usize,
pub memory_usage_bytes: usize,
}
impl Default for PipelineStats {
fn default() -> Self {
Self {
total_processed: 0,
avg_processing_time_us: 0.0,
cache_hits: 0,
cache_misses: 0,
memory_usage_bytes: 0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExecutionMode {
Sequential,
Parallel,
Batch,
}
pub struct ConditionalTransform<T: Float> {
pub transform: Box<dyn Transform<T>>,
pub predicate: Box<dyn Fn(&Image<T>) -> bool + Send + Sync>,
pub name: String,
}
impl<T: Float> fmt::Debug for ConditionalTransform<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ConditionalTransform")
.field("name", &self.name)
.finish()
}
}
impl<T: Float + 'static + std::fmt::Debug> Transform<T> for ConditionalTransform<T> {
fn apply(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
if (self.predicate)(image) {
self.transform.apply(image)
} else {
Ok(image.clone())
}
}
}
pub struct Pipeline<T: Float> {
transforms: Vec<Box<dyn Transform<T>>>,
execution_mode: ExecutionMode,
cache: Arc<RwLock<HashMap<String, Image<T>>>>,
max_cache_size: usize,
cache_enabled: bool,
stats: Arc<RwLock<PipelineStats>>,
name: String,
}
impl<T: Float + Clone + 'static + std::fmt::Debug> Pipeline<T> {
pub fn new(name: String) -> Self {
Self {
transforms: Vec::new(),
execution_mode: ExecutionMode::Sequential,
cache: Arc::new(RwLock::new(HashMap::new())),
max_cache_size: 1000,
cache_enabled: false,
stats: Arc::new(RwLock::new(PipelineStats::default())),
name,
}
}
pub fn add_transform(mut self, transform: Box<dyn Transform<T>>) -> Self {
self.transforms.push(transform);
self
}
pub fn add_conditional_transform<F>(
mut self,
transform: Box<dyn Transform<T>>,
predicate: F,
name: String,
) -> Self
where
F: Fn(&Image<T>) -> bool + Send + Sync + 'static,
{
let conditional = ConditionalTransform {
transform,
predicate: Box::new(predicate),
name,
};
self.transforms.push(Box::new(conditional));
self
}
pub fn with_execution_mode(mut self, mode: ExecutionMode) -> Self {
self.execution_mode = mode;
self
}
pub fn with_cache(mut self, max_size: usize) -> Self {
self.cache_enabled = true;
self.max_cache_size = max_size;
self
}
pub fn apply(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
let start_time = std::time::Instant::now();
let cache_key = if self.cache_enabled {
Some(self.generate_cache_key(image))
} else {
None
};
if let Some(ref key) = cache_key {
if let Ok(cache) = self.cache.read() {
if let Some(cached_image) = cache.get(key) {
if let Ok(mut stats) = self.stats.write() {
stats.cache_hits += 1;
}
return Ok(cached_image.clone());
}
}
}
let result = match self.execution_mode {
ExecutionMode::Sequential => self.apply_sequential(image),
ExecutionMode::Parallel => self.apply_parallel(image),
ExecutionMode::Batch => self.apply_batch(&[image.clone()]).map(|mut v| v.remove(0)),
}?;
if let Some(key) = cache_key {
if let Ok(mut cache) = self.cache.write() {
if cache.len() >= self.max_cache_size {
if let Some(oldest_key) = cache.keys().next().cloned() {
cache.remove(&oldest_key);
}
}
cache.insert(key, result.clone());
if let Ok(mut stats) = self.stats.write() {
stats.cache_misses += 1;
}
}
}
let processing_time = start_time.elapsed().as_micros() as f64;
if let Ok(mut stats) = self.stats.write() {
stats.total_processed += 1;
let total = stats.total_processed as f64;
stats.avg_processing_time_us =
(stats.avg_processing_time_us * (total - 1.0) + processing_time) / total;
}
Ok(result)
}
fn apply_sequential(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
let mut result = image.clone();
for transform in &self.transforms {
result = transform.apply(&result)?;
}
Ok(result)
}
fn apply_parallel(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
self.apply_sequential(image)
}
pub fn apply_batch(&self, images: &[Image<T>]) -> RusTorchResult<Vec<Image<T>>> {
let mut results = Vec::with_capacity(images.len());
for image in images {
results.push(self.apply_sequential(image)?);
}
Ok(results)
}
fn generate_cache_key(&self, image: &Image<T>) -> String {
format!(
"{}_{}_{}_{:?}_{}",
self.name, image.height, image.width, image.format, image.channels
)
}
pub fn get_stats(&self) -> PipelineStats {
if let Ok(stats) = self.stats.read() {
stats.clone()
} else {
PipelineStats::default()
}
}
pub fn reset_stats(&self) {
if let Ok(mut stats) = self.stats.write() {
*stats = PipelineStats::default();
}
}
pub fn clear_cache(&self) {
if let Ok(mut cache) = self.cache.write() {
cache.clear();
}
}
pub fn cache_info(&self) -> (usize, usize) {
if let Ok(cache) = self.cache.read() {
(cache.len(), self.max_cache_size)
} else {
(0, self.max_cache_size)
}
}
pub fn len(&self) -> usize {
self.transforms.len()
}
pub fn is_empty(&self) -> bool {
self.transforms.is_empty()
}
pub fn name(&self) -> &str {
&self.name
}
}
impl<T: Float + 'static + std::fmt::Debug> fmt::Debug for Pipeline<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Pipeline")
.field("name", &self.name)
.field("num_transforms", &self.transforms.len())
.field("execution_mode", &self.execution_mode)
.field("cache_enabled", &self.cache_enabled)
.field("max_cache_size", &self.max_cache_size)
.finish()
}
}
impl<T: Float + 'static + std::fmt::Debug> Transform<T> for Pipeline<T> {
fn apply(&self, image: &Image<T>) -> RusTorchResult<Image<T>> {
self.apply(image)
}
}
pub struct PipelineBuilder<T: Float> {
pipeline: Pipeline<T>,
}
impl<T: Float + Clone + 'static + std::fmt::Debug> PipelineBuilder<T> {
pub fn new(name: String) -> Self {
Self {
pipeline: Pipeline::new(name),
}
}
pub fn transform(mut self, transform: Box<dyn Transform<T>>) -> Self {
self.pipeline = self.pipeline.add_transform(transform);
self
}
pub fn conditional_transform<F>(
mut self,
transform: Box<dyn Transform<T>>,
predicate: F,
name: String,
) -> Self
where
F: Fn(&Image<T>) -> bool + Send + Sync + 'static,
{
self.pipeline = self
.pipeline
.add_conditional_transform(transform, predicate, name);
self
}
pub fn execution_mode(mut self, mode: ExecutionMode) -> Self {
self.pipeline = self.pipeline.with_execution_mode(mode);
self
}
pub fn cache(mut self, max_size: usize) -> Self {
self.pipeline = self.pipeline.with_cache(max_size);
self
}
pub fn build(self) -> Pipeline<T> {
self.pipeline
}
}
pub mod predicates {
use super::*;
pub fn min_size<T: Float>(
min_width: usize,
min_height: usize,
) -> Box<dyn Fn(&Image<T>) -> bool + Send + Sync> {
Box::new(move |image: &Image<T>| image.width >= min_width && image.height >= min_height)
}
pub fn max_size<T: Float>(
max_width: usize,
max_height: usize,
) -> Box<dyn Fn(&Image<T>) -> bool + Send + Sync> {
Box::new(move |image: &Image<T>| image.width <= max_width && image.height <= max_height)
}
pub fn format_is<T: Float>(
target_format: ImageFormat,
) -> Box<dyn Fn(&Image<T>) -> bool + Send + Sync> {
Box::new(move |image: &Image<T>| image.format == target_format)
}
pub fn channels_eq<T: Float>(
target_channels: usize,
) -> Box<dyn Fn(&Image<T>) -> bool + Send + Sync> {
Box::new(move |image: &Image<T>| image.channels == target_channels)
}
pub fn probability<T: Float>(prob: f64) -> Box<dyn Fn(&Image<T>) -> bool + Send + Sync> {
Box::new(move |_: &Image<T>| {
use rand::Rng;
rand::thread_rng().gen::<f64>() < prob
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
use crate::vision::transforms::{Resize, ToTensor};
#[test]
fn test_pipeline_creation() {
let pipeline = Pipeline::<f32>::new("test_pipeline".to_string());
assert_eq!(pipeline.name(), "test_pipeline");
assert!(pipeline.is_empty());
assert_eq!(pipeline.len(), 0);
}
#[test]
fn test_pipeline_builder() {
let pipeline = PipelineBuilder::<f32>::new("test".to_string())
.transform(Box::new(Resize::new((224, 224))))
.transform(Box::new(ToTensor::new()))
.cache(100)
.execution_mode(ExecutionMode::Sequential)
.build();
assert_eq!(pipeline.len(), 2);
assert_eq!(pipeline.name(), "test");
assert!(!pipeline.is_empty());
}
#[test]
fn test_conditional_transform() {
let predicate = predicates::min_size::<f32>(100, 100);
let small_image_data = vec![0.5f32; 3 * 50 * 50];
let small_tensor = Tensor::from_vec(small_image_data, vec![3, 50, 50]);
let small_image = Image::new(small_tensor, ImageFormat::CHW).unwrap();
assert!(!predicate(&small_image));
let large_image_data = vec![0.5f32; 3 * 200 * 200];
let large_tensor = Tensor::from_vec(large_image_data, vec![3, 200, 200]);
let large_image = Image::new(large_tensor, ImageFormat::CHW).unwrap();
assert!(predicate(&large_image));
}
#[test]
fn test_pipeline_stats() {
let pipeline = Pipeline::<f32>::new("stats_test".to_string());
let stats = pipeline.get_stats();
assert_eq!(stats.total_processed, 0);
assert_eq!(stats.cache_hits, 0);
assert_eq!(stats.cache_misses, 0);
}
#[test]
fn test_cache_info() {
let pipeline = Pipeline::<f32>::new("cache_test".to_string()).with_cache(50);
let (current_size, max_size) = pipeline.cache_info();
assert_eq!(current_size, 0);
assert_eq!(max_size, 50);
}
#[test]
fn test_execution_modes() {
let modes = [
ExecutionMode::Sequential,
ExecutionMode::Parallel,
ExecutionMode::Batch,
];
for mode in modes.iter() {
let _pipeline =
Pipeline::<f32>::new("mode_test".to_string()).with_execution_mode(*mode);
}
}
}