#![allow(dead_code)]
use crate::hardware::{GpuTransform, HardwareContext};
use crate::transforms::Transform;
use crate::{Result, VisionError};
use std::path::Path;
use torsh_core::device::{Device, DeviceType};
use torsh_core::dtype::DType;
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct VideoFrame {
pub data: Tensor<f32>,
pub timestamp: f64,
pub frame_index: usize,
}
impl VideoFrame {
pub fn new(data: Tensor<f32>, timestamp: f64, frame_index: usize) -> Self {
Self {
data,
timestamp,
frame_index,
}
}
pub fn width(&self) -> usize {
self.data.shape().dims()[2]
}
pub fn height(&self) -> usize {
self.data.shape().dims()[1]
}
pub fn channels(&self) -> usize {
self.data.shape().dims()[0]
}
}
pub trait VideoReader {
fn read_frame(&mut self) -> Result<Option<VideoFrame>>;
fn seek(&mut self, frame_index: usize) -> Result<()>;
fn total_frames(&self) -> usize;
fn fps(&self) -> f32;
fn duration(&self) -> f64;
}
pub struct SimpleVideoReader {
frames: Vec<VideoFrame>,
current_frame: usize,
fps: f32,
}
impl SimpleVideoReader {
pub fn from_images<P: AsRef<Path>>(image_paths: &[P], fps: f32) -> Result<Self> {
let mut frames = Vec::new();
for (i, path) in image_paths.iter().enumerate() {
let image = image::open(path).map_err(VisionError::ImageError)?;
let tensor = crate::ops::to_tensor(&image)?;
let timestamp = i as f64 / fps as f64;
frames.push(VideoFrame::new(tensor, timestamp, i));
}
Ok(Self {
frames,
current_frame: 0,
fps,
})
}
pub fn from_tensors(tensors: Vec<Tensor<f32>>, fps: f32) -> Self {
let frames = tensors
.into_iter()
.enumerate()
.map(|(i, tensor)| {
let timestamp = i as f64 / fps as f64;
VideoFrame::new(tensor, timestamp, i)
})
.collect();
Self {
frames,
current_frame: 0,
fps,
}
}
}
impl VideoReader for SimpleVideoReader {
fn read_frame(&mut self) -> Result<Option<VideoFrame>> {
if self.current_frame >= self.frames.len() {
Ok(None)
} else {
let frame = self.frames[self.current_frame].clone();
self.current_frame += 1;
Ok(Some(frame))
}
}
fn seek(&mut self, frame_index: usize) -> Result<()> {
if frame_index >= self.frames.len() {
return Err(VisionError::InvalidArgument(format!(
"Frame index {} out of bounds (total frames: {})",
frame_index,
self.frames.len()
)));
}
self.current_frame = frame_index;
Ok(())
}
fn total_frames(&self) -> usize {
self.frames.len()
}
fn fps(&self) -> f32 {
self.fps
}
fn duration(&self) -> f64 {
self.frames.len() as f64 / self.fps as f64
}
}
pub trait VideoWriter {
fn write_frame(&mut self, frame: &VideoFrame) -> Result<()>;
fn finalize(&mut self) -> Result<()>;
}
pub struct SimpleVideoWriter {
frames: Vec<VideoFrame>,
output_path: String,
fps: f32,
}
impl SimpleVideoWriter {
pub fn new<P: AsRef<Path>>(output_path: P, fps: f32) -> Self {
Self {
frames: Vec::new(),
output_path: output_path.as_ref().to_string_lossy().to_string(),
fps,
}
}
}
impl VideoWriter for SimpleVideoWriter {
fn write_frame(&mut self, frame: &VideoFrame) -> Result<()> {
self.frames.push(frame.clone());
Ok(())
}
fn finalize(&mut self) -> Result<()> {
println!(
"Writing {} frames to {}",
self.frames.len(),
self.output_path
);
Ok(())
}
}
pub struct VideoTransform<T: Transform> {
transform: T,
}
impl<T: Transform> VideoTransform<T> {
pub fn new(transform: T) -> Self {
Self { transform }
}
pub fn apply(&self, frame: &VideoFrame) -> Result<VideoFrame> {
let transformed_data = self.transform.forward(&frame.data)?;
Ok(VideoFrame::new(
transformed_data,
frame.timestamp,
frame.frame_index,
))
}
pub fn apply_to_video<R: VideoReader, W: VideoWriter>(
&self,
reader: &mut R,
writer: &mut W,
) -> Result<()> {
while let Some(frame) = reader.read_frame()? {
let transformed_frame = self.apply(&frame)?;
writer.write_frame(&transformed_frame)?;
}
writer.finalize()?;
Ok(())
}
}
pub struct VideoDataset {
videos: Vec<Box<dyn VideoReader>>,
current_video: usize,
sequence_length: usize,
overlap: usize,
}
impl VideoDataset {
pub fn new(videos: Vec<Box<dyn VideoReader>>, sequence_length: usize, overlap: usize) -> Self {
Self {
videos,
current_video: 0,
sequence_length,
overlap,
}
}
pub fn get_sequence(
&mut self,
video_index: usize,
start_frame: usize,
) -> Result<Vec<VideoFrame>> {
if video_index >= self.videos.len() {
return Err(VisionError::InvalidArgument(format!(
"Video index {} out of bounds",
video_index
)));
}
let reader = &mut self.videos[video_index];
reader.seek(start_frame)?;
let mut sequence = Vec::new();
for _ in 0..self.sequence_length {
if let Some(frame) = reader.read_frame()? {
sequence.push(frame);
} else {
break;
}
}
Ok(sequence)
}
pub fn total_sequences(&self) -> usize {
self.videos
.iter()
.map(|reader| {
let total_frames = reader.total_frames();
if total_frames >= self.sequence_length {
(total_frames - self.sequence_length) / (self.sequence_length - self.overlap)
+ 1
} else {
0
}
})
.sum()
}
}
pub struct OpticalFlow {
device: DeviceType,
}
impl OpticalFlow {
pub fn new(device: DeviceType) -> Self {
Self { device }
}
pub fn lucas_kanade(
&self,
frame1: &Tensor<f32>,
frame2: &Tensor<f32>,
) -> Result<(Tensor<f32>, Tensor<f32>)> {
let gray1 = if frame1.shape().dims().len() == 3 && frame1.shape().dims()[0] == 3 {
crate::ops::rgb_to_grayscale(frame1)?
} else {
frame1.clone()
};
let gray2 = if frame2.shape().dims().len() == 3 && frame2.shape().dims()[0] == 3 {
crate::ops::rgb_to_grayscale(frame2)?
} else {
frame2.clone()
};
let grad_x1 = self.compute_gradient_x(&gray1)?;
let grad_y1 = self.compute_gradient_y(&gray1)?;
let grad_t = gray2.sub(&gray1)?;
let gray1_shape = gray1.shape();
let gray1_dims = gray1_shape.dims();
let (height, width) = if gray1_dims.len() == 2 {
(gray1_dims[0], gray1_dims[1])
} else if gray1_dims.len() == 3 {
(gray1_dims[1], gray1_dims[2])
} else {
return Err(VisionError::TensorError(
torsh_core::error::TorshError::InvalidArgument(format!(
"Expected 2D or 3D tensor, got {}D",
gray1_dims.len()
)),
));
};
let flow_x = Tensor::zeros(&[height, width], self.device)?;
let flow_y = Tensor::zeros(&[height, width], self.device)?;
let window_size = 5;
let half_window = window_size / 2;
for y in half_window..(height - half_window) {
for x in half_window..(width - half_window) {
let (fx, fy) = self.compute_optical_flow_at_point(
&grad_x1,
&grad_y1,
&grad_t,
x,
y,
window_size,
)?;
flow_x.set(&[y, x], fx)?;
flow_y.set(&[y, x], fy)?;
}
}
Ok((flow_x, flow_y))
}
fn compute_gradient_x(&self, image: &Tensor<f32>) -> Result<Tensor<f32>> {
let kernel = Tensor::from_data(
vec![-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0],
vec![3, 3],
self.device,
)?;
self.convolve_2d(image, &kernel)
}
fn compute_gradient_y(&self, image: &Tensor<f32>) -> Result<Tensor<f32>> {
let kernel = Tensor::from_data(
vec![-1.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0],
vec![3, 3],
self.device,
)?;
self.convolve_2d(image, &kernel)
}
fn convolve_2d(&self, image: &Tensor<f32>, kernel: &Tensor<f32>) -> Result<Tensor<f32>> {
let image_shape = image.shape();
let image_dims = image_shape.dims();
let is_2d = image_dims.len() == 2;
let (h, w) = if is_2d {
(image_dims[0], image_dims[1])
} else if image_dims.len() == 3 {
(image_dims[1], image_dims[2])
} else {
return Err(VisionError::TensorError(
torsh_core::error::TorshError::InvalidArgument(format!(
"Expected 2D or 3D tensor, got {}D",
image_dims.len()
)),
));
};
let (kh, kw) = (kernel.shape().dims()[0], kernel.shape().dims()[1]);
let output = Tensor::zeros(&[h, w], self.device)?;
let pad_h = kh / 2;
let pad_w = kw / 2;
for y in pad_h..(h - pad_h) {
for x in pad_w..(w - pad_w) {
let mut sum = 0.0;
for ky in 0..kh {
for kx in 0..kw {
let img_y = y + ky - pad_h;
let img_x = x + kx - pad_w;
let img_val = if is_2d {
image.get(&[img_y, img_x])?
} else {
image.get(&[0, img_y, img_x])?
};
let kernel_val = kernel.get(&[ky, kx])?;
sum += img_val * kernel_val;
}
}
output.set(&[y, x], sum)?;
}
}
Ok(output)
}
fn compute_optical_flow_at_point(
&self,
grad_x: &Tensor<f32>,
grad_y: &Tensor<f32>,
grad_t: &Tensor<f32>,
x: usize,
y: usize,
window_size: usize,
) -> Result<(f32, f32)> {
let half_window = window_size / 2;
let mut a11 = 0.0;
let mut a12 = 0.0;
let mut a22 = 0.0;
let mut b1 = 0.0;
let mut b2 = 0.0;
for dy in 0..window_size {
for dx in 0..window_size {
let px = x + dx - half_window;
let py = y + dy - half_window;
let gx = grad_x.get(&[py, px])?;
let gy = grad_y.get(&[py, px])?;
let gt = grad_t.get(&[py, px])?;
a11 += gx * gx;
a12 += gx * gy;
a22 += gy * gy;
b1 += -gx * gt;
b2 += -gy * gt;
}
}
let det = a11 * a22 - a12 * a12;
if det.abs() < 1e-6 {
return Ok((0.0, 0.0));
}
let flow_x = (a22 * b1 - a12 * b2) / det;
let flow_y = (a11 * b2 - a12 * b1) / det;
Ok((flow_x, flow_y))
}
}
pub struct VideoModel {
context: HardwareContext,
}
impl VideoModel {
pub fn new(context: HardwareContext) -> Self {
Self { context }
}
pub fn conv3d(
&self,
input: &Tensor<f32>,
kernel: &Tensor<f32>,
stride: (usize, usize, usize),
padding: (usize, usize, usize),
) -> Result<Tensor<f32>> {
let input_shape = input.shape();
let kernel_shape = kernel.shape();
if input_shape.dims().len() != 5 || kernel_shape.dims().len() != 5 {
return Err(VisionError::InvalidShape(
"Expected 5D tensors for 3D convolution (N, C, T, H, W)".to_string(),
));
}
let (batch_size, _in_channels, time_steps, height, width) = (
input_shape.dims()[0],
input_shape.dims()[1],
input_shape.dims()[2],
input_shape.dims()[3],
input_shape.dims()[4],
);
let (out_channels, _, kernel_t, kernel_h, kernel_w) = (
kernel_shape.dims()[0],
kernel_shape.dims()[1],
kernel_shape.dims()[2],
kernel_shape.dims()[3],
kernel_shape.dims()[4],
);
let (stride_t, stride_h, stride_w) = stride;
let (pad_t, pad_h, pad_w) = padding;
let out_t = (time_steps + 2 * pad_t - kernel_t) / stride_t + 1;
let out_h = (height + 2 * pad_h - kernel_h) / stride_h + 1;
let out_w = (width + 2 * pad_w - kernel_w) / stride_w + 1;
let output = Tensor::zeros(
&[batch_size, out_channels, out_t, out_h, out_w],
input.device(),
)?;
Ok(output)
}
pub fn temporal_pool(
&self,
input: &Tensor<f32>,
pool_type: TemporalPoolType,
) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.dims().len() != 5 {
return Err(VisionError::InvalidShape(
"Expected 5D tensor (N, C, T, H, W)".to_string(),
));
}
let (_batch_size, _channels, _time_steps, _height, _width) = (
shape.dims()[0],
shape.dims()[1],
shape.dims()[2],
shape.dims()[3],
shape.dims()[4],
);
match pool_type {
TemporalPoolType::Average => {
let pooled = input.mean(Some(&[2]), false)?;
Ok(pooled)
}
TemporalPoolType::Max => {
let pooled = input.max_dim(2, false)?;
Ok(pooled)
}
TemporalPoolType::Last => {
let last_frame = input.select(2, -1)?;
Ok(last_frame)
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum TemporalPoolType {
Average,
Max,
Last,
}
pub struct VideoAugmentation {
transforms: Vec<Box<dyn Transform>>,
temporal_transforms: Vec<Box<dyn TemporalTransform>>,
}
impl VideoAugmentation {
pub fn new() -> Self {
Self {
transforms: Vec::new(),
temporal_transforms: Vec::new(),
}
}
pub fn add_frame_transform<T: Transform + 'static>(&mut self, transform: T) {
self.transforms.push(Box::new(transform));
}
pub fn add_temporal_transform<T: TemporalTransform + 'static>(&mut self, transform: T) {
self.temporal_transforms.push(Box::new(transform));
}
pub fn apply_to_sequence(&self, sequence: &[VideoFrame]) -> Result<Vec<VideoFrame>> {
let mut result = sequence.to_vec();
for temporal_transform in &self.temporal_transforms {
result = temporal_transform.apply_temporal(&result)?;
}
for transform in &self.transforms {
for frame in &mut result {
frame.data = transform.forward(&frame.data)?;
}
}
Ok(result)
}
}
impl Default for VideoAugmentation {
fn default() -> Self {
Self::new()
}
}
pub trait TemporalTransform: Send + Sync {
fn apply_temporal(&self, sequence: &[VideoFrame]) -> Result<Vec<VideoFrame>>;
}
pub struct TemporalSampling {
num_frames: usize,
strategy: SamplingStrategy,
}
impl TemporalSampling {
pub fn new(num_frames: usize, strategy: SamplingStrategy) -> Self {
Self {
num_frames,
strategy,
}
}
}
impl TemporalTransform for TemporalSampling {
fn apply_temporal(&self, sequence: &[VideoFrame]) -> Result<Vec<VideoFrame>> {
if sequence.len() <= self.num_frames {
return Ok(sequence.to_vec());
}
let indices = match self.strategy {
SamplingStrategy::Uniform => {
let step = sequence.len() as f32 / self.num_frames as f32;
(0..self.num_frames)
.map(|i| (i as f32 * step) as usize)
.collect()
}
SamplingStrategy::Random => {
use scirs2_core::random::Random;
let mut rng = Random::seed(42); let mut indices: Vec<usize> = (0..sequence.len()).collect();
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
indices.truncate(self.num_frames);
indices.sort();
indices
}
SamplingStrategy::Center => {
let start = (sequence.len() - self.num_frames) / 2;
(start..start + self.num_frames).collect()
}
};
Ok(indices.into_iter().map(|i| sequence[i].clone()).collect())
}
}
#[derive(Debug, Clone, Copy)]
pub enum SamplingStrategy {
Uniform,
Random,
Center,
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::zeros;
#[test]
fn test_video_frame() {
let tensor = zeros(&[3, 224, 224]).unwrap();
let frame = VideoFrame::new(tensor, 0.0, 0);
assert_eq!(frame.width(), 224);
assert_eq!(frame.height(), 224);
assert_eq!(frame.channels(), 3);
}
#[test]
fn test_simple_video_reader() {
let tensors = vec![
zeros(&[3, 64, 64]).unwrap(),
zeros(&[3, 64, 64]).unwrap(),
zeros(&[3, 64, 64]).unwrap(),
];
let mut reader = SimpleVideoReader::from_tensors(tensors, 30.0);
assert_eq!(reader.total_frames(), 3);
assert_eq!(reader.fps(), 30.0);
let frame1 = reader.read_frame().unwrap().unwrap();
assert_eq!(frame1.frame_index, 0);
let frame2 = reader.read_frame().unwrap().unwrap();
assert_eq!(frame2.frame_index, 1);
reader.seek(0).unwrap();
let frame_again = reader.read_frame().unwrap().unwrap();
assert_eq!(frame_again.frame_index, 0);
}
#[test]
fn test_video_transform() {
let transform = crate::transforms::Resize::new((128, 128));
let video_transform = VideoTransform::new(transform);
let tensor = zeros(&[3, 64, 64]).unwrap();
let frame = VideoFrame::new(tensor, 0.0, 0);
let transformed_frame = video_transform.apply(&frame).unwrap();
assert_eq!(transformed_frame.data.shape().dims(), &[3, 128, 128]);
}
#[test]
fn test_optical_flow() {
let device = DeviceType::Cpu;
let optical_flow = OpticalFlow::new(device);
let frame1 = zeros(&[32, 32]).unwrap();
let frame2 = zeros(&[32, 32]).unwrap();
let (flow_x, flow_y) = optical_flow.lucas_kanade(&frame1, &frame2).unwrap();
assert_eq!(flow_x.shape().dims(), &[32, 32]);
assert_eq!(flow_y.shape().dims(), &[32, 32]);
}
#[test]
fn test_temporal_sampling() {
let frames = (0..10)
.map(|i| VideoFrame::new(zeros(&[3, 32, 32]).unwrap(), i as f64 * 0.1, i))
.collect::<Vec<_>>();
let sampling = TemporalSampling::new(5, SamplingStrategy::Uniform);
let sampled = sampling.apply_temporal(&frames).unwrap();
assert_eq!(sampled.len(), 5);
}
#[test]
fn test_video_augmentation() {
let mut augmentation = VideoAugmentation::new();
augmentation.add_frame_transform(crate::transforms::Resize::new((64, 64)));
augmentation.add_temporal_transform(TemporalSampling::new(3, SamplingStrategy::Center));
let frames = (0..5)
.map(|i| VideoFrame::new(zeros(&[3, 32, 32]).unwrap(), i as f64 * 0.1, i))
.collect::<Vec<_>>();
let augmented = augmentation.apply_to_sequence(&frames).unwrap();
assert_eq!(augmented.len(), 3);
assert_eq!(augmented[0].data.shape().dims(), &[3, 64, 64]);
}
}