use crate::streaming::{self, StreamMask, StreamTensor, StreamingModule};
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;
use crate::conv::{StreamableConv1d, StreamableConvTranspose1d};
#[derive(Debug, Clone)]
pub struct Config {
pub dimension: usize,
pub channels: usize,
pub causal: bool,
pub n_filters: usize,
pub n_residual_layers: usize,
pub ratios: Vec<usize>,
pub activation: candle_nn::Activation,
pub norm: crate::conv::Norm,
pub kernel_size: usize,
pub residual_kernel_size: usize,
pub last_kernel_size: usize,
pub dilation_base: usize,
pub pad_mode: crate::conv::PadMode,
pub true_skip: bool,
pub compress: usize,
pub lstm: usize,
pub disable_norm_outer_blocks: usize,
pub final_activation: Option<candle_nn::Activation>,
}
#[derive(Debug, Clone)]
pub struct SeaNetResnetBlock {
block: Vec<StreamableConv1d>,
shortcut: Option<StreamableConv1d>,
activation: candle_nn::Activation,
skip_op: streaming::StreamingBinOp,
span: tracing::Span,
}
impl SeaNetResnetBlock {
#[allow(clippy::too_many_arguments)]
pub fn new(
dim: usize,
k_sizes_and_dilations: &[(usize, usize)],
activation: candle_nn::Activation,
norm: Option<crate::conv::Norm>,
causal: bool,
pad_mode: crate::conv::PadMode,
compress: usize,
true_skip: bool,
vb: VarBuilder,
) -> Result<Self> {
let mut block = Vec::with_capacity(k_sizes_and_dilations.len());
let hidden = dim / compress;
let vb_b = vb.pp("block");
for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() {
let in_c = if i == 0 { dim } else { hidden };
let out_c = if i == k_sizes_and_dilations.len() - 1 { dim } else { hidden };
let c = StreamableConv1d::new(
in_c,
out_c,
*k_size,
1,
*dilation,
1,
true,
causal,
norm,
pad_mode,
vb_b.pp(2 * i + 1),
)?;
block.push(c)
}
let shortcut = if true_skip {
None
} else {
let c = StreamableConv1d::new(
dim,
dim,
1,
1,
1,
1,
true,
causal,
norm,
pad_mode,
vb.pp("shortcut"),
)?;
Some(c)
};
Ok(Self {
block,
shortcut,
activation,
skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1),
span: tracing::span!(tracing::Level::TRACE, "sea-resnet"),
})
}
pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
for b in self.block.iter_mut() {
b.reset_batch_idx(batch_idx, batch_size)?;
}
if let Some(shortcut) = self.shortcut.as_mut() {
shortcut.reset_batch_idx(batch_idx, batch_size)?;
}
self.skip_op.reset_batch_idx(batch_idx, batch_size)?;
Ok(())
}
}
impl Module for SeaNetResnetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut ys = xs.clone();
for block in self.block.iter() {
ys = ys.apply(&self.activation)?.apply(block)?;
}
match self.shortcut.as_ref() {
None => ys + xs,
Some(shortcut) => ys + xs.apply(shortcut),
}
}
}
impl StreamingModule for SeaNetResnetBlock {
fn reset_state(&mut self) {
self.skip_op.reset_state();
for block in self.block.iter_mut() {
block.reset_state()
}
if let Some(shortcut) = self.shortcut.as_mut() {
shortcut.reset_state()
}
}
fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
let _enter = self.span.enter();
let mut ys = xs.clone();
for block in self.block.iter_mut() {
ys = block.step(&ys.apply(&self.activation)?, m)?;
}
match self.shortcut.as_mut() {
None => self.skip_op.step(&ys, xs, m),
Some(shortcut) => self.skip_op.step(&ys, &shortcut.step(xs, m)?, m),
}
}
}
#[derive(Debug, Clone)]
struct EncoderLayer {
residuals: Vec<SeaNetResnetBlock>,
downsample: StreamableConv1d,
}
#[derive(Debug, Clone)]
pub struct SeaNetEncoder {
init_conv1d: StreamableConv1d,
activation: candle_nn::Activation,
layers: Vec<EncoderLayer>,
final_conv1d: StreamableConv1d,
span: tracing::Span,
}
impl SeaNetEncoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
if cfg.lstm > 0 {
candle::bail!("seanet lstm is not supported")
}
let n_blocks = 2 + cfg.ratios.len();
let mut mult = 1usize;
let init_norm = if cfg.disable_norm_outer_blocks >= 1 { None } else { Some(cfg.norm) };
let mut layer_idx = 0;
let vb = vb.pp("model");
let init_conv1d = StreamableConv1d::new(
cfg.channels,
mult * cfg.n_filters,
cfg.kernel_size,
1,
1,
1,
true,
cfg.causal,
init_norm,
cfg.pad_mode,
vb.pp(layer_idx),
)?;
layer_idx += 1;
let mut layers = Vec::with_capacity(cfg.ratios.len());
for (i, &ratio) in cfg.ratios.iter().rev().enumerate() {
let norm = if cfg.disable_norm_outer_blocks >= i + 2 { None } else { Some(cfg.norm) };
let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
for j in 0..cfg.n_residual_layers {
let resnet_block = SeaNetResnetBlock::new(
mult * cfg.n_filters,
&[(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)), (1, 1)],
cfg.activation,
norm,
cfg.causal,
cfg.pad_mode,
cfg.compress,
cfg.true_skip,
vb.pp(layer_idx),
)?;
residuals.push(resnet_block);
layer_idx += 1;
}
let downsample = StreamableConv1d::new(
mult * cfg.n_filters,
mult * cfg.n_filters * 2,
ratio * 2,
ratio,
1,
1,
true,
true,
norm,
cfg.pad_mode,
vb.pp(layer_idx + 1),
)?;
layer_idx += 2;
let layer = EncoderLayer { downsample, residuals };
layers.push(layer);
mult *= 2
}
let final_norm =
if cfg.disable_norm_outer_blocks >= n_blocks { None } else { Some(cfg.norm) };
let final_conv1d = StreamableConv1d::new(
mult * cfg.n_filters,
cfg.dimension,
cfg.last_kernel_size,
1,
1,
1,
true,
cfg.causal,
final_norm,
cfg.pad_mode,
vb.pp(layer_idx + 1),
)?;
Ok(Self {
init_conv1d,
activation: cfg.activation,
layers,
final_conv1d,
span: tracing::span!(tracing::Level::TRACE, "sea-encoder"),
})
}
pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
self.init_conv1d.reset_batch_idx(batch_idx, batch_size)?;
self.final_conv1d.reset_batch_idx(batch_idx, batch_size)?;
for layer in self.layers.iter_mut() {
layer.downsample.reset_batch_idx(batch_idx, batch_size)?;
for l in layer.residuals.iter_mut() {
l.reset_batch_idx(batch_idx, batch_size)?;
}
}
Ok(())
}
}
impl Module for SeaNetEncoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.apply(&self.init_conv1d)?;
for layer in self.layers.iter() {
for residual in layer.residuals.iter() {
xs = xs.apply(residual)?
}
xs = xs.apply(&self.activation)?.apply(&layer.downsample)?;
}
xs.apply(&self.activation)?.apply(&self.final_conv1d)
}
}
impl StreamingModule for SeaNetEncoder {
fn reset_state(&mut self) {
self.init_conv1d.reset_state();
self.layers.iter_mut().for_each(|v| {
v.residuals.iter_mut().for_each(|v| v.reset_state());
v.downsample.reset_state()
});
self.final_conv1d.reset_state();
}
fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
let _enter = self.span.enter();
let mut xs = self.init_conv1d.step(xs, m)?;
for layer in self.layers.iter_mut() {
for residual in layer.residuals.iter_mut() {
xs = residual.step(&xs, m)?;
}
xs = layer.downsample.step(&xs.apply(&self.activation)?, m)?;
}
self.final_conv1d.step(&xs.apply(&self.activation)?, m)
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
upsample: StreamableConvTranspose1d,
residuals: Vec<SeaNetResnetBlock>,
}
#[derive(Debug, Clone)]
pub struct SeaNetDecoder {
init_conv1d: StreamableConv1d,
activation: candle_nn::Activation,
layers: Vec<DecoderLayer>,
final_conv1d: StreamableConv1d,
final_activation: Option<candle_nn::Activation>,
span: tracing::Span,
}
impl SeaNetDecoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
if cfg.lstm > 0 {
candle::bail!("seanet lstm is not supported")
}
let n_blocks = 2 + cfg.ratios.len();
let mut mult = 1 << cfg.ratios.len();
let init_norm =
if cfg.disable_norm_outer_blocks == n_blocks { None } else { Some(cfg.norm) };
let mut layer_idx = 0;
let vb = vb.pp("model");
let init_conv1d = StreamableConv1d::new(
cfg.dimension,
mult * cfg.n_filters,
cfg.kernel_size,
1,
1,
1,
true,
cfg.causal,
init_norm,
cfg.pad_mode,
vb.pp(layer_idx),
)?;
layer_idx += 1;
let mut layers = Vec::with_capacity(cfg.ratios.len());
for (i, &ratio) in cfg.ratios.iter().enumerate() {
let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks {
None
} else {
Some(cfg.norm)
};
let upsample = StreamableConvTranspose1d::new(
mult * cfg.n_filters,
mult * cfg.n_filters / 2,
ratio * 2,
ratio,
1,
true,
true,
norm,
vb.pp(layer_idx + 1),
)?;
layer_idx += 2;
let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
for j in 0..cfg.n_residual_layers {
let resnet_block = SeaNetResnetBlock::new(
mult * cfg.n_filters / 2,
&[(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)), (1, 1)],
cfg.activation,
norm,
cfg.causal,
cfg.pad_mode,
cfg.compress,
cfg.true_skip,
vb.pp(layer_idx),
)?;
residuals.push(resnet_block);
layer_idx += 1;
}
let layer = DecoderLayer { upsample, residuals };
layers.push(layer);
mult /= 2
}
let final_norm = if cfg.disable_norm_outer_blocks >= 1 { None } else { Some(cfg.norm) };
let final_conv1d = StreamableConv1d::new(
cfg.n_filters,
cfg.channels,
cfg.last_kernel_size,
1,
1,
1,
true,
cfg.causal,
final_norm,
cfg.pad_mode,
vb.pp(layer_idx + 1),
)?;
Ok(Self {
init_conv1d,
activation: cfg.activation,
layers,
final_conv1d,
final_activation: cfg.final_activation,
span: tracing::span!(tracing::Level::TRACE, "sea-decoder"),
})
}
pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
self.init_conv1d.reset_batch_idx(batch_idx, batch_size)?;
self.final_conv1d.reset_batch_idx(batch_idx, batch_size)?;
for layer in self.layers.iter_mut() {
layer.upsample.reset_batch_idx(batch_idx, batch_size)?;
for l in layer.residuals.iter_mut() {
l.reset_batch_idx(batch_idx, batch_size)?;
}
}
Ok(())
}
}
impl Module for SeaNetDecoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.apply(&self.init_conv1d)?;
for layer in self.layers.iter() {
xs = xs.apply(&self.activation)?.apply(&layer.upsample)?;
for residual in layer.residuals.iter() {
xs = xs.apply(residual)?
}
}
let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?;
let xs = match self.final_activation.as_ref() {
None => xs,
Some(act) => xs.apply(act)?,
};
Ok(xs)
}
}
impl StreamingModule for SeaNetDecoder {
fn reset_state(&mut self) {
self.init_conv1d.reset_state();
self.layers.iter_mut().for_each(|v| {
v.residuals.iter_mut().for_each(|v| v.reset_state());
v.upsample.reset_state()
});
self.final_conv1d.reset_state();
}
fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
let _enter = self.span.enter();
let mut xs = self.init_conv1d.step(xs, m)?;
for layer in self.layers.iter_mut() {
xs = layer.upsample.step(&xs.apply(&self.activation)?, m)?;
for residual in layer.residuals.iter_mut() {
xs = residual.step(&xs, m)?;
}
}
let xs = self.final_conv1d.step(&xs.apply(&self.activation)?, m)?;
let xs = match self.final_activation.as_ref() {
None => xs,
Some(act) => xs.apply(act)?,
};
Ok(xs)
}
}