#![allow(dead_code)]
use crate::reference::kokoro::ops::leaky_relu;
pub mod acoustic;
pub mod decoder;
pub mod diffusion;
pub mod gpu;
pub mod mel;
pub mod model;
pub mod style_encoder;
pub use acoustic::StyleTtsAcoustic;
pub use mel::MelFrontend;
pub use model::StyleTtsModel;
pub use style_encoder::StyleEncoder;
#[derive(Clone)]
pub struct Map {
pub data: Vec<f32>,
pub c: usize,
pub h: usize,
pub w: usize,
}
impl Map {
pub fn new(data: Vec<f32>, c: usize, h: usize, w: usize) -> Self {
debug_assert_eq!(data.len(), c * h * w);
Self { data, c, h, w }
}
}
pub fn conv2d(
x: &Map,
w: &[f32],
b: Option<&[f32]>,
out_c: usize,
kh: usize,
kw: usize,
stride: usize,
pad: usize,
groups: usize,
) -> Map {
let (ic, h, win) = (x.c, x.h, x.w);
let ho = (h + 2 * pad - kh) / stride + 1;
let wo = (win + 2 * pad - kw) / stride + 1;
let icpg = ic / groups;
let ocpg = out_c / groups;
let mut out = vec![0f32; out_c * ho * wo];
for oc in 0..out_c {
let g = oc / ocpg;
let bias = b.map_or(0.0, |bb| bb[oc]);
for oy in 0..ho {
for ox in 0..wo {
let mut acc = bias;
for icj in 0..icpg {
let in_c = g * icpg + icj;
let wbase = (oc * icpg + icj) * kh * kw;
let xbase = in_c * h * win;
for ky in 0..kh {
let iy = oy * stride + ky;
if iy < pad || iy >= h + pad {
continue;
}
let iy = iy - pad;
for kx in 0..kw {
let ix = ox * stride + kx;
if ix < pad || ix >= win + pad {
continue;
}
let ix = ix - pad;
acc += x.data[xbase + iy * win + ix] * w[wbase + ky * kw + kx];
}
}
}
out[oc * ho * wo + oy * wo + ox] = acc;
}
}
}
Map::new(out, out_c, ho, wo)
}
pub fn avg_pool2d_half(x: &Map) -> Map {
let (padded, w) = if !x.w.is_multiple_of(2) {
let mut p = vec![0f32; x.c * x.h * (x.w + 1)];
for c in 0..x.c {
for y in 0..x.h {
let src = c * x.h * x.w + y * x.w;
let dst = c * x.h * (x.w + 1) + y * (x.w + 1);
p[dst..dst + x.w].copy_from_slice(&x.data[src..src + x.w]);
p[dst + x.w] = x.data[src + x.w - 1]; }
}
(p, x.w + 1)
} else {
(x.data.clone(), x.w)
};
let ho = x.h / 2;
let wo = w / 2;
let mut out = vec![0f32; x.c * ho * wo];
for c in 0..x.c {
for oy in 0..ho {
for ox in 0..wo {
let base = c * x.h * w;
let s = padded[base + (2 * oy) * w + 2 * ox]
+ padded[base + (2 * oy) * w + 2 * ox + 1]
+ padded[base + (2 * oy + 1) * w + 2 * ox]
+ padded[base + (2 * oy + 1) * w + 2 * ox + 1];
out[c * ho * wo + oy * wo + ox] = s * 0.25;
}
}
}
Map::new(out, x.c, ho, wo)
}
pub fn adaptive_avg_pool2d_1(x: &Map) -> Vec<f32> {
let hw = (x.h * x.w) as f32;
(0..x.c)
.map(|c| {
x.data[c * x.h * x.w..(c + 1) * x.h * x.w]
.iter()
.sum::<f32>()
/ hw
})
.collect()
}
const LRELU: f32 = 0.2;
pub struct Conv {
pub w: Vec<f32>,
pub b: Option<Vec<f32>>,
pub oc: usize,
pub kh: usize,
pub kw: usize,
pub stride: usize,
pub pad: usize,
pub groups: usize,
}
impl Conv {
pub fn apply(&self, x: &Map) -> Map {
conv2d(
x,
&self.w,
self.b.as_deref(),
self.oc,
self.kh,
self.kw,
self.stride,
self.pad,
self.groups,
)
}
}
pub struct ResBlk {
pub conv1: Conv, pub down: Conv, pub conv2: Conv, pub sc: Option<Conv>, }
impl ResBlk {
pub fn forward(&self, x: &Map) -> Map {
let sc = match &self.sc {
Some(c) => c.apply(x),
None => x.clone(),
};
let shortcut = avg_pool2d_half(&sc);
let mut r = x.clone();
leaky_relu(&mut r.data, LRELU);
let r = self.conv1.apply(&r);
let mut r = self.down.apply(&r);
leaky_relu(&mut r.data, LRELU);
let r = self.conv2.apply(&r);
debug_assert_eq!(shortcut.data.len(), r.data.len());
let inv = 1.0 / std::f32::consts::SQRT_2;
let data: Vec<f32> = shortcut
.data
.iter()
.zip(&r.data)
.map(|(a, b)| (a + b) * inv)
.collect();
Map::new(data, r.c, r.h, r.w)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn avg_pool_odd_width_repeats_last_col() {
let x = Map::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 1, 2, 3);
let p = avg_pool2d_half(&x);
assert_eq!((p.c, p.h, p.w), (1, 1, 2));
assert!((p.data[0] - 3.0).abs() < 1e-6);
assert!((p.data[1] - 4.5).abs() < 1e-6);
}
#[test]
fn conv2d_identity_1x1() {
let x = Map::new(vec![1.0, 2.0, 3.0, 4.0], 1, 2, 2);
let w = vec![2.0]; let y = conv2d(&x, &w, None, 1, 1, 1, 1, 0, 1);
assert_eq!(y.data, vec![2.0, 4.0, 6.0, 8.0]);
}
}