Skip to main content

proteus_lib/container/
info.rs

1//! Container metadata helpers and duration probing.
2
3use std::{
4    collections::HashMap,
5    fs::File,
6    io::{Read, Seek, SeekFrom},
7    path::Path,
8};
9
10use log::debug;
11
12use symphonia::core::sample::SampleFormat;
13use symphonia::core::{
14    audio::{AudioBufferRef, Channels, Layout},
15    codecs::{CodecParameters, DecoderOptions, CODEC_TYPE_NULL},
16    errors::Error,
17    formats::{FormatOptions, Track},
18    io::{MediaSource, MediaSourceStream, ReadOnlySource},
19    meta::MetadataOptions,
20    probe::{Hint, ProbeResult},
21    units::TimeBase,
22};
23
24/// Convert Symphonia codec parameters to seconds using time base and frames.
25pub fn get_time_from_frames(codec_params: &CodecParameters) -> f64 {
26    let tb = match codec_params.time_base {
27        Some(tb) => tb,
28        None => return 0.0,
29    };
30    let dur = match codec_params.n_frames {
31        Some(frames) => codec_params.start_ts + frames,
32        None => return 0.0,
33    };
34    let time = tb.calc_time(dur);
35
36    time.seconds as f64 + time.frac
37}
38
39/// Probe a media file (or stdin `-`) and return the Symphonia probe result.
40pub fn get_probe_result_from_string(file_path: &str) -> Result<ProbeResult, Error> {
41    // If the path string is '-' then read from standard input.
42    if file_path == "-" {
43        let source = Box::new(ReadOnlySource::new(std::io::stdin())) as Box<dyn MediaSource>;
44        return probe_with_hint(source, None);
45    }
46
47    let path = Path::new(file_path);
48    let ext = path
49        .extension()
50        .and_then(|ext| ext.to_str())
51        .map(|ext| ext.to_string());
52    let mut hints: Vec<Option<String>> = Vec::new();
53
54    if let Some(ext) = ext.clone() {
55        let ext_lc = ext.to_lowercase();
56        if ext_lc == "prot" {
57            hints.push(Some("mka".to_string()));
58        }
59        if ext_lc == "aiff" || ext_lc == "aif" || ext_lc == "aifc" || ext_lc == "aaif" {
60            hints.push(Some("aiff".to_string()));
61            hints.push(Some("aif".to_string()));
62            hints.push(Some("aifc".to_string()));
63        } else {
64            hints.push(Some(ext_lc));
65        }
66    }
67
68    // Always try without a hint as a fallback.
69    hints.push(None);
70
71    for hint in hints {
72        let source =
73            Box::new(File::open(path).expect("failed to open media file")) as Box<dyn MediaSource>;
74        if let Ok(probed) = probe_with_hint(source, hint.as_deref()) {
75            return Ok(probed);
76        }
77    }
78
79    Err(Error::IoError(std::io::Error::new(
80        std::io::ErrorKind::Other,
81        "Failed to probe media file",
82    )))
83}
84
85fn probe_with_hint(
86    source: Box<dyn MediaSource>,
87    extension_hint: Option<&str>,
88) -> Result<ProbeResult, Error> {
89    let mut hint = Hint::new();
90    if let Some(extension_str) = extension_hint {
91        hint.with_extension(extension_str);
92    }
93
94    let mss = MediaSourceStream::new(source, Default::default());
95    let format_opts = FormatOptions {
96        ..Default::default()
97    };
98    let metadata_opts: MetadataOptions = Default::default();
99
100    symphonia::default::get_probe().format(&hint, mss, &format_opts, &metadata_opts)
101}
102
103/// Best-effort duration mapping per track using metadata or frame counts.
104///
105/// For container files, this may be approximate if metadata is inaccurate.
106pub fn get_durations(file_path: &str) -> HashMap<u32, f64> {
107    let mut probed = match get_probe_result_from_string(file_path) {
108        Ok(probed) => probed,
109        Err(_) => return fallback_durations(file_path),
110    };
111
112    let mut durations: Vec<f64> = Vec::new();
113
114    if let Some(metadata_rev) = probed.format.metadata().current() {
115        metadata_rev.tags().iter().for_each(|tag| {
116            if tag.key == "DURATION" {
117                // Convert duration of type 01:12:37.227000000 to 4337.227
118                let duration = tag.value.to_string().clone();
119                let duration_parts = duration.split(':').collect::<Vec<&str>>();
120                if duration_parts.len() >= 3 {
121                    let hours = duration_parts[0].parse::<f64>().unwrap_or(0.0);
122                    let minutes = duration_parts[1].parse::<f64>().unwrap_or(0.0);
123                    let seconds = duration_parts[2].parse::<f64>().unwrap_or(0.0);
124                    let duration_in_seconds = (hours * 3600.0) + (minutes * 60.0) + seconds;
125                    durations.push(duration_in_seconds);
126                }
127            }
128        });
129    }
130
131    // Convert durations to HashMap with key as index and value as duration
132    let mut duration_map: HashMap<u32, f64> = HashMap::new();
133
134    if probed.format.tracks().is_empty() {
135        return fallback_durations(file_path);
136    }
137
138    for (index, track) in probed.format.tracks().iter().enumerate() {
139        if let Some(real_duration) = durations.get(index) {
140            duration_map.insert(track.id, *real_duration);
141            continue;
142        }
143
144        let codec_params = &track.codec_params;
145        let duration = get_time_from_frames(codec_params);
146        duration_map.insert(track.id, duration);
147    }
148
149    duration_map
150}
151
152fn get_durations_best_effort(file_path: &str) -> HashMap<u32, f64> {
153    let metadata_durations = std::panic::catch_unwind(|| get_durations(file_path)).ok();
154    if let Some(durations) = metadata_durations {
155        let all_zero = durations.values().all(|value| *value <= 0.0);
156        if !durations.is_empty() && !all_zero {
157            return durations;
158        }
159    }
160
161    get_durations_by_scan(file_path)
162}
163
164/// Scan all packets to compute per-track durations (accurate but slower).
165pub fn get_durations_by_scan(file_path: &str) -> HashMap<u32, f64> {
166    let mut probed = match get_probe_result_from_string(file_path) {
167        Ok(probed) => probed,
168        Err(_) => return fallback_durations(file_path),
169    };
170    if probed.format.tracks().is_empty() {
171        return fallback_durations(file_path);
172    }
173    let mut max_ts: HashMap<u32, u64> = HashMap::new();
174    let mut time_bases: HashMap<u32, Option<TimeBase>> = HashMap::new();
175    let mut sample_rates: HashMap<u32, Option<u32>> = HashMap::new();
176
177    for track in probed.format.tracks().iter() {
178        max_ts.insert(track.id, 0);
179        time_bases.insert(track.id, track.codec_params.time_base);
180        sample_rates.insert(track.id, track.codec_params.sample_rate);
181    }
182
183    loop {
184        match probed.format.next_packet() {
185            Ok(packet) => {
186                let entry = max_ts.entry(packet.track_id()).or_insert(0);
187                if packet.ts() > *entry {
188                    *entry = packet.ts();
189                }
190            }
191            Err(_) => break,
192        }
193    }
194
195    let mut duration_map: HashMap<u32, f64> = HashMap::new();
196    for (track_id, ts) in max_ts {
197        let seconds = if let Some(time_base) = time_bases.get(&track_id).copied().flatten() {
198            let time = time_base.calc_time(ts);
199            time.seconds as f64 + time.frac
200        } else if let Some(sample_rate) = sample_rates.get(&track_id).copied().flatten() {
201            ts as f64 / sample_rate as f64
202        } else {
203            0.0
204        };
205        duration_map.insert(track_id, seconds);
206    }
207
208    duration_map
209}
210
211// impl PartialEq for Layout {
212//     fn eq(&self, other: &Self) -> bool {
213//         // Implement equality comparison logic for Layout
214//         match (self, other) {
215//             (Layout::Mono, Layout::Mono) => true,
216//             (Layout::Stereo, Layout::Stereo) => true,
217//             (Layout::TwoPointOne, Layout::TwoPointOne) => true,
218//             (Layout::FivePointOne, Layout::FivePointOne) => true,
219//             _ => false,
220//         }
221//     }
222// }
223
224/// Aggregate codec information for a track.
225#[derive(Debug)]
226pub struct TrackInfo {
227    pub sample_rate: u32,
228    pub channel_count: u32,
229    pub bits_per_sample: u32,
230}
231
232fn get_track_info(track: &Track) -> TrackInfo {
233    let codec_params = &track.codec_params;
234    let sample_rate = codec_params.sample_rate.unwrap_or(0);
235    let bits_per_sample = codec_params
236        .bits_per_sample
237        .unwrap_or_else(|| bits_from_sample_format(codec_params.sample_format));
238
239    let mut channel_count = match codec_params.channel_layout {
240        Some(Layout::Mono) => 1,
241        Some(Layout::Stereo) => 2,
242        Some(Layout::TwoPointOne) => 3,
243        Some(Layout::FivePointOne) => 6,
244        _ => 0,
245    };
246
247    if channel_count == 0 {
248        channel_count = codec_params
249            .channels
250            .unwrap_or(Channels::FRONT_CENTRE)
251            .iter()
252            .count() as u32;
253    }
254
255    TrackInfo {
256        sample_rate,
257        channel_count,
258        bits_per_sample,
259    }
260}
261
262fn bits_from_sample_format(sample_format: Option<SampleFormat>) -> u32 {
263    match sample_format {
264        Some(SampleFormat::U8 | SampleFormat::S8) => 8,
265        Some(SampleFormat::U16 | SampleFormat::S16) => 16,
266        Some(SampleFormat::U24 | SampleFormat::S24) => 24,
267        Some(SampleFormat::U32 | SampleFormat::S32 | SampleFormat::F32) => 32,
268        Some(SampleFormat::F64) => 64,
269        None => 0,
270    }
271}
272
273fn bits_from_decode(file_path: &str) -> u32 {
274    let mut probed = match get_probe_result_from_string(file_path) {
275        Ok(probed) => probed,
276        Err(_) => return 0,
277    };
278
279    let (track_id, codec_params) = match probed
280        .format
281        .tracks()
282        .iter()
283        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
284    {
285        Some(track) => (track.id, track.codec_params.clone()),
286        None => return 0,
287    };
288
289    let dec_opts: DecoderOptions = Default::default();
290    let mut decoder = match symphonia::default::get_codecs().make(&codec_params, &dec_opts) {
291        Ok(decoder) => decoder,
292        Err(_) => return 0,
293    };
294
295    loop {
296        let packet = match probed.format.next_packet() {
297            Ok(packet) => packet,
298            Err(_) => return 0,
299        };
300
301        if packet.track_id() != track_id {
302            continue;
303        }
304
305        match decoder.decode(&packet) {
306            Ok(decoded) => {
307                return match decoded {
308                    AudioBufferRef::U8(_) => 8,
309                    AudioBufferRef::S8(_) => 8,
310                    AudioBufferRef::U16(_) => 16,
311                    AudioBufferRef::S16(_) => 16,
312                    AudioBufferRef::U24(_) => 24,
313                    AudioBufferRef::S24(_) => 24,
314                    AudioBufferRef::U32(_) => 32,
315                    AudioBufferRef::S32(_) => 32,
316                    AudioBufferRef::F32(_) => 32,
317                    AudioBufferRef::F64(_) => 64,
318                };
319            }
320            Err(Error::DecodeError(_)) => continue,
321            Err(_) => return 0,
322        }
323    }
324}
325
326#[derive(Debug, Clone, Copy)]
327struct AiffInfo {
328    channels: u16,
329    sample_rate: f64,
330    bits_per_sample: u16,
331    sample_frames: u32,
332}
333
334fn fallback_track_info(file_path: &str) -> TrackInfo {
335    if let Some(info) = parse_aiff_info(file_path) {
336        return TrackInfo {
337            sample_rate: info.sample_rate.round() as u32,
338            channel_count: info.channels as u32,
339            bits_per_sample: info.bits_per_sample as u32,
340        };
341    }
342
343    TrackInfo {
344        sample_rate: 0,
345        channel_count: 0,
346        bits_per_sample: 0,
347    }
348}
349
350fn fallback_durations(file_path: &str) -> HashMap<u32, f64> {
351    if let Some(info) = parse_aiff_info(file_path) {
352        let duration = if info.sample_rate > 0.0 {
353            info.sample_frames as f64 / info.sample_rate
354        } else {
355            0.0
356        };
357        let mut map = HashMap::new();
358        map.insert(0, duration);
359        return map;
360    }
361
362    HashMap::new()
363}
364
365fn parse_aiff_info(file_path: &str) -> Option<AiffInfo> {
366    let path = Path::new(file_path);
367    let ext = path
368        .extension()
369        .and_then(|ext| ext.to_str())?
370        .to_lowercase();
371    if ext != "aiff" && ext != "aif" && ext != "aifc" && ext != "aaif" {
372        return None;
373    }
374
375    let mut file = File::open(path).ok()?;
376    let mut header = [0u8; 12];
377    file.read_exact(&mut header).ok()?;
378    if &header[0..4] != b"FORM" {
379        return None;
380    }
381    let form_type = &header[8..12];
382    if form_type != b"AIFF" && form_type != b"AIFC" {
383        return None;
384    }
385
386    loop {
387        let mut chunk_header = [0u8; 8];
388        if file.read_exact(&mut chunk_header).is_err() {
389            break;
390        }
391        let chunk_id = &chunk_header[0..4];
392        let chunk_size = u32::from_be_bytes([
393            chunk_header[4],
394            chunk_header[5],
395            chunk_header[6],
396            chunk_header[7],
397        ]) as u64;
398
399        if chunk_id == b"COMM" {
400            if chunk_size < 18 {
401                return None;
402            }
403            let mut comm = vec![0u8; chunk_size as usize];
404            file.read_exact(&mut comm).ok()?;
405            let channels = u16::from_be_bytes([comm[0], comm[1]]);
406            let sample_frames = u32::from_be_bytes([comm[2], comm[3], comm[4], comm[5]]);
407            let bits_per_sample = u16::from_be_bytes([comm[6], comm[7]]);
408            let mut rate_bytes = [0u8; 10];
409            rate_bytes.copy_from_slice(&comm[8..18]);
410            let sample_rate = extended_80_to_f64(rate_bytes);
411
412            return Some(AiffInfo {
413                channels,
414                sample_rate,
415                bits_per_sample,
416                sample_frames,
417            });
418        }
419
420        let skip = chunk_size + (chunk_size % 2);
421        if file.seek(SeekFrom::Current(skip as i64)).is_err() {
422            break;
423        }
424    }
425
426    None
427}
428
429fn extended_80_to_f64(bytes: [u8; 10]) -> f64 {
430    let sign = (bytes[0] & 0x80) != 0;
431    let exponent = (((bytes[0] & 0x7F) as u16) << 8) | bytes[1] as u16;
432    let mut mantissa: u64 = 0;
433    for i in 0..8 {
434        mantissa = (mantissa << 8) | bytes[2 + i] as u64;
435    }
436
437    if exponent == 0 && mantissa == 0 {
438        return 0.0;
439    }
440    if exponent == 0x7FFF {
441        return f64::NAN;
442    }
443
444    let exp = exponent as i32 - 16383;
445    let fraction = mantissa as f64 / (1u64 << 63) as f64;
446    let value = 2f64.powi(exp) * fraction;
447    if sign {
448        -value
449    } else {
450        value
451    }
452}
453
454fn reduce_track_infos(track_infos: Vec<TrackInfo>) -> TrackInfo {
455    if track_infos.is_empty() {
456        return TrackInfo {
457            sample_rate: 0,
458            channel_count: 0,
459            bits_per_sample: 0,
460        };
461    }
462
463    let info = track_infos
464        .into_iter()
465        .fold(None, |acc: Option<TrackInfo>, track_info| match acc {
466            Some(acc) => {
467                if acc.sample_rate != 0
468                    && track_info.sample_rate != 0
469                    && acc.sample_rate != track_info.sample_rate
470                {
471                    panic!("Sample rates do not match");
472                }
473
474                if acc.channel_count != 0
475                    && track_info.channel_count != 0
476                    && acc.channel_count != track_info.channel_count
477                {
478                    panic!(
479                        "Channel layouts do not match {} != {}",
480                        acc.channel_count, track_info.channel_count
481                    );
482                }
483
484                if acc.bits_per_sample != 0
485                    && track_info.bits_per_sample != 0
486                    && acc.bits_per_sample != track_info.bits_per_sample
487                {
488                    panic!("Bits per sample do not match");
489                }
490
491                Some(TrackInfo {
492                    sample_rate: if acc.sample_rate == 0 {
493                        track_info.sample_rate
494                    } else {
495                        acc.sample_rate
496                    },
497                    channel_count: if acc.channel_count == 0 {
498                        track_info.channel_count
499                    } else {
500                        acc.channel_count
501                    },
502                    bits_per_sample: if acc.bits_per_sample == 0 {
503                        track_info.bits_per_sample
504                    } else {
505                        acc.bits_per_sample
506                    },
507                })
508            }
509            None => Some(track_info),
510        });
511
512    info.unwrap()
513}
514
515fn gather_track_info(file_path: &str) -> TrackInfo {
516    let probed = match get_probe_result_from_string(file_path) {
517        Ok(probed) => probed,
518        Err(_) => return fallback_track_info(file_path),
519    };
520
521    let tracks = probed.format.tracks();
522    if tracks.is_empty() {
523        return fallback_track_info(file_path);
524    }
525    let mut track_infos: Vec<TrackInfo> = Vec::new();
526    for track in tracks {
527        let track_info = get_track_info(track);
528        track_infos.push(track_info);
529    }
530
531    let mut info = reduce_track_infos(track_infos);
532    if info.bits_per_sample == 0 {
533        let decoded_bits = bits_from_decode(file_path);
534        if decoded_bits > 0 {
535            info.bits_per_sample = decoded_bits;
536        }
537    }
538    if info.sample_rate == 0 && info.channel_count == 0 && info.bits_per_sample == 0 {
539        return fallback_track_info(file_path);
540    }
541    info
542}
543
544fn gather_track_info_from_file_paths(file_paths: Vec<String>) -> TrackInfo {
545    let mut track_infos: Vec<TrackInfo> = Vec::new();
546
547    for file_path in file_paths {
548        debug!("File path: {:?}", file_path);
549        let track_info = gather_track_info(&file_path);
550        track_infos.push(track_info);
551    }
552
553    reduce_track_infos(track_infos)
554}
555
556/// Combined container info (track list, durations, sample format).
557#[derive(Debug, Clone)]
558pub struct Info {
559    pub file_paths: Vec<String>,
560    pub duration_map: HashMap<u32, f64>,
561    pub channels: u32,
562    pub sample_rate: u32,
563    pub bits_per_sample: u32,
564}
565
566impl Info {
567    /// Build info for a single container file by scanning all packets.
568    pub fn new(file_path: String) -> Self {
569        let track_info = gather_track_info(&file_path);
570
571        Self {
572            duration_map: get_durations_by_scan(&file_path),
573            file_paths: vec![file_path],
574            channels: track_info.channel_count,
575            sample_rate: track_info.sample_rate,
576            bits_per_sample: track_info.bits_per_sample,
577        }
578    }
579
580    /// Build info for a list of standalone files.
581    ///
582    /// Uses metadata when available and falls back to scanning.
583    pub fn new_from_file_paths(file_paths: Vec<String>) -> Self {
584        let mut duration_map: HashMap<u32, f64> = HashMap::new();
585
586        for (index, file_path) in file_paths.iter().enumerate() {
587            let durations = get_durations_best_effort(file_path);
588            let longest = durations
589                .iter()
590                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
591                .map(|entry| *entry.1)
592                .unwrap_or(0.0);
593            duration_map.insert(index as u32, longest);
594        }
595
596        let track_info = gather_track_info_from_file_paths(file_paths.clone());
597
598        Self {
599            duration_map,
600            file_paths,
601            channels: track_info.channel_count,
602            sample_rate: track_info.sample_rate,
603            bits_per_sample: track_info.bits_per_sample,
604        }
605    }
606
607    /// Get the duration for the given track index, if known.
608    pub fn get_duration(&self, index: u32) -> Option<f64> {
609        match self.duration_map.get(&index) {
610            Some(duration) => Some(*duration),
611            None => None,
612        }
613    }
614}