use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use blazen_audio::AudioBackend;
use blazen_audio_dac_vendored as upstream;
use candle_core::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use serde::{Deserialize, Serialize};
use tokio::sync::OnceCell;
use crate::error::{CodecError, Result};
use crate::traits::CodecBackend;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DacConfig {
pub repo_id: String,
pub revision: Option<String>,
pub weights_filename: String,
pub config_filename: String,
pub cpu_only: bool,
pub cache_dir: Option<PathBuf>,
}
impl Default for DacConfig {
fn default() -> Self {
Self {
repo_id: "descript/dac_44khz".to_string(),
revision: None,
weights_filename: "model.safetensors".to_string(),
config_filename: "config.json".to_string(),
cpu_only: false,
cache_dir: None,
}
}
}
fn pick_device(cpu_only: bool) -> Device {
if cpu_only {
return Device::Cpu;
}
#[cfg(feature = "cuda")]
{
if let Ok(dev) = Device::new_cuda(0) {
return dev;
}
}
#[cfg(feature = "metal")]
{
if let Ok(dev) = Device::new_metal(0) {
return dev;
}
}
Device::Cpu
}
#[derive(Debug, Clone, Deserialize)]
struct HfDacConfig {
sampling_rate: u32,
n_codebooks: usize,
codebook_size: usize,
hidden_size: usize,
#[serde(default = "default_hop_length")]
hop_length: u32,
}
const fn default_hop_length() -> u32 {
512
}
impl HfDacConfig {
fn into_candle(self) -> upstream::Config {
let frame_rate = self.sampling_rate / self.hop_length;
let codebooks_u32 = u32::try_from(self.n_codebooks).unwrap_or(u32::MAX);
let model_bitrate = codebooks_u32.saturating_mul(frame_rate).saturating_mul(10);
upstream::Config {
num_codebooks: self.n_codebooks,
model_bitrate,
codebook_size: self.codebook_size,
latent_dim: self.hidden_size,
frame_rate,
sampling_rate: self.sampling_rate,
}
}
}
struct LoadedDac {
inner: upstream::Model,
device: Device,
sample_rate: u32,
num_codebooks: usize,
}
pub struct DacBackend {
id: String,
config: DacConfig,
loaded: Arc<OnceCell<LoadedDac>>,
}
impl std::fmt::Debug for DacBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DacBackend")
.field("id", &self.id)
.field("config", &self.config)
.field("loaded", &self.loaded.initialized())
.finish()
}
}
impl DacBackend {
#[must_use]
pub fn new(config: DacConfig) -> Self {
let id = format!("dac:{}", config.repo_id);
Self {
id,
config,
loaded: Arc::new(OnceCell::new()),
}
}
#[must_use]
pub fn default_44khz() -> Self {
Self::new(DacConfig::default())
}
#[must_use]
pub fn config(&self) -> &DacConfig {
&self.config
}
#[must_use]
pub fn sample_rate_loaded(&self) -> Option<u32> {
self.loaded.get().map(|m| m.sample_rate)
}
#[must_use]
pub fn num_codebooks_loaded(&self) -> Option<usize> {
self.loaded.get().map(|m| m.num_codebooks)
}
async fn ensure_loaded(&self) -> Result<&LoadedDac> {
self.loaded
.get_or_try_init(|| async { self.load_inner().await })
.await
}
async fn load_inner(&self) -> Result<LoadedDac> {
let repo = self.config.repo_id.clone();
let revision = self.config.revision.clone();
let weights_filename = self.config.weights_filename.clone();
let config_filename = self.config.config_filename.clone();
let cache_dir = self.config.cache_dir.clone();
let (weights_path, config_path) =
tokio::task::spawn_blocking(move || -> Result<(PathBuf, PathBuf)> {
let mut builder = hf_hub::api::sync::ApiBuilder::new();
if let Some(dir) = cache_dir {
builder = builder.with_cache_dir(dir);
}
let api = builder.build().map_err(|e| CodecError::HfHub {
repo: repo.clone(),
source: std::io::Error::other(e.to_string()),
})?;
let model_repo = match revision {
Some(rev) => api.repo(hf_hub::Repo::with_revision(
repo.clone(),
hf_hub::RepoType::Model,
rev,
)),
None => api.model(repo.clone()),
};
let weights = model_repo
.get(&weights_filename)
.map_err(|e| CodecError::HfHub {
repo: repo.clone(),
source: std::io::Error::other(e.to_string()),
})?;
let cfg = model_repo
.get(&config_filename)
.map_err(|e| CodecError::HfHub {
repo,
source: std::io::Error::other(e.to_string()),
})?;
Ok((weights, cfg))
})
.await
.map_err(|e| CodecError::other(format!("blocking task join failed: {e}")))??;
let config_bytes = std::fs::read(&config_path).map_err(CodecError::Io)?;
let hf_config: HfDacConfig = serde_json::from_slice(&config_bytes).map_err(|e| {
CodecError::other(format!(
"failed to parse DAC config.json at {}: {e}",
config_path.display()
))
})?;
let sample_rate = hf_config.sampling_rate;
let num_codebooks = hf_config.n_codebooks;
let candle_cfg = hf_config.into_candle();
let device = pick_device(self.config.cpu_only);
#[allow(unsafe_code)]
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[&weights_path], DType::F32, &device)
.map_err(CodecError::from)?
};
let inner = upstream::Model::new(&candle_cfg, vb).map_err(CodecError::from)?;
Ok(LoadedDac {
inner,
device,
sample_rate,
num_codebooks,
})
}
}
#[async_trait]
impl AudioBackend for DacBackend {
fn id(&self) -> &str {
&self.id
}
fn provider_kind(&self) -> &'static str {
"codec"
}
async fn load(&self) -> std::result::Result<(), blazen_audio::AudioError> {
self.ensure_loaded()
.await
.map_err(blazen_audio::AudioError::from)?;
Ok(())
}
async fn is_loaded(&self) -> bool {
self.loaded.initialized()
}
}
#[async_trait]
impl CodecBackend for DacBackend {
async fn encode_pcm(&self, samples: &[f32], sample_rate: u32) -> Result<Vec<u32>> {
if samples.is_empty() {
return Err(CodecError::invalid_input("PCM input is empty"));
}
let model = self.ensure_loaded().await?;
if sample_rate != model.sample_rate {
return Err(CodecError::invalid_input(format!(
"expected sample rate {} Hz, got {} Hz -- resample first",
model.sample_rate, sample_rate
)));
}
let xs = Tensor::from_slice(samples, (1, 1, samples.len()), &model.device)
.map_err(CodecError::from)?;
let codes = model.inner.encode_to_codes(&xs).map_err(CodecError::from)?;
let (_, n_cb, t_prime) = codes.dims3().map_err(CodecError::from)?;
if n_cb != model.num_codebooks {
return Err(CodecError::other(format!(
"DAC encoder produced {n_cb} codebooks, expected {} from loaded config",
model.num_codebooks
)));
}
let flat = codes
.i(0)
.map_err(CodecError::from)?
.reshape((n_cb * t_prime,))
.map_err(CodecError::from)?
.to_vec1::<u32>()
.map_err(CodecError::from)?;
Ok(flat)
}
async fn decode_tokens(&self, tokens: &[u32], num_codebooks: usize) -> Result<Vec<f32>> {
if num_codebooks == 0 {
return Err(CodecError::invalid_input("num_codebooks must be > 0"));
}
if tokens.is_empty() || !tokens.len().is_multiple_of(num_codebooks) {
return Err(CodecError::invalid_input(format!(
"token count {} is not a positive multiple of num_codebooks {}",
tokens.len(),
num_codebooks
)));
}
let model = self.ensure_loaded().await?;
if num_codebooks != model.num_codebooks {
return Err(CodecError::invalid_input(format!(
"expected num_codebooks {} (from loaded model), got {num_codebooks}",
model.num_codebooks
)));
}
let seqlen = tokens.len() / num_codebooks;
let codes = Tensor::from_slice(tokens, (1, num_codebooks, seqlen), &model.device)
.map_err(CodecError::from)?;
let audio = model.inner.decode_codes(&codes).map_err(CodecError::from)?;
let audio = audio
.i(0)
.map_err(CodecError::from)?
.i(0)
.map_err(CodecError::from)?
.flatten_all()
.map_err(CodecError::from)?;
audio.to_vec1::<f32>().map_err(CodecError::from)
}
fn sample_rate(&self) -> u32 {
self.sample_rate_loaded().unwrap_or(44_100)
}
fn num_codebooks(&self) -> usize {
self.num_codebooks_loaded().unwrap_or(9)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_uses_44khz_repo() {
let cfg = DacConfig::default();
assert_eq!(cfg.repo_id, "descript/dac_44khz");
assert_eq!(cfg.weights_filename, "model.safetensors");
assert_eq!(cfg.config_filename, "config.json");
assert!(cfg.revision.is_none());
assert!(!cfg.cpu_only);
assert!(cfg.cache_dir.is_none());
}
#[test]
fn default_44khz_constructor_matches_default_config() {
let backend = DacBackend::default_44khz();
assert_eq!(backend.config().repo_id, "descript/dac_44khz");
assert_eq!(CodecBackend::sample_rate(&backend), 44_100);
assert_eq!(CodecBackend::num_codebooks(&backend), 9);
assert!(backend.sample_rate_loaded().is_none());
assert!(backend.num_codebooks_loaded().is_none());
}
#[test]
fn id_includes_repo_name() {
let backend = DacBackend::default_44khz();
assert_eq!(backend.id(), "dac:descript/dac_44khz");
assert_eq!(backend.provider_kind(), "codec");
}
#[test]
fn custom_repo_round_trips_through_config() {
let cfg = DacConfig {
repo_id: "descript/dac_24khz".to_string(),
revision: Some("main".to_string()),
..Default::default()
};
let backend = DacBackend::new(cfg);
assert_eq!(backend.id(), "dac:descript/dac_24khz");
assert_eq!(backend.config().revision.as_deref(), Some("main"));
}
#[test]
fn hf_config_maps_into_candle_config() {
let raw = serde_json::json!({
"sampling_rate": 44100,
"n_codebooks": 9,
"codebook_size": 1024,
"hidden_size": 1024,
"hop_length": 512
});
let hf: HfDacConfig = serde_json::from_value(raw).unwrap();
let candle = hf.into_candle();
assert_eq!(candle.sampling_rate, 44_100);
assert_eq!(candle.num_codebooks, 9);
assert_eq!(candle.codebook_size, 1024);
assert_eq!(candle.latent_dim, 1024);
assert_eq!(candle.frame_rate, 44_100 / 512);
assert_eq!(candle.model_bitrate, 9 * (44_100 / 512) * 10);
}
#[tokio::test]
async fn encode_rejects_empty_pcm() {
let backend = DacBackend::default_44khz();
let err = backend.encode_pcm(&[], 44_100).await.unwrap_err();
match err {
CodecError::InvalidInput(msg) => assert!(msg.contains("empty")),
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn decode_rejects_misaligned_tokens() {
let backend = DacBackend::default_44khz();
let err = backend
.decode_tokens(&[1, 2, 3, 4, 5], 9)
.await
.unwrap_err();
match err {
CodecError::InvalidInput(msg) => assert!(msg.contains("multiple")),
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn decode_rejects_zero_codebooks() {
let backend = DacBackend::default_44khz();
let err = backend.decode_tokens(&[1, 2, 3], 0).await.unwrap_err();
match err {
CodecError::InvalidInput(msg) => assert!(msg.contains("num_codebooks")),
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn is_loaded_starts_false() {
let backend = DacBackend::default_44khz();
assert!(!AudioBackend::is_loaded(&backend).await);
}
#[tokio::test]
async fn provider_kind_dispatches_through_audio_backend_trait() {
let backend: Arc<dyn AudioBackend> = Arc::new(DacBackend::default_44khz());
assert_eq!(backend.provider_kind(), "codec");
assert!(backend.id().starts_with("dac:"));
}
#[cfg(feature = "live-models")]
#[ignore = "upstream candle DAC builder expects weight_norm parametrisation \
+ a different nested-block naming convention than the published \
descript/dac_44khz HF checkpoint; load fails before encode runs"]
#[tokio::test]
async fn live_round_trip_sine_wave_44khz() {
let backend = DacBackend::default_44khz();
backend
.load()
.await
.expect("load default 44 kHz DAC weights");
let num_codebooks = backend
.num_codebooks_loaded()
.expect("loaded codebook count");
let sample_rate = backend.sample_rate_loaded().expect("loaded sample rate");
assert_eq!(sample_rate, 44_100);
assert_eq!(num_codebooks, 9);
let len = 44_100usize;
let mut pcm: Vec<f32> = Vec::with_capacity(len);
for i in 0..len {
#[allow(clippy::cast_precision_loss)]
let t = i as f32 / 44_100.0;
pcm.push(0.5 * (2.0 * std::f32::consts::PI * 440.0 * t).sin());
}
let codes = backend
.encode_pcm(&pcm, 44_100)
.await
.expect("encode 44 kHz sine wave");
assert!(!codes.is_empty(), "encoded codes must not be empty");
assert!(
codes.len().is_multiple_of(num_codebooks),
"encoded token count {} should be a multiple of num_codebooks {num_codebooks}",
codes.len(),
);
let frames_per_cb = codes.len() / num_codebooks;
let expected_frames = 44_100usize / 512;
let frame_diff = frames_per_cb.abs_diff(expected_frames);
assert!(
frame_diff <= 2,
"expected ~{expected_frames} frames per codebook, got {frames_per_cb}",
);
for &c in &codes {
assert!(
c < 1024,
"DAC code {c} exceeds the 44 kHz codebook size of 1024",
);
}
let decoded = backend
.decode_tokens(&codes, num_codebooks)
.await
.expect("decode DAC codes");
assert!(!decoded.is_empty(), "decoded PCM must not be empty");
for s in &decoded {
assert!(s.is_finite(), "decoded sample must be finite, got {s}");
}
let tolerance = 1024usize;
let diff = decoded.len().abs_diff(len);
assert!(
diff <= tolerance,
"decoded len {} differs from input len {len} by more than {tolerance}",
decoded.len()
);
let energy: f32 = decoded.iter().map(|s| s * s).sum();
assert!(
energy > 1.0,
"decoded waveform has near-zero energy ({energy}); encode likely produced \
garbage codes",
);
}
}