deep_delta_learning/
embed_expander.rs1use burn::module::Module;
2use burn::prelude::*;
3
4use crate::shortconv::CausalDepthwiseConv1d;
5
6#[derive(Module, Debug)]
7pub struct EmbeddingExpander<B: Backend> {
8 conv: Option<CausalDepthwiseConv1d<B>>,
9 d_model: usize,
10 d_value: usize,
11}
12
13impl<B: Backend> EmbeddingExpander<B> {
14 pub fn new(
15 d_model: usize,
16 d_value: usize,
17 embed_conv: bool,
18 kernel_size: usize,
19 device: &B::Device,
20 ) -> Self {
21 let conv = if embed_conv && d_value > 1 {
22 Some(CausalDepthwiseConv1d::<B>::identity(
23 d_model * d_value,
24 kernel_size,
25 device,
26 ))
27 } else {
28 None
29 };
30 Self {
31 conv,
32 d_model,
33 d_value,
34 }
35 }
36
37 pub fn forward(&self, x_emb: Tensor<B, 3>) -> Tensor<B, 4> {
38 let [batch_size, seq_len, _] = x_emb.dims();
39 let repeated = x_emb.unsqueeze_dim::<4>(3).repeat_dim(3, self.d_value);
40 match &self.conv {
41 None => repeated,
42 Some(conv) => conv
43 .forward(
44 repeated
45 .reshape([batch_size, seq_len, self.d_model * self.d_value])
46 .swap_dims(1, 2),
47 )
48 .swap_dims(1, 2)
49 .reshape([batch_size, seq_len, self.d_model, self.d_value]),
50 }
51 }
52}