Skip to main content

hanzo_engine/
video_input.rs

1use image::DynamicImage;
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4
5/// Decoded video input: a sequence of frames with metadata for timestamp generation.
6///
7/// Create from pre-decoded frames with [`VideoInput::from_frames`], or use the
8/// server-core `parse_video_url` helper to decode from a video file (requires FFmpeg
9/// for non-GIF formats).
10#[derive(Clone, PartialEq)]
11pub struct VideoInput {
12    /// Decoded video frames (RGB images).
13    pub frames: Vec<DynamicImage>,
14    /// Frames per second of the *original* video. Used to compute per-frame
15    /// timestamps for the prompt (e.g. `"00:05"`). Defaults to 24.0.
16    pub fps: f64,
17    /// Total number of frames in the original video before sampling.
18    pub total_num_frames: usize,
19    /// Indices of the frames that were sampled from the original video.
20    /// Length must equal `frames.len()`.
21    pub sampled_indices: Vec<usize>,
22}
23
24impl std::fmt::Debug for VideoInput {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("VideoInput")
27            .field("num_frames", &self.frames.len())
28            .field("fps", &self.fps)
29            .field("total_num_frames", &self.total_num_frames)
30            .finish()
31    }
32}
33
34impl VideoInput {
35    /// Create a `VideoInput` from pre-decoded frames.
36    ///
37    /// `fps` is the original video frame rate (used for timestamp generation).
38    /// If the frames were not sampled (i.e. all frames are provided), pass `None`
39    /// for `sampled_indices` and they will default to `0..frames.len()`.
40    pub fn from_frames(
41        frames: Vec<DynamicImage>,
42        fps: f64,
43        sampled_indices: Option<Vec<usize>>,
44    ) -> Self {
45        let n = frames.len();
46        let sampled_indices = sampled_indices.unwrap_or_else(|| (0..n).collect());
47        Self {
48            frames,
49            fps,
50            total_num_frames: *sampled_indices.last().unwrap_or(&0) + 1,
51            sampled_indices,
52        }
53    }
54
55    /// Compute per-frame timestamps in seconds.
56    #[allow(clippy::cast_precision_loss)]
57    pub fn timestamps_secs(&self) -> Vec<f64> {
58        self.sampled_indices
59            .iter()
60            .map(|&idx| idx as f64 / self.fps)
61            .collect()
62    }
63
64    /// Format timestamps as `"mm:ss"` strings.
65    #[allow(clippy::cast_possible_truncation)]
66    pub fn timestamp_strings(&self) -> Vec<String> {
67        self.timestamps_secs()
68            .iter()
69            .map(|&secs| {
70                let minutes = (secs / 60.0) as u32;
71                let seconds = (secs % 60.0) as u32;
72                format!("{minutes:02}:{seconds:02}")
73            })
74            .collect()
75    }
76
77    /// Compute a content hash for each frame (for prefix caching).
78    pub fn frame_hashes(&self) -> Vec<u64> {
79        self.frames
80            .iter()
81            .map(|img| {
82                let mut hasher = DefaultHasher::new();
83                img.as_bytes().hash(&mut hasher);
84                hasher.finish()
85            })
86            .collect()
87    }
88
89    /// Compute a single hash representing the entire video (for prefix caching).
90    pub fn video_hash(&self) -> u64 {
91        let mut hasher = DefaultHasher::new();
92        for frame in &self.frames {
93            frame.as_bytes().hash(&mut hasher);
94        }
95        self.fps.to_bits().hash(&mut hasher);
96        hasher.finish()
97    }
98}
99
100/// Sample `num_frames` frame indices uniformly from a video with `total_frames` frames.
101///
102/// Matches the HF reference: `torch.arange(0, total, total / num_frames).int()`
103#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
104pub fn sample_frame_indices(total_frames: usize, num_frames: usize) -> Vec<usize> {
105    if num_frames == 0 || total_frames == 0 {
106        return Vec::new();
107    }
108    let n = num_frames.min(total_frames);
109    (0..n)
110        .map(|i| ((i as f64) * (total_frames as f64) / (n as f64)) as usize)
111        .collect()
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_sample_frame_indices() {
120        let indices = sample_frame_indices(96, 32);
121        assert_eq!(indices.len(), 32);
122        assert_eq!(indices[0], 0);
123        assert_eq!(indices[1], 3);
124        assert_eq!(indices[31], 93);
125    }
126
127    #[test]
128    fn test_sample_frame_indices_equal() {
129        let indices = sample_frame_indices(5, 5);
130        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
131    }
132
133    #[test]
134    fn test_sample_frame_indices_more_than_total() {
135        let indices = sample_frame_indices(3, 10);
136        assert_eq!(indices.len(), 3);
137    }
138
139    #[test]
140    fn test_timestamp_strings() {
141        let vi = VideoInput {
142            frames: Vec::new(),
143            fps: 24.0,
144            total_num_frames: 2880,
145            sampled_indices: vec![0, 720, 1440, 2160],
146        };
147        let ts = vi.timestamp_strings();
148        assert_eq!(ts, vec!["00:00", "00:30", "01:00", "01:30"]);
149    }
150}