use crate::tokenizer::TextTokenizer;
use crate::voxcpm2::model::{
AUDIO_START_TOKEN, REF_AUDIO_END_TOKEN, REF_AUDIO_START_TOKEN, VoxCpm2Model,
};
use crate::VoxCpm2Config;
use burn::prelude::*;
use burn::tensor::{Int, TensorData};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Debug, Clone, Default)]
pub struct CancelToken(Arc<AtomicBool>);
impl CancelToken {
pub fn new() -> Self {
Self::default()
}
pub fn cancel(&self) {
self.0.store(true, Ordering::Relaxed);
}
pub fn is_cancelled(&self) -> bool {
self.0.load(Ordering::Relaxed)
}
}
#[derive(Debug, Clone)]
pub enum PromptAudio {
File(PathBuf),
Encoded(Vec<u8>),
Pcm {
samples: Vec<f32>,
sample_rate: u32,
},
}
impl From<PathBuf> for PromptAudio {
fn from(p: PathBuf) -> Self {
PromptAudio::File(p)
}
}
impl From<&Path> for PromptAudio {
fn from(p: &Path) -> Self {
PromptAudio::File(p.to_path_buf())
}
}
impl From<&str> for PromptAudio {
fn from(p: &str) -> Self {
PromptAudio::File(PathBuf::from(p))
}
}
#[derive(Debug, Clone, Default)]
pub enum Prompt {
#[default]
None,
Reference {
audio: PromptAudio,
},
Continuation {
audio: PromptAudio,
text: String,
},
Combined {
reference_audio: PromptAudio,
prompt_audio: PromptAudio,
prompt_text: String,
},
}
#[derive(Debug, Clone)]
pub struct GenerateOptions {
pub cfg_value: f32,
pub inference_timesteps: usize,
pub min_len: usize,
pub max_len: usize,
pub prompt: Prompt,
pub cancel: Option<CancelToken>,
}
impl Default for GenerateOptions {
fn default() -> Self {
Self {
cfg_value: 2.0,
inference_timesteps: 10,
min_len: 2,
max_len: 2000,
prompt: Prompt::None,
cancel: None,
}
}
}
impl GenerateOptions {
pub fn builder() -> GenerateOptionsBuilder {
GenerateOptionsBuilder::default()
}
}
#[derive(Debug, Clone, Default)]
pub struct GenerateOptionsBuilder {
inner: GenerateOptions,
}
impl GenerateOptionsBuilder {
pub fn cfg(mut self, v: f32) -> Self {
self.inner.cfg_value = v;
self
}
pub fn timesteps(mut self, n: usize) -> Self {
self.inner.inference_timesteps = n;
self
}
pub fn min_len(mut self, n: usize) -> Self {
self.inner.min_len = n;
self
}
pub fn max_len(mut self, n: usize) -> Self {
self.inner.max_len = n;
self
}
pub fn prompt(mut self, p: Prompt) -> Self {
self.inner.prompt = p;
self
}
pub fn cancel(mut self, token: CancelToken) -> Self {
self.inner.cancel = Some(token);
self
}
pub fn build(self) -> GenerateOptions {
self.inner
}
}
#[derive(Debug)]
pub struct VoxCPM<B: Backend> {
pub model: VoxCpm2Model<B>,
pub tokenizer: TextTokenizer,
device: B::Device,
}
impl<B: Backend> VoxCPM<B> {
pub fn from_config(
config: VoxCpm2Config,
tokenizer: TextTokenizer,
device: &B::Device,
) -> Self {
Self {
model: VoxCpm2Model::new(config, device),
tokenizer,
device: device.clone(),
}
}
pub fn from_local(path: impl AsRef<Path>, device: &B::Device) -> crate::Result<Self> {
let path = path.as_ref();
let config_bytes = std::fs::read_to_string(path.join("config.json"))?;
let config: VoxCpm2Config = serde_json::from_str(&config_bytes)?;
let tokenizer = TextTokenizer::from_local(path)?;
let mut model = VoxCpm2Model::<B>::new(config, device);
let result = crate::weights::load_pretrained(&mut model, path)?;
log::info!(
"weights loaded — applied={}, skipped={}, missing={}, unused={}, errors={}",
result.applied.len(),
result.skipped.len(),
result.missing.len(),
result.unused.len(),
result.errors.len(),
);
if !result.missing.is_empty() {
log::warn!("missing module params (first 20):");
for (k, ctx) in result.missing.iter().take(20) {
log::warn!(" {k} [{ctx}]");
}
}
if !result.unused.is_empty() {
log::warn!("unused checkpoint tensors (first 20):");
for k in result.unused.iter().take(20) {
log::warn!(" {k}");
}
}
if !result.errors.is_empty() {
log::error!("load errors (first 20):");
for e in result.errors.iter().take(20) {
log::error!(" {e:?}");
}
}
Ok(Self {
model,
tokenizer,
device: device.clone(),
})
}
pub fn sample_rate(&self) -> u32 {
self.model.sample_rate() as u32
}
pub fn audio_vae_decode(&self, feat: Tensor<B, 3>) -> Tensor<B, 3> {
self.model.audio_vae.decode(feat)
}
fn encode_prompt_audio(
&self,
audio: &PromptAudio,
padding_mode: PadMode,
) -> crate::Result<Tensor<B, 3>> {
let encoder_sr = self.model.audio_vae.sample_rate() as u32;
let mut samples = match audio {
PromptAudio::File(path) => crate::audio::load_audio_as(path, encoder_sr)?,
PromptAudio::Encoded(bytes) => crate::audio::load_audio_bytes_as(bytes, encoder_sr)?,
PromptAudio::Pcm { samples, sample_rate } => {
crate::audio::resample(samples, *sample_rate, encoder_sr)?
}
};
let p = self.model.patch_size();
let chunk = self.model.audio_vae.config.0.chunk_size();
let patch_len = p * chunk;
let n = samples.len();
if n == 0 {
return Err(crate::Error::AudioDecode(
"prompt audio decoded to 0 samples".into(),
));
}
let rem = n % patch_len;
if rem != 0 {
let pad = patch_len - rem;
match padding_mode {
PadMode::Right => samples.resize(n + pad, 0.0),
PadMode::Left => {
let mut new = vec![0.0f32; pad];
new.extend_from_slice(&samples);
samples = new;
}
}
}
let n_padded = samples.len();
let audio: Tensor<B, 3> =
Tensor::from_data(TensorData::new(samples, [1, 1, n_padded]), &self.device);
let feat = self.model.audio_vae.encode(audio); let [_, d, tp] = feat.dims();
debug_assert_eq!(tp % p, 0);
let t = tp / p;
let feat: Tensor<B, 3> = feat.reshape([d, t, p]);
let feat = feat.swap_dims(0, 1).swap_dims(1, 2);
Ok(feat)
}
pub fn generate(&self, text: &str, opts: GenerateOptions) -> crate::Result<Vec<f32>> {
let (ref_audio, prompt_audio, prompt_text) = match &opts.prompt {
Prompt::None => (None, None, None),
Prompt::Reference { audio } => (Some(audio), None, None),
Prompt::Continuation { audio, text } => (None, Some(audio), Some(text.as_str())),
Prompt::Combined {
reference_audio,
prompt_audio,
prompt_text,
} => (
Some(reference_audio),
Some(prompt_audio),
Some(prompt_text.as_str()),
),
};
let device = &self.device;
let p = self.model.patch_size();
let d = self.model.latent_dim();
let full_text: String = match prompt_text {
Some(pt) => format!("{pt}{text}"),
None => text.to_string(),
};
let mut text_tokens = self.tokenizer.encode(&full_text)?;
text_tokens.push(AUDIO_START_TOKEN);
let text_len = text_tokens.len();
let ref_feat_opt = match ref_audio {
Some(audio) => Some(self.encode_prompt_audio(audio, PadMode::Right)?),
None => None,
};
let prompt_feat_opt = match prompt_audio {
Some(audio) => Some(self.encode_prompt_audio(audio, PadMode::Left)?),
None => None,
};
let z_patch = |n: usize| -> Tensor<B, 3> { Tensor::<B, 3>::zeros([n, p, d], device) };
let mut tokens: Vec<i64> = Vec::new();
let mut t_mask: Vec<f32> = Vec::new();
let mut f_mask: Vec<f32> = Vec::new();
let mut feat_chunks: Vec<Tensor<B, 3>> = Vec::new();
if let Some(ref_feat) = ref_feat_opt {
let ref_len = ref_feat.dims()[0];
tokens.push(REF_AUDIO_START_TOKEN);
tokens.extend(std::iter::repeat_n(0i64, ref_len));
tokens.push(REF_AUDIO_END_TOKEN);
t_mask.push(1.0);
t_mask.extend(std::iter::repeat_n(0.0, ref_len));
t_mask.push(1.0);
f_mask.push(0.0);
f_mask.extend(std::iter::repeat_n(1.0, ref_len));
f_mask.push(0.0);
feat_chunks.push(z_patch(1));
feat_chunks.push(ref_feat);
feat_chunks.push(z_patch(1));
}
tokens.extend_from_slice(&text_tokens);
t_mask.extend(std::iter::repeat_n(1.0, text_len));
f_mask.extend(std::iter::repeat_n(0.0, text_len));
feat_chunks.push(z_patch(text_len));
if let Some(prompt_feat) = prompt_feat_opt {
let prompt_len = prompt_feat.dims()[0];
tokens.extend(std::iter::repeat_n(0i64, prompt_len));
t_mask.extend(std::iter::repeat_n(0.0, prompt_len));
f_mask.extend(std::iter::repeat_n(1.0, prompt_len));
feat_chunks.push(prompt_feat);
}
let s = tokens.len();
let feat_seq = if feat_chunks.len() == 1 {
feat_chunks.pop().unwrap()
} else {
Tensor::cat(feat_chunks, 0)
};
let text_token_t: Tensor<B, 2, Int> =
Tensor::from_data(TensorData::new(tokens, [1, s]), device);
let text_mask_t: Tensor<B, 2> =
Tensor::from_data(TensorData::new(t_mask, [1, s]), device);
let feat_mask_t: Tensor<B, 2> =
Tensor::from_data(TensorData::new(f_mask, [1, s]), device);
let feat_t: Tensor<B, 4> = feat_seq.unsqueeze_dim(0);
let cancel_fn: Option<Box<dyn Fn() -> bool>> = opts
.cancel
.as_ref()
.map(|c| {
let c = c.clone();
Box::new(move || c.is_cancelled()) as Box<dyn Fn() -> bool>
});
let latent = self.model.inference(
text_token_t,
text_mask_t,
feat_t,
feat_mask_t,
opts.min_len,
opts.max_len,
opts.inference_timesteps,
opts.cfg_value as f64,
cancel_fn.as_deref(),
)?;
let wav = self.model.audio_vae.decode(latent);
let wav = wav.squeeze_dim::<2>(1); let wav = wav.squeeze_dim::<1>(0); let data = wav.into_data();
let samples: Vec<f32> = data
.convert::<f32>()
.into_vec::<f32>()
.map_err(|_| crate::Error::Other("unexpected VAE output dtype".into()))?;
Ok(samples)
}
}
#[derive(Debug, Clone, Copy)]
enum PadMode {
Right,
Left,
}