use tenflowers_core::{Result, Tensor, TensorError};
#[cfg(feature = "gpu")]
use std::sync::Arc;
#[cfg(feature = "gpu")]
use crate::Transform;
#[cfg(feature = "gpu")]
use super::context::GpuContext;
#[cfg(feature = "gpu")]
pub struct GpuRotation {
angle_range: (f32, f32),
context: Arc<GpuContext>,
}
#[cfg(feature = "gpu")]
impl GpuRotation {
pub fn new(angle_range: (f32, f32), context: Arc<GpuContext>) -> Result<Self> {
Ok(Self {
angle_range,
context,
})
}
pub async fn rotate_tensor(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
Ok(input.clone())
}
}
#[cfg(feature = "gpu")]
impl Transform<f32> for GpuRotation {
fn apply(&self, sample: (Tensor<f32>, Tensor<f32>)) -> Result<(Tensor<f32>, Tensor<f32>)> {
let (image_tensor, label_tensor) = sample;
let rotated_tensor = pollster::block_on(self.rotate_tensor(&image_tensor))?;
Ok((rotated_tensor, label_tensor))
}
}
#[cfg(not(feature = "gpu"))]
pub struct GpuRotation;
#[cfg(not(feature = "gpu"))]
impl GpuRotation {
pub fn new(_angle_range: (f32, f32), _context: ()) -> Result<Self> {
Err(TensorError::unsupported_operation_simple(
"GPU transforms require 'gpu' feature to be enabled".to_string(),
))
}
}