#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use std::sync::Arc;
use candle_core::{Result, Tensor};
use candle_nn::{Conv2dConfig, Linear, Module};
use mistralrs_quant::ShardedVarBuilder;
use crate::layers;
use super::config::NemoConvConfig;
pub struct NemoConvSubsampling {
conv: Vec<Arc<dyn Module + Send + Sync>>,
out: Linear,
subsampling_factor: usize,
}
impl NemoConvSubsampling {
pub fn new(cfg: &NemoConvConfig, vb: ShardedVarBuilder) -> Result<Self> {
if !cfg.subsampling_factor.is_multiple_of(2) {
candle_core::bail!("Sampling factor should be a multiple of 2!");
}
let sampling_num = (cfg.subsampling_factor as f32).log2() as usize;
let mut in_channels = 1;
let mut layers: Vec<Arc<dyn Module + Send + Sync>> = Vec::new();
let stride = 2;
let kernel_size = 3;
let ceil_mode = false;
assert!(!cfg.is_causal);
assert_eq!(cfg.subsampling, "dw_striding");
let left_padding = (kernel_size - 1) / 2;
let right_padding = (kernel_size - 1) / 2;
{
let vb_layers = vb.pp("conv");
let mut idx = 0;
layers.push(Arc::new(layers::conv2d(
in_channels,
cfg.conv_channels,
kernel_size,
Conv2dConfig {
padding: left_padding,
stride,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
},
vb_layers.pp(idx),
)?));
in_channels = cfg.conv_channels;
idx += 1;
layers.push(Arc::new(cfg.activation));
for _ in 0..(sampling_num - 1) {
idx += 1;
layers.push(Arc::new(layers::conv2d(
in_channels,
in_channels,
kernel_size,
Conv2dConfig {
padding: left_padding,
stride,
dilation: 1,
groups: in_channels,
cudnn_fwd_algo: None,
},
vb_layers.pp(idx),
)?));
idx += 1;
layers.push(Arc::new(layers::conv2d(
in_channels,
cfg.conv_channels,
1,
Conv2dConfig {
padding: 0,
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
},
vb_layers.pp(idx),
)?));
idx += 1;
layers.push(Arc::new(cfg.activation));
}
}
let in_length = cfg.feat_in as f32;
let out_length = Self::calc_length(
in_length,
left_padding + right_padding,
kernel_size,
stride,
ceil_mode,
sampling_num,
);
let out = layers::linear_b(
cfg.conv_channels * out_length,
cfg.feat_out,
true,
vb.pp("out"),
)?;
Ok(Self {
conv: layers,
out,
subsampling_factor: cfg.subsampling_factor,
})
}
fn calc_length(
mut length: f32,
all_paddings: usize,
kernel_size: usize,
stride: usize,
ceil_mode: bool,
repeat_num: usize,
) -> usize {
let add_pad = all_paddings as f32 - kernel_size as f32;
let one = 1f32;
for _ in 0..repeat_num {
length = (length + add_pad) / (stride as f32) + one;
if ceil_mode {
length = length.ceil();
} else {
length = length.floor();
}
}
length as usize
}
pub fn forward(&self, x: &Tensor, mask: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {
let mut x = x.unsqueeze(1)?;
for layer in &self.conv {
x = layer.forward(&x)?;
}
let (b, c, t, f) = x.dims4()?;
x = x.transpose(1, 2)?.reshape((b, t, c * f))?;
x = x.apply(&self.out)?;
let new_mask = if let Some(mask) = mask {
let max_audio_length = x.dim(1)?;
let feature_lens = mask.sum_keepdim(1)?;
let padding_length = feature_lens.apply(&|t: &Tensor| {
(t.to_dtype(candle_core::DType::F32)? / self.subsampling_factor as f64)?.ceil()
})?;
let device = x.device();
let arange = Tensor::arange(0u32, max_audio_length as u32, device)?
.unsqueeze(0)?
.broadcast_as((padding_length.dim(0)?, max_audio_length))?;
let pad_mask = arange.lt(&padding_length)?;
Some(pad_mask.unsqueeze(1)?)
} else {
None
};
Ok((x, new_mask))
}
}