mecomp_analysis/decoder/mod.rs
1#![allow(clippy::missing_inline_in_public_items)]
2
3use std::{
4 cell::RefCell,
5 clone::Clone,
6 marker::Send,
7 num::NonZeroUsize,
8 path::{Path, PathBuf},
9 sync::mpsc::{self, SendError, SyncSender},
10 thread,
11};
12
13use log::{debug, error, trace};
14use rayon::iter::{IntoParallelIterator, ParallelIterator};
15
16use crate::{
17 Analysis, ResampledAudio,
18 embeddings::{AudioEmbeddingModel, Embedding},
19 errors::{AnalysisError, AnalysisResult},
20};
21
22mod mecomp;
23#[allow(clippy::module_name_repetitions)]
24pub use mecomp::{MecompDecoder, SymphoniaSource};
25
26pub type ProcessingCallback =
27 SyncSender<(PathBuf, AnalysisResult<Analysis>, AnalysisResult<Embedding>)>;
28
29/// Trait used to implement your own decoder.
30///
31/// The `decode` function should be implemented so that it
32/// decodes and resample a song to one channel with a sampling rate of 22050 Hz
33/// and a f32le layout.
34/// Once it is implemented, several functions
35/// to perform analysis from path(s) are available, such as
36/// [`analyze_paths_with_cores`](Decoder::analyze_paths_with_cores) and
37/// [`analyze_paths`](Decoder::analyze_paths).
38#[allow(clippy::module_name_repetitions)]
39pub trait Decoder {
40 /// A function that should decode and resample a song, optionally
41 /// extracting the song's metadata such as the artist, the album, etc.
42 ///
43 /// The output sample array should be resampled to f32le, one channel, with a sampling rate
44 /// of 22050 Hz. Anything other than that will yield wrong results.
45 ///
46 /// # Errors
47 ///
48 /// This function will return an error if the file path is invalid, if
49 /// the file path points to a file containing no or corrupted audio stream,
50 /// or if the analysis could not be conducted to the end for some reason.
51 ///
52 /// The error type returned should give a hint as to whether it was a
53 /// decoding or an analysis error.
54 fn decode(&self, path: &Path) -> AnalysisResult<ResampledAudio>;
55
56 /// Returns a decoded song's `Analysis` given a file path, or an error if the song
57 /// could not be analyzed for some reason.
58 ///
59 /// see:
60 /// - [`Decoder::decode`] for how songs are decoded
61 /// - [`Analysis::from_samples`] for how analyses are calculated from [`ResampledAudio`]
62 ///
63 /// # Arguments
64 ///
65 /// * `path` - A [`Path`] holding a valid file path to a valid audio file.
66 ///
67 /// # Errors
68 ///
69 /// This function will return an error if the file path is invalid, if
70 /// the file path points to a file containing no or corrupted audio stream,
71 /// or if the analysis could not be conducted to the end for some reason.
72 ///
73 /// The error type returned should give a hint as to whether it was a
74 /// decoding or an analysis error.
75 #[inline]
76 fn analyze_path<P: AsRef<Path>>(&self, path: P) -> AnalysisResult<Analysis> {
77 Analysis::from_samples(&self.decode(path.as_ref())?)
78 }
79
80 /// Analyze songs in `paths` in parallel across all logical cores,
81 /// and emits the [`AnalysisResult<Analysis>`] objects (along with the [`Path`] they correspond to)
82 /// through the provided [callback channel](`mpsc::Sender`).
83 ///
84 /// This function is blocking, so it should be called in a separate thread
85 /// from where the [receiver](`mpsc::Receiver`) is consumed.
86 ///
87 /// You can cancel the job by dropping the `callback` channel's [receiver](`mpsc::Receiver`).
88 ///
89 /// see [`Decoder::analyze_path`] for more details on how the analyses are generated.
90 ///
91 /// # Example
92 ///
93 /// ```rust
94 /// use mecomp_analysis::decoder::{Decoder as _, MecmopDecoder as Decoder};
95 ///
96 /// let paths = vec![
97 /// "data/piano.wav",
98 /// "data/s32_mono_44_1_kHz.flac"
99 /// ];
100 ///
101 /// let (tx, rx) = std::mpsc::channel();
102 ///
103 /// let handle = std::thread::spawn(move || {
104 /// Decoder::new().unwrap().analyze_paths(paths, tx).unwrap();
105 /// });
106 ///
107 /// for (path, maybe_analysis) = rx {
108 /// if let Ok(analysis) = maybe_analysis {
109 /// println!("{} analyzed successfully!", path.display());
110 /// // do something with the analysis
111 /// } else {
112 /// eprintln!("error analyzing {}!", path.display());
113 /// }
114 /// }
115 /// ```
116 ///
117 /// # Errors
118 ///
119 /// Errors if the `callback` channel is closed.
120 #[inline]
121 fn analyze_paths<P: Into<PathBuf>, I: Send + IntoIterator<Item = P>>(
122 &self,
123 paths: I,
124 callback: mpsc::Sender<(PathBuf, AnalysisResult<Analysis>)>,
125 ) -> Result<(), SendError<()>>
126 where
127 Self: Sync + Send,
128 {
129 let cores = thread::available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap());
130 self.analyze_paths_with_cores(paths, cores, callback)
131 }
132
133 /// Analyze songs in `paths` in parallel across `number_cores` threads,
134 /// and emits the [`AnalysisResult<Analysis>`] objects (along with the [`Path`] they correspond to)
135 /// through the provided [callback channel](`mpsc::Sender`).
136 ///
137 /// This function is blocking, so it should be called in a separate thread
138 /// from where the [receiver](`mpsc::Receiver`) is consumed.
139 ///
140 /// You can cancel the job by dropping the `callback` channel's [receiver](`mpsc::Receiver`).
141 ///
142 /// See also: [`Decoder::analyze_paths`]
143 ///
144 /// # Errors
145 ///
146 /// Errors if the `callback` channel is closed.
147 fn analyze_paths_with_cores<P: Into<PathBuf>, I: IntoIterator<Item = P>>(
148 &self,
149 paths: I,
150 number_cores: NonZeroUsize,
151 callback: mpsc::Sender<(PathBuf, AnalysisResult<Analysis>)>,
152 ) -> Result<(), SendError<()>>
153 where
154 Self: Sync + Send,
155 {
156 let mut cores = thread::available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap());
157 if cores > number_cores {
158 cores = number_cores;
159 }
160 let paths: Vec<PathBuf> = paths.into_iter().map(Into::into).collect();
161
162 if paths.is_empty() {
163 return Ok(());
164 }
165
166 let pool = rayon::ThreadPoolBuilder::new()
167 .num_threads(cores.get())
168 .build()
169 .unwrap();
170
171 pool.install(|| {
172 paths.into_par_iter().try_for_each(|path| {
173 debug!("Analyzing file '{}'", path.display());
174 let analysis = self.analyze_path(&path);
175 callback.send((path, analysis)).map_err(|_| SendError(()))
176 })
177 })
178 }
179
180 /// Process raw audio samples in `audios`, and yield the `Analysis` and `Embedding` objects
181 /// through the provided `callback` channel.
182 /// Parallelizes the process across `number_cores` CPU cores.
183 ///
184 /// You can cancel the job by dropping the `callback` channel.
185 ///
186 /// Note: A new [`AudioEmbeddingModel`](crate::embeddings::AudioEmbeddingModel) session will be created
187 /// for each thread.
188 ///
189 /// # Errors
190 ///
191 /// Errors if the `callback` channel is closed.
192 #[inline]
193 fn process_songs_with_cores(
194 &self,
195 paths: &[PathBuf],
196 callback: ProcessingCallback,
197 number_cores: NonZeroUsize,
198 model_config: crate::embeddings::ModelConfig,
199 ) -> AnalysisResult<()>
200 where
201 Self: Sync + Send,
202 {
203 let mut cores = thread::available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap());
204 if cores > number_cores {
205 cores = number_cores;
206 }
207
208 if paths.is_empty() {
209 return Ok(());
210 }
211
212 thread_local! {
213 static MODEL: RefCell<Option<AudioEmbeddingModel>> = const { RefCell::new(None) };
214 }
215
216 let pool = rayon::ThreadPoolBuilder::new()
217 .num_threads(cores.get())
218 .thread_name(|idx| format!("Analyzer {idx}"))
219 .exit_handler(|thread_id| {
220 // Clean up thread-local model to free memory
221 debug!("Cleaning up model in thread Analyzer {thread_id}");
222 let _ = MODEL.take();
223 })
224 .build()
225 .unwrap();
226
227 pool.install(|| {
228 // Process songs in parallel, but each song is fully processed (decode -> analyze -> embed -> send)
229 // before the thread moves to the next song. This prevents memory accumulation from
230 // buffering decoded audio across multiple pipeline stages.
231 paths.into_par_iter().try_for_each(|path| {
232 let thread_name = thread::current().name().unwrap_or("unknown").to_string();
233
234 // Decode the audio file
235 let audio = match self.decode(path) {
236 Ok(audio) => {
237 trace!("Decoded {} in thread {thread_name}", path.display());
238 audio
239 }
240 Err(e) => {
241 error!("Error decoding {}: {e}", path.display());
242 return Ok(()); // Skip this file, continue with others
243 }
244 };
245
246 let (analysis, embedding) = pool.join(
247 || {
248 // Analyze the audio
249 let analysis = Analysis::from_samples(&audio);
250 trace!("Analyzed {} in thread {thread_name}", path.display());
251 analysis
252 },
253 || {
254 // Load or get the model for this thread, then generate the embedding
255 let embedding = MODEL.with(|model_cell| {
256 let mut model_ref = model_cell.borrow_mut();
257 if model_ref.is_none() {
258 debug!("Loading model in thread {thread_name}");
259 *model_ref = Some(AudioEmbeddingModel::load(&model_config)?);
260 }
261 let model = model_ref.as_mut().unwrap();
262 model.embed(&audio).map_err(AnalysisError::from)
263 });
264 trace!(
265 "Generated embeddings for {} in thread {thread_name}",
266 path.display()
267 );
268 embedding
269 },
270 );
271
272 // Drop the audio samples before sending to free memory immediately
273 drop(audio);
274
275 // Send the results - the bounded channel will apply backpressure
276 // if the consumer is slow, preventing unbounded memory growth
277 callback
278 .send((path.clone(), analysis, embedding))
279 .map_err(|_| AnalysisError::SendError)
280 })
281 })
282 }
283
284 /// Process raw audio samples in `audios`, and yield the `Analysis` and `Embedding` objects
285 /// through the provided `callback` channel.
286 /// Parallelizes the process across all available CPU cores.
287 ///
288 /// You can cancel the job by dropping the `callback` channel.
289 ///
290 /// Note: A new [`AudioEmbeddingModel`](crate::embeddings::AudioEmbeddingModel) session will be created
291 /// for each thread.
292 ///
293 /// # Errors
294 /// Errors if the `callback` channel is closed.
295 #[inline]
296 fn process_songs(
297 &self,
298 paths: &[PathBuf],
299 callback: ProcessingCallback,
300 model_config: crate::embeddings::ModelConfig,
301 ) -> AnalysisResult<()>
302 where
303 Self: Sync + Send,
304 {
305 let cores = thread::available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap());
306 self.process_songs_with_cores(paths, callback, cores, model_config)
307 }
308}