use crate::reference::kokoro::ops::{layer_norm_plain, linear, softmax};
use std::collections::HashMap;
const C: usize = 256; const E: usize = 768; const F: usize = C + E; const S: usize = 256; const HEADS: usize = 8;
const HD: usize = 64; const MID: usize = HEADS * HD; const NBLK: usize = 3;
const EPS: f32 = 1e-5;
fn gelu_exact(v: &mut [f32]) {
for x in v.iter_mut() {
let z = *x / std::f32::consts::SQRT_2;
let t = 1.0 / (1.0 + 0.327_591_1 * z.abs());
let y = 1.0
- (((((1.061_405_4 * t - 1.453_152_) * t) + 1.421_413_7) * t - 0.284_496_74) * t
+ 0.254_829_6)
* t
* (-z * z).exp();
let erf = if z >= 0.0 { y } else { -y };
*x *= 0.5 * (1.0 + erf);
}
}
pub struct StyleDiffusion<'a> {
w: &'a HashMap<String, Vec<f32>>,
sigma_data: f32,
sigma_min: f32,
sigma_max: f32,
rho: f32,
steps: usize,
}
impl<'a> StyleDiffusion<'a> {
pub fn new(w: &'a HashMap<String, Vec<f32>>) -> Self {
let f = |k: &str, d: f32| w.get(k).and_then(|v| v.first().copied()).unwrap_or(d);
Self {
w,
sigma_data: f("diff_sigma_data", 0.2),
sigma_min: f("diff_sigma_min", 1e-4),
sigma_max: f("diff_sigma_max", 3.0),
rho: f("diff_rho", 9.0),
steps: w
.get("diff_steps")
.and_then(|v| v.first().copied())
.unwrap_or(5.0) as usize,
}
}
fn g(&self, name: &str) -> &[f32] {
self.w
.get(&format!("diffusion.{name}"))
.unwrap_or_else(|| panic!("missing diffusion.{name}"))
}
fn ada_ln(&self, h: &[f32], l: usize, s: &[f32], fc: &str) -> Vec<f32> {
let gb = linear(
s,
1,
S,
self.g(&format!("{fc}.weight")),
Some(self.g(&format!("{fc}.bias"))),
2 * F,
);
let (gamma, beta) = (&gb[..F], &gb[F..]);
let ln = layer_norm_plain(h, l, F, EPS);
let mut out = vec![0f32; l * F];
for t in 0..l {
for c in 0..F {
out[t * F + c] = (1.0 + gamma[c]) * ln[t * F + c] + beta[c];
}
}
out
}
fn block(&self, h: &mut [f32], l: usize, s: &[f32], i: usize) {
let p = |s: &str| format!("blocks.{i}.{s}");
let xn = self.ada_ln(h, l, s, &p("attention.norm.fc"));
let cn = self.ada_ln(h, l, s, &p("attention.norm_context.fc"));
let q = linear(&xn, l, F, self.g(&p("attention.to_q.weight")), None, MID);
let kv = linear(
&cn,
l,
F,
self.g(&p("attention.to_kv.weight")),
None,
2 * MID,
);
let scale = (HD as f32).powf(-0.5);
let mut ctx = vec![0f32; l * MID]; for hd in 0..HEADS {
let off = hd * HD;
for i_q in 0..l {
let mut sim = vec![0f32; l];
for j in 0..l {
let mut dot = 0.0;
for d in 0..HD {
dot += q[i_q * MID + off + d] * kv[j * 2 * MID + off + d];
}
sim[j] = dot * scale;
}
softmax(&mut sim);
for d in 0..HD {
let mut acc = 0.0;
for j in 0..l {
acc += sim[j] * kv[j * 2 * MID + MID + off + d]; }
ctx[i_q * MID + off + d] = acc;
}
}
}
let attn = linear(
&ctx,
l,
MID,
self.g(&p("attention.attention.to_out.weight")),
Some(self.g(&p("attention.attention.to_out.bias"))),
F,
);
for k in 0..l * F {
h[k] += attn[k];
}
let mut ff = linear(
h,
l,
F,
self.g(&p("feed_forward.0.weight")),
Some(self.g(&p("feed_forward.0.bias"))),
2 * F,
);
gelu_exact(&mut ff);
let ff = linear(
&ff,
l,
2 * F,
self.g(&p("feed_forward.2.weight")),
Some(self.g(&p("feed_forward.2.bias"))),
F,
);
for k in 0..l * F {
h[k] += ff[k];
}
}
fn net(&self, x: &[f32], time: f32, emb: &[f32], l: usize, s: &[f32]) -> Vec<f32> {
let mut tpos = vec![0f32; 257];
tpos[0] = time;
let tw = self.g("to_time.0.0.weights"); for j in 0..128 {
let f = time * tw[j] * 2.0 * std::f32::consts::PI;
tpos[1 + j] = f.sin();
tpos[1 + 128 + j] = f.cos();
}
let mut t_emb = linear(
&tpos,
1,
257,
self.g("to_time.0.1.weight"),
Some(self.g("to_time.0.1.bias")),
F,
);
gelu_exact(&mut t_emb);
let mut f_emb = linear(
s,
1,
S,
self.g("to_features.0.weight"),
Some(self.g("to_features.0.bias")),
F,
);
gelu_exact(&mut f_emb);
let mut mapping: Vec<f32> = (0..F).map(|k| t_emb[k] + f_emb[k]).collect();
mapping = linear(
&mapping,
1,
F,
self.g("to_mapping.0.weight"),
Some(self.g("to_mapping.0.bias")),
F,
);
gelu_exact(&mut mapping);
mapping = linear(
&mapping,
1,
F,
self.g("to_mapping.2.weight"),
Some(self.g("to_mapping.2.bias")),
F,
);
gelu_exact(&mut mapping);
let mut h = vec![0f32; l * F];
for t in 0..l {
h[t * F..t * F + C].copy_from_slice(&x[..C]);
h[t * F + C..t * F + F].copy_from_slice(&emb[t * E..t * E + E]);
}
for i in 0..NBLK {
for t in 0..l {
for k in 0..F {
h[t * F + k] += mapping[k];
}
}
self.block(&mut h, l, s, i);
}
let mut pooled = vec![0f32; F];
for t in 0..l {
for k in 0..F {
pooled[k] += h[t * F + k];
}
}
for v in pooled.iter_mut() {
*v /= l as f32;
}
linear(
&pooled,
1,
F,
self.g("to_out.1.weight"),
Some(self.g("to_out.1.bias")),
C,
)
}
fn denoise(&self, x: &[f32], sigma: f32, emb: &[f32], l: usize, s: &[f32]) -> Vec<f32> {
let sd = self.sigma_data;
let c_skip = sd * sd / (sigma * sigma + sd * sd);
let c_out = sigma * sd / (sd * sd + sigma * sigma).sqrt();
let c_in = 1.0 / (sigma * sigma + sd * sd).sqrt();
let c_noise = sigma.ln() * 0.25;
let xin: Vec<f32> = x.iter().map(|v| c_in * v).collect();
let pred = self.net(&xin, c_noise, emb, l, s);
(0..C).map(|k| c_skip * x[k] + c_out * pred[k]).collect()
}
fn karras_sigmas(&self) -> Vec<f32> {
let inv = 1.0 / self.rho;
let (a, b) = (self.sigma_max.powf(inv), self.sigma_min.powf(inv));
let mut s: Vec<f32> = (0..self.steps)
.map(|i| (a + (i as f32 / (self.steps - 1) as f32) * (b - a)).powf(self.rho))
.collect();
s.push(0.0);
s
}
pub fn sample(
&self,
noise_init: &[f32],
noises: &[Vec<f32>],
emb: &[f32],
l: usize,
ref_s: &[f32],
) -> Vec<f32> {
let sig = self.karras_sigmas();
let mut x: Vec<f32> = noise_init.iter().map(|v| sig[0] * v).collect();
for i in 0..self.steps - 1 {
let (s, sn) = (sig[i], sig[i + 1]);
let sigma_up = (sn * sn * (s * s - sn * sn) / (s * s)).sqrt();
let sigma_down = (sn * sn - sigma_up * sigma_up).sqrt();
let sigma_mid = (s + sigma_down) * 0.5; let dn = self.denoise(&x, s, emb, l, ref_s);
let d: Vec<f32> = (0..C).map(|k| (x[k] - dn[k]) / s).collect();
let x_mid: Vec<f32> = (0..C).map(|k| x[k] + d[k] * (sigma_mid - s)).collect();
let dn_mid = self.denoise(&x_mid, sigma_mid, emb, l, ref_s);
let d_mid: Vec<f32> = (0..C).map(|k| (x_mid[k] - dn_mid[k]) / sigma_mid).collect();
let nz = &noises[i];
for k in 0..C {
x[k] = x[k] + d_mid[k] * (sigma_down - s) + nz[k] * sigma_up;
}
}
x }
pub fn net_eval(&self, x: &[f32], time: f32, emb: &[f32], l: usize, ref_s: &[f32]) -> Vec<f32> {
self.net(x, time, emb, l, ref_s)
}
}