mecomp_analysis/
embeddings.rs1use crate::ResampledAudio;
5use log::warn;
6use ort::{
7 execution_providers::CPUExecutionProvider,
8 session::{Session, builder::GraphOptimizationLevel},
9 value::TensorRef,
10};
11use std::path::{Path, PathBuf};
12
13#[cfg(feature = "cuda")]
15use ort::execution_providers::CUDAExecutionProvider;
16#[cfg(target_os = "macos")]
17use ort::execution_providers::CoreMLExecutionProvider;
18#[cfg(target_os = "windows")]
19use ort::execution_providers::DirectMLExecutionProvider;
20#[cfg(feature = "tensorrt")]
21use ort::execution_providers::TensorRTExecutionProvider;
22
23static MODEL_BYTES: &[u8] = include_bytes!("../models/audio_embedding_model.onnx");
24
25const EMBEDDING_SIZE: i64 = 32;
27pub const DIM_EMBEDDING: usize = 32;
29
30#[derive(Debug, Default, PartialEq, Clone, Copy)]
31#[repr(transparent)]
32pub struct Embedding(pub [f32; DIM_EMBEDDING]);
33
34impl Embedding {
35 #[inline]
37 #[must_use]
38 pub const fn len(&self) -> usize {
39 self.0.len()
40 }
41
42 #[inline]
46 #[must_use]
47 pub const fn is_empty(&self) -> bool {
48 self.0.is_empty()
49 }
50
51 #[inline]
53 #[must_use]
54 pub const fn as_slice(&self) -> &[f32] {
55 &self.0
56 }
57
58 #[inline]
60 #[must_use]
61 pub const fn as_mut_slice(&mut self) -> &mut [f32] {
62 &mut self.0
63 }
64
65 #[inline]
67 #[must_use]
68 pub const fn inner(&self) -> &[f32; DIM_EMBEDDING] {
69 &self.0
70 }
71}
72
73#[derive(Debug, Clone, Default)]
74pub struct ModelConfig {
75 pub path: Option<PathBuf>,
76}
77
78#[derive(Debug)]
80pub struct AudioEmbeddingModel {
81 session: ort::session::Session,
82}
83
84#[allow(clippy::vec_init_then_push)]
88fn build_execution_providers() -> Vec<ort::execution_providers::ExecutionProviderDispatch> {
89 let mut providers = Vec::new();
90
91 #[cfg(feature = "tensorrt")]
93 {
94 providers.push(TensorRTExecutionProvider::default().build());
95 log::info!("TensorRT execution provider enabled");
96 }
97
98 #[cfg(feature = "cuda")]
99 {
100 providers.push(CUDAExecutionProvider::default().build());
101 log::info!("CUDA execution provider enabled");
102 }
103
104 #[cfg(target_os = "windows")]
106 {
107 providers.push(DirectMLExecutionProvider::default().build());
108 log::info!("DirectML execution provider enabled (Windows)");
109 }
110
111 #[cfg(target_os = "macos")]
112 {
113 providers.push(
114 CoreMLExecutionProvider::default()
115 .with_subgraphs(true) .build(),
117 );
118 log::info!("CoreML execution provider enabled (macOS)");
119 }
120
121 providers.push(
127 CPUExecutionProvider::default()
128 .with_arena_allocator(false)
129 .build(),
130 );
131
132 providers
133}
134
135fn session_builder() -> ort::Result<ort::session::builder::SessionBuilder> {
136 let providers = build_execution_providers();
137
138 let builder = Session::builder()?
139 .with_execution_providers(providers)?
140 .with_memory_pattern(false)?
141 .with_optimization_level(GraphOptimizationLevel::Level3)?;
142
143 Ok(builder)
144}
145
146impl AudioEmbeddingModel {
147 #[inline]
152 pub fn load_default() -> ort::Result<Self> {
153 let session = session_builder()?.commit_from_memory(MODEL_BYTES)?;
154
155 Ok(Self { session })
156 }
157
158 #[inline]
163 pub fn load_from_onnx<P: AsRef<Path>>(path: P) -> ort::Result<Self> {
164 let session = session_builder()?.commit_from_file(&path)?;
165
166 Ok(Self { session })
167 }
168
169 #[inline]
177 pub fn load(config: &ModelConfig) -> ort::Result<Self> {
178 config.path.as_ref().map_or_else(Self::load_default, |path|{
179 Self::load_from_onnx(path).or_else(|e| {
180 warn!("failed to load embeddings model from specified path: {e}, falling back to default model.");
181 Self::load_default()
182 })
183 })
184 }
185
186 #[allow(clippy::missing_inline_in_public_items)]
196 pub fn embed(&mut self, audio: &ResampledAudio) -> ort::Result<Embedding> {
197 let inputs = ort::inputs! {
199 "audio" => TensorRef::from_array_view(([1, audio.samples.len()], audio.samples.as_slice()))?,
200 };
201
202 let outputs = self.session.run(inputs)?;
204
205 let (shape, embedding) = outputs["embedding"].try_extract_tensor::<f32>()?;
207
208 let expected_shape = &[1, EMBEDDING_SIZE];
209 if shape.iter().as_slice() != expected_shape {
210 return Err(ort::Error::new(format!(
211 "Unexpected embedding shape: {shape:?}, expected {expected_shape:?}",
212 )));
213 }
214
215 let sized_embedding: [f32; DIM_EMBEDDING] = embedding
216 .try_into()
217 .map_err(|_| ort::Error::new("Failed to convert embedding to fixed-size array"))?;
218
219 Ok(Embedding(sized_embedding))
220 }
221
222 #[allow(clippy::missing_inline_in_public_items)]
234 pub fn embed_batch(&mut self, audios: &[ResampledAudio]) -> ort::Result<Vec<Embedding>> {
235 let max_len = audios.iter().map(|a| a.samples.len()).max().unwrap_or(0);
236
237 let batch_size = audios.len();
238
239 let mut input_data = vec![0f32; batch_size * max_len];
241 for (i, audio) in audios.iter().enumerate() {
242 input_data[i * max_len..i * max_len + audio.samples.len()]
243 .copy_from_slice(&audio.samples);
244 }
245
246 let input = ort::inputs! {
247 "audio" => TensorRef::from_array_view(([batch_size, max_len], &*input_data))?,
248 };
249
250 let outputs = self.session.run(input)?;
252
253 let (shape, embedding_tensor) = outputs["embedding"].try_extract_tensor::<f32>()?;
255 #[allow(clippy::cast_possible_wrap)]
256 let expected_shape = &[batch_size as i64, EMBEDDING_SIZE];
257 if shape.iter().as_slice() != expected_shape {
258 return Err(ort::Error::new(format!(
259 "Unexpected embedding shape: {shape:?}, expected {expected_shape:?}",
260 )));
261 }
262
263 let mut embeddings = Vec::with_capacity(batch_size);
264 for i in 0..batch_size {
265 let start = i * DIM_EMBEDDING;
266 let end = start + DIM_EMBEDDING;
267 let sized_embedding: [f32; DIM_EMBEDDING] = embedding_tensor[start..end]
268 .try_into()
269 .map_err(|_| ort::Error::new("Failed to convert embedding to fixed-size array"))?;
270 embeddings.push(Embedding(sized_embedding));
271 }
272
273 Ok(embeddings)
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use crate::decoder::Decoder;
281 use crate::decoder::MecompDecoder;
282
283 const TEST_AUDIO_PATH: &str = "data/5_mins_of_noise_stereo_48kHz.ogg";
284
285 #[test]
286 fn test_embedding_model() {
287 let decoder = MecompDecoder::new().unwrap();
288 let audio = decoder
289 .decode(Path::new(TEST_AUDIO_PATH))
290 .expect("Failed to decode test audio");
291
292 let mut model =
293 AudioEmbeddingModel::load_default().expect("Failed to load embedding model");
294 let embedding = model.embed(&audio).expect("Failed to compute embedding");
295 assert_eq!(embedding.len(), DIM_EMBEDDING);
296 }
297
298 #[test]
299 fn test_embedding_model_batch() {
300 let decoder = MecompDecoder::new().unwrap();
301 let audio = decoder
302 .decode(Path::new(TEST_AUDIO_PATH))
303 .expect("Failed to decode test audio");
304
305 let audios = vec![audio.clone(); 4];
306
307 let mut model =
308 AudioEmbeddingModel::load_default().expect("Failed to load embedding model");
309 let embeddings = model
310 .embed_batch(&audios)
311 .expect("Failed to compute batch embeddings");
312 assert_eq!(embeddings.len(), 4);
313 for embedding in &embeddings {
314 assert_eq!(embedding.len(), DIM_EMBEDDING);
315 }
316
317 for embedding in &embeddings[1..] {
319 assert_eq!(embedding, &embeddings[0]);
320 }
321 }
322
323 #[test]
324 fn test_embedding_model_batch_mixed_sizes() {
325 let decoder = MecompDecoder::new().unwrap();
326 let audio1 = decoder
327 .decode(Path::new(TEST_AUDIO_PATH))
328 .expect("Failed to decode test audio");
329
330 let audio2 = ResampledAudio {
332 samples: audio1.samples[..audio1.samples.len() / 2].to_vec(),
333 path: audio1.path.clone(),
334 };
335
336 let audios = vec![audio1.clone(), audio2.clone()];
337
338 let mut model =
339 AudioEmbeddingModel::load_default().expect("Failed to load embedding model");
340 let embeddings = model
341 .embed_batch(&audios)
342 .expect("Failed to compute batch embeddings");
343 assert_eq!(embeddings.len(), 2);
344 for embedding in &embeddings {
345 assert_eq!(embedding.len(), DIM_EMBEDDING);
346 }
347 }
348}