mecomp_analysis/decoder/mod.rs
1#![allow(clippy::missing_inline_in_public_items)]
2
3use std::{
4 cell::RefCell,
5 clone::Clone,
6 marker::Send,
7 num::NonZeroUsize,
8 path::{Path, PathBuf},
9 sync::mpsc::{self, SendError, SyncSender},
10 thread,
11};
12
13use log::{debug, error, trace};
14use rayon::iter::{IntoParallelIterator, ParallelIterator};
15
16use crate::{
17 Analysis, ResampledAudio,
18 embeddings::{AudioEmbeddingModel, Embedding},
19 errors::{AnalysisError, AnalysisResult},
20};
21
22mod mecomp;
23#[allow(clippy::module_name_repetitions)]
24pub use mecomp::{MecompDecoder, SymphoniaSource};
25
26pub type ProcessingCallback =
27 SyncSender<(PathBuf, AnalysisResult<Analysis>, AnalysisResult<Embedding>)>;
28
29/// Trait used to implement your own decoder.
30///
31/// The `decode` function should be implemented so that it
32/// decodes and resample a song to one channel with a sampling rate of 22050 Hz
33/// and a f32le layout.
34/// Once it is implemented, several functions
35/// to perform analysis from path(s) are available, such as
36/// [`analyze_paths_with_cores`](Decoder::analyze_paths_with_cores) and
37/// [`analyze_paths`](Decoder::analyze_paths).
38#[allow(clippy::module_name_repetitions)]
39pub trait Decoder {
40 /// A function that should decode and resample a song, optionally
41 /// extracting the song's metadata such as the artist, the album, etc.
42 ///
43 /// The output sample array should be resampled to f32le, one channel, with a sampling rate
44 /// of 22050 Hz. Anything other than that will yield wrong results.
45 ///
46 /// # Errors
47 ///
48 /// This function will return an error if the file path is invalid, if
49 /// the file path points to a file containing no or corrupted audio stream,
50 /// or if the analysis could not be conducted to the end for some reason.
51 ///
52 /// The error type returned should give a hint as to whether it was a
53 /// decoding or an analysis error.
54 fn decode(&self, path: &Path) -> AnalysisResult<ResampledAudio>;
55
56 /// Returns a decoded song's `Analysis` given a file path, or an error if the song
57 /// could not be analyzed for some reason.
58 ///
59 /// see:
60 /// - [`Decoder::decode`] for how songs are decoded
61 /// - [`Analysis::from_samples`] for how analyses are calculated from [`ResampledAudio`]
62 ///
63 /// # Arguments
64 ///
65 /// * `path` - A [`Path`] holding a valid file path to a valid audio file.
66 ///
67 /// # Errors
68 ///
69 /// This function will return an error if the file path is invalid, if
70 /// the file path points to a file containing no or corrupted audio stream,
71 /// or if the analysis could not be conducted to the end for some reason.
72 ///
73 /// The error type returned should give a hint as to whether it was a
74 /// decoding or an analysis error.
75 #[inline]
76 fn analyze_path<P: AsRef<Path>>(&self, path: P) -> AnalysisResult<Analysis> {
77 Analysis::from_samples(&self.decode(path.as_ref())?)
78 }
79
80 /// Analyze songs in `paths` in parallel across all logical cores,
81 /// and emits the [`AnalysisResult<Analysis>`] objects (along with the [`Path`] they correspond to)
82 /// through the provided [callback channel](`mpsc::Sender`).
83 ///
84 /// This function is blocking, so it should be called in a separate thread
85 /// from where the [receiver](`mpsc::Receiver`) is consumed.
86 ///
87 /// You can cancel the job by dropping the `callback` channel's [receiver](`mpsc::Receiver`).
88 ///
89 /// see [`Decoder::analyze_path`] for more details on how the analyses are generated.
90 ///
91 /// # Example
92 ///
93 /// ```rust
94 /// use mecomp_analysis::decoder::{Decoder as _, MecmopDecoder as Decoder};
95 ///
96 /// let paths = vec![
97 /// "data/piano.wav",
98 /// "data/s32_mono_44_1_kHz.flac"
99 /// ];
100 ///
101 /// let (tx, rx) = std::mpsc::channel();
102 ///
103 /// let handle = std::thread::spawn(move || {
104 /// Decoder::new().unwrap().analyze_paths(paths, tx).unwrap();
105 /// });
106 ///
107 /// for (path, maybe_analysis) = rx {
108 /// if let Ok(analysis) = maybe_analysis {
109 /// println!("{} analyzed successfully!", path.display());
110 /// // do something with the analysis
111 /// } else {
112 /// eprintln!("error analyzing {}!", path.display());
113 /// }
114 /// }
115 /// ```
116 ///
117 /// # Errors
118 ///
119 /// Errors if the `callback` channel is closed.
120 #[inline]
121 fn analyze_paths<P: Into<PathBuf>, I: Send + IntoIterator<Item = P>>(
122 &self,
123 paths: I,
124 callback: mpsc::Sender<(PathBuf, AnalysisResult<Analysis>)>,
125 ) -> Result<(), SendError<()>>
126 where
127 Self: Sync + Send,
128 {
129 let cores = thread::available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap());
130 self.analyze_paths_with_cores(paths, cores, callback)
131 }
132
133 /// Analyze songs in `paths` in parallel across `number_cores` threads,
134 /// and emits the [`AnalysisResult<Analysis>`] objects (along with the [`Path`] they correspond to)
135 /// through the provided [callback channel](`mpsc::Sender`).
136 ///
137 /// This function is blocking, so it should be called in a separate thread
138 /// from where the [receiver](`mpsc::Receiver`) is consumed.
139 ///
140 /// You can cancel the job by dropping the `callback` channel's [receiver](`mpsc::Receiver`).
141 ///
142 /// See also: [`Decoder::analyze_paths`]
143 ///
144 /// # Errors
145 ///
146 /// Errors if the `callback` channel is closed.
147 fn analyze_paths_with_cores<P: Into<PathBuf>, I: IntoIterator<Item = P>>(
148 &self,
149 paths: I,
150 number_cores: NonZeroUsize,
151 callback: mpsc::Sender<(PathBuf, AnalysisResult<Analysis>)>,
152 ) -> Result<(), SendError<()>>
153 where
154 Self: Sync + Send,
155 {
156 let mut cores = thread::available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap());
157 if cores > number_cores {
158 cores = number_cores;
159 }
160 let paths: Vec<PathBuf> = paths.into_iter().map(Into::into).collect();
161
162 if paths.is_empty() {
163 return Ok(());
164 }
165
166 let pool = rayon::ThreadPoolBuilder::new()
167 .num_threads(cores.get())
168 .build()
169 .unwrap();
170
171 pool.install(|| {
172 paths.into_par_iter().try_for_each(|path| {
173 debug!("Analyzing file '{}'", path.display());
174 let analysis = self.analyze_path(&path);
175 callback.send((path, analysis)).map_err(|_| SendError(()))
176 })
177 })
178 }
179
180 /// Process raw audio samples in `audios`, and yield the `Analysis` and `Embedding` objects
181 /// through the provided `callback` channel.
182 /// Parallelizes the process across `number_cores` CPU cores.
183 ///
184 /// You can cancel the job by dropping the `callback` channel.
185 ///
186 /// Note: A new [`AudioEmbeddingModel`](crate::embeddings::AudioEmbeddingModel) session will be created
187 /// for each batch processed.
188 ///
189 /// # Errors
190 ///
191 /// Errors if the `callback` channel is closed.
192 #[inline]
193 fn process_songs_with_cores(
194 &self,
195 paths: &[PathBuf],
196 callback: ProcessingCallback,
197 number_cores: NonZeroUsize,
198 model_config: crate::embeddings::ModelConfig,
199 ) -> AnalysisResult<()>
200 where
201 Self: Sync + Send,
202 {
203 let mut cores = thread::available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap());
204 if cores > number_cores {
205 cores = number_cores;
206 }
207
208 if paths.is_empty() {
209 return Ok(());
210 }
211
212 thread_local! {
213 static MODEL: RefCell<Option<AudioEmbeddingModel>> = const { RefCell::new(None) };
214 }
215
216 let pool = rayon::ThreadPoolBuilder::new()
217 .num_threads(cores.get())
218 .thread_name(|idx| format!("Analyzer {idx}"))
219 .build()
220 .unwrap();
221
222 pool.install(|| {
223 // Process songs in parallel, but each song is fully processed (decode -> analyze -> embed -> send)
224 // before the thread moves to the next song. This prevents memory accumulation from
225 // buffering decoded audio across multiple pipeline stages.
226 paths.into_par_iter().try_for_each(|path| {
227 let thread_name = thread::current().name().unwrap_or("unknown").to_string();
228
229 // Decode the audio file
230 let audio = match self.decode(path) {
231 Ok(audio) => {
232 trace!("Decoded {} in thread {thread_name}", path.display());
233 audio
234 }
235 Err(e) => {
236 error!("Error decoding {}: {e}", path.display());
237 return Ok(()); // Skip this file, continue with others
238 }
239 };
240
241 // Analyze the audio
242 let analysis = Analysis::from_samples(&audio);
243 trace!("Analyzed {} in thread {thread_name}", path.display());
244
245 // Generate embedding
246 let embedding = MODEL.try_with(|model_cell| {
247 let mut model_ref = model_cell.borrow_mut();
248 if model_ref.is_none() {
249 debug!("Loading embedding model in thread {thread_name}");
250 *model_ref = Some(AudioEmbeddingModel::load(&model_config)?);
251 }
252 trace!(
253 "Generating embeddings for {} in thread {thread_name}",
254 path.display()
255 );
256 model_ref
257 .as_mut()
258 .unwrap()
259 .embed(&audio)
260 .map_err(AnalysisError::from)
261 });
262
263 // Flatten the Result<Result<...>> and convert to AnalysisResult
264 let embedding = match embedding {
265 Ok(Ok(e)) => Ok(e),
266 Ok(Err(e)) => Err(e),
267 Err(e) => Err(AnalysisError::AccessError(e)),
268 };
269
270 // Drop the audio samples before sending to free memory immediately
271 drop(audio);
272
273 // Send the results - the bounded channel will apply backpressure
274 // if the consumer is slow, preventing unbounded memory growth
275 callback
276 .send((path.clone(), analysis, embedding))
277 .map_err(|_| AnalysisError::SendError)
278 })
279 })
280 }
281
282 /// Process raw audio samples in `audios`, and yield the `Analysis` and `Embedding` objects
283 /// through the provided `callback` channel.
284 /// Parallelizes the process across all available CPU cores.
285 ///
286 /// You can cancel the job by dropping the `callback` channel.
287 ///
288 /// Note: A new [`AudioEmbeddingModel`](crate::embeddings::AudioEmbeddingModel) session will be created
289 /// for each batch processed.
290 ///
291 /// # Errors
292 /// Errors if the `callback` channel is closed.
293 #[inline]
294 fn process_songs(
295 &self,
296 paths: &[PathBuf],
297 callback: ProcessingCallback,
298 model_config: crate::embeddings::ModelConfig,
299 ) -> AnalysisResult<()>
300 where
301 Self: Sync + Send,
302 {
303 let cores = thread::available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap());
304 self.process_songs_with_cores(paths, callback, cores, model_config)
305 }
306}