deep_delta_learning/
compressor.rs1use burn::module::{Module, Param};
2use burn::prelude::*;
3
4use crate::shortconv::CausalDepthwiseConv1d;
5
6#[derive(Module, Debug)]
7pub struct TokenConvCompressor<B: Backend> {
8 conv: CausalDepthwiseConv1d<B>,
9 read_vector: Param<Tensor<B, 1>>,
10 d_model: usize,
11 d_value: usize,
12}
13
14impl<B: Backend> TokenConvCompressor<B> {
15 pub fn new(d_model: usize, d_value: usize, kernel_size: usize, device: &B::Device) -> Self {
16 Self {
17 conv: CausalDepthwiseConv1d::<B>::new(d_model * d_value, kernel_size, false, device),
18 read_vector: Param::from_tensor(
19 Tensor::<B, 1>::ones([d_value], device).div_scalar(d_value as f32),
20 ),
21 d_model,
22 d_value,
23 }
24 }
25
26 pub fn forward(&self, state: Tensor<B, 4>) -> Tensor<B, 3> {
27 let [batch_size, seq_len, _, _] = state.dims();
28 let flattened = state
29 .reshape([batch_size, seq_len, self.d_model * self.d_value])
30 .swap_dims(1, 2);
31 let convolved = self.conv.forward(flattened).swap_dims(1, 2).reshape([
32 batch_size,
33 seq_len,
34 self.d_model,
35 self.d_value,
36 ]);
37 let read = self
38 .read_vector
39 .val()
40 .clone()
41 .unsqueeze_dim::<2>(0)
42 .unsqueeze_dim::<3>(0)
43 .unsqueeze_dim::<4>(0);
44 (convolved * read).sum_dim(3).squeeze_dim::<3>(3)
45 }
46}
47
48#[derive(Module, Debug)]
49pub struct ChannelConvCompressor<B: Backend> {
50 conv: CausalDepthwiseConv1d<B>,
51 d_model: usize,
52 d_value: usize,
53}
54
55impl<B: Backend> ChannelConvCompressor<B> {
56 pub fn new(d_model: usize, d_value: usize, device: &B::Device) -> Self {
57 Self {
58 conv: CausalDepthwiseConv1d::<B>::new(d_model, d_value, false, device),
59 d_model,
60 d_value,
61 }
62 }
63
64 pub fn forward(&self, state: Tensor<B, 4>) -> Tensor<B, 3> {
65 let [batch_size, seq_len, _, _] = state.dims();
66 self.conv
67 .forward_valid(state.reshape([batch_size * seq_len, self.d_model, self.d_value]))
68 .squeeze_dim::<2>(2)
69 .reshape([batch_size, seq_len, self.d_model])
70 }
71}