use crate::error::{Error, Result};
use crate::model::audio::kokoro::{
BertEncoder, Decoder, IStftOptions, KokoroConfig, ProsodyPredictor, TextEncoder,
decode_prosody_durations, istft, length_regulator, split_voice_style,
};
use numr::dtype::DType;
use numr::ops::{
ActivationOps, BinaryOps, CompareOps, ConvOps, IndexingOps, MatmulOps, NormalizationOps,
ReduceOps, ScalarOps, ShapeOps, TensorOps, TypeConversionOps, UnaryOps, UtilityOps,
};
use numr::runtime::cpu::{CpuClient, CpuRuntime};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub struct KokoroModelV2<R: Runtime> {
pub bert: BertEncoder<R>,
pub text_encoder: TextEncoder<R>,
pub predictor: ProsodyPredictor<R>,
pub decoder: Decoder<R>,
pub config: KokoroConfig,
}
impl<R: Runtime> KokoroModelV2<R> {
#[allow(clippy::type_complexity)]
pub fn forward_to_spectrogram<C>(
&self,
client: &C,
token_ids: &Tensor<R>,
voice_row: &Tensor<R>,
min_frames_per_phoneme: u32,
) -> Result<(Tensor<R>, Tensor<R>, Vec<u32>)>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ IndexingOps<R>
+ MatmulOps<R>
+ NormalizationOps<R>
+ ActivationOps<R>
+ TensorOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ConvOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ CompareOps<R>
+ TypeConversionOps<R>
+ UtilityOps<R>,
R::Client: IndexingOps<R>,
{
let (decoder_style, predictor_style) = split_voice_style(voice_row, self.config.style_dim)?;
let d_en = self.bert.forward(client, token_ids)?;
let dur_logits = self
.predictor
.predict_duration(client, &d_en, &predictor_style)?;
let logits_shape = dur_logits.shape();
let (b, t_phon) = (logits_shape[0], logits_shape[1]);
if b != 1 {
return Err(Error::InvalidArgument {
arg: "token_ids",
reason: "synthesis is single-utterance; batch > 1 not supported".into(),
});
}
let logits_flat: Vec<f32> = dur_logits.contiguous()?.to_vec();
let durations = decode_prosody_durations(
&logits_flat,
t_phon,
self.config.max_dur,
min_frames_per_phoneme,
);
let frames = length_regulator(client, &d_en, &durations)?;
let (f0, n_energy) = self
.predictor
.predict_f0_n(client, &frames, &predictor_style)?;
let t_en = self.text_encoder.forward(client, token_ids)?;
let asr_bt_d = length_regulator(client, &t_en, &durations)?; let asr = asr_bt_d
.transpose(1, 2)
.map_err(Error::Numr)?
.contiguous()?;
let (mag, phase) = self
.decoder
.forward(client, &asr, &f0, &n_energy, &decoder_style)?;
Ok((mag, phase, durations))
}
}
impl KokoroModelV2<CpuRuntime> {
#[allow(clippy::type_complexity)]
pub fn forward_to_spectrogram_cpu(
&self,
client: &CpuClient,
token_ids: &Tensor<CpuRuntime>,
voice_row: &Tensor<CpuRuntime>,
min_frames_per_phoneme: u32,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>, Vec<u32>)> {
let (decoder_style, predictor_style) = split_voice_style(voice_row, self.config.style_dim)?;
let d_en = self.bert.forward(client, token_ids)?;
let dur_logits = self
.predictor
.predict_duration(client, &d_en, &predictor_style)?;
let logits_shape = dur_logits.shape();
let (b, t_phon) = (logits_shape[0], logits_shape[1]);
if b != 1 {
return Err(Error::InvalidArgument {
arg: "token_ids",
reason: "synthesis is single-utterance; batch > 1 not supported".into(),
});
}
let logits_flat: Vec<f32> = dur_logits.contiguous()?.to_vec();
let durations = decode_prosody_durations(
&logits_flat,
t_phon,
self.config.max_dur,
min_frames_per_phoneme,
);
let frames = length_regulator(client, &d_en, &durations)?;
let (f0, n_energy) = self
.predictor
.predict_f0_n(client, &frames, &predictor_style)?;
let t_en = self.text_encoder.forward(client, token_ids)?;
let asr_bt_d = length_regulator(client, &t_en, &durations)?;
let asr = asr_bt_d
.transpose(1, 2)
.map_err(Error::Numr)?
.contiguous()?;
let (mag, phase) =
self.decoder
.forward_cpu_full(client, &asr, &f0, &n_energy, &decoder_style)?;
Ok((mag, phase, durations))
}
pub fn synthesize_cpu(
&self,
client: &CpuClient,
token_ids: &Tensor<CpuRuntime>,
voice_row: &Tensor<CpuRuntime>,
min_frames_per_phoneme: u32,
) -> Result<Tensor<CpuRuntime>> {
let (mag, phase, _durations) =
self.forward_to_spectrogram_cpu(client, token_ids, voice_row, min_frames_per_phoneme)?;
let window = super::window::hann_window(self.config.n_fft, voice_row.device());
let opts = IStftOptions {
hop_length: self.config.hop_length,
center: true,
eps: 1e-8,
};
istft(client, &mag, &phase, &window, opts)
}
}
pub fn alignment_matrix_from_durations<R: Runtime<DType = DType>>(
durations: &[u32],
reference: &Tensor<R>,
) -> Result<Tensor<R>> {
let t_phon = durations.len();
let t_frames: u32 = durations.iter().sum();
if t_frames == 0 {
return Err(Error::InvalidArgument {
arg: "durations",
reason: "total frame count must be > 0".into(),
});
}
let t_frames = t_frames as usize;
let mut data = vec![0.0f32; t_phon * t_frames];
let mut cursor = 0usize;
for (p, &d) in durations.iter().enumerate() {
for f in 0..(d as usize) {
data[p * t_frames + cursor + f] = 1.0;
}
cursor += d as usize;
}
Ok(Tensor::<R>::from_slice(
&data,
&[t_phon, t_frames],
reference.device(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::CpuDevice;
#[test]
fn alignment_matrix_shape_and_placement() {
let device = CpuDevice::new();
let reference = Tensor::<CpuRuntime>::from_slice(&[0.0f32], &[1], &device);
let aln = alignment_matrix_from_durations(&[2, 1, 3], &reference).unwrap();
assert_eq!(aln.shape(), &[3, 6]);
let v: Vec<f32> = aln.to_vec();
#[rustfmt::skip]
let expected = vec![
1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, ];
assert_eq!(v, expected);
}
#[test]
fn alignment_matrix_rejects_zero_total() {
let device = CpuDevice::new();
let reference = Tensor::<CpuRuntime>::from_slice(&[0.0f32], &[1], &device);
assert!(alignment_matrix_from_durations(&[0, 0], &reference).is_err());
}
#[test]
fn alignment_matrix_single_phoneme() {
let device = CpuDevice::new();
let reference = Tensor::<CpuRuntime>::from_slice(&[0.0f32], &[1], &device);
let aln = alignment_matrix_from_durations(&[4], &reference).unwrap();
assert_eq!(aln.shape(), &[1, 4]);
let v: Vec<f32> = aln.to_vec();
assert_eq!(v, vec![1.0, 1.0, 1.0, 1.0]);
}
}