Skip to main content

mistralrs_core/
video_input.rs

1use image::DynamicImage;
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4
5/// Decoded video input: a sequence of frames with metadata for timestamp generation.
6///
7/// Create from pre-decoded frames with [`VideoInput::from_frames`], or use the
8/// server-core `parse_video_url` helper to decode from a video file (requires FFmpeg
9/// for non-GIF formats).
10#[derive(Clone, PartialEq)]
11pub struct VideoInput {
12    /// Decoded video frames (RGB images).
13    pub frames: Vec<DynamicImage>,
14    /// Frames per second of the *original* video. Used to compute per-frame
15    /// timestamps for the prompt (e.g. `"00:05"`). Defaults to 24.0.
16    pub fps: f64,
17    /// Total number of frames in the original video before sampling.
18    pub total_num_frames: usize,
19    /// Indices of the frames that were sampled from the original video.
20    /// Length must equal `frames.len()`.
21    pub sampled_indices: Vec<usize>,
22}
23
24impl std::fmt::Debug for VideoInput {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("VideoInput")
27            .field("num_frames", &self.frames.len())
28            .field("fps", &self.fps)
29            .field("total_num_frames", &self.total_num_frames)
30            .finish()
31    }
32}
33
34impl VideoInput {
35    /// Create a `VideoInput` from pre-decoded frames.
36    ///
37    /// `fps` is the original video frame rate (used for timestamp generation).
38    /// If the frames were not sampled (i.e. all frames are provided), pass `None`
39    /// for `sampled_indices` and they will default to `0..frames.len()`.
40    pub fn from_frames(
41        frames: Vec<DynamicImage>,
42        fps: f64,
43        sampled_indices: Option<Vec<usize>>,
44    ) -> Self {
45        let n = frames.len();
46        let sampled_indices = sampled_indices.unwrap_or_else(|| (0..n).collect());
47        Self {
48            frames,
49            fps,
50            total_num_frames: *sampled_indices.last().unwrap_or(&0) + 1,
51            sampled_indices,
52        }
53    }
54
55    /// Compute per-frame timestamps in seconds.
56    pub fn timestamps_secs(&self) -> Vec<f64> {
57        self.sampled_indices
58            .iter()
59            .map(|&idx| idx as f64 / self.fps)
60            .collect()
61    }
62
63    /// Format timestamps as `"mm:ss"` strings.
64    pub fn timestamp_strings(&self) -> Vec<String> {
65        self.timestamps_secs()
66            .iter()
67            .map(|&secs| {
68                let minutes = (secs / 60.0) as u32;
69                let seconds = (secs % 60.0) as u32;
70                format!("{minutes:02}:{seconds:02}")
71            })
72            .collect()
73    }
74
75    /// Compute a content hash for each frame (for prefix caching).
76    pub fn frame_hashes(&self) -> Vec<u64> {
77        self.frames
78            .iter()
79            .map(|img| {
80                let mut hasher = DefaultHasher::new();
81                img.as_bytes().hash(&mut hasher);
82                hasher.finish()
83            })
84            .collect()
85    }
86
87    /// Compute a single hash representing the entire video (for prefix caching).
88    pub fn video_hash(&self) -> u64 {
89        let mut hasher = DefaultHasher::new();
90        for frame in &self.frames {
91            frame.as_bytes().hash(&mut hasher);
92        }
93        self.fps.to_bits().hash(&mut hasher);
94        hasher.finish()
95    }
96}
97
98/// Sample `num_frames` frame indices uniformly from a video with `total_frames` frames.
99///
100/// Matches the HF reference: `torch.arange(0, total, total / num_frames).int()`
101pub fn sample_frame_indices(total_frames: usize, num_frames: usize) -> Vec<usize> {
102    if num_frames == 0 || total_frames == 0 {
103        return Vec::new();
104    }
105    let n = num_frames.min(total_frames);
106    (0..n)
107        .map(|i| ((i as f64) * (total_frames as f64) / (n as f64)) as usize)
108        .collect()
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn test_sample_frame_indices() {
117        let indices = sample_frame_indices(96, 32);
118        assert_eq!(indices.len(), 32);
119        assert_eq!(indices[0], 0);
120        assert_eq!(indices[1], 3);
121        assert_eq!(indices[31], 93);
122    }
123
124    #[test]
125    fn test_sample_frame_indices_equal() {
126        let indices = sample_frame_indices(5, 5);
127        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
128    }
129
130    #[test]
131    fn test_sample_frame_indices_more_than_total() {
132        let indices = sample_frame_indices(3, 10);
133        assert_eq!(indices.len(), 3);
134    }
135
136    #[test]
137    fn test_timestamp_strings() {
138        let vi = VideoInput {
139            frames: Vec::new(),
140            fps: 24.0,
141            total_num_frames: 2880,
142            sampled_indices: vec![0, 720, 1440, 2160],
143        };
144        let ts = vi.timestamp_strings();
145        assert_eq!(ts, vec!["00:00", "00:30", "01:00", "01:30"]);
146    }
147}