use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use blazen_audio::AudioBackend;
use candle_core::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::snac as upstream;
use serde::{Deserialize, Serialize};
use tokio::sync::OnceCell;
use crate::error::{CodecError, Result};
use crate::traits::CodecBackend;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnacConfig {
pub repo_id: String,
pub revision: Option<String>,
pub weights_filename: String,
pub config_filename: String,
pub cpu_only: bool,
pub cache_dir: Option<PathBuf>,
}
impl Default for SnacConfig {
fn default() -> Self {
Self {
repo_id: "hubertsiuzdak/snac_24khz".to_string(),
revision: None,
weights_filename: "pytorch_model.bin".to_string(),
config_filename: "config.json".to_string(),
cpu_only: false,
cache_dir: None,
}
}
}
fn pick_device(cpu_only: bool) -> Device {
if cpu_only {
return Device::Cpu;
}
#[cfg(feature = "cuda")]
{
if let Ok(dev) = Device::new_cuda(0) {
return dev;
}
}
#[cfg(feature = "metal")]
{
if let Ok(dev) = Device::new_metal(0) {
return dev;
}
}
Device::Cpu
}
const fn gcd(mut a: usize, mut b: usize) -> usize {
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a
}
fn lcm_all(strides: &[usize]) -> usize {
let mut acc = strides[0];
for &s in &strides[1..] {
acc = acc / gcd(acc, s) * s;
}
acc
}
fn per_codebook_lens(strides: &[usize], base_len: usize) -> Vec<usize> {
strides.iter().map(|s| base_len / *s).collect()
}
fn total_tokens(strides: &[usize], base_len: usize) -> usize {
per_codebook_lens(strides, base_len).iter().sum()
}
fn base_len_from_token_count(strides: &[usize], token_count: usize) -> Option<usize> {
if token_count == 0 {
return None;
}
let lcm = lcm_all(strides);
let step = total_tokens(strides, lcm);
if !token_count.is_multiple_of(step) {
return None;
}
Some((token_count / step) * lcm)
}
struct LoadedSnac {
inner: upstream::Model,
device: Device,
sample_rate: u32,
num_codebooks: usize,
vq_strides: Vec<usize>,
}
pub struct SnacBackend {
id: String,
config: SnacConfig,
loaded: Arc<OnceCell<LoadedSnac>>,
}
impl std::fmt::Debug for SnacBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SnacBackend")
.field("id", &self.id)
.field("config", &self.config)
.field("loaded", &self.loaded.initialized())
.finish()
}
}
impl SnacBackend {
#[must_use]
pub fn new(config: SnacConfig) -> Self {
let id = format!("snac:{}", config.repo_id);
Self {
id,
config,
loaded: Arc::new(OnceCell::new()),
}
}
#[must_use]
pub fn default_24khz() -> Self {
Self::new(SnacConfig::default())
}
#[must_use]
pub fn config(&self) -> &SnacConfig {
&self.config
}
#[must_use]
pub fn sample_rate_loaded(&self) -> Option<u32> {
self.loaded.get().map(|m| m.sample_rate)
}
#[must_use]
pub fn num_codebooks_loaded(&self) -> Option<usize> {
self.loaded.get().map(|m| m.num_codebooks)
}
#[must_use]
pub fn vq_strides_loaded(&self) -> Option<&[usize]> {
self.loaded.get().map(|m| m.vq_strides.as_slice())
}
async fn ensure_loaded(&self) -> Result<&LoadedSnac> {
self.loaded
.get_or_try_init(|| async { self.load_inner().await })
.await
}
async fn load_inner(&self) -> Result<LoadedSnac> {
let repo = self.config.repo_id.clone();
let revision = self.config.revision.clone();
let weights_filename = self.config.weights_filename.clone();
let config_filename = self.config.config_filename.clone();
let cache_dir = self.config.cache_dir.clone();
let (weights_path, config_path) =
tokio::task::spawn_blocking(move || -> Result<(PathBuf, PathBuf)> {
let mut builder = hf_hub::api::sync::ApiBuilder::new();
if let Some(dir) = cache_dir {
builder = builder.with_cache_dir(dir);
}
let api = builder.build().map_err(|e| CodecError::HfHub {
repo: repo.clone(),
source: std::io::Error::other(e.to_string()),
})?;
let model_repo = match revision {
Some(rev) => api.repo(hf_hub::Repo::with_revision(
repo.clone(),
hf_hub::RepoType::Model,
rev,
)),
None => api.model(repo.clone()),
};
let weights = model_repo
.get(&weights_filename)
.map_err(|e| CodecError::HfHub {
repo: repo.clone(),
source: std::io::Error::other(e.to_string()),
})?;
let cfg = model_repo
.get(&config_filename)
.map_err(|e| CodecError::HfHub {
repo,
source: std::io::Error::other(e.to_string()),
})?;
Ok((weights, cfg))
})
.await
.map_err(|e| CodecError::other(format!("blocking task join failed: {e}")))??;
let config_bytes = std::fs::read(&config_path).map_err(CodecError::Io)?;
let candle_cfg: upstream::Config = serde_json::from_slice(&config_bytes).map_err(|e| {
CodecError::other(format!(
"failed to parse SNAC config.json at {}: {e}",
config_path.display()
))
})?;
if candle_cfg.vq_strides.is_empty() {
return Err(CodecError::other(
"SNAC config.json has empty vq_strides; checkpoint is malformed",
));
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let sample_rate = candle_cfg.sampling_rate as u32;
let num_codebooks = candle_cfg.vq_strides.len();
let vq_strides = candle_cfg.vq_strides.clone();
let device = pick_device(self.config.cpu_only);
let vb =
VarBuilder::from_pth(&weights_path, DType::F32, &device).map_err(CodecError::from)?;
let inner = upstream::Model::new(&candle_cfg, vb).map_err(CodecError::from)?;
Ok(LoadedSnac {
inner,
device,
sample_rate,
num_codebooks,
vq_strides,
})
}
}
#[async_trait]
impl AudioBackend for SnacBackend {
fn id(&self) -> &str {
&self.id
}
fn provider_kind(&self) -> &'static str {
"codec"
}
async fn load(&self) -> std::result::Result<(), blazen_audio::AudioError> {
self.ensure_loaded()
.await
.map_err(blazen_audio::AudioError::from)?;
Ok(())
}
async fn is_loaded(&self) -> bool {
self.loaded.initialized()
}
}
#[async_trait]
impl CodecBackend for SnacBackend {
async fn encode_pcm(&self, samples: &[f32], sample_rate: u32) -> Result<Vec<u32>> {
if samples.is_empty() {
return Err(CodecError::invalid_input("PCM input is empty"));
}
let model = self.ensure_loaded().await?;
if sample_rate != model.sample_rate {
return Err(CodecError::invalid_input(format!(
"expected sample rate {} Hz, got {} Hz -- resample first",
model.sample_rate, sample_rate
)));
}
let xs = Tensor::from_slice(samples, (1, 1, samples.len()), &model.device)
.map_err(CodecError::from)?;
let codes_per_cb = model.inner.encode(&xs).map_err(CodecError::from)?;
let mut flat = Vec::with_capacity(
codes_per_cb
.iter()
.try_fold(0usize, |acc, t| Ok::<_, CodecError>(acc + t.elem_count()))?,
);
for codes in &codes_per_cb {
let codes_vec = codes
.i(0)
.map_err(CodecError::from)?
.flatten_all()
.map_err(CodecError::from)?
.to_vec1::<u32>()
.map_err(CodecError::from)?;
flat.extend_from_slice(&codes_vec);
}
Ok(flat)
}
async fn decode_tokens(&self, tokens: &[u32], num_codebooks: usize) -> Result<Vec<f32>> {
if num_codebooks == 0 {
return Err(CodecError::invalid_input("num_codebooks must be > 0"));
}
if tokens.is_empty() {
return Err(CodecError::invalid_input("tokens input is empty"));
}
let model = self.ensure_loaded().await?;
if num_codebooks != model.num_codebooks {
return Err(CodecError::invalid_input(format!(
"expected num_codebooks {} (from loaded model), got {num_codebooks}",
model.num_codebooks
)));
}
let base_len =
base_len_from_token_count(&model.vq_strides, tokens.len()).ok_or_else(|| {
let lcm = lcm_all(&model.vq_strides);
let step = total_tokens(&model.vq_strides, lcm);
CodecError::invalid_input(format!(
"token count {} is not a positive multiple of the SNAC multi-scale step {step} \
(vq_strides = {:?}, lcm = {lcm})",
tokens.len(),
model.vq_strides,
))
})?;
let lens = per_codebook_lens(&model.vq_strides, base_len);
let mut tensors = Vec::with_capacity(num_codebooks);
let mut offset = 0usize;
for len in &lens {
let slice = &tokens[offset..offset + *len];
let tensor =
Tensor::from_slice(slice, (1, *len), &model.device).map_err(CodecError::from)?;
tensors.push(tensor);
offset += *len;
}
let refs: Vec<&Tensor> = tensors.iter().collect();
let audio = model.inner.decode(&refs).map_err(CodecError::from)?;
let audio = audio
.i(0)
.map_err(CodecError::from)?
.i(0)
.map_err(CodecError::from)?
.flatten_all()
.map_err(CodecError::from)?;
audio.to_vec1::<f32>().map_err(CodecError::from)
}
fn sample_rate(&self) -> u32 {
self.sample_rate_loaded().unwrap_or(24_000)
}
fn num_codebooks(&self) -> usize {
self.num_codebooks_loaded().unwrap_or(3)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_uses_24khz_repo() {
let cfg = SnacConfig::default();
assert_eq!(cfg.repo_id, "hubertsiuzdak/snac_24khz");
assert_eq!(cfg.weights_filename, "pytorch_model.bin");
assert_eq!(cfg.config_filename, "config.json");
assert!(cfg.revision.is_none());
assert!(!cfg.cpu_only);
assert!(cfg.cache_dir.is_none());
}
#[test]
fn default_24khz_constructor_matches_default_config() {
let backend = SnacBackend::default_24khz();
assert_eq!(backend.config().repo_id, "hubertsiuzdak/snac_24khz");
assert_eq!(CodecBackend::sample_rate(&backend), 24_000);
assert_eq!(CodecBackend::num_codebooks(&backend), 3);
assert!(backend.sample_rate_loaded().is_none());
assert!(backend.num_codebooks_loaded().is_none());
assert!(backend.vq_strides_loaded().is_none());
}
#[test]
fn id_includes_repo_name() {
let backend = SnacBackend::default_24khz();
assert_eq!(backend.id(), "snac:hubertsiuzdak/snac_24khz");
assert_eq!(backend.provider_kind(), "codec");
}
#[test]
fn custom_repo_round_trips_through_config() {
let cfg = SnacConfig {
repo_id: "hubertsiuzdak/snac_32khz".to_string(),
revision: Some("main".to_string()),
..Default::default()
};
let backend = SnacBackend::new(cfg);
assert_eq!(backend.id(), "snac:hubertsiuzdak/snac_32khz");
assert_eq!(backend.config().revision.as_deref(), Some("main"));
}
#[test]
fn lcm_all_matches_known_snac_strides() {
assert_eq!(lcm_all(&[4, 2, 1]), 4);
assert_eq!(lcm_all(&[8, 4, 2, 1]), 8);
assert_eq!(lcm_all(&[3, 5]), 15);
}
#[test]
fn per_codebook_lens_split_evenly_at_step_boundary() {
let strides = vec![4, 2, 1];
assert_eq!(per_codebook_lens(&strides, 4), vec![1, 2, 4]);
assert_eq!(per_codebook_lens(&strides, 8), vec![2, 4, 8]);
}
#[test]
fn total_tokens_for_canonical_24khz_strides() {
assert_eq!(total_tokens(&[4, 2, 1], 4), 7);
assert_eq!(total_tokens(&[4, 2, 1], 8), 14);
assert_eq!(total_tokens(&[4, 2, 1], 12), 21);
}
#[test]
fn base_len_from_token_count_round_trips_for_valid_inputs() {
let strides = vec![4, 2, 1];
assert_eq!(base_len_from_token_count(&strides, 7), Some(4));
assert_eq!(base_len_from_token_count(&strides, 14), Some(8));
assert_eq!(base_len_from_token_count(&strides, 21), Some(12));
}
#[test]
fn base_len_from_token_count_rejects_non_step_multiples() {
let strides = vec![4, 2, 1];
assert_eq!(base_len_from_token_count(&strides, 6), None);
assert_eq!(base_len_from_token_count(&strides, 8), None);
assert_eq!(base_len_from_token_count(&strides, 13), None);
assert_eq!(base_len_from_token_count(&strides, 0), None);
}
#[tokio::test]
async fn encode_rejects_empty_pcm() {
let backend = SnacBackend::default_24khz();
let err = backend.encode_pcm(&[], 24_000).await.unwrap_err();
match err {
CodecError::InvalidInput(msg) => assert!(msg.contains("empty")),
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn decode_rejects_empty_tokens() {
let backend = SnacBackend::default_24khz();
let err = backend.decode_tokens(&[], 3).await.unwrap_err();
match err {
CodecError::InvalidInput(msg) => assert!(msg.contains("empty")),
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn decode_rejects_zero_codebooks() {
let backend = SnacBackend::default_24khz();
let err = backend.decode_tokens(&[1, 2, 3], 0).await.unwrap_err();
match err {
CodecError::InvalidInput(msg) => assert!(msg.contains("num_codebooks")),
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn is_loaded_starts_false() {
let backend = SnacBackend::default_24khz();
assert!(!AudioBackend::is_loaded(&backend).await);
}
#[tokio::test]
async fn provider_kind_dispatches_through_audio_backend_trait() {
let backend: Arc<dyn AudioBackend> = Arc::new(SnacBackend::default_24khz());
assert_eq!(backend.provider_kind(), "codec");
assert!(backend.id().starts_with("snac:"));
}
#[cfg(feature = "live-models")]
#[tokio::test]
async fn live_round_trip_sine_wave_24khz() {
let backend = SnacBackend::default_24khz();
backend
.load()
.await
.expect("load default 24 kHz SNAC weights");
let num_codebooks = backend
.num_codebooks_loaded()
.expect("loaded codebook count");
let sample_rate = backend.sample_rate_loaded().expect("loaded sample rate");
let strides = backend
.vq_strides_loaded()
.expect("loaded vq_strides")
.to_vec();
assert_eq!(sample_rate, 24_000);
assert_eq!(num_codebooks, 3);
assert_eq!(strides, vec![4, 2, 1]);
let len = 24_000usize;
let mut pcm: Vec<f32> = Vec::with_capacity(len);
for i in 0..len {
#[allow(clippy::cast_precision_loss)]
let t = i as f32 / 24_000.0;
pcm.push(0.5 * (2.0 * std::f32::consts::PI * 440.0 * t).sin());
}
let codes = backend
.encode_pcm(&pcm, 24_000)
.await
.expect("encode 24 kHz sine wave");
assert!(!codes.is_empty(), "encoded codes must not be empty");
let step = total_tokens(&strides, lcm_all(&strides));
assert!(
codes.len().is_multiple_of(step),
"encoded token count {} should be a multiple of the SNAC step {step}",
codes.len()
);
let decoded = backend
.decode_tokens(&codes, num_codebooks)
.await
.expect("decode SNAC codes");
assert!(!decoded.is_empty(), "decoded PCM must not be empty");
for s in &decoded {
assert!(s.is_finite(), "decoded sample must be finite, got {s}");
}
let tolerance = 4096usize;
let diff = decoded.len().abs_diff(len);
assert!(
diff <= tolerance,
"decoded len {} differs from input len {len} by more than {tolerance}",
decoded.len()
);
}
}