Skip to main content

multiscreen_rs/
model.rs

1use crate::{
2    error::{Error, Result},
3    runtime::DefaultAutodiffBackend,
4};
5use burn::{
6    grad_clipping::GradientClippingConfig,
7    module::{Module, Param},
8    optim::{AdamWConfig, GradientsParams, Optimizer},
9    record::{FullPrecisionSettings, NamedMpkFileRecorder},
10    tensor::{
11        Int, Tensor, TensorData, activation,
12        backend::{AutodiffBackend, Backend},
13    },
14};
15use serde::{Deserialize, Serialize};
16use std::f32::consts::PI;
17use std::path::Path;
18
19const EPS: f32 = 1e-6;
20
21/// Supported neural Multiscreen parameter budgets.
22///
23/// The final count is approximate because the token embedding scales with the
24/// caller's vocabulary size. Use `MultiscreenModelConfig::estimated_parameter_count`
25/// to see the exact count for a resolved config.
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
27pub enum MultiscreenParameterBudget {
28    Params1M,
29    Params5M,
30    Params10M,
31    Params50M,
32    Params100M,
33}
34
35impl MultiscreenParameterBudget {
36    pub const ALL: [Self; 5] = [
37        Self::Params1M,
38        Self::Params5M,
39        Self::Params10M,
40        Self::Params50M,
41        Self::Params100M,
42    ];
43
44    pub fn label(self) -> &'static str {
45        match self {
46            Self::Params1M => "1M",
47            Self::Params5M => "5M",
48            Self::Params10M => "10M",
49            Self::Params50M => "50M",
50            Self::Params100M => "100M",
51        }
52    }
53
54    pub fn target_parameter_count(self) -> usize {
55        match self {
56            Self::Params1M => 1_000_000,
57            Self::Params5M => 5_000_000,
58            Self::Params10M => 10_000_000,
59            Self::Params50M => 50_000_000,
60            Self::Params100M => 100_000_000,
61        }
62    }
63}
64
65/// Paper-faithful neural Multiscreen model configuration.
66#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
67pub struct MultiscreenModelConfig {
68    /// Vocabulary size.
69    pub vocab_size: usize,
70    /// Context length `T`.
71    pub seq_len: usize,
72    /// Paper `N_L`: number of residual Multiscreen layers.
73    pub layers: usize,
74    /// Paper `N_H`: number of gated screening tiles per layer.
75    pub tiles: usize,
76    /// Paper `d_E`: token embedding/model width.
77    pub d_model: usize,
78    /// Paper `d_K`: query/key width. Must be at least 2 for MiPE.
79    pub d_key: usize,
80    /// Paper `d_V`: value/gate width.
81    pub d_value: usize,
82    /// Paper MiPE threshold `w_th`.
83    pub w_th: f32,
84}
85
86impl MultiscreenModelConfig {
87    pub fn tiny() -> Self {
88        Self {
89            vocab_size: 64,
90            seq_len: 64,
91            layers: 2,
92            tiles: 4,
93            d_model: 64,
94            d_key: 16,
95            d_value: 32,
96            w_th: 32.0,
97        }
98    }
99
100    pub fn tiny_for_tests() -> Self {
101        Self {
102            vocab_size: 32,
103            seq_len: 8,
104            layers: 1,
105            tiles: 2,
106            d_model: 16,
107            d_key: 4,
108            d_value: 8,
109            w_th: 8.0,
110        }
111    }
112
113    pub fn for_parameter_budget(
114        budget: MultiscreenParameterBudget,
115        vocab_size: usize,
116        seq_len: usize,
117    ) -> Self {
118        match budget {
119            MultiscreenParameterBudget::Params1M => Self::preset_1m(vocab_size, seq_len),
120            MultiscreenParameterBudget::Params5M => Self::preset_5m(vocab_size, seq_len),
121            MultiscreenParameterBudget::Params10M => Self::preset_10m(vocab_size, seq_len),
122            MultiscreenParameterBudget::Params50M => Self::preset_50m(vocab_size, seq_len),
123            MultiscreenParameterBudget::Params100M => Self::preset_100m(vocab_size, seq_len),
124        }
125    }
126
127    pub fn preset_1m(vocab_size: usize, seq_len: usize) -> Self {
128        Self::from_dimensions(vocab_size, seq_len, 2, 2, 128, 32, 64)
129    }
130
131    pub fn preset_5m(vocab_size: usize, seq_len: usize) -> Self {
132        Self::from_dimensions(vocab_size, seq_len, 2, 4, 384, 96, 192)
133    }
134
135    pub fn preset_10m(vocab_size: usize, seq_len: usize) -> Self {
136        Self::from_dimensions(vocab_size, seq_len, 3, 4, 512, 128, 256)
137    }
138
139    pub fn preset_50m(vocab_size: usize, seq_len: usize) -> Self {
140        Self::from_dimensions(vocab_size, seq_len, 6, 4, 960, 240, 480)
141    }
142
143    pub fn preset_100m(vocab_size: usize, seq_len: usize) -> Self {
144        Self::from_dimensions(vocab_size, seq_len, 8, 4, 1216, 304, 608)
145    }
146
147    pub fn paper_10m(vocab_size: usize, seq_len: usize) -> Self {
148        Self::preset_10m(vocab_size, seq_len)
149    }
150
151    pub fn estimated_parameter_count(&self) -> usize {
152        let embedding_params = self.vocab_size.saturating_mul(self.d_model);
153        let per_tile_params = self
154            .d_model
155            .saturating_mul(
156                2usize
157                    .saturating_mul(self.d_key)
158                    .saturating_add(3usize.saturating_mul(self.d_value)),
159            )
160            .saturating_add(3);
161        let tile_params = self
162            .layers
163            .saturating_mul(self.tiles)
164            .saturating_mul(per_tile_params);
165        embedding_params
166            .saturating_add(2)
167            .saturating_add(tile_params)
168    }
169
170    fn from_dimensions(
171        vocab_size: usize,
172        seq_len: usize,
173        layers: usize,
174        tiles: usize,
175        d_model: usize,
176        d_key: usize,
177        d_value: usize,
178    ) -> Self {
179        Self {
180            vocab_size,
181            seq_len,
182            layers,
183            tiles,
184            d_model,
185            d_key,
186            d_value,
187            w_th: 32.0,
188        }
189    }
190
191    pub fn validate(&self) -> Result<()> {
192        ensure(self.vocab_size > 0, "vocab_size must be greater than zero")?;
193        ensure(self.seq_len > 0, "seq_len must be greater than zero")?;
194        ensure(self.layers > 0, "layers must be greater than zero")?;
195        ensure(self.tiles > 0, "tiles must be greater than zero")?;
196        ensure(self.d_model > 0, "d_model must be greater than zero")?;
197        ensure(self.d_key >= 2, "d_key must be at least 2 for MiPE")?;
198        ensure(self.d_value > 0, "d_value must be greater than zero")?;
199        ensure(
200            self.w_th.is_finite() && self.w_th > 0.0,
201            "w_th must be positive and finite",
202        )?;
203        Ok(())
204    }
205}
206
207/// Training options for token-sequence training.
208#[derive(Clone, Debug)]
209pub struct ModelTrainingConfig {
210    pub steps: usize,
211    pub batch_size: usize,
212    pub learning_rate: f64,
213    pub weight_decay: f64,
214    pub grad_clip_norm: Option<f64>,
215    pub pad_token_id: u32,
216    /// Directory to save checkpoints into. When `None`, no checkpoints are
217    /// saved during training.
218    pub checkpoint_dir: Option<String>,
219    /// Save a checkpoint every N steps. Only used when `checkpoint_dir` is
220    /// `Some`. A value of `0` disables periodic snapshots (only `best.mpk`
221    /// is kept).
222    pub checkpoint_interval: usize,
223}
224
225impl Default for ModelTrainingConfig {
226    fn default() -> Self {
227        Self {
228            steps: 100,
229            batch_size: 4,
230            learning_rate: 2e-4,
231            weight_decay: 0.01,
232            grad_clip_norm: Some(1.0),
233            pad_token_id: 0,
234            checkpoint_dir: None,
235            checkpoint_interval: 0,
236        }
237    }
238}
239
240/// Greedy generation options.
241#[derive(Clone, Debug)]
242pub struct ModelInferenceConfig {
243    pub max_new_tokens: usize,
244    pub pad_token_id: u32,
245}
246
247impl Default for ModelInferenceConfig {
248    fn default() -> Self {
249        Self {
250            max_new_tokens: 16,
251            pad_token_id: 0,
252        }
253    }
254}
255
256#[derive(Clone, Debug, PartialEq)]
257pub struct ModelTrainingReport {
258    pub steps: usize,
259    pub final_loss: f32,
260    /// The lowest loss observed across all training steps.
261    pub best_loss: f32,
262    /// The step at which `best_loss` was recorded.
263    pub best_loss_step: usize,
264    pub training_window_count: usize,
265    pub parameter_count: usize,
266}
267
268/// Result of evaluating a model on held-out sequences.
269#[derive(Clone, Debug)]
270pub struct EvaluationResult {
271    /// Average cross-entropy loss across all batches.
272    pub loss: f32,
273    /// Perplexity = exp(average_loss).
274    pub perplexity: f32,
275    /// Fraction of tokens where argmax(logits) == target (next-token accuracy).
276    pub accuracy: f64,
277    /// Number of batches evaluated.
278    pub num_batches: usize,
279    /// Total number of (unmasked) tokens evaluated.
280    pub total_tokens: usize,
281}
282
283pub struct MultiscreenModelOutput {
284    pub token_ids: Vec<u32>,
285}
286
287/// Burn-backed neural Multiscreen language model ported from
288/// `multiscreen-testing`.
289#[derive(Module, Debug)]
290pub struct MultiscreenModel<B: Backend = DefaultAutodiffBackend> {
291    #[module(skip)]
292    config: MultiscreenModelConfig,
293    token_embedding: Param<Tensor<B, 2>>,
294    s_e: Param<Tensor<B, 1>>,
295    s_f: Param<Tensor<B, 1>>,
296    layers: Vec<MultiscreenLayer<B>>,
297}
298
299/// Convenience alias for the default Burn Flex autodiff model.
300pub type DefaultMultiscreenModel = MultiscreenModel<DefaultAutodiffBackend>;
301
302#[derive(Module, Debug)]
303struct MultiscreenLayer<B: Backend> {
304    tiles: Vec<GatedScreeningTile<B>>,
305}
306
307#[derive(Module, Debug)]
308struct GatedScreeningTile<B: Backend> {
309    w_q: Param<Tensor<B, 2>>,
310    w_k: Param<Tensor<B, 2>>,
311    w_v: Param<Tensor<B, 2>>,
312    w_g: Param<Tensor<B, 2>>,
313    w_o: Param<Tensor<B, 2>>,
314    s_w: Param<Tensor<B, 1>>,
315    s_r: Param<Tensor<B, 1>>,
316    s_o: Param<Tensor<B, 1>>,
317    #[module(skip)]
318    w_th: f32,
319}
320
321impl<B: Backend> MultiscreenModel<B> {
322    pub fn new(config: MultiscreenModelConfig, device: &B::Device) -> Result<Self> {
323        config.validate()?;
324
325        let mut seed = 0x4d55_4c54_4953_4352;
326        let token_embedding = init_matrix(
327            config.vocab_size,
328            config.d_model,
329            0.1 / (config.d_model as f32).sqrt(),
330            &mut seed,
331            device,
332        );
333        let s_e = init_scalar(0.0, device);
334        let s_f = init_scalar((config.d_model as f32).sqrt().ln(), device);
335
336        let mut layers = Vec::with_capacity(config.layers);
337        for _layer_idx in 0..config.layers {
338            let mut tiles = Vec::with_capacity(config.tiles);
339            for tile_idx in 0..config.tiles {
340                let w_q = init_matrix(
341                    config.d_model,
342                    config.d_key,
343                    0.1 / (config.d_key as f32).sqrt(),
344                    &mut seed,
345                    device,
346                );
347                let w_k = init_matrix(
348                    config.d_model,
349                    config.d_key,
350                    0.1 / (config.d_key as f32).sqrt(),
351                    &mut seed,
352                    device,
353                );
354                let w_v = init_matrix(
355                    config.d_model,
356                    config.d_value,
357                    0.1 / (config.d_value as f32).sqrt(),
358                    &mut seed,
359                    device,
360                );
361                let w_g = init_matrix(config.d_model, config.d_value, 0.1, &mut seed, device);
362                let w_o = init_matrix(
363                    config.d_value,
364                    config.d_model,
365                    0.1 / (config.d_model as f32).sqrt(),
366                    &mut seed,
367                    device,
368                );
369
370                let window_frac = if config.tiles == 1 {
371                    0.0
372                } else {
373                    tile_idx as f32 / (config.tiles - 1) as f32
374                };
375                let s_w = init_scalar(window_frac * config.w_th.ln(), device);
376                let s_r = init_scalar(0.0, device);
377                let s_o = init_scalar(-0.5 * ((config.layers * config.tiles) as f32).ln(), device);
378
379                tiles.push(GatedScreeningTile {
380                    w_q,
381                    w_k,
382                    w_v,
383                    w_g,
384                    w_o,
385                    s_w,
386                    s_r,
387                    s_o,
388                    w_th: config.w_th,
389                });
390            }
391            layers.push(MultiscreenLayer { tiles });
392        }
393
394        Ok(Self {
395            config,
396            token_embedding,
397            s_e,
398            s_f,
399            layers,
400        })
401    }
402
403    pub fn config(&self) -> &MultiscreenModelConfig {
404        &self.config
405    }
406
407    pub fn parameter_count(&self) -> usize {
408        self.num_params()
409    }
410
411    pub fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
412        let [batch, seq_len] = tokens.dims();
413        let embedding = row_unit_normalize(self.token_embedding.val());
414        let one_hot = tokens.one_hot::<3>(self.config().vocab_size).float();
415        let mut x =
416            linear(one_hot, embedding.clone()).reshape([batch, seq_len, self.config().d_model]);
417        let x_dims = x.dims();
418        x = x * expand_scalar3(self.s_e.val().exp(), x_dims);
419
420        for layer in &self.layers {
421            let mut layer_update = Tensor::<B, 3>::zeros(x.dims(), &x.device());
422            for tile in &layer.tiles {
423                layer_update = layer_update + tile.forward(x.clone());
424            }
425            x = x + layer_update;
426        }
427
428        let logits_weight = embedding.swap_dims(0, 1);
429        let logits = linear(x, logits_weight);
430        logits.clone() * expand_scalar3(self.s_f.val().exp(), logits.dims())
431    }
432
433    pub fn save_parameters(&self, path: impl AsRef<Path>) -> Result<()> {
434        let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
435        self.clone()
436            .save_file(path.as_ref().to_path_buf(), &recorder)
437            .map_err(|err| Error::Serialization(err.to_string()))
438    }
439
440    pub fn load_parameters(&mut self, path: impl AsRef<Path>) -> Result<()> {
441        let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
442        let device =
443            self.devices().into_iter().next().ok_or_else(|| {
444                Error::Serialization("model has no device for parameter load".into())
445            })?;
446        let loaded = self
447            .clone()
448            .load_file(path.as_ref().to_path_buf(), &recorder, &device)
449            .map_err(|err| Error::Serialization(err.to_string()))?;
450        *self = loaded;
451        Ok(())
452    }
453
454    /// Greedy token generation.
455    /// Generate tokens one at a time, invoking a callback for each newly
456    /// produced token. This enables streaming / word-by-word output similar
457    /// to ChatGPT.
458    ///
459    /// The callback receives `(token_id, index)` where `index` is the
460    /// zero-based position of the *new* token (0 = first generated token).
461    /// If the callback returns `false`, generation stops early.
462    ///
463    /// Returns the full output (prompt + generated) token sequence.
464    pub fn infer_tokens_stream(
465        &self,
466        prompt: &[u32],
467        inference: &ModelInferenceConfig,
468        device: &B::Device,
469        mut on_token: impl FnMut(u32, usize) -> bool,
470    ) -> Result<MultiscreenModelOutput> {
471        if prompt.is_empty() {
472            return Err(Error::Inference(
473                "prompt must contain at least one token".to_string(),
474            ));
475        }
476
477        let mut output = prompt.to_vec();
478        for i in 0..inference.max_new_tokens {
479            let next = self.predict_next_token(&output, inference.pad_token_id, device)?;
480            output.push(next);
481            if !on_token(next, i) {
482                break;
483            }
484        }
485
486        Ok(MultiscreenModelOutput { token_ids: output })
487    }
488
489    /// Generate tokens and return them all at once (non-streaming).
490    ///
491    /// For streaming / token-by-token output, use [`Self::infer_tokens_stream`].
492    pub fn infer_tokens(
493        &self,
494        prompt: &[u32],
495        inference: &ModelInferenceConfig,
496        device: &B::Device,
497    ) -> Result<MultiscreenModelOutput> {
498        self.infer_tokens_stream(prompt, inference, device, |_, _| true)
499    }
500
501    pub fn predict_next_token(
502        &self,
503        context: &[u32],
504        pad_token_id: u32,
505        device: &B::Device,
506    ) -> Result<u32> {
507        if context.is_empty() {
508            return Err(Error::Inference(
509                "context must contain at least one token".to_string(),
510            ));
511        }
512
513        let input = context_window(context, self.config().seq_len, pad_token_id);
514        let input = tensor_from_u32::<B, 2>(input, [1, self.config().seq_len], device)?;
515        let logits = self.forward(input);
516        let last_logits = logits
517            .slice([
518                0..1,
519                self.config().seq_len - 1..self.config().seq_len,
520                0..self.config().vocab_size,
521            ])
522            .reshape([self.config().vocab_size]);
523        let values = tensor_to_vec(last_logits)?;
524        argmax(&values).map(|idx| idx as u32)
525    }
526
527    /// Run a forward pass and return the full logit tensor.
528    ///
529    /// The returned tensor has shape `[1, seq_len, vocab_size]`.
530    /// This is useful for sampling-based generation (top-k, temperature, etc.)
531    /// where you need access to the raw logit values, not just the argmax.
532    ///
533    /// The `context` is padded/truncated to `seq_len` automatically.
534    pub fn forward_logits(
535        &self,
536        context: &[u32],
537        pad_token_id: u32,
538        device: &B::Device,
539    ) -> Result<Tensor<B, 3>> {
540        if context.is_empty() {
541            return Err(Error::Inference(
542                "context must contain at least one token".to_string(),
543            ));
544        }
545
546        let input = context_window(context, self.config().seq_len, pad_token_id);
547        let input = tensor_from_u32::<B, 2>(input, [1, self.config().seq_len], device)?;
548        let logits = self.forward(input);
549        Ok(logits)
550    }
551
552    /// Evaluates the model on token sequences, returning average loss,
553    /// perplexity, and next-token prediction accuracy.
554    ///
555    /// This method works on any `Backend` (including non-autodiff), which makes
556    /// it safe to call on an inference-only model without VRAM growth.
557    pub fn evaluate_on_sequences(
558        &self,
559        sequences: &[Vec<u32>],
560        seq_len: usize,
561        batch_size: usize,
562        pad_token_id: u32,
563        device: &B::Device,
564    ) -> Result<EvaluationResult> {
565        let windows = TrainingWindows::from_sequences(sequences, seq_len, pad_token_id)?;
566        if windows.is_empty() {
567            return Ok(EvaluationResult {
568                loss: f32::NAN,
569                perplexity: f32::NAN,
570                accuracy: 0.0,
571                num_batches: 0,
572                total_tokens: 0,
573            });
574        }
575
576        let num_batches = windows.len().div_ceil(batch_size);
577        let mut total_loss = 0.0_f64;
578        let mut total_correct = 0_usize;
579        let mut total_tokens = 0_usize;
580
581        for step in 0..num_batches {
582            let batch = windows.batch::<B>(step, batch_size, device)?;
583            let logits = self.forward(batch.inputs); // [B, S, V]
584            let loss = cross_entropy_loss_with_mask(
585                logits.clone(),
586                batch.targets.clone(),
587                batch.loss_mask.clone(),
588            );
589            let loss_val = tensor_scalar(loss)? as f64;
590            total_loss += loss_val;
591
592            // Compute accuracy: argmax(logits) == target, masked
593            let [b, s, v] = logits.dims();
594            let mask_vec: Vec<f32> = batch
595                .loss_mask
596                .clone()
597                .reshape([b * s])
598                .into_data()
599                .into_vec::<f32>()
600                .map_err(|e| Error::Inference(e.to_string()))?;
601            let target_vec: Vec<i32> = batch
602                .targets
603                .clone()
604                .reshape([b * s])
605                .into_data()
606                .into_vec::<i32>()
607                .map_err(|e| Error::Inference(e.to_string()))?;
608            let logit_vec: Vec<f32> = logits
609                .reshape([b * s * v])
610                .into_data()
611                .into_vec::<f32>()
612                .map_err(|e| Error::Inference(e.to_string()))?;
613
614            for bi in 0..b {
615                for si in 0..s {
616                    let mi = bi * s + si;
617                    if mask_vec[mi] < 0.5 {
618                        continue;
619                    }
620                    total_tokens += 1;
621                    let base = bi * s * v + si * v;
622                    let mut best_idx = 0;
623                    let mut best_val = f32::NEG_INFINITY;
624                    for vi in 0..v {
625                        let val = logit_vec[base + vi];
626                        if val > best_val {
627                            best_val = val;
628                            best_idx = vi;
629                        }
630                    }
631                    if best_idx == target_vec[mi] as usize {
632                        total_correct += 1;
633                    }
634                }
635            }
636        }
637
638        let avg_loss = total_loss / num_batches as f64;
639        let perplexity = avg_loss.exp();
640        let accuracy = if total_tokens > 0 {
641            total_correct as f64 / total_tokens as f64
642        } else {
643            0.0
644        };
645
646        Ok(EvaluationResult {
647            loss: avg_loss as f32,
648            perplexity: perplexity as f32,
649            accuracy,
650            num_batches,
651            total_tokens,
652        })
653    }
654}
655
656impl<B> MultiscreenModel<B>
657where
658    B: AutodiffBackend,
659{
660    /// Trains this model directly on token sequences.
661    ///
662    /// The optional `on_step` callback is invoked after each optimizer step with
663    /// `(step_index, loss_value)`. Use it for progress logging, CSV export, etc.
664    pub fn train_token_sequences(
665        &mut self,
666        sequences: &[Vec<u32>],
667        training: &ModelTrainingConfig,
668        device: &B::Device,
669        mut on_step: impl FnMut(usize, f32),
670    ) -> Result<ModelTrainingReport> {
671        if training.batch_size == 0 {
672            return Err(Error::Training(
673                "batch_size must be greater than zero".to_string(),
674            ));
675        }
676        let windows = TrainingWindows::from_sequences(
677            sequences,
678            self.config().seq_len,
679            training.pad_token_id,
680        )?;
681        if windows.is_empty() {
682            return Err(Error::Training(
683                "training requires at least one sequence with two or more tokens".to_string(),
684            ));
685        }
686
687        let mut optimizer_config =
688            AdamWConfig::new().with_weight_decay(training.weight_decay as f32);
689        if let Some(max_norm) = training.grad_clip_norm.filter(|value| *value > 0.0) {
690            optimizer_config = optimizer_config
691                .with_grad_clipping(Some(GradientClippingConfig::Norm(max_norm as f32)));
692        }
693        let mut optimizer = optimizer_config.init::<B, Self>();
694        let mut model = self.clone();
695        let mut final_loss = f32::NAN;
696        let mut best_loss = f32::MAX;
697        let mut best_loss_step: usize = 0;
698
699        let ckpt_dir = training.checkpoint_dir.as_deref().map(Path::new);
700        if let Some(dir) = &ckpt_dir {
701            std::fs::create_dir_all(dir).map_err(|e| {
702                Error::Io(format!(
703                    "failed to create checkpoint directory {:?}: {}",
704                    dir, e
705                ))
706            })?;
707        }
708
709        for step in 0..training.steps {
710            let batch = windows.batch::<B>(step, training.batch_size, device)?;
711            let logits = model.forward(batch.inputs);
712            let loss = cross_entropy_loss_with_mask(logits, batch.targets, batch.loss_mask);
713            final_loss = tensor_scalar(loss.clone())?;
714            let grads = loss.backward();
715            let grads = GradientsParams::from_grads(grads, &model);
716            model = optimizer.step(training.learning_rate, model, grads);
717
718            // --- Periodic + best checkpoint saving ---
719            if final_loss < best_loss {
720                best_loss = final_loss;
721                best_loss_step = step;
722                if let Some(dir) = &ckpt_dir {
723                    let path = dir.join("best.mpk");
724                    model.save_parameters(&path)?;
725                }
726            }
727            if training.checkpoint_interval > 0
728                && (step + 1) % training.checkpoint_interval == 0
729                && let Some(dir) = &ckpt_dir
730            {
731                let path = dir.join(format!("step_{:06}.mpk", step + 1));
732                model.save_parameters(&path)?;
733            }
734
735            on_step(step, final_loss);
736        }
737
738        if training.steps == 0 {
739            let batch = windows.batch::<B>(0, training.batch_size, device)?;
740            final_loss = tensor_scalar(cross_entropy_loss_with_mask(
741                model.forward(batch.inputs),
742                batch.targets,
743                batch.loss_mask,
744            ))?;
745            best_loss = final_loss;
746            best_loss_step = 0;
747        }
748
749        *self = model;
750
751        Ok(ModelTrainingReport {
752            steps: training.steps,
753            final_loss,
754            best_loss,
755            best_loss_step,
756            training_window_count: windows.len(),
757            parameter_count: self.parameter_count(),
758        })
759    }
760
761    /// Trains this model on chat-style (prompt, response) token-ID pairs.
762    ///
763    /// This is the chat-aware counterpart of [`MultiscreenModel::train_token_sequences`]. The model
764    /// sees the full context (prompt + response) but loss is computed **only** on
765    /// the response tokens, preventing the model from learning to generate role
766    /// labels like `system:`, `user:`, or `assistant:`.
767    ///
768    /// Each element of `chat_pairs` is `(prompt_token_ids, response_token_ids)`.
769    /// The caller is responsible for appending an EOS token to the response IDs
770    /// when desired — the EOS token will receive `loss_mask = 1.0` like any other
771    /// response token.
772    pub fn train_chat_sequences(
773        &mut self,
774        chat_pairs: &[(Vec<u32>, Vec<u32>)],
775        training: &ModelTrainingConfig,
776        device: &B::Device,
777        mut on_step: impl FnMut(usize, f32),
778    ) -> Result<ModelTrainingReport> {
779        if training.batch_size == 0 {
780            return Err(Error::Training(
781                "batch_size must be greater than zero".to_string(),
782            ));
783        }
784        let windows = TrainingWindows::from_chat_sequences(
785            chat_pairs,
786            self.config().seq_len,
787            training.pad_token_id,
788        )?;
789        if windows.is_empty() {
790            return Err(Error::Training(
791                "training requires at least one chat pair that produces two or more tokens"
792                    .to_string(),
793            ));
794        }
795
796        let mut optimizer_config =
797            AdamWConfig::new().with_weight_decay(training.weight_decay as f32);
798        if let Some(max_norm) = training.grad_clip_norm.filter(|value| *value > 0.0) {
799            optimizer_config = optimizer_config
800                .with_grad_clipping(Some(GradientClippingConfig::Norm(max_norm as f32)));
801        }
802        let mut optimizer = optimizer_config.init::<B, Self>();
803        let mut model = self.clone();
804        let mut final_loss = f32::NAN;
805        let mut best_loss = f32::MAX;
806        let mut best_loss_step: usize = 0;
807
808        let ckpt_dir = training.checkpoint_dir.as_deref().map(Path::new);
809        if let Some(dir) = &ckpt_dir {
810            std::fs::create_dir_all(dir).map_err(|e| {
811                Error::Io(format!(
812                    "failed to create checkpoint directory {:?}: {}",
813                    dir, e
814                ))
815            })?;
816        }
817
818        for step in 0..training.steps {
819            let batch = windows.batch::<B>(step, training.batch_size, device)?;
820            let logits = model.forward(batch.inputs);
821            let loss = cross_entropy_loss_with_mask(logits, batch.targets, batch.loss_mask);
822            final_loss = tensor_scalar(loss.clone())?;
823            let grads = loss.backward();
824            let grads = GradientsParams::from_grads(grads, &model);
825            model = optimizer.step(training.learning_rate, model, grads);
826
827            // --- Periodic + best checkpoint saving ---
828            if final_loss < best_loss {
829                best_loss = final_loss;
830                best_loss_step = step;
831                if let Some(dir) = &ckpt_dir {
832                    let path = dir.join("best.mpk");
833                    model.save_parameters(&path)?;
834                }
835            }
836            if training.checkpoint_interval > 0
837                && (step + 1) % training.checkpoint_interval == 0
838                && let Some(dir) = &ckpt_dir
839            {
840                let path = dir.join(format!("step_{:06}.mpk", step + 1));
841                model.save_parameters(&path)?;
842            }
843
844            on_step(step, final_loss);
845        }
846
847        if training.steps == 0 {
848            let batch = windows.batch::<B>(0, training.batch_size, device)?;
849            final_loss = tensor_scalar(cross_entropy_loss_with_mask(
850                model.forward(batch.inputs),
851                batch.targets,
852                batch.loss_mask,
853            ))?;
854            best_loss = final_loss;
855            best_loss_step = 0;
856        }
857
858        *self = model;
859
860        Ok(ModelTrainingReport {
861            steps: training.steps,
862            final_loss,
863            best_loss,
864            best_loss_step,
865            training_window_count: windows.len(),
866            parameter_count: self.parameter_count(),
867        })
868    }
869}
870
871impl<B: Backend> crate::lm::LanguageModel<B> for MultiscreenModel<B> {
872    fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
873        MultiscreenModel::forward(self, tokens)
874    }
875}
876
877impl<B: Backend> crate::lm::TrainableLanguageModel<B> for MultiscreenModel<B> {}
878
879impl<B: Backend> GatedScreeningTile<B> {
880    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
881        let q = row_unit_normalize(linear(x.clone(), self.w_q.val()));
882        let k = row_unit_normalize(linear(x.clone(), self.w_k.val()));
883        let v = row_unit_normalize(linear(x.clone(), self.w_v.val()));
884        let g = linear(x, self.w_g.val());
885
886        let w = self.s_w.val().clamp(-10.0, 8.0).exp() + 1.0;
887        let r = activation::sigmoid(self.s_r.val().clamp(-10.0, 8.0));
888
889        let q = apply_mipe(q, w.clone(), self.w_th);
890        let k = apply_mipe(k, w.clone(), self.w_th);
891
892        let similarity = q.matmul(k.swap_dims(1, 2));
893        let alpha = trim_and_square_tensor(similarity.clone(), r);
894        let softmask = causal_softmask_tensor::<B>(similarity.dims()[1], w, &similarity.device());
895        let relevance = alpha * softmask.unsqueeze();
896        let h = relevance.matmul(v);
897        let u = tanh_norm(h);
898        let gate = activation::silu(g).tanh();
899        let gated = u * gate;
900        let out = linear(gated, self.w_o.val());
901        out.clone() * expand_scalar3(self.s_o.val().exp(), out.dims())
902    }
903}
904
905#[allow(dead_code)]
906pub fn cross_entropy_loss<B: Backend>(
907    logits: Tensor<B, 3>,
908    targets: Tensor<B, 2, Int>,
909) -> Tensor<B, 1> {
910    let device = logits.device();
911    let [batch, seq_len, _] = logits.dims();
912    let loss_mask = Tensor::<B, 2>::ones([batch, seq_len], &device);
913    cross_entropy_loss_with_mask(logits, targets, loss_mask)
914}
915
916pub fn cross_entropy_loss_with_mask<B: Backend>(
917    logits: Tensor<B, 3>,
918    targets: Tensor<B, 2, Int>,
919    loss_mask: Tensor<B, 2>,
920) -> Tensor<B, 1> {
921    let [batch, seq_len, vocab_size] = logits.dims();
922    let token_count = batch * seq_len;
923    let flat_logits = logits.reshape([token_count, vocab_size]);
924    let flat_targets = targets.reshape([token_count]);
925    let flat_mask = loss_mask.reshape([token_count]);
926    let log_probs = activation::log_softmax(flat_logits, 1);
927    let target_probs = flat_targets.one_hot::<2>(vocab_size).float();
928    let picked = (log_probs * target_probs).sum_dim(1).reshape([token_count]);
929    let masked_nll = (picked.neg() * flat_mask.clone()).sum();
930    // Guard against all-zero masks: return a safe "no loss" scalar instead
931    // of near-zero due to EPS / EPS. This prevents the optimizer from
932    // seeing a bogus 0-loss that corrupts best-loss tracking.
933    let mask_sum = flat_mask.sum();
934    let denom = mask_sum.add_scalar(EPS);
935    masked_nll / denom
936}
937
938pub fn row_unit_normalize<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
939    let denom = x.clone().square().sum_dim(D - 1).add_scalar(EPS).sqrt();
940    x / denom
941}
942
943fn trim_and_square_tensor<B: Backend>(similarity: Tensor<B, 3>, r: Tensor<B, 1>) -> Tensor<B, 3> {
944    let distance_from_one = similarity.mul_scalar(-1.0).add_scalar(1.0);
945    let scaled = distance_from_one.clone() / expand_scalar3(r, distance_from_one.clone().dims());
946    scaled
947        .mul_scalar(-1.0)
948        .add_scalar(1.0)
949        .clamp(0.0, 1.0)
950        .square()
951}
952
953fn causal_softmask_tensor<B: Backend>(
954    seq_len: usize,
955    w: Tensor<B, 1>,
956    device: &B::Device,
957) -> Tensor<B, 2> {
958    let mut distances = Vec::with_capacity(seq_len * seq_len);
959    for i in 0..seq_len {
960        for j in 0..seq_len {
961            distances.push(j as f32 - i as f32);
962        }
963    }
964
965    let dist = Tensor::<B, 2>::from_data(TensorData::new(distances, [seq_len, seq_len]), device);
966    let w = expand_scalar2(w, [seq_len, seq_len]);
967    let causal = dist.clone().lower_equal_elem(0.0);
968    let within_window = dist.clone().greater(w.clone().neg());
969    let active = causal.float() * within_window.float();
970    let tapered = (dist / w)
971        .mul_scalar(PI)
972        .cos()
973        .mul_scalar(0.5)
974        .add_scalar(0.5);
975    active * tapered
976}
977
978pub fn tanh_norm<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
979    let norm = x.clone().square().sum_dim(D - 1).add_scalar(EPS).sqrt();
980    let scale = norm.clone().tanh() / norm;
981    x * scale
982}
983
984fn apply_mipe<B: Backend>(z: Tensor<B, 3>, w: Tensor<B, 1>, w_th: f32) -> Tensor<B, 3> {
985    let [_batch, seq_len, d_key] = z.dims();
986    debug_assert!(d_key >= 2);
987
988    let positions = Tensor::<B, 3>::from_data(
989        TensorData::new(
990            (0..seq_len).map(|idx| idx as f32).collect::<Vec<_>>(),
991            [1, seq_len, 1],
992        ),
993        &z.device(),
994    );
995
996    let gamma_raw = w
997        .clone()
998        .mul_scalar(PI / w_th)
999        .cos()
1000        .mul_scalar(0.5)
1001        .add_scalar(0.5);
1002    let gamma = gamma_raw.mask_fill(w.clone().greater_equal_elem(w_th), 0.0);
1003    let gamma = expand_scalar3(gamma, [1, seq_len, 1]);
1004    let w = expand_scalar3(w, [1, seq_len, 1]);
1005    let phi = (positions * gamma / w).mul_scalar(PI);
1006    let cos_phi = phi.clone().cos();
1007    let sin_phi = phi.sin();
1008
1009    let x0 = z.clone().narrow(2, 0, 1);
1010    let x1 = z.clone().narrow(2, 1, 1);
1011    let rot0 = x0.clone() * cos_phi.clone() - x1.clone() * sin_phi.clone();
1012    let rot1 = x0 * sin_phi + x1 * cos_phi;
1013
1014    if d_key == 2 {
1015        Tensor::cat(vec![rot0, rot1], 2)
1016    } else {
1017        let rest = z.narrow(2, 2, d_key - 2);
1018        Tensor::cat(vec![rot0, rot1, rest], 2)
1019    }
1020}
1021
1022fn linear<B: Backend>(x: Tensor<B, 3>, weight: Tensor<B, 2>) -> Tensor<B, 3> {
1023    let [batch, _seq_len, _in_dim] = x.dims();
1024    let [weight_in, out_dim] = weight.dims();
1025    x.matmul(weight.unsqueeze::<3>().expand([batch, weight_in, out_dim]))
1026}
1027
1028fn expand_scalar2<B: Backend>(value: Tensor<B, 1>, dims: [usize; 2]) -> Tensor<B, 2> {
1029    value.unsqueeze::<2>().expand(dims)
1030}
1031
1032fn expand_scalar3<B: Backend>(value: Tensor<B, 1>, dims: [usize; 3]) -> Tensor<B, 3> {
1033    value.unsqueeze::<3>().expand(dims)
1034}
1035
1036fn init_scalar<B: Backend>(value: f32, device: &B::Device) -> Param<Tensor<B, 1>> {
1037    Param::from_tensor(Tensor::<B, 1>::from_data([value], device))
1038}
1039
1040fn init_matrix<B: Backend>(
1041    rows: usize,
1042    cols: usize,
1043    std: f32,
1044    seed: &mut u64,
1045    device: &B::Device,
1046) -> Param<Tensor<B, 2>> {
1047    let values = gaussian_values(rows * cols, std, seed);
1048    Param::from_tensor(Tensor::<B, 2>::from_data(
1049        TensorData::new(values, [rows, cols]),
1050        device,
1051    ))
1052}
1053
1054fn gaussian_values(len: usize, std: f32, seed: &mut u64) -> Vec<f32> {
1055    let mut values = Vec::with_capacity(len);
1056    while values.len() < len {
1057        let u1 = next_uniform(seed).max(1e-7);
1058        let u2 = next_uniform(seed);
1059        let radius = (-2.0 * u1.ln()).sqrt();
1060        let theta = 2.0 * PI * u2;
1061        values.push(radius * theta.cos() * std);
1062        if values.len() < len {
1063            values.push(radius * theta.sin() * std);
1064        }
1065    }
1066    values
1067}
1068
1069fn next_uniform(seed: &mut u64) -> f32 {
1070    *seed = seed
1071        .wrapping_mul(6364136223846793005)
1072        .wrapping_add(1442695040888963407);
1073    let bits = (*seed >> 40) as u32;
1074    (bits as f32 + 1.0) / ((1u32 << 24) as f32 + 2.0)
1075}
1076
1077fn context_window(context: &[u32], seq_len: usize, pad_token_id: u32) -> Vec<u32> {
1078    let mut input = vec![pad_token_id; seq_len];
1079    let suffix = if context.len() > seq_len {
1080        &context[context.len() - seq_len..]
1081    } else {
1082        context
1083    };
1084    let start = seq_len - suffix.len();
1085    input[start..].copy_from_slice(suffix);
1086    input
1087}
1088
1089fn argmax(values: &[f32]) -> Result<usize> {
1090    values
1091        .iter()
1092        .enumerate()
1093        .max_by(|left, right| left.1.total_cmp(right.1))
1094        .map(|(index, _)| index)
1095        .ok_or_else(|| Error::Inference("cannot argmax empty logits".to_string()))
1096}
1097
1098fn tensor_scalar<B: Backend>(tensor: Tensor<B, 1>) -> Result<f32> {
1099    tensor
1100        .into_data()
1101        .into_vec::<f32>()
1102        .map_err(|err| Error::Training(err.to_string()))?
1103        .into_iter()
1104        .next()
1105        .ok_or_else(|| Error::Training("expected scalar tensor".to_string()))
1106}
1107
1108fn tensor_to_vec<B: Backend>(tensor: Tensor<B, 1>) -> Result<Vec<f32>> {
1109    tensor
1110        .into_data()
1111        .into_vec::<f32>()
1112        .map_err(|err| Error::Inference(err.to_string()))
1113}
1114
1115fn tensor_from_u32<B: Backend, const D: usize>(
1116    values: Vec<u32>,
1117    shape: [usize; D],
1118    device: &B::Device,
1119) -> Result<Tensor<B, D, Int>> {
1120    let values = values
1121        .into_iter()
1122        .map(|value| {
1123            i32::try_from(value)
1124                .map_err(|_| Error::Config(format!("token id {value} exceeds i32::MAX")))
1125        })
1126        .collect::<Result<Vec<_>>>()?;
1127    Ok(Tensor::<B, D, Int>::from_data(
1128        TensorData::new(values, shape),
1129        device,
1130    ))
1131}
1132
1133fn ensure(condition: bool, message: &str) -> Result<()> {
1134    if condition {
1135        Ok(())
1136    } else {
1137        Err(Error::Config(message.to_string()))
1138    }
1139}
1140
1141struct TrainingWindow {
1142    inputs: Vec<u32>,
1143    targets: Vec<u32>,
1144    loss_mask: Vec<f32>,
1145}
1146
1147struct TrainingWindows {
1148    windows: Vec<TrainingWindow>,
1149    seq_len: usize,
1150}
1151
1152impl TrainingWindows {
1153    /// Creates training windows from chat-style (prompt, response) pairs with loss masking.
1154    ///
1155    /// Tokens belonging to the prompt portion have `loss_mask = 0.0` so the model
1156    /// sees them as context but does not learn to generate them. Response tokens
1157    /// (including any EOS appended by the caller) have `loss_mask = 1.0`.
1158    fn from_chat_sequences(
1159        chat_pairs: &[(Vec<u32>, Vec<u32>)],
1160        seq_len: usize,
1161        pad_token_id: u32,
1162    ) -> Result<Self> {
1163        let mut windows = Vec::new();
1164        for (prompt_ids, response_ids) in chat_pairs {
1165            let mut full_seq = prompt_ids.clone();
1166            full_seq.extend_from_slice(response_ids);
1167
1168            if full_seq.len() < 2 {
1169                continue;
1170            }
1171
1172            let prompt_len = prompt_ids.len();
1173
1174            let mut start = 0;
1175            while start + 1 < full_seq.len() {
1176                let end = (start + seq_len + 1).min(full_seq.len());
1177                let chunk = &full_seq[start..end];
1178                let prediction_count = chunk.len() - 1;
1179
1180                let mut inputs = vec![pad_token_id; seq_len];
1181                let mut targets = vec![pad_token_id; seq_len];
1182                let mut loss_mask = vec![0.0; seq_len];
1183                inputs[..prediction_count].copy_from_slice(&chunk[..prediction_count]);
1184                targets[..prediction_count].copy_from_slice(&chunk[1..]);
1185
1186                // Mask: only compute loss when the *target* token falls within the
1187                // response portion of the full sequence.
1188                let mut has_unmasked = false;
1189                for (i, mask) in loss_mask.iter_mut().enumerate().take(prediction_count) {
1190                    let target_global_idx = start + i + 1;
1191                    if target_global_idx >= prompt_len {
1192                        *mask = 1.0;
1193                        has_unmasked = true;
1194                    }
1195                }
1196
1197                // Skip windows that have NO response tokens at all (pure prompt).
1198                // These produce loss ≈ 0 due to EPS normalization, which corrupts
1199                // best-loss tracking and wastes compute.
1200                if !has_unmasked {
1201                    if end == full_seq.len() {
1202                        break;
1203                    }
1204                    start += seq_len;
1205                    continue;
1206                }
1207
1208                windows.push(TrainingWindow {
1209                    inputs,
1210                    targets,
1211                    loss_mask,
1212                });
1213
1214                if end == full_seq.len() {
1215                    break;
1216                }
1217                start += seq_len;
1218            }
1219        }
1220
1221        Ok(Self { windows, seq_len })
1222    }
1223
1224    fn from_sequences(sequences: &[Vec<u32>], seq_len: usize, pad_token_id: u32) -> Result<Self> {
1225        let mut windows = Vec::new();
1226        for sequence in sequences {
1227            if sequence.len() < 2 {
1228                continue;
1229            }
1230
1231            let mut start = 0;
1232            while start + 1 < sequence.len() {
1233                let end = (start + seq_len + 1).min(sequence.len());
1234                let chunk = &sequence[start..end];
1235                let prediction_count = chunk.len() - 1;
1236
1237                let mut inputs = vec![pad_token_id; seq_len];
1238                let mut targets = vec![pad_token_id; seq_len];
1239                let mut loss_mask = vec![0.0; seq_len];
1240                inputs[..prediction_count].copy_from_slice(&chunk[..prediction_count]);
1241                targets[..prediction_count].copy_from_slice(&chunk[1..]);
1242                loss_mask[..prediction_count].fill(1.0);
1243
1244                windows.push(TrainingWindow {
1245                    inputs,
1246                    targets,
1247                    loss_mask,
1248                });
1249
1250                if end == sequence.len() {
1251                    break;
1252                }
1253                start += seq_len;
1254            }
1255        }
1256
1257        Ok(Self { windows, seq_len })
1258    }
1259
1260    fn is_empty(&self) -> bool {
1261        self.windows.is_empty()
1262    }
1263
1264    fn len(&self) -> usize {
1265        self.windows.len()
1266    }
1267
1268    fn batch<B: Backend>(
1269        &self,
1270        step: usize,
1271        batch_size: usize,
1272        device: &B::Device,
1273    ) -> Result<TokenBatch<B>> {
1274        let mut inputs = Vec::with_capacity(batch_size * self.seq_len);
1275        let mut targets = Vec::with_capacity(batch_size * self.seq_len);
1276        let mut loss_mask = Vec::with_capacity(batch_size * self.seq_len);
1277
1278        for batch_idx in 0..batch_size {
1279            let index = (step * batch_size + batch_idx) % self.windows.len();
1280            let window = &self.windows[index];
1281            inputs.extend_from_slice(&window.inputs);
1282            targets.extend_from_slice(&window.targets);
1283            loss_mask.extend_from_slice(&window.loss_mask);
1284        }
1285
1286        Ok(TokenBatch {
1287            inputs: tensor_from_u32(inputs, [batch_size, self.seq_len], device)?,
1288            targets: tensor_from_u32(targets, [batch_size, self.seq_len], device)?,
1289            loss_mask: Tensor::<B, 2>::from_data(
1290                TensorData::new(loss_mask, [batch_size, self.seq_len]),
1291                device,
1292            ),
1293        })
1294    }
1295}
1296
1297struct TokenBatch<B: Backend> {
1298    inputs: Tensor<B, 2, Int>,
1299    targets: Tensor<B, 2, Int>,
1300    loss_mask: Tensor<B, 2>,
1301}
1302
1303#[cfg(test)]
1304#[allow(dead_code)]
1305pub fn make_batch<B: Backend>(
1306    step: usize,
1307    batch_size: usize,
1308    seq_len: usize,
1309    vocab_size: usize,
1310    device: &B::Device,
1311) -> Result<(Tensor<B, 2, Int>, Tensor<B, 2, Int>)> {
1312    let mut inputs = Vec::with_capacity(batch_size * seq_len);
1313    let mut targets = Vec::with_capacity(batch_size * seq_len);
1314
1315    for batch in 0..batch_size {
1316        let offset = (step * 7 + batch * 13) % vocab_size;
1317        for pos in 0..seq_len {
1318            let token = ((offset + pos) % vocab_size) as u32;
1319            let next = ((offset + pos + 1) % vocab_size) as u32;
1320            inputs.push(token);
1321            targets.push(next);
1322        }
1323    }
1324
1325    Ok((
1326        tensor_from_u32(inputs, [batch_size, seq_len], device)?,
1327        tensor_from_u32(targets, [batch_size, seq_len], device)?,
1328    ))
1329}