Skip to main content

proteus_lib/container/
info.rs

1//! Container metadata helpers and duration probing.
2
3use std::{
4    collections::HashMap,
5    fs::File,
6    io::{Read, Seek, SeekFrom},
7    path::Path,
8};
9
10use log::debug;
11
12use symphonia::core::{
13    audio::{AudioBufferRef, Channels, Layout},
14    codecs::{CodecParameters, DecoderOptions, CODEC_TYPE_NULL},
15    errors::Error,
16    formats::{FormatOptions, Track},
17    io::{MediaSource, MediaSourceStream, ReadOnlySource},
18    meta::MetadataOptions,
19    probe::{Hint, ProbeResult},
20    units::TimeBase,
21};
22use symphonia::core::sample::SampleFormat;
23
24/// Convert Symphonia codec parameters to seconds using time base and frames.
25pub fn get_time_from_frames(codec_params: &CodecParameters) -> f64 {
26    let tb = match codec_params.time_base {
27        Some(tb) => tb,
28        None => return 0.0,
29    };
30    let dur = match codec_params.n_frames {
31        Some(frames) => codec_params.start_ts + frames,
32        None => return 0.0,
33    };
34    let time = tb.calc_time(dur);
35
36    time.seconds as f64 + time.frac
37}
38
39/// Probe a media file (or stdin `-`) and return the Symphonia probe result.
40pub fn get_probe_result_from_string(file_path: &str) -> Result<ProbeResult, Error> {
41    // If the path string is '-' then read from standard input.
42    if file_path == "-" {
43        let source = Box::new(ReadOnlySource::new(std::io::stdin())) as Box<dyn MediaSource>;
44        return probe_with_hint(source, None);
45    }
46
47    let path = Path::new(file_path);
48    let ext = path.extension().and_then(|ext| ext.to_str()).map(|ext| ext.to_string());
49    let mut hints: Vec<Option<String>> = Vec::new();
50
51    if let Some(ext) = ext.clone() {
52        let ext_lc = ext.to_lowercase();
53        if ext_lc == "prot" {
54            hints.push(Some("mka".to_string()));
55        }
56        if ext_lc == "aiff" {
57            hints.push(Some("aiff".to_string()));
58            hints.push(Some("aif".to_string()));
59        } else {
60            hints.push(Some(ext_lc));
61        }
62    }
63
64    // Always try without a hint as a fallback.
65    hints.push(None);
66
67    for hint in hints {
68        let source = Box::new(File::open(path).expect("failed to open media file")) as Box<dyn MediaSource>;
69        if let Ok(probed) = probe_with_hint(source, hint.as_deref()) {
70            return Ok(probed);
71        }
72    }
73
74    Err(Error::IoError(std::io::Error::new(
75        std::io::ErrorKind::Other,
76        "Failed to probe media file",
77    )))
78}
79
80fn probe_with_hint(
81    source: Box<dyn MediaSource>,
82    extension_hint: Option<&str>,
83) -> Result<ProbeResult, Error> {
84    let mut hint = Hint::new();
85    if let Some(extension_str) = extension_hint {
86        hint.with_extension(extension_str);
87    }
88
89    let mss = MediaSourceStream::new(source, Default::default());
90    let format_opts = FormatOptions {
91        ..Default::default()
92    };
93    let metadata_opts: MetadataOptions = Default::default();
94
95    symphonia::default::get_probe().format(&hint, mss, &format_opts, &metadata_opts)
96}
97
98/// Best-effort duration mapping per track using metadata or frame counts.
99///
100/// For container files, this may be approximate if metadata is inaccurate.
101pub fn get_durations(file_path: &str) -> HashMap<u32, f64> {
102    let mut probed = match get_probe_result_from_string(file_path) {
103        Ok(probed) => probed,
104        Err(_) => return fallback_durations(file_path),
105    };
106
107    let mut durations: Vec<f64> = Vec::new();
108
109    if let Some(metadata_rev) = probed.format.metadata().current() {
110        metadata_rev.tags().iter().for_each(|tag| {
111            if tag.key == "DURATION" {
112                // Convert duration of type 01:12:37.227000000 to 4337.227
113                let duration = tag.value.to_string().clone();
114                let duration_parts = duration.split(':').collect::<Vec<&str>>();
115                if duration_parts.len() >= 3 {
116                    let hours = duration_parts[0].parse::<f64>().unwrap_or(0.0);
117                    let minutes = duration_parts[1].parse::<f64>().unwrap_or(0.0);
118                    let seconds = duration_parts[2].parse::<f64>().unwrap_or(0.0);
119                    let duration_in_seconds = (hours * 3600.0) + (minutes * 60.0) + seconds;
120                    durations.push(duration_in_seconds);
121                }
122            }
123        });
124    }
125
126    // Convert durations to HashMap with key as index and value as duration
127    let mut duration_map: HashMap<u32, f64> = HashMap::new();
128
129    if probed.format.tracks().is_empty() {
130        return fallback_durations(file_path);
131    }
132
133    for (index, track) in probed.format.tracks().iter().enumerate() {
134        if let Some(real_duration) = durations.get(index) {
135            duration_map.insert(track.id, *real_duration);
136            continue;
137        }
138
139        let codec_params = &track.codec_params;
140        let duration = get_time_from_frames(codec_params);
141        duration_map.insert(track.id, duration);
142    }
143
144    duration_map
145}
146
147fn get_durations_best_effort(file_path: &str) -> HashMap<u32, f64> {
148    let metadata_durations = std::panic::catch_unwind(|| get_durations(file_path)).ok();
149    if let Some(durations) = metadata_durations {
150        let all_zero = durations.values().all(|value| *value <= 0.0);
151        if !durations.is_empty() && !all_zero {
152            return durations;
153        }
154    }
155
156    get_durations_by_scan(file_path)
157}
158
159/// Scan all packets to compute per-track durations (accurate but slower).
160pub fn get_durations_by_scan(file_path: &str) -> HashMap<u32, f64> {
161    let mut probed = match get_probe_result_from_string(file_path) {
162        Ok(probed) => probed,
163        Err(_) => return fallback_durations(file_path),
164    };
165    if probed.format.tracks().is_empty() {
166        return fallback_durations(file_path);
167    }
168    let mut max_ts: HashMap<u32, u64> = HashMap::new();
169    let mut time_bases: HashMap<u32, Option<TimeBase>> = HashMap::new();
170    let mut sample_rates: HashMap<u32, Option<u32>> = HashMap::new();
171
172    for track in probed.format.tracks().iter() {
173        max_ts.insert(track.id, 0);
174        time_bases.insert(track.id, track.codec_params.time_base);
175        sample_rates.insert(track.id, track.codec_params.sample_rate);
176    }
177
178    loop {
179        match probed.format.next_packet() {
180            Ok(packet) => {
181                let entry = max_ts.entry(packet.track_id()).or_insert(0);
182                if packet.ts() > *entry {
183                    *entry = packet.ts();
184                }
185            }
186            Err(_) => break,
187        }
188    }
189
190    let mut duration_map: HashMap<u32, f64> = HashMap::new();
191    for (track_id, ts) in max_ts {
192        let seconds = if let Some(time_base) = time_bases.get(&track_id).copied().flatten() {
193            let time = time_base.calc_time(ts);
194            time.seconds as f64 + time.frac
195        } else if let Some(sample_rate) = sample_rates.get(&track_id).copied().flatten() {
196            ts as f64 / sample_rate as f64
197        } else {
198            0.0
199        };
200        duration_map.insert(track_id, seconds);
201    }
202
203    duration_map
204}
205
206// impl PartialEq for Layout {
207//     fn eq(&self, other: &Self) -> bool {
208//         // Implement equality comparison logic for Layout
209//         match (self, other) {
210//             (Layout::Mono, Layout::Mono) => true,
211//             (Layout::Stereo, Layout::Stereo) => true,
212//             (Layout::TwoPointOne, Layout::TwoPointOne) => true,
213//             (Layout::FivePointOne, Layout::FivePointOne) => true,
214//             _ => false,
215//         }
216//     }
217// }
218
219/// Aggregate codec information for a track.
220#[derive(Debug)]
221pub struct TrackInfo {
222    pub sample_rate: u32,
223    pub channel_count: u32,
224    pub bits_per_sample: u32,
225}
226
227fn get_track_info(track: &Track) -> TrackInfo {
228    let codec_params = &track.codec_params;
229    let sample_rate = codec_params.sample_rate.unwrap_or(0);
230    let bits_per_sample = codec_params
231        .bits_per_sample
232        .unwrap_or_else(|| bits_from_sample_format(codec_params.sample_format));
233
234    let mut channel_count = match codec_params.channel_layout {
235        Some(Layout::Mono) => 1,
236        Some(Layout::Stereo) => 2,
237        Some(Layout::TwoPointOne) => 3,
238        Some(Layout::FivePointOne) => 6,
239        _ => 0,
240    };
241
242    if channel_count == 0 {
243        channel_count = codec_params
244            .channels
245            .unwrap_or(Channels::FRONT_CENTRE)
246            .iter()
247            .count() as u32;
248    }
249
250    TrackInfo {
251        sample_rate,
252        channel_count,
253        bits_per_sample,
254    }
255}
256
257fn bits_from_sample_format(sample_format: Option<SampleFormat>) -> u32 {
258    match sample_format {
259        Some(SampleFormat::U8 | SampleFormat::S8) => 8,
260        Some(SampleFormat::U16 | SampleFormat::S16) => 16,
261        Some(SampleFormat::U24 | SampleFormat::S24) => 24,
262        Some(SampleFormat::U32 | SampleFormat::S32 | SampleFormat::F32) => 32,
263        Some(SampleFormat::F64) => 64,
264        None => 0,
265    }
266}
267
268fn bits_from_decode(file_path: &str) -> u32 {
269    let mut probed = match get_probe_result_from_string(file_path) {
270        Ok(probed) => probed,
271        Err(_) => return 0,
272    };
273
274    let (track_id, codec_params) = match probed
275        .format
276        .tracks()
277        .iter()
278        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
279    {
280        Some(track) => (track.id, track.codec_params.clone()),
281        None => return 0,
282    };
283
284    let dec_opts: DecoderOptions = Default::default();
285    let mut decoder = match symphonia::default::get_codecs().make(&codec_params, &dec_opts)
286    {
287        Ok(decoder) => decoder,
288        Err(_) => return 0,
289    };
290
291    loop {
292        let packet = match probed.format.next_packet() {
293            Ok(packet) => packet,
294            Err(_) => return 0,
295        };
296
297        if packet.track_id() != track_id {
298            continue;
299        }
300
301        match decoder.decode(&packet) {
302            Ok(decoded) => {
303                return match decoded {
304                    AudioBufferRef::U8(_) => 8,
305                    AudioBufferRef::S8(_) => 8,
306                    AudioBufferRef::U16(_) => 16,
307                    AudioBufferRef::S16(_) => 16,
308                    AudioBufferRef::U24(_) => 24,
309                    AudioBufferRef::S24(_) => 24,
310                    AudioBufferRef::U32(_) => 32,
311                    AudioBufferRef::S32(_) => 32,
312                    AudioBufferRef::F32(_) => 32,
313                    AudioBufferRef::F64(_) => 64,
314                };
315            }
316            Err(Error::DecodeError(_)) => continue,
317            Err(_) => return 0,
318        }
319    }
320}
321
322#[derive(Debug, Clone, Copy)]
323struct AiffInfo {
324    channels: u16,
325    sample_rate: f64,
326    bits_per_sample: u16,
327    sample_frames: u32,
328}
329
330fn fallback_track_info(file_path: &str) -> TrackInfo {
331    if let Some(info) = parse_aiff_info(file_path) {
332        return TrackInfo {
333            sample_rate: info.sample_rate.round() as u32,
334            channel_count: info.channels as u32,
335            bits_per_sample: info.bits_per_sample as u32,
336        };
337    }
338
339    TrackInfo {
340        sample_rate: 0,
341        channel_count: 0,
342        bits_per_sample: 0,
343    }
344}
345
346fn fallback_durations(file_path: &str) -> HashMap<u32, f64> {
347    if let Some(info) = parse_aiff_info(file_path) {
348        let duration = if info.sample_rate > 0.0 {
349            info.sample_frames as f64 / info.sample_rate
350        } else {
351            0.0
352        };
353        let mut map = HashMap::new();
354        map.insert(0, duration);
355        return map;
356    }
357
358    HashMap::new()
359}
360
361fn parse_aiff_info(file_path: &str) -> Option<AiffInfo> {
362    let path = Path::new(file_path);
363    let ext = path.extension().and_then(|ext| ext.to_str())?.to_lowercase();
364    if ext != "aiff" && ext != "aif" && ext != "aifc" {
365        return None;
366    }
367
368    let mut file = File::open(path).ok()?;
369    let mut header = [0u8; 12];
370    file.read_exact(&mut header).ok()?;
371    if &header[0..4] != b"FORM" {
372        return None;
373    }
374    let form_type = &header[8..12];
375    if form_type != b"AIFF" && form_type != b"AIFC" {
376        return None;
377    }
378
379    loop {
380        let mut chunk_header = [0u8; 8];
381        if file.read_exact(&mut chunk_header).is_err() {
382            break;
383        }
384        let chunk_id = &chunk_header[0..4];
385        let chunk_size = u32::from_be_bytes([
386            chunk_header[4],
387            chunk_header[5],
388            chunk_header[6],
389            chunk_header[7],
390        ]) as u64;
391
392        if chunk_id == b"COMM" {
393            if chunk_size < 18 {
394                return None;
395            }
396            let mut comm = vec![0u8; chunk_size as usize];
397            file.read_exact(&mut comm).ok()?;
398            let channels = u16::from_be_bytes([comm[0], comm[1]]);
399            let sample_frames = u32::from_be_bytes([comm[2], comm[3], comm[4], comm[5]]);
400            let bits_per_sample = u16::from_be_bytes([comm[6], comm[7]]);
401            let mut rate_bytes = [0u8; 10];
402            rate_bytes.copy_from_slice(&comm[8..18]);
403            let sample_rate = extended_80_to_f64(rate_bytes);
404
405            return Some(AiffInfo {
406                channels,
407                sample_rate,
408                bits_per_sample,
409                sample_frames,
410            });
411        }
412
413        let skip = chunk_size + (chunk_size % 2);
414        if file.seek(SeekFrom::Current(skip as i64)).is_err() {
415            break;
416        }
417    }
418
419    None
420}
421
422fn extended_80_to_f64(bytes: [u8; 10]) -> f64 {
423    let sign = (bytes[0] & 0x80) != 0;
424    let exponent = (((bytes[0] & 0x7F) as u16) << 8) | bytes[1] as u16;
425    let mut mantissa: u64 = 0;
426    for i in 0..8 {
427        mantissa = (mantissa << 8) | bytes[2 + i] as u64;
428    }
429
430    if exponent == 0 && mantissa == 0 {
431        return 0.0;
432    }
433    if exponent == 0x7FFF {
434        return f64::NAN;
435    }
436
437    let exp = exponent as i32 - 16383;
438    let fraction = mantissa as f64 / (1u64 << 63) as f64;
439    let value = 2f64.powi(exp) * fraction;
440    if sign { -value } else { value }
441}
442
443fn reduce_track_infos(track_infos: Vec<TrackInfo>) -> TrackInfo {
444    if track_infos.is_empty() {
445        return TrackInfo {
446            sample_rate: 0,
447            channel_count: 0,
448            bits_per_sample: 0,
449        };
450    }
451
452    let info = track_infos
453        .into_iter()
454        .fold(None, |acc: Option<TrackInfo>, track_info| match acc {
455            Some(acc) => {
456                if acc.sample_rate != 0
457                    && track_info.sample_rate != 0
458                    && acc.sample_rate != track_info.sample_rate
459                {
460                    panic!("Sample rates do not match");
461                }
462
463                if acc.channel_count != 0
464                    && track_info.channel_count != 0
465                    && acc.channel_count != track_info.channel_count
466                {
467                    panic!(
468                        "Channel layouts do not match {} != {}",
469                        acc.channel_count, track_info.channel_count
470                    );
471                }
472
473                if acc.bits_per_sample != 0
474                    && track_info.bits_per_sample != 0
475                    && acc.bits_per_sample != track_info.bits_per_sample
476                {
477                    panic!("Bits per sample do not match");
478                }
479
480                Some(TrackInfo {
481                    sample_rate: if acc.sample_rate == 0 {
482                        track_info.sample_rate
483                    } else {
484                        acc.sample_rate
485                    },
486                    channel_count: if acc.channel_count == 0 {
487                        track_info.channel_count
488                    } else {
489                        acc.channel_count
490                    },
491                    bits_per_sample: if acc.bits_per_sample == 0 {
492                        track_info.bits_per_sample
493                    } else {
494                        acc.bits_per_sample
495                    },
496                })
497            }
498            None => Some(track_info),
499        });
500
501    info.unwrap()
502}
503
504fn gather_track_info(file_path: &str) -> TrackInfo {
505    let probed = match get_probe_result_from_string(file_path) {
506        Ok(probed) => probed,
507        Err(_) => return fallback_track_info(file_path),
508    };
509
510    let tracks = probed.format.tracks();
511    if tracks.is_empty() {
512        return fallback_track_info(file_path);
513    }
514    let mut track_infos: Vec<TrackInfo> = Vec::new();
515    for track in tracks {
516        let track_info = get_track_info(track);
517        track_infos.push(track_info);
518    }
519
520    let mut info = reduce_track_infos(track_infos);
521    if info.bits_per_sample == 0 {
522        let decoded_bits = bits_from_decode(file_path);
523        if decoded_bits > 0 {
524            info.bits_per_sample = decoded_bits;
525        }
526    }
527    if info.sample_rate == 0 && info.channel_count == 0 && info.bits_per_sample == 0 {
528        return fallback_track_info(file_path);
529    }
530    info
531}
532
533fn gather_track_info_from_file_paths(file_paths: Vec<String>) -> TrackInfo {
534    let mut track_infos: Vec<TrackInfo> = Vec::new();
535
536    for file_path in file_paths {
537        debug!("File path: {:?}", file_path);
538        let track_info = gather_track_info(&file_path);
539        track_infos.push(track_info);
540    }
541
542    reduce_track_infos(track_infos)
543}
544
545/// Combined container info (track list, durations, sample format).
546#[derive(Debug, Clone)]
547pub struct Info {
548    pub file_paths: Vec<String>,
549    pub duration_map: HashMap<u32, f64>,
550    pub channels: u32,
551    pub sample_rate: u32,
552    pub bits_per_sample: u32,
553}
554
555impl Info {
556    /// Build info for a single container file by scanning all packets.
557    pub fn new(file_path: String) -> Self {
558        let track_info = gather_track_info(&file_path);
559
560        Self {
561            duration_map: get_durations_by_scan(&file_path),
562            file_paths: vec![file_path],
563            channels: track_info.channel_count,
564            sample_rate: track_info.sample_rate,
565            bits_per_sample: track_info.bits_per_sample,
566        }
567    }
568
569    /// Build info for a list of standalone files.
570    ///
571    /// Uses metadata when available and falls back to scanning.
572    pub fn new_from_file_paths(file_paths: Vec<String>) -> Self {
573        let mut duration_map: HashMap<u32, f64> = HashMap::new();
574
575        for (index, file_path) in file_paths.iter().enumerate() {
576            let durations = get_durations_best_effort(file_path);
577            let longest = durations
578                .iter()
579                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
580                .map(|entry| *entry.1)
581                .unwrap_or(0.0);
582            duration_map.insert(index as u32, longest);
583        }
584
585        let track_info = gather_track_info_from_file_paths(file_paths.clone());
586
587        Self {
588            duration_map,
589            file_paths,
590            channels: track_info.channel_count,
591            sample_rate: track_info.sample_rate,
592            bits_per_sample: track_info.bits_per_sample,
593        }
594    }
595
596    /// Get the duration for the given track index, if known.
597    pub fn get_duration(&self, index: u32) -> Option<f64> {
598        match self.duration_map.get(&index) {
599            Some(duration) => Some(*duration),
600            None => None,
601        }
602    }
603}