#![allow(dead_code)]
use super::KokoroModel;
use super::convblocks::conv1d;
use super::ops::{bilstm, layer_norm_plain, linear, sigmoid};
type Bilstm = [Vec<f32>; 8];
impl KokoroModel {
pub(crate) fn load_bilstm(&self, prefix: &str) -> Bilstm {
[
self.t(&format!("{prefix}.weight_ih_l0")),
self.t(&format!("{prefix}.weight_hh_l0")),
self.t(&format!("{prefix}.bias_ih_l0")),
self.t(&format!("{prefix}.bias_hh_l0")),
self.t(&format!("{prefix}.weight_ih_l0_reverse")),
self.t(&format!("{prefix}.weight_hh_l0_reverse")),
self.t(&format!("{prefix}.bias_ih_l0_reverse")),
self.t(&format!("{prefix}.bias_hh_l0_reverse")),
]
}
pub(crate) fn run_bilstm(
&self,
w: &Bilstm,
x: &[f32],
t: usize,
in_dim: usize,
hidden: usize,
) -> Vec<f32> {
bilstm(
x, t, in_dim, hidden, &w[0], &w[1], &w[2], &w[3], &w[4], &w[5], &w[6], &w[7],
)
}
pub fn bert_encoder(&self, bert: &[f32], t: usize) -> Vec<f32> {
let h = self.cfg.plbert_hidden; let d = self.cfg.hidden_dim; let w = self.t("k.bert_encoder.weight");
let b = self.t("k.bert_encoder.bias");
linear(bert, t, h, &w, Some(&b), d)
}
pub fn duration_encode(&self, be: &[f32], t: usize, style: &[f32]) -> Vec<f32> {
let d = self.cfg.hidden_dim; let sd = self.cfg.style_dim; let cat = d + sd; let mut x = vec![0.0f32; t * cat];
for ti in 0..t {
x[ti * cat..ti * cat + d].copy_from_slice(&be[ti * d..(ti + 1) * d]);
x[ti * cat + d..(ti + 1) * cat].copy_from_slice(style);
}
for layer in 0..self.cfg.n_layer {
let lw = self.load_bilstm(&format!("k.predictor.text_encoder.lstms.{}", 2 * layer));
let lstm_out = self.run_bilstm(&lw, &x, t, cat, d / 2);
let fc_w = self.t(&format!(
"k.predictor.text_encoder.lstms.{}.fc.weight",
2 * layer + 1
));
let fc_b = self.t(&format!(
"k.predictor.text_encoder.lstms.{}.fc.bias",
2 * layer + 1
));
let gb = linear(style, 1, sd, &fc_w, Some(&fc_b), 2 * d); let (gamma, beta) = gb.split_at(d);
let ln = layer_norm_plain(&lstm_out, t, d, 1e-5);
for ti in 0..t {
for c in 0..d {
x[ti * cat + c] = (1.0 + gamma[c]) * ln[ti * d + c] + beta[c];
}
x[ti * cat + d..(ti + 1) * cat].copy_from_slice(style);
}
}
x
}
pub fn predict_duration(&self, d: &[f32], t: usize) -> (Vec<f32>, Vec<usize>) {
let cat = self.cfg.hidden_dim + self.cfg.style_dim; let hid = self.cfg.hidden_dim; let lw = self.load_bilstm("k.predictor.lstm");
let x = self.run_bilstm(&lw, d, t, cat, hid / 2); let w = self.t("k.predictor.duration_proj.linear_layer.weight");
let b = self.t("k.predictor.duration_proj.linear_layer.bias");
let logits = linear(&x, t, hid, &w, Some(&b), self.cfg.max_dur); let mut pred_dur = vec![0usize; t];
for ti in 0..t {
let s: f32 = logits[ti * self.cfg.max_dur..(ti + 1) * self.cfg.max_dur]
.iter()
.map(|&v| sigmoid(v))
.sum();
pred_dur[ti] = s.round().max(1.0) as usize;
}
(logits, pred_dur)
}
pub fn expand_by_dur_cm(
&self,
feat: &[f32],
t: usize,
c: usize,
dur: &[usize],
) -> (Vec<f32>, usize) {
let f: usize = dur.iter().sum();
let mut out = vec![0.0f32; c * f];
let mut fi = 0;
for ti in 0..t {
for _ in 0..dur[ti] {
for cc in 0..c {
out[cc * f + fi] = feat[ti * c + cc];
}
fi += 1;
}
}
(out, f)
}
pub fn f0_n(&self, en: &[f32], f: usize, style: &[f32]) -> (Vec<f32>, Vec<f32>) {
let cat = self.cfg.hidden_dim + self.cfg.style_dim; let hid = self.cfg.hidden_dim; let mut x_rm = vec![0.0f32; f * cat];
for ff in 0..f {
for c in 0..cat {
x_rm[ff * cat + c] = en[c * f + ff];
}
}
let sw = self.load_bilstm("k.predictor.shared");
let xs = self.run_bilstm(&sw, &x_rm, f, cat, hid / 2); let mut x_cm = vec![0.0f32; hid * f];
for ff in 0..f {
for c in 0..hid {
x_cm[c * f + ff] = xs[ff * hid + c];
}
}
let half = hid / 2; let run_stack = |which: &str| -> Vec<f32> {
let (h, t1) = self.adain_resblk1d(
&format!("k.predictor.{which}.0"),
&x_cm,
hid,
f,
hid,
false,
style,
);
let (h, t2) = self.adain_resblk1d(
&format!("k.predictor.{which}.1"),
&h,
hid,
t1,
half,
true,
style,
);
let (h, t3) = self.adain_resblk1d(
&format!("k.predictor.{which}.2"),
&h,
half,
t2,
half,
false,
style,
);
let pw = self.t(&format!("k.predictor.{which}_proj.weight"));
let pb = self.t(&format!("k.predictor.{which}_proj.bias"));
conv1d(&h, half, t3, &pw, Some(&pb), 1, 1, 1, 0, 1, 1).0 };
(run_stack("F0"), run_stack("N"))
}
}