use candle_core::{Device, Result, Tensor};
use candle_nn::VarBuilder;
use crate::layers::conv::{AdaIn1d, AdaLayerNorm, Conv1d, ConvTranspose1d, LinearNorm};
use crate::layers::lstm::Lstm;
fn scalar_like(tensor: &Tensor, value: f32) -> Result<Tensor> {
Tensor::new(value, tensor.device())?.to_dtype(tensor.dtype())
}
fn scale_tensor(tensor: &Tensor, value: f32) -> Result<Tensor> {
tensor.broadcast_mul(&scalar_like(tensor, value)?)
}
fn upsample_1d_repeat(x: &Tensor, target_len: usize) -> Result<Tensor> {
let (batch, channels, length) = x.dims3()?;
if target_len == length {
return Ok(x.clone());
}
let scale = target_len / length;
if scale * length == target_len && scale > 1 {
let x = x.unsqueeze(3)?;
let x = x.repeat(&[1, 1, 1, scale])?;
x.reshape((batch, channels, target_len))
} else {
let device = x.device().clone();
let x_cpu = x.to_device(&Device::Cpu)?;
let upsampled = x_cpu.upsample_nearest1d(target_len)?;
upsampled.to_device(&device)
}
}
fn leaky_relu(x: &Tensor, negative_slope: f32) -> Result<Tensor> {
let scaled = scale_tensor(x, negative_slope)?;
x.maximum(&scaled)
}
pub struct AdainResBlk1d {
conv1: Conv1d,
conv2: Conv1d,
norm1: AdaIn1d,
norm2: AdaIn1d,
conv1x1: Option<Conv1d>,
pool: Option<ConvTranspose1d>,
upsample: bool,
}
impl AdainResBlk1d {
pub fn load(
dim_in: usize,
dim_out: usize,
style_dim: usize,
upsample: bool,
vb: VarBuilder,
) -> Result<Self> {
let conv1 = Conv1d::load(dim_in, dim_out, 3, 1, 1, 1, 1, true, vb.pp("conv1"))?;
let conv2 = Conv1d::load(dim_out, dim_out, 3, 1, 1, 1, 1, true, vb.pp("conv2"))?;
let norm1 = AdaIn1d::load(style_dim, dim_in, vb.pp("norm1"))?;
let norm2 = AdaIn1d::load(style_dim, dim_out, vb.pp("norm2"))?;
let conv1x1 = if dim_in != dim_out {
Some(Conv1d::load(
dim_in,
dim_out,
1,
1,
0,
1,
1,
false,
vb.pp("conv1x1"),
)?)
} else {
None
};
let pool = if upsample {
Some(ConvTranspose1d::load(
dim_in,
dim_in,
3,
2,
1,
1,
dim_in,
true,
vb.pp("pool"),
)?)
} else {
None
};
Ok(Self {
conv1,
conv2,
norm1,
norm2,
conv1x1,
pool,
upsample,
})
}
fn shortcut(&self, x: &Tensor) -> Result<Tensor> {
let mut out = x.clone();
if self.upsample {
let (_b, _c, length) = out.dims3()?;
out = upsample_1d_repeat(&out, length * 2)?;
}
if let Some(ref sc) = self.conv1x1 {
out = sc.forward(&out)?;
}
Ok(out)
}
fn residual(&self, x: &Tensor, s: &Tensor) -> Result<Tensor> {
let mut out = self.norm1.forward(x, s)?;
out = leaky_relu(&out, 0.2)?;
if let Some(ref pool) = self.pool {
out = pool.forward(&out)?;
}
out = self.conv1.forward(&out)?;
out = self.norm2.forward(&out, s)?;
out = leaky_relu(&out, 0.2)?;
self.conv2.forward(&out)
}
pub fn forward(&self, x: &Tensor, s: &Tensor) -> Result<Tensor> {
let residual = self.shortcut(x)?;
let out = self.residual(x, s)?;
let combined = out.add(&residual)?;
scale_tensor(&combined, std::f32::consts::FRAC_1_SQRT_2)
}
pub fn upsample_type(&self) -> &str {
if self.upsample {
"nearest"
} else {
"none"
}
}
}
pub struct DurationEncoder {
lstms: Vec<Lstm>,
ada_norms: Vec<AdaLayerNorm>,
sty_dim: usize,
}
impl DurationEncoder {
pub fn load(
sty_dim: usize,
d_model: usize,
nlayers: usize,
vb: VarBuilder,
_device: &Device,
) -> Result<Self> {
let mut lstms = Vec::with_capacity(nlayers);
let mut ada_norms = Vec::with_capacity(nlayers);
for i in 0..nlayers {
let lstm = Lstm::load(
1, d_model + sty_dim, d_model / 2, true, vb.pp("lstms").pp((i * 2).to_string()),
)?;
lstms.push(lstm);
let norm =
AdaLayerNorm::load(sty_dim, d_model, vb.pp("lstms").pp((i * 2 + 1).to_string()))?;
ada_norms.push(norm);
}
Ok(Self {
lstms,
ada_norms,
sty_dim,
})
}
pub fn forward(
&self,
x: &Tensor,
style: &Tensor,
_text_lengths: &Tensor,
mask: &Tensor,
) -> Result<Tensor> {
let (batch, _channels, seq_len) = x.dims3()?;
let mut x = x.permute((2, 0, 1))?;
let s = style
.unsqueeze(0)?
.broadcast_as((seq_len, batch, self.sty_dim))?;
x = Tensor::cat(&[&x, &s], 2)?;
let mask_f = mask.to_dtype(x.dtype())?.transpose(0, 1)?.unsqueeze(2)?;
let inv_mask = mask_f.neg()?.add(&Tensor::ones_like(&mask_f)?)?;
x = x.broadcast_mul(&inv_mask)?;
x = x.transpose(0, 1)?;
for (lstm, norm) in self.lstms.iter().zip(self.ada_norms.iter()) {
x = lstm.forward(&x)?;
x = norm.forward(&x, style)?;
let s_batch = style
.unsqueeze(1)?
.broadcast_as((batch, seq_len, self.sty_dim))?;
x = Tensor::cat(&[&x, &s_batch], 2)?;
let batch_mask = mask.unsqueeze(2)?.to_dtype(x.dtype())?;
let inv = batch_mask.neg()?.add(&Tensor::ones_like(&batch_mask)?)?;
x = x.broadcast_mul(&inv)?;
}
Ok(x)
}
}
pub struct ProsodyPredictor {
pub text_encoder: DurationEncoder,
pub lstm: Lstm,
pub duration_proj: LinearNorm,
pub shared: Lstm,
pub f0_blocks: Vec<AdainResBlk1d>,
pub f0_proj: Conv1d,
pub n_blocks: Vec<AdainResBlk1d>,
pub n_proj: Conv1d,
}
impl ProsodyPredictor {
pub fn load(
style_dim: usize,
d_hid: usize,
nlayers: usize,
max_dur: usize,
vb: VarBuilder,
device: &Device,
) -> Result<Self> {
let text_encoder =
DurationEncoder::load(style_dim, d_hid, nlayers, vb.pp("text_encoder"), device)?;
let lstm = Lstm::load(1, d_hid + style_dim, d_hid / 2, true, vb.pp("lstm"))?;
let duration_proj =
LinearNorm::load(d_hid, max_dur, vb.pp("duration_proj").pp("linear_layer"))?;
let shared = Lstm::load(1, d_hid + style_dim, d_hid / 2, true, vb.pp("shared"))?;
let f0_0 = AdainResBlk1d::load(d_hid, d_hid, style_dim, false, vb.pp("F0").pp("0"))?;
let f0_1 = AdainResBlk1d::load(
d_hid,
d_hid / 2,
style_dim,
true, vb.pp("F0").pp("1"),
)?;
let f0_2 =
AdainResBlk1d::load(d_hid / 2, d_hid / 2, style_dim, false, vb.pp("F0").pp("2"))?;
let f0_proj = Conv1d::load(d_hid / 2, 1, 1, 1, 0, 1, 1, true, vb.pp("F0_proj"))?;
let n_0 = AdainResBlk1d::load(d_hid, d_hid, style_dim, false, vb.pp("N").pp("0"))?;
let n_1 = AdainResBlk1d::load(
d_hid,
d_hid / 2,
style_dim,
true, vb.pp("N").pp("1"),
)?;
let n_2 = AdainResBlk1d::load(d_hid / 2, d_hid / 2, style_dim, false, vb.pp("N").pp("2"))?;
let n_proj = Conv1d::load(d_hid / 2, 1, 1, 1, 0, 1, 1, true, vb.pp("N_proj"))?;
Ok(Self {
text_encoder,
lstm,
duration_proj,
shared,
f0_blocks: vec![f0_0, f0_1, f0_2],
f0_proj,
n_blocks: vec![n_0, n_1, n_2],
n_proj,
})
}
pub fn predict_duration(&self, d: &Tensor, _s: &Tensor) -> Result<Tensor> {
let x = self.lstm.forward(d)?;
let dur = self.duration_proj.forward(&x)?;
let dur = candle_nn::ops::sigmoid(&dur)?;
dur.sum(2) }
pub fn f0_n_predict(&self, x: &Tensor, s: &Tensor) -> Result<(Tensor, Tensor)> {
let x_t = x.transpose(1, 2)?;
let shared_out = self.shared.forward(&x_t)?;
let mut f0 = shared_out.transpose(1, 2)?; for block in &self.f0_blocks {
f0 = block.forward(&f0, s)?;
}
let f0 = self.f0_proj.forward(&f0)?.squeeze(1)?;
let mut n = shared_out.transpose(1, 2)?;
for block in &self.n_blocks {
n = block.forward(&n, s)?;
}
let n = self.n_proj.forward(&n)?.squeeze(1)?;
Ok((f0, n))
}
}