Skip to main content

proteus_lib/container/
info.rs

1//! Container metadata helpers and duration probing.
2
3use std::{collections::HashMap, fs::File, path::Path};
4
5use log::debug;
6
7use symphonia::core::{
8    audio::{Channels, Layout},
9    codecs::CodecParameters,
10    formats::{FormatOptions, Track},
11    io::{MediaSource, MediaSourceStream, ReadOnlySource},
12    meta::MetadataOptions,
13    probe::{Hint, ProbeResult},
14    units::TimeBase,
15};
16
17/// Convert Symphonia codec parameters to seconds using time base and frames.
18pub fn get_time_from_frames(codec_params: &CodecParameters) -> f64 {
19    let tb = codec_params.time_base.unwrap();
20    let dur = codec_params
21        .n_frames
22        .map(|frames| codec_params.start_ts + frames)
23        .unwrap();
24    let time = tb.calc_time(dur);
25
26    time.seconds as f64 + time.frac
27}
28
29/// Probe a media file (or stdin `-`) and return the Symphonia probe result.
30pub fn get_probe_result_from_string(file_path: &str) -> ProbeResult {
31    // Create a hint to help the format registry guess what format reader is appropriate.
32    let mut hint = Hint::new();
33
34    // If the path string is '-' then read from standard input.
35    let source = if file_path == "-" {
36        Box::new(ReadOnlySource::new(std::io::stdin())) as Box<dyn MediaSource>
37    } else {
38        // Othwerise, get a Path from the path string.
39        let path = Path::new(file_path);
40
41        // Provide the file extension as a hint.
42        if let Some(extension) = path.extension() {
43            if let Some(extension_str) = extension.to_str() {
44                hint.with_extension(extension_str);
45            }
46        }
47
48        Box::new(File::open(path).expect("failed to open media file")) as Box<dyn MediaSource>
49    };
50
51    // Create the media source stream using the boxed media source from above.
52    let mss = MediaSourceStream::new(source, Default::default());
53
54    // Use the default options for format readers other than for gapless playback.
55    let format_opts = FormatOptions {
56        // enable_gapless: !args.is_present("no-gapless"),
57        ..Default::default()
58    };
59
60    // Use the default options for metadata readers.
61    let metadata_opts: MetadataOptions = Default::default();
62
63    // Get the value of the track option, if provided.
64    // let track = match args.value_of("track") {
65    //     Some(track_str) => track_str.parse::<usize>().ok(),
66    //     _ => None,
67    // };
68
69    symphonia::default::get_probe()
70        .format(&hint, mss, &format_opts, &metadata_opts)
71        .unwrap()
72}
73
74/// Best-effort duration mapping per track using metadata or frame counts.
75///
76/// For container files, this may be approximate if metadata is inaccurate.
77pub fn get_durations(file_path: &str) -> HashMap<u32, f64> {
78    let mut probed = get_probe_result_from_string(file_path);
79
80    let mut durations: Vec<f64> = Vec::new();
81
82    if let Some(metadata_rev) = probed.format.metadata().current() {
83        metadata_rev.tags().iter().for_each(|tag| {
84            if tag.key == "DURATION" {
85                // Convert duration of type 01:12:37.227000000 to 4337.227
86                let duration = tag.value.to_string().clone();
87                let duration_parts = duration.split(":").collect::<Vec<&str>>();
88                let hours = duration_parts[0].parse::<f64>().unwrap();
89                let minutes = duration_parts[1].parse::<f64>().unwrap();
90                let seconds = duration_parts[2].parse::<f64>().unwrap();
91                // let milliseconds = duration_parts[3].parse::<f64>().unwrap();
92                let duration_in_seconds = (hours * 3600.0) + (minutes * 60.0) + seconds;
93
94                durations.push(duration_in_seconds);
95            }
96        });
97    }
98
99    // Convert durations to HashMap with key as index and value as duration
100    let mut duration_map: HashMap<u32, f64> = HashMap::new();
101
102    for (index, track) in probed.format.tracks().iter().enumerate() {
103        if let Some(real_duration) = durations.get(index) {
104            duration_map.insert(track.id, *real_duration);
105            continue;
106        }
107
108        let codec_params = &track.codec_params;
109        let duration = get_time_from_frames(codec_params);
110        duration_map.insert(track.id, duration);
111    }
112
113    duration_map
114}
115
116fn get_durations_best_effort(file_path: &str) -> HashMap<u32, f64> {
117    let metadata_durations = std::panic::catch_unwind(|| get_durations(file_path)).ok();
118    if let Some(durations) = metadata_durations {
119        let all_zero = durations.values().all(|value| *value <= 0.0);
120        if !durations.is_empty() && !all_zero {
121            return durations;
122        }
123    }
124
125    get_durations_by_scan(file_path)
126}
127
128/// Scan all packets to compute per-track durations (accurate but slower).
129pub fn get_durations_by_scan(file_path: &str) -> HashMap<u32, f64> {
130    let mut probed = get_probe_result_from_string(file_path);
131    let mut max_ts: HashMap<u32, u64> = HashMap::new();
132    let mut time_bases: HashMap<u32, Option<TimeBase>> = HashMap::new();
133    let mut sample_rates: HashMap<u32, Option<u32>> = HashMap::new();
134
135    for track in probed.format.tracks().iter() {
136        max_ts.insert(track.id, 0);
137        time_bases.insert(track.id, track.codec_params.time_base);
138        sample_rates.insert(track.id, track.codec_params.sample_rate);
139    }
140
141    loop {
142        match probed.format.next_packet() {
143            Ok(packet) => {
144                let entry = max_ts.entry(packet.track_id()).or_insert(0);
145                if packet.ts() > *entry {
146                    *entry = packet.ts();
147                }
148            }
149            Err(_) => break,
150        }
151    }
152
153    let mut duration_map: HashMap<u32, f64> = HashMap::new();
154    for (track_id, ts) in max_ts {
155        let seconds = if let Some(time_base) = time_bases.get(&track_id).copied().flatten() {
156            let time = time_base.calc_time(ts);
157            time.seconds as f64 + time.frac
158        } else if let Some(sample_rate) = sample_rates.get(&track_id).copied().flatten() {
159            ts as f64 / sample_rate as f64
160        } else {
161            0.0
162        };
163        duration_map.insert(track_id, seconds);
164    }
165
166    duration_map
167}
168
169// impl PartialEq for Layout {
170//     fn eq(&self, other: &Self) -> bool {
171//         // Implement equality comparison logic for Layout
172//         match (self, other) {
173//             (Layout::Mono, Layout::Mono) => true,
174//             (Layout::Stereo, Layout::Stereo) => true,
175//             (Layout::TwoPointOne, Layout::TwoPointOne) => true,
176//             (Layout::FivePointOne, Layout::FivePointOne) => true,
177//             _ => false,
178//         }
179//     }
180// }
181
182/// Aggregate codec information for a track.
183#[derive(Debug)]
184pub struct TrackInfo {
185    pub sample_rate: u32,
186    pub channel_count: u32,
187    pub bits_per_sample: u32,
188}
189
190fn get_track_info(track: &Track) -> TrackInfo {
191    let codec_params = &track.codec_params;
192    let sample_rate = codec_params.sample_rate.unwrap();
193    let bits_per_sample = codec_params.bits_per_sample.unwrap();
194
195    let mut channel_count = match codec_params.channel_layout {
196        Some(Layout::Mono) => 1,
197        Some(Layout::Stereo) => 2,
198        Some(Layout::TwoPointOne) => 3,
199        Some(Layout::FivePointOne) => 6,
200        _ => 0,
201    };
202
203    if channel_count == 0 {
204        channel_count = codec_params
205            .channels
206            .unwrap_or(Channels::FRONT_CENTRE)
207            .iter()
208            .count() as u32;
209    }
210
211    TrackInfo {
212        sample_rate,
213        channel_count,
214        bits_per_sample,
215    }
216}
217
218fn reduce_track_infos(track_infos: Vec<TrackInfo>) -> TrackInfo {
219    let info = track_infos
220        .into_iter()
221        .fold(None, |acc: Option<TrackInfo>, track_info| match acc {
222            Some(acc) => {
223                if acc.sample_rate != track_info.sample_rate {
224                    panic!("Sample rates do not match");
225                }
226
227                if acc.channel_count != track_info.channel_count {
228                    panic!(
229                        "Channel layouts do not match {} != {}",
230                        acc.channel_count, track_info.channel_count
231                    );
232                }
233
234                if acc.bits_per_sample != track_info.bits_per_sample {
235                    panic!("Bits per sample do not match");
236                }
237
238                Some(acc)
239            }
240            None => Some(track_info),
241        });
242
243    info.unwrap()
244}
245
246fn gather_track_info(file_path: &str) -> TrackInfo {
247    let probed = get_probe_result_from_string(file_path);
248
249    let tracks = probed.format.tracks();
250    let mut track_infos: Vec<TrackInfo> = Vec::new();
251    for track in tracks {
252        let track_info = get_track_info(track);
253        track_infos.push(track_info);
254    }
255
256    reduce_track_infos(track_infos)
257}
258
259fn gather_track_info_from_file_paths(file_paths: Vec<String>) -> TrackInfo {
260    let mut track_infos: Vec<TrackInfo> = Vec::new();
261
262    for file_path in file_paths {
263        debug!("File path: {:?}", file_path);
264        let track_info = gather_track_info(&file_path);
265        track_infos.push(track_info);
266    }
267
268    reduce_track_infos(track_infos)
269}
270
271/// Combined container info (track list, durations, sample format).
272#[derive(Debug, Clone)]
273pub struct Info {
274    pub file_paths: Vec<String>,
275    pub duration_map: HashMap<u32, f64>,
276    pub channels: u32,
277    pub sample_rate: u32,
278    pub bits_per_sample: u32,
279}
280
281impl Info {
282    /// Build info for a single container file by scanning all packets.
283    pub fn new(file_path: String) -> Self {
284        let track_info = gather_track_info(&file_path);
285
286        Self {
287            duration_map: get_durations_by_scan(&file_path),
288            file_paths: vec![file_path],
289            channels: track_info.channel_count,
290            sample_rate: track_info.sample_rate,
291            bits_per_sample: track_info.bits_per_sample,
292        }
293    }
294
295    /// Build info for a list of standalone files.
296    ///
297    /// Uses metadata when available and falls back to scanning.
298    pub fn new_from_file_paths(file_paths: Vec<String>) -> Self {
299        let mut duration_map: HashMap<u32, f64> = HashMap::new();
300
301        for (index, file_path) in file_paths.iter().enumerate() {
302            let durations = get_durations_best_effort(file_path);
303            let longest = durations
304                .iter()
305                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
306                .map(|entry| *entry.1)
307                .unwrap_or(0.0);
308            duration_map.insert(index as u32, longest);
309        }
310
311        let track_info = gather_track_info_from_file_paths(file_paths.clone());
312
313        Self {
314            duration_map,
315            file_paths,
316            channels: track_info.channel_count,
317            sample_rate: track_info.sample_rate,
318            bits_per_sample: track_info.bits_per_sample,
319        }
320    }
321
322    /// Get the duration for the given track index, if known.
323    pub fn get_duration(&self, index: u32) -> Option<f64> {
324        match self.duration_map.get(&index) {
325            Some(duration) => Some(*duration),
326            None => None,
327        }
328    }
329}