Skip to main content

mecomp_analysis/
embeddings.rs

1//! Module for handling creating vector embeddings of audio data using theoretically any ONNX model,
2//! but specifically designed for use with the model in `models/audio_embedding_model.onnx`.
3
4use crate::ResampledAudio;
5use log::warn;
6use ort::{
7    execution_providers::CPUExecutionProvider,
8    session::{Session, builder::GraphOptimizationLevel},
9    value::TensorRef,
10};
11use std::path::{Path, PathBuf};
12
13// Conditionally import execution providers based on enabled features/platform
14#[cfg(feature = "cuda")]
15use ort::execution_providers::CUDAExecutionProvider;
16#[cfg(target_os = "macos")]
17use ort::execution_providers::CoreMLExecutionProvider;
18#[cfg(target_os = "windows")]
19use ort::execution_providers::DirectMLExecutionProvider;
20#[cfg(feature = "tensorrt")]
21use ort::execution_providers::TensorRTExecutionProvider;
22
23static MODEL_BYTES: &[u8] = include_bytes!("../models/audio_embedding_model.onnx");
24
25/// The size of the embedding produced by the audio embedding model, onnx wants this to be i64.
26const EMBEDDING_SIZE: i64 = 32;
27/// The size of the embedding produced by the audio embedding model as a usize.
28pub const DIM_EMBEDDING: usize = 32;
29
30#[derive(Debug, Default, PartialEq, Clone, Copy)]
31#[repr(transparent)]
32pub struct Embedding(pub [f32; DIM_EMBEDDING]);
33
34impl Embedding {
35    /// Get the length of the embedding vector.
36    #[inline]
37    #[must_use]
38    pub const fn len(&self) -> usize {
39        self.0.len()
40    }
41
42    /// Check if the embedding is empty.
43    ///
44    /// Should always return false since embeddings have a fixed size.
45    #[inline]
46    #[must_use]
47    pub const fn is_empty(&self) -> bool {
48        self.0.is_empty()
49    }
50
51    /// Get a reference to the embedding as a slice.
52    #[inline]
53    #[must_use]
54    pub const fn as_slice(&self) -> &[f32] {
55        &self.0
56    }
57
58    /// Get a mutable reference to the embedding as a slice.
59    #[inline]
60    #[must_use]
61    pub const fn as_mut_slice(&mut self) -> &mut [f32] {
62        &mut self.0
63    }
64
65    /// Get the inner array of the embedding.
66    #[inline]
67    #[must_use]
68    pub const fn inner(&self) -> &[f32; DIM_EMBEDDING] {
69        &self.0
70    }
71}
72
73#[derive(Debug, Clone, Default)]
74pub struct ModelConfig {
75    pub path: Option<PathBuf>,
76}
77
78/// Struct representing an audio embedding model loaded from an ONNX file.
79#[derive(Debug)]
80pub struct AudioEmbeddingModel {
81    session: ort::session::Session,
82}
83
84/// Build a list of execution providers in priority order based on available features/platform.
85/// Providers are tried in order; if one fails to register, the next is attempted.
86/// CPU is always the final fallback.
87#[allow(clippy::vec_init_then_push)]
88fn build_execution_providers() -> Vec<ort::execution_providers::ExecutionProviderDispatch> {
89    let mut providers = Vec::new();
90
91    // GPU providers (feature-gated, require user to have appropriate drivers)
92    #[cfg(feature = "tensorrt")]
93    {
94        providers.push(TensorRTExecutionProvider::default().build());
95        log::info!("TensorRT execution provider enabled");
96    }
97
98    #[cfg(feature = "cuda")]
99    {
100        providers.push(CUDAExecutionProvider::default().build());
101        log::info!("CUDA execution provider enabled");
102    }
103
104    // Platform-specific zero-dependency providers
105    #[cfg(target_os = "windows")]
106    {
107        providers.push(DirectMLExecutionProvider::default().build());
108        log::info!("DirectML execution provider enabled (Windows)");
109    }
110
111    #[cfg(target_os = "macos")]
112    {
113        providers.push(
114            CoreMLExecutionProvider::default()
115                .with_subgraphs(true) // Enable CoreML on subgraphs for better coverage
116                .build(),
117        );
118        log::info!("CoreML execution provider enabled (macOS)");
119    }
120
121    // CPU is always the final fallback
122    // Disable the CPU memory arena to prevent memory accumulation when processing
123    // variable-length audio inputs. The arena allocator grows to accommodate the
124    // largest input seen but never shrinks, which causes memory to accumulate
125    // when processing many songs of varying lengths.
126    providers.push(
127        CPUExecutionProvider::default()
128            .with_arena_allocator(false)
129            .build(),
130    );
131
132    providers
133}
134
135fn session_builder() -> ort::Result<ort::session::builder::SessionBuilder> {
136    let providers = build_execution_providers();
137
138    let builder = Session::builder()?
139        .with_execution_providers(providers)?
140        .with_memory_pattern(false)?
141        .with_optimization_level(GraphOptimizationLevel::Level3)?;
142
143    Ok(builder)
144}
145
146impl AudioEmbeddingModel {
147    /// Load the default audio embedding model included in the package.
148    ///
149    /// # Errors
150    /// Fails if the model cannot be loaded for some reason.
151    #[inline]
152    pub fn load_default() -> ort::Result<Self> {
153        let session = session_builder()?.commit_from_memory(MODEL_BYTES)?;
154
155        Ok(Self { session })
156    }
157
158    /// Load an audio embedding model from the specified ONNX file path.
159    ///
160    /// # Errors
161    /// Fails if the model cannot be loaded for some reason.
162    #[inline]
163    pub fn load_from_onnx<P: AsRef<Path>>(path: P) -> ort::Result<Self> {
164        let session = session_builder()?.commit_from_file(&path)?;
165
166        Ok(Self { session })
167    }
168
169    /// Load the an audio embedding model with the specified configuration.
170    ///
171    /// # Arguments
172    ///
173    /// * `config` - The configuration for how the model should be loaded.
174    /// # Errors
175    /// Fails if the model cannot be loaded for some reason.
176    #[inline]
177    pub fn load(config: &ModelConfig) -> ort::Result<Self> {
178        config.path.as_ref().map_or_else(Self::load_default, |path|{
179            Self::load_from_onnx(path).or_else(|e| {
180                warn!("failed to load embeddings model from specified path: {e}, falling back to default model.");
181                Self::load_default()
182            })
183
184        })
185    }
186
187    /// Compute embedding from raw audio samples (f32, mono, 22050 Hz),
188    /// blocks during execution.
189    ///
190    /// # Errors
191    ///
192    /// Fails if:
193    /// * the audio cannot be converted into a tensor,
194    /// * the model inference fails,
195    /// * the output is missing or has an unexpected shape (should be named "embedding" and have shape `[1, 32]`).
196    #[inline]
197    pub fn embed(&mut self, audio: &ResampledAudio) -> ort::Result<Embedding> {
198        // Create input with batch dimension
199        let inputs = ort::inputs! {
200            "audio" => TensorRef::from_array_view(([1, audio.samples.len()], audio.samples.as_slice()))?,
201        };
202
203        // Run inference
204        let outputs = self.session.run(inputs)?;
205
206        // Extract embedding
207        let (shape, embedding) = outputs["embedding"].try_extract_tensor::<f32>()?;
208
209        let expected_shape = &[1, EMBEDDING_SIZE];
210        if shape.iter().as_slice() != expected_shape {
211            return Err(ort::Error::new(format!(
212                "Unexpected embedding shape: {shape:?}, expected {expected_shape:?}",
213            )));
214        }
215
216        let sized_embedding: [f32; DIM_EMBEDDING] = embedding
217            .try_into()
218            .map_err(|_| ort::Error::new("Failed to convert embedding to fixed-size array"))?;
219
220        Ok(Embedding(sized_embedding))
221    }
222
223    /// Compute embeddings for a batch of raw audio samples (f32, mono, 22050 Hz),
224    /// blocks during execution.
225    ///
226    /// For efficiency, all audio samples should be similar in length.
227    ///
228    /// # Errors
229    ///
230    /// Fails if:
231    /// * the audio cannot be converted into a tensor,
232    /// * the model inference fails,
233    /// * the output is missing or has an unexpected shape (should be named "embedding" and have shape `[batch_size, 32]`).
234    #[inline]
235    pub fn embed_batch(&mut self, audios: &[ResampledAudio]) -> ort::Result<Vec<Embedding>> {
236        let max_len = audios.iter().map(|a| a.samples.len()).max().unwrap_or(0);
237
238        let batch_size = audios.len();
239
240        // Prepare input tensor with zero-padding
241        let mut input_data = vec![0f32; batch_size * max_len];
242        for (i, audio) in audios.iter().enumerate() {
243            input_data[i * max_len..i * max_len + audio.samples.len()]
244                .copy_from_slice(&audio.samples);
245        }
246
247        let input = ort::inputs! {
248            "audio" => TensorRef::from_array_view(([batch_size, max_len], &*input_data))?,
249        };
250
251        // Run inference
252        let outputs = self.session.run(input)?;
253
254        // Extract embeddings
255        let (shape, embedding_tensor) = outputs["embedding"].try_extract_tensor::<f32>()?;
256        #[allow(clippy::cast_possible_wrap)]
257        let expected_shape = &[batch_size as i64, EMBEDDING_SIZE];
258        if shape.iter().as_slice() != expected_shape {
259            return Err(ort::Error::new(format!(
260                "Unexpected embedding shape: {shape:?}, expected {expected_shape:?}",
261            )));
262        }
263
264        let mut embeddings = Vec::with_capacity(batch_size);
265        for i in 0..batch_size {
266            let start = i * DIM_EMBEDDING;
267            let end = start + DIM_EMBEDDING;
268            let sized_embedding: [f32; DIM_EMBEDDING] = embedding_tensor[start..end]
269                .try_into()
270                .map_err(|_| ort::Error::new("Failed to convert embedding to fixed-size array"))?;
271            embeddings.push(Embedding(sized_embedding));
272        }
273
274        Ok(embeddings)
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::decoder::Decoder;
282    use crate::decoder::MecompDecoder;
283
284    const TEST_AUDIO_PATH: &str = "data/5_mins_of_noise_stereo_48kHz.ogg";
285
286    #[test]
287    fn test_embedding_model() {
288        let decoder = MecompDecoder::new().unwrap();
289        let audio = decoder
290            .decode(Path::new(TEST_AUDIO_PATH))
291            .expect("Failed to decode test audio");
292
293        let mut model =
294            AudioEmbeddingModel::load_default().expect("Failed to load embedding model");
295        let embedding = model.embed(&audio).expect("Failed to compute embedding");
296        assert_eq!(embedding.len(), DIM_EMBEDDING);
297    }
298
299    #[test]
300    fn test_embedding_model_batch() {
301        let decoder = MecompDecoder::new().unwrap();
302        let audio = decoder
303            .decode(Path::new(TEST_AUDIO_PATH))
304            .expect("Failed to decode test audio");
305
306        let audios = vec![audio.clone(); 4];
307
308        let mut model =
309            AudioEmbeddingModel::load_default().expect("Failed to load embedding model");
310        let embeddings = model
311            .embed_batch(&audios)
312            .expect("Failed to compute batch embeddings");
313        assert_eq!(embeddings.len(), 4);
314        for embedding in &embeddings {
315            assert_eq!(embedding.len(), DIM_EMBEDDING);
316        }
317
318        // since all the audios are the same, all embeddings should be the same
319        for embedding in &embeddings[1..] {
320            assert_eq!(embedding, &embeddings[0]);
321        }
322    }
323
324    #[test]
325    fn test_embedding_model_batch_mixed_sizes() {
326        let decoder = MecompDecoder::new().unwrap();
327        let audio1 = decoder
328            .decode(Path::new(TEST_AUDIO_PATH))
329            .expect("Failed to decode test audio");
330
331        // create a shorter audio by taking only the first half of the samples
332        let audio2 = ResampledAudio {
333            samples: audio1.samples[..audio1.samples.len() / 2].to_vec(),
334            path: audio1.path.clone(),
335        };
336
337        let audios = vec![audio1.clone(), audio2.clone()];
338
339        let mut model =
340            AudioEmbeddingModel::load_default().expect("Failed to load embedding model");
341        let embeddings = model
342            .embed_batch(&audios)
343            .expect("Failed to compute batch embeddings");
344        assert_eq!(embeddings.len(), 2);
345        for embedding in &embeddings {
346            assert_eq!(embedding.len(), DIM_EMBEDDING);
347        }
348    }
349}