Skip to main content

multiscreen_rs/
model.rs

1use crate::{
2    error::{Error, Result},
3    runtime::DefaultAutodiffBackend,
4};
5use burn::{
6    grad_clipping::GradientClippingConfig,
7    module::{Module, Param},
8    optim::{AdamWConfig, GradientsParams, Optimizer},
9    record::{FullPrecisionSettings, NamedMpkFileRecorder},
10    tensor::{
11        activation,
12        backend::{AutodiffBackend, Backend},
13        Int, Tensor, TensorData,
14    },
15};
16use serde::{Deserialize, Serialize};
17use std::f32::consts::PI;
18use std::path::Path;
19
20const EPS: f32 = 1e-6;
21
22/// Supported neural Multiscreen parameter budgets.
23///
24/// The final count is approximate because the token embedding scales with the
25/// caller's vocabulary size. Use `MultiscreenModelConfig::estimated_parameter_count`
26/// to see the exact count for a resolved config.
27#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
28pub enum MultiscreenParameterBudget {
29    Params1M,
30    Params5M,
31    Params10M,
32    Params50M,
33    Params100M,
34}
35
36impl MultiscreenParameterBudget {
37    pub const ALL: [Self; 5] = [
38        Self::Params1M,
39        Self::Params5M,
40        Self::Params10M,
41        Self::Params50M,
42        Self::Params100M,
43    ];
44
45    pub fn label(self) -> &'static str {
46        match self {
47            Self::Params1M => "1M",
48            Self::Params5M => "5M",
49            Self::Params10M => "10M",
50            Self::Params50M => "50M",
51            Self::Params100M => "100M",
52        }
53    }
54
55    pub fn target_parameter_count(self) -> usize {
56        match self {
57            Self::Params1M => 1_000_000,
58            Self::Params5M => 5_000_000,
59            Self::Params10M => 10_000_000,
60            Self::Params50M => 50_000_000,
61            Self::Params100M => 100_000_000,
62        }
63    }
64}
65
66/// Paper-faithful neural Multiscreen model configuration.
67#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
68pub struct MultiscreenModelConfig {
69    /// Vocabulary size.
70    pub vocab_size: usize,
71    /// Context length `T`.
72    pub seq_len: usize,
73    /// Paper `N_L`: number of residual Multiscreen layers.
74    pub layers: usize,
75    /// Paper `N_H`: number of gated screening tiles per layer.
76    pub tiles: usize,
77    /// Paper `d_E`: token embedding/model width.
78    pub d_model: usize,
79    /// Paper `d_K`: query/key width. Must be at least 2 for MiPE.
80    pub d_key: usize,
81    /// Paper `d_V`: value/gate width.
82    pub d_value: usize,
83    /// Paper MiPE threshold `w_th`.
84    pub w_th: f32,
85}
86
87impl MultiscreenModelConfig {
88    pub fn tiny() -> Self {
89        Self {
90            vocab_size: 64,
91            seq_len: 64,
92            layers: 2,
93            tiles: 4,
94            d_model: 64,
95            d_key: 16,
96            d_value: 32,
97            w_th: 32.0,
98        }
99    }
100
101    pub fn tiny_for_tests() -> Self {
102        Self {
103            vocab_size: 32,
104            seq_len: 8,
105            layers: 1,
106            tiles: 2,
107            d_model: 16,
108            d_key: 4,
109            d_value: 8,
110            w_th: 8.0,
111        }
112    }
113
114    pub fn for_parameter_budget(
115        budget: MultiscreenParameterBudget,
116        vocab_size: usize,
117        seq_len: usize,
118    ) -> Self {
119        match budget {
120            MultiscreenParameterBudget::Params1M => Self::preset_1m(vocab_size, seq_len),
121            MultiscreenParameterBudget::Params5M => Self::preset_5m(vocab_size, seq_len),
122            MultiscreenParameterBudget::Params10M => Self::preset_10m(vocab_size, seq_len),
123            MultiscreenParameterBudget::Params50M => Self::preset_50m(vocab_size, seq_len),
124            MultiscreenParameterBudget::Params100M => Self::preset_100m(vocab_size, seq_len),
125        }
126    }
127
128    pub fn preset_1m(vocab_size: usize, seq_len: usize) -> Self {
129        Self::from_dimensions(vocab_size, seq_len, 2, 2, 128, 32, 64)
130    }
131
132    pub fn preset_5m(vocab_size: usize, seq_len: usize) -> Self {
133        Self::from_dimensions(vocab_size, seq_len, 2, 4, 384, 96, 192)
134    }
135
136    pub fn preset_10m(vocab_size: usize, seq_len: usize) -> Self {
137        Self::from_dimensions(vocab_size, seq_len, 3, 4, 512, 128, 256)
138    }
139
140    pub fn preset_50m(vocab_size: usize, seq_len: usize) -> Self {
141        Self::from_dimensions(vocab_size, seq_len, 6, 4, 960, 240, 480)
142    }
143
144    pub fn preset_100m(vocab_size: usize, seq_len: usize) -> Self {
145        Self::from_dimensions(vocab_size, seq_len, 8, 4, 1216, 304, 608)
146    }
147
148    pub fn paper_10m(vocab_size: usize, seq_len: usize) -> Self {
149        Self::preset_10m(vocab_size, seq_len)
150    }
151
152    pub fn estimated_parameter_count(&self) -> usize {
153        let embedding_params = self.vocab_size.saturating_mul(self.d_model);
154        let per_tile_params = self
155            .d_model
156            .saturating_mul(
157                2usize
158                    .saturating_mul(self.d_key)
159                    .saturating_add(3usize.saturating_mul(self.d_value)),
160            )
161            .saturating_add(3);
162        let tile_params = self
163            .layers
164            .saturating_mul(self.tiles)
165            .saturating_mul(per_tile_params);
166        embedding_params
167            .saturating_add(2)
168            .saturating_add(tile_params)
169    }
170
171    fn from_dimensions(
172        vocab_size: usize,
173        seq_len: usize,
174        layers: usize,
175        tiles: usize,
176        d_model: usize,
177        d_key: usize,
178        d_value: usize,
179    ) -> Self {
180        Self {
181            vocab_size,
182            seq_len,
183            layers,
184            tiles,
185            d_model,
186            d_key,
187            d_value,
188            w_th: 32.0,
189        }
190    }
191
192    pub fn validate(&self) -> Result<()> {
193        ensure(self.vocab_size > 0, "vocab_size must be greater than zero")?;
194        ensure(self.seq_len > 0, "seq_len must be greater than zero")?;
195        ensure(self.layers > 0, "layers must be greater than zero")?;
196        ensure(self.tiles > 0, "tiles must be greater than zero")?;
197        ensure(self.d_model > 0, "d_model must be greater than zero")?;
198        ensure(self.d_key >= 2, "d_key must be at least 2 for MiPE")?;
199        ensure(self.d_value > 0, "d_value must be greater than zero")?;
200        ensure(
201            self.w_th.is_finite() && self.w_th > 0.0,
202            "w_th must be positive and finite",
203        )?;
204        Ok(())
205    }
206}
207
208/// Training options for token-sequence training.
209#[derive(Clone, Debug)]
210pub struct ModelTrainingConfig {
211    pub steps: usize,
212    pub batch_size: usize,
213    pub learning_rate: f64,
214    pub weight_decay: f64,
215    pub grad_clip_norm: Option<f64>,
216    pub pad_token_id: u32,
217}
218
219impl Default for ModelTrainingConfig {
220    fn default() -> Self {
221        Self {
222            steps: 100,
223            batch_size: 4,
224            learning_rate: 2e-4,
225            weight_decay: 0.01,
226            grad_clip_norm: Some(1.0),
227            pad_token_id: 0,
228        }
229    }
230}
231
232/// Greedy generation options.
233#[derive(Clone, Debug)]
234pub struct ModelInferenceConfig {
235    pub max_new_tokens: usize,
236    pub pad_token_id: u32,
237}
238
239impl Default for ModelInferenceConfig {
240    fn default() -> Self {
241        Self {
242            max_new_tokens: 16,
243            pad_token_id: 0,
244        }
245    }
246}
247
248#[derive(Clone, Debug, PartialEq)]
249pub struct ModelTrainingReport {
250    pub steps: usize,
251    pub final_loss: f32,
252    pub training_window_count: usize,
253    pub parameter_count: usize,
254}
255
256/// Result of evaluating a model on held-out sequences.
257#[derive(Clone, Debug)]
258pub struct EvaluationResult {
259    /// Average cross-entropy loss across all batches.
260    pub loss: f32,
261    /// Perplexity = exp(average_loss).
262    pub perplexity: f32,
263    /// Fraction of tokens where argmax(logits) == target (next-token accuracy).
264    pub accuracy: f64,
265    /// Number of batches evaluated.
266    pub num_batches: usize,
267    /// Total number of (unmasked) tokens evaluated.
268    pub total_tokens: usize,
269}
270
271pub struct MultiscreenModelOutput {
272    pub token_ids: Vec<u32>,
273}
274
275/// Burn-backed neural Multiscreen language model ported from
276/// `multiscreen-testing`.
277#[derive(Module, Debug)]
278pub struct MultiscreenModel<B: Backend = DefaultAutodiffBackend> {
279    #[module(skip)]
280    config: MultiscreenModelConfig,
281    token_embedding: Param<Tensor<B, 2>>,
282    s_e: Param<Tensor<B, 1>>,
283    s_f: Param<Tensor<B, 1>>,
284    layers: Vec<MultiscreenLayer<B>>,
285}
286
287/// Convenience alias for the default Burn Flex autodiff model.
288pub type DefaultMultiscreenModel = MultiscreenModel<DefaultAutodiffBackend>;
289
290#[derive(Module, Debug)]
291struct MultiscreenLayer<B: Backend> {
292    tiles: Vec<GatedScreeningTile<B>>,
293}
294
295#[derive(Module, Debug)]
296struct GatedScreeningTile<B: Backend> {
297    w_q: Param<Tensor<B, 2>>,
298    w_k: Param<Tensor<B, 2>>,
299    w_v: Param<Tensor<B, 2>>,
300    w_g: Param<Tensor<B, 2>>,
301    w_o: Param<Tensor<B, 2>>,
302    s_w: Param<Tensor<B, 1>>,
303    s_r: Param<Tensor<B, 1>>,
304    s_o: Param<Tensor<B, 1>>,
305    #[module(skip)]
306    w_th: f32,
307}
308
309impl<B: Backend> MultiscreenModel<B> {
310    pub fn new(config: MultiscreenModelConfig, device: &B::Device) -> Result<Self> {
311        config.validate()?;
312
313        let mut seed = 0x4d55_4c54_4953_4352;
314        let token_embedding = init_matrix(
315            config.vocab_size,
316            config.d_model,
317            0.1 / (config.d_model as f32).sqrt(),
318            &mut seed,
319            device,
320        );
321        let s_e = init_scalar(0.0, device);
322        let s_f = init_scalar((config.d_model as f32).sqrt().ln(), device);
323
324        let mut layers = Vec::with_capacity(config.layers);
325        for _layer_idx in 0..config.layers {
326            let mut tiles = Vec::with_capacity(config.tiles);
327            for tile_idx in 0..config.tiles {
328                let w_q = init_matrix(
329                    config.d_model,
330                    config.d_key,
331                    0.1 / (config.d_key as f32).sqrt(),
332                    &mut seed,
333                    device,
334                );
335                let w_k = init_matrix(
336                    config.d_model,
337                    config.d_key,
338                    0.1 / (config.d_key as f32).sqrt(),
339                    &mut seed,
340                    device,
341                );
342                let w_v = init_matrix(
343                    config.d_model,
344                    config.d_value,
345                    0.1 / (config.d_value as f32).sqrt(),
346                    &mut seed,
347                    device,
348                );
349                let w_g = init_matrix(config.d_model, config.d_value, 0.1, &mut seed, device);
350                let w_o = init_matrix(
351                    config.d_value,
352                    config.d_model,
353                    0.1 / (config.d_model as f32).sqrt(),
354                    &mut seed,
355                    device,
356                );
357
358                let window_frac = if config.tiles == 1 {
359                    0.0
360                } else {
361                    tile_idx as f32 / (config.tiles - 1) as f32
362                };
363                let s_w = init_scalar(window_frac * config.w_th.ln(), device);
364                let s_r = init_scalar(0.0, device);
365                let s_o = init_scalar(-0.5 * ((config.layers * config.tiles) as f32).ln(), device);
366
367                tiles.push(GatedScreeningTile {
368                    w_q,
369                    w_k,
370                    w_v,
371                    w_g,
372                    w_o,
373                    s_w,
374                    s_r,
375                    s_o,
376                    w_th: config.w_th,
377                });
378            }
379            layers.push(MultiscreenLayer { tiles });
380        }
381
382        Ok(Self {
383            config,
384            token_embedding,
385            s_e,
386            s_f,
387            layers,
388        })
389    }
390
391    pub fn config(&self) -> &MultiscreenModelConfig {
392        &self.config
393    }
394
395    pub fn parameter_count(&self) -> usize {
396        self.num_params()
397    }
398
399    pub fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
400        let [batch, seq_len] = tokens.dims();
401        let embedding = row_unit_normalize(self.token_embedding.val());
402        let one_hot = tokens.one_hot::<3>(self.config().vocab_size).float();
403        let mut x =
404            linear(one_hot, embedding.clone()).reshape([batch, seq_len, self.config().d_model]);
405        let x_dims = x.dims();
406        x = x * expand_scalar3(self.s_e.val().exp(), x_dims);
407
408        for layer in &self.layers {
409            let mut layer_update = Tensor::<B, 3>::zeros(x.dims(), &x.device());
410            for tile in &layer.tiles {
411                layer_update = layer_update + tile.forward(x.clone());
412            }
413            x = x + layer_update;
414        }
415
416        let logits_weight = embedding.swap_dims(0, 1);
417        let logits = linear(x, logits_weight);
418        logits.clone() * expand_scalar3(self.s_f.val().exp(), logits.dims())
419    }
420
421    pub fn save_parameters(&self, path: impl AsRef<Path>) -> Result<()> {
422        let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
423        self.clone()
424            .save_file(path.as_ref().to_path_buf(), &recorder)
425            .map_err(|err| Error::Serialization(err.to_string()))
426    }
427
428    pub fn load_parameters(&mut self, path: impl AsRef<Path>) -> Result<()> {
429        let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
430        let device =
431            self.devices().into_iter().next().ok_or_else(|| {
432                Error::Serialization("model has no device for parameter load".into())
433            })?;
434        let loaded = self
435            .clone()
436            .load_file(path.as_ref().to_path_buf(), &recorder, &device)
437            .map_err(|err| Error::Serialization(err.to_string()))?;
438        *self = loaded;
439        Ok(())
440    }
441
442    /// Greedy token generation.
443    /// Generate tokens one at a time, invoking a callback for each newly
444    /// produced token. This enables streaming / word-by-word output similar
445    /// to ChatGPT.
446    ///
447    /// The callback receives `(token_id, index)` where `index` is the
448    /// zero-based position of the *new* token (0 = first generated token).
449    /// If the callback returns `false`, generation stops early.
450    ///
451    /// Returns the full output (prompt + generated) token sequence.
452    pub fn infer_tokens_stream(
453        &self,
454        prompt: &[u32],
455        inference: &ModelInferenceConfig,
456        device: &B::Device,
457        mut on_token: impl FnMut(u32, usize) -> bool,
458    ) -> Result<MultiscreenModelOutput> {
459        if prompt.is_empty() {
460            return Err(Error::Inference(
461                "prompt must contain at least one token".to_string(),
462            ));
463        }
464
465        let mut output = prompt.to_vec();
466        for i in 0..inference.max_new_tokens {
467            let next = self.predict_next_token(&output, inference.pad_token_id, device)?;
468            output.push(next);
469            if !on_token(next, i) {
470                break;
471            }
472        }
473
474        Ok(MultiscreenModelOutput { token_ids: output })
475    }
476
477    /// Generate tokens and return them all at once (non-streaming).
478    ///
479    /// For streaming / token-by-token output, use [`Self::infer_tokens_stream`].
480    pub fn infer_tokens(
481        &self,
482        prompt: &[u32],
483        inference: &ModelInferenceConfig,
484        device: &B::Device,
485    ) -> Result<MultiscreenModelOutput> {
486        self.infer_tokens_stream(prompt, inference, device, |_, _| true)
487    }
488
489    pub fn predict_next_token(
490        &self,
491        context: &[u32],
492        pad_token_id: u32,
493        device: &B::Device,
494    ) -> Result<u32> {
495        if context.is_empty() {
496            return Err(Error::Inference(
497                "context must contain at least one token".to_string(),
498            ));
499        }
500
501        let input = context_window(context, self.config().seq_len, pad_token_id);
502        let input = tensor_from_u32::<B, 2>(input, [1, self.config().seq_len], device)?;
503        let logits = self.forward(input);
504        let last_logits = logits
505            .slice([
506                0..1,
507                self.config().seq_len - 1..self.config().seq_len,
508                0..self.config().vocab_size,
509            ])
510            .reshape([self.config().vocab_size]);
511        let values = tensor_to_vec(last_logits)?;
512        argmax(&values).map(|idx| idx as u32)
513    }
514}
515
516impl<B> MultiscreenModel<B>
517where
518    B: AutodiffBackend,
519{
520    /// Trains this model directly on token sequences.
521    ///
522    /// The optional `on_step` callback is invoked after each optimizer step with
523    /// `(step_index, loss_value)`. Use it for progress logging, CSV export, etc.
524    pub fn train_token_sequences(
525        &mut self,
526        sequences: &[Vec<u32>],
527        training: &ModelTrainingConfig,
528        device: &B::Device,
529        mut on_step: impl FnMut(usize, f32),
530    ) -> Result<ModelTrainingReport> {
531        if training.batch_size == 0 {
532            return Err(Error::Training(
533                "batch_size must be greater than zero".to_string(),
534            ));
535        }
536        let windows = TrainingWindows::from_sequences(
537            sequences,
538            self.config().seq_len,
539            training.pad_token_id,
540        )?;
541        if windows.is_empty() {
542            return Err(Error::Training(
543                "training requires at least one sequence with two or more tokens".to_string(),
544            ));
545        }
546
547        let mut optimizer_config =
548            AdamWConfig::new().with_weight_decay(training.weight_decay as f32);
549        if let Some(max_norm) = training.grad_clip_norm.filter(|value| *value > 0.0) {
550            optimizer_config = optimizer_config
551                .with_grad_clipping(Some(GradientClippingConfig::Norm(max_norm as f32)));
552        }
553        let mut optimizer = optimizer_config.init::<B, Self>();
554        let mut model = self.clone();
555        let mut final_loss = f32::NAN;
556
557        for step in 0..training.steps {
558            let batch = windows.batch::<B>(step, training.batch_size, device)?;
559            let logits = model.forward(batch.inputs);
560            let loss = cross_entropy_loss_with_mask(logits, batch.targets, batch.loss_mask);
561            final_loss = tensor_scalar(loss.clone())?;
562            let grads = loss.backward();
563            let grads = GradientsParams::from_grads(grads, &model);
564            model = optimizer.step(training.learning_rate, model, grads);
565            on_step(step, final_loss);
566        }
567
568        if training.steps == 0 {
569            let batch = windows.batch::<B>(0, training.batch_size, device)?;
570            final_loss = tensor_scalar(cross_entropy_loss_with_mask(
571                model.forward(batch.inputs),
572                batch.targets,
573                batch.loss_mask,
574            ))?;
575        }
576
577        *self = model;
578
579        Ok(ModelTrainingReport {
580            steps: training.steps,
581            final_loss,
582            training_window_count: windows.len(),
583            parameter_count: self.parameter_count(),
584        })
585    }
586
587    /// Evaluates the model on token sequences, returning average loss,
588    /// perplexity, and next-token prediction accuracy.
589    pub fn evaluate_on_sequences(
590        &self,
591        sequences: &[Vec<u32>],
592        seq_len: usize,
593        batch_size: usize,
594        pad_token_id: u32,
595        device: &B::Device,
596    ) -> Result<EvaluationResult> {
597        let windows = TrainingWindows::from_sequences(sequences, seq_len, pad_token_id)?;
598        if windows.is_empty() {
599            return Ok(EvaluationResult {
600                loss: f32::NAN,
601                perplexity: f32::NAN,
602                accuracy: 0.0,
603                num_batches: 0,
604                total_tokens: 0,
605            });
606        }
607
608        let num_batches = windows.len().div_ceil(batch_size);
609        let mut total_loss = 0.0_f64;
610        let mut total_correct = 0_usize;
611        let mut total_tokens = 0_usize;
612
613        for step in 0..num_batches {
614            let batch = windows.batch::<B>(step, batch_size, device)?;
615            let logits = self.forward(batch.inputs); // [B, S, V]
616            let loss = cross_entropy_loss_with_mask(
617                logits.clone(),
618                batch.targets.clone(),
619                batch.loss_mask.clone(),
620            );
621            let loss_val = tensor_scalar(loss)? as f64;
622            total_loss += loss_val;
623
624            // Compute accuracy: argmax(logits) == target, masked
625            let [b, s, v] = logits.dims();
626            let mask_vec: Vec<f32> = batch
627                .loss_mask
628                .clone()
629                .reshape([b * s])
630                .into_data()
631                .into_vec::<f32>()
632                .map_err(|e| Error::Inference(e.to_string()))?;
633            let target_vec: Vec<i32> = batch
634                .targets
635                .clone()
636                .reshape([b * s])
637                .into_data()
638                .into_vec::<i32>()
639                .map_err(|e| Error::Inference(e.to_string()))?;
640            let logit_vec: Vec<f32> = logits
641                .reshape([b * s * v])
642                .into_data()
643                .into_vec::<f32>()
644                .map_err(|e| Error::Inference(e.to_string()))?;
645
646            for bi in 0..b {
647                for si in 0..s {
648                    let mi = bi * s + si;
649                    if mask_vec[mi] < 0.5 {
650                        continue;
651                    }
652                    total_tokens += 1;
653                    let base = bi * s * v + si * v;
654                    let mut best_idx = 0;
655                    let mut best_val = f32::NEG_INFINITY;
656                    for vi in 0..v {
657                        let val = logit_vec[base + vi];
658                        if val > best_val {
659                            best_val = val;
660                            best_idx = vi;
661                        }
662                    }
663                    if best_idx == target_vec[mi] as usize {
664                        total_correct += 1;
665                    }
666                }
667            }
668        }
669
670        let avg_loss = total_loss / num_batches as f64;
671        let perplexity = avg_loss.exp();
672        let accuracy = if total_tokens > 0 {
673            total_correct as f64 / total_tokens as f64
674        } else {
675            0.0
676        };
677
678        Ok(EvaluationResult {
679            loss: avg_loss as f32,
680            perplexity: perplexity as f32,
681            accuracy,
682            num_batches,
683            total_tokens,
684        })
685    }
686}
687
688impl<B: Backend> crate::lm::LanguageModel<B> for MultiscreenModel<B> {
689    fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
690        MultiscreenModel::forward(self, tokens)
691    }
692}
693
694impl<B: Backend> crate::lm::TrainableLanguageModel<B> for MultiscreenModel<B> {}
695
696impl<B: Backend> GatedScreeningTile<B> {
697    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
698        let q = row_unit_normalize(linear(x.clone(), self.w_q.val()));
699        let k = row_unit_normalize(linear(x.clone(), self.w_k.val()));
700        let v = row_unit_normalize(linear(x.clone(), self.w_v.val()));
701        let g = linear(x, self.w_g.val());
702
703        let w = self.s_w.val().clamp(-10.0, 8.0).exp() + 1.0;
704        let r = activation::sigmoid(self.s_r.val().clamp(-10.0, 8.0));
705
706        let q = apply_mipe(q, w.clone(), self.w_th);
707        let k = apply_mipe(k, w.clone(), self.w_th);
708
709        let similarity = q.matmul(k.swap_dims(1, 2));
710        let alpha = trim_and_square_tensor(similarity.clone(), r);
711        let softmask = causal_softmask_tensor::<B>(similarity.dims()[1], w, &similarity.device());
712        let relevance = alpha * softmask.unsqueeze();
713        let h = relevance.matmul(v);
714        let u = tanh_norm(h);
715        let gate = activation::silu(g).tanh();
716        let gated = u * gate;
717        let out = linear(gated, self.w_o.val());
718        out.clone() * expand_scalar3(self.s_o.val().exp(), out.dims())
719    }
720}
721
722#[allow(dead_code)]
723pub fn cross_entropy_loss<B: Backend>(
724    logits: Tensor<B, 3>,
725    targets: Tensor<B, 2, Int>,
726) -> Tensor<B, 1> {
727    let device = logits.device();
728    let [batch, seq_len, _] = logits.dims();
729    let loss_mask = Tensor::<B, 2>::ones([batch, seq_len], &device);
730    cross_entropy_loss_with_mask(logits, targets, loss_mask)
731}
732
733pub fn cross_entropy_loss_with_mask<B: Backend>(
734    logits: Tensor<B, 3>,
735    targets: Tensor<B, 2, Int>,
736    loss_mask: Tensor<B, 2>,
737) -> Tensor<B, 1> {
738    let [batch, seq_len, vocab_size] = logits.dims();
739    let token_count = batch * seq_len;
740    let flat_logits = logits.reshape([token_count, vocab_size]);
741    let flat_targets = targets.reshape([token_count]);
742    let flat_mask = loss_mask.reshape([token_count]);
743    let log_probs = activation::log_softmax(flat_logits, 1);
744    let target_probs = flat_targets.one_hot::<2>(vocab_size).float();
745    let picked = (log_probs * target_probs).sum_dim(1).reshape([token_count]);
746    let denom = flat_mask.clone().sum().add_scalar(EPS);
747    (picked.neg() * flat_mask).sum() / denom
748}
749
750pub fn row_unit_normalize<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
751    let denom = x.clone().square().sum_dim(D - 1).add_scalar(EPS).sqrt();
752    x / denom
753}
754
755fn trim_and_square_tensor<B: Backend>(similarity: Tensor<B, 3>, r: Tensor<B, 1>) -> Tensor<B, 3> {
756    let distance_from_one = similarity.mul_scalar(-1.0).add_scalar(1.0);
757    let scaled = distance_from_one.clone() / expand_scalar3(r, distance_from_one.clone().dims());
758    scaled
759        .mul_scalar(-1.0)
760        .add_scalar(1.0)
761        .clamp(0.0, 1.0)
762        .square()
763}
764
765fn causal_softmask_tensor<B: Backend>(
766    seq_len: usize,
767    w: Tensor<B, 1>,
768    device: &B::Device,
769) -> Tensor<B, 2> {
770    let mut distances = Vec::with_capacity(seq_len * seq_len);
771    for i in 0..seq_len {
772        for j in 0..seq_len {
773            distances.push(j as f32 - i as f32);
774        }
775    }
776
777    let dist = Tensor::<B, 2>::from_data(TensorData::new(distances, [seq_len, seq_len]), device);
778    let w = expand_scalar2(w, [seq_len, seq_len]);
779    let causal = dist.clone().lower_equal_elem(0.0);
780    let within_window = dist.clone().greater(w.clone().neg());
781    let active = causal.float() * within_window.float();
782    let tapered = (dist / w)
783        .mul_scalar(PI)
784        .cos()
785        .mul_scalar(0.5)
786        .add_scalar(0.5);
787    active * tapered
788}
789
790pub fn tanh_norm<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
791    let norm = x.clone().square().sum_dim(D - 1).add_scalar(EPS).sqrt();
792    let scale = norm.clone().tanh() / norm;
793    x * scale
794}
795
796fn apply_mipe<B: Backend>(z: Tensor<B, 3>, w: Tensor<B, 1>, w_th: f32) -> Tensor<B, 3> {
797    let [_batch, seq_len, d_key] = z.dims();
798    debug_assert!(d_key >= 2);
799
800    let positions = Tensor::<B, 3>::from_data(
801        TensorData::new(
802            (0..seq_len).map(|idx| idx as f32).collect::<Vec<_>>(),
803            [1, seq_len, 1],
804        ),
805        &z.device(),
806    );
807
808    let gamma_raw = w
809        .clone()
810        .mul_scalar(PI / w_th)
811        .cos()
812        .mul_scalar(0.5)
813        .add_scalar(0.5);
814    let gamma = gamma_raw.mask_fill(w.clone().greater_equal_elem(w_th), 0.0);
815    let gamma = expand_scalar3(gamma, [1, seq_len, 1]);
816    let w = expand_scalar3(w, [1, seq_len, 1]);
817    let phi = (positions * gamma / w).mul_scalar(PI);
818    let cos_phi = phi.clone().cos();
819    let sin_phi = phi.sin();
820
821    let x0 = z.clone().narrow(2, 0, 1);
822    let x1 = z.clone().narrow(2, 1, 1);
823    let rot0 = x0.clone() * cos_phi.clone() - x1.clone() * sin_phi.clone();
824    let rot1 = x0 * sin_phi + x1 * cos_phi;
825
826    if d_key == 2 {
827        Tensor::cat(vec![rot0, rot1], 2)
828    } else {
829        let rest = z.narrow(2, 2, d_key - 2);
830        Tensor::cat(vec![rot0, rot1, rest], 2)
831    }
832}
833
834fn linear<B: Backend>(x: Tensor<B, 3>, weight: Tensor<B, 2>) -> Tensor<B, 3> {
835    let [batch, _seq_len, _in_dim] = x.dims();
836    let [weight_in, out_dim] = weight.dims();
837    x.matmul(weight.unsqueeze::<3>().expand([batch, weight_in, out_dim]))
838}
839
840fn expand_scalar2<B: Backend>(value: Tensor<B, 1>, dims: [usize; 2]) -> Tensor<B, 2> {
841    value.unsqueeze::<2>().expand(dims)
842}
843
844fn expand_scalar3<B: Backend>(value: Tensor<B, 1>, dims: [usize; 3]) -> Tensor<B, 3> {
845    value.unsqueeze::<3>().expand(dims)
846}
847
848fn init_scalar<B: Backend>(value: f32, device: &B::Device) -> Param<Tensor<B, 1>> {
849    Param::from_tensor(Tensor::<B, 1>::from_data([value], device))
850}
851
852fn init_matrix<B: Backend>(
853    rows: usize,
854    cols: usize,
855    std: f32,
856    seed: &mut u64,
857    device: &B::Device,
858) -> Param<Tensor<B, 2>> {
859    let values = gaussian_values(rows * cols, std, seed);
860    Param::from_tensor(Tensor::<B, 2>::from_data(
861        TensorData::new(values, [rows, cols]),
862        device,
863    ))
864}
865
866fn gaussian_values(len: usize, std: f32, seed: &mut u64) -> Vec<f32> {
867    let mut values = Vec::with_capacity(len);
868    while values.len() < len {
869        let u1 = next_uniform(seed).max(1e-7);
870        let u2 = next_uniform(seed);
871        let radius = (-2.0 * u1.ln()).sqrt();
872        let theta = 2.0 * PI * u2;
873        values.push(radius * theta.cos() * std);
874        if values.len() < len {
875            values.push(radius * theta.sin() * std);
876        }
877    }
878    values
879}
880
881fn next_uniform(seed: &mut u64) -> f32 {
882    *seed = seed
883        .wrapping_mul(6364136223846793005)
884        .wrapping_add(1442695040888963407);
885    let bits = (*seed >> 40) as u32;
886    (bits as f32 + 1.0) / ((1u32 << 24) as f32 + 2.0)
887}
888
889fn context_window(context: &[u32], seq_len: usize, pad_token_id: u32) -> Vec<u32> {
890    let mut input = vec![pad_token_id; seq_len];
891    let suffix = if context.len() > seq_len {
892        &context[context.len() - seq_len..]
893    } else {
894        context
895    };
896    let start = seq_len - suffix.len();
897    input[start..].copy_from_slice(suffix);
898    input
899}
900
901fn argmax(values: &[f32]) -> Result<usize> {
902    values
903        .iter()
904        .enumerate()
905        .max_by(|left, right| left.1.total_cmp(right.1))
906        .map(|(index, _)| index)
907        .ok_or_else(|| Error::Inference("cannot argmax empty logits".to_string()))
908}
909
910fn tensor_scalar<B: Backend>(tensor: Tensor<B, 1>) -> Result<f32> {
911    tensor
912        .into_data()
913        .into_vec::<f32>()
914        .map_err(|err| Error::Training(err.to_string()))?
915        .into_iter()
916        .next()
917        .ok_or_else(|| Error::Training("expected scalar tensor".to_string()))
918}
919
920fn tensor_to_vec<B: Backend>(tensor: Tensor<B, 1>) -> Result<Vec<f32>> {
921    tensor
922        .into_data()
923        .into_vec::<f32>()
924        .map_err(|err| Error::Inference(err.to_string()))
925}
926
927fn tensor_from_u32<B: Backend, const D: usize>(
928    values: Vec<u32>,
929    shape: [usize; D],
930    device: &B::Device,
931) -> Result<Tensor<B, D, Int>> {
932    let values = values
933        .into_iter()
934        .map(|value| {
935            i32::try_from(value)
936                .map_err(|_| Error::Config(format!("token id {value} exceeds i32::MAX")))
937        })
938        .collect::<Result<Vec<_>>>()?;
939    Ok(Tensor::<B, D, Int>::from_data(
940        TensorData::new(values, shape),
941        device,
942    ))
943}
944
945fn ensure(condition: bool, message: &str) -> Result<()> {
946    if condition {
947        Ok(())
948    } else {
949        Err(Error::Config(message.to_string()))
950    }
951}
952
953struct TrainingWindow {
954    inputs: Vec<u32>,
955    targets: Vec<u32>,
956    loss_mask: Vec<f32>,
957}
958
959struct TrainingWindows {
960    windows: Vec<TrainingWindow>,
961    seq_len: usize,
962}
963
964impl TrainingWindows {
965    fn from_sequences(sequences: &[Vec<u32>], seq_len: usize, pad_token_id: u32) -> Result<Self> {
966        let mut windows = Vec::new();
967        for sequence in sequences {
968            if sequence.len() < 2 {
969                continue;
970            }
971
972            let mut start = 0;
973            while start + 1 < sequence.len() {
974                let end = (start + seq_len + 1).min(sequence.len());
975                let chunk = &sequence[start..end];
976                let prediction_count = chunk.len() - 1;
977
978                let mut inputs = vec![pad_token_id; seq_len];
979                let mut targets = vec![pad_token_id; seq_len];
980                let mut loss_mask = vec![0.0; seq_len];
981                inputs[..prediction_count].copy_from_slice(&chunk[..prediction_count]);
982                targets[..prediction_count].copy_from_slice(&chunk[1..]);
983                loss_mask[..prediction_count].fill(1.0);
984
985                windows.push(TrainingWindow {
986                    inputs,
987                    targets,
988                    loss_mask,
989                });
990
991                if end == sequence.len() {
992                    break;
993                }
994                start += seq_len;
995            }
996        }
997
998        Ok(Self { windows, seq_len })
999    }
1000
1001    fn is_empty(&self) -> bool {
1002        self.windows.is_empty()
1003    }
1004
1005    fn len(&self) -> usize {
1006        self.windows.len()
1007    }
1008
1009    fn batch<B: Backend>(
1010        &self,
1011        step: usize,
1012        batch_size: usize,
1013        device: &B::Device,
1014    ) -> Result<TokenBatch<B>> {
1015        let mut inputs = Vec::with_capacity(batch_size * self.seq_len);
1016        let mut targets = Vec::with_capacity(batch_size * self.seq_len);
1017        let mut loss_mask = Vec::with_capacity(batch_size * self.seq_len);
1018
1019        for batch_idx in 0..batch_size {
1020            let index = (step * batch_size + batch_idx) % self.windows.len();
1021            let window = &self.windows[index];
1022            inputs.extend_from_slice(&window.inputs);
1023            targets.extend_from_slice(&window.targets);
1024            loss_mask.extend_from_slice(&window.loss_mask);
1025        }
1026
1027        Ok(TokenBatch {
1028            inputs: tensor_from_u32(inputs, [batch_size, self.seq_len], device)?,
1029            targets: tensor_from_u32(targets, [batch_size, self.seq_len], device)?,
1030            loss_mask: Tensor::<B, 2>::from_data(
1031                TensorData::new(loss_mask, [batch_size, self.seq_len]),
1032                device,
1033            ),
1034        })
1035    }
1036}
1037
1038struct TokenBatch<B: Backend> {
1039    inputs: Tensor<B, 2, Int>,
1040    targets: Tensor<B, 2, Int>,
1041    loss_mask: Tensor<B, 2>,
1042}
1043
1044#[cfg(test)]
1045#[allow(dead_code)]
1046pub fn make_batch<B: Backend>(
1047    step: usize,
1048    batch_size: usize,
1049    seq_len: usize,
1050    vocab_size: usize,
1051    device: &B::Device,
1052) -> Result<(Tensor<B, 2, Int>, Tensor<B, 2, Int>)> {
1053    let mut inputs = Vec::with_capacity(batch_size * seq_len);
1054    let mut targets = Vec::with_capacity(batch_size * seq_len);
1055
1056    for batch in 0..batch_size {
1057        let offset = (step * 7 + batch * 13) % vocab_size;
1058        for pos in 0..seq_len {
1059            let token = ((offset + pos) % vocab_size) as u32;
1060            let next = ((offset + pos + 1) % vocab_size) as u32;
1061            inputs.push(token);
1062            targets.push(next);
1063        }
1064    }
1065
1066    Ok((
1067        tensor_from_u32(inputs, [batch_size, seq_len], device)?,
1068        tensor_from_u32(targets, [batch_size, seq_len], device)?,
1069    ))
1070}