1use crate::streaming::{self, StreamMask, StreamTensor, StreamingModule};
6use candle::{Module, Result, Tensor};
7use candle_nn::VarBuilder;
8
9use crate::conv::{StreamableConv1d, StreamableConvTranspose1d};
10
11#[derive(Debug, Clone)]
12pub struct Config {
13 pub dimension: usize,
14 pub channels: usize,
15 pub causal: bool,
16 pub n_filters: usize,
17 pub n_residual_layers: usize,
18 pub ratios: Vec<usize>,
19 pub activation: candle_nn::Activation,
20 pub norm: crate::conv::Norm,
21 pub kernel_size: usize,
22 pub residual_kernel_size: usize,
23 pub last_kernel_size: usize,
24 pub dilation_base: usize,
25 pub pad_mode: crate::conv::PadMode,
26 pub true_skip: bool,
27 pub compress: usize,
28 pub lstm: usize,
29 pub disable_norm_outer_blocks: usize,
30 pub final_activation: Option<candle_nn::Activation>,
31}
32
33#[derive(Debug, Clone)]
34pub struct SeaNetResnetBlock {
35 block: Vec<StreamableConv1d>,
36 shortcut: Option<StreamableConv1d>,
37 activation: candle_nn::Activation,
38 skip_op: streaming::StreamingBinOp,
39 span: tracing::Span,
40}
41
42impl SeaNetResnetBlock {
43 #[allow(clippy::too_many_arguments)]
44 pub fn new(
45 dim: usize,
46 k_sizes_and_dilations: &[(usize, usize)],
47 activation: candle_nn::Activation,
48 norm: Option<crate::conv::Norm>,
49 causal: bool,
50 pad_mode: crate::conv::PadMode,
51 compress: usize,
52 true_skip: bool,
53 vb: VarBuilder,
54 ) -> Result<Self> {
55 let mut block = Vec::with_capacity(k_sizes_and_dilations.len());
56 let hidden = dim / compress;
57 let vb_b = vb.pp("block");
58 for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() {
59 let in_c = if i == 0 { dim } else { hidden };
60 let out_c = if i == k_sizes_and_dilations.len() - 1 { dim } else { hidden };
61 let c = StreamableConv1d::new(
62 in_c,
63 out_c,
64 *k_size,
65 1,
66 *dilation,
67 1,
68 true,
69 causal,
70 norm,
71 pad_mode,
72 vb_b.pp(2 * i + 1),
73 )?;
74 block.push(c)
75 }
76 let shortcut = if true_skip {
77 None
78 } else {
79 let c = StreamableConv1d::new(
80 dim,
81 dim,
82 1,
83 1,
84 1,
85 1,
86 true,
87 causal,
88 norm,
89 pad_mode,
90 vb.pp("shortcut"),
91 )?;
92 Some(c)
93 };
94 Ok(Self {
95 block,
96 shortcut,
97 activation,
98 skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1),
99 span: tracing::span!(tracing::Level::TRACE, "sea-resnet"),
100 })
101 }
102
103 pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
104 for b in self.block.iter_mut() {
105 b.reset_batch_idx(batch_idx, batch_size)?;
106 }
107 if let Some(shortcut) = self.shortcut.as_mut() {
108 shortcut.reset_batch_idx(batch_idx, batch_size)?;
109 }
110 self.skip_op.reset_batch_idx(batch_idx, batch_size)?;
111 Ok(())
112 }
113}
114
115impl Module for SeaNetResnetBlock {
116 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
117 let _enter = self.span.enter();
118 let mut ys = xs.clone();
119 for block in self.block.iter() {
120 ys = ys.apply(&self.activation)?.apply(block)?;
121 }
122 match self.shortcut.as_ref() {
123 None => ys + xs,
124 Some(shortcut) => ys + xs.apply(shortcut),
125 }
126 }
127}
128
129impl StreamingModule for SeaNetResnetBlock {
130 fn reset_state(&mut self) {
131 self.skip_op.reset_state();
132 for block in self.block.iter_mut() {
133 block.reset_state()
134 }
135 if let Some(shortcut) = self.shortcut.as_mut() {
136 shortcut.reset_state()
137 }
138 }
139
140 fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
141 let _enter = self.span.enter();
142 let mut ys = xs.clone();
143 for block in self.block.iter_mut() {
144 ys = block.step(&ys.apply(&self.activation)?, m)?;
145 }
146 match self.shortcut.as_mut() {
147 None => self.skip_op.step(&ys, xs, m),
148 Some(shortcut) => self.skip_op.step(&ys, &shortcut.step(xs, m)?, m),
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
154struct EncoderLayer {
155 residuals: Vec<SeaNetResnetBlock>,
156 downsample: StreamableConv1d,
157}
158
159#[derive(Debug, Clone)]
160pub struct SeaNetEncoder {
161 init_conv1d: StreamableConv1d,
162 activation: candle_nn::Activation,
163 layers: Vec<EncoderLayer>,
164 final_conv1d: StreamableConv1d,
165 span: tracing::Span,
166}
167
168impl SeaNetEncoder {
169 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
170 if cfg.lstm > 0 {
171 candle::bail!("seanet lstm is not supported")
172 }
173 let n_blocks = 2 + cfg.ratios.len();
174 let mut mult = 1usize;
175 let init_norm = if cfg.disable_norm_outer_blocks >= 1 { None } else { Some(cfg.norm) };
176 let mut layer_idx = 0;
177 let vb = vb.pp("model");
178 let init_conv1d = StreamableConv1d::new(
179 cfg.channels,
180 mult * cfg.n_filters,
181 cfg.kernel_size,
182 1,
183 1,
184 1,
185 true,
186 cfg.causal,
187 init_norm,
188 cfg.pad_mode,
189 vb.pp(layer_idx),
190 )?;
191 layer_idx += 1;
192 let mut layers = Vec::with_capacity(cfg.ratios.len());
193
194 for (i, &ratio) in cfg.ratios.iter().rev().enumerate() {
195 let norm = if cfg.disable_norm_outer_blocks >= i + 2 { None } else { Some(cfg.norm) };
196 let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
197 for j in 0..cfg.n_residual_layers {
198 let resnet_block = SeaNetResnetBlock::new(
199 mult * cfg.n_filters,
200 &[(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)), (1, 1)],
201 cfg.activation,
202 norm,
203 cfg.causal,
204 cfg.pad_mode,
205 cfg.compress,
206 cfg.true_skip,
207 vb.pp(layer_idx),
208 )?;
209 residuals.push(resnet_block);
210 layer_idx += 1;
211 }
212 let downsample = StreamableConv1d::new(
213 mult * cfg.n_filters,
214 mult * cfg.n_filters * 2,
215 ratio * 2,
216 ratio,
217 1,
218 1,
219 true,
220 true,
221 norm,
222 cfg.pad_mode,
223 vb.pp(layer_idx + 1),
224 )?;
225 layer_idx += 2;
226 let layer = EncoderLayer { downsample, residuals };
227 layers.push(layer);
228 mult *= 2
229 }
230
231 let final_norm =
232 if cfg.disable_norm_outer_blocks >= n_blocks { None } else { Some(cfg.norm) };
233 let final_conv1d = StreamableConv1d::new(
234 mult * cfg.n_filters,
235 cfg.dimension,
236 cfg.last_kernel_size,
237 1,
238 1,
239 1,
240 true,
241 cfg.causal,
242 final_norm,
243 cfg.pad_mode,
244 vb.pp(layer_idx + 1),
245 )?;
246 Ok(Self {
247 init_conv1d,
248 activation: cfg.activation,
249 layers,
250 final_conv1d,
251 span: tracing::span!(tracing::Level::TRACE, "sea-encoder"),
252 })
253 }
254
255 pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
256 self.init_conv1d.reset_batch_idx(batch_idx, batch_size)?;
257 self.final_conv1d.reset_batch_idx(batch_idx, batch_size)?;
258 for layer in self.layers.iter_mut() {
259 layer.downsample.reset_batch_idx(batch_idx, batch_size)?;
260 for l in layer.residuals.iter_mut() {
261 l.reset_batch_idx(batch_idx, batch_size)?;
262 }
263 }
264 Ok(())
265 }
266}
267
268impl Module for SeaNetEncoder {
269 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
270 let _enter = self.span.enter();
271 let mut xs = xs.apply(&self.init_conv1d)?;
272 for layer in self.layers.iter() {
273 for residual in layer.residuals.iter() {
274 xs = xs.apply(residual)?
275 }
276 xs = xs.apply(&self.activation)?.apply(&layer.downsample)?;
277 }
278 xs.apply(&self.activation)?.apply(&self.final_conv1d)
279 }
280}
281
282impl StreamingModule for SeaNetEncoder {
283 fn reset_state(&mut self) {
284 self.init_conv1d.reset_state();
285 self.layers.iter_mut().for_each(|v| {
286 v.residuals.iter_mut().for_each(|v| v.reset_state());
287 v.downsample.reset_state()
288 });
289 self.final_conv1d.reset_state();
290 }
291
292 fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
293 let _enter = self.span.enter();
294 let mut xs = self.init_conv1d.step(xs, m)?;
295 for layer in self.layers.iter_mut() {
296 for residual in layer.residuals.iter_mut() {
297 xs = residual.step(&xs, m)?;
298 }
299 xs = layer.downsample.step(&xs.apply(&self.activation)?, m)?;
300 }
301 self.final_conv1d.step(&xs.apply(&self.activation)?, m)
302 }
303}
304
305#[derive(Debug, Clone)]
306struct DecoderLayer {
307 upsample: StreamableConvTranspose1d,
308 residuals: Vec<SeaNetResnetBlock>,
309}
310
311#[derive(Debug, Clone)]
312pub struct SeaNetDecoder {
313 init_conv1d: StreamableConv1d,
314 activation: candle_nn::Activation,
315 layers: Vec<DecoderLayer>,
316 final_conv1d: StreamableConv1d,
317 final_activation: Option<candle_nn::Activation>,
318 span: tracing::Span,
319}
320
321impl SeaNetDecoder {
322 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
323 if cfg.lstm > 0 {
324 candle::bail!("seanet lstm is not supported")
325 }
326 let n_blocks = 2 + cfg.ratios.len();
327 let mut mult = 1 << cfg.ratios.len();
328 let init_norm =
329 if cfg.disable_norm_outer_blocks == n_blocks { None } else { Some(cfg.norm) };
330 let mut layer_idx = 0;
331 let vb = vb.pp("model");
332 let init_conv1d = StreamableConv1d::new(
333 cfg.dimension,
334 mult * cfg.n_filters,
335 cfg.kernel_size,
336 1,
337 1,
338 1,
339 true,
340 cfg.causal,
341 init_norm,
342 cfg.pad_mode,
343 vb.pp(layer_idx),
344 )?;
345 layer_idx += 1;
346 let mut layers = Vec::with_capacity(cfg.ratios.len());
347 for (i, &ratio) in cfg.ratios.iter().enumerate() {
348 let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks {
349 None
350 } else {
351 Some(cfg.norm)
352 };
353 let upsample = StreamableConvTranspose1d::new(
354 mult * cfg.n_filters,
355 mult * cfg.n_filters / 2,
356 ratio * 2,
357 ratio,
358 1,
359 true,
360 true,
361 norm,
362 vb.pp(layer_idx + 1),
363 )?;
364 layer_idx += 2;
365
366 let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
367 for j in 0..cfg.n_residual_layers {
368 let resnet_block = SeaNetResnetBlock::new(
369 mult * cfg.n_filters / 2,
370 &[(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)), (1, 1)],
371 cfg.activation,
372 norm,
373 cfg.causal,
374 cfg.pad_mode,
375 cfg.compress,
376 cfg.true_skip,
377 vb.pp(layer_idx),
378 )?;
379 residuals.push(resnet_block);
380 layer_idx += 1;
381 }
382 let layer = DecoderLayer { upsample, residuals };
383 layers.push(layer);
384 mult /= 2
385 }
386 let final_norm = if cfg.disable_norm_outer_blocks >= 1 { None } else { Some(cfg.norm) };
387 let final_conv1d = StreamableConv1d::new(
388 cfg.n_filters,
389 cfg.channels,
390 cfg.last_kernel_size,
391 1,
392 1,
393 1,
394 true,
395 cfg.causal,
396 final_norm,
397 cfg.pad_mode,
398 vb.pp(layer_idx + 1),
399 )?;
400 Ok(Self {
401 init_conv1d,
402 activation: cfg.activation,
403 layers,
404 final_conv1d,
405 final_activation: cfg.final_activation,
406 span: tracing::span!(tracing::Level::TRACE, "sea-decoder"),
407 })
408 }
409
410 pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
411 self.init_conv1d.reset_batch_idx(batch_idx, batch_size)?;
412 self.final_conv1d.reset_batch_idx(batch_idx, batch_size)?;
413 for layer in self.layers.iter_mut() {
414 layer.upsample.reset_batch_idx(batch_idx, batch_size)?;
415 for l in layer.residuals.iter_mut() {
416 l.reset_batch_idx(batch_idx, batch_size)?;
417 }
418 }
419 Ok(())
420 }
421}
422
423impl Module for SeaNetDecoder {
424 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
425 let _enter = self.span.enter();
426 let mut xs = xs.apply(&self.init_conv1d)?;
427 for layer in self.layers.iter() {
428 xs = xs.apply(&self.activation)?.apply(&layer.upsample)?;
429 for residual in layer.residuals.iter() {
430 xs = xs.apply(residual)?
431 }
432 }
433 let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?;
434 let xs = match self.final_activation.as_ref() {
435 None => xs,
436 Some(act) => xs.apply(act)?,
437 };
438 Ok(xs)
439 }
440}
441
442impl StreamingModule for SeaNetDecoder {
443 fn reset_state(&mut self) {
444 self.init_conv1d.reset_state();
445 self.layers.iter_mut().for_each(|v| {
446 v.residuals.iter_mut().for_each(|v| v.reset_state());
447 v.upsample.reset_state()
448 });
449 self.final_conv1d.reset_state();
450 }
451
452 fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
453 let _enter = self.span.enter();
454 let mut xs = self.init_conv1d.step(xs, m)?;
455 for layer in self.layers.iter_mut() {
456 xs = layer.upsample.step(&xs.apply(&self.activation)?, m)?;
457 for residual in layer.residuals.iter_mut() {
458 xs = residual.step(&xs, m)?;
459 }
460 }
461 let xs = self.final_conv1d.step(&xs.apply(&self.activation)?, m)?;
462 let xs = match self.final_activation.as_ref() {
463 None => xs,
464 Some(act) => xs.apply(act)?,
465 };
466 Ok(xs)
467 }
468}