use crate::error::{Error, Result};
use crate::model::audio::kokoro::{AdaINResBlock1, MagPhaseHead, SourceModuleHnNSF, UpsampleBlock};
use numr::dtype::DType;
use numr::ops::{
ActivationOps, BinaryOps, CompareOps, ConvOps, MatmulOps, NormalizationOps, ReduceOps,
ScalarOps, ShapeOps, TensorOps, TypeConversionOps, UnaryOps, UtilityOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub struct GeneratorStftParams {
pub n_fft: usize,
pub hop_length: usize,
}
pub struct IStftNetGenerator<R: Runtime> {
pub m_source: SourceModuleHnNSF<R>,
pub ups: Vec<UpsampleBlock<R>>,
pub resblocks: Vec<AdaINResBlock1<R>>,
pub noise_convs: Vec<crate::nn::Conv1d<R>>,
pub noise_res: Vec<AdaINResBlock1<R>>,
pub conv_post: MagPhaseHead<R>,
pub num_kernels: usize,
pub leaky_slope: f64,
pub stft: GeneratorStftParams,
pub last_stage_reflect_pad: usize,
pub f0_upsample_factor: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct IStftNetGeneratorOpts {
pub num_kernels: usize,
pub leaky_slope: f64,
pub stft: GeneratorStftParams,
pub last_stage_reflect_pad: usize,
pub f0_upsample_factor: usize,
}
impl Default for IStftNetGeneratorOpts {
fn default() -> Self {
Self {
num_kernels: 3,
leaky_slope: 0.1,
stft: GeneratorStftParams {
n_fft: 20,
hop_length: 5,
},
last_stage_reflect_pad: 3,
f0_upsample_factor: 60,
}
}
}
impl<R: Runtime> IStftNetGenerator<R> {
pub fn new(
m_source: SourceModuleHnNSF<R>,
ups: Vec<UpsampleBlock<R>>,
resblocks: Vec<AdaINResBlock1<R>>,
noise_convs: Vec<crate::nn::Conv1d<R>>,
noise_res: Vec<AdaINResBlock1<R>>,
conv_post: MagPhaseHead<R>,
opts: IStftNetGeneratorOpts,
) -> Result<Self> {
if ups.is_empty() {
return Err(Error::InvalidArgument {
arg: "ups",
reason: "must have at least one upsample stage".into(),
});
}
if opts.num_kernels == 0 {
return Err(Error::InvalidArgument {
arg: "opts.num_kernels",
reason: "must be > 0".into(),
});
}
if resblocks.len() != ups.len() * opts.num_kernels {
return Err(Error::InvalidArgument {
arg: "resblocks",
reason: format!(
"expected {} resblocks (num_upsamples {} * num_kernels {}), got {}",
ups.len() * opts.num_kernels,
ups.len(),
opts.num_kernels,
resblocks.len()
),
});
}
match (noise_convs.len(), noise_res.len()) {
(0, 0) => {}
(a, b) if a == ups.len() && b == ups.len() => {}
(a, b) => {
return Err(Error::InvalidArgument {
arg: "noise_convs / noise_res",
reason: format!(
"must both be empty OR both match num_upsamples ({}); got ({a}, {b})",
ups.len()
),
});
}
}
if opts.stft.n_fft == 0 || opts.stft.hop_length == 0 {
return Err(Error::InvalidArgument {
arg: "opts.stft",
reason: "n_fft and hop_length must be > 0".into(),
});
}
Ok(Self {
m_source,
ups,
resblocks,
noise_convs,
noise_res,
conv_post,
num_kernels: opts.num_kernels,
leaky_slope: opts.leaky_slope,
stft: opts.stft,
last_stage_reflect_pad: opts.last_stage_reflect_pad,
f0_upsample_factor: opts.f0_upsample_factor,
})
}
pub fn num_upsamples(&self) -> usize {
self.ups.len()
}
#[allow(clippy::type_complexity)]
pub fn forward<C>(
&self,
client: &C,
x: &Tensor<R>,
style: &Tensor<R>,
_f0: &Tensor<R>,
) -> Result<(Tensor<R>, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ ConvOps<R>
+ NormalizationOps<R>
+ ActivationOps<R>
+ TensorOps<R>
+ MatmulOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ CompareOps<R>
+ TypeConversionOps<R>
+ UtilityOps<R>,
{
let mut x = x.clone();
for stage in 0..self.num_upsamples() {
x = client
.leaky_relu(&x, self.leaky_slope)
.map_err(Error::Numr)?;
x = self.ups[stage].forward(client, &x)?;
let mut xs: Option<Tensor<R>> = None;
for k in 0..self.num_kernels {
let idx = stage * self.num_kernels + k;
let out = self.resblocks[idx].forward(client, &x, style)?;
xs = Some(match xs {
None => out,
Some(prev) => client.add(&prev, &out).map_err(Error::Numr)?,
});
}
let summed = xs.expect("at least one resblock per stage — validated in new()");
x = client
.mul_scalar(&summed, 1.0 / self.num_kernels as f64)
.map_err(Error::Numr)?;
}
let x = client
.leaky_relu(&x, self.leaky_slope)
.map_err(Error::Numr)?;
self.conv_post.forward(client, &x)
}
}
impl IStftNetGenerator<numr::runtime::cpu::CpuRuntime> {
pub fn has_noise_path(&self) -> bool {
!self.noise_convs.is_empty() && !self.noise_res.is_empty()
}
pub fn harmonic_excitation_spec_cpu(
&self,
client: &numr::runtime::cpu::CpuClient,
f0: &numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>,
n_fft: usize,
hop_length: usize,
) -> Result<numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>> {
let f0_shape = f0.shape();
if f0_shape.len() != 3 || f0_shape[2] != 1 {
return Err(Error::InvalidArgument {
arg: "f0",
reason: format!("expected [B, T, 1], got {f0_shape:?}"),
});
}
let scale = self.f0_upsample_factor.max(1);
let (b, t) = (f0_shape[0], f0_shape[1]);
let f0_audio = if scale == 1 {
f0.clone()
} else {
f0.broadcast_to(&[b, t, scale])
.map_err(Error::Numr)?
.contiguous()?
.reshape(&[b, t * scale, 1])
.map_err(Error::Numr)?
};
let excitation = self.m_source.forward(client, &f0_audio)?;
let exc_shape = excitation.shape();
let (bb, t_audio) = (exc_shape[0], exc_shape[1]);
let waveform = excitation.reshape(&[bb, t_audio]).map_err(Error::Numr)?;
let hann = crate::model::audio::kokoro::hann_window(n_fft, f0.device());
let (mag, phase) = crate::model::audio::stft::stft(
&waveform,
&hann,
crate::model::audio::stft::StftOptions {
n_fft,
hop_length,
center: true,
},
)?;
use numr::ops::ShapeOps;
client.cat(&[&mag, &phase], 1).map_err(Error::Numr)
}
#[allow(clippy::type_complexity)]
pub fn forward_cpu_full(
&self,
client: &numr::runtime::cpu::CpuClient,
x: &numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>,
style: &numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>,
f0: &numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>,
) -> Result<(
numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>,
numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>,
)> {
if !self.has_noise_path() {
return self.forward(client, x, style, f0);
}
let har =
self.harmonic_excitation_spec_cpu(client, f0, self.stft.n_fft, self.stft.hop_length)?;
let mut x = x.clone();
for stage in 0..self.num_upsamples() {
x = client
.leaky_relu(&x, self.leaky_slope)
.map_err(Error::Numr)?;
let noise_c = self.noise_convs[stage].forward_inference(client, &har)?;
let x_source = self.noise_res[stage].forward(client, &noise_c, style)?;
x = self.ups[stage].forward(client, &x)?;
let trunk_t = x.shape()[2];
let source_t = x_source.shape()[2];
let x_source = if source_t > trunk_t {
x_source
.narrow(2, 0, trunk_t)
.map_err(Error::Numr)?
.contiguous()?
} else if source_t < trunk_t {
return Err(Error::InvalidArgument {
arg: "x_source",
reason: format!(
"noise residual is shorter ({source_t}) than trunk ({trunk_t}); \
check f0_upsample_factor vs upsample_ratios config"
),
});
} else {
x_source
};
x = client.add(&x, &x_source).map_err(Error::Numr)?;
let is_last = stage == self.num_upsamples() - 1;
if is_last && self.last_stage_reflect_pad > 0 {
x = crate::model::audio::reflection_pad::reflection_pad_1d(
&x,
self.last_stage_reflect_pad,
self.last_stage_reflect_pad,
)?;
}
let mut xs: Option<numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>> = None;
for k in 0..self.num_kernels {
let idx = stage * self.num_kernels + k;
let out = self.resblocks[idx].forward(client, &x, style)?;
xs = Some(match xs {
None => out,
Some(prev) => client.add(&prev, &out).map_err(Error::Numr)?,
});
}
let summed = xs.expect("at least one resblock per stage — validated in new()");
x = client
.mul_scalar(&summed, 1.0 / self.num_kernels as f64)
.map_err(Error::Numr)?;
}
let x = client
.leaky_relu(&x, self.leaky_slope)
.map_err(Error::Numr)?;
self.conv_post.forward(client, &x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::audio::kokoro::{KokoroAdaIn1d, PoolParams, SineGen};
use crate::nn::Conv1d;
use crate::test_utils::cpu_setup;
use numr::ops::PaddingMode;
use numr::runtime::cpu::CpuRuntime;
fn zeros(shape: &[usize], device: &<CpuRuntime as Runtime>::Device) -> Tensor<CpuRuntime> {
let n: usize = shape.iter().product();
Tensor::<CpuRuntime>::from_slice(&vec![0.0f32; n], shape, device)
}
fn ones(shape: &[usize], device: &<CpuRuntime as Runtime>::Device) -> Tensor<CpuRuntime> {
let n: usize = shape.iter().product();
Tensor::<CpuRuntime>::from_slice(&vec![1.0f32; n], shape, device)
}
fn conv(
c_out: usize,
c_in: usize,
k: usize,
device: &<CpuRuntime as Runtime>::Device,
) -> Conv1d<CpuRuntime> {
Conv1d::new(
zeros(&[c_out, c_in, k], device),
Some(zeros(&[c_out], device)),
1,
PaddingMode::Same,
1,
1,
false,
)
}
fn adain(
c: usize,
s: usize,
device: &<CpuRuntime as Runtime>::Device,
) -> KokoroAdaIn1d<CpuRuntime> {
KokoroAdaIn1d::new(
zeros(&[2 * c, s], device),
zeros(&[2 * c], device),
ones(&[c], device),
zeros(&[c], device),
1e-5,
)
.unwrap()
}
fn resblock(
c: usize,
s: usize,
device: &<CpuRuntime as Runtime>::Device,
) -> AdaINResBlock1<CpuRuntime> {
AdaINResBlock1::new(
[
conv(c, c, 3, device),
conv(c, c, 3, device),
conv(c, c, 3, device),
],
[
conv(c, c, 3, device),
conv(c, c, 3, device),
conv(c, c, 3, device),
],
[
adain(c, s, device),
adain(c, s, device),
adain(c, s, device),
],
[
adain(c, s, device),
adain(c, s, device),
adain(c, s, device),
],
[
ones(&[1, c, 1], device),
ones(&[1, c, 1], device),
ones(&[1, c, 1], device),
],
[
ones(&[1, c, 1], device),
ones(&[1, c, 1], device),
ones(&[1, c, 1], device),
],
1e-9,
)
.unwrap()
}
fn build_tiny_generator(
device: &<CpuRuntime as Runtime>::Device,
) -> IStftNetGenerator<CpuRuntime> {
let style_dim = 4;
let n_fft = 4;
let ups = vec![
UpsampleBlock::new(
zeros(&[8, 4, 2], device), None,
2,
PaddingMode::Valid,
0,
1,
1,
0.1,
),
UpsampleBlock::new(
zeros(&[4, 4, 1], device),
None,
1,
PaddingMode::Valid,
0,
1,
1,
0.1,
),
];
let resblocks = vec![
resblock(4, style_dim, device),
resblock(4, style_dim, device),
resblock(4, style_dim, device),
resblock(4, style_dim, device),
];
let source = SourceModuleHnNSF::new(
SineGen::new(24_000.0, 1),
zeros(&[1, 2], device),
zeros(&[1], device),
)
.unwrap();
let mag_phase = MagPhaseHead::new(conv(2 * (n_fft / 2 + 1), 4, 3, device), n_fft).unwrap();
IStftNetGenerator::new(
source,
ups,
resblocks,
Vec::new(),
Vec::new(),
mag_phase,
IStftNetGeneratorOpts {
num_kernels: 2,
last_stage_reflect_pad: 0,
..Default::default()
},
)
.unwrap()
}
#[test]
fn forward_returns_mag_phase_shapes() {
let (client, device) = cpu_setup();
let g = build_tiny_generator(&device);
let x = zeros(&[1, 8, 3], &device);
let style = zeros(&[1, 4], &device);
let f0 = zeros(&[1, 3, 1], &device);
let (mag, phase) = g.forward(&client, &x, &style, &f0).unwrap();
assert_eq!(mag.shape(), &[1, 3, 6]); assert_eq!(phase.shape(), &[1, 3, 6]);
}
#[test]
fn new_rejects_wrong_resblock_count() {
let (_client, device) = cpu_setup();
let ups = vec![UpsampleBlock::new(
zeros(&[4, 4, 1], &device),
None,
1,
PaddingMode::Valid,
0,
1,
1,
0.1,
)];
let resblocks = vec![resblock(4, 2, &device), resblock(4, 2, &device)];
let source = SourceModuleHnNSF::new(
SineGen::new(24_000.0, 1),
zeros(&[1, 2], &device),
zeros(&[1], &device),
)
.unwrap();
let mag_phase = MagPhaseHead::new(conv(6, 4, 3, &device), 4).unwrap();
let bad = IStftNetGenerator::new(
source,
ups,
resblocks,
Vec::new(),
Vec::new(),
mag_phase,
IStftNetGeneratorOpts {
num_kernels: 3,
last_stage_reflect_pad: 0,
..Default::default()
},
);
assert!(bad.is_err());
}
#[test]
fn new_rejects_empty_ups() {
let (_client, device) = cpu_setup();
let source = SourceModuleHnNSF::new(
SineGen::new(24_000.0, 1),
zeros(&[1, 2], &device),
zeros(&[1], &device),
)
.unwrap();
let mag_phase = MagPhaseHead::new(conv(6, 4, 3, &device), 4).unwrap();
let bad = IStftNetGenerator::new(
source,
Vec::new(),
Vec::new(),
Vec::new(),
Vec::new(),
mag_phase,
IStftNetGeneratorOpts {
num_kernels: 1,
last_stage_reflect_pad: 0,
..Default::default()
},
);
assert!(bad.is_err());
}
#[test]
fn _pool_params_type_is_in_scope() {
let _: Option<PoolParams<CpuRuntime>> = None;
}
}