use crate::error::{Error, Result};
use crate::model::audio::kokoro::{AdainResBlk1d, IStftNetGenerator};
use crate::nn::Conv1d;
use numr::dtype::DType;
use numr::ops::{
ActivationOps, BinaryOps, CompareOps, ConvOps, MatmulOps, NormalizationOps, ReduceOps,
ScalarOps, ShapeOps, TensorOps, TypeConversionOps, UnaryOps, UtilityOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub struct Decoder<R: Runtime> {
pub asr_res: Conv1d<R>,
pub f0_conv: Conv1d<R>,
pub n_conv: Conv1d<R>,
pub encode: AdainResBlk1d<R>,
pub decode: Vec<AdainResBlk1d<R>>,
pub generator: IStftNetGenerator<R>,
}
impl<R: Runtime> Decoder<R> {
pub fn new(
asr_res: Conv1d<R>,
f0_conv: Conv1d<R>,
n_conv: Conv1d<R>,
encode: AdainResBlk1d<R>,
decode: Vec<AdainResBlk1d<R>>,
generator: IStftNetGenerator<R>,
) -> Result<Self> {
if decode.is_empty() {
return Err(Error::InvalidArgument {
arg: "decode",
reason: "must have at least one decode block".into(),
});
}
Ok(Self {
asr_res,
f0_conv,
n_conv,
encode,
decode,
generator,
})
}
#[allow(clippy::type_complexity)]
pub fn forward<C>(
&self,
client: &C,
asr_feats: &Tensor<R>,
f0_curve: &Tensor<R>,
n_curve: &Tensor<R>,
style: &Tensor<R>,
) -> Result<(Tensor<R>, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ ConvOps<R>
+ NormalizationOps<R>
+ ActivationOps<R>
+ TensorOps<R>
+ MatmulOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ CompareOps<R>
+ TypeConversionOps<R>
+ UtilityOps<R>,
{
let f0_shape = f0_curve.shape();
let n_shape = n_curve.shape();
if f0_shape != n_shape || f0_shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "f0_curve/n_curve",
reason: format!(
"both must be [B, T_f0] with matching shapes, got {f0_shape:?} vs {n_shape:?}"
),
});
}
let (b, t_f0) = (f0_shape[0], f0_shape[1]);
let f0 = f0_curve.reshape(&[b, 1, t_f0]).map_err(Error::Numr)?;
let n = n_curve.reshape(&[b, 1, t_f0]).map_err(Error::Numr)?;
let f0 = self.f0_conv.forward_inference(client, &f0)?;
let n = self.n_conv.forward_inference(client, &n)?;
let x = client.cat(&[asr_feats, &f0, &n], 1).map_err(Error::Numr)?;
let mut x = self.encode.forward(client, &x, style)?;
let asr_res_proj = self.asr_res.forward_inference(client, asr_feats)?;
for block in &self.decode {
let cat = client
.cat(&[&x, &f0, &n, &asr_res_proj], 1)
.map_err(Error::Numr)?;
x = block.forward(client, &cat, style)?;
}
let f0_for_gen = f0_curve.reshape(&[b, t_f0, 1]).map_err(Error::Numr)?;
self.generator.forward(client, &x, style, &f0_for_gen)
}
}
impl Decoder<numr::runtime::cpu::CpuRuntime> {
#[allow(clippy::type_complexity)]
pub fn forward_cpu_full(
&self,
client: &numr::runtime::cpu::CpuClient,
asr_feats: &numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>,
f0_curve: &numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>,
n_curve: &numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>,
style: &numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>,
) -> Result<(
numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>,
numr::tensor::Tensor<numr::runtime::cpu::CpuRuntime>,
)> {
let f0_shape = f0_curve.shape();
let n_shape = n_curve.shape();
if f0_shape != n_shape || f0_shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "f0_curve/n_curve",
reason: format!(
"both must be [B, T_f0] with matching shapes, got {f0_shape:?} vs {n_shape:?}"
),
});
}
let (b, t_f0) = (f0_shape[0], f0_shape[1]);
let f0 = f0_curve.reshape(&[b, 1, t_f0]).map_err(Error::Numr)?;
let n = n_curve.reshape(&[b, 1, t_f0]).map_err(Error::Numr)?;
let f0 = self.f0_conv.forward_inference(client, &f0)?;
let n = self.n_conv.forward_inference(client, &n)?;
let x = client.cat(&[asr_feats, &f0, &n], 1).map_err(Error::Numr)?;
let mut x = self.encode.forward(client, &x, style)?;
let asr_res_proj = self.asr_res.forward_inference(client, asr_feats)?;
for block in &self.decode {
let cat = client
.cat(&[&x, &f0, &n, &asr_res_proj], 1)
.map_err(Error::Numr)?;
x = block.forward(client, &cat, style)?;
}
let f0_for_gen = f0_curve.reshape(&[b, t_f0, 1]).map_err(Error::Numr)?;
self.generator
.forward_cpu_full(client, &x, style, &f0_for_gen)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::audio::kokoro::AdaINResBlock1;
use crate::model::audio::kokoro::{
KokoroAdaIn1d, MagPhaseHead, SineGen, SourceModuleHnNSF, UpsampleBlock,
};
use crate::test_utils::cpu_setup;
use numr::ops::PaddingMode;
use numr::runtime::cpu::CpuRuntime;
fn zeros(shape: &[usize], device: &<CpuRuntime as Runtime>::Device) -> Tensor<CpuRuntime> {
let n: usize = shape.iter().product();
Tensor::<CpuRuntime>::from_slice(&vec![0.0f32; n], shape, device)
}
fn ones(shape: &[usize], device: &<CpuRuntime as Runtime>::Device) -> Tensor<CpuRuntime> {
let n: usize = shape.iter().product();
Tensor::<CpuRuntime>::from_slice(&vec![1.0f32; n], shape, device)
}
fn conv(
c_out: usize,
c_in: usize,
k: usize,
stride: usize,
device: &<CpuRuntime as Runtime>::Device,
) -> Conv1d<CpuRuntime> {
Conv1d::new(
zeros(&[c_out, c_in, k], device),
Some(zeros(&[c_out], device)),
stride,
PaddingMode::Same,
1,
1,
false,
)
}
fn adain(
c: usize,
s: usize,
device: &<CpuRuntime as Runtime>::Device,
) -> KokoroAdaIn1d<CpuRuntime> {
KokoroAdaIn1d::new(
zeros(&[2 * c, s], device),
zeros(&[2 * c], device),
ones(&[c], device),
zeros(&[c], device),
1e-5,
)
.unwrap()
}
fn resblk1d(
c_in: usize,
c_out: usize,
s: usize,
device: &<CpuRuntime as Runtime>::Device,
) -> AdainResBlk1d<CpuRuntime> {
let needs_shortcut = c_in != c_out;
AdainResBlk1d::new(
adain(c_in, s, device),
adain(c_out, s, device),
conv(c_out, c_in, 3, 1, device),
conv(c_out, c_out, 3, 1, device),
if needs_shortcut {
Some(conv(c_out, c_in, 1, 1, device))
} else {
None
},
None,
0.2,
)
}
fn resblock1(
c: usize,
s: usize,
device: &<CpuRuntime as Runtime>::Device,
) -> AdaINResBlock1<CpuRuntime> {
AdaINResBlock1::new(
[
conv(c, c, 3, 1, device),
conv(c, c, 3, 1, device),
conv(c, c, 3, 1, device),
],
[
conv(c, c, 3, 1, device),
conv(c, c, 3, 1, device),
conv(c, c, 3, 1, device),
],
[
adain(c, s, device),
adain(c, s, device),
adain(c, s, device),
],
[
adain(c, s, device),
adain(c, s, device),
adain(c, s, device),
],
[
ones(&[1, c, 1], device),
ones(&[1, c, 1], device),
ones(&[1, c, 1], device),
],
[
ones(&[1, c, 1], device),
ones(&[1, c, 1], device),
ones(&[1, c, 1], device),
],
1e-9,
)
.unwrap()
}
#[test]
fn forward_returns_mag_phase() {
let (client, device) = cpu_setup();
let style_dim = 4;
let asr_res = conv(4, 8, 1, 1, &device);
let f0_conv = conv(1, 1, 3, 1, &device); let n_conv = conv(1, 1, 3, 1, &device);
let encode = resblk1d(10, 8, style_dim, &device);
let decode = vec![
resblk1d(14, 8, style_dim, &device),
resblk1d(14, 4, style_dim, &device), ];
let n_fft = 4;
let source = SourceModuleHnNSF::new(
SineGen::new(24_000.0, 1),
zeros(&[1, 2], &device),
zeros(&[1], &device),
)
.unwrap();
let ups = vec![UpsampleBlock::new(
zeros(&[4, 4, 1], &device),
None,
1,
PaddingMode::Valid,
0,
1,
1,
0.1,
)];
let resblocks = vec![
resblock1(4, style_dim, &device),
resblock1(4, style_dim, &device),
];
let mag_phase = MagPhaseHead::new(conv(6, 4, 3, 1, &device), n_fft).unwrap();
let generator = IStftNetGenerator::new(
source,
ups,
resblocks,
Vec::new(),
Vec::new(),
mag_phase,
crate::model::audio::kokoro::IStftNetGeneratorOpts {
num_kernels: 2,
last_stage_reflect_pad: 0,
..Default::default()
},
)
.unwrap();
let decoder = Decoder::new(asr_res, f0_conv, n_conv, encode, decode, generator).unwrap();
let t = 5;
let asr = zeros(&[1, 8, t], &device);
let f0 = zeros(&[1, t], &device);
let ne = zeros(&[1, t], &device);
let style = zeros(&[1, style_dim], &device);
let (mag, phase) = decoder.forward(&client, &asr, &f0, &ne, &style).unwrap();
assert_eq!(mag.shape(), &[1, 3, 5]);
assert_eq!(phase.shape(), &[1, 3, 5]);
}
#[test]
fn new_rejects_empty_decode() {
let (_client, device) = cpu_setup();
let asr_res = conv(4, 8, 1, 1, &device);
let f0_conv = conv(1, 1, 3, 1, &device);
let n_conv = conv(1, 1, 3, 1, &device);
let encode = resblk1d(6, 8, 2, &device);
let source = SourceModuleHnNSF::new(
SineGen::new(24_000.0, 1),
zeros(&[1, 2], &device),
zeros(&[1], &device),
)
.unwrap();
let ups = vec![UpsampleBlock::new(
zeros(&[4, 4, 1], &device),
None,
1,
PaddingMode::Valid,
0,
1,
1,
0.1,
)];
let resblocks = vec![resblock1(4, 2, &device)];
let mag_phase = MagPhaseHead::new(conv(6, 4, 3, 1, &device), 4).unwrap();
let generator = IStftNetGenerator::new(
source,
ups,
resblocks,
Vec::new(),
Vec::new(),
mag_phase,
crate::model::audio::kokoro::IStftNetGeneratorOpts {
num_kernels: 1,
last_stage_reflect_pad: 0,
..Default::default()
},
)
.unwrap();
assert!(Decoder::new(asr_res, f0_conv, n_conv, encode, Vec::new(), generator).is_err());
}
#[test]
fn forward_rejects_mismatched_f0_n_shapes() {
let (client, device) = cpu_setup();
let asr_res = conv(4, 8, 1, 1, &device);
let f0_conv = conv(1, 1, 3, 1, &device);
let n_conv = conv(1, 1, 3, 1, &device);
let encode = resblk1d(6, 8, 2, &device);
let decode = vec![resblk1d(14, 4, 2, &device)];
let source = SourceModuleHnNSF::new(
SineGen::new(24_000.0, 1),
zeros(&[1, 2], &device),
zeros(&[1], &device),
)
.unwrap();
let ups = vec![UpsampleBlock::new(
zeros(&[4, 4, 1], &device),
None,
1,
PaddingMode::Valid,
0,
1,
1,
0.1,
)];
let resblocks = vec![resblock1(4, 2, &device)];
let mag_phase = MagPhaseHead::new(conv(6, 4, 3, 1, &device), 4).unwrap();
let generator = IStftNetGenerator::new(
source,
ups,
resblocks,
Vec::new(),
Vec::new(),
mag_phase,
crate::model::audio::kokoro::IStftNetGeneratorOpts {
num_kernels: 1,
last_stage_reflect_pad: 0,
..Default::default()
},
)
.unwrap();
let decoder = Decoder::new(asr_res, f0_conv, n_conv, encode, decode, generator).unwrap();
let asr = zeros(&[1, 8, 5], &device);
let f0 = zeros(&[1, 5], &device);
let n = zeros(&[1, 7], &device); let style = zeros(&[1, 2], &device);
assert!(decoder.forward(&client, &asr, &f0, &n, &style).is_err());
}
}