use crate::{
array::Array,
audio::load::{LoadedAudioModel, base_load_model},
error::{Error, InvariantViolationPayload, Result},
};
pub const MODEL_REMAPPING: &[(&str, &str)] = &[];
pub trait CodecModel {
fn encode(&self, audio: &Array) -> Result<Array> {
let _ = audio;
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"CodecModel::encode",
"needs `encode` override (per-architecture)",
)))
}
fn decode(&self, codes: &Array) -> Result<Array> {
let _ = codes;
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"CodecModel::decode",
"needs `decode` override (per-architecture)",
)))
}
fn sample_rate(&self) -> u32;
}
pub fn load_model<F>(path: &str, constructor: F) -> Result<Box<dyn CodecModel>>
where
F: FnOnce(LoadedAudioModel) -> Result<Box<dyn CodecModel>>,
{
let bundle = base_load_model(path)?;
constructor(bundle)
}
pub fn load<F>(path: &str, constructor: F) -> Result<Box<dyn CodecModel>>
where
F: FnOnce(LoadedAudioModel) -> Result<Box<dyn CodecModel>>,
{
load_model(path, constructor)
}
#[cfg(test)]
mod tests {
use super::*;
use std::{fs, path::PathBuf};
struct FakeCodec;
impl CodecModel for FakeCodec {
fn encode(&self, audio: &Array) -> Result<Array> {
let t = audio.size();
Array::from_slice::<f32>(&vec![0.0; t], &(1, 1, t))
}
fn decode(&self, codes: &Array) -> Result<Array> {
let t = codes.shape().iter().product::<usize>();
Array::from_slice::<f32>(&vec![0.0; t], &(t,))
}
fn sample_rate(&self) -> u32 {
24_000
}
}
fn temp_dir(name: &str) -> PathBuf {
let dir = std::env::temp_dir().join(format!(
"mlxrs_audio_codec_load_{}_{}",
std::process::id(),
name
));
let _ = fs::remove_dir_all(&dir);
fs::create_dir_all(&dir).unwrap();
dir
}
#[test]
fn load_codec_constructs_via_factory() {
let dir = temp_dir("constructs_via_factory");
let body = r#"{ "model_type": "encodec", "sample_rate": 24000 }"#;
fs::write(dir.join("config.json"), body).unwrap();
let captured: std::cell::RefCell<Option<PathBuf>> = std::cell::RefCell::new(None);
let model = load(&dir.to_string_lossy(), |bundle| {
*captured.borrow_mut() = Some(bundle.model_path().to_path_buf());
Ok(Box::new(FakeCodec))
})
.expect("load constructs via the supplied factory");
assert_eq!(captured.into_inner().unwrap(), dir);
assert_eq!(model.sample_rate(), 24_000);
let probe = Array::from_slice::<f32>(&[0.0_f32; 8], &(8,)).unwrap();
let codes = model.encode(&probe).unwrap();
let back = model.decode(&codes).unwrap();
assert_eq!(back.shape().iter().product::<usize>(), 8);
}
}