use candle_core::{DType, Device, Module, Tensor};
use candle_nn::VarBuilder;
use tracing::{info, warn};
use crate::error::{TtsError, TtsResult};
use crate::mel::{MelConfig, MelSpectrogram};
use crate::traits::ReferenceAudio;
fn scalar_like(tensor: &Tensor, value: f32) -> candle_core::Result<Tensor> {
Tensor::new(value, tensor.device())?.to_dtype(tensor.dtype())
}
fn scale_tensor(tensor: &Tensor, value: f32) -> candle_core::Result<Tensor> {
tensor.broadcast_mul(&scalar_like(tensor, value)?)
}
fn leaky_relu(x: &Tensor, negative_slope: f32) -> candle_core::Result<Tensor> {
let relu_x = x.relu()?;
let neg_part = (x - &relu_x)?.broadcast_mul(&scalar_like(x, negative_slope)?)?;
&relu_x + &neg_part
}
fn avg_pool2d_2x2(x: &Tensor) -> candle_core::Result<Tensor> {
let (b, c, h, w) = x.dims4()?;
let oh = h / 2;
let ow = w / 2;
let x = x.reshape((b, c, oh, 2, w))?.mean(3)?;
x.reshape((b, c, oh, ow, 2))?.mean(4)
}
fn load_conv2d(
in_c: usize,
out_c: usize,
kernel: usize,
cfg: candle_nn::Conv2dConfig,
vb: VarBuilder,
) -> candle_core::Result<candle_nn::Conv2d> {
let shape = (out_c, in_c / cfg.groups, kernel, kernel);
let ws = vb
.get(shape, "weight")
.or_else(|_| vb.get(shape, "weight_orig"))?;
let bs = vb.get(out_c, "bias").ok();
Ok(candle_nn::Conv2d::new(ws, bs, cfg))
}
fn cfg_pad1() -> candle_nn::Conv2dConfig {
candle_nn::Conv2dConfig {
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
..Default::default()
}
}
fn cfg_stride2() -> candle_nn::Conv2dConfig {
candle_nn::Conv2dConfig {
padding: 1,
stride: 2,
dilation: 1,
groups: 1,
..Default::default()
}
}
fn cfg_no_pad() -> candle_nn::Conv2dConfig {
candle_nn::Conv2dConfig {
padding: 0,
stride: 1,
dilation: 1,
groups: 1,
..Default::default()
}
}
struct ResBlk2d {
conv1: candle_nn::Conv2d,
conv2: candle_nn::Conv2d,
downsample_conv: candle_nn::Conv2d,
conv1x1: Option<candle_nn::Conv2d>,
}
impl ResBlk2d {
fn load(dim_in: usize, dim_out: usize, vb: VarBuilder) -> candle_core::Result<Self> {
let conv1 = load_conv2d(dim_in, dim_in, 3, cfg_pad1(), vb.pp("conv1"))?;
let conv2 = load_conv2d(dim_in, dim_out, 3, cfg_pad1(), vb.pp("conv2"))?;
let downsample_conv = load_conv2d(
dim_in,
dim_in,
3,
cfg_stride2(),
vb.pp("downsample_res").pp("conv"),
)?;
let conv1x1 = if dim_in != dim_out {
Some(load_conv2d(
dim_in,
dim_out,
1,
cfg_no_pad(),
vb.pp("conv1x1"),
)?)
} else {
None
};
Ok(Self {
conv1,
conv2,
downsample_conv,
conv1x1,
})
}
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let shortcut = match &self.conv1x1 {
Some(c) => avg_pool2d_2x2(&c.forward(x)?)?,
None => avg_pool2d_2x2(x)?,
};
let h = leaky_relu(x, 0.2)?;
let h = self.conv1.forward(&h)?;
let h = self.downsample_conv.forward(&h)?;
let h = leaky_relu(&h, 0.2)?;
let h = self.conv2.forward(&h)?;
let sum: Tensor = (shortcut + h)?;
scale_tensor(&sum, std::f32::consts::FRAC_1_SQRT_2)
}
}
struct SingleStyleEncoder {
initial_conv: candle_nn::Conv2d,
res_blocks: Vec<ResBlk2d>,
final_conv: candle_nn::Conv2d,
fc: candle_nn::Linear,
}
impl SingleStyleEncoder {
fn load(
dim_in: usize,
style_dim: usize,
max_conv_dim: usize,
vb: VarBuilder,
) -> candle_core::Result<Self> {
let vb_shared = vb.pp("shared");
let initial_conv = load_conv2d(1, dim_in, 3, cfg_pad1(), vb_shared.pp("0"))?;
let mut cur_dim = dim_in;
let mut res_blocks = Vec::with_capacity(4);
for i in 0..4u32 {
let next_dim = (cur_dim * 2).min(max_conv_dim);
let blk = ResBlk2d::load(cur_dim, next_dim, vb_shared.pp((i + 1).to_string()))?;
res_blocks.push(blk);
cur_dim = next_dim;
}
let final_conv = load_conv2d(cur_dim, cur_dim, 5, cfg_no_pad(), vb_shared.pp("6"))?;
let fc = candle_nn::linear(cur_dim, style_dim, vb.pp("unshared"))?;
Ok(Self {
initial_conv,
res_blocks,
final_conv,
fc,
})
}
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let mut h = self.initial_conv.forward(x)?;
for blk in &self.res_blocks {
h = blk.forward(&h)?;
}
h = leaky_relu(&h, 0.2)?;
h = self.final_conv.forward(&h)?;
h = h.mean_keepdim(candle_core::D::Minus1)?;
h = h.mean_keepdim(candle_core::D::Minus2)?;
h = h.flatten_from(1)?;
self.fc.forward(&h)
}
}
pub struct StyleEncoder {
acoustic: SingleStyleEncoder,
prosodic: SingleStyleEncoder,
mel: MelSpectrogram,
target_sample_rate: u32,
}
impl StyleEncoder {
pub fn try_load(
dim_in: usize,
style_dim: usize,
max_conv_dim: usize,
sample_rate: u32,
vb: &VarBuilder,
device: &Device,
) -> TtsResult<Option<Self>> {
let acoustic =
match SingleStyleEncoder::load(dim_in, style_dim, max_conv_dim, vb.pp("style_encoder"))
{
Ok(enc) => enc,
Err(e) => {
warn!(
"Style encoder weights not found (voice cloning unavailable): {}",
e
);
return Ok(None);
}
};
let prosodic =
SingleStyleEncoder::load(dim_in, style_dim, max_conv_dim, vb.pp("predictor_encoder"))
.map_err(|e| {
TtsError::WeightLoadError(format!(
"Found style_encoder but not predictor_encoder: {}",
e
))
})?;
let mel = MelSpectrogram::new(MelConfig::kokoro(), device)?;
info!("Style encoders loaded — voice cloning available");
Ok(Some(Self {
acoustic,
prosodic,
mel,
target_sample_rate: sample_rate,
}))
}
pub fn encode(&self, audio: &ReferenceAudio, dtype: DType) -> TtsResult<Tensor> {
let samples = if audio.sample_rate != self.target_sample_rate {
info!(
"Resampling reference audio from {} Hz to {} Hz",
audio.sample_rate, self.target_sample_rate
);
crate::mel::resample_linear(&audio.samples, audio.sample_rate, self.target_sample_rate)
} else {
audio.samples.clone()
};
let device = self.mel.config().n_fft; let _ = device;
let audio_tensor = Tensor::new(samples.as_slice(), &Device::Cpu)?;
let mel_spec = self.mel.compute(&audio_tensor)?;
let mel_input = mel_spec.unsqueeze(1)?.to_dtype(dtype)?;
let acoustic_style = self.acoustic.forward(&mel_input)?; let prosodic_style = self.prosodic.forward(&mel_input)?;
Tensor::cat(&[&acoustic_style, &prosodic_style], 1).map_err(TtsError::from)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_leaky_relu() {
let device = Device::Cpu;
let x = Tensor::new(&[-2.0f32, -1.0, 0.0, 1.0, 2.0], &device).unwrap();
let y = leaky_relu(&x, 0.2).unwrap();
let vals: Vec<f32> = y.to_vec1().unwrap();
assert!((vals[0] - (-0.4)).abs() < 1e-5);
assert!((vals[1] - (-0.2)).abs() < 1e-5);
assert!((vals[2] - 0.0).abs() < 1e-5);
assert!((vals[3] - 1.0).abs() < 1e-5);
assert!((vals[4] - 2.0).abs() < 1e-5);
}
#[test]
fn test_avg_pool2d_2x2() {
let device = Device::Cpu;
let data: Vec<f32> = (1..=16).map(|x| x as f32).collect();
let x = Tensor::new(data.as_slice(), &device)
.unwrap()
.reshape((1, 1, 4, 4))
.unwrap();
let y = avg_pool2d_2x2(&x).unwrap();
assert_eq!(y.dims(), &[1, 1, 2, 2]);
let vals: Vec<f32> = y.flatten_all().unwrap().to_vec1().unwrap();
assert!((vals[0] - 3.5).abs() < 1e-5);
assert!((vals[1] - 5.5).abs() < 1e-5);
assert!((vals[2] - 11.5).abs() < 1e-5);
assert!((vals[3] - 13.5).abs() < 1e-5);
}
}