1use burn::prelude::*;
2use burn::tensor::Int;
3use serde::{Deserialize, Serialize};
4
5use crate::baseline::BaselineTransformer;
6use crate::config::{CompressionVariant, DdlConfig};
7use crate::generation::{
8 AutoregressiveModel, GenerationConfig, GenerationError, GenerationResult, generate_tokens,
9};
10use crate::spectral::{ModelDiagnostics, SpectralDiagnostics};
11use crate::transformer::DdlTransformer;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum ModelVariant {
15 Baseline,
16 DdlVector,
17 DdlMatrixToken,
18 DdlMatrixTokenEc,
19 DdlMatrixChannel,
20 DdlMatrixChannelEc,
21}
22
23impl ModelVariant {
24 pub const ALL: [Self; 6] = [
25 Self::Baseline,
26 Self::DdlVector,
27 Self::DdlMatrixToken,
28 Self::DdlMatrixTokenEc,
29 Self::DdlMatrixChannel,
30 Self::DdlMatrixChannelEc,
31 ];
32
33 pub fn all() -> &'static [Self] {
34 &Self::ALL
35 }
36
37 pub fn slug(&self) -> &'static str {
38 match self {
39 Self::Baseline => "baseline",
40 Self::DdlVector => "ddl",
41 Self::DdlMatrixToken => "ddl-tokenconv",
42 Self::DdlMatrixTokenEc => "ddl-ec",
43 Self::DdlMatrixChannel => "ddl-cc",
44 Self::DdlMatrixChannelEc => "ddl-cc-ec",
45 }
46 }
47
48 pub fn uses_ddl(&self) -> bool {
49 !matches!(self, Self::Baseline)
50 }
51
52 pub fn resolve_config(&self, base_config: &DdlConfig) -> DdlConfig {
53 let matrix_d_value = base_config.d_value.max(4);
54
55 match self {
56 Self::Baseline => base_config
57 .clone()
58 .with_d_value(1)
59 .with_embed_conv(false)
60 .with_compression(CompressionVariant::TokenConv),
61 Self::DdlVector => base_config
62 .clone()
63 .with_d_value(1)
64 .with_embed_conv(false)
65 .with_compression(CompressionVariant::TokenConv),
66 Self::DdlMatrixToken => base_config
67 .clone()
68 .with_d_value(matrix_d_value)
69 .with_embed_conv(false)
70 .with_compression(CompressionVariant::TokenConv),
71 Self::DdlMatrixTokenEc => base_config
72 .clone()
73 .with_d_value(matrix_d_value)
74 .with_embed_conv(true)
75 .with_compression(CompressionVariant::TokenConv),
76 Self::DdlMatrixChannel => base_config
77 .clone()
78 .with_d_value(matrix_d_value)
79 .with_embed_conv(false)
80 .with_compression(CompressionVariant::ChannelConv),
81 Self::DdlMatrixChannelEc => base_config
82 .clone()
83 .with_d_value(matrix_d_value)
84 .with_embed_conv(true)
85 .with_compression(CompressionVariant::ChannelConv),
86 }
87 }
88
89 pub fn init_model<B: Backend>(
90 &self,
91 resolved_config: &DdlConfig,
92 device: &B::Device,
93 ) -> ModelInstance<B> {
94 match self {
95 Self::Baseline => {
96 ModelInstance::Baseline(Box::new(BaselineTransformer::new(resolved_config, device)))
97 }
98 _ => ModelInstance::Ddl(Box::new(resolved_config.init(device))),
99 }
100 }
101
102 pub fn build<B: Backend>(
103 &self,
104 base_config: &DdlConfig,
105 device: &B::Device,
106 ) -> (DdlConfig, ModelInstance<B>) {
107 let resolved = self.resolve_config(base_config);
108 let model = self.init_model(&resolved, device);
109 (resolved, model)
110 }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
114pub enum DiagnosticLevel {
115 None,
116 Summary,
117 #[default]
118 Spectral,
119}
120
121impl DiagnosticLevel {
122 pub fn wants_model_diagnostics(self) -> bool {
123 !matches!(self, Self::None)
124 }
125
126 pub fn wants_spectral(self) -> bool {
127 matches!(self.effective(), Self::Spectral)
128 }
129
130 pub fn effective(self) -> Self {
131 #[cfg(feature = "spectral")]
132 {
133 self
134 }
135
136 #[cfg(not(feature = "spectral"))]
137 {
138 match self {
139 Self::Spectral => Self::Summary,
140 other => other,
141 }
142 }
143 }
144
145 pub fn slug(self) -> &'static str {
146 match self {
147 Self::None => "none",
148 Self::Summary => "summary",
149 Self::Spectral => "spectral",
150 }
151 }
152}
153
154#[derive(Debug, Clone)]
155pub enum ModelInstance<B: Backend> {
156 Baseline(Box<BaselineTransformer<B>>),
157 Ddl(Box<DdlTransformer<B>>),
158}
159
160impl<B: Backend> ModelInstance<B> {
161 pub fn num_params(&self) -> usize {
162 match self {
163 Self::Baseline(model) => model.num_params(),
164 Self::Ddl(model) => model.num_params(),
165 }
166 }
167
168 pub fn forward_logits(
169 &self,
170 input_ids: Tensor<B, 2, Int>,
171 mask: Option<&Tensor<B, 3>>,
172 ) -> Tensor<B, 3> {
173 match self {
174 Self::Baseline(model) => model.forward_logits(input_ids, mask),
175 Self::Ddl(model) => model.forward_logits(input_ids, mask),
176 }
177 }
178
179 pub fn forward_with_optional_diagnostics(
180 &self,
181 input_ids: Tensor<B, 2, Int>,
182 mask: Option<&Tensor<B, 3>>,
183 ) -> ModelOutput<B> {
184 self.forward_with_diagnostics(input_ids, mask, DiagnosticLevel::default())
185 }
186
187 pub fn forward_with_diagnostics(
188 &self,
189 input_ids: Tensor<B, 2, Int>,
190 mask: Option<&Tensor<B, 3>>,
191 diagnostic_level: DiagnosticLevel,
192 ) -> ModelOutput<B> {
193 let diagnostic_level = diagnostic_level.effective();
194
195 match self {
196 Self::Baseline(model) => ModelOutput {
197 logits: model.forward_logits(input_ids, mask),
198 diagnostics: None,
199 spectral: None,
200 },
201 Self::Ddl(model) => match diagnostic_level {
202 DiagnosticLevel::None => ModelOutput {
203 logits: model.forward_logits(input_ids, mask),
204 diagnostics: None,
205 spectral: None,
206 },
207 DiagnosticLevel::Summary => {
208 let (logits, diagnostics) = model.forward(input_ids, mask);
209 ModelOutput {
210 logits,
211 diagnostics: Some(diagnostics),
212 spectral: None,
213 }
214 }
215 DiagnosticLevel::Spectral => {
216 #[cfg(feature = "spectral")]
217 {
218 let (logits, diagnostics, spectral) =
219 model.forward_with_spectral_diagnostics(input_ids, mask);
220 ModelOutput {
221 logits,
222 diagnostics: Some(diagnostics),
223 spectral: Some(spectral),
224 }
225 }
226
227 #[cfg(not(feature = "spectral"))]
228 {
229 let (logits, diagnostics) = model.forward(input_ids, mask);
230 ModelOutput {
231 logits,
232 diagnostics: Some(diagnostics),
233 spectral: None,
234 }
235 }
236 }
237 },
238 }
239 }
240
241 pub fn max_seq_len(&self) -> usize {
242 match self {
243 Self::Baseline(model) => model.max_seq_len(),
244 Self::Ddl(model) => model.max_seq_len(),
245 }
246 }
247
248 pub fn generate(
249 &self,
250 prompt_tokens: &[usize],
251 generation_config: &GenerationConfig,
252 device: &B::Device,
253 ) -> Result<GenerationResult, GenerationError> {
254 generate_tokens(self, prompt_tokens, generation_config, device)
255 }
256}
257
258impl<B: Backend> AutoregressiveModel<B> for ModelInstance<B> {
259 fn forward_logits(
260 &self,
261 input_ids: Tensor<B, 2, Int>,
262 mask: Option<&Tensor<B, 3>>,
263 ) -> Tensor<B, 3> {
264 ModelInstance::forward_logits(self, input_ids, mask)
265 }
266
267 fn max_seq_len(&self) -> usize {
268 ModelInstance::max_seq_len(self)
269 }
270}
271
272pub struct ModelOutput<B: Backend> {
273 pub logits: Tensor<B, 3>,
274 pub diagnostics: Option<ModelDiagnostics>,
275 pub spectral: Option<SpectralDiagnostics>,
276}