kizzasi_model/
transformer.rs

1//! Transformer: Standard Multi-Head Attention Baseline
2//!
3//! This module implements a standard Transformer architecture for comparison
4//! with State Space Models. While Transformers require O(N²) attention computation
5//! and O(N) memory per step during inference, they serve as a strong baseline
6//! for quality comparison.
7//!
8//! # Architecture
9//!
10//! ```text
11//! Input → [Embedding] → [LayerNorm] → [Multi-Head Attention] → [Add] →
12//!                          ↓                                      ↓
13//!                       [LayerNorm] → [Feed Forward] → [Add] → Output
14//! ```
15//!
16//! # Comparison with SSMs
17//!
18//! | Model       | Per-Step Time | Per-Step Memory | Training  | Context |
19//! |-------------|---------------|-----------------|-----------|---------|
20//! | Transformer | O(N)          | O(N)            | O(N²)     | Limited |
21//! | Mamba/RWKV  | O(1)          | O(1)            | O(N)      | ∞       |
22//! | S4/S4D      | O(1)          | O(1)            | O(N log N)| ∞       |
23//!
24//! # Purpose
25//!
26//! This implementation serves as a quality baseline to validate that SSM
27//! architectures (Mamba2, RWKV, S4D) achieve competitive or better performance
28//! while maintaining their efficiency advantages.
29
30use crate::error::{ModelError, ModelResult};
31use crate::{AutoregressiveModel, ModelType};
32use kizzasi_core::{gelu, softmax, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor};
33use scirs2_core::ndarray::{Array1, Array2};
34use scirs2_core::random::{rng, Rng};
35use std::collections::VecDeque;
36#[allow(unused_imports)]
37use tracing::{debug, instrument, trace};
38
39/// Configuration for Transformer
40#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
41pub struct TransformerConfig {
42    /// Input dimension
43    pub input_dim: usize,
44    /// Hidden dimension (d_model)
45    pub hidden_dim: usize,
46    /// Number of attention heads
47    pub num_heads: usize,
48    /// Head dimension (derived: hidden_dim / num_heads)
49    pub head_dim: usize,
50    /// Feed-forward intermediate dimension (typically 4x hidden_dim)
51    pub ff_dim: usize,
52    /// Number of layers
53    pub num_layers: usize,
54    /// Maximum context window
55    pub max_seq_len: usize,
56    /// Dropout rate
57    pub dropout: f32,
58    /// Use RMSNorm instead of LayerNorm
59    pub use_rms_norm: bool,
60    /// Use causal masking (autoregressive)
61    pub causal: bool,
62}
63
64impl Default for TransformerConfig {
65    fn default() -> Self {
66        let hidden_dim = 512;
67        let num_heads = 8;
68        Self {
69            input_dim: 1,
70            hidden_dim,
71            num_heads,
72            head_dim: hidden_dim / num_heads,
73            ff_dim: hidden_dim * 4,
74            num_layers: 6,
75            max_seq_len: 2048,
76            dropout: 0.1,
77            use_rms_norm: true,
78            causal: true,
79        }
80    }
81}
82
83impl TransformerConfig {
84    /// Create a new Transformer configuration
85    pub fn new() -> Self {
86        Self::default()
87    }
88
89    /// Set input dimension
90    pub fn input_dim(mut self, dim: usize) -> Self {
91        self.input_dim = dim;
92        self
93    }
94
95    /// Set hidden dimension
96    pub fn hidden_dim(mut self, dim: usize) -> Self {
97        self.hidden_dim = dim;
98        self.head_dim = dim / self.num_heads;
99        self
100    }
101
102    /// Set number of heads
103    pub fn num_heads(mut self, n: usize) -> Self {
104        self.num_heads = n;
105        self.head_dim = self.hidden_dim / n;
106        self
107    }
108
109    /// Set number of layers
110    pub fn num_layers(mut self, n: usize) -> Self {
111        self.num_layers = n;
112        self
113    }
114
115    /// Set maximum sequence length
116    pub fn max_seq_len(mut self, len: usize) -> Self {
117        self.max_seq_len = len;
118        self
119    }
120
121    /// Validate the configuration
122    pub fn validate(&self) -> ModelResult<()> {
123        if self.hidden_dim == 0 {
124            return Err(ModelError::invalid_config("hidden_dim must be > 0"));
125        }
126        if self.num_heads == 0 {
127            return Err(ModelError::invalid_config("num_heads must be > 0"));
128        }
129        if !self.hidden_dim.is_multiple_of(self.num_heads) {
130            return Err(ModelError::invalid_config(
131                "hidden_dim must be divisible by num_heads",
132            ));
133        }
134        if self.num_layers == 0 {
135            return Err(ModelError::invalid_config("num_layers must be > 0"));
136        }
137        if self.max_seq_len == 0 {
138            return Err(ModelError::invalid_config("max_seq_len must be > 0"));
139        }
140        Ok(())
141    }
142}
143
144/// Multi-Head Self-Attention
145struct MultiHeadAttention {
146    num_heads: usize,
147    head_dim: usize,
148    hidden_dim: usize,
149
150    /// Query, Key, Value projections
151    q_proj: Array2<f32>,
152    k_proj: Array2<f32>,
153    v_proj: Array2<f32>,
154
155    /// Output projection
156    o_proj: Array2<f32>,
157
158    /// Cached keys and values for autoregressive generation
159    key_cache: VecDeque<Array1<f32>>,
160    value_cache: VecDeque<Array1<f32>>,
161    max_cache_len: usize,
162}
163
164impl MultiHeadAttention {
165    fn new(config: &TransformerConfig) -> ModelResult<Self> {
166        let mut rng = rng();
167        let scale = (2.0 / config.hidden_dim as f32).sqrt();
168
169        let q_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
170            (rng.random::<f32>() - 0.5) * 2.0 * scale
171        });
172        let k_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
173            (rng.random::<f32>() - 0.5) * 2.0 * scale
174        });
175        let v_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
176            (rng.random::<f32>() - 0.5) * 2.0 * scale
177        });
178        let o_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
179            (rng.random::<f32>() - 0.5) * 2.0 * scale
180        });
181
182        Ok(Self {
183            num_heads: config.num_heads,
184            head_dim: config.head_dim,
185            hidden_dim: config.hidden_dim,
186            q_proj,
187            k_proj,
188            v_proj,
189            o_proj,
190            key_cache: VecDeque::new(),
191            value_cache: VecDeque::new(),
192            max_cache_len: config.max_seq_len,
193        })
194    }
195
196    fn forward(&mut self, x: &Array1<f32>, causal: bool) -> CoreResult<Array1<f32>> {
197        let batch_size = x.len().min(self.hidden_dim);
198
199        // Project to Q, K, V
200        let q = self.project(x, &self.q_proj);
201        let k = self.project(x, &self.k_proj);
202        let v = self.project(x, &self.v_proj);
203
204        // Add to cache
205        self.key_cache.push_back(k.clone());
206        self.value_cache.push_back(v.clone());
207
208        // Maintain cache size
209        while self.key_cache.len() > self.max_cache_len {
210            self.key_cache.pop_front();
211            self.value_cache.pop_front();
212        }
213
214        // Compute attention over cached context
215        let seq_len = self.key_cache.len();
216        let scale = (self.head_dim as f32).sqrt();
217
218        let mut attention_output = Array1::zeros(batch_size);
219
220        // For each head
221        for h in 0..self.num_heads {
222            let head_start = h * self.head_dim;
223            let _head_end = (head_start + self.head_dim).min(batch_size);
224
225            // Compute attention scores with all cached positions
226            let mut scores = Vec::with_capacity(seq_len);
227            for pos in 0..seq_len {
228                let k_cached = &self.key_cache[pos];
229                let mut score = 0.0;
230
231                // Q · K^T for this head
232                for i in 0..self.head_dim {
233                    let q_idx = head_start + i;
234                    let k_idx = head_start + i;
235                    if q_idx < q.len() && k_idx < k_cached.len() {
236                        score += q[q_idx] * k_cached[k_idx];
237                    }
238                }
239                score /= scale;
240
241                // Causal masking: only attend to current and past positions
242                if !causal || pos < seq_len {
243                    scores.push(score);
244                } else {
245                    scores.push(f32::NEG_INFINITY);
246                }
247            }
248
249            // Softmax over positions
250            let attention_weights = softmax(&Array1::from_vec(scores));
251
252            // Weighted sum of values
253            for i in 0..self.head_dim {
254                let out_idx = head_start + i;
255                if out_idx >= attention_output.len() {
256                    break;
257                }
258
259                let mut weighted_value = 0.0;
260                for (pos, &weight) in attention_weights.iter().enumerate() {
261                    let v_cached = &self.value_cache[pos];
262                    let v_idx = head_start + i;
263                    if v_idx < v_cached.len() {
264                        weighted_value += weight * v_cached[v_idx];
265                    }
266                }
267                attention_output[out_idx] = weighted_value;
268            }
269        }
270
271        // Output projection
272        let output = self.project(&attention_output, &self.o_proj);
273        Ok(output)
274    }
275
276    fn project(&self, x: &Array1<f32>, weight: &Array2<f32>) -> Array1<f32> {
277        let out_dim = weight.shape()[0];
278        let mut output = Array1::zeros(out_dim.min(x.len()));
279        for i in 0..output.len() {
280            let mut sum = 0.0;
281            for j in 0..x.len().min(weight.shape()[1]) {
282                sum += weight[[i, j]] * x[j];
283            }
284            output[i] = sum;
285        }
286        output
287    }
288
289    fn reset(&mut self) {
290        self.key_cache.clear();
291        self.value_cache.clear();
292    }
293}
294
295/// Feed-Forward Network
296struct FeedForward {
297    fc1: Array2<f32>,
298    fc2: Array2<f32>,
299}
300
301impl FeedForward {
302    fn new(config: &TransformerConfig) -> ModelResult<Self> {
303        let mut rng = rng();
304        let scale1 = (2.0 / config.hidden_dim as f32).sqrt();
305        let scale2 = (2.0 / config.ff_dim as f32).sqrt();
306
307        let fc1 = Array2::from_shape_fn((config.hidden_dim, config.ff_dim), |_| {
308            (rng.random::<f32>() - 0.5) * 2.0 * scale1
309        });
310        let fc2 = Array2::from_shape_fn((config.ff_dim, config.hidden_dim), |_| {
311            (rng.random::<f32>() - 0.5) * 2.0 * scale2
312        });
313
314        Ok(Self { fc1, fc2 })
315    }
316
317    fn forward(&self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
318        // First layer
319        let mut hidden = Array1::zeros(self.fc1.shape()[1]);
320        for i in 0..hidden.len() {
321            let mut sum = 0.0;
322            for j in 0..x.len().min(self.fc1.shape()[0]) {
323                sum += self.fc1[[j, i]] * x[j];
324            }
325            hidden[i] = sum;
326        }
327
328        // Activation (GELU)
329        hidden = gelu(&hidden);
330
331        // Second layer
332        let mut output = Array1::zeros(x.len().min(self.fc2.shape()[1]));
333        for i in 0..output.len() {
334            let mut sum = 0.0;
335            for j in 0..hidden.len().min(self.fc2.shape()[0]) {
336                sum += self.fc2[[j, i]] * hidden[j];
337            }
338            output[i] = sum;
339        }
340
341        Ok(output)
342    }
343}
344
345/// Transformer Layer
346struct TransformerLayer {
347    ln1: LayerNorm,
348    ln2: LayerNorm,
349    attention: MultiHeadAttention,
350    feed_forward: FeedForward,
351    causal: bool,
352}
353
354impl TransformerLayer {
355    fn new(config: &TransformerConfig) -> ModelResult<Self> {
356        let norm_type = if config.use_rms_norm {
357            NormType::RMSNorm
358        } else {
359            NormType::LayerNorm
360        };
361
362        let ln1 = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
363        let ln2 = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
364        let attention = MultiHeadAttention::new(config)?;
365        let feed_forward = FeedForward::new(config)?;
366
367        Ok(Self {
368            ln1,
369            ln2,
370            attention,
371            feed_forward,
372            causal: config.causal,
373        })
374    }
375
376    fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
377        // Pre-norm: LayerNorm → Attention → Residual
378        let x_norm = self.ln1.forward(x);
379        let attn_out = self.attention.forward(&x_norm, self.causal)?;
380        let mut x_attn = x.clone();
381        for i in 0..x_attn.len().min(attn_out.len()) {
382            x_attn[i] += attn_out[i];
383        }
384
385        // Pre-norm: LayerNorm → FFN → Residual
386        let x_norm2 = self.ln2.forward(&x_attn);
387        let ff_out = self.feed_forward.forward(&x_norm2)?;
388        let mut output = x_attn;
389        for i in 0..output.len().min(ff_out.len()) {
390            output[i] += ff_out[i];
391        }
392
393        Ok(output)
394    }
395
396    fn reset(&mut self) {
397        self.attention.reset();
398    }
399}
400
401/// Transformer model
402pub struct Transformer {
403    config: TransformerConfig,
404    layers: Vec<TransformerLayer>,
405    ln_out: LayerNorm,
406    input_proj: Array2<f32>,
407    output_proj: Array2<f32>,
408}
409
410impl Transformer {
411    /// Create a new Transformer model
412    pub fn new(config: TransformerConfig) -> ModelResult<Self> {
413        config.validate()?;
414
415        // Initialize layers
416        let mut layers = Vec::with_capacity(config.num_layers);
417        for _ in 0..config.num_layers {
418            layers.push(TransformerLayer::new(&config)?);
419        }
420
421        // Output layer normalization
422        let norm_type = if config.use_rms_norm {
423            NormType::RMSNorm
424        } else {
425            NormType::LayerNorm
426        };
427        let ln_out = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
428
429        // Initialize input/output projections
430        let mut rng = rng();
431        let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
432        let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
433            (rng.random::<f32>() - 0.5) * 2.0 * scale
434        });
435
436        let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
437        let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
438            (rng.random::<f32>() - 0.5) * 2.0 * scale
439        });
440
441        Ok(Self {
442            config,
443            layers,
444            ln_out,
445            input_proj,
446            output_proj,
447        })
448    }
449
450    /// Get the configuration
451    pub fn config(&self) -> &TransformerConfig {
452        &self.config
453    }
454
455    /// Load weights from a SafeTensors model file
456    ///
457    /// # Weight Naming Convention
458    ///
459    /// The following tensor names are expected:
460    /// - `input_proj`: Input projection matrix (input_dim, hidden_dim)
461    /// - `output_proj`: Output projection matrix (hidden_dim, input_dim)
462    /// - `ln_out.weight`: Output layer norm weight (gamma)
463    /// - `ln_out.bias`: Output layer norm bias (beta, optional)
464    ///
465    /// For each layer i:
466    /// - `layers.{i}.ln1.weight`: Attention layer norm weight
467    /// - `layers.{i}.ln1.bias`: Attention layer norm bias (optional)
468    /// - `layers.{i}.ln2.weight`: Feed-forward layer norm weight
469    /// - `layers.{i}.ln2.bias`: Feed-forward layer norm bias (optional)
470    ///
471    /// Multi-head attention parameters:
472    /// - `layers.{i}.attention.q_proj`: Query projection
473    /// - `layers.{i}.attention.k_proj`: Key projection
474    /// - `layers.{i}.attention.v_proj`: Value projection
475    /// - `layers.{i}.attention.o_proj`: Output projection
476    ///
477    /// Feed-forward parameters:
478    /// - `layers.{i}.feed_forward.fc1`: First linear layer
479    /// - `layers.{i}.feed_forward.fc2`: Second linear layer
480    pub fn load_weights(&mut self, loader: &crate::loader::ModelLoader) -> ModelResult<()> {
481        // Load input/output projections
482        if loader.has_tensor("input_proj") {
483            self.input_proj = loader.load_array2("input_proj")?;
484        }
485        if loader.has_tensor("output_proj") {
486            self.output_proj = loader.load_array2("output_proj")?;
487        }
488
489        // Load output layer norm
490        if loader.has_tensor("ln_out.weight") {
491            let weight = loader.load_array1("ln_out.weight")?;
492            self.ln_out.set_gamma(weight);
493        }
494        if loader.has_tensor("ln_out.bias") {
495            let bias = loader.load_array1("ln_out.bias")?;
496            self.ln_out.set_beta(bias);
497        }
498
499        // Load each layer's weights
500        for (i, layer) in self.layers.iter_mut().enumerate() {
501            let prefix = format!("layers.{}", i);
502
503            // Load layer norm 1 (attention)
504            if loader.has_tensor(&format!("{}.ln1.weight", prefix)) {
505                let weight = loader.load_array1(&format!("{}.ln1.weight", prefix))?;
506                layer.ln1.set_gamma(weight);
507            }
508            if loader.has_tensor(&format!("{}.ln1.bias", prefix)) {
509                let bias = loader.load_array1(&format!("{}.ln1.bias", prefix))?;
510                layer.ln1.set_beta(bias);
511            }
512
513            // Load layer norm 2 (feed-forward)
514            if loader.has_tensor(&format!("{}.ln2.weight", prefix)) {
515                let weight = loader.load_array1(&format!("{}.ln2.weight", prefix))?;
516                layer.ln2.set_gamma(weight);
517            }
518            if loader.has_tensor(&format!("{}.ln2.bias", prefix)) {
519                let bias = loader.load_array1(&format!("{}.ln2.bias", prefix))?;
520                layer.ln2.set_beta(bias);
521            }
522
523            // Load attention parameters
524            let attn_prefix = format!("{}.attention", prefix);
525            if loader.has_tensor(&format!("{}.q_proj", attn_prefix)) {
526                layer.attention.q_proj = loader.load_array2(&format!("{}.q_proj", attn_prefix))?;
527            }
528            if loader.has_tensor(&format!("{}.k_proj", attn_prefix)) {
529                layer.attention.k_proj = loader.load_array2(&format!("{}.k_proj", attn_prefix))?;
530            }
531            if loader.has_tensor(&format!("{}.v_proj", attn_prefix)) {
532                layer.attention.v_proj = loader.load_array2(&format!("{}.v_proj", attn_prefix))?;
533            }
534            if loader.has_tensor(&format!("{}.o_proj", attn_prefix)) {
535                layer.attention.o_proj = loader.load_array2(&format!("{}.o_proj", attn_prefix))?;
536            }
537
538            // Load feed-forward parameters
539            let ff_prefix = format!("{}.feed_forward", prefix);
540            if loader.has_tensor(&format!("{}.fc1", ff_prefix)) {
541                layer.feed_forward.fc1 = loader.load_array2(&format!("{}.fc1", ff_prefix))?;
542            }
543            if loader.has_tensor(&format!("{}.fc2", ff_prefix)) {
544                layer.feed_forward.fc2 = loader.load_array2(&format!("{}.fc2", ff_prefix))?;
545            }
546        }
547
548        Ok(())
549    }
550
551    /// Save weights to a SafeTensors model file (stub for future implementation)
552    #[allow(unused_variables)]
553    pub fn save_weights(&self, path: &str) -> ModelResult<()> {
554        // TODO: Implement SafeTensors saving
555        Err(ModelError::simple_load_error(
556            "Transformer save_weights not yet implemented".to_string(),
557        ))
558    }
559}
560
561impl SignalPredictor for Transformer {
562    #[instrument(skip(self, input))]
563    fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
564        // Project input to hidden dimension
565        let mut hidden = input.dot(&self.input_proj);
566
567        // Pass through each layer
568        for layer in &mut self.layers {
569            hidden = layer.forward(&hidden)?;
570        }
571
572        // Final layer normalization
573        hidden = self.ln_out.forward(&hidden);
574
575        // Project back to input dimension
576        let output = hidden.dot(&self.output_proj);
577        Ok(output)
578    }
579
580    fn reset(&mut self) {
581        for layer in &mut self.layers {
582            layer.reset();
583        }
584    }
585
586    fn context_window(&self) -> usize {
587        self.config.max_seq_len
588    }
589}
590
591impl AutoregressiveModel for Transformer {
592    fn hidden_dim(&self) -> usize {
593        self.config.hidden_dim
594    }
595
596    fn state_dim(&self) -> usize {
597        // Transformers use KV cache, which grows with sequence length
598        self.config.hidden_dim
599    }
600
601    fn num_layers(&self) -> usize {
602        self.config.num_layers
603    }
604
605    fn model_type(&self) -> ModelType {
606        ModelType::Transformer
607    }
608
609    fn get_states(&self) -> Vec<HiddenState> {
610        // Return KV cache state for each layer
611        self.layers
612            .iter()
613            .map(|layer| {
614                let cache_len = layer.attention.key_cache.len();
615                let mut combined = Array2::zeros((cache_len.max(1), self.config.hidden_dim));
616
617                // Store key cache (value cache could be stored similarly)
618                for (i, k) in layer.attention.key_cache.iter().enumerate() {
619                    for j in 0..k.len().min(self.config.hidden_dim) {
620                        combined[[i, j]] = k[j];
621                    }
622                }
623
624                let mut hs = HiddenState::new(combined.shape()[0], combined.shape()[1]);
625                hs.update(combined);
626                hs
627            })
628            .collect()
629    }
630
631    fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
632        if states.len() != self.config.num_layers {
633            return Err(ModelError::state_count_mismatch(
634                "Transformer",
635                self.config.num_layers,
636                states.len(),
637            ));
638        }
639
640        for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
641            let combined = states[layer_idx].state();
642
643            // Restore key cache (simplified - in practice would restore both K and V)
644            layer.attention.key_cache.clear();
645            for i in 0..combined.shape()[0] {
646                let mut k = Array1::zeros(self.config.hidden_dim);
647                for j in 0..self.config.hidden_dim.min(combined.shape()[1]) {
648                    k[j] = combined[[i, j]];
649                }
650                layer.attention.key_cache.push_back(k);
651            }
652        }
653
654        Ok(())
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661
662    #[test]
663    fn test_transformer_config() {
664        let config = TransformerConfig::new()
665            .hidden_dim(256)
666            .num_heads(8)
667            .num_layers(4);
668
669        assert_eq!(config.hidden_dim, 256);
670        assert_eq!(config.num_heads, 8);
671        assert_eq!(config.head_dim, 32);
672        assert!(config.validate().is_ok());
673    }
674
675    #[test]
676    fn test_transformer_creation() {
677        let config = TransformerConfig::new().hidden_dim(128).num_heads(4);
678        let model = Transformer::new(config);
679        assert!(model.is_ok());
680    }
681
682    #[test]
683    fn test_transformer_forward() {
684        let config = TransformerConfig::new()
685            .hidden_dim(64)
686            .num_heads(4)
687            .num_layers(2)
688            .max_seq_len(128);
689        let mut model = Transformer::new(config).expect("Failed to create Transformer");
690
691        let input = Array1::from_vec(vec![0.5]);
692        let output = model.step(&input);
693        assert!(output.is_ok());
694    }
695
696    #[test]
697    fn test_invalid_heads() {
698        let config = TransformerConfig::new().hidden_dim(100).num_heads(3); // Not divisible
699        assert!(config.validate().is_err());
700    }
701
702    #[test]
703    fn test_context_window() {
704        // Use smaller configuration for faster test
705        // Default has hidden_dim=512, num_layers=6 which is slow to initialize
706        let config = TransformerConfig::new()
707            .hidden_dim(64)
708            .num_heads(4)
709            .num_layers(2)
710            .max_seq_len(512);
711        let model = Transformer::new(config).expect("Failed to create Transformer");
712        assert_eq!(model.context_window(), 512);
713    }
714}