Skip to main content

multiscreen_rs/
model.rs

1use crate::{
2    error::{Error, Result},
3    runtime::DefaultAutodiffBackend,
4};
5use burn::{
6    grad_clipping::GradientClippingConfig,
7    module::{Module, Param},
8    optim::{AdamWConfig, GradientsParams, Optimizer},
9    record::{FullPrecisionSettings, NamedMpkFileRecorder},
10    tensor::{
11        activation,
12        backend::{AutodiffBackend, Backend},
13        Int, Tensor, TensorData,
14    },
15};
16use serde::{Deserialize, Serialize};
17use std::f32::consts::PI;
18use std::path::Path;
19
20const EPS: f32 = 1e-6;
21
22/// Supported neural Multiscreen parameter budgets.
23///
24/// The final count is approximate because the token embedding scales with the
25/// caller's vocabulary size. Use `MultiscreenModelConfig::estimated_parameter_count`
26/// to see the exact count for a resolved config.
27#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
28pub enum MultiscreenParameterBudget {
29    Params1M,
30    Params5M,
31    Params10M,
32    Params50M,
33    Params100M,
34}
35
36impl MultiscreenParameterBudget {
37    pub const ALL: [Self; 5] = [
38        Self::Params1M,
39        Self::Params5M,
40        Self::Params10M,
41        Self::Params50M,
42        Self::Params100M,
43    ];
44
45    pub fn label(self) -> &'static str {
46        match self {
47            Self::Params1M => "1M",
48            Self::Params5M => "5M",
49            Self::Params10M => "10M",
50            Self::Params50M => "50M",
51            Self::Params100M => "100M",
52        }
53    }
54
55    pub fn target_parameter_count(self) -> usize {
56        match self {
57            Self::Params1M => 1_000_000,
58            Self::Params5M => 5_000_000,
59            Self::Params10M => 10_000_000,
60            Self::Params50M => 50_000_000,
61            Self::Params100M => 100_000_000,
62        }
63    }
64}
65
66/// Paper-faithful neural Multiscreen model configuration.
67#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
68pub struct MultiscreenModelConfig {
69    /// Vocabulary size.
70    pub vocab_size: usize,
71    /// Context length `T`.
72    pub seq_len: usize,
73    /// Paper `N_L`: number of residual Multiscreen layers.
74    pub layers: usize,
75    /// Paper `N_H`: number of gated screening tiles per layer.
76    pub tiles: usize,
77    /// Paper `d_E`: token embedding/model width.
78    pub d_model: usize,
79    /// Paper `d_K`: query/key width. Must be at least 2 for MiPE.
80    pub d_key: usize,
81    /// Paper `d_V`: value/gate width.
82    pub d_value: usize,
83    /// Paper MiPE threshold `w_th`.
84    pub w_th: f32,
85}
86
87impl MultiscreenModelConfig {
88    pub fn tiny() -> Self {
89        Self {
90            vocab_size: 64,
91            seq_len: 64,
92            layers: 2,
93            tiles: 4,
94            d_model: 64,
95            d_key: 16,
96            d_value: 32,
97            w_th: 32.0,
98        }
99    }
100
101    pub fn tiny_for_tests() -> Self {
102        Self {
103            vocab_size: 32,
104            seq_len: 8,
105            layers: 1,
106            tiles: 2,
107            d_model: 16,
108            d_key: 4,
109            d_value: 8,
110            w_th: 8.0,
111        }
112    }
113
114    pub fn for_parameter_budget(
115        budget: MultiscreenParameterBudget,
116        vocab_size: usize,
117        seq_len: usize,
118    ) -> Self {
119        match budget {
120            MultiscreenParameterBudget::Params1M => Self::preset_1m(vocab_size, seq_len),
121            MultiscreenParameterBudget::Params5M => Self::preset_5m(vocab_size, seq_len),
122            MultiscreenParameterBudget::Params10M => Self::preset_10m(vocab_size, seq_len),
123            MultiscreenParameterBudget::Params50M => Self::preset_50m(vocab_size, seq_len),
124            MultiscreenParameterBudget::Params100M => Self::preset_100m(vocab_size, seq_len),
125        }
126    }
127
128    pub fn preset_1m(vocab_size: usize, seq_len: usize) -> Self {
129        Self::from_dimensions(vocab_size, seq_len, 2, 2, 128, 32, 64)
130    }
131
132    pub fn preset_5m(vocab_size: usize, seq_len: usize) -> Self {
133        Self::from_dimensions(vocab_size, seq_len, 2, 4, 384, 96, 192)
134    }
135
136    pub fn preset_10m(vocab_size: usize, seq_len: usize) -> Self {
137        Self::from_dimensions(vocab_size, seq_len, 3, 4, 512, 128, 256)
138    }
139
140    pub fn preset_50m(vocab_size: usize, seq_len: usize) -> Self {
141        Self::from_dimensions(vocab_size, seq_len, 6, 4, 960, 240, 480)
142    }
143
144    pub fn preset_100m(vocab_size: usize, seq_len: usize) -> Self {
145        Self::from_dimensions(vocab_size, seq_len, 8, 4, 1216, 304, 608)
146    }
147
148    pub fn paper_10m(vocab_size: usize, seq_len: usize) -> Self {
149        Self::preset_10m(vocab_size, seq_len)
150    }
151
152    pub fn estimated_parameter_count(&self) -> usize {
153        let embedding_params = self.vocab_size.saturating_mul(self.d_model);
154        let per_tile_params = self
155            .d_model
156            .saturating_mul(
157                2usize
158                    .saturating_mul(self.d_key)
159                    .saturating_add(3usize.saturating_mul(self.d_value)),
160            )
161            .saturating_add(3);
162        let tile_params = self
163            .layers
164            .saturating_mul(self.tiles)
165            .saturating_mul(per_tile_params);
166        embedding_params
167            .saturating_add(2)
168            .saturating_add(tile_params)
169    }
170
171    fn from_dimensions(
172        vocab_size: usize,
173        seq_len: usize,
174        layers: usize,
175        tiles: usize,
176        d_model: usize,
177        d_key: usize,
178        d_value: usize,
179    ) -> Self {
180        Self {
181            vocab_size,
182            seq_len,
183            layers,
184            tiles,
185            d_model,
186            d_key,
187            d_value,
188            w_th: 32.0,
189        }
190    }
191
192    pub fn validate(&self) -> Result<()> {
193        ensure(self.vocab_size > 0, "vocab_size must be greater than zero")?;
194        ensure(self.seq_len > 0, "seq_len must be greater than zero")?;
195        ensure(self.layers > 0, "layers must be greater than zero")?;
196        ensure(self.tiles > 0, "tiles must be greater than zero")?;
197        ensure(self.d_model > 0, "d_model must be greater than zero")?;
198        ensure(self.d_key >= 2, "d_key must be at least 2 for MiPE")?;
199        ensure(self.d_value > 0, "d_value must be greater than zero")?;
200        ensure(
201            self.w_th.is_finite() && self.w_th > 0.0,
202            "w_th must be positive and finite",
203        )?;
204        Ok(())
205    }
206}
207
208/// Training options for token-sequence training.
209#[derive(Clone, Debug)]
210pub struct ModelTrainingConfig {
211    pub steps: usize,
212    pub batch_size: usize,
213    pub learning_rate: f64,
214    pub weight_decay: f64,
215    pub grad_clip_norm: Option<f64>,
216    pub pad_token_id: u32,
217}
218
219impl Default for ModelTrainingConfig {
220    fn default() -> Self {
221        Self {
222            steps: 100,
223            batch_size: 4,
224            learning_rate: 2e-4,
225            weight_decay: 0.01,
226            grad_clip_norm: Some(1.0),
227            pad_token_id: 0,
228        }
229    }
230}
231
232/// Greedy generation options.
233#[derive(Clone, Debug)]
234pub struct ModelInferenceConfig {
235    pub max_new_tokens: usize,
236    pub pad_token_id: u32,
237}
238
239impl Default for ModelInferenceConfig {
240    fn default() -> Self {
241        Self {
242            max_new_tokens: 16,
243            pad_token_id: 0,
244        }
245    }
246}
247
248#[derive(Clone, Debug, PartialEq)]
249pub struct ModelTrainingReport {
250    pub steps: usize,
251    pub final_loss: f32,
252    pub training_window_count: usize,
253    pub parameter_count: usize,
254}
255
256/// Result of evaluating a model on held-out sequences.
257#[derive(Clone, Debug)]
258pub struct EvaluationResult {
259    /// Average cross-entropy loss across all batches.
260    pub loss: f32,
261    /// Perplexity = exp(average_loss).
262    pub perplexity: f32,
263    /// Fraction of tokens where argmax(logits) == target (next-token accuracy).
264    pub accuracy: f64,
265    /// Number of batches evaluated.
266    pub num_batches: usize,
267    /// Total number of (unmasked) tokens evaluated.
268    pub total_tokens: usize,
269}
270
271pub struct MultiscreenModelOutput {
272    pub token_ids: Vec<u32>,
273}
274
275/// Burn-backed neural Multiscreen language model ported from
276/// `multiscreen-testing`.
277#[derive(Module, Debug)]
278pub struct MultiscreenModel<B: Backend = DefaultAutodiffBackend> {
279    #[module(skip)]
280    config: MultiscreenModelConfig,
281    token_embedding: Param<Tensor<B, 2>>,
282    s_e: Param<Tensor<B, 1>>,
283    s_f: Param<Tensor<B, 1>>,
284    layers: Vec<MultiscreenLayer<B>>,
285}
286
287/// Convenience alias for the default Burn Flex autodiff model.
288pub type DefaultMultiscreenModel = MultiscreenModel<DefaultAutodiffBackend>;
289
290#[derive(Module, Debug)]
291struct MultiscreenLayer<B: Backend> {
292    tiles: Vec<GatedScreeningTile<B>>,
293}
294
295#[derive(Module, Debug)]
296struct GatedScreeningTile<B: Backend> {
297    w_q: Param<Tensor<B, 2>>,
298    w_k: Param<Tensor<B, 2>>,
299    w_v: Param<Tensor<B, 2>>,
300    w_g: Param<Tensor<B, 2>>,
301    w_o: Param<Tensor<B, 2>>,
302    s_w: Param<Tensor<B, 1>>,
303    s_r: Param<Tensor<B, 1>>,
304    s_o: Param<Tensor<B, 1>>,
305    #[module(skip)]
306    w_th: f32,
307}
308
309impl<B: Backend> MultiscreenModel<B> {
310    pub fn new(config: MultiscreenModelConfig, device: &B::Device) -> Result<Self> {
311        config.validate()?;
312
313        let mut seed = 0x4d55_4c54_4953_4352;
314        let token_embedding = init_matrix(
315            config.vocab_size,
316            config.d_model,
317            0.1 / (config.d_model as f32).sqrt(),
318            &mut seed,
319            device,
320        );
321        let s_e = init_scalar(0.0, device);
322        let s_f = init_scalar((config.d_model as f32).sqrt().ln(), device);
323
324        let mut layers = Vec::with_capacity(config.layers);
325        for _layer_idx in 0..config.layers {
326            let mut tiles = Vec::with_capacity(config.tiles);
327            for tile_idx in 0..config.tiles {
328                let w_q = init_matrix(
329                    config.d_model,
330                    config.d_key,
331                    0.1 / (config.d_key as f32).sqrt(),
332                    &mut seed,
333                    device,
334                );
335                let w_k = init_matrix(
336                    config.d_model,
337                    config.d_key,
338                    0.1 / (config.d_key as f32).sqrt(),
339                    &mut seed,
340                    device,
341                );
342                let w_v = init_matrix(
343                    config.d_model,
344                    config.d_value,
345                    0.1 / (config.d_value as f32).sqrt(),
346                    &mut seed,
347                    device,
348                );
349                let w_g = init_matrix(config.d_model, config.d_value, 0.1, &mut seed, device);
350                let w_o = init_matrix(
351                    config.d_value,
352                    config.d_model,
353                    0.1 / (config.d_model as f32).sqrt(),
354                    &mut seed,
355                    device,
356                );
357
358                let window_frac = if config.tiles == 1 {
359                    0.0
360                } else {
361                    tile_idx as f32 / (config.tiles - 1) as f32
362                };
363                let s_w = init_scalar(window_frac * config.w_th.ln(), device);
364                let s_r = init_scalar(0.0, device);
365                let s_o = init_scalar(-0.5 * ((config.layers * config.tiles) as f32).ln(), device);
366
367                tiles.push(GatedScreeningTile {
368                    w_q,
369                    w_k,
370                    w_v,
371                    w_g,
372                    w_o,
373                    s_w,
374                    s_r,
375                    s_o,
376                    w_th: config.w_th,
377                });
378            }
379            layers.push(MultiscreenLayer { tiles });
380        }
381
382        Ok(Self {
383            config,
384            token_embedding,
385            s_e,
386            s_f,
387            layers,
388        })
389    }
390
391    pub fn config(&self) -> &MultiscreenModelConfig {
392        &self.config
393    }
394
395    pub fn parameter_count(&self) -> usize {
396        self.num_params()
397    }
398
399    pub fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
400        let [batch, seq_len] = tokens.dims();
401        let embedding = row_unit_normalize(self.token_embedding.val());
402        let one_hot = tokens.one_hot::<3>(self.config().vocab_size).float();
403        let mut x =
404            linear(one_hot, embedding.clone()).reshape([batch, seq_len, self.config().d_model]);
405        let x_dims = x.dims();
406        x = x * expand_scalar3(self.s_e.val().exp(), x_dims);
407
408        for layer in &self.layers {
409            let mut layer_update = Tensor::<B, 3>::zeros(x.dims(), &x.device());
410            for tile in &layer.tiles {
411                layer_update = layer_update + tile.forward(x.clone());
412            }
413            x = x + layer_update;
414        }
415
416        let logits_weight = embedding.swap_dims(0, 1);
417        let logits = linear(x, logits_weight);
418        logits.clone() * expand_scalar3(self.s_f.val().exp(), logits.dims())
419    }
420
421    pub fn save_parameters(&self, path: impl AsRef<Path>) -> Result<()> {
422        let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
423        self.clone()
424            .save_file(path.as_ref().to_path_buf(), &recorder)
425            .map_err(|err| Error::Serialization(err.to_string()))
426    }
427
428    pub fn load_parameters(&mut self, path: impl AsRef<Path>) -> Result<()> {
429        let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
430        let device =
431            self.devices().into_iter().next().ok_or_else(|| {
432                Error::Serialization("model has no device for parameter load".into())
433            })?;
434        let loaded = self
435            .clone()
436            .load_file(path.as_ref().to_path_buf(), &recorder, &device)
437            .map_err(|err| Error::Serialization(err.to_string()))?;
438        *self = loaded;
439        Ok(())
440    }
441
442    /// Greedy token generation.
443    /// Generate tokens one at a time, invoking a callback for each newly
444    /// produced token. This enables streaming / word-by-word output similar
445    /// to ChatGPT.
446    ///
447    /// The callback receives `(token_id, index)` where `index` is the
448    /// zero-based position of the *new* token (0 = first generated token).
449    /// If the callback returns `false`, generation stops early.
450    ///
451    /// Returns the full output (prompt + generated) token sequence.
452    pub fn infer_tokens_stream(
453        &self,
454        prompt: &[u32],
455        inference: &ModelInferenceConfig,
456        device: &B::Device,
457        mut on_token: impl FnMut(u32, usize) -> bool,
458    ) -> Result<MultiscreenModelOutput> {
459        if prompt.is_empty() {
460            return Err(Error::Inference(
461                "prompt must contain at least one token".to_string(),
462            ));
463        }
464
465        let mut output = prompt.to_vec();
466        for i in 0..inference.max_new_tokens {
467            let next = self.predict_next_token(&output, inference.pad_token_id, device)?;
468            output.push(next);
469            if !on_token(next, i) {
470                break;
471            }
472        }
473
474        Ok(MultiscreenModelOutput { token_ids: output })
475    }
476
477    /// Generate tokens and return them all at once (non-streaming).
478    ///
479    /// For streaming / token-by-token output, use [`Self::infer_tokens_stream`].
480    pub fn infer_tokens(
481        &self,
482        prompt: &[u32],
483        inference: &ModelInferenceConfig,
484        device: &B::Device,
485    ) -> Result<MultiscreenModelOutput> {
486        self.infer_tokens_stream(prompt, inference, device, |_, _| true)
487    }
488
489    pub fn predict_next_token(
490        &self,
491        context: &[u32],
492        pad_token_id: u32,
493        device: &B::Device,
494    ) -> Result<u32> {
495        if context.is_empty() {
496            return Err(Error::Inference(
497                "context must contain at least one token".to_string(),
498            ));
499        }
500
501        let input = context_window(context, self.config().seq_len, pad_token_id);
502        let input = tensor_from_u32::<B, 2>(input, [1, self.config().seq_len], device)?;
503        let logits = self.forward(input);
504        let last_logits = logits
505            .slice([
506                0..1,
507                self.config().seq_len - 1..self.config().seq_len,
508                0..self.config().vocab_size,
509            ])
510            .reshape([self.config().vocab_size]);
511        let values = tensor_to_vec(last_logits)?;
512        argmax(&values).map(|idx| idx as u32)
513    }
514
515    /// Run a forward pass and return the full logit tensor.
516    ///
517    /// The returned tensor has shape `[1, seq_len, vocab_size]`.
518    /// This is useful for sampling-based generation (top-k, temperature, etc.)
519    /// where you need access to the raw logit values, not just the argmax.
520    ///
521    /// The `context` is padded/truncated to `seq_len` automatically.
522    pub fn forward_logits(
523        &self,
524        context: &[u32],
525        pad_token_id: u32,
526        device: &B::Device,
527    ) -> Result<Tensor<B, 3>> {
528        if context.is_empty() {
529            return Err(Error::Inference(
530                "context must contain at least one token".to_string(),
531            ));
532        }
533
534        let input = context_window(context, self.config().seq_len, pad_token_id);
535        let input = tensor_from_u32::<B, 2>(input, [1, self.config().seq_len], device)?;
536        let logits = self.forward(input);
537        Ok(logits)
538    }
539
540    /// Evaluates the model on token sequences, returning average loss,
541    /// perplexity, and next-token prediction accuracy.
542    ///
543    /// This method works on any `Backend` (including non-autodiff), which makes
544    /// it safe to call on an inference-only model without VRAM growth.
545    pub fn evaluate_on_sequences(
546        &self,
547        sequences: &[Vec<u32>],
548        seq_len: usize,
549        batch_size: usize,
550        pad_token_id: u32,
551        device: &B::Device,
552    ) -> Result<EvaluationResult> {
553        let windows = TrainingWindows::from_sequences(sequences, seq_len, pad_token_id)?;
554        if windows.is_empty() {
555            return Ok(EvaluationResult {
556                loss: f32::NAN,
557                perplexity: f32::NAN,
558                accuracy: 0.0,
559                num_batches: 0,
560                total_tokens: 0,
561            });
562        }
563
564        let num_batches = windows.len().div_ceil(batch_size);
565        let mut total_loss = 0.0_f64;
566        let mut total_correct = 0_usize;
567        let mut total_tokens = 0_usize;
568
569        for step in 0..num_batches {
570            let batch = windows.batch::<B>(step, batch_size, device)?;
571            let logits = self.forward(batch.inputs); // [B, S, V]
572            let loss = cross_entropy_loss_with_mask(
573                logits.clone(),
574                batch.targets.clone(),
575                batch.loss_mask.clone(),
576            );
577            let loss_val = tensor_scalar(loss)? as f64;
578            total_loss += loss_val;
579
580            // Compute accuracy: argmax(logits) == target, masked
581            let [b, s, v] = logits.dims();
582            let mask_vec: Vec<f32> = batch
583                .loss_mask
584                .clone()
585                .reshape([b * s])
586                .into_data()
587                .into_vec::<f32>()
588                .map_err(|e| Error::Inference(e.to_string()))?;
589            let target_vec: Vec<i32> = batch
590                .targets
591                .clone()
592                .reshape([b * s])
593                .into_data()
594                .into_vec::<i32>()
595                .map_err(|e| Error::Inference(e.to_string()))?;
596            let logit_vec: Vec<f32> = logits
597                .reshape([b * s * v])
598                .into_data()
599                .into_vec::<f32>()
600                .map_err(|e| Error::Inference(e.to_string()))?;
601
602            for bi in 0..b {
603                for si in 0..s {
604                    let mi = bi * s + si;
605                    if mask_vec[mi] < 0.5 {
606                        continue;
607                    }
608                    total_tokens += 1;
609                    let base = bi * s * v + si * v;
610                    let mut best_idx = 0;
611                    let mut best_val = f32::NEG_INFINITY;
612                    for vi in 0..v {
613                        let val = logit_vec[base + vi];
614                        if val > best_val {
615                            best_val = val;
616                            best_idx = vi;
617                        }
618                    }
619                    if best_idx == target_vec[mi] as usize {
620                        total_correct += 1;
621                    }
622                }
623            }
624        }
625
626        let avg_loss = total_loss / num_batches as f64;
627        let perplexity = avg_loss.exp();
628        let accuracy = if total_tokens > 0 {
629            total_correct as f64 / total_tokens as f64
630        } else {
631            0.0
632        };
633
634        Ok(EvaluationResult {
635            loss: avg_loss as f32,
636            perplexity: perplexity as f32,
637            accuracy,
638            num_batches,
639            total_tokens,
640        })
641    }
642}
643
644impl<B> MultiscreenModel<B>
645where
646    B: AutodiffBackend,
647{
648    /// Trains this model directly on token sequences.
649    ///
650    /// The optional `on_step` callback is invoked after each optimizer step with
651    /// `(step_index, loss_value)`. Use it for progress logging, CSV export, etc.
652    pub fn train_token_sequences(
653        &mut self,
654        sequences: &[Vec<u32>],
655        training: &ModelTrainingConfig,
656        device: &B::Device,
657        mut on_step: impl FnMut(usize, f32),
658    ) -> Result<ModelTrainingReport> {
659        if training.batch_size == 0 {
660            return Err(Error::Training(
661                "batch_size must be greater than zero".to_string(),
662            ));
663        }
664        let windows = TrainingWindows::from_sequences(
665            sequences,
666            self.config().seq_len,
667            training.pad_token_id,
668        )?;
669        if windows.is_empty() {
670            return Err(Error::Training(
671                "training requires at least one sequence with two or more tokens".to_string(),
672            ));
673        }
674
675        let mut optimizer_config =
676            AdamWConfig::new().with_weight_decay(training.weight_decay as f32);
677        if let Some(max_norm) = training.grad_clip_norm.filter(|value| *value > 0.0) {
678            optimizer_config = optimizer_config
679                .with_grad_clipping(Some(GradientClippingConfig::Norm(max_norm as f32)));
680        }
681        let mut optimizer = optimizer_config.init::<B, Self>();
682        let mut model = self.clone();
683        let mut final_loss = f32::NAN;
684
685        for step in 0..training.steps {
686            let batch = windows.batch::<B>(step, training.batch_size, device)?;
687            let logits = model.forward(batch.inputs);
688            let loss = cross_entropy_loss_with_mask(logits, batch.targets, batch.loss_mask);
689            final_loss = tensor_scalar(loss.clone())?;
690            let grads = loss.backward();
691            let grads = GradientsParams::from_grads(grads, &model);
692            model = optimizer.step(training.learning_rate, model, grads);
693            on_step(step, final_loss);
694        }
695
696        if training.steps == 0 {
697            let batch = windows.batch::<B>(0, training.batch_size, device)?;
698            final_loss = tensor_scalar(cross_entropy_loss_with_mask(
699                model.forward(batch.inputs),
700                batch.targets,
701                batch.loss_mask,
702            ))?;
703        }
704
705        *self = model;
706
707        Ok(ModelTrainingReport {
708            steps: training.steps,
709            final_loss,
710            training_window_count: windows.len(),
711            parameter_count: self.parameter_count(),
712        })
713    }
714
715    /// Trains this model on chat-style (prompt, response) token-ID pairs.
716    ///
717    /// This is the chat-aware counterpart of [`train_token_sequences`]. The model
718    /// sees the full context (prompt + response) but loss is computed **only** on
719    /// the response tokens, preventing the model from learning to generate role
720    /// labels like `system:`, `user:`, or `assistant:`.
721    ///
722    /// Each element of `chat_pairs` is `(prompt_token_ids, response_token_ids)`.
723    /// The caller is responsible for appending an EOS token to the response IDs
724    /// when desired — the EOS token will receive `loss_mask = 1.0` like any other
725    /// response token.
726    pub fn train_chat_sequences(
727        &mut self,
728        chat_pairs: &[(Vec<u32>, Vec<u32>)],
729        training: &ModelTrainingConfig,
730        device: &B::Device,
731        mut on_step: impl FnMut(usize, f32),
732    ) -> Result<ModelTrainingReport> {
733        if training.batch_size == 0 {
734            return Err(Error::Training(
735                "batch_size must be greater than zero".to_string(),
736            ));
737        }
738        let windows = TrainingWindows::from_chat_sequences(
739            chat_pairs,
740            self.config().seq_len,
741            training.pad_token_id,
742        )?;
743        if windows.is_empty() {
744            return Err(Error::Training(
745                "training requires at least one chat pair that produces two or more tokens"
746                    .to_string(),
747            ));
748        }
749
750        let mut optimizer_config =
751            AdamWConfig::new().with_weight_decay(training.weight_decay as f32);
752        if let Some(max_norm) = training.grad_clip_norm.filter(|value| *value > 0.0) {
753            optimizer_config = optimizer_config
754                .with_grad_clipping(Some(GradientClippingConfig::Norm(max_norm as f32)));
755        }
756        let mut optimizer = optimizer_config.init::<B, Self>();
757        let mut model = self.clone();
758        let mut final_loss = f32::NAN;
759
760        for step in 0..training.steps {
761            let batch = windows.batch::<B>(step, training.batch_size, device)?;
762            let logits = model.forward(batch.inputs);
763            let loss = cross_entropy_loss_with_mask(logits, batch.targets, batch.loss_mask);
764            final_loss = tensor_scalar(loss.clone())?;
765            let grads = loss.backward();
766            let grads = GradientsParams::from_grads(grads, &model);
767            model = optimizer.step(training.learning_rate, model, grads);
768            on_step(step, final_loss);
769        }
770
771        if training.steps == 0 {
772            let batch = windows.batch::<B>(0, training.batch_size, device)?;
773            final_loss = tensor_scalar(cross_entropy_loss_with_mask(
774                model.forward(batch.inputs),
775                batch.targets,
776                batch.loss_mask,
777            ))?;
778        }
779
780        *self = model;
781
782        Ok(ModelTrainingReport {
783            steps: training.steps,
784            final_loss,
785            training_window_count: windows.len(),
786            parameter_count: self.parameter_count(),
787        })
788    }
789}
790
791impl<B: Backend> crate::lm::LanguageModel<B> for MultiscreenModel<B> {
792    fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
793        MultiscreenModel::forward(self, tokens)
794    }
795}
796
797impl<B: Backend> crate::lm::TrainableLanguageModel<B> for MultiscreenModel<B> {}
798
799impl<B: Backend> GatedScreeningTile<B> {
800    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
801        let q = row_unit_normalize(linear(x.clone(), self.w_q.val()));
802        let k = row_unit_normalize(linear(x.clone(), self.w_k.val()));
803        let v = row_unit_normalize(linear(x.clone(), self.w_v.val()));
804        let g = linear(x, self.w_g.val());
805
806        let w = self.s_w.val().clamp(-10.0, 8.0).exp() + 1.0;
807        let r = activation::sigmoid(self.s_r.val().clamp(-10.0, 8.0));
808
809        let q = apply_mipe(q, w.clone(), self.w_th);
810        let k = apply_mipe(k, w.clone(), self.w_th);
811
812        let similarity = q.matmul(k.swap_dims(1, 2));
813        let alpha = trim_and_square_tensor(similarity.clone(), r);
814        let softmask = causal_softmask_tensor::<B>(similarity.dims()[1], w, &similarity.device());
815        let relevance = alpha * softmask.unsqueeze();
816        let h = relevance.matmul(v);
817        let u = tanh_norm(h);
818        let gate = activation::silu(g).tanh();
819        let gated = u * gate;
820        let out = linear(gated, self.w_o.val());
821        out.clone() * expand_scalar3(self.s_o.val().exp(), out.dims())
822    }
823}
824
825#[allow(dead_code)]
826pub fn cross_entropy_loss<B: Backend>(
827    logits: Tensor<B, 3>,
828    targets: Tensor<B, 2, Int>,
829) -> Tensor<B, 1> {
830    let device = logits.device();
831    let [batch, seq_len, _] = logits.dims();
832    let loss_mask = Tensor::<B, 2>::ones([batch, seq_len], &device);
833    cross_entropy_loss_with_mask(logits, targets, loss_mask)
834}
835
836pub fn cross_entropy_loss_with_mask<B: Backend>(
837    logits: Tensor<B, 3>,
838    targets: Tensor<B, 2, Int>,
839    loss_mask: Tensor<B, 2>,
840) -> Tensor<B, 1> {
841    let [batch, seq_len, vocab_size] = logits.dims();
842    let token_count = batch * seq_len;
843    let flat_logits = logits.reshape([token_count, vocab_size]);
844    let flat_targets = targets.reshape([token_count]);
845    let flat_mask = loss_mask.reshape([token_count]);
846    let log_probs = activation::log_softmax(flat_logits, 1);
847    let target_probs = flat_targets.one_hot::<2>(vocab_size).float();
848    let picked = (log_probs * target_probs).sum_dim(1).reshape([token_count]);
849    let denom = flat_mask.clone().sum().add_scalar(EPS);
850    (picked.neg() * flat_mask).sum() / denom
851}
852
853pub fn row_unit_normalize<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
854    let denom = x.clone().square().sum_dim(D - 1).add_scalar(EPS).sqrt();
855    x / denom
856}
857
858fn trim_and_square_tensor<B: Backend>(similarity: Tensor<B, 3>, r: Tensor<B, 1>) -> Tensor<B, 3> {
859    let distance_from_one = similarity.mul_scalar(-1.0).add_scalar(1.0);
860    let scaled = distance_from_one.clone() / expand_scalar3(r, distance_from_one.clone().dims());
861    scaled
862        .mul_scalar(-1.0)
863        .add_scalar(1.0)
864        .clamp(0.0, 1.0)
865        .square()
866}
867
868fn causal_softmask_tensor<B: Backend>(
869    seq_len: usize,
870    w: Tensor<B, 1>,
871    device: &B::Device,
872) -> Tensor<B, 2> {
873    let mut distances = Vec::with_capacity(seq_len * seq_len);
874    for i in 0..seq_len {
875        for j in 0..seq_len {
876            distances.push(j as f32 - i as f32);
877        }
878    }
879
880    let dist = Tensor::<B, 2>::from_data(TensorData::new(distances, [seq_len, seq_len]), device);
881    let w = expand_scalar2(w, [seq_len, seq_len]);
882    let causal = dist.clone().lower_equal_elem(0.0);
883    let within_window = dist.clone().greater(w.clone().neg());
884    let active = causal.float() * within_window.float();
885    let tapered = (dist / w)
886        .mul_scalar(PI)
887        .cos()
888        .mul_scalar(0.5)
889        .add_scalar(0.5);
890    active * tapered
891}
892
893pub fn tanh_norm<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
894    let norm = x.clone().square().sum_dim(D - 1).add_scalar(EPS).sqrt();
895    let scale = norm.clone().tanh() / norm;
896    x * scale
897}
898
899fn apply_mipe<B: Backend>(z: Tensor<B, 3>, w: Tensor<B, 1>, w_th: f32) -> Tensor<B, 3> {
900    let [_batch, seq_len, d_key] = z.dims();
901    debug_assert!(d_key >= 2);
902
903    let positions = Tensor::<B, 3>::from_data(
904        TensorData::new(
905            (0..seq_len).map(|idx| idx as f32).collect::<Vec<_>>(),
906            [1, seq_len, 1],
907        ),
908        &z.device(),
909    );
910
911    let gamma_raw = w
912        .clone()
913        .mul_scalar(PI / w_th)
914        .cos()
915        .mul_scalar(0.5)
916        .add_scalar(0.5);
917    let gamma = gamma_raw.mask_fill(w.clone().greater_equal_elem(w_th), 0.0);
918    let gamma = expand_scalar3(gamma, [1, seq_len, 1]);
919    let w = expand_scalar3(w, [1, seq_len, 1]);
920    let phi = (positions * gamma / w).mul_scalar(PI);
921    let cos_phi = phi.clone().cos();
922    let sin_phi = phi.sin();
923
924    let x0 = z.clone().narrow(2, 0, 1);
925    let x1 = z.clone().narrow(2, 1, 1);
926    let rot0 = x0.clone() * cos_phi.clone() - x1.clone() * sin_phi.clone();
927    let rot1 = x0 * sin_phi + x1 * cos_phi;
928
929    if d_key == 2 {
930        Tensor::cat(vec![rot0, rot1], 2)
931    } else {
932        let rest = z.narrow(2, 2, d_key - 2);
933        Tensor::cat(vec![rot0, rot1, rest], 2)
934    }
935}
936
937fn linear<B: Backend>(x: Tensor<B, 3>, weight: Tensor<B, 2>) -> Tensor<B, 3> {
938    let [batch, _seq_len, _in_dim] = x.dims();
939    let [weight_in, out_dim] = weight.dims();
940    x.matmul(weight.unsqueeze::<3>().expand([batch, weight_in, out_dim]))
941}
942
943fn expand_scalar2<B: Backend>(value: Tensor<B, 1>, dims: [usize; 2]) -> Tensor<B, 2> {
944    value.unsqueeze::<2>().expand(dims)
945}
946
947fn expand_scalar3<B: Backend>(value: Tensor<B, 1>, dims: [usize; 3]) -> Tensor<B, 3> {
948    value.unsqueeze::<3>().expand(dims)
949}
950
951fn init_scalar<B: Backend>(value: f32, device: &B::Device) -> Param<Tensor<B, 1>> {
952    Param::from_tensor(Tensor::<B, 1>::from_data([value], device))
953}
954
955fn init_matrix<B: Backend>(
956    rows: usize,
957    cols: usize,
958    std: f32,
959    seed: &mut u64,
960    device: &B::Device,
961) -> Param<Tensor<B, 2>> {
962    let values = gaussian_values(rows * cols, std, seed);
963    Param::from_tensor(Tensor::<B, 2>::from_data(
964        TensorData::new(values, [rows, cols]),
965        device,
966    ))
967}
968
969fn gaussian_values(len: usize, std: f32, seed: &mut u64) -> Vec<f32> {
970    let mut values = Vec::with_capacity(len);
971    while values.len() < len {
972        let u1 = next_uniform(seed).max(1e-7);
973        let u2 = next_uniform(seed);
974        let radius = (-2.0 * u1.ln()).sqrt();
975        let theta = 2.0 * PI * u2;
976        values.push(radius * theta.cos() * std);
977        if values.len() < len {
978            values.push(radius * theta.sin() * std);
979        }
980    }
981    values
982}
983
984fn next_uniform(seed: &mut u64) -> f32 {
985    *seed = seed
986        .wrapping_mul(6364136223846793005)
987        .wrapping_add(1442695040888963407);
988    let bits = (*seed >> 40) as u32;
989    (bits as f32 + 1.0) / ((1u32 << 24) as f32 + 2.0)
990}
991
992fn context_window(context: &[u32], seq_len: usize, pad_token_id: u32) -> Vec<u32> {
993    let mut input = vec![pad_token_id; seq_len];
994    let suffix = if context.len() > seq_len {
995        &context[context.len() - seq_len..]
996    } else {
997        context
998    };
999    let start = seq_len - suffix.len();
1000    input[start..].copy_from_slice(suffix);
1001    input
1002}
1003
1004fn argmax(values: &[f32]) -> Result<usize> {
1005    values
1006        .iter()
1007        .enumerate()
1008        .max_by(|left, right| left.1.total_cmp(right.1))
1009        .map(|(index, _)| index)
1010        .ok_or_else(|| Error::Inference("cannot argmax empty logits".to_string()))
1011}
1012
1013fn tensor_scalar<B: Backend>(tensor: Tensor<B, 1>) -> Result<f32> {
1014    tensor
1015        .into_data()
1016        .into_vec::<f32>()
1017        .map_err(|err| Error::Training(err.to_string()))?
1018        .into_iter()
1019        .next()
1020        .ok_or_else(|| Error::Training("expected scalar tensor".to_string()))
1021}
1022
1023fn tensor_to_vec<B: Backend>(tensor: Tensor<B, 1>) -> Result<Vec<f32>> {
1024    tensor
1025        .into_data()
1026        .into_vec::<f32>()
1027        .map_err(|err| Error::Inference(err.to_string()))
1028}
1029
1030fn tensor_from_u32<B: Backend, const D: usize>(
1031    values: Vec<u32>,
1032    shape: [usize; D],
1033    device: &B::Device,
1034) -> Result<Tensor<B, D, Int>> {
1035    let values = values
1036        .into_iter()
1037        .map(|value| {
1038            i32::try_from(value)
1039                .map_err(|_| Error::Config(format!("token id {value} exceeds i32::MAX")))
1040        })
1041        .collect::<Result<Vec<_>>>()?;
1042    Ok(Tensor::<B, D, Int>::from_data(
1043        TensorData::new(values, shape),
1044        device,
1045    ))
1046}
1047
1048fn ensure(condition: bool, message: &str) -> Result<()> {
1049    if condition {
1050        Ok(())
1051    } else {
1052        Err(Error::Config(message.to_string()))
1053    }
1054}
1055
1056struct TrainingWindow {
1057    inputs: Vec<u32>,
1058    targets: Vec<u32>,
1059    loss_mask: Vec<f32>,
1060}
1061
1062struct TrainingWindows {
1063    windows: Vec<TrainingWindow>,
1064    seq_len: usize,
1065}
1066
1067impl TrainingWindows {
1068    /// Creates training windows from chat-style (prompt, response) pairs with loss masking.
1069    ///
1070    /// Tokens belonging to the prompt portion have `loss_mask = 0.0` so the model
1071    /// sees them as context but does not learn to generate them. Response tokens
1072    /// (including any EOS appended by the caller) have `loss_mask = 1.0`.
1073    fn from_chat_sequences(
1074        chat_pairs: &[(Vec<u32>, Vec<u32>)],
1075        seq_len: usize,
1076        pad_token_id: u32,
1077    ) -> Result<Self> {
1078        let mut windows = Vec::new();
1079        for (prompt_ids, response_ids) in chat_pairs {
1080            let mut full_seq = prompt_ids.clone();
1081            full_seq.extend_from_slice(response_ids);
1082
1083            if full_seq.len() < 2 {
1084                continue;
1085            }
1086
1087            let prompt_len = prompt_ids.len();
1088
1089            let mut start = 0;
1090            while start + 1 < full_seq.len() {
1091                let end = (start + seq_len + 1).min(full_seq.len());
1092                let chunk = &full_seq[start..end];
1093                let prediction_count = chunk.len() - 1;
1094
1095                let mut inputs = vec![pad_token_id; seq_len];
1096                let mut targets = vec![pad_token_id; seq_len];
1097                let mut loss_mask = vec![0.0; seq_len];
1098                inputs[..prediction_count].copy_from_slice(&chunk[..prediction_count]);
1099                targets[..prediction_count].copy_from_slice(&chunk[1..]);
1100
1101                // Mask: only compute loss when the *target* token falls within the
1102                // response portion of the full sequence.
1103                for (i, mask) in loss_mask.iter_mut().enumerate().take(prediction_count) {
1104                    let target_global_idx = start + i + 1;
1105                    if target_global_idx >= prompt_len {
1106                        *mask = 1.0;
1107                    }
1108                }
1109
1110                windows.push(TrainingWindow {
1111                    inputs,
1112                    targets,
1113                    loss_mask,
1114                });
1115
1116                if end == full_seq.len() {
1117                    break;
1118                }
1119                start += seq_len;
1120            }
1121        }
1122
1123        Ok(Self { windows, seq_len })
1124    }
1125
1126    fn from_sequences(sequences: &[Vec<u32>], seq_len: usize, pad_token_id: u32) -> Result<Self> {
1127        let mut windows = Vec::new();
1128        for sequence in sequences {
1129            if sequence.len() < 2 {
1130                continue;
1131            }
1132
1133            let mut start = 0;
1134            while start + 1 < sequence.len() {
1135                let end = (start + seq_len + 1).min(sequence.len());
1136                let chunk = &sequence[start..end];
1137                let prediction_count = chunk.len() - 1;
1138
1139                let mut inputs = vec![pad_token_id; seq_len];
1140                let mut targets = vec![pad_token_id; seq_len];
1141                let mut loss_mask = vec![0.0; seq_len];
1142                inputs[..prediction_count].copy_from_slice(&chunk[..prediction_count]);
1143                targets[..prediction_count].copy_from_slice(&chunk[1..]);
1144                loss_mask[..prediction_count].fill(1.0);
1145
1146                windows.push(TrainingWindow {
1147                    inputs,
1148                    targets,
1149                    loss_mask,
1150                });
1151
1152                if end == sequence.len() {
1153                    break;
1154                }
1155                start += seq_len;
1156            }
1157        }
1158
1159        Ok(Self { windows, seq_len })
1160    }
1161
1162    fn is_empty(&self) -> bool {
1163        self.windows.is_empty()
1164    }
1165
1166    fn len(&self) -> usize {
1167        self.windows.len()
1168    }
1169
1170    fn batch<B: Backend>(
1171        &self,
1172        step: usize,
1173        batch_size: usize,
1174        device: &B::Device,
1175    ) -> Result<TokenBatch<B>> {
1176        let mut inputs = Vec::with_capacity(batch_size * self.seq_len);
1177        let mut targets = Vec::with_capacity(batch_size * self.seq_len);
1178        let mut loss_mask = Vec::with_capacity(batch_size * self.seq_len);
1179
1180        for batch_idx in 0..batch_size {
1181            let index = (step * batch_size + batch_idx) % self.windows.len();
1182            let window = &self.windows[index];
1183            inputs.extend_from_slice(&window.inputs);
1184            targets.extend_from_slice(&window.targets);
1185            loss_mask.extend_from_slice(&window.loss_mask);
1186        }
1187
1188        Ok(TokenBatch {
1189            inputs: tensor_from_u32(inputs, [batch_size, self.seq_len], device)?,
1190            targets: tensor_from_u32(targets, [batch_size, self.seq_len], device)?,
1191            loss_mask: Tensor::<B, 2>::from_data(
1192                TensorData::new(loss_mask, [batch_size, self.seq_len]),
1193                device,
1194            ),
1195        })
1196    }
1197}
1198
1199struct TokenBatch<B: Backend> {
1200    inputs: Tensor<B, 2, Int>,
1201    targets: Tensor<B, 2, Int>,
1202    loss_mask: Tensor<B, 2>,
1203}
1204
1205#[cfg(test)]
1206#[allow(dead_code)]
1207pub fn make_batch<B: Backend>(
1208    step: usize,
1209    batch_size: usize,
1210    seq_len: usize,
1211    vocab_size: usize,
1212    device: &B::Device,
1213) -> Result<(Tensor<B, 2, Int>, Tensor<B, 2, Int>)> {
1214    let mut inputs = Vec::with_capacity(batch_size * seq_len);
1215    let mut targets = Vec::with_capacity(batch_size * seq_len);
1216
1217    for batch in 0..batch_size {
1218        let offset = (step * 7 + batch * 13) % vocab_size;
1219        for pos in 0..seq_len {
1220            let token = ((offset + pos) % vocab_size) as u32;
1221            let next = ((offset + pos + 1) % vocab_size) as u32;
1222            inputs.push(token);
1223            targets.push(next);
1224        }
1225    }
1226
1227    Ok((
1228        tensor_from_u32(inputs, [batch_size, seq_len], device)?,
1229        tensor_from_u32(targets, [batch_size, seq_len], device)?,
1230    ))
1231}