use crate::tokenizer::TextTokenizer;
use crate::voxcpm2::model::{
AUDIO_START_TOKEN, REF_AUDIO_END_TOKEN, REF_AUDIO_START_TOKEN, VoxCpm2Model,
};
use crate::VoxCpm2Config;
use burn::prelude::*;
use burn::tensor::{Int, TensorData};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Debug, Clone, Default)]
pub struct CancelToken(Arc<AtomicBool>);
impl CancelToken {
pub fn new() -> Self {
Self::default()
}
pub fn cancel(&self) {
self.0.store(true, Ordering::Relaxed);
}
pub fn is_cancelled(&self) -> bool {
self.0.load(Ordering::Relaxed)
}
}
#[derive(Debug, Clone)]
pub enum PromptAudio {
File(PathBuf),
Encoded(Vec<u8>),
Pcm {
samples: Vec<f32>,
sample_rate: u32,
},
}
impl From<PathBuf> for PromptAudio {
fn from(p: PathBuf) -> Self {
PromptAudio::File(p)
}
}
impl From<&Path> for PromptAudio {
fn from(p: &Path) -> Self {
PromptAudio::File(p.to_path_buf())
}
}
impl From<&str> for PromptAudio {
fn from(p: &str) -> Self {
PromptAudio::File(PathBuf::from(p))
}
}
#[derive(Debug, Clone, Default)]
pub enum Prompt {
#[default]
None,
Reference {
audio: PromptAudio,
},
Continuation {
audio: PromptAudio,
text: String,
},
Combined {
reference_audio: PromptAudio,
prompt_audio: PromptAudio,
prompt_text: String,
},
}
#[derive(Debug, Clone)]
pub struct GenerateOptions {
pub cfg_value: f32,
pub inference_timesteps: usize,
pub min_len: usize,
pub max_len: usize,
pub prompt: Prompt,
pub cancel: Option<CancelToken>,
pub chunk_patches: usize,
pub parallel_segments: Option<usize>,
}
impl Default for GenerateOptions {
fn default() -> Self {
Self {
cfg_value: 2.0,
inference_timesteps: 10,
min_len: 2,
max_len: 2000,
prompt: Prompt::None,
cancel: None,
chunk_patches: 5,
parallel_segments: None,
}
}
}
impl GenerateOptions {
pub fn builder() -> GenerateOptionsBuilder {
GenerateOptionsBuilder::default()
}
}
#[derive(Debug, Clone, Default)]
pub struct GenerateOptionsBuilder {
inner: GenerateOptions,
}
impl GenerateOptionsBuilder {
pub fn cfg(mut self, v: f32) -> Self {
self.inner.cfg_value = v;
self
}
pub fn timesteps(mut self, n: usize) -> Self {
self.inner.inference_timesteps = n;
self
}
pub fn min_len(mut self, n: usize) -> Self {
self.inner.min_len = n;
self
}
pub fn max_len(mut self, n: usize) -> Self {
self.inner.max_len = n;
self
}
pub fn prompt(mut self, p: Prompt) -> Self {
self.inner.prompt = p;
self
}
pub fn cancel(mut self, token: CancelToken) -> Self {
self.inner.cancel = Some(token);
self
}
pub fn chunk_patches(mut self, n: usize) -> Self {
self.inner.chunk_patches = n;
self
}
pub fn parallel_segments(mut self, n: usize) -> Self {
self.inner.parallel_segments = Some(n);
self
}
pub fn build(self) -> GenerateOptions {
self.inner
}
}
#[derive(Debug)]
pub struct VoxCPM<B: Backend> {
pub model: VoxCpm2Model<B>,
pub tokenizer: TextTokenizer,
device: B::Device,
}
impl<B: Backend> VoxCPM<B> {
pub fn from_config(
config: VoxCpm2Config,
tokenizer: TextTokenizer,
device: &B::Device,
) -> Self {
Self {
model: VoxCpm2Model::new(config, device),
tokenizer,
device: device.clone(),
}
}
pub fn from_local(path: impl AsRef<Path>, device: &B::Device) -> crate::Result<Self> {
let path = path.as_ref();
let config_bytes = std::fs::read_to_string(path.join("config.json"))?;
let config: VoxCpm2Config = serde_json::from_str(&config_bytes)?;
let tokenizer = TextTokenizer::from_local(path)?;
let mut model = VoxCpm2Model::<B>::new(config, device);
let result = crate::weights::load_pretrained(&mut model, path)?;
log::info!(
"weights loaded — applied={}, skipped={}, missing={}, unused={}, errors={}",
result.applied.len(),
result.skipped.len(),
result.missing.len(),
result.unused.len(),
result.errors.len(),
);
if !result.missing.is_empty() {
log::warn!("missing module params (first 20):");
for (k, ctx) in result.missing.iter().take(20) {
log::warn!(" {k} [{ctx}]");
}
}
if !result.unused.is_empty() {
log::warn!("unused checkpoint tensors (first 20):");
for k in result.unused.iter().take(20) {
log::warn!(" {k}");
}
}
if !result.errors.is_empty() {
log::error!("load errors (first 20):");
for e in result.errors.iter().take(20) {
log::error!(" {e:?}");
}
}
Ok(Self {
model,
tokenizer,
device: device.clone(),
})
}
pub fn sample_rate(&self) -> u32 {
self.model.sample_rate() as u32
}
pub fn audio_vae_decode(&self, feat: Tensor<B, 3>) -> Tensor<B, 3> {
self.model.audio_vae.decode(feat)
}
fn encode_prompt_audio(
&self,
audio: &PromptAudio,
padding_mode: PadMode,
) -> crate::Result<Tensor<B, 3>> {
let encoder_sr = self.model.audio_vae.sample_rate() as u32;
let mut samples = match audio {
PromptAudio::File(path) => crate::audio::load_audio_as(path, encoder_sr)?,
PromptAudio::Encoded(bytes) => crate::audio::load_audio_bytes_as(bytes, encoder_sr)?,
PromptAudio::Pcm { samples, sample_rate } => {
crate::audio::resample(samples, *sample_rate, encoder_sr)?
}
};
let p = self.model.patch_size();
let chunk = self.model.audio_vae.config.0.chunk_size();
let patch_len = p * chunk;
let n = samples.len();
if n == 0 {
return Err(crate::Error::AudioDecode(
"prompt audio decoded to 0 samples".into(),
));
}
let rem = n % patch_len;
if rem != 0 {
let pad = patch_len - rem;
match padding_mode {
PadMode::Right => samples.resize(n + pad, 0.0),
PadMode::Left => {
let mut new = vec![0.0f32; pad];
new.extend_from_slice(&samples);
samples = new;
}
}
}
let n_padded = samples.len();
let audio: Tensor<B, 3> =
Tensor::from_data(TensorData::new(samples, [1, 1, n_padded]), &self.device);
let feat = self.model.audio_vae.encode(audio); let [_, d, tp] = feat.dims();
debug_assert_eq!(tp % p, 0);
let t = tp / p;
let feat: Tensor<B, 3> = feat.reshape([d, t, p]);
let feat = feat.swap_dims(0, 1).swap_dims(1, 2);
Ok(feat)
}
pub fn generate(&self, text: &str, opts: GenerateOptions) -> crate::Result<Vec<f32>> {
if let Some(parallel_n) = opts.parallel_segments {
if parallel_n >= 2 && matches!(opts.prompt, Prompt::None | Prompt::Reference { .. }) {
let segments = split_sentences(text);
if segments.len() >= 2 {
return self.generate_parallel(&segments, parallel_n, &opts);
}
}
}
let inputs = self.build_inference_inputs(text, &opts.prompt)?;
let cancel_fn: Option<Box<dyn Fn() -> bool>> = opts.cancel.as_ref().map(|c| {
let c = c.clone();
Box::new(move || c.is_cancelled()) as Box<dyn Fn() -> bool>
});
let (latent, _stop_steps) = self.model.inference(
inputs.text_token,
inputs.text_mask,
inputs.feat,
inputs.feat_mask,
opts.min_len,
opts.max_len,
opts.inference_timesteps,
opts.cfg_value as f64,
cancel_fn.as_deref(),
)?;
Ok(decode_latent_to_samples(&self.model.audio_vae, latent)?)
}
pub fn generate_stream(
&self,
text: &str,
opts: GenerateOptions,
) -> crate::Result<GenerateStream<'_, B>> {
let inputs = self.build_inference_inputs(text, &opts.prompt)?;
let state = self.model.prefill(
inputs.text_token,
inputs.text_mask,
inputs.feat,
inputs.feat_mask,
opts.max_len,
);
Ok(GenerateStream {
model: &self.model,
state,
pred_feats: Vec::new(),
samples_emitted: 0,
step: 0,
min_len: opts.min_len,
max_len: opts.max_len,
inference_timesteps: opts.inference_timesteps,
cfg_value: opts.cfg_value as f64,
chunk_patches: opts.chunk_patches.max(1),
cancel: opts.cancel,
finished: false,
})
}
fn build_inference_inputs(
&self,
text: &str,
prompt: &Prompt,
) -> crate::Result<InferenceInputs<B>> {
let (ref_audio, prompt_audio, prompt_text) = match prompt {
Prompt::None => (None, None, None),
Prompt::Reference { audio } => (Some(audio), None, None),
Prompt::Continuation { audio, text } => (None, Some(audio), Some(text.as_str())),
Prompt::Combined {
reference_audio,
prompt_audio,
prompt_text,
} => (
Some(reference_audio),
Some(prompt_audio),
Some(prompt_text.as_str()),
),
};
let device = &self.device;
let p = self.model.patch_size();
let d = self.model.latent_dim();
let full_text: String = match prompt_text {
Some(pt) => format!("{pt}{text}"),
None => text.to_string(),
};
let mut text_tokens = self.tokenizer.encode(&full_text)?;
text_tokens.push(AUDIO_START_TOKEN);
let text_len = text_tokens.len();
let ref_feat_opt = match ref_audio {
Some(audio) => Some(self.encode_prompt_audio(audio, PadMode::Right)?),
None => None,
};
let prompt_feat_opt = match prompt_audio {
Some(audio) => Some(self.encode_prompt_audio(audio, PadMode::Left)?),
None => None,
};
let z_patch = |n: usize| -> Tensor<B, 3> { Tensor::<B, 3>::zeros([n, p, d], device) };
let mut tokens: Vec<i64> = Vec::new();
let mut t_mask: Vec<f32> = Vec::new();
let mut f_mask: Vec<f32> = Vec::new();
let mut feat_chunks: Vec<Tensor<B, 3>> = Vec::new();
if let Some(ref_feat) = ref_feat_opt {
let ref_len = ref_feat.dims()[0];
tokens.push(REF_AUDIO_START_TOKEN);
tokens.extend(std::iter::repeat_n(0i64, ref_len));
tokens.push(REF_AUDIO_END_TOKEN);
t_mask.push(1.0);
t_mask.extend(std::iter::repeat_n(0.0, ref_len));
t_mask.push(1.0);
f_mask.push(0.0);
f_mask.extend(std::iter::repeat_n(1.0, ref_len));
f_mask.push(0.0);
feat_chunks.push(z_patch(1));
feat_chunks.push(ref_feat);
feat_chunks.push(z_patch(1));
}
tokens.extend_from_slice(&text_tokens);
t_mask.extend(std::iter::repeat_n(1.0, text_len));
f_mask.extend(std::iter::repeat_n(0.0, text_len));
feat_chunks.push(z_patch(text_len));
if let Some(prompt_feat) = prompt_feat_opt {
let prompt_len = prompt_feat.dims()[0];
tokens.extend(std::iter::repeat_n(0i64, prompt_len));
t_mask.extend(std::iter::repeat_n(0.0, prompt_len));
f_mask.extend(std::iter::repeat_n(1.0, prompt_len));
feat_chunks.push(prompt_feat);
}
let s = tokens.len();
let feat_seq = if feat_chunks.len() == 1 {
feat_chunks.pop().unwrap()
} else {
Tensor::cat(feat_chunks, 0)
};
let text_token: Tensor<B, 2, Int> =
Tensor::from_data(TensorData::new(tokens, [1, s]), device);
let text_mask: Tensor<B, 2> =
Tensor::from_data(TensorData::new(t_mask, [1, s]), device);
let feat_mask: Tensor<B, 2> =
Tensor::from_data(TensorData::new(f_mask, [1, s]), device);
let feat: Tensor<B, 4> = feat_seq.unsqueeze_dim(0);
Ok(InferenceInputs {
text_token,
text_mask,
feat,
feat_mask,
})
}
fn generate_one_with_prompt(
&self,
text: &str,
prompt: &Prompt,
opts: &GenerateOptions,
cancel_fn: Option<&dyn Fn() -> bool>,
) -> crate::Result<Vec<f32>> {
let inputs = self.build_inference_inputs(text, prompt)?;
let (latent, _stops) = self.model.inference(
inputs.text_token,
inputs.text_mask,
inputs.feat,
inputs.feat_mask,
opts.min_len,
opts.max_len,
opts.inference_timesteps,
opts.cfg_value as f64,
cancel_fn,
)?;
decode_latent_to_samples(&self.model.audio_vae, latent)
}
fn build_batched_inputs(
&self,
texts: &[&str],
ref_feat_opt: Option<&Tensor<B, 3>>,
) -> crate::Result<(
Tensor<B, 2, Int>,
Tensor<B, 2>,
Tensor<B, 4>,
Tensor<B, 2>,
Vec<usize>,
)> {
let device = &self.device;
let p = self.model.patch_size();
let d = self.model.latent_dim();
let mut rows_tt: Vec<Vec<i64>> = Vec::with_capacity(texts.len());
let mut rows_tm: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
let mut rows_fm: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
let mut rows_feat_chunks: Vec<Vec<Tensor<B, 3>>> = Vec::with_capacity(texts.len());
let mut row_lens: Vec<usize> = Vec::with_capacity(texts.len());
let z_patch = |n: usize| -> Tensor<B, 3> { Tensor::<B, 3>::zeros([n, p, d], device) };
let ref_len = ref_feat_opt.map(|f| f.dims()[0]).unwrap_or(0);
for &text in texts {
let mut text_tokens = self.tokenizer.encode(text)?;
text_tokens.push(AUDIO_START_TOKEN);
let text_len = text_tokens.len();
let mut tt: Vec<i64> = Vec::new();
let mut tm: Vec<f32> = Vec::new();
let mut fm: Vec<f32> = Vec::new();
let mut feat_chunks: Vec<Tensor<B, 3>> = Vec::new();
if let Some(rf) = ref_feat_opt {
tt.push(REF_AUDIO_START_TOKEN);
tt.extend(std::iter::repeat_n(0i64, ref_len));
tt.push(REF_AUDIO_END_TOKEN);
tm.push(1.0);
tm.extend(std::iter::repeat_n(0.0, ref_len));
tm.push(1.0);
fm.push(0.0);
fm.extend(std::iter::repeat_n(1.0, ref_len));
fm.push(0.0);
feat_chunks.push(z_patch(1));
feat_chunks.push(rf.clone());
feat_chunks.push(z_patch(1));
}
tt.extend_from_slice(&text_tokens);
tm.extend(std::iter::repeat_n(1.0, text_len));
fm.extend(std::iter::repeat_n(0.0, text_len));
feat_chunks.push(z_patch(text_len));
row_lens.push(tt.len());
rows_tt.push(tt);
rows_tm.push(tm);
rows_fm.push(fm);
rows_feat_chunks.push(feat_chunks);
}
let max_s = *row_lens.iter().max().unwrap_or(&0);
if max_s == 0 {
return Err(crate::Error::Other("no text in any segment".into()));
}
let batch = texts.len();
let mut pad_tt = vec![0i64; batch * max_s];
let mut pad_tm = vec![0.0f32; batch * max_s];
let mut pad_fm = vec![0.0f32; batch * max_s];
let mut row_feats: Vec<Tensor<B, 4>> = Vec::with_capacity(batch);
for (b, ((tt, tm), (fm, mut chunks))) in rows_tt.iter().zip(rows_tm.iter())
.zip(rows_fm.iter().zip(rows_feat_chunks.into_iter()))
.enumerate()
{
let s = tt.len();
for j in 0..s {
pad_tt[b * max_s + j] = tt[j];
pad_tm[b * max_s + j] = tm[j];
pad_fm[b * max_s + j] = fm[j];
}
let feat_row: Tensor<B, 3> = if chunks.len() == 1 {
chunks.pop().unwrap()
} else {
Tensor::cat(chunks, 0)
};
let feat_row = if s == max_s {
feat_row
} else {
let pad = max_s - s;
Tensor::cat(vec![feat_row, z_patch(pad)], 0)
};
row_feats.push(feat_row.unsqueeze::<4>()); }
let text_token: Tensor<B, 2, Int> =
Tensor::from_data(TensorData::new(pad_tt, [batch, max_s]), device);
let text_mask: Tensor<B, 2> =
Tensor::from_data(TensorData::new(pad_tm, [batch, max_s]), device);
let feat_mask: Tensor<B, 2> =
Tensor::from_data(TensorData::new(pad_fm, [batch, max_s]), device);
let feat: Tensor<B, 4> = Tensor::cat(row_feats, 0);
Ok((text_token, text_mask, feat, feat_mask, row_lens))
}
fn pcm_to_ref_feat(&self, samples: &[f32]) -> crate::Result<Tensor<B, 3>> {
let in_sr = self.model.sample_rate() as u32;
let enc_sr = self.model.audio_vae.sample_rate() as u32;
let resampled = if in_sr == enc_sr {
samples.to_vec()
} else {
crate::audio::resample(samples, in_sr, enc_sr)?
};
self.encode_prompt_audio(
&PromptAudio::Pcm { samples: resampled, sample_rate: enc_sr },
PadMode::Right,
)
}
pub fn batch(&self) -> BatchBuilder<'_, B> {
BatchBuilder { voxcpm: self, items: Vec::new() }
}
fn run_batch(
&self,
items: Vec<(String, Prompt)>,
opts: GenerateOptions,
) -> crate::Result<Vec<Vec<f32>>> {
let device = &self.device;
let p = self.model.patch_size();
let d = self.model.latent_dim();
let mut rows: Vec<InferenceInputs<B>> = Vec::with_capacity(items.len());
let mut lens: Vec<usize> = Vec::with_capacity(items.len());
for (text, prompt) in &items {
let inp = self.build_inference_inputs(text, prompt)?;
lens.push(inp.text_token.dims()[1]);
rows.push(inp);
}
let max_s = *lens.iter().max().unwrap();
let mut tt_rows: Vec<Tensor<B, 2, Int>> = Vec::with_capacity(rows.len());
let mut tm_rows: Vec<Tensor<B, 2>> = Vec::with_capacity(rows.len());
let mut fm_rows: Vec<Tensor<B, 2>> = Vec::with_capacity(rows.len());
let mut feat_rows: Vec<Tensor<B, 4>> = Vec::with_capacity(rows.len());
for (i, inp) in rows.into_iter().enumerate() {
let s = lens[i];
let pad = max_s - s;
let (tt, tm, ft, fm) = if pad == 0 {
(inp.text_token, inp.text_mask, inp.feat, inp.feat_mask)
} else {
let tt_pad: Tensor<B, 2, Int> =
Tensor::zeros([1, pad], device);
let tm_pad: Tensor<B, 2> = Tensor::zeros([1, pad], device);
let fm_pad: Tensor<B, 2> = Tensor::zeros([1, pad], device);
let ft_pad: Tensor<B, 4> = Tensor::zeros([1, pad, p, d], device);
(
Tensor::cat(vec![inp.text_token, tt_pad], 1),
Tensor::cat(vec![inp.text_mask, tm_pad], 1),
Tensor::cat(vec![inp.feat, ft_pad], 1),
Tensor::cat(vec![inp.feat_mask, fm_pad], 1),
)
};
tt_rows.push(tt);
tm_rows.push(tm);
feat_rows.push(ft);
fm_rows.push(fm);
}
let text_token = Tensor::cat(tt_rows, 0);
let text_mask = Tensor::cat(tm_rows, 0);
let feat = Tensor::cat(feat_rows, 0);
let feat_mask = Tensor::cat(fm_rows, 0);
let cancel_fn: Option<Box<dyn Fn() -> bool>> = opts.cancel.as_ref().map(|c| {
let c = c.clone();
Box::new(move || c.is_cancelled()) as Box<dyn Fn() -> bool>
});
let (latent, stops) = self.model.inference_with_lengths(
text_token,
text_mask,
feat,
feat_mask,
opts.min_len,
opts.max_len,
opts.inference_timesteps,
opts.cfg_value as f64,
cancel_fn.as_deref(),
Some(lens),
)?;
let dims = latent.dims();
let mut out: Vec<Vec<f32>> = Vec::with_capacity(items.len());
for i in 0..items.len() {
let stop_i = stops[i];
let pat = (stop_i * p).min(dims[2]);
if pat == 0 {
out.push(Vec::new());
continue;
}
let lat_i = latent.clone().slice([i..i + 1, 0..dims[1], 0..pat]);
let pcm = decode_latent_to_samples(&self.model.audio_vae, lat_i)?;
out.push(pcm);
}
Ok(out)
}
fn generate_parallel(
&self,
segments: &[String],
parallel_n: usize,
opts: &GenerateOptions,
) -> crate::Result<Vec<f32>> {
debug_assert!(parallel_n >= 2);
debug_assert!(segments.len() >= 2);
let cancel_fn: Option<Box<dyn Fn() -> bool>> = opts.cancel.as_ref().map(|c| {
let c = c.clone();
Box::new(move || c.is_cancelled()) as Box<dyn Fn() -> bool>
});
let (ref_feat, mut output_audio): (Tensor<B, 3>, Vec<f32>) = match &opts.prompt {
Prompt::Reference { audio } => {
let rf = self.encode_prompt_audio(audio, PadMode::Right)?;
(rf, Vec::new())
}
Prompt::None => {
let seed_text = segments[0].as_str();
let seed_pcm = self.generate_one_with_prompt(
seed_text,
&Prompt::None,
opts,
cancel_fn.as_deref(),
)?;
let rf = self.pcm_to_ref_feat(&seed_pcm)?;
(rf, seed_pcm)
}
_ => unreachable!(),
};
let remaining: &[String] = match &opts.prompt {
Prompt::None => &segments[1..],
_ => &segments[..],
};
for group in remaining.chunks(parallel_n) {
let texts_ref: Vec<&str> = group.iter().map(|s| s.as_str()).collect();
let (tt, tm, ft, fm, lens) =
self.build_batched_inputs(&texts_ref, Some(&ref_feat))?;
let (latent, stops) = self.model.inference_with_lengths(
tt, tm, ft, fm,
opts.min_len,
opts.max_len,
opts.inference_timesteps,
opts.cfg_value as f64,
cancel_fn.as_deref(),
Some(lens),
)?;
let dims = latent.dims();
let p = self.model.patch_size();
for i in 0..group.len() {
let stop_i = stops[i];
let pat = (stop_i * p).min(dims[2]);
if pat == 0 {
continue;
}
let lat_i = latent.clone().slice([i..i + 1, 0..dims[1], 0..pat]);
let pcm = decode_latent_to_samples(&self.model.audio_vae, lat_i)?;
output_audio.extend_from_slice(&pcm);
}
}
Ok(output_audio)
}
}
pub struct BatchBuilder<'a, B: Backend> {
voxcpm: &'a VoxCPM<B>,
items: Vec<(String, Prompt)>,
}
impl<'a, B: Backend> BatchBuilder<'a, B> {
pub fn add(mut self, text: impl Into<String>, prompt: Prompt) -> Self {
self.items.push((text.into(), prompt));
self
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn run(self, opts: GenerateOptions) -> crate::Result<Vec<Vec<f32>>> {
if self.items.is_empty() {
return Ok(Vec::new());
}
if self.items.len() == 1 {
let (text, prompt) = self.items.into_iter().next().unwrap();
let cancel_fn: Option<Box<dyn Fn() -> bool>> = opts.cancel.as_ref().map(|c| {
let c = c.clone();
Box::new(move || c.is_cancelled()) as Box<dyn Fn() -> bool>
});
let pcm = self
.voxcpm
.generate_one_with_prompt(&text, &prompt, &opts, cancel_fn.as_deref())?;
return Ok(vec![pcm]);
}
self.voxcpm.run_batch(self.items, opts)
}
}
pub fn split_sentences(text: &str) -> Vec<String> {
let mut out: Vec<String> = Vec::new();
let mut buf = String::new();
for c in text.chars() {
buf.push(c);
if matches!(c, '.' | '!' | '?' | '\n') {
let trimmed = buf.trim();
if !trimmed.is_empty() {
out.push(trimmed.to_string());
}
buf.clear();
}
}
let trimmed = buf.trim();
if !trimmed.is_empty() {
out.push(trimmed.to_string());
}
out
}
struct InferenceInputs<B: Backend> {
text_token: Tensor<B, 2, Int>,
text_mask: Tensor<B, 2>,
feat: Tensor<B, 4>,
feat_mask: Tensor<B, 2>,
}
fn decode_latent_to_samples<B: Backend>(
audio_vae: &crate::audiovae::AudioVae<B>,
latent: Tensor<B, 3>,
) -> crate::Result<Vec<f32>> {
let wav = audio_vae.decode(latent);
let wav = wav.squeeze_dim::<2>(1); let wav = wav.squeeze_dim::<1>(0); let data = wav.into_data();
data.convert::<f32>()
.into_vec::<f32>()
.map_err(|_| crate::Error::Other("unexpected VAE output dtype".into()))
}
#[derive(Debug)]
pub struct GenerateStream<'a, B: Backend> {
model: &'a crate::voxcpm2::VoxCpm2Model<B>,
state: crate::voxcpm2::model::InferenceState<B>,
pred_feats: Vec<Tensor<B, 4>>,
samples_emitted: usize,
step: usize,
min_len: usize,
max_len: usize,
inference_timesteps: usize,
cfg_value: f64,
chunk_patches: usize,
cancel: Option<CancelToken>,
finished: bool,
}
impl<B: Backend> GenerateStream<'_, B> {
pub fn sample_rate(&self) -> u32 {
self.model.sample_rate() as u32
}
pub fn steps_taken(&self) -> usize {
self.state.steps_taken
}
fn step_chunk(&mut self) -> crate::Result<Option<Vec<f32>>> {
if self.finished {
return Ok(None);
}
let mut produced_any = false;
for _ in 0..self.chunk_patches {
if self.step >= self.max_len {
self.finished = true;
break;
}
if let Some(c) = &self.cancel
&& c.is_cancelled()
{
self.finished = true;
return Err(crate::Error::Cancelled);
}
let i = self.step;
let crate::voxcpm2::model::DitStep { pred_feat, stops } =
self.model.dit_step(&mut self.state, self.inference_timesteps, self.cfg_value);
self.pred_feats.push(pred_feat.clone());
produced_any = true;
let stop = stops.first().copied().unwrap_or(false);
if i > self.min_len && stop {
self.finished = true;
self.step += 1;
break;
}
self.model.lm_step(&mut self.state, pred_feat);
self.step += 1;
}
if !produced_any {
return Ok(None);
}
let latent = crate::voxcpm2::VoxCpm2Model::stack_pred_feats(&self.pred_feats);
let all = decode_latent_to_samples(&self.model.audio_vae, latent)?;
if all.len() <= self.samples_emitted {
return Ok(Some(Vec::new()));
}
let chunk = all[self.samples_emitted..].to_vec();
self.samples_emitted = all.len();
Ok(Some(chunk))
}
}
impl<B: Backend> Iterator for GenerateStream<'_, B> {
type Item = crate::Result<Vec<f32>>;
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.step_chunk() {
Ok(Some(chunk)) if chunk.is_empty() => {
if self.finished {
return None;
}
continue;
}
Ok(Some(chunk)) => return Some(Ok(chunk)),
Ok(None) => return None,
Err(e) => return Some(Err(e)),
}
}
}
}
#[derive(Debug, Clone, Copy)]
enum PadMode {
Right,
Left,
}