use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use blazen_audio::AudioBackend;
use candle_core::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::encodec as upstream;
use serde::{Deserialize, Serialize};
use tokio::sync::OnceCell;
use crate::error::{CodecError, Result};
use crate::traits::CodecBackend;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncodecConfig {
pub hf_repo: String,
pub weights_filename: String,
pub cpu_only: bool,
pub cache_dir: Option<PathBuf>,
}
impl Default for EncodecConfig {
fn default() -> Self {
Self {
hf_repo: "facebook/encodec_24khz".to_string(),
weights_filename: "model.safetensors".to_string(),
cpu_only: false,
cache_dir: None,
}
}
}
fn pick_device(cpu_only: bool) -> Device {
if cpu_only {
return Device::Cpu;
}
#[cfg(feature = "cuda")]
{
if let Ok(dev) = Device::new_cuda(0) {
return dev;
}
}
#[cfg(feature = "metal")]
{
if let Ok(dev) = Device::new_metal(0) {
return dev;
}
}
Device::Cpu
}
struct LoadedModel {
inner: upstream::Model,
device: Device,
sample_rate: u32,
num_codebooks: usize,
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
fn config_num_codebooks(cfg: &upstream::Config) -> usize {
let last_bw = cfg.target_bandwidths.last().copied().unwrap_or(24.0_f64);
let hop_length: usize = cfg.upsampling_ratios.iter().product();
let frame_rate = if hop_length == 0 {
1
} else {
cfg.sampling_rate.div_ceil(hop_length)
};
let bits_per_second = 1000.0_f64 * last_bw;
let denom = frame_rate.saturating_mul(10).max(1);
(bits_per_second as usize) / denom
}
pub struct EncodecBackend {
id: String,
config: EncodecConfig,
loaded: Arc<OnceCell<LoadedModel>>,
}
impl std::fmt::Debug for EncodecBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncodecBackend")
.field("id", &self.id)
.field("config", &self.config)
.field("loaded", &self.loaded.initialized())
.finish()
}
}
impl EncodecBackend {
#[must_use]
pub fn new(config: EncodecConfig) -> Self {
let id = format!("encodec:{}", config.hf_repo);
Self {
id,
config,
loaded: Arc::new(OnceCell::new()),
}
}
#[must_use]
pub fn default_24khz() -> Self {
Self::new(EncodecConfig::default())
}
#[must_use]
pub fn config(&self) -> &EncodecConfig {
&self.config
}
#[must_use]
pub fn sample_rate_loaded(&self) -> Option<u32> {
self.loaded.get().map(|m| m.sample_rate)
}
async fn ensure_loaded(&self) -> Result<&LoadedModel> {
self.loaded
.get_or_try_init(|| async { self.load_inner().await })
.await
}
async fn load_inner(&self) -> Result<LoadedModel> {
let repo = self.config.hf_repo.clone();
let filename = self.config.weights_filename.clone();
let cache_dir = self.config.cache_dir.clone();
let weights_path = tokio::task::spawn_blocking(move || -> Result<PathBuf> {
let mut builder = hf_hub::api::sync::ApiBuilder::new();
if let Some(dir) = cache_dir {
builder = builder.with_cache_dir(dir);
}
let api = builder.build().map_err(|e| CodecError::HfHub {
repo: repo.clone(),
source: std::io::Error::other(e.to_string()),
})?;
api.model(repo.clone())
.get(&filename)
.map_err(|e| CodecError::HfHub {
repo,
source: std::io::Error::other(e.to_string()),
})
})
.await
.map_err(|e| CodecError::other(format!("blocking task join failed: {e}")))??;
let device = pick_device(self.config.cpu_only);
let cfg = upstream::Config::default();
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let sample_rate = cfg.sampling_rate as u32;
let num_codebooks = config_num_codebooks(&cfg);
#[allow(unsafe_code)]
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[&weights_path], DType::F32, &device)
.map_err(CodecError::from)?
};
let inner = upstream::Model::new(&cfg, vb).map_err(CodecError::from)?;
Ok(LoadedModel {
inner,
device,
sample_rate,
num_codebooks,
})
}
}
#[async_trait]
impl AudioBackend for EncodecBackend {
fn id(&self) -> &str {
&self.id
}
fn provider_kind(&self) -> &'static str {
"codec"
}
async fn load(&self) -> std::result::Result<(), blazen_audio::AudioError> {
self.ensure_loaded()
.await
.map_err(blazen_audio::AudioError::from)?;
Ok(())
}
async fn is_loaded(&self) -> bool {
self.loaded.initialized()
}
}
#[async_trait]
impl CodecBackend for EncodecBackend {
async fn encode_pcm(&self, samples: &[f32], sample_rate: u32) -> Result<Vec<u32>> {
if samples.is_empty() {
return Err(CodecError::invalid_input("PCM input is empty"));
}
let model = self.ensure_loaded().await?;
if sample_rate != model.sample_rate {
return Err(CodecError::invalid_input(format!(
"expected sample rate {} Hz, got {} Hz -- resample first",
model.sample_rate, sample_rate
)));
}
let xs = Tensor::from_slice(samples, (1, 1, samples.len()), &model.device)
.map_err(CodecError::from)?;
let codes = model.inner.encode(&xs).map_err(CodecError::from)?;
let codes = codes
.i(0)
.map_err(CodecError::from)?
.flatten_all()
.map_err(CodecError::from)?;
codes.to_vec1::<u32>().map_err(CodecError::from)
}
async fn decode_tokens(&self, tokens: &[u32], num_codebooks: usize) -> Result<Vec<f32>> {
if num_codebooks == 0 {
return Err(CodecError::invalid_input("num_codebooks must be > 0"));
}
if tokens.is_empty() || !tokens.len().is_multiple_of(num_codebooks) {
return Err(CodecError::invalid_input(format!(
"token count {} is not a positive multiple of num_codebooks {}",
tokens.len(),
num_codebooks
)));
}
let model = self.ensure_loaded().await?;
let seqlen = tokens.len() / num_codebooks;
let codes = Tensor::from_slice(tokens, (1, num_codebooks, seqlen), &model.device)
.map_err(CodecError::from)?;
let audio = model.inner.decode(&codes).map_err(CodecError::from)?;
let audio = audio
.i(0)
.map_err(CodecError::from)?
.i(0)
.map_err(CodecError::from)?
.flatten_all()
.map_err(CodecError::from)?;
audio.to_vec1::<f32>().map_err(CodecError::from)
}
fn sample_rate(&self) -> u32 {
self.sample_rate_loaded().unwrap_or(24_000)
}
fn num_codebooks(&self) -> usize {
if let Some(model) = self.loaded.get() {
model.num_codebooks
} else {
config_num_codebooks(&upstream::Config::default())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_uses_24khz_repo() {
let cfg = EncodecConfig::default();
assert_eq!(cfg.hf_repo, "facebook/encodec_24khz");
assert_eq!(cfg.weights_filename, "model.safetensors");
assert!(!cfg.cpu_only);
assert!(cfg.cache_dir.is_none());
}
#[test]
fn default_24khz_constructor_matches_default_config() {
let backend = EncodecBackend::default_24khz();
assert_eq!(backend.config().hf_repo, "facebook/encodec_24khz");
assert_eq!(CodecBackend::sample_rate(&backend), 24_000);
assert_eq!(CodecBackend::num_codebooks(&backend), 32);
assert!(backend.sample_rate_loaded().is_none());
}
#[test]
fn config_num_codebooks_matches_default_24kbps() {
let cfg = upstream::Config::default();
assert_eq!(config_num_codebooks(&cfg), 32);
}
#[test]
fn id_includes_repo_name() {
let backend = EncodecBackend::default_24khz();
assert_eq!(backend.id(), "encodec:facebook/encodec_24khz");
assert_eq!(backend.provider_kind(), "codec");
}
#[tokio::test]
async fn encode_rejects_empty_pcm() {
let backend = EncodecBackend::default_24khz();
let err = backend.encode_pcm(&[], 24_000).await.unwrap_err();
match err {
CodecError::InvalidInput(msg) => assert!(msg.contains("empty")),
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn decode_rejects_misaligned_tokens() {
let backend = EncodecBackend::default_24khz();
let err = backend
.decode_tokens(&[1, 2, 3, 4, 5], 4)
.await
.unwrap_err();
match err {
CodecError::InvalidInput(msg) => {
assert!(msg.contains("multiple"));
}
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn decode_rejects_zero_codebooks() {
let backend = EncodecBackend::default_24khz();
let err = backend.decode_tokens(&[1, 2, 3], 0).await.unwrap_err();
match err {
CodecError::InvalidInput(msg) => assert!(msg.contains("num_codebooks")),
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn is_loaded_starts_false() {
let backend = EncodecBackend::default_24khz();
assert!(!AudioBackend::is_loaded(&backend).await);
}
#[cfg(feature = "live-models")]
#[tokio::test]
async fn live_round_trip_sine_wave_24khz() {
let backend = EncodecBackend::default_24khz();
AudioBackend::load(&backend)
.await
.expect("load default 24 kHz EnCodec weights");
let sample_rate = backend.sample_rate_loaded().expect("loaded sample rate");
assert_eq!(sample_rate, 24_000);
let len = 24_000usize;
let mut pcm: Vec<f32> = Vec::with_capacity(len);
for i in 0..len {
#[allow(clippy::cast_precision_loss)]
let t = i as f32 / 24_000.0;
pcm.push(0.5 * (2.0 * std::f32::consts::PI * 440.0 * t).sin());
}
let codes = backend
.encode_pcm(&pcm, 24_000)
.await
.expect("encode 24 kHz sine wave");
assert!(!codes.is_empty(), "encoded codes must not be empty");
let num_codebooks = CodecBackend::num_codebooks(&backend);
let decoded = backend
.decode_tokens(&codes, num_codebooks)
.await
.expect("decode EnCodec codes");
assert!(!decoded.is_empty(), "decoded PCM must not be empty");
for s in &decoded {
assert!(s.is_finite(), "decoded sample must be finite, got {s}");
}
let tolerance = 2048usize;
let diff = decoded.len().abs_diff(len);
assert!(
diff <= tolerance,
"decoded len {} differs from input len {len} by more than {tolerance}",
decoded.len()
);
}
}