Skip to main content

mecomp_analysis/decoder/
mod.rs

1#![allow(clippy::missing_inline_in_public_items)]
2
3use std::{
4    cell::RefCell,
5    clone::Clone,
6    marker::Send,
7    num::NonZeroUsize,
8    path::{Path, PathBuf},
9    sync::mpsc::{self, SendError, SyncSender},
10    thread,
11};
12
13use log::{debug, error, trace};
14use rayon::iter::{IntoParallelIterator, ParallelIterator};
15
16use crate::{
17    Analysis, ResampledAudio,
18    embeddings::{AudioEmbeddingModel, Embedding},
19    errors::{AnalysisError, AnalysisResult},
20};
21
22mod mecomp;
23#[allow(clippy::module_name_repetitions)]
24pub use mecomp::{MecompDecoder, SymphoniaSource};
25
26pub type ProcessingCallback =
27    SyncSender<(PathBuf, AnalysisResult<Analysis>, AnalysisResult<Embedding>)>;
28
29/// Trait used to implement your own decoder.
30///
31/// The `decode` function should be implemented so that it
32/// decodes and resample a song to one channel with a sampling rate of 22050 Hz
33/// and a f32le layout.
34/// Once it is implemented, several functions
35/// to perform analysis from path(s) are available, such as
36/// [`analyze_paths_with_cores`](Decoder::analyze_paths_with_cores) and
37/// [`analyze_paths`](Decoder::analyze_paths).
38#[allow(clippy::module_name_repetitions)]
39pub trait Decoder {
40    /// A function that should decode and resample a song, optionally
41    /// extracting the song's metadata such as the artist, the album, etc.
42    ///
43    /// The output sample array should be resampled to f32le, one channel, with a sampling rate
44    /// of 22050 Hz. Anything other than that will yield wrong results.
45    ///
46    /// # Errors
47    ///
48    /// This function will return an error if the file path is invalid, if
49    /// the file path points to a file containing no or corrupted audio stream,
50    /// or if the analysis could not be conducted to the end for some reason.
51    ///
52    /// The error type returned should give a hint as to whether it was a
53    /// decoding or an analysis error.
54    fn decode(&self, path: &Path) -> AnalysisResult<ResampledAudio>;
55
56    /// Returns a decoded song's `Analysis` given a file path, or an error if the song
57    /// could not be analyzed for some reason.
58    ///
59    /// see:
60    /// - [`Decoder::decode`] for how songs are decoded
61    /// - [`Analysis::from_samples`] for how analyses are calculated from [`ResampledAudio`]
62    ///
63    /// # Arguments
64    ///
65    /// * `path` - A [`Path`] holding a valid file path to a valid audio file.
66    ///
67    /// # Errors
68    ///
69    /// This function will return an error if the file path is invalid, if
70    /// the file path points to a file containing no or corrupted audio stream,
71    /// or if the analysis could not be conducted to the end for some reason.
72    ///
73    /// The error type returned should give a hint as to whether it was a
74    /// decoding or an analysis error.
75    #[inline]
76    fn analyze_path<P: AsRef<Path>>(&self, path: P) -> AnalysisResult<Analysis> {
77        Analysis::from_samples(&self.decode(path.as_ref())?)
78    }
79
80    /// Analyze songs in `paths` in parallel across all logical cores,
81    /// and emits the [`AnalysisResult<Analysis>`] objects (along with the [`Path`] they correspond to)
82    /// through the provided [callback channel](`mpsc::Sender`).
83    ///
84    /// This function is blocking, so it should be called in a separate thread
85    /// from where the [receiver](`mpsc::Receiver`) is consumed.
86    ///
87    /// You can cancel the job by dropping the `callback` channel's [receiver](`mpsc::Receiver`).
88    ///
89    /// see [`Decoder::analyze_path`] for more details on how the analyses are generated.
90    ///
91    /// # Example
92    ///
93    /// ```rust
94    /// use mecomp_analysis::decoder::{Decoder as _, MecmopDecoder as Decoder};
95    ///
96    /// let paths = vec![
97    ///     "data/piano.wav",
98    ///     "data/s32_mono_44_1_kHz.flac"
99    /// ];
100    ///
101    /// let (tx, rx) = std::mpsc::channel();
102    ///
103    /// let handle = std::thread::spawn(move || {
104    ///     Decoder::new().unwrap().analyze_paths(paths, tx).unwrap();
105    /// });
106    ///
107    /// for (path, maybe_analysis) = rx {
108    ///     if let Ok(analysis) = maybe_analysis {
109    ///         println!("{} analyzed successfully!", path.display());
110    ///         // do something with the analysis
111    ///     } else {
112    ///         eprintln!("error analyzing {}!", path.display());
113    ///     }
114    /// }
115    /// ```
116    ///
117    /// # Errors
118    ///
119    /// Errors if the `callback` channel is closed.
120    #[inline]
121    fn analyze_paths<P: Into<PathBuf>, I: Send + IntoIterator<Item = P>>(
122        &self,
123        paths: I,
124        callback: mpsc::Sender<(PathBuf, AnalysisResult<Analysis>)>,
125    ) -> Result<(), SendError<()>>
126    where
127        Self: Sync + Send,
128    {
129        let cores = thread::available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap());
130        self.analyze_paths_with_cores(paths, cores, callback)
131    }
132
133    /// Analyze songs in `paths` in parallel across `number_cores` threads,
134    /// and emits the [`AnalysisResult<Analysis>`] objects (along with the [`Path`] they correspond to)
135    /// through the provided [callback channel](`mpsc::Sender`).
136    ///
137    /// This function is blocking, so it should be called in a separate thread
138    /// from where the [receiver](`mpsc::Receiver`) is consumed.
139    ///
140    /// You can cancel the job by dropping the `callback` channel's [receiver](`mpsc::Receiver`).
141    ///
142    /// See also: [`Decoder::analyze_paths`]
143    ///
144    /// # Errors
145    ///
146    /// Errors if the `callback` channel is closed.
147    fn analyze_paths_with_cores<P: Into<PathBuf>, I: IntoIterator<Item = P>>(
148        &self,
149        paths: I,
150        number_cores: NonZeroUsize,
151        callback: mpsc::Sender<(PathBuf, AnalysisResult<Analysis>)>,
152    ) -> Result<(), SendError<()>>
153    where
154        Self: Sync + Send,
155    {
156        let mut cores = thread::available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap());
157        if cores > number_cores {
158            cores = number_cores;
159        }
160        let paths: Vec<PathBuf> = paths.into_iter().map(Into::into).collect();
161
162        if paths.is_empty() {
163            return Ok(());
164        }
165
166        let pool = rayon::ThreadPoolBuilder::new()
167            .num_threads(cores.get())
168            .build()
169            .unwrap();
170
171        pool.install(|| {
172            paths.into_par_iter().try_for_each(|path| {
173                debug!("Analyzing file '{}'", path.display());
174                let analysis = self.analyze_path(&path);
175                callback.send((path, analysis)).map_err(|_| SendError(()))
176            })
177        })
178    }
179
180    /// Process raw audio samples in `audios`, and yield the `Analysis` and `Embedding` objects
181    /// through the provided `callback` channel.
182    /// Parallelizes the process across `number_cores` CPU cores.
183    ///
184    /// You can cancel the job by dropping the `callback` channel.
185    ///
186    /// Note: A new [`AudioEmbeddingModel`](crate::embeddings::AudioEmbeddingModel) session will be created
187    /// for each thread.
188    ///
189    /// # Errors
190    ///
191    /// Errors if the `callback` channel is closed.
192    #[inline]
193    fn process_songs_with_cores(
194        &self,
195        paths: &[PathBuf],
196        callback: ProcessingCallback,
197        number_cores: NonZeroUsize,
198        model_config: crate::embeddings::ModelConfig,
199    ) -> AnalysisResult<()>
200    where
201        Self: Sync + Send,
202    {
203        let mut cores = thread::available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap());
204        if cores > number_cores {
205            cores = number_cores;
206        }
207
208        if paths.is_empty() {
209            return Ok(());
210        }
211
212        thread_local! {
213            static MODEL: RefCell<Option<AudioEmbeddingModel>> = const { RefCell::new(None) };
214        }
215
216        let pool = rayon::ThreadPoolBuilder::new()
217            .num_threads(cores.get())
218            .thread_name(|idx| format!("Analyzer {idx}"))
219            .exit_handler(|thread_id| {
220                // Clean up thread-local model to free memory
221                debug!("Cleaning up model in thread Analyzer {thread_id}");
222                let _ = MODEL.take();
223            })
224            .build()
225            .unwrap();
226
227        pool.install(|| {
228            // Process songs in parallel, but each song is fully processed (decode -> analyze -> embed -> send)
229            // before the thread moves to the next song. This prevents memory accumulation from
230            // buffering decoded audio across multiple pipeline stages.
231            paths.into_par_iter().try_for_each(|path| {
232                let thread_name = thread::current().name().unwrap_or("unknown").to_string();
233
234                // Decode the audio file
235                let audio = match self.decode(path) {
236                    Ok(audio) => {
237                        trace!("Decoded {} in thread {thread_name}", path.display());
238                        audio
239                    }
240                    Err(e) => {
241                        error!("Error decoding {}: {e}", path.display());
242                        return Ok(()); // Skip this file, continue with others
243                    }
244                };
245
246                let (analysis, embedding) = pool.join(
247                    || {
248                        // Analyze the audio
249                        let analysis = Analysis::from_samples(&audio);
250                        trace!("Analyzed {} in thread {thread_name}", path.display());
251                        analysis
252                    },
253                    || {
254                        // Load or get the model for this thread, then generate the embedding
255                        let embedding = MODEL.with(|model_cell| {
256                            let mut model_ref = model_cell.borrow_mut();
257                            if model_ref.is_none() {
258                                debug!("Loading model in thread {thread_name}");
259                                *model_ref = Some(AudioEmbeddingModel::load(&model_config)?);
260                            }
261                            let model = model_ref.as_mut().unwrap();
262                            model.embed(&audio).map_err(AnalysisError::from)
263                        });
264                        trace!(
265                            "Generated embeddings for {} in thread {thread_name}",
266                            path.display()
267                        );
268                        embedding
269                    },
270                );
271
272                // Drop the audio samples before sending to free memory immediately
273                drop(audio);
274
275                // Send the results - the bounded channel will apply backpressure
276                // if the consumer is slow, preventing unbounded memory growth
277                callback
278                    .send((path.clone(), analysis, embedding))
279                    .map_err(|_| AnalysisError::SendError)
280            })
281        })
282    }
283
284    /// Process raw audio samples in `audios`, and yield the `Analysis` and `Embedding` objects
285    /// through the provided `callback` channel.
286    /// Parallelizes the process across all available CPU cores.
287    ///
288    /// You can cancel the job by dropping the `callback` channel.
289    ///
290    /// Note: A new [`AudioEmbeddingModel`](crate::embeddings::AudioEmbeddingModel) session will be created
291    /// for each thread.
292    ///
293    /// # Errors
294    /// Errors if the `callback` channel is closed.
295    #[inline]
296    fn process_songs(
297        &self,
298        paths: &[PathBuf],
299        callback: ProcessingCallback,
300        model_config: crate::embeddings::ModelConfig,
301    ) -> AnalysisResult<()>
302    where
303        Self: Sync + Send,
304    {
305        let cores = thread::available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap());
306        self.process_songs_with_cores(paths, callback, cores, model_config)
307    }
308}