pub mod datasets;
pub mod image;
pub use datasets::{ImageNet, CIFAR10, MNIST};
pub use image::transforms::transforms::{CenterCrop, Normalize, Resize};
pub use image::{
Compose, ImageFolder, ImageToTensor, RandomHorizontalFlip, RandomRotation, RandomVerticalFlip,
TensorToImage,
};
use crate::{dataset::Dataset, transforms::Transform};
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, vec::Vec};
use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
pub struct VideoFrames {
frames: Vec<Tensor<f32>>,
fps: f32,
duration: f32,
}
impl VideoFrames {
pub fn new(frames: Vec<Tensor<f32>>, fps: f32) -> Self {
let duration = frames.len() as f32 / fps;
Self {
frames,
fps,
duration,
}
}
pub fn frames(&self) -> &[Tensor<f32>] {
&self.frames
}
pub fn fps(&self) -> f32 {
self.fps
}
pub fn duration(&self) -> f32 {
self.duration
}
pub fn num_frames(&self) -> usize {
self.frames.len()
}
}
pub struct VideoFolder {
root: PathBuf,
samples: Vec<(PathBuf, usize)>,
classes: Vec<String>,
transform: Option<Box<dyn Transform<VideoFrames, Output = Tensor<f32>>>>,
max_frames: usize,
frame_rate: Option<f32>,
}
impl VideoFolder {
pub fn new<P: AsRef<Path>>(root: P) -> Result<Self> {
let root = root.as_ref().to_path_buf();
if !root.exists() {
return Err(TorshError::IoError(format!(
"Directory does not exist: {root:?}"
)));
}
let mut classes = Vec::new();
let mut samples = Vec::new();
for entry in std::fs::read_dir(&root).map_err(|e| TorshError::IoError(e.to_string()))? {
let entry = entry.map_err(|e| TorshError::IoError(e.to_string()))?;
let path = entry.path();
if path.is_dir() {
let class_name = path
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| TorshError::IoError("Invalid class directory name".to_string()))?
.to_string();
let class_idx = classes.len();
classes.push(class_name);
for video_entry in
std::fs::read_dir(&path).map_err(|e| TorshError::IoError(e.to_string()))?
{
let video_entry =
video_entry.map_err(|e| TorshError::IoError(e.to_string()))?;
let video_path = video_entry.path();
if Self::is_video_file(&video_path) {
samples.push((video_path, class_idx));
}
}
}
}
Ok(Self {
root,
samples,
classes,
transform: None,
max_frames: 32, frame_rate: None,
})
}
pub fn with_max_frames(mut self, max_frames: usize) -> Self {
self.max_frames = max_frames;
self
}
pub fn with_frame_rate(mut self, fps: f32) -> Self {
self.frame_rate = Some(fps);
self
}
pub fn with_transform<T>(mut self, transform: T) -> Self
where
T: Transform<VideoFrames, Output = Tensor<f32>> + 'static,
{
self.transform = Some(Box::new(transform));
self
}
pub fn classes(&self) -> &[String] {
&self.classes
}
pub fn root(&self) -> &Path {
&self.root
}
pub fn num_samples(&self) -> usize {
self.samples.len()
}
fn is_video_file(path: &Path) -> bool {
if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
matches!(
extension.to_lowercase().as_str(),
"mp4" | "avi" | "mov" | "mkv" | "wmv" | "flv" | "webm"
)
} else {
false
}
}
fn load_video(&self, _path: &Path) -> Result<VideoFrames> {
let mut frames = Vec::new();
for _i in 0..self.max_frames {
let frame = torsh_tensor::creation::rand::<f32>(&[3, 224, 224])?;
frames.push(frame);
}
let fps = self.frame_rate.unwrap_or(30.0);
Ok(VideoFrames::new(frames, fps))
}
}
impl Dataset for VideoFolder {
type Item = (Tensor<f32>, usize);
fn len(&self) -> usize {
self.samples.len()
}
fn get(&self, index: usize) -> Result<Self::Item> {
if index >= self.samples.len() {
return Err(TorshError::IndexError {
index,
size: self.samples.len(),
});
}
let (ref path, class_idx) = self.samples[index];
let video_frames = self.load_video(path)?;
let tensor = if let Some(ref transform) = self.transform {
transform.transform(video_frames)?
} else {
VideoToTensor.transform(video_frames)?
};
Ok((tensor, class_idx))
}
}
pub struct VideoToTensor;
impl Transform<VideoFrames> for VideoToTensor {
type Output = Tensor<f32>;
fn transform(&self, input: VideoFrames) -> Result<Self::Output> {
let frames = input.frames();
if frames.is_empty() {
return Err(TorshError::InvalidArgument(
"VideoFrames cannot be empty".to_string(),
));
}
let frame_shape = frames[0].shape();
let dims = frame_shape.dims();
if dims.len() != 3 {
return Err(TorshError::InvalidShape(
"Expected 3D frame tensors (C, H, W)".to_string(),
));
}
let (channels, height, width) = (dims[0], dims[1], dims[2]);
let num_frames = frames.len();
let mut video_data = Vec::with_capacity(num_frames * channels * height * width);
for frame in frames {
let frame_data = frame.to_vec()?;
video_data.extend(frame_data);
}
Tensor::from_data(
video_data,
vec![num_frames, channels, height, width],
torsh_core::device::DeviceType::Cpu,
)
}
}
pub struct TensorToVideo {
fps: f32,
}
impl TensorToVideo {
pub fn new(fps: f32) -> Self {
Self { fps }
}
}
impl Transform<Tensor<f32>> for TensorToVideo {
type Output = VideoFrames;
fn transform(&self, input: Tensor<f32>) -> Result<Self::Output> {
let shape = input.shape();
let dims = shape.dims();
if dims.len() != 4 {
return Err(TorshError::InvalidShape(
"Expected 4D tensor (T, C, H, W)".to_string(),
));
}
let (num_frames, channels, height, width) = (dims[0], dims[1], dims[2], dims[3]);
let frame_size = channels * height * width;
let data = input.to_vec()?;
let mut frames = Vec::with_capacity(num_frames);
for t in 0..num_frames {
let start_idx = t * frame_size;
let end_idx = start_idx + frame_size;
let frame_data = data[start_idx..end_idx].to_vec();
let frame = Tensor::from_data(
frame_data,
vec![channels, height, width],
torsh_core::device::DeviceType::Cpu,
)?;
frames.push(frame);
}
Ok(VideoFrames::new(frames, self.fps))
}
}