Skip to main content

mecomp_analysis/
embeddings.rs

1//! Module for handling creating vector embeddings of audio data using theoretically any ONNX model,
2//! but specifically designed for use with the model in `models/audio_embedding_model.onnx`.
3
4use crate::ResampledAudio;
5use log::warn;
6use ort::{
7    execution_providers::CPUExecutionProvider,
8    session::{Session, builder::GraphOptimizationLevel},
9    value::TensorRef,
10};
11use std::path::{Path, PathBuf};
12
13// Conditionally import execution providers based on enabled features/platform
14#[cfg(feature = "cuda")]
15use ort::execution_providers::CUDAExecutionProvider;
16#[cfg(target_os = "macos")]
17use ort::execution_providers::CoreMLExecutionProvider;
18#[cfg(target_os = "windows")]
19use ort::execution_providers::DirectMLExecutionProvider;
20#[cfg(feature = "tensorrt")]
21use ort::execution_providers::TensorRTExecutionProvider;
22
23static MODEL_BYTES: &[u8] = include_bytes!("../models/audio_embedding_model.onnx");
24
25/// The size of the embedding produced by the audio embedding model, onnx wants this to be i64.
26const EMBEDDING_SIZE: i64 = 32;
27/// The size of the embedding produced by the audio embedding model as a usize.
28pub const DIM_EMBEDDING: usize = 32;
29
30#[derive(Debug, Default, PartialEq, Clone, Copy)]
31#[repr(transparent)]
32pub struct Embedding(pub [f32; DIM_EMBEDDING]);
33
34impl Embedding {
35    /// Get the length of the embedding vector.
36    #[inline]
37    #[must_use]
38    pub const fn len(&self) -> usize {
39        self.0.len()
40    }
41
42    /// Check if the embedding is empty.
43    ///
44    /// Should always return false since embeddings have a fixed size.
45    #[inline]
46    #[must_use]
47    pub const fn is_empty(&self) -> bool {
48        self.0.is_empty()
49    }
50
51    /// Get a reference to the embedding as a slice.
52    #[inline]
53    #[must_use]
54    pub const fn as_slice(&self) -> &[f32] {
55        &self.0
56    }
57
58    /// Get a mutable reference to the embedding as a slice.
59    #[inline]
60    #[must_use]
61    pub const fn as_mut_slice(&mut self) -> &mut [f32] {
62        &mut self.0
63    }
64
65    /// Get the inner array of the embedding.
66    #[inline]
67    #[must_use]
68    pub const fn inner(&self) -> &[f32; DIM_EMBEDDING] {
69        &self.0
70    }
71}
72
73#[derive(Debug, Clone, Default)]
74pub struct ModelConfig {
75    pub path: Option<PathBuf>,
76}
77
78/// Struct representing an audio embedding model loaded from an ONNX file.
79#[derive(Debug)]
80pub struct AudioEmbeddingModel {
81    session: ort::session::Session,
82}
83
84/// Build a list of execution providers in priority order based on available features/platform.
85/// Providers are tried in order; if one fails to register, the next is attempted.
86/// CPU is always the final fallback.
87#[allow(clippy::vec_init_then_push)]
88fn build_execution_providers() -> Vec<ort::execution_providers::ExecutionProviderDispatch> {
89    let mut providers = Vec::new();
90
91    // GPU providers (feature-gated, require user to have appropriate drivers)
92    #[cfg(feature = "tensorrt")]
93    {
94        providers.push(TensorRTExecutionProvider::default().build());
95        log::info!("TensorRT execution provider enabled");
96    }
97
98    #[cfg(feature = "cuda")]
99    {
100        providers.push(CUDAExecutionProvider::default().build());
101        log::info!("CUDA execution provider enabled");
102    }
103
104    // Platform-specific zero-dependency providers
105    #[cfg(target_os = "windows")]
106    {
107        providers.push(DirectMLExecutionProvider::default().build());
108        log::info!("DirectML execution provider enabled (Windows)");
109    }
110
111    #[cfg(target_os = "macos")]
112    {
113        providers.push(
114            CoreMLExecutionProvider::default()
115                .with_subgraphs(true) // Enable CoreML on subgraphs for better coverage
116                .build(),
117        );
118        log::info!("CoreML execution provider enabled (macOS)");
119    }
120
121    // CPU is always the final fallback
122    // Disable the CPU memory arena to prevent memory accumulation when processing
123    // variable-length audio inputs. The arena allocator grows to accommodate the
124    // largest input seen but never shrinks, which causes memory to accumulate
125    // when processing many songs of varying lengths.
126    providers.push(
127        CPUExecutionProvider::default()
128            .with_arena_allocator(false)
129            .build(),
130    );
131
132    providers
133}
134
135fn session_builder() -> ort::Result<ort::session::builder::SessionBuilder> {
136    let providers = build_execution_providers();
137
138    let builder = Session::builder()?
139        .with_execution_providers(providers)?
140        .with_memory_pattern(false)?
141        .with_optimization_level(GraphOptimizationLevel::Level3)?;
142
143    Ok(builder)
144}
145
146impl AudioEmbeddingModel {
147    /// Load the default audio embedding model included in the package.
148    ///
149    /// # Errors
150    /// Fails if the model cannot be loaded for some reason.
151    #[inline]
152    pub fn load_default() -> ort::Result<Self> {
153        let session = session_builder()?.commit_from_memory(MODEL_BYTES)?;
154
155        Ok(Self { session })
156    }
157
158    /// Load an audio embedding model from the specified ONNX file path.
159    ///
160    /// # Errors
161    /// Fails if the model cannot be loaded for some reason.
162    #[inline]
163    pub fn load_from_onnx<P: AsRef<Path>>(path: P) -> ort::Result<Self> {
164        let session = session_builder()?.commit_from_file(&path)?;
165
166        Ok(Self { session })
167    }
168
169    /// Load the an audio embedding model with the specified configuration.
170    ///
171    /// # Arguments
172    ///
173    /// * `config` - The configuration for how the model should be loaded.
174    /// # Errors
175    /// Fails if the model cannot be loaded for some reason.
176    #[inline]
177    pub fn load(config: &ModelConfig) -> ort::Result<Self> {
178        config.path.as_ref().map_or_else(Self::load_default, |path|{
179            Self::load_from_onnx(path).or_else(|e| {
180                warn!("failed to load embeddings model from specified path: {e}, falling back to default model.");
181                Self::load_default()
182            })
183        })
184    }
185
186    /// Compute embedding from raw audio samples (f32, mono, 22050 Hz),
187    /// blocks during execution.
188    ///
189    /// # Errors
190    ///
191    /// Fails if:
192    /// * the audio cannot be converted into a tensor,
193    /// * the model inference fails,
194    /// * the output is missing or has an unexpected shape (should be named "embedding" and have shape `[1, 32]`).
195    #[allow(clippy::missing_inline_in_public_items)]
196    pub fn embed(&mut self, audio: &ResampledAudio) -> ort::Result<Embedding> {
197        // Create input with batch dimension
198        let inputs = ort::inputs! {
199            "audio" => TensorRef::from_array_view(([1, audio.samples.len()], audio.samples.as_slice()))?,
200        };
201
202        // Run inference
203        let outputs = self.session.run(inputs)?;
204
205        // Extract embedding
206        let (shape, embedding) = outputs["embedding"].try_extract_tensor::<f32>()?;
207
208        let expected_shape = &[1, EMBEDDING_SIZE];
209        if shape.iter().as_slice() != expected_shape {
210            return Err(ort::Error::new(format!(
211                "Unexpected embedding shape: {shape:?}, expected {expected_shape:?}",
212            )));
213        }
214
215        let sized_embedding: [f32; DIM_EMBEDDING] = embedding
216            .try_into()
217            .map_err(|_| ort::Error::new("Failed to convert embedding to fixed-size array"))?;
218
219        Ok(Embedding(sized_embedding))
220    }
221
222    /// Compute embeddings for a batch of raw audio samples (f32, mono, 22050 Hz),
223    /// blocks during execution.
224    ///
225    /// For efficiency, all audio samples should be similar in length.
226    ///
227    /// # Errors
228    ///
229    /// Fails if:
230    /// * the audio cannot be converted into a tensor,
231    /// * the model inference fails,
232    /// * the output is missing or has an unexpected shape (should be named "embedding" and have shape `[batch_size, 32]`).
233    #[allow(clippy::missing_inline_in_public_items)]
234    pub fn embed_batch(&mut self, audios: &[ResampledAudio]) -> ort::Result<Vec<Embedding>> {
235        let max_len = audios.iter().map(|a| a.samples.len()).max().unwrap_or(0);
236
237        let batch_size = audios.len();
238
239        // Prepare input tensor with zero-padding
240        let mut input_data = vec![0f32; batch_size * max_len];
241        for (i, audio) in audios.iter().enumerate() {
242            input_data[i * max_len..i * max_len + audio.samples.len()]
243                .copy_from_slice(&audio.samples);
244        }
245
246        let input = ort::inputs! {
247            "audio" => TensorRef::from_array_view(([batch_size, max_len], &*input_data))?,
248        };
249
250        // Run inference
251        let outputs = self.session.run(input)?;
252
253        // Extract embeddings
254        let (shape, embedding_tensor) = outputs["embedding"].try_extract_tensor::<f32>()?;
255        #[allow(clippy::cast_possible_wrap)]
256        let expected_shape = &[batch_size as i64, EMBEDDING_SIZE];
257        if shape.iter().as_slice() != expected_shape {
258            return Err(ort::Error::new(format!(
259                "Unexpected embedding shape: {shape:?}, expected {expected_shape:?}",
260            )));
261        }
262
263        let mut embeddings = Vec::with_capacity(batch_size);
264        for i in 0..batch_size {
265            let start = i * DIM_EMBEDDING;
266            let end = start + DIM_EMBEDDING;
267            let sized_embedding: [f32; DIM_EMBEDDING] = embedding_tensor[start..end]
268                .try_into()
269                .map_err(|_| ort::Error::new("Failed to convert embedding to fixed-size array"))?;
270            embeddings.push(Embedding(sized_embedding));
271        }
272
273        Ok(embeddings)
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use crate::decoder::Decoder;
281    use crate::decoder::MecompDecoder;
282
283    const TEST_AUDIO_PATH: &str = "data/5_mins_of_noise_stereo_48kHz.ogg";
284
285    #[test]
286    fn test_embedding_model() {
287        let decoder = MecompDecoder::new().unwrap();
288        let audio = decoder
289            .decode(Path::new(TEST_AUDIO_PATH))
290            .expect("Failed to decode test audio");
291
292        let mut model =
293            AudioEmbeddingModel::load_default().expect("Failed to load embedding model");
294        let embedding = model.embed(&audio).expect("Failed to compute embedding");
295        assert_eq!(embedding.len(), DIM_EMBEDDING);
296    }
297
298    #[test]
299    fn test_embedding_model_batch() {
300        let decoder = MecompDecoder::new().unwrap();
301        let audio = decoder
302            .decode(Path::new(TEST_AUDIO_PATH))
303            .expect("Failed to decode test audio");
304
305        let audios = vec![audio.clone(); 4];
306
307        let mut model =
308            AudioEmbeddingModel::load_default().expect("Failed to load embedding model");
309        let embeddings = model
310            .embed_batch(&audios)
311            .expect("Failed to compute batch embeddings");
312        assert_eq!(embeddings.len(), 4);
313        for embedding in &embeddings {
314            assert_eq!(embedding.len(), DIM_EMBEDDING);
315        }
316
317        // since all the audios are the same, all embeddings should be the same
318        for embedding in &embeddings[1..] {
319            assert_eq!(embedding, &embeddings[0]);
320        }
321    }
322
323    #[test]
324    fn test_embedding_model_batch_mixed_sizes() {
325        let decoder = MecompDecoder::new().unwrap();
326        let audio1 = decoder
327            .decode(Path::new(TEST_AUDIO_PATH))
328            .expect("Failed to decode test audio");
329
330        // create a shorter audio by taking only the first half of the samples
331        let audio2 = ResampledAudio {
332            samples: audio1.samples[..audio1.samples.len() / 2].to_vec(),
333            path: audio1.path.clone(),
334        };
335
336        let audios = vec![audio1.clone(), audio2.clone()];
337
338        let mut model =
339            AudioEmbeddingModel::load_default().expect("Failed to load embedding model");
340        let embeddings = model
341            .embed_batch(&audios)
342            .expect("Failed to compute batch embeddings");
343        assert_eq!(embeddings.len(), 2);
344        for embedding in &embeddings {
345            assert_eq!(embedding.len(), DIM_EMBEDDING);
346        }
347    }
348}