use candle_core::{Device, Module, Result, Tensor};
use candle_nn::VarBuilder;
use crate::layers::conv::{ChannelNorm, Conv1d};
use crate::layers::lstm::Lstm;
fn scalar_like(tensor: &Tensor, value: f32) -> Result<Tensor> {
Tensor::new(value, tensor.device())?.to_dtype(tensor.dtype())
}
fn leaky_relu(x: &Tensor, negative_slope: f32) -> Result<Tensor> {
let scaled = x.broadcast_mul(&scalar_like(x, negative_slope)?)?;
x.maximum(&scaled)
}
pub struct TextEncoder {
embedding: candle_nn::Embedding,
cnn: Vec<CnnBlock>,
lstm: Lstm,
channels: usize,
}
struct CnnBlock {
conv: Conv1d,
norm: ChannelNorm,
}
impl TextEncoder {
pub fn load(
channels: usize,
kernel_size: usize,
depth: usize,
n_symbols: usize,
vb: VarBuilder,
_device: &Device,
) -> Result<Self> {
let embedding = candle_nn::embedding(n_symbols, channels, vb.pp("embedding"))?;
let mut cnn = Vec::with_capacity(depth);
for i in 0..depth {
let padding = kernel_size / 2;
let block_vb = vb.pp("cnn").pp(i.to_string());
let conv = Conv1d::load(
channels,
channels,
kernel_size,
1, padding,
1, 1, true, block_vb.pp("0"),
)?;
let norm = ChannelNorm::load(channels, block_vb.pp("1"))?;
cnn.push(CnnBlock { conv, norm });
}
let lstm = Lstm::load(
1, channels, channels / 2, true, vb.pp("lstm"),
)?;
Ok(Self {
embedding,
cnn,
lstm,
channels,
})
}
pub fn forward(
&self,
input_ids: &Tensor,
_input_lengths: &Tensor,
text_mask: &Tensor,
) -> Result<Tensor> {
let mut x = self.embedding.forward(input_ids)?;
x = x.transpose(1, 2)?;
let mask = text_mask.unsqueeze(1)?.to_dtype(x.dtype())?;
let inv_mask = mask.neg()?.add(&Tensor::ones_like(&mask)?)?;
x = x.broadcast_mul(&inv_mask)?;
for block in &self.cnn {
x = block.conv.forward(&x)?;
let x_t = x.transpose(1, 2)?;
let x_normed = block.norm.forward(&x_t)?;
x = x_normed.transpose(1, 2)?;
x = leaky_relu(&x, 0.2)?;
x = x.broadcast_mul(&inv_mask)?;
}
let x_t = x.transpose(1, 2)?;
let lstm_out = self.lstm.forward(&x_t)?;
let result = lstm_out.transpose(1, 2)?;
result.broadcast_mul(&inv_mask)
}
pub fn channels(&self) -> usize {
self.channels
}
}