use crate::error::{Error, Result};
use crate::nn::{BiLstm, Conv1d, Embedding};
use numr::dtype::DType;
#[allow(unused_imports)]
use numr::ops::{
ActivationOps, BinaryOps, ConvOps, IndexingOps, MatmulOps, NormalizationOps, PaddingMode,
ReduceOps, ScalarOps, TensorOps, UnaryOps, UtilityOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub struct ConvBlock<R: Runtime> {
conv: Conv1d<R>,
ln_weight: Tensor<R>, ln_bias: Tensor<R>, eps: f32,
leaky_slope: f64,
}
impl<R: Runtime> ConvBlock<R> {
pub fn new(
conv: Conv1d<R>,
ln_weight: Tensor<R>,
ln_bias: Tensor<R>,
eps: f32,
leaky_slope: f64,
) -> Self {
Self {
conv,
ln_weight,
ln_bias,
eps,
leaky_slope,
}
}
pub fn forward<C>(&self, client: &C, x: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + ConvOps<R> + NormalizationOps<R> + ActivationOps<R> + TensorOps<R>,
{
let y = self.conv.forward_inference(client, x)?;
let y_bt_c = y.transpose(1, 2).map_err(Error::Numr)?.contiguous()?;
let y_ln = client
.layer_norm(&y_bt_c, &self.ln_weight, &self.ln_bias, self.eps)
.map_err(Error::Numr)?;
let y_bct = y_ln.transpose(1, 2).map_err(Error::Numr)?.contiguous()?;
client
.leaky_relu(&y_bct, self.leaky_slope)
.map_err(Error::Numr)
}
}
#[derive(Debug, Clone, Copy)]
pub struct TextEncoderConfig {
pub n_symbols: usize,
pub channels: usize,
pub kernel_size: usize,
pub depth: usize,
pub eps: f32,
pub leaky_slope: f64,
}
impl Default for TextEncoderConfig {
fn default() -> Self {
Self {
n_symbols: 178,
channels: 512,
kernel_size: 5,
depth: 3,
eps: 1e-5,
leaky_slope: 0.2,
}
}
}
pub struct TextEncoder<R: Runtime> {
embedding: Embedding<R>,
conv_blocks: Vec<ConvBlock<R>>,
lstm: BiLstm<R>,
channels: usize,
}
impl<R: Runtime> TextEncoder<R> {
pub fn new(
embedding: Embedding<R>,
conv_blocks: Vec<ConvBlock<R>>,
lstm: BiLstm<R>,
channels: usize,
) -> Result<Self> {
if 2 * lstm.hidden_size() != channels {
return Err(Error::InvalidArgument {
arg: "lstm",
reason: format!(
"BiLSTM total output width must equal text encoder channels ({channels}), \
got 2 * {}",
lstm.hidden_size()
),
});
}
Ok(Self {
embedding,
conv_blocks,
lstm,
channels,
})
}
pub fn channels(&self) -> usize {
self.channels
}
pub fn forward<C>(&self, client: &C, phoneme_ids: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ IndexingOps<R>
+ ConvOps<R>
+ NormalizationOps<R>
+ ActivationOps<R>
+ TensorOps<R>
+ MatmulOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ UtilityOps<R>,
R::Client: IndexingOps<R>,
{
let shape = phoneme_ids.shape();
if shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "phoneme_ids",
reason: format!("expected [B, T], got {shape:?}"),
});
}
let embedded = self.embedding.forward(client, phoneme_ids)?;
let mut h = embedded
.tensor()
.transpose(1, 2)
.map_err(Error::Numr)?
.contiguous()?;
for block in &self.conv_blocks {
h = block.forward(client, &h)?;
}
let h_btc = h.transpose(1, 2).map_err(Error::Numr)?.contiguous()?;
self.lstm.forward(client, &h_btc)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nn::{BiLstm, Conv1d, Embedding, Lstm};
use crate::test_utils::cpu_setup;
use numr::runtime::cpu::CpuRuntime;
fn zeros(shape: &[usize], device: &<CpuRuntime as Runtime>::Device) -> Tensor<CpuRuntime> {
let n: usize = shape.iter().product();
Tensor::<CpuRuntime>::from_slice(&vec![0.0f32; n], shape, device)
}
fn ones(shape: &[usize], device: &<CpuRuntime as Runtime>::Device) -> Tensor<CpuRuntime> {
let n: usize = shape.iter().product();
Tensor::<CpuRuntime>::from_slice(&vec![1.0f32; n], shape, device)
}
fn build_tiny_encoder(device: &<CpuRuntime as Runtime>::Device) -> TextEncoder<CpuRuntime> {
let channels = 4;
let kernel_size = 3;
let emb = Embedding::new(zeros(&[8, channels], device), false);
let mut blocks = Vec::new();
for _ in 0..2 {
let conv = Conv1d::new(
zeros(&[channels, channels, kernel_size], device),
Some(zeros(&[channels], device)),
1,
PaddingMode::Same,
1,
1,
false,
);
blocks.push(ConvBlock::new(
conv,
ones(&[channels], device),
zeros(&[channels], device),
1e-5,
0.2,
));
}
let hidden = channels / 2;
let lstm_f = Lstm::new(
zeros(&[4 * hidden, channels], device),
zeros(&[4 * hidden, hidden], device),
zeros(&[4 * hidden], device),
zeros(&[4 * hidden], device),
)
.unwrap();
let lstm_b = Lstm::new(
zeros(&[4 * hidden, channels], device),
zeros(&[4 * hidden, hidden], device),
zeros(&[4 * hidden], device),
zeros(&[4 * hidden], device),
)
.unwrap();
let bi = BiLstm::new(lstm_f, lstm_b).unwrap();
TextEncoder::new(emb, blocks, bi, channels).unwrap()
}
#[test]
fn forward_shape_is_b_t_c() {
let (client, device) = cpu_setup();
let enc = build_tiny_encoder(&device);
let ids = Tensor::<CpuRuntime>::from_slice(&[1i64, 2, 3, 4, 5, 6], &[2, 3], &device);
let out = enc.forward(&client, &ids).unwrap();
assert_eq!(out.shape(), &[2, 3, 4]);
}
#[test]
fn zero_weights_yield_finite_output() {
let (client, device) = cpu_setup();
let enc = build_tiny_encoder(&device);
let ids = Tensor::<CpuRuntime>::from_slice(&[0i64; 5], &[1, 5], &device);
let out = enc.forward(&client, &ids).unwrap();
for v in out.to_vec::<f32>() {
assert!(v.is_finite(), "got non-finite value {v}");
}
}
#[test]
fn rejects_rank_other_than_2() {
let (client, device) = cpu_setup();
let enc = build_tiny_encoder(&device);
let ids = Tensor::<CpuRuntime>::from_slice(&[0i64, 1, 2], &[3], &device);
assert!(enc.forward(&client, &ids).is_err());
}
#[test]
fn new_rejects_lstm_width_mismatch() {
let (_client, device) = cpu_setup();
let emb = Embedding::new(zeros(&[8, 4], &device), false);
let hidden = 3;
let lstm_f = Lstm::new(
zeros(&[4 * hidden, 4], &device),
zeros(&[4 * hidden, hidden], &device),
zeros(&[4 * hidden], &device),
zeros(&[4 * hidden], &device),
)
.unwrap();
let lstm_b = Lstm::new(
zeros(&[4 * hidden, 4], &device),
zeros(&[4 * hidden, hidden], &device),
zeros(&[4 * hidden], &device),
zeros(&[4 * hidden], &device),
)
.unwrap();
let bi = BiLstm::new(lstm_f, lstm_b).unwrap();
assert!(TextEncoder::new(emb, Vec::new(), bi, 4).is_err());
}
}