use std::collections::HashMap;
use std::path::{Path, PathBuf};
use mlx_rs::module::{Module, Param};
use mlx_rs::nn;
use mlx_rs::ops;
use mlx_rs::ops::indexing::{take_axis, IndexOp};
use mlx_rs::Array;
use tracing::info;
use super::mlx::{build_qlinear, load_all_tensors, QLinear, QuantConfig};
use crate::InferenceError;
#[derive(Debug, Clone)]
pub struct KokoroConfig {
pub dim_in: usize,
pub hidden_dim: usize,
pub style_dim: usize,
pub n_mels: usize,
pub n_token: usize,
pub n_layer: usize,
pub plbert_num_layers: usize,
pub plbert_hidden: usize,
pub plbert_num_heads: usize,
pub plbert_intermediate: usize,
pub plbert_embedding_dim: usize,
pub plbert_vocab_size: usize,
pub plbert_max_position: usize,
pub(crate) quant: Option<QuantConfig>,
pub sample_rate: u32,
}
impl Default for KokoroConfig {
fn default() -> Self {
Self {
dim_in: 64,
hidden_dim: 512,
style_dim: 128,
n_mels: 80,
n_token: 178,
n_layer: 3,
plbert_num_layers: 12,
plbert_hidden: 768,
plbert_num_heads: 12,
plbert_intermediate: 2048,
plbert_embedding_dim: 128,
plbert_vocab_size: 178,
plbert_max_position: 512,
quant: Some(QuantConfig {
group_size: 64,
bits: 6,
}),
sample_rate: 24000,
}
}
}
fn get_tensor(tensors: &HashMap<String, Array>, key: &str) -> Result<Array, InferenceError> {
tensors
.get(key)
.cloned()
.ok_or_else(|| InferenceError::InferenceFailed(format!("missing tensor: {key}")))
}
fn build_dense_linear(
tensors: &HashMap<String, Array>,
prefix: &str,
) -> Result<nn::Linear, InferenceError> {
let weight = get_tensor(tensors, &format!("{prefix}.weight"))?;
let bias = tensors.get(&format!("{prefix}.bias")).cloned();
Ok(nn::Linear {
weight: Param::new(weight),
bias: Param::new(bias),
})
}
fn build_layer_norm(
tensors: &HashMap<String, Array>,
prefix: &str,
eps: f32,
) -> Result<LayerNorm, InferenceError> {
let weight = get_tensor(tensors, &format!("{prefix}.weight"))?;
let bias = tensors.get(&format!("{prefix}.bias")).cloned();
Ok(LayerNorm { weight, bias, eps })
}
fn conv1d_forward(
input: &Array,
weight: &Array,
bias: Option<&Array>,
stride: i32,
padding: i32,
dilation: i32,
) -> Result<Array, mlx_rs::error::Exception> {
let mut y = ops::conv1d(input, weight, stride, padding, dilation, None::<i32>)?;
if let Some(b) = bias {
y = ops::add(&y, b)?;
}
Ok(y)
}
fn conv_transpose1d_forward(
input: &Array,
weight: &Array,
bias: Option<&Array>,
stride: i32,
padding: i32,
) -> Result<Array, mlx_rs::error::Exception> {
let mut y = ops::conv_transpose1d(
input,
weight,
stride,
padding,
None::<i32>,
None::<i32>,
None::<i32>,
)?;
if let Some(b) = bias {
y = ops::add(&y, b)?;
}
Ok(y)
}
struct LayerNorm {
weight: Array,
bias: Option<Array>,
eps: f32,
}
impl LayerNorm {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let mean = x.mean_axes(&[-1], true)?;
let centered = ops::subtract(x, &mean)?;
let var = centered.multiply(¢ered)?.mean_axes(&[-1], true)?;
let eps = Array::from_f32(self.eps);
let inv_std = ops::rsqrt(&ops::add(&var, &eps)?)?;
let normed = ops::multiply(¢ered, &inv_std)?;
let scaled = ops::multiply(&normed, &self.weight)?;
if let Some(ref bias) = self.bias {
ops::add(&scaled, bias)
} else {
Ok(scaled)
}
}
}
struct AlbertAttention {
query: QLinear,
key: QLinear,
value: QLinear,
dense: QLinear,
layer_norm: LayerNorm,
num_heads: usize,
head_dim: usize,
}
impl AlbertAttention {
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape();
let (batch, seq_len, _hidden) = (shape[0] as usize, shape[1] as usize, shape[2] as usize);
let q = self.query.forward(x)?;
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
let reshape_head = |t: Array| -> Result<Array, mlx_rs::error::Exception> {
let r = ops::reshape(
&t,
&[
batch as i32,
seq_len as i32,
self.num_heads as i32,
self.head_dim as i32,
],
)?;
ops::transpose_axes(&r, &[0, 2, 1, 3])
};
let q = reshape_head(q)?;
let k = reshape_head(k)?;
let v = reshape_head(v)?;
let scale = Array::from_f32(1.0 / (self.head_dim as f32).sqrt());
let scores = ops::multiply(
&ops::matmul(&q, &ops::transpose_axes(&k, &[0, 1, 3, 2])?)?,
&scale,
)?;
let attn = ops::softmax_axis(&scores, -1, None)?;
let out = ops::matmul(&attn, &v)?;
let out = ops::transpose_axes(&out, &[0, 2, 1, 3])?;
let out = ops::reshape(
&out,
&[
batch as i32,
seq_len as i32,
(self.num_heads * self.head_dim) as i32,
],
)?;
let projected = self.dense.forward(&out)?;
let residual = ops::add(x, &projected)?;
self.layer_norm.forward(&residual)
}
}
struct AlbertFfn {
ffn: QLinear,
ffn_output: QLinear,
layer_norm: LayerNorm,
}
impl AlbertFfn {
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = self.ffn.forward(x)?;
let h = nn::gelu(&h)?;
let h = self.ffn_output.forward(&h)?;
let residual = ops::add(x, &h)?;
self.layer_norm.forward(&residual)
}
}
struct AlbertLayer {
attention: AlbertAttention,
ffn: AlbertFfn,
}
impl AlbertLayer {
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = self.attention.forward(x)?;
self.ffn.forward(&h)
}
}
struct PLBert {
word_embeddings: Array,
position_embeddings: Array,
token_type_embeddings: Array,
embedding_hidden_mapping: nn::Linear,
albert_layers: Vec<AlbertLayer>,
#[allow(dead_code)]
pooler: nn::Linear,
#[allow(dead_code)]
num_layer_groups: usize,
}
impl PLBert {
fn load(
tensors: &HashMap<String, Array>,
config: &KokoroConfig,
) -> Result<Self, InferenceError> {
let quant = config.quant.as_ref();
let pfx = "bert";
let word_embeddings =
get_tensor(tensors, &format!("{pfx}.embeddings.word_embeddings.weight"))?;
let position_embeddings = get_tensor(
tensors,
&format!("{pfx}.embeddings.position_embeddings.weight"),
)?;
let token_type_embeddings = get_tensor(
tensors,
&format!("{pfx}.embeddings.token_type_embeddings.weight"),
)?;
let embedding_hidden_mapping = build_dense_linear(
tensors,
&format!("{pfx}.encoder.embedding_hidden_mapping_in"),
)?;
let mut num_groups = 0usize;
for key in tensors.keys() {
if let Some(rest) = key.strip_prefix(&format!("{pfx}.encoder.albert_layer_groups.")) {
if let Some(idx_str) = rest.split('.').next() {
if let Ok(idx) = idx_str.parse::<usize>() {
num_groups = num_groups.max(idx + 1);
}
}
}
}
if num_groups == 0 {
num_groups = 1;
}
let head_dim = config.plbert_hidden / config.plbert_num_heads;
let mut albert_layers = Vec::with_capacity(num_groups);
for g in 0..num_groups {
let mut num_inner = 0usize;
for key in tensors.keys() {
if let Some(rest) = key.strip_prefix(&format!(
"{pfx}.encoder.albert_layer_groups.{g}.albert_layers."
)) {
if let Some(idx_str) = rest.split('.').next() {
if let Ok(idx) = idx_str.parse::<usize>() {
num_inner = num_inner.max(idx + 1);
}
}
}
}
if num_inner == 0 {
num_inner = 1;
}
for l in 0..num_inner {
let lpfx = format!("{pfx}.encoder.albert_layer_groups.{g}.albert_layers.{l}");
let attention = AlbertAttention {
query: build_qlinear(tensors, &format!("{lpfx}.attention.query"), quant)?,
key: build_qlinear(tensors, &format!("{lpfx}.attention.key"), quant)?,
value: build_qlinear(tensors, &format!("{lpfx}.attention.value"), quant)?,
dense: build_qlinear(tensors, &format!("{lpfx}.attention.dense"), quant)?,
layer_norm: build_layer_norm(
tensors,
&format!("{lpfx}.attention.LayerNorm"),
1e-12,
)?,
num_heads: config.plbert_num_heads,
head_dim,
};
let ffn = AlbertFfn {
ffn: build_qlinear(tensors, &format!("{lpfx}.ffn"), quant)?,
ffn_output: build_qlinear(tensors, &format!("{lpfx}.ffn_output"), quant)?,
layer_norm: build_layer_norm(
tensors,
&format!("{lpfx}.full_layer_layer_norm"),
1e-12,
)?,
};
albert_layers.push(AlbertLayer { attention, ffn });
}
}
let pooler = build_dense_linear(tensors, &format!("{pfx}.pooler.dense"))?;
Ok(Self {
word_embeddings,
position_embeddings,
token_type_embeddings,
embedding_hidden_mapping,
albert_layers,
pooler,
num_layer_groups: num_groups,
})
}
fn forward(&mut self, token_ids: &Array) -> Result<Array, mlx_rs::error::Exception> {
let seq_len = token_ids.shape()[1] as usize;
let flat_ids = ops::reshape(token_ids, &[-1])?;
let tok_emb = take_axis(&self.word_embeddings, &flat_ids, 0)?;
let tok_emb = ops::reshape(&tok_emb, &[1, seq_len as i32, -1])?;
let pos_ids =
Array::from_slice(&(0..seq_len as i32).collect::<Vec<_>>(), &[seq_len as i32]);
let pos_emb = take_axis(&self.position_embeddings, &pos_ids, 0)?;
let pos_emb = ops::reshape(&pos_emb, &[1, seq_len as i32, -1])?;
let tt_ids = Array::from_slice(&vec![0i32; seq_len], &[seq_len as i32]);
let tt_emb = take_axis(&self.token_type_embeddings, &tt_ids, 0)?;
let tt_emb = ops::reshape(&tt_emb, &[1, seq_len as i32, -1])?;
let emb = ops::add(&ops::add(&tok_emb, &pos_emb)?, &tt_emb)?;
let mut h = self.embedding_hidden_mapping.forward(&emb)?;
let total_virtual_layers = 12usize; let actual_layers = self.albert_layers.len();
for i in 0..total_virtual_layers {
let layer_idx = i % actual_layers;
h = self.albert_layers[layer_idx].forward(&h)?;
}
Ok(h)
}
}
struct BertEncoder {
proj: QLinear,
}
impl BertEncoder {
fn load(
tensors: &HashMap<String, Array>,
quant: Option<&QuantConfig>,
) -> Result<Self, InferenceError> {
let proj = build_qlinear(tensors, "bert_encoder", quant)?;
Ok(Self { proj })
}
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
self.proj.forward(x)
}
}
struct DurationPredictor {
conv_layers: Vec<(Array, Option<Array>)>,
output_proj: Option<nn::Linear>,
}
impl DurationPredictor {
fn load(
tensors: &HashMap<String, Array>,
_config: &KokoroConfig,
) -> Result<Self, InferenceError> {
let mut conv_layers = Vec::new();
for i in 0..5 {
let w_key = format!("duration_predictor.conv_layers.{i}.weight");
let b_key = format!("duration_predictor.conv_layers.{i}.bias");
if let Some(w) = tensors.get(&w_key) {
let b = tensors.get(&b_key).cloned();
conv_layers.push((w.clone(), b));
} else {
break;
}
}
let output_proj = if tensors.contains_key("duration_predictor.output.weight") {
Some(
build_dense_linear(tensors, "duration_predictor.output").unwrap_or_else(|_| {
nn::Linear {
weight: Param::new(Array::ones::<f32>(&[1, 512]).unwrap()),
bias: Param::new(None),
}
}),
)
} else {
None
};
Ok(Self {
conv_layers,
output_proj,
})
}
fn predict(
&mut self,
hidden: &Array,
_style: &Array,
) -> Result<Vec<usize>, mlx_rs::error::Exception> {
let seq_len = hidden.shape()[1] as usize;
if !self.conv_layers.is_empty() {
let mut h = hidden.clone();
for (weight, bias) in &self.conv_layers {
h = conv1d_forward(&h, weight, bias.as_ref(), 1, 1, 1)?;
h = nn::relu(&h)?;
}
if let Some(ref mut proj) = self.output_proj {
h = proj.forward(&h)?;
}
let dur_pred = h.index((.., .., 0));
mlx_rs::transforms::eval([&dur_pred])?;
let dur_data: Vec<f32> = dur_pred.as_slice::<f32>().to_vec();
return Ok(dur_data
.iter()
.map(|&d| (d.abs().round() as usize).max(1))
.collect());
}
let avg_frames = 10usize;
Ok(vec![avg_frames; seq_len])
}
}
struct VocoderResBlock {
convs: Vec<(Array, Option<Array>, i32)>,
}
impl VocoderResBlock {
fn load(tensors: &HashMap<String, Array>, prefix: &str) -> Result<Self, InferenceError> {
let mut convs = Vec::new();
for i in 0..6 {
let w_key = format!("{prefix}.convs.{i}.weight");
if let Some(w) = tensors.get(&w_key) {
let b = tensors.get(&format!("{prefix}.convs.{i}.bias")).cloned();
let dilation = if i % 2 == 0 { 1 } else { 3 };
convs.push((w.clone(), b, dilation));
} else {
break;
}
}
Ok(Self { convs })
}
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let mut h = x.clone();
for (weight, bias, dilation) in &self.convs {
let activated = nn::silu(&h)?;
let kernel_size = weight.shape()[1] as i32;
let pad = (kernel_size - 1) * dilation / 2;
h = conv1d_forward(&activated, weight, bias.as_ref(), 1, pad, *dilation)?;
}
ops::add(x, &h)
}
}
struct IstftDecoder {
conv_pre_weight: Array,
conv_pre_bias: Option<Array>,
upsample_stages: Vec<(Array, Option<Array>, i32)>,
res_blocks: Vec<Vec<VocoderResBlock>>,
conv_post_weight: Array,
conv_post_bias: Option<Array>,
n_fft: usize,
hop_length: usize,
}
impl IstftDecoder {
fn load(
tensors: &HashMap<String, Array>,
config: &KokoroConfig,
) -> Result<Self, InferenceError> {
let pfx = "decoder";
let conv_pre_weight = get_tensor(tensors, &format!("{pfx}.conv_pre.weight"))?;
let conv_pre_bias = tensors.get(&format!("{pfx}.conv_pre.bias")).cloned();
let mut upsample_stages = Vec::new();
let upsample_strides = [8, 8, 2, 2]; for i in 0..8 {
let w_key = format!("{pfx}.ups.{i}.weight");
if let Some(w) = tensors.get(&w_key) {
let b = tensors.get(&format!("{pfx}.ups.{i}.bias")).cloned();
let stride = if i < upsample_strides.len() {
upsample_strides[i]
} else {
2
};
upsample_stages.push((w.clone(), b, stride));
} else {
break;
}
}
let mut res_blocks = Vec::new();
for i in 0..upsample_stages.len() {
let mut stage_blocks = Vec::new();
for j in 0..4 {
let rpfx = format!("{pfx}.resblocks.{}", i * 3 + j);
if tensors.contains_key(&format!("{rpfx}.convs.0.weight")) {
stage_blocks.push(VocoderResBlock::load(tensors, &rpfx)?);
} else {
break;
}
}
res_blocks.push(stage_blocks);
}
let conv_post_weight = get_tensor(tensors, &format!("{pfx}.conv_post.weight"))?;
let conv_post_bias = tensors.get(&format!("{pfx}.conv_post.bias")).cloned();
let n_fft = 16;
let hop_length = 4;
let _ = config;
Ok(Self {
conv_pre_weight,
conv_pre_bias,
upsample_stages,
res_blocks,
conv_post_weight,
conv_post_bias,
n_fft,
hop_length,
})
}
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let kernel_size = self.conv_pre_weight.shape()[1] as i32;
let pad = (kernel_size - 1) / 2;
let mut h = conv1d_forward(
x,
&self.conv_pre_weight,
self.conv_pre_bias.as_ref(),
1,
pad,
1,
)?;
for (i, (up_w, up_b, stride)) in self.upsample_stages.iter().enumerate() {
h = nn::silu(&h)?;
let up_kernel = up_w.shape()[1] as i32;
let up_pad = (up_kernel - *stride) / 2;
h = conv_transpose1d_forward(&h, up_w, up_b.as_ref(), *stride, up_pad)?;
if i < self.res_blocks.len() {
let mut sum = Array::from_f32(0.0);
let mut count = 0;
for block in &self.res_blocks[i] {
let block_out = block.forward(&h)?;
sum = if count == 0 {
block_out
} else {
ops::add(&sum, &block_out)?
};
count += 1;
}
if count > 0 {
let scale = Array::from_f32(1.0 / count as f32);
h = ops::multiply(&sum, &scale)?;
}
}
}
h = nn::silu(&h)?;
let post_kernel = self.conv_post_weight.shape()[1] as i32;
let post_pad = (post_kernel - 1) / 2;
h = conv1d_forward(
&h,
&self.conv_post_weight,
self.conv_post_bias.as_ref(),
1,
post_pad,
1,
)?;
let n_fft_half = (self.n_fft / 2 + 1) as i32;
let out_channels = h.shape()[2];
if out_channels >= 2 * n_fft_half {
let magnitude = h.index((.., .., ..n_fft_half));
let phase = h.index((.., .., n_fft_half..2 * n_fft_half));
let mag_linear = ops::exp(&magnitude)?;
let real = ops::multiply(&mag_linear, &ops::cos(&phase)?)?;
let imag = ops::multiply(&mag_linear, &ops::sin(&phase)?)?;
let batch = h.shape()[0] as usize;
let n_frames = real.shape()[1] as usize;
let out_len = n_frames * self.hop_length + self.n_fft;
let mut window = vec![0.0f32; self.n_fft];
for i in 0..self.n_fft {
window[i] =
0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / self.n_fft as f32).cos());
}
let mut basis_cos = vec![0.0f32; (self.n_fft / 2 + 1) * self.n_fft];
let mut basis_sin = vec![0.0f32; (self.n_fft / 2 + 1) * self.n_fft];
let n_fft_f = self.n_fft as f32;
for k in 0..=self.n_fft / 2 {
for n in 0..self.n_fft {
let angle = 2.0 * std::f32::consts::PI * k as f32 * n as f32 / n_fft_f;
basis_cos[k * self.n_fft + n] = angle.cos() * window[n] * 2.0 / n_fft_f;
basis_sin[k * self.n_fft + n] = angle.sin() * window[n] * 2.0 / n_fft_f;
}
}
for n in 0..self.n_fft {
basis_cos[n] /= 2.0;
if self.n_fft / 2 > 0 {
basis_cos[(self.n_fft / 2) * self.n_fft + n] /= 2.0;
}
basis_sin[n] /= 2.0;
if self.n_fft / 2 > 0 {
basis_sin[(self.n_fft / 2) * self.n_fft + n] /= 2.0;
}
}
let basis_cos_arr = Array::from_slice(
&basis_cos,
&[(self.n_fft / 2 + 1) as i32, self.n_fft as i32],
);
let basis_sin_arr = Array::from_slice(
&basis_sin,
&[(self.n_fft / 2 + 1) as i32, self.n_fft as i32],
);
let frame_real = ops::matmul(&real, &basis_cos_arr)?;
let frame_imag = ops::matmul(&imag, &basis_sin_arr)?;
let frame_samples = ops::add(&frame_real, &frame_imag)?;
mlx_rs::transforms::eval([&frame_samples])?;
let frame_data: Vec<f32> = frame_samples.as_slice::<f32>().to_vec();
let mut output = vec![0.0f32; batch * out_len];
let mut window_sum = vec![0.0f32; out_len];
for b in 0..batch {
for f in 0..n_frames {
let offset = f * self.hop_length;
for n in 0..self.n_fft {
if offset + n < out_len {
let idx = b * n_frames * self.n_fft + f * self.n_fft + n;
output[b * out_len + offset + n] += frame_data[idx];
window_sum[offset + n] += window[n] * window[n];
}
}
}
}
for b in 0..batch {
for i in 0..out_len {
if window_sum[i] > 1e-8 {
output[b * out_len + i] /= window_sum[i];
}
}
}
let waveform = Array::from_slice(&output, &[batch as i32, out_len as i32]);
Ok(waveform)
} else {
let squeezed = h.index((.., .., 0));
Ok(squeezed)
}
}
}
struct StyleEncoder {
#[allow(dead_code)]
proj: Option<nn::Linear>,
}
impl StyleEncoder {
fn load(
tensors: &HashMap<String, Array>,
_config: &KokoroConfig,
) -> Result<Self, InferenceError> {
let proj = if tensors.contains_key("style_encoder.proj.weight") {
Some(build_dense_linear(tensors, "style_encoder.proj")?)
} else {
None
};
Ok(Self { proj })
}
fn get_style(
&mut self,
voice_path: Option<&Path>,
config: &KokoroConfig,
) -> Result<Array, InferenceError> {
if let Some(path) = voice_path {
if path.exists() {
let data = std::fs::read(path)
.map_err(|e| InferenceError::InferenceFailed(format!("read voice: {e}")))?;
if data.len() > 10 && data[0] == 0x93 && data[1] == b'N' {
let header_len = if data[6] == 1 {
u16::from_le_bytes([data[8], data[9]]) as usize
} else {
u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize
};
let data_offset = if data[6] == 1 {
10 + header_len
} else {
12 + header_len
};
let float_data: Vec<f32> = data[data_offset..]
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
let style = Array::from_slice(&float_data, &[1, float_data.len() as i32]);
return Ok(style);
}
let float_data: Vec<f32> = data
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
if !float_data.is_empty() {
let style = Array::from_slice(&float_data, &[1, float_data.len() as i32]);
return Ok(style);
}
}
}
Ok(Array::from_slice(
&vec![0.0f32; config.style_dim],
&[1, config.style_dim as i32],
))
}
}
struct TextToMel {
text_proj: Option<QLinear>,
style_proj: Option<QLinear>,
}
impl TextToMel {
fn load(
tensors: &HashMap<String, Array>,
config: &KokoroConfig,
) -> Result<Self, InferenceError> {
let quant = config.quant.as_ref();
let text_proj = if tensors.contains_key("text_to_mel.text_proj.weight")
|| tensors.contains_key("text_to_mel.text_proj.scales")
{
Some(build_qlinear(tensors, "text_to_mel.text_proj", quant)?)
} else {
None
};
let style_proj = if tensors.contains_key("text_to_mel.style_proj.weight")
|| tensors.contains_key("text_to_mel.style_proj.scales")
{
Some(build_qlinear(tensors, "text_to_mel.style_proj", quant)?)
} else {
None
};
Ok(Self {
text_proj,
style_proj,
})
}
fn forward(
&mut self,
text_hidden: &Array,
style: &Array,
durations: &[usize],
) -> Result<Array, mlx_rs::error::Exception> {
let seq_len = text_hidden.shape()[1] as usize;
let hidden_dim = text_hidden.shape()[2] as usize;
let styled = if let Some(ref mut sp) = self.style_proj {
let style_expanded = ops::broadcast_to(
style,
&[1, seq_len as i32, style.shape()[style.shape().len() - 1]],
)?;
let style_proj = sp.forward(&style_expanded)?;
ops::add(text_hidden, &style_proj)?
} else {
if style.shape().last().copied() == Some(hidden_dim as i32) {
let style_expanded =
ops::broadcast_to(style, &[1, seq_len as i32, hidden_dim as i32])?;
ops::add(text_hidden, &style_expanded)?
} else {
text_hidden.clone()
}
};
let projected = if let Some(ref mut tp) = self.text_proj {
tp.forward(&styled)?
} else {
styled
};
let total_frames: usize = durations.iter().sum();
let out_dim = projected.shape()[2] as usize;
mlx_rs::transforms::eval([&projected])?;
let proj_data: Vec<f32> = projected.as_slice::<f32>().to_vec();
let mut expanded = vec![0.0f32; total_frames * out_dim];
let mut frame_offset = 0;
for (phone_idx, &dur) in durations.iter().enumerate() {
if phone_idx < seq_len {
let src_offset = phone_idx * out_dim;
for f in 0..dur {
for d in 0..out_dim {
expanded[(frame_offset + f) * out_dim + d] = proj_data[src_offset + d];
}
}
}
frame_offset += dur;
}
Ok(Array::from_slice(
&expanded,
&[1, total_frames as i32, out_dim as i32],
))
}
}
fn write_wav(path: &Path, samples: &[f32], sample_rate: u32) -> Result<(), InferenceError> {
let num_samples = samples.len() as u32;
let bytes_per_sample = 2u16; let num_channels = 1u16;
let byte_rate = sample_rate * bytes_per_sample as u32 * num_channels as u32;
let block_align = bytes_per_sample * num_channels;
let data_size = num_samples * bytes_per_sample as u32;
let file_size = 36 + data_size;
let mut buf = Vec::with_capacity(44 + data_size as usize);
buf.extend_from_slice(b"RIFF");
buf.extend_from_slice(&file_size.to_le_bytes());
buf.extend_from_slice(b"WAVE");
buf.extend_from_slice(b"fmt ");
buf.extend_from_slice(&16u32.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&num_channels.to_le_bytes());
buf.extend_from_slice(&sample_rate.to_le_bytes());
buf.extend_from_slice(&byte_rate.to_le_bytes());
buf.extend_from_slice(&block_align.to_le_bytes());
buf.extend_from_slice(&(bytes_per_sample * 8).to_le_bytes());
buf.extend_from_slice(b"data");
buf.extend_from_slice(&data_size.to_le_bytes());
for &sample in samples {
let clamped = sample.clamp(-1.0, 1.0);
let pcm = (clamped * 32767.0) as i16;
buf.extend_from_slice(&pcm.to_le_bytes());
}
std::fs::write(path, &buf)
.map_err(|e| InferenceError::InferenceFailed(format!("write WAV: {e}")))?;
Ok(())
}
fn text_to_phoneme_ids(text: &str, vocab_size: usize) -> Vec<i32> {
let mut ids = Vec::with_capacity(text.len() + 2);
ids.push(1);
for ch in text.chars() {
let code = ch as u32;
if code < 128 {
let token_id = (code as usize % (vocab_size.saturating_sub(2))) + 2;
ids.push(token_id as i32);
}
}
ids.push(0);
ids
}
pub struct KokoroBackend {
plbert: PLBert,
bert_encoder: BertEncoder,
duration_predictor: DurationPredictor,
style_encoder: StyleEncoder,
text_to_mel: TextToMel,
decoder: IstftDecoder,
config: KokoroConfig,
model_dir: PathBuf,
}
unsafe impl Send for KokoroBackend {}
unsafe impl Sync for KokoroBackend {}
impl KokoroBackend {
pub fn load(model_dir: &Path) -> Result<Self, InferenceError> {
let config = KokoroConfig::default();
info!(
hidden = config.hidden_dim,
style_dim = config.style_dim,
n_token = config.n_token,
"loading Kokoro TTS model via MLX"
);
#[cfg(feature = "mlx-metal")]
let default_device = mlx_rs::Device::gpu();
#[cfg(not(feature = "mlx-metal"))]
let default_device = mlx_rs::Device::cpu();
match std::env::var("CAR_MLX_DEVICE").ok().as_deref() {
Some("cpu") => mlx_rs::Device::set_default(&mlx_rs::Device::cpu()),
#[cfg(feature = "mlx-metal")]
Some("gpu") => mlx_rs::Device::set_default(&mlx_rs::Device::gpu()),
_ => mlx_rs::Device::set_default(&default_device),
}
let weights_path = model_dir.join("kokoro-v1_0.safetensors");
let tensors = if weights_path.exists() {
info!("loading kokoro-v1_0.safetensors");
let t = Array::load_safetensors(&weights_path)
.map_err(|e| InferenceError::InferenceFailed(format!("load safetensors: {e}")))?;
let mut map = HashMap::new();
for (name, array) in t {
map.insert(name, array);
}
map
} else {
info!("loading safetensors via index");
load_all_tensors(model_dir)?
};
info!(tensors = tensors.len(), "Kokoro tensors loaded");
let plbert = PLBert::load(&tensors, &config)?;
info!("PLBert text encoder loaded");
let bert_encoder = BertEncoder::load(&tensors, config.quant.as_ref())?;
info!("BERT → hidden projection loaded");
let duration_predictor = DurationPredictor::load(&tensors, &config)?;
info!("Duration predictor loaded");
let style_encoder = StyleEncoder::load(&tensors, &config)?;
info!("Style encoder loaded");
let text_to_mel = TextToMel::load(&tensors, &config)?;
info!("Text-to-mel synthesizer loaded");
let decoder = IstftDecoder::load(&tensors, &config)?;
info!("iSTFTNet decoder loaded");
info!("Kokoro TTS model loaded successfully");
Ok(Self {
plbert,
bert_encoder,
duration_predictor,
style_encoder,
text_to_mel,
decoder,
config,
model_dir: model_dir.to_path_buf(),
})
}
pub fn synthesize(
&mut self,
text: &str,
voice: Option<&str>,
output_path: &Path,
) -> Result<PathBuf, InferenceError> {
let map_err = |e: mlx_rs::error::Exception| InferenceError::InferenceFailed(e.to_string());
info!(
text_len = text.len(),
voice = voice.unwrap_or("default"),
output = %output_path.display(),
"synthesizing speech with Kokoro"
);
let phoneme_ids = text_to_phoneme_ids(text, self.config.n_token);
let seq_len = phoneme_ids.len() as i32;
let token_ids = Array::from_slice(&phoneme_ids, &[1, seq_len]);
let bert_out = self.plbert.forward(&token_ids).map_err(map_err)?;
let text_hidden = self.bert_encoder.forward(&bert_out).map_err(map_err)?;
let voice_path = voice.map(|v| {
let npy_path = self.model_dir.join(format!("voices/{v}.npy"));
if npy_path.exists() {
npy_path
} else {
self.model_dir.join(format!("{v}.npy"))
}
});
let style = self
.style_encoder
.get_style(voice_path.as_deref(), &self.config)?;
let durations = self
.duration_predictor
.predict(&text_hidden, &style)
.map_err(map_err)?;
info!(
phones = phoneme_ids.len(),
total_frames = durations.iter().sum::<usize>(),
"duration prediction complete"
);
let mel_features = self
.text_to_mel
.forward(&text_hidden, &style, &durations)
.map_err(map_err)?;
let waveform = self.decoder.forward(&mel_features).map_err(map_err)?;
mlx_rs::transforms::eval([&waveform]).map_err(map_err)?;
let wave_data: Vec<f32> = waveform.as_slice::<f32>().to_vec();
let max_abs = wave_data.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
let normalized: Vec<f32> = if max_abs > 1e-8 {
wave_data.iter().map(|s| s / max_abs * 0.95).collect()
} else {
wave_data
};
write_wav(output_path, &normalized, self.config.sample_rate)?;
info!(
path = %output_path.display(),
samples = normalized.len(),
duration_secs = normalized.len() as f32 / self.config.sample_rate as f32,
"WAV file written"
);
Ok(output_path.to_path_buf())
}
}