Skip to main content

deep_delta_learning/
compressor.rs

1use burn::module::{Module, Param};
2use burn::prelude::*;
3
4use crate::shortconv::CausalDepthwiseConv1d;
5
6#[derive(Module, Debug)]
7pub struct TokenConvCompressor<B: Backend> {
8    conv: CausalDepthwiseConv1d<B>,
9    read_vector: Param<Tensor<B, 1>>,
10    d_model: usize,
11    d_value: usize,
12}
13
14impl<B: Backend> TokenConvCompressor<B> {
15    pub fn new(d_model: usize, d_value: usize, kernel_size: usize, device: &B::Device) -> Self {
16        Self {
17            conv: CausalDepthwiseConv1d::<B>::new(d_model * d_value, kernel_size, false, device),
18            read_vector: Param::from_tensor(
19                Tensor::<B, 1>::ones([d_value], device).div_scalar(d_value as f32),
20            ),
21            d_model,
22            d_value,
23        }
24    }
25
26    pub fn forward(&self, state: Tensor<B, 4>) -> Tensor<B, 3> {
27        let [batch_size, seq_len, _, _] = state.dims();
28        let flattened = state
29            .reshape([batch_size, seq_len, self.d_model * self.d_value])
30            .swap_dims(1, 2);
31        let convolved = self.conv.forward(flattened).swap_dims(1, 2).reshape([
32            batch_size,
33            seq_len,
34            self.d_model,
35            self.d_value,
36        ]);
37        let read = self
38            .read_vector
39            .val()
40            .clone()
41            .unsqueeze_dim::<2>(0)
42            .unsqueeze_dim::<3>(0)
43            .unsqueeze_dim::<4>(0);
44        (convolved * read).sum_dim(3).squeeze_dim::<3>(3)
45    }
46}
47
48#[derive(Module, Debug)]
49pub struct ChannelConvCompressor<B: Backend> {
50    conv: CausalDepthwiseConv1d<B>,
51    d_model: usize,
52    d_value: usize,
53}
54
55impl<B: Backend> ChannelConvCompressor<B> {
56    pub fn new(d_model: usize, d_value: usize, device: &B::Device) -> Self {
57        Self {
58            conv: CausalDepthwiseConv1d::<B>::new(d_model, d_value, false, device),
59            d_model,
60            d_value,
61        }
62    }
63
64    pub fn forward(&self, state: Tensor<B, 4>) -> Tensor<B, 3> {
65        let [batch_size, seq_len, _, _] = state.dims();
66        self.conv
67            .forward_valid(state.reshape([batch_size * seq_len, self.d_model, self.d_value]))
68            .squeeze_dim::<2>(2)
69            .reshape([batch_size, seq_len, self.d_model])
70    }
71}