use tenflowers_core::{Result, Tensor, TensorError};
#[cfg(feature = "gpu")]
use std::sync::Arc;
#[cfg(feature = "gpu")]
use crate::Transform;
#[cfg(feature = "gpu")]
use super::context::GpuContext;
#[cfg(feature = "gpu")]
pub struct GpuRandomCrop {
output_width: u32,
output_height: u32,
context: Arc<GpuContext>,
}
#[cfg(feature = "gpu")]
impl GpuRandomCrop {
pub fn new(output_width: u32, output_height: u32, context: Arc<GpuContext>) -> Result<Self> {
Ok(Self {
output_width,
output_height,
context,
})
}
pub async fn crop_tensor(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
Ok(input.clone())
}
}
#[cfg(feature = "gpu")]
impl Transform<f32> for GpuRandomCrop {
fn apply(&self, sample: (Tensor<f32>, Tensor<f32>)) -> Result<(Tensor<f32>, Tensor<f32>)> {
let (image_tensor, label_tensor) = sample;
let cropped_tensor = pollster::block_on(self.crop_tensor(&image_tensor))?;
Ok((cropped_tensor, label_tensor))
}
}
#[cfg(not(feature = "gpu"))]
pub struct GpuRandomCrop;
#[cfg(not(feature = "gpu"))]
impl GpuRandomCrop {
pub fn new(_output_width: u32, _output_height: u32, _context: ()) -> Result<Self> {
Err(TensorError::unsupported_operation_simple(
"GPU transforms require 'gpu' feature to be enabled".to_string(),
))
}
}