tenflowers-dataset 0.1.1

Data pipeline and dataset utilities for TenfloweRS
Documentation
//! GPU-accelerated random crop transform
//!
//! This module provides GPU-accelerated random cropping using WGPU compute shaders
//! for data augmentation and image preprocessing.

use tenflowers_core::{Result, Tensor, TensorError};

#[cfg(feature = "gpu")]
use std::sync::Arc;

#[cfg(feature = "gpu")]
use crate::Transform;

#[cfg(feature = "gpu")]
use super::context::GpuContext;

/// GPU-accelerated random crop transform
#[cfg(feature = "gpu")]
pub struct GpuRandomCrop {
    output_width: u32,
    output_height: u32,
    context: Arc<GpuContext>,
}

#[cfg(feature = "gpu")]
impl GpuRandomCrop {
    /// Create a new GPU random crop transform
    pub fn new(output_width: u32, output_height: u32, context: Arc<GpuContext>) -> Result<Self> {
        Ok(Self {
            output_width,
            output_height,
            context,
        })
    }

    /// Apply random crop to image tensor using GPU
    pub async fn crop_tensor(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
        // Placeholder implementation - would contain full GPU cropping logic
        Ok(input.clone())
    }
}

#[cfg(feature = "gpu")]
impl Transform<f32> for GpuRandomCrop {
    fn apply(&self, sample: (Tensor<f32>, Tensor<f32>)) -> Result<(Tensor<f32>, Tensor<f32>)> {
        let (image_tensor, label_tensor) = sample;
        let cropped_tensor = pollster::block_on(self.crop_tensor(&image_tensor))?;
        Ok((cropped_tensor, label_tensor))
    }
}

/// CPU fallback crop transform when GPU is not available
#[cfg(not(feature = "gpu"))]
pub struct GpuRandomCrop;

#[cfg(not(feature = "gpu"))]
impl GpuRandomCrop {
    /// Create a new crop transform (fallback to CPU)
    pub fn new(_output_width: u32, _output_height: u32, _context: ()) -> Result<Self> {
        Err(TensorError::unsupported_operation_simple(
            "GPU transforms require 'gpu' feature to be enabled".to_string(),
        ))
    }
}