Skip to main content

deep_delta_learning/
variant.rs

1use burn::prelude::*;
2use burn::tensor::Int;
3use serde::{Deserialize, Serialize};
4
5use crate::baseline::BaselineTransformer;
6use crate::config::{CompressionVariant, DdlConfig};
7use crate::generation::{
8    AutoregressiveModel, GenerationConfig, GenerationError, GenerationResult, generate_tokens,
9};
10use crate::spectral::{ModelDiagnostics, SpectralDiagnostics};
11use crate::transformer::DdlTransformer;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum ModelVariant {
15    Baseline,
16    DdlVector,
17    DdlMatrixToken,
18    DdlMatrixTokenEc,
19    DdlMatrixChannel,
20    DdlMatrixChannelEc,
21}
22
23impl ModelVariant {
24    pub const ALL: [Self; 6] = [
25        Self::Baseline,
26        Self::DdlVector,
27        Self::DdlMatrixToken,
28        Self::DdlMatrixTokenEc,
29        Self::DdlMatrixChannel,
30        Self::DdlMatrixChannelEc,
31    ];
32
33    pub fn all() -> &'static [Self] {
34        &Self::ALL
35    }
36
37    pub fn slug(&self) -> &'static str {
38        match self {
39            Self::Baseline => "baseline",
40            Self::DdlVector => "ddl",
41            Self::DdlMatrixToken => "ddl-tokenconv",
42            Self::DdlMatrixTokenEc => "ddl-ec",
43            Self::DdlMatrixChannel => "ddl-cc",
44            Self::DdlMatrixChannelEc => "ddl-cc-ec",
45        }
46    }
47
48    pub fn uses_ddl(&self) -> bool {
49        !matches!(self, Self::Baseline)
50    }
51
52    pub fn resolve_config(&self, base_config: &DdlConfig) -> DdlConfig {
53        let matrix_d_value = base_config.d_value.max(4);
54
55        match self {
56            Self::Baseline => base_config
57                .clone()
58                .with_d_value(1)
59                .with_embed_conv(false)
60                .with_compression(CompressionVariant::TokenConv),
61            Self::DdlVector => base_config
62                .clone()
63                .with_d_value(1)
64                .with_embed_conv(false)
65                .with_compression(CompressionVariant::TokenConv),
66            Self::DdlMatrixToken => base_config
67                .clone()
68                .with_d_value(matrix_d_value)
69                .with_embed_conv(false)
70                .with_compression(CompressionVariant::TokenConv),
71            Self::DdlMatrixTokenEc => base_config
72                .clone()
73                .with_d_value(matrix_d_value)
74                .with_embed_conv(true)
75                .with_compression(CompressionVariant::TokenConv),
76            Self::DdlMatrixChannel => base_config
77                .clone()
78                .with_d_value(matrix_d_value)
79                .with_embed_conv(false)
80                .with_compression(CompressionVariant::ChannelConv),
81            Self::DdlMatrixChannelEc => base_config
82                .clone()
83                .with_d_value(matrix_d_value)
84                .with_embed_conv(true)
85                .with_compression(CompressionVariant::ChannelConv),
86        }
87    }
88
89    pub fn init_model<B: Backend>(
90        &self,
91        resolved_config: &DdlConfig,
92        device: &B::Device,
93    ) -> ModelInstance<B> {
94        match self {
95            Self::Baseline => {
96                ModelInstance::Baseline(Box::new(BaselineTransformer::new(resolved_config, device)))
97            }
98            _ => ModelInstance::Ddl(Box::new(resolved_config.init(device))),
99        }
100    }
101
102    pub fn build<B: Backend>(
103        &self,
104        base_config: &DdlConfig,
105        device: &B::Device,
106    ) -> (DdlConfig, ModelInstance<B>) {
107        let resolved = self.resolve_config(base_config);
108        let model = self.init_model(&resolved, device);
109        (resolved, model)
110    }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
114pub enum DiagnosticLevel {
115    None,
116    Summary,
117    #[default]
118    Spectral,
119}
120
121impl DiagnosticLevel {
122    pub fn wants_model_diagnostics(self) -> bool {
123        !matches!(self, Self::None)
124    }
125
126    pub fn wants_spectral(self) -> bool {
127        matches!(self.effective(), Self::Spectral)
128    }
129
130    pub fn effective(self) -> Self {
131        #[cfg(feature = "spectral")]
132        {
133            self
134        }
135
136        #[cfg(not(feature = "spectral"))]
137        {
138            match self {
139                Self::Spectral => Self::Summary,
140                other => other,
141            }
142        }
143    }
144
145    pub fn slug(self) -> &'static str {
146        match self {
147            Self::None => "none",
148            Self::Summary => "summary",
149            Self::Spectral => "spectral",
150        }
151    }
152}
153
154#[derive(Debug, Clone)]
155pub enum ModelInstance<B: Backend> {
156    Baseline(Box<BaselineTransformer<B>>),
157    Ddl(Box<DdlTransformer<B>>),
158}
159
160impl<B: Backend> ModelInstance<B> {
161    pub fn num_params(&self) -> usize {
162        match self {
163            Self::Baseline(model) => model.num_params(),
164            Self::Ddl(model) => model.num_params(),
165        }
166    }
167
168    pub fn forward_logits(
169        &self,
170        input_ids: Tensor<B, 2, Int>,
171        mask: Option<&Tensor<B, 3>>,
172    ) -> Tensor<B, 3> {
173        match self {
174            Self::Baseline(model) => model.forward_logits(input_ids, mask),
175            Self::Ddl(model) => model.forward_logits(input_ids, mask),
176        }
177    }
178
179    pub fn forward_with_optional_diagnostics(
180        &self,
181        input_ids: Tensor<B, 2, Int>,
182        mask: Option<&Tensor<B, 3>>,
183    ) -> ModelOutput<B> {
184        self.forward_with_diagnostics(input_ids, mask, DiagnosticLevel::default())
185    }
186
187    pub fn forward_with_diagnostics(
188        &self,
189        input_ids: Tensor<B, 2, Int>,
190        mask: Option<&Tensor<B, 3>>,
191        diagnostic_level: DiagnosticLevel,
192    ) -> ModelOutput<B> {
193        let diagnostic_level = diagnostic_level.effective();
194
195        match self {
196            Self::Baseline(model) => ModelOutput {
197                logits: model.forward_logits(input_ids, mask),
198                diagnostics: None,
199                spectral: None,
200            },
201            Self::Ddl(model) => match diagnostic_level {
202                DiagnosticLevel::None => ModelOutput {
203                    logits: model.forward_logits(input_ids, mask),
204                    diagnostics: None,
205                    spectral: None,
206                },
207                DiagnosticLevel::Summary => {
208                    let (logits, diagnostics) = model.forward(input_ids, mask);
209                    ModelOutput {
210                        logits,
211                        diagnostics: Some(diagnostics),
212                        spectral: None,
213                    }
214                }
215                DiagnosticLevel::Spectral => {
216                    #[cfg(feature = "spectral")]
217                    {
218                        let (logits, diagnostics, spectral) =
219                            model.forward_with_spectral_diagnostics(input_ids, mask);
220                        ModelOutput {
221                            logits,
222                            diagnostics: Some(diagnostics),
223                            spectral: Some(spectral),
224                        }
225                    }
226
227                    #[cfg(not(feature = "spectral"))]
228                    {
229                        let (logits, diagnostics) = model.forward(input_ids, mask);
230                        ModelOutput {
231                            logits,
232                            diagnostics: Some(diagnostics),
233                            spectral: None,
234                        }
235                    }
236                }
237            },
238        }
239    }
240
241    pub fn max_seq_len(&self) -> usize {
242        match self {
243            Self::Baseline(model) => model.max_seq_len(),
244            Self::Ddl(model) => model.max_seq_len(),
245        }
246    }
247
248    pub fn generate(
249        &self,
250        prompt_tokens: &[usize],
251        generation_config: &GenerationConfig,
252        device: &B::Device,
253    ) -> Result<GenerationResult, GenerationError> {
254        generate_tokens(self, prompt_tokens, generation_config, device)
255    }
256}
257
258impl<B: Backend> AutoregressiveModel<B> for ModelInstance<B> {
259    fn forward_logits(
260        &self,
261        input_ids: Tensor<B, 2, Int>,
262        mask: Option<&Tensor<B, 3>>,
263    ) -> Tensor<B, 3> {
264        ModelInstance::forward_logits(self, input_ids, mask)
265    }
266
267    fn max_seq_len(&self) -> usize {
268        ModelInstance::max_seq_len(self)
269    }
270}
271
272pub struct ModelOutput<B: Backend> {
273    pub logits: Tensor<B, 3>,
274    pub diagnostics: Option<ModelDiagnostics>,
275    pub spectral: Option<SpectralDiagnostics>,
276}