Skip to main content

proteus_lib/
info.rs

1use std::{collections::HashMap, fs::File, path::Path};
2
3use log::debug;
4
5use symphonia::core::{
6    audio::{Channels, Layout}, codecs::CodecParameters, formats::{FormatOptions, Track}, io::{
7        MediaSource, MediaSourceStream, ReadOnlySource
8    }, meta::MetadataOptions, probe::{
9        Hint,
10        ProbeResult
11    }
12};
13
14pub fn get_time_from_frames(codec_params: &CodecParameters) -> f64 {
15    let tb = codec_params.time_base.unwrap();
16    let dur = codec_params.n_frames.map(|frames| codec_params.start_ts + frames).unwrap();
17    let time = tb.calc_time(dur);
18
19    time.seconds as f64 + time.frac
20}
21
22pub fn get_probe_result_from_string(file_path: &str) -> ProbeResult {
23    // Create a hint to help the format registry guess what format reader is appropriate.
24    let mut hint = Hint::new();
25
26    // If the path string is '-' then read from standard input.
27    let source = if file_path == "-" {
28        Box::new(ReadOnlySource::new(std::io::stdin())) as Box<dyn MediaSource>
29    } else {
30        // Othwerise, get a Path from the path string.
31        let path = Path::new(file_path);
32
33        // Provide the file extension as a hint.
34        if let Some(extension) = path.extension() {
35            if let Some(extension_str) = extension.to_str() {
36                hint.with_extension(extension_str);
37            }
38        }
39
40        Box::new(File::open(path).expect("failed to open media file")) as Box<dyn MediaSource>
41    };
42
43    // Create the media source stream using the boxed media source from above.
44    let mss = MediaSourceStream::new(source, Default::default());
45
46    // Use the default options for format readers other than for gapless playback.
47    let format_opts = FormatOptions {
48        // enable_gapless: !args.is_present("no-gapless"),
49        ..Default::default()
50    };
51
52    // Use the default options for metadata readers.
53    let metadata_opts: MetadataOptions = Default::default();
54
55    // Get the value of the track option, if provided.
56    // let track = match args.value_of("track") {
57    //     Some(track_str) => track_str.parse::<usize>().ok(),
58    //     _ => None,
59    // };
60
61    symphonia::default::get_probe().format(&hint, mss, &format_opts, &metadata_opts).unwrap()
62}
63
64fn get_durations(file_path: &str) -> HashMap<u32, f64> {
65    let mut probed = get_probe_result_from_string(file_path);
66
67    let mut durations: Vec<f64> = Vec::new();
68
69    if let Some(metadata_rev) = probed.format.metadata().current() {
70        metadata_rev.tags().iter().for_each(|tag| {
71            if tag.key == "DURATION" {
72                // Convert duration of type 01:12:37.227000000 to 4337.227
73                let duration = tag.value.to_string().clone();
74                let duration_parts = duration.split(":").collect::<Vec<&str>>();
75                let hours = duration_parts[0].parse::<f64>().unwrap();
76                let minutes = duration_parts[1].parse::<f64>().unwrap();
77                let seconds = duration_parts[2].parse::<f64>().unwrap();
78                // let milliseconds = duration_parts[3].parse::<f64>().unwrap();
79                let duration_in_seconds = (hours * 3600.0) + (minutes * 60.0) + seconds;
80
81                durations.push(duration_in_seconds);
82            }
83        });
84    }
85
86    // Convert durations to HashMap with key as index and value as duration
87    let mut duration_map: HashMap<u32, f64> = HashMap::new();
88
89    for (index, track) in probed.format.tracks().iter().enumerate() {
90        if let Some(real_duration) = durations.get(index) {
91            duration_map.insert(track.id, *real_duration);
92            continue;
93        }
94
95        let codec_params = &track.codec_params;
96        let duration = get_time_from_frames(codec_params);
97        duration_map.insert(track.id, duration);
98    }
99
100    duration_map
101}
102
103// impl PartialEq for Layout {
104//     fn eq(&self, other: &Self) -> bool {
105//         // Implement equality comparison logic for Layout
106//         match (self, other) {
107//             (Layout::Mono, Layout::Mono) => true,
108//             (Layout::Stereo, Layout::Stereo) => true,
109//             (Layout::TwoPointOne, Layout::TwoPointOne) => true,
110//             (Layout::FivePointOne, Layout::FivePointOne) => true,
111//             _ => false,
112//         }
113//     }
114// }
115
116#[derive(Debug)]
117pub struct TrackInfo {
118    pub sample_rate: u32,
119    pub channel_count: u32,
120    pub bits_per_sample: u32,
121}
122
123fn get_track_info(track: &Track) -> TrackInfo {
124    let codec_params = &track.codec_params;
125    let sample_rate = codec_params.sample_rate.unwrap();
126    let bits_per_sample = codec_params.bits_per_sample.unwrap();
127
128    let mut channel_count = match codec_params.channel_layout {
129        Some(Layout::Mono) => 1,
130        Some(Layout::Stereo) => 2,
131        Some(Layout::TwoPointOne) => 3,
132        Some(Layout::FivePointOne) => 6,
133        _ => 0,
134    };
135
136    if channel_count == 0 {
137        channel_count = codec_params.channels.unwrap_or(Channels::FRONT_CENTRE).iter().count() as u32;
138    }
139    
140    TrackInfo {
141        sample_rate,
142        channel_count,
143        bits_per_sample,
144    }
145}
146
147fn reduce_track_infos(track_infos: Vec<TrackInfo>) -> TrackInfo {
148    let info = track_infos.into_iter().fold(None, |acc: Option<TrackInfo>, track_info| {
149        match acc {
150            Some(acc) => {
151                if acc.sample_rate != track_info.sample_rate {
152                    panic!("Sample rates do not match");
153                }
154
155                if acc.channel_count != track_info.channel_count {
156                    panic!("Channel layouts do not match {} != {}", acc.channel_count, track_info.channel_count);
157                }
158
159                if acc.bits_per_sample != track_info.bits_per_sample {
160                    panic!("Bits per sample do not match");
161                }
162
163                Some(acc)
164            },
165            None => Some(track_info),
166        }
167    });
168
169    info.unwrap()
170}
171
172fn gather_track_info(file_path: &str) -> TrackInfo {
173    let probed = get_probe_result_from_string(file_path);
174
175    let tracks = probed.format.tracks();
176    let mut track_infos: Vec<TrackInfo> = Vec::new();
177    for track in tracks {
178        let track_info = get_track_info(track);
179        track_infos.push(track_info);
180    }
181    
182    reduce_track_infos(track_infos)
183}
184
185fn gather_track_info_from_file_paths(file_paths: Vec<String>) -> TrackInfo {
186    let mut track_infos: Vec<TrackInfo> = Vec::new();
187
188    for file_path in file_paths {
189        debug!("File path: {:?}", file_path);
190        let track_info = gather_track_info(&file_path);
191        track_infos.push(track_info);
192    }
193
194    reduce_track_infos(track_infos)
195}
196
197#[derive(Debug, Clone)]
198pub struct Info {
199    pub file_paths: Vec<String>,
200    pub duration_map: HashMap<u32, f64>,
201    pub channels: u32,
202    pub sample_rate: u32,
203    pub bits_per_sample: u32,
204}
205
206impl Info {
207    pub fn new(file_path: String) -> Self {
208        let track_info = gather_track_info(&file_path);
209
210        Self {
211            duration_map: get_durations(&file_path),
212            file_paths: vec![file_path],
213            channels: track_info.channel_count,
214            sample_rate: track_info.sample_rate,
215            bits_per_sample: track_info.bits_per_sample,
216        }
217    }
218
219    pub fn new_from_file_paths(file_paths: Vec<String>) -> Self {
220        let mut duration_map: HashMap<u32, f64> = HashMap::new();
221
222        for (index, file_path) in file_paths.iter().enumerate() {
223            let durations = get_durations(file_path);
224            let longest = durations.iter().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()).unwrap();
225            duration_map.insert(index as u32, *longest.1);
226        }
227
228        let track_info = gather_track_info_from_file_paths(file_paths.clone());
229
230        Self {
231            duration_map,
232            file_paths,
233            channels: track_info.channel_count,
234            sample_rate: track_info.sample_rate,
235            bits_per_sample: track_info.bits_per_sample,
236        }
237    }
238    
239    pub fn get_duration(&self, index: u32) -> Option<f64> {
240        match self.duration_map.get(&index) {
241            Some(duration) => Some(*duration),
242            None => None,
243        }
244    }
245}