pub mod datasets;
pub mod pipeline;
pub mod presets;
pub mod transforms;
pub mod utils;
pub use datasets::*;
pub use pipeline::*;
pub use transforms::*;
use crate::tensor::Tensor;
use num_traits::Float;
#[derive(Debug, Clone)]
pub struct Image<T: Float> {
pub data: Tensor<T>,
pub height: usize,
pub width: usize,
pub channels: usize,
pub format: ImageFormat,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ImageFormat {
CHW,
HWC,
}
impl<T: Float + 'static> Image<T> {
pub fn new(data: Tensor<T>, format: ImageFormat) -> crate::error::RusTorchResult<Self> {
let shape = data.shape();
let (height, width, channels) = match (format, shape.len()) {
(ImageFormat::CHW, 3) => (shape[1], shape[2], shape[0]),
(ImageFormat::HWC, 3) => (shape[0], shape[1], shape[2]),
(ImageFormat::CHW, 4) => (shape[2], shape[3], shape[1]), (ImageFormat::HWC, 4) => (shape[1], shape[2], shape[3]), _ => return Err(crate::error::RusTorchError::invalid_image_shape(shape)),
};
Ok(Image {
data,
height,
width,
channels,
format,
})
}
pub fn to_format(&self, target_format: ImageFormat) -> crate::error::RusTorchResult<Image<T>> {
if self.format == target_format {
return Ok(self.clone());
}
let mut new_image = self.clone();
new_image.format = target_format;
Ok(new_image)
}
pub fn size(&self) -> (usize, usize) {
(self.height, self.width)
}
}
pub type RusTorchResult<T> = crate::error::RusTorchResult<T>;