deep_delta_learning/
config.rs1use burn::config::Config;
2use burn::tensor::backend::Backend;
3
4use crate::error::ConfigValidationError;
5use crate::transformer::DdlTransformer;
6
7#[derive(Config, Debug, Copy, PartialEq, Eq)]
8pub enum DdlMapping {
9 KMap,
10 VMap,
11}
12
13#[derive(Config, Debug, Copy, PartialEq, Eq)]
14pub enum CompressionVariant {
15 TokenConv,
16 ChannelConv,
17}
18
19#[derive(Config, Debug, PartialEq)]
20pub struct DdlConfig {
21 pub d_model: usize,
22 pub num_layers: usize,
23 pub num_heads: usize,
24 pub vocab_size: usize,
25 #[config(default = 1)]
26 pub d_value: usize,
27 #[config(default = 0)]
28 pub head_dim: usize,
29 #[config(default = 0)]
30 pub mlp_hidden: usize,
31 #[config(default = "DdlMapping::KMap")]
32 pub mapping: DdlMapping,
33 #[config(default = 0.5)]
34 pub beta_init: f64,
35 #[config(default = true)]
36 pub beta_single_linear: bool,
37 #[config(default = 1e-6)]
38 pub k_eps: f64,
39 #[config(default = "CompressionVariant::TokenConv")]
40 pub compression: CompressionVariant,
41 #[config(default = 4)]
42 pub shortconv_kernel_size: usize,
43 #[config(default = false)]
44 pub embed_conv: bool,
45 #[config(default = 128)]
46 pub max_seq_len: usize,
47 #[config(default = 10000.0)]
48 pub rope_theta: f32,
49}
50
51impl DdlConfig {
52 pub fn validate(&self) -> Result<(), ConfigValidationError> {
53 if self.d_model == 0 {
54 return Err(ConfigValidationError::NonPositiveUsize {
55 field: "d_model",
56 value: self.d_model,
57 });
58 }
59 if self.num_layers == 0 {
60 return Err(ConfigValidationError::NonPositiveUsize {
61 field: "num_layers",
62 value: self.num_layers,
63 });
64 }
65 if self.num_heads == 0 {
66 return Err(ConfigValidationError::NonPositiveUsize {
67 field: "num_heads",
68 value: self.num_heads,
69 });
70 }
71 if self.vocab_size == 0 {
72 return Err(ConfigValidationError::NonPositiveUsize {
73 field: "vocab_size",
74 value: self.vocab_size,
75 });
76 }
77 if self.d_value == 0 {
78 return Err(ConfigValidationError::NonPositiveUsize {
79 field: "d_value",
80 value: self.d_value,
81 });
82 }
83 if !self.d_model.is_multiple_of(self.num_heads) {
84 return Err(ConfigValidationError::DModelNotDivisibleByHeads {
85 d_model: self.d_model,
86 num_heads: self.num_heads,
87 });
88 }
89
90 let head_dim = self.effective_head_dim();
91 let Some(total_head_dim) = head_dim.checked_mul(self.num_heads) else {
92 return Err(ConfigValidationError::HeadDimMismatch {
93 head_dim,
94 num_heads: self.num_heads,
95 d_model: self.d_model,
96 });
97 };
98 if head_dim == 0 || total_head_dim != self.d_model {
99 return Err(ConfigValidationError::HeadDimMismatch {
100 head_dim,
101 num_heads: self.num_heads,
102 d_model: self.d_model,
103 });
104 }
105 if !(self.beta_init > 0.0 && self.beta_init < 2.0) {
106 return Err(ConfigValidationError::BetaInitOutOfRange {
107 value: self.beta_init,
108 });
109 }
110 if self.k_eps <= 0.0 {
111 return Err(ConfigValidationError::NonPositiveF64 {
112 field: "k_eps",
113 value: self.k_eps,
114 });
115 }
116 if self.shortconv_kernel_size == 0 {
117 return Err(ConfigValidationError::NonPositiveUsize {
118 field: "shortconv_kernel_size",
119 value: self.shortconv_kernel_size,
120 });
121 }
122 if self.max_seq_len == 0 {
123 return Err(ConfigValidationError::NonPositiveUsize {
124 field: "max_seq_len",
125 value: self.max_seq_len,
126 });
127 }
128 if self.rope_theta <= 0.0 {
129 return Err(ConfigValidationError::NonPositiveF32 {
130 field: "rope_theta",
131 value: self.rope_theta,
132 });
133 }
134
135 Ok(())
136 }
137
138 pub fn effective_head_dim(&self) -> usize {
139 if self.head_dim == 0 {
140 self.d_model / self.num_heads
141 } else {
142 self.head_dim
143 }
144 }
145
146 pub fn effective_mlp_hidden(&self) -> usize {
147 if self.mlp_hidden == 0 {
148 (8 * self.d_model).div_ceil(3)
149 } else {
150 self.mlp_hidden
151 }
152 }
153
154 pub fn uses_matrix_state(&self) -> bool {
155 self.d_value > 1
156 }
157
158 pub fn try_init<B: Backend>(
159 &self,
160 device: &B::Device,
161 ) -> Result<DdlTransformer<B>, ConfigValidationError> {
162 self.validate()?;
163 Ok(DdlTransformer::new(self.clone(), device))
164 }
165
166 pub fn init<B: Backend>(&self, device: &B::Device) -> DdlTransformer<B> {
167 self.try_init(device)
168 .unwrap_or_else(|error| panic!("invalid ddl configuration: {error}"))
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn infers_head_dim_when_unset() {
178 let config = DdlConfig::new(768, 12, 6, 50_257);
179 assert_eq!(config.effective_head_dim(), 128);
180 }
181
182 #[test]
183 fn infers_swiglu_hidden_when_unset() {
184 let config = DdlConfig::new(768, 12, 6, 50_257);
185 assert_eq!(config.effective_mlp_hidden(), 2_048);
186 }
187}