Skip to main content

deep_delta_learning/
embed_expander.rs

1use burn::module::Module;
2use burn::prelude::*;
3
4use crate::shortconv::CausalDepthwiseConv1d;
5
6#[derive(Module, Debug)]
7pub struct EmbeddingExpander<B: Backend> {
8    conv: Option<CausalDepthwiseConv1d<B>>,
9    d_model: usize,
10    d_value: usize,
11}
12
13impl<B: Backend> EmbeddingExpander<B> {
14    pub fn new(
15        d_model: usize,
16        d_value: usize,
17        embed_conv: bool,
18        kernel_size: usize,
19        device: &B::Device,
20    ) -> Self {
21        let conv = if embed_conv && d_value > 1 {
22            Some(CausalDepthwiseConv1d::<B>::identity(
23                d_model * d_value,
24                kernel_size,
25                device,
26            ))
27        } else {
28            None
29        };
30        Self {
31            conv,
32            d_model,
33            d_value,
34        }
35    }
36
37    pub fn forward(&self, x_emb: Tensor<B, 3>) -> Tensor<B, 4> {
38        let [batch_size, seq_len, _] = x_emb.dims();
39        let repeated = x_emb.unsqueeze_dim::<4>(3).repeat_dim(3, self.d_value);
40        match &self.conv {
41            None => repeated,
42            Some(conv) => conv
43                .forward(
44                    repeated
45                        .reshape([batch_size, seq_len, self.d_model * self.d_value])
46                        .swap_dims(1, 2),
47                )
48                .swap_dims(1, 2)
49                .reshape([batch_size, seq_len, self.d_model, self.d_value]),
50        }
51    }
52}