Skip to main content

deep_delta_learning/
config.rs

1use burn::config::Config;
2use burn::tensor::backend::Backend;
3
4use crate::error::ConfigValidationError;
5use crate::transformer::DdlTransformer;
6
7#[derive(Config, Debug, Copy, PartialEq, Eq)]
8pub enum DdlMapping {
9    KMap,
10    VMap,
11}
12
13#[derive(Config, Debug, Copy, PartialEq, Eq)]
14pub enum CompressionVariant {
15    TokenConv,
16    ChannelConv,
17}
18
19#[derive(Config, Debug, PartialEq)]
20pub struct DdlConfig {
21    pub d_model: usize,
22    pub num_layers: usize,
23    pub num_heads: usize,
24    pub vocab_size: usize,
25    #[config(default = 1)]
26    pub d_value: usize,
27    #[config(default = 0)]
28    pub head_dim: usize,
29    #[config(default = 0)]
30    pub mlp_hidden: usize,
31    #[config(default = "DdlMapping::KMap")]
32    pub mapping: DdlMapping,
33    #[config(default = 0.5)]
34    pub beta_init: f64,
35    #[config(default = true)]
36    pub beta_single_linear: bool,
37    #[config(default = 1e-6)]
38    pub k_eps: f64,
39    #[config(default = "CompressionVariant::TokenConv")]
40    pub compression: CompressionVariant,
41    #[config(default = 4)]
42    pub shortconv_kernel_size: usize,
43    #[config(default = false)]
44    pub embed_conv: bool,
45    #[config(default = 128)]
46    pub max_seq_len: usize,
47    #[config(default = 10000.0)]
48    pub rope_theta: f32,
49}
50
51impl DdlConfig {
52    pub fn validate(&self) -> Result<(), ConfigValidationError> {
53        if self.d_model == 0 {
54            return Err(ConfigValidationError::NonPositiveUsize {
55                field: "d_model",
56                value: self.d_model,
57            });
58        }
59        if self.num_layers == 0 {
60            return Err(ConfigValidationError::NonPositiveUsize {
61                field: "num_layers",
62                value: self.num_layers,
63            });
64        }
65        if self.num_heads == 0 {
66            return Err(ConfigValidationError::NonPositiveUsize {
67                field: "num_heads",
68                value: self.num_heads,
69            });
70        }
71        if self.vocab_size == 0 {
72            return Err(ConfigValidationError::NonPositiveUsize {
73                field: "vocab_size",
74                value: self.vocab_size,
75            });
76        }
77        if self.d_value == 0 {
78            return Err(ConfigValidationError::NonPositiveUsize {
79                field: "d_value",
80                value: self.d_value,
81            });
82        }
83        if !self.d_model.is_multiple_of(self.num_heads) {
84            return Err(ConfigValidationError::DModelNotDivisibleByHeads {
85                d_model: self.d_model,
86                num_heads: self.num_heads,
87            });
88        }
89
90        let head_dim = self.effective_head_dim();
91        let Some(total_head_dim) = head_dim.checked_mul(self.num_heads) else {
92            return Err(ConfigValidationError::HeadDimMismatch {
93                head_dim,
94                num_heads: self.num_heads,
95                d_model: self.d_model,
96            });
97        };
98        if head_dim == 0 || total_head_dim != self.d_model {
99            return Err(ConfigValidationError::HeadDimMismatch {
100                head_dim,
101                num_heads: self.num_heads,
102                d_model: self.d_model,
103            });
104        }
105        if !(self.beta_init > 0.0 && self.beta_init < 2.0) {
106            return Err(ConfigValidationError::BetaInitOutOfRange {
107                value: self.beta_init,
108            });
109        }
110        if self.k_eps <= 0.0 {
111            return Err(ConfigValidationError::NonPositiveF64 {
112                field: "k_eps",
113                value: self.k_eps,
114            });
115        }
116        if self.shortconv_kernel_size == 0 {
117            return Err(ConfigValidationError::NonPositiveUsize {
118                field: "shortconv_kernel_size",
119                value: self.shortconv_kernel_size,
120            });
121        }
122        if self.max_seq_len == 0 {
123            return Err(ConfigValidationError::NonPositiveUsize {
124                field: "max_seq_len",
125                value: self.max_seq_len,
126            });
127        }
128        if self.rope_theta <= 0.0 {
129            return Err(ConfigValidationError::NonPositiveF32 {
130                field: "rope_theta",
131                value: self.rope_theta,
132            });
133        }
134
135        Ok(())
136    }
137
138    pub fn effective_head_dim(&self) -> usize {
139        if self.head_dim == 0 {
140            self.d_model / self.num_heads
141        } else {
142            self.head_dim
143        }
144    }
145
146    pub fn effective_mlp_hidden(&self) -> usize {
147        if self.mlp_hidden == 0 {
148            (8 * self.d_model).div_ceil(3)
149        } else {
150            self.mlp_hidden
151        }
152    }
153
154    pub fn uses_matrix_state(&self) -> bool {
155        self.d_value > 1
156    }
157
158    pub fn try_init<B: Backend>(
159        &self,
160        device: &B::Device,
161    ) -> Result<DdlTransformer<B>, ConfigValidationError> {
162        self.validate()?;
163        Ok(DdlTransformer::new(self.clone(), device))
164    }
165
166    pub fn init<B: Backend>(&self, device: &B::Device) -> DdlTransformer<B> {
167        self.try_init(device)
168            .unwrap_or_else(|error| panic!("invalid ddl configuration: {error}"))
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn infers_head_dim_when_unset() {
178        let config = DdlConfig::new(768, 12, 6, 50_257);
179        assert_eq!(config.effective_head_dim(), 128);
180    }
181
182    #[test]
183    fn infers_swiglu_hidden_when_unset() {
184        let config = DdlConfig::new(768, 12, 6, 50_257);
185        assert_eq!(config.effective_mlp_hidden(), 2_048);
186    }
187}