1use crate::{
2 error::{Error, Result},
3 runtime::DefaultAutodiffBackend,
4};
5use burn::{
6 grad_clipping::GradientClippingConfig,
7 module::{Module, Param},
8 optim::{AdamWConfig, GradientsParams, Optimizer},
9 record::{FullPrecisionSettings, NamedMpkFileRecorder},
10 tensor::{
11 Int, Tensor, TensorData, activation,
12 backend::{AutodiffBackend, Backend},
13 },
14};
15use serde::{Deserialize, Serialize};
16use std::f32::consts::PI;
17use std::path::Path;
18
19const EPS: f32 = 1e-6;
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
27pub enum MultiscreenParameterBudget {
28 Params1M,
29 Params5M,
30 Params10M,
31 Params50M,
32 Params100M,
33}
34
35impl MultiscreenParameterBudget {
36 pub const ALL: [Self; 5] = [
37 Self::Params1M,
38 Self::Params5M,
39 Self::Params10M,
40 Self::Params50M,
41 Self::Params100M,
42 ];
43
44 pub fn label(self) -> &'static str {
45 match self {
46 Self::Params1M => "1M",
47 Self::Params5M => "5M",
48 Self::Params10M => "10M",
49 Self::Params50M => "50M",
50 Self::Params100M => "100M",
51 }
52 }
53
54 pub fn target_parameter_count(self) -> usize {
55 match self {
56 Self::Params1M => 1_000_000,
57 Self::Params5M => 5_000_000,
58 Self::Params10M => 10_000_000,
59 Self::Params50M => 50_000_000,
60 Self::Params100M => 100_000_000,
61 }
62 }
63}
64
65#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
67pub struct MultiscreenModelConfig {
68 pub vocab_size: usize,
70 pub seq_len: usize,
72 pub layers: usize,
74 pub tiles: usize,
76 pub d_model: usize,
78 pub d_key: usize,
80 pub d_value: usize,
82 pub w_th: f32,
84}
85
86impl MultiscreenModelConfig {
87 pub fn tiny() -> Self {
88 Self {
89 vocab_size: 64,
90 seq_len: 64,
91 layers: 2,
92 tiles: 4,
93 d_model: 64,
94 d_key: 16,
95 d_value: 32,
96 w_th: 32.0,
97 }
98 }
99
100 pub fn tiny_for_tests() -> Self {
101 Self {
102 vocab_size: 32,
103 seq_len: 8,
104 layers: 1,
105 tiles: 2,
106 d_model: 16,
107 d_key: 4,
108 d_value: 8,
109 w_th: 8.0,
110 }
111 }
112
113 pub fn for_parameter_budget(
114 budget: MultiscreenParameterBudget,
115 vocab_size: usize,
116 seq_len: usize,
117 ) -> Self {
118 match budget {
119 MultiscreenParameterBudget::Params1M => Self::preset_1m(vocab_size, seq_len),
120 MultiscreenParameterBudget::Params5M => Self::preset_5m(vocab_size, seq_len),
121 MultiscreenParameterBudget::Params10M => Self::preset_10m(vocab_size, seq_len),
122 MultiscreenParameterBudget::Params50M => Self::preset_50m(vocab_size, seq_len),
123 MultiscreenParameterBudget::Params100M => Self::preset_100m(vocab_size, seq_len),
124 }
125 }
126
127 pub fn preset_1m(vocab_size: usize, seq_len: usize) -> Self {
128 Self::from_dimensions(vocab_size, seq_len, 2, 2, 128, 32, 64)
129 }
130
131 pub fn preset_5m(vocab_size: usize, seq_len: usize) -> Self {
132 Self::from_dimensions(vocab_size, seq_len, 2, 4, 384, 96, 192)
133 }
134
135 pub fn preset_10m(vocab_size: usize, seq_len: usize) -> Self {
136 Self::from_dimensions(vocab_size, seq_len, 3, 4, 512, 128, 256)
137 }
138
139 pub fn preset_50m(vocab_size: usize, seq_len: usize) -> Self {
140 Self::from_dimensions(vocab_size, seq_len, 6, 4, 960, 240, 480)
141 }
142
143 pub fn preset_100m(vocab_size: usize, seq_len: usize) -> Self {
144 Self::from_dimensions(vocab_size, seq_len, 8, 4, 1216, 304, 608)
145 }
146
147 pub fn paper_10m(vocab_size: usize, seq_len: usize) -> Self {
148 Self::preset_10m(vocab_size, seq_len)
149 }
150
151 pub fn estimated_parameter_count(&self) -> usize {
152 let embedding_params = self.vocab_size.saturating_mul(self.d_model);
153 let per_tile_params = self
154 .d_model
155 .saturating_mul(
156 2usize
157 .saturating_mul(self.d_key)
158 .saturating_add(3usize.saturating_mul(self.d_value)),
159 )
160 .saturating_add(3);
161 let tile_params = self
162 .layers
163 .saturating_mul(self.tiles)
164 .saturating_mul(per_tile_params);
165 embedding_params
166 .saturating_add(2)
167 .saturating_add(tile_params)
168 }
169
170 fn from_dimensions(
171 vocab_size: usize,
172 seq_len: usize,
173 layers: usize,
174 tiles: usize,
175 d_model: usize,
176 d_key: usize,
177 d_value: usize,
178 ) -> Self {
179 Self {
180 vocab_size,
181 seq_len,
182 layers,
183 tiles,
184 d_model,
185 d_key,
186 d_value,
187 w_th: 32.0,
188 }
189 }
190
191 pub fn validate(&self) -> Result<()> {
192 ensure(self.vocab_size > 0, "vocab_size must be greater than zero")?;
193 ensure(self.seq_len > 0, "seq_len must be greater than zero")?;
194 ensure(self.layers > 0, "layers must be greater than zero")?;
195 ensure(self.tiles > 0, "tiles must be greater than zero")?;
196 ensure(self.d_model > 0, "d_model must be greater than zero")?;
197 ensure(self.d_key >= 2, "d_key must be at least 2 for MiPE")?;
198 ensure(self.d_value > 0, "d_value must be greater than zero")?;
199 ensure(
200 self.w_th.is_finite() && self.w_th > 0.0,
201 "w_th must be positive and finite",
202 )?;
203 Ok(())
204 }
205}
206
207#[derive(Clone, Debug)]
209pub struct ModelTrainingConfig {
210 pub steps: usize,
211 pub batch_size: usize,
212 pub learning_rate: f64,
213 pub weight_decay: f64,
214 pub grad_clip_norm: Option<f64>,
215 pub pad_token_id: u32,
216 pub checkpoint_dir: Option<String>,
219 pub checkpoint_interval: usize,
223}
224
225impl Default for ModelTrainingConfig {
226 fn default() -> Self {
227 Self {
228 steps: 100,
229 batch_size: 4,
230 learning_rate: 2e-4,
231 weight_decay: 0.01,
232 grad_clip_norm: Some(1.0),
233 pad_token_id: 0,
234 checkpoint_dir: None,
235 checkpoint_interval: 0,
236 }
237 }
238}
239
240#[derive(Clone, Debug)]
242pub struct ModelInferenceConfig {
243 pub max_new_tokens: usize,
244 pub pad_token_id: u32,
245}
246
247impl Default for ModelInferenceConfig {
248 fn default() -> Self {
249 Self {
250 max_new_tokens: 16,
251 pad_token_id: 0,
252 }
253 }
254}
255
256#[derive(Clone, Debug, PartialEq)]
257pub struct ModelTrainingReport {
258 pub steps: usize,
259 pub final_loss: f32,
260 pub best_loss: f32,
262 pub best_loss_step: usize,
264 pub training_window_count: usize,
265 pub parameter_count: usize,
266}
267
268#[derive(Clone, Debug)]
270pub struct EvaluationResult {
271 pub loss: f32,
273 pub perplexity: f32,
275 pub accuracy: f64,
277 pub num_batches: usize,
279 pub total_tokens: usize,
281}
282
283pub struct MultiscreenModelOutput {
284 pub token_ids: Vec<u32>,
285}
286
287#[derive(Module, Debug)]
290pub struct MultiscreenModel<B: Backend = DefaultAutodiffBackend> {
291 #[module(skip)]
292 config: MultiscreenModelConfig,
293 token_embedding: Param<Tensor<B, 2>>,
294 s_e: Param<Tensor<B, 1>>,
295 s_f: Param<Tensor<B, 1>>,
296 layers: Vec<MultiscreenLayer<B>>,
297}
298
299pub type DefaultMultiscreenModel = MultiscreenModel<DefaultAutodiffBackend>;
301
302#[derive(Module, Debug)]
303struct MultiscreenLayer<B: Backend> {
304 tiles: Vec<GatedScreeningTile<B>>,
305}
306
307#[derive(Module, Debug)]
308struct GatedScreeningTile<B: Backend> {
309 w_q: Param<Tensor<B, 2>>,
310 w_k: Param<Tensor<B, 2>>,
311 w_v: Param<Tensor<B, 2>>,
312 w_g: Param<Tensor<B, 2>>,
313 w_o: Param<Tensor<B, 2>>,
314 s_w: Param<Tensor<B, 1>>,
315 s_r: Param<Tensor<B, 1>>,
316 s_o: Param<Tensor<B, 1>>,
317 #[module(skip)]
318 w_th: f32,
319}
320
321impl<B: Backend> MultiscreenModel<B> {
322 pub fn new(config: MultiscreenModelConfig, device: &B::Device) -> Result<Self> {
323 config.validate()?;
324
325 let mut seed = 0x4d55_4c54_4953_4352;
326 let token_embedding = init_matrix(
327 config.vocab_size,
328 config.d_model,
329 0.1 / (config.d_model as f32).sqrt(),
330 &mut seed,
331 device,
332 );
333 let s_e = init_scalar(0.0, device);
334 let s_f = init_scalar((config.d_model as f32).sqrt().ln(), device);
335
336 let mut layers = Vec::with_capacity(config.layers);
337 for _layer_idx in 0..config.layers {
338 let mut tiles = Vec::with_capacity(config.tiles);
339 for tile_idx in 0..config.tiles {
340 let w_q = init_matrix(
341 config.d_model,
342 config.d_key,
343 0.1 / (config.d_key as f32).sqrt(),
344 &mut seed,
345 device,
346 );
347 let w_k = init_matrix(
348 config.d_model,
349 config.d_key,
350 0.1 / (config.d_key as f32).sqrt(),
351 &mut seed,
352 device,
353 );
354 let w_v = init_matrix(
355 config.d_model,
356 config.d_value,
357 0.1 / (config.d_value as f32).sqrt(),
358 &mut seed,
359 device,
360 );
361 let w_g = init_matrix(config.d_model, config.d_value, 0.1, &mut seed, device);
362 let w_o = init_matrix(
363 config.d_value,
364 config.d_model,
365 0.1 / (config.d_model as f32).sqrt(),
366 &mut seed,
367 device,
368 );
369
370 let window_frac = if config.tiles == 1 {
371 0.0
372 } else {
373 tile_idx as f32 / (config.tiles - 1) as f32
374 };
375 let s_w = init_scalar(window_frac * config.w_th.ln(), device);
376 let s_r = init_scalar(0.0, device);
377 let s_o = init_scalar(-0.5 * ((config.layers * config.tiles) as f32).ln(), device);
378
379 tiles.push(GatedScreeningTile {
380 w_q,
381 w_k,
382 w_v,
383 w_g,
384 w_o,
385 s_w,
386 s_r,
387 s_o,
388 w_th: config.w_th,
389 });
390 }
391 layers.push(MultiscreenLayer { tiles });
392 }
393
394 Ok(Self {
395 config,
396 token_embedding,
397 s_e,
398 s_f,
399 layers,
400 })
401 }
402
403 pub fn config(&self) -> &MultiscreenModelConfig {
404 &self.config
405 }
406
407 pub fn parameter_count(&self) -> usize {
408 self.num_params()
409 }
410
411 pub fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
412 let [batch, seq_len] = tokens.dims();
413 let embedding = row_unit_normalize(self.token_embedding.val());
414 let one_hot = tokens.one_hot::<3>(self.config().vocab_size).float();
415 let mut x =
416 linear(one_hot, embedding.clone()).reshape([batch, seq_len, self.config().d_model]);
417 let x_dims = x.dims();
418 x = x * expand_scalar3(self.s_e.val().exp(), x_dims);
419
420 for layer in &self.layers {
421 let mut layer_update = Tensor::<B, 3>::zeros(x.dims(), &x.device());
422 for tile in &layer.tiles {
423 layer_update = layer_update + tile.forward(x.clone());
424 }
425 x = x + layer_update;
426 }
427
428 let logits_weight = embedding.swap_dims(0, 1);
429 let logits = linear(x, logits_weight);
430 logits.clone() * expand_scalar3(self.s_f.val().exp(), logits.dims())
431 }
432
433 pub fn save_parameters(&self, path: impl AsRef<Path>) -> Result<()> {
434 let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
435 self.clone()
436 .save_file(path.as_ref().to_path_buf(), &recorder)
437 .map_err(|err| Error::Serialization(err.to_string()))
438 }
439
440 pub fn load_parameters(&mut self, path: impl AsRef<Path>) -> Result<()> {
441 let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
442 let device =
443 self.devices().into_iter().next().ok_or_else(|| {
444 Error::Serialization("model has no device for parameter load".into())
445 })?;
446 let loaded = self
447 .clone()
448 .load_file(path.as_ref().to_path_buf(), &recorder, &device)
449 .map_err(|err| Error::Serialization(err.to_string()))?;
450 *self = loaded;
451 Ok(())
452 }
453
454 pub fn infer_tokens_stream(
465 &self,
466 prompt: &[u32],
467 inference: &ModelInferenceConfig,
468 device: &B::Device,
469 mut on_token: impl FnMut(u32, usize) -> bool,
470 ) -> Result<MultiscreenModelOutput> {
471 if prompt.is_empty() {
472 return Err(Error::Inference(
473 "prompt must contain at least one token".to_string(),
474 ));
475 }
476
477 let mut output = prompt.to_vec();
478 for i in 0..inference.max_new_tokens {
479 let next = self.predict_next_token(&output, inference.pad_token_id, device)?;
480 output.push(next);
481 if !on_token(next, i) {
482 break;
483 }
484 }
485
486 Ok(MultiscreenModelOutput { token_ids: output })
487 }
488
489 pub fn infer_tokens(
493 &self,
494 prompt: &[u32],
495 inference: &ModelInferenceConfig,
496 device: &B::Device,
497 ) -> Result<MultiscreenModelOutput> {
498 self.infer_tokens_stream(prompt, inference, device, |_, _| true)
499 }
500
501 pub fn predict_next_token(
502 &self,
503 context: &[u32],
504 pad_token_id: u32,
505 device: &B::Device,
506 ) -> Result<u32> {
507 if context.is_empty() {
508 return Err(Error::Inference(
509 "context must contain at least one token".to_string(),
510 ));
511 }
512
513 let input = context_window(context, self.config().seq_len, pad_token_id);
514 let input = tensor_from_u32::<B, 2>(input, [1, self.config().seq_len], device)?;
515 let logits = self.forward(input);
516 let last_logits = logits
517 .slice([
518 0..1,
519 self.config().seq_len - 1..self.config().seq_len,
520 0..self.config().vocab_size,
521 ])
522 .reshape([self.config().vocab_size]);
523 let values = tensor_to_vec(last_logits)?;
524 argmax(&values).map(|idx| idx as u32)
525 }
526
527 pub fn forward_logits(
535 &self,
536 context: &[u32],
537 pad_token_id: u32,
538 device: &B::Device,
539 ) -> Result<Tensor<B, 3>> {
540 if context.is_empty() {
541 return Err(Error::Inference(
542 "context must contain at least one token".to_string(),
543 ));
544 }
545
546 let input = context_window(context, self.config().seq_len, pad_token_id);
547 let input = tensor_from_u32::<B, 2>(input, [1, self.config().seq_len], device)?;
548 let logits = self.forward(input);
549 Ok(logits)
550 }
551
552 pub fn evaluate_on_sequences(
558 &self,
559 sequences: &[Vec<u32>],
560 seq_len: usize,
561 batch_size: usize,
562 pad_token_id: u32,
563 device: &B::Device,
564 ) -> Result<EvaluationResult> {
565 let windows = TrainingWindows::from_sequences(sequences, seq_len, pad_token_id)?;
566 if windows.is_empty() {
567 return Ok(EvaluationResult {
568 loss: f32::NAN,
569 perplexity: f32::NAN,
570 accuracy: 0.0,
571 num_batches: 0,
572 total_tokens: 0,
573 });
574 }
575
576 let num_batches = windows.len().div_ceil(batch_size);
577 let mut total_loss = 0.0_f64;
578 let mut total_correct = 0_usize;
579 let mut total_tokens = 0_usize;
580
581 for step in 0..num_batches {
582 let batch = windows.batch::<B>(step, batch_size, device)?;
583 let logits = self.forward(batch.inputs); let loss = cross_entropy_loss_with_mask(
585 logits.clone(),
586 batch.targets.clone(),
587 batch.loss_mask.clone(),
588 );
589 let loss_val = tensor_scalar(loss)? as f64;
590 total_loss += loss_val;
591
592 let [b, s, v] = logits.dims();
594 let mask_vec: Vec<f32> = batch
595 .loss_mask
596 .clone()
597 .reshape([b * s])
598 .into_data()
599 .into_vec::<f32>()
600 .map_err(|e| Error::Inference(e.to_string()))?;
601 let target_vec: Vec<i32> = batch
602 .targets
603 .clone()
604 .reshape([b * s])
605 .into_data()
606 .into_vec::<i32>()
607 .map_err(|e| Error::Inference(e.to_string()))?;
608 let logit_vec: Vec<f32> = logits
609 .reshape([b * s * v])
610 .into_data()
611 .into_vec::<f32>()
612 .map_err(|e| Error::Inference(e.to_string()))?;
613
614 for bi in 0..b {
615 for si in 0..s {
616 let mi = bi * s + si;
617 if mask_vec[mi] < 0.5 {
618 continue;
619 }
620 total_tokens += 1;
621 let base = bi * s * v + si * v;
622 let mut best_idx = 0;
623 let mut best_val = f32::NEG_INFINITY;
624 for vi in 0..v {
625 let val = logit_vec[base + vi];
626 if val > best_val {
627 best_val = val;
628 best_idx = vi;
629 }
630 }
631 if best_idx == target_vec[mi] as usize {
632 total_correct += 1;
633 }
634 }
635 }
636 }
637
638 let avg_loss = total_loss / num_batches as f64;
639 let perplexity = avg_loss.exp();
640 let accuracy = if total_tokens > 0 {
641 total_correct as f64 / total_tokens as f64
642 } else {
643 0.0
644 };
645
646 Ok(EvaluationResult {
647 loss: avg_loss as f32,
648 perplexity: perplexity as f32,
649 accuracy,
650 num_batches,
651 total_tokens,
652 })
653 }
654}
655
656impl<B> MultiscreenModel<B>
657where
658 B: AutodiffBackend,
659{
660 pub fn train_token_sequences(
665 &mut self,
666 sequences: &[Vec<u32>],
667 training: &ModelTrainingConfig,
668 device: &B::Device,
669 mut on_step: impl FnMut(usize, f32),
670 ) -> Result<ModelTrainingReport> {
671 if training.batch_size == 0 {
672 return Err(Error::Training(
673 "batch_size must be greater than zero".to_string(),
674 ));
675 }
676 let windows = TrainingWindows::from_sequences(
677 sequences,
678 self.config().seq_len,
679 training.pad_token_id,
680 )?;
681 if windows.is_empty() {
682 return Err(Error::Training(
683 "training requires at least one sequence with two or more tokens".to_string(),
684 ));
685 }
686
687 let mut optimizer_config =
688 AdamWConfig::new().with_weight_decay(training.weight_decay as f32);
689 if let Some(max_norm) = training.grad_clip_norm.filter(|value| *value > 0.0) {
690 optimizer_config = optimizer_config
691 .with_grad_clipping(Some(GradientClippingConfig::Norm(max_norm as f32)));
692 }
693 let mut optimizer = optimizer_config.init::<B, Self>();
694 let mut model = self.clone();
695 let mut final_loss = f32::NAN;
696 let mut best_loss = f32::MAX;
697 let mut best_loss_step: usize = 0;
698
699 let ckpt_dir = training.checkpoint_dir.as_deref().map(Path::new);
700 if let Some(dir) = &ckpt_dir {
701 std::fs::create_dir_all(dir).map_err(|e| {
702 Error::Io(format!(
703 "failed to create checkpoint directory {:?}: {}",
704 dir, e
705 ))
706 })?;
707 }
708
709 for step in 0..training.steps {
710 let batch = windows.batch::<B>(step, training.batch_size, device)?;
711 let logits = model.forward(batch.inputs);
712 let loss = cross_entropy_loss_with_mask(logits, batch.targets, batch.loss_mask);
713 final_loss = tensor_scalar(loss.clone())?;
714 let grads = loss.backward();
715 let grads = GradientsParams::from_grads(grads, &model);
716 model = optimizer.step(training.learning_rate, model, grads);
717
718 if final_loss < best_loss {
720 best_loss = final_loss;
721 best_loss_step = step;
722 if let Some(dir) = &ckpt_dir {
723 let path = dir.join("best.mpk");
724 model.save_parameters(&path)?;
725 }
726 }
727 if training.checkpoint_interval > 0
728 && (step + 1) % training.checkpoint_interval == 0
729 && let Some(dir) = &ckpt_dir
730 {
731 let path = dir.join(format!("step_{:06}.mpk", step + 1));
732 model.save_parameters(&path)?;
733 }
734
735 on_step(step, final_loss);
736 }
737
738 if training.steps == 0 {
739 let batch = windows.batch::<B>(0, training.batch_size, device)?;
740 final_loss = tensor_scalar(cross_entropy_loss_with_mask(
741 model.forward(batch.inputs),
742 batch.targets,
743 batch.loss_mask,
744 ))?;
745 best_loss = final_loss;
746 best_loss_step = 0;
747 }
748
749 *self = model;
750
751 Ok(ModelTrainingReport {
752 steps: training.steps,
753 final_loss,
754 best_loss,
755 best_loss_step,
756 training_window_count: windows.len(),
757 parameter_count: self.parameter_count(),
758 })
759 }
760
761 pub fn train_chat_sequences(
773 &mut self,
774 chat_pairs: &[(Vec<u32>, Vec<u32>)],
775 training: &ModelTrainingConfig,
776 device: &B::Device,
777 mut on_step: impl FnMut(usize, f32),
778 ) -> Result<ModelTrainingReport> {
779 if training.batch_size == 0 {
780 return Err(Error::Training(
781 "batch_size must be greater than zero".to_string(),
782 ));
783 }
784 let windows = TrainingWindows::from_chat_sequences(
785 chat_pairs,
786 self.config().seq_len,
787 training.pad_token_id,
788 )?;
789 if windows.is_empty() {
790 return Err(Error::Training(
791 "training requires at least one chat pair that produces two or more tokens"
792 .to_string(),
793 ));
794 }
795
796 let mut optimizer_config =
797 AdamWConfig::new().with_weight_decay(training.weight_decay as f32);
798 if let Some(max_norm) = training.grad_clip_norm.filter(|value| *value > 0.0) {
799 optimizer_config = optimizer_config
800 .with_grad_clipping(Some(GradientClippingConfig::Norm(max_norm as f32)));
801 }
802 let mut optimizer = optimizer_config.init::<B, Self>();
803 let mut model = self.clone();
804 let mut final_loss = f32::NAN;
805 let mut best_loss = f32::MAX;
806 let mut best_loss_step: usize = 0;
807
808 let ckpt_dir = training.checkpoint_dir.as_deref().map(Path::new);
809 if let Some(dir) = &ckpt_dir {
810 std::fs::create_dir_all(dir).map_err(|e| {
811 Error::Io(format!(
812 "failed to create checkpoint directory {:?}: {}",
813 dir, e
814 ))
815 })?;
816 }
817
818 for step in 0..training.steps {
819 let batch = windows.batch::<B>(step, training.batch_size, device)?;
820 let logits = model.forward(batch.inputs);
821 let loss = cross_entropy_loss_with_mask(logits, batch.targets, batch.loss_mask);
822 final_loss = tensor_scalar(loss.clone())?;
823 let grads = loss.backward();
824 let grads = GradientsParams::from_grads(grads, &model);
825 model = optimizer.step(training.learning_rate, model, grads);
826
827 if final_loss < best_loss {
829 best_loss = final_loss;
830 best_loss_step = step;
831 if let Some(dir) = &ckpt_dir {
832 let path = dir.join("best.mpk");
833 model.save_parameters(&path)?;
834 }
835 }
836 if training.checkpoint_interval > 0
837 && (step + 1) % training.checkpoint_interval == 0
838 && let Some(dir) = &ckpt_dir
839 {
840 let path = dir.join(format!("step_{:06}.mpk", step + 1));
841 model.save_parameters(&path)?;
842 }
843
844 on_step(step, final_loss);
845 }
846
847 if training.steps == 0 {
848 let batch = windows.batch::<B>(0, training.batch_size, device)?;
849 final_loss = tensor_scalar(cross_entropy_loss_with_mask(
850 model.forward(batch.inputs),
851 batch.targets,
852 batch.loss_mask,
853 ))?;
854 best_loss = final_loss;
855 best_loss_step = 0;
856 }
857
858 *self = model;
859
860 Ok(ModelTrainingReport {
861 steps: training.steps,
862 final_loss,
863 best_loss,
864 best_loss_step,
865 training_window_count: windows.len(),
866 parameter_count: self.parameter_count(),
867 })
868 }
869}
870
871impl<B: Backend> crate::lm::LanguageModel<B> for MultiscreenModel<B> {
872 fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
873 MultiscreenModel::forward(self, tokens)
874 }
875}
876
877impl<B: Backend> crate::lm::TrainableLanguageModel<B> for MultiscreenModel<B> {}
878
879impl<B: Backend> GatedScreeningTile<B> {
880 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
881 let q = row_unit_normalize(linear(x.clone(), self.w_q.val()));
882 let k = row_unit_normalize(linear(x.clone(), self.w_k.val()));
883 let v = row_unit_normalize(linear(x.clone(), self.w_v.val()));
884 let g = linear(x, self.w_g.val());
885
886 let w = self.s_w.val().clamp(-10.0, 8.0).exp() + 1.0;
887 let r = activation::sigmoid(self.s_r.val().clamp(-10.0, 8.0));
888
889 let q = apply_mipe(q, w.clone(), self.w_th);
890 let k = apply_mipe(k, w.clone(), self.w_th);
891
892 let similarity = q.matmul(k.swap_dims(1, 2));
893 let alpha = trim_and_square_tensor(similarity.clone(), r);
894 let softmask = causal_softmask_tensor::<B>(similarity.dims()[1], w, &similarity.device());
895 let relevance = alpha * softmask.unsqueeze();
896 let h = relevance.matmul(v);
897 let u = tanh_norm(h);
898 let gate = activation::silu(g).tanh();
899 let gated = u * gate;
900 let out = linear(gated, self.w_o.val());
901 out.clone() * expand_scalar3(self.s_o.val().exp(), out.dims())
902 }
903}
904
905#[allow(dead_code)]
906pub fn cross_entropy_loss<B: Backend>(
907 logits: Tensor<B, 3>,
908 targets: Tensor<B, 2, Int>,
909) -> Tensor<B, 1> {
910 let device = logits.device();
911 let [batch, seq_len, _] = logits.dims();
912 let loss_mask = Tensor::<B, 2>::ones([batch, seq_len], &device);
913 cross_entropy_loss_with_mask(logits, targets, loss_mask)
914}
915
916pub fn cross_entropy_loss_with_mask<B: Backend>(
917 logits: Tensor<B, 3>,
918 targets: Tensor<B, 2, Int>,
919 loss_mask: Tensor<B, 2>,
920) -> Tensor<B, 1> {
921 let [batch, seq_len, vocab_size] = logits.dims();
922 let token_count = batch * seq_len;
923 let flat_logits = logits.reshape([token_count, vocab_size]);
924 let flat_targets = targets.reshape([token_count]);
925 let flat_mask = loss_mask.reshape([token_count]);
926 let log_probs = activation::log_softmax(flat_logits, 1);
927 let target_probs = flat_targets.one_hot::<2>(vocab_size).float();
928 let picked = (log_probs * target_probs).sum_dim(1).reshape([token_count]);
929 let masked_nll = (picked.neg() * flat_mask.clone()).sum();
930 let mask_sum = flat_mask.sum();
934 let denom = mask_sum.add_scalar(EPS);
935 masked_nll / denom
936}
937
938pub fn row_unit_normalize<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
939 let denom = x.clone().square().sum_dim(D - 1).add_scalar(EPS).sqrt();
940 x / denom
941}
942
943fn trim_and_square_tensor<B: Backend>(similarity: Tensor<B, 3>, r: Tensor<B, 1>) -> Tensor<B, 3> {
944 let distance_from_one = similarity.mul_scalar(-1.0).add_scalar(1.0);
945 let scaled = distance_from_one.clone() / expand_scalar3(r, distance_from_one.clone().dims());
946 scaled
947 .mul_scalar(-1.0)
948 .add_scalar(1.0)
949 .clamp(0.0, 1.0)
950 .square()
951}
952
953fn causal_softmask_tensor<B: Backend>(
954 seq_len: usize,
955 w: Tensor<B, 1>,
956 device: &B::Device,
957) -> Tensor<B, 2> {
958 let mut distances = Vec::with_capacity(seq_len * seq_len);
959 for i in 0..seq_len {
960 for j in 0..seq_len {
961 distances.push(j as f32 - i as f32);
962 }
963 }
964
965 let dist = Tensor::<B, 2>::from_data(TensorData::new(distances, [seq_len, seq_len]), device);
966 let w = expand_scalar2(w, [seq_len, seq_len]);
967 let causal = dist.clone().lower_equal_elem(0.0);
968 let within_window = dist.clone().greater(w.clone().neg());
969 let active = causal.float() * within_window.float();
970 let tapered = (dist / w)
971 .mul_scalar(PI)
972 .cos()
973 .mul_scalar(0.5)
974 .add_scalar(0.5);
975 active * tapered
976}
977
978pub fn tanh_norm<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
979 let norm = x.clone().square().sum_dim(D - 1).add_scalar(EPS).sqrt();
980 let scale = norm.clone().tanh() / norm;
981 x * scale
982}
983
984fn apply_mipe<B: Backend>(z: Tensor<B, 3>, w: Tensor<B, 1>, w_th: f32) -> Tensor<B, 3> {
985 let [_batch, seq_len, d_key] = z.dims();
986 debug_assert!(d_key >= 2);
987
988 let positions = Tensor::<B, 3>::from_data(
989 TensorData::new(
990 (0..seq_len).map(|idx| idx as f32).collect::<Vec<_>>(),
991 [1, seq_len, 1],
992 ),
993 &z.device(),
994 );
995
996 let gamma_raw = w
997 .clone()
998 .mul_scalar(PI / w_th)
999 .cos()
1000 .mul_scalar(0.5)
1001 .add_scalar(0.5);
1002 let gamma = gamma_raw.mask_fill(w.clone().greater_equal_elem(w_th), 0.0);
1003 let gamma = expand_scalar3(gamma, [1, seq_len, 1]);
1004 let w = expand_scalar3(w, [1, seq_len, 1]);
1005 let phi = (positions * gamma / w).mul_scalar(PI);
1006 let cos_phi = phi.clone().cos();
1007 let sin_phi = phi.sin();
1008
1009 let x0 = z.clone().narrow(2, 0, 1);
1010 let x1 = z.clone().narrow(2, 1, 1);
1011 let rot0 = x0.clone() * cos_phi.clone() - x1.clone() * sin_phi.clone();
1012 let rot1 = x0 * sin_phi + x1 * cos_phi;
1013
1014 if d_key == 2 {
1015 Tensor::cat(vec![rot0, rot1], 2)
1016 } else {
1017 let rest = z.narrow(2, 2, d_key - 2);
1018 Tensor::cat(vec![rot0, rot1, rest], 2)
1019 }
1020}
1021
1022fn linear<B: Backend>(x: Tensor<B, 3>, weight: Tensor<B, 2>) -> Tensor<B, 3> {
1023 let [batch, _seq_len, _in_dim] = x.dims();
1024 let [weight_in, out_dim] = weight.dims();
1025 x.matmul(weight.unsqueeze::<3>().expand([batch, weight_in, out_dim]))
1026}
1027
1028fn expand_scalar2<B: Backend>(value: Tensor<B, 1>, dims: [usize; 2]) -> Tensor<B, 2> {
1029 value.unsqueeze::<2>().expand(dims)
1030}
1031
1032fn expand_scalar3<B: Backend>(value: Tensor<B, 1>, dims: [usize; 3]) -> Tensor<B, 3> {
1033 value.unsqueeze::<3>().expand(dims)
1034}
1035
1036fn init_scalar<B: Backend>(value: f32, device: &B::Device) -> Param<Tensor<B, 1>> {
1037 Param::from_tensor(Tensor::<B, 1>::from_data([value], device))
1038}
1039
1040fn init_matrix<B: Backend>(
1041 rows: usize,
1042 cols: usize,
1043 std: f32,
1044 seed: &mut u64,
1045 device: &B::Device,
1046) -> Param<Tensor<B, 2>> {
1047 let values = gaussian_values(rows * cols, std, seed);
1048 Param::from_tensor(Tensor::<B, 2>::from_data(
1049 TensorData::new(values, [rows, cols]),
1050 device,
1051 ))
1052}
1053
1054fn gaussian_values(len: usize, std: f32, seed: &mut u64) -> Vec<f32> {
1055 let mut values = Vec::with_capacity(len);
1056 while values.len() < len {
1057 let u1 = next_uniform(seed).max(1e-7);
1058 let u2 = next_uniform(seed);
1059 let radius = (-2.0 * u1.ln()).sqrt();
1060 let theta = 2.0 * PI * u2;
1061 values.push(radius * theta.cos() * std);
1062 if values.len() < len {
1063 values.push(radius * theta.sin() * std);
1064 }
1065 }
1066 values
1067}
1068
1069fn next_uniform(seed: &mut u64) -> f32 {
1070 *seed = seed
1071 .wrapping_mul(6364136223846793005)
1072 .wrapping_add(1442695040888963407);
1073 let bits = (*seed >> 40) as u32;
1074 (bits as f32 + 1.0) / ((1u32 << 24) as f32 + 2.0)
1075}
1076
1077fn context_window(context: &[u32], seq_len: usize, pad_token_id: u32) -> Vec<u32> {
1078 let mut input = vec![pad_token_id; seq_len];
1079 let suffix = if context.len() > seq_len {
1080 &context[context.len() - seq_len..]
1081 } else {
1082 context
1083 };
1084 let start = seq_len - suffix.len();
1085 input[start..].copy_from_slice(suffix);
1086 input
1087}
1088
1089fn argmax(values: &[f32]) -> Result<usize> {
1090 values
1091 .iter()
1092 .enumerate()
1093 .max_by(|left, right| left.1.total_cmp(right.1))
1094 .map(|(index, _)| index)
1095 .ok_or_else(|| Error::Inference("cannot argmax empty logits".to_string()))
1096}
1097
1098fn tensor_scalar<B: Backend>(tensor: Tensor<B, 1>) -> Result<f32> {
1099 tensor
1100 .into_data()
1101 .into_vec::<f32>()
1102 .map_err(|err| Error::Training(err.to_string()))?
1103 .into_iter()
1104 .next()
1105 .ok_or_else(|| Error::Training("expected scalar tensor".to_string()))
1106}
1107
1108fn tensor_to_vec<B: Backend>(tensor: Tensor<B, 1>) -> Result<Vec<f32>> {
1109 tensor
1110 .into_data()
1111 .into_vec::<f32>()
1112 .map_err(|err| Error::Inference(err.to_string()))
1113}
1114
1115fn tensor_from_u32<B: Backend, const D: usize>(
1116 values: Vec<u32>,
1117 shape: [usize; D],
1118 device: &B::Device,
1119) -> Result<Tensor<B, D, Int>> {
1120 let values = values
1121 .into_iter()
1122 .map(|value| {
1123 i32::try_from(value)
1124 .map_err(|_| Error::Config(format!("token id {value} exceeds i32::MAX")))
1125 })
1126 .collect::<Result<Vec<_>>>()?;
1127 Ok(Tensor::<B, D, Int>::from_data(
1128 TensorData::new(values, shape),
1129 device,
1130 ))
1131}
1132
1133fn ensure(condition: bool, message: &str) -> Result<()> {
1134 if condition {
1135 Ok(())
1136 } else {
1137 Err(Error::Config(message.to_string()))
1138 }
1139}
1140
1141struct TrainingWindow {
1142 inputs: Vec<u32>,
1143 targets: Vec<u32>,
1144 loss_mask: Vec<f32>,
1145}
1146
1147struct TrainingWindows {
1148 windows: Vec<TrainingWindow>,
1149 seq_len: usize,
1150}
1151
1152impl TrainingWindows {
1153 fn from_chat_sequences(
1159 chat_pairs: &[(Vec<u32>, Vec<u32>)],
1160 seq_len: usize,
1161 pad_token_id: u32,
1162 ) -> Result<Self> {
1163 let mut windows = Vec::new();
1164 for (prompt_ids, response_ids) in chat_pairs {
1165 let mut full_seq = prompt_ids.clone();
1166 full_seq.extend_from_slice(response_ids);
1167
1168 if full_seq.len() < 2 {
1169 continue;
1170 }
1171
1172 let prompt_len = prompt_ids.len();
1173
1174 let mut start = 0;
1175 while start + 1 < full_seq.len() {
1176 let end = (start + seq_len + 1).min(full_seq.len());
1177 let chunk = &full_seq[start..end];
1178 let prediction_count = chunk.len() - 1;
1179
1180 let mut inputs = vec![pad_token_id; seq_len];
1181 let mut targets = vec![pad_token_id; seq_len];
1182 let mut loss_mask = vec![0.0; seq_len];
1183 inputs[..prediction_count].copy_from_slice(&chunk[..prediction_count]);
1184 targets[..prediction_count].copy_from_slice(&chunk[1..]);
1185
1186 let mut has_unmasked = false;
1189 for (i, mask) in loss_mask.iter_mut().enumerate().take(prediction_count) {
1190 let target_global_idx = start + i + 1;
1191 if target_global_idx >= prompt_len {
1192 *mask = 1.0;
1193 has_unmasked = true;
1194 }
1195 }
1196
1197 if !has_unmasked {
1201 if end == full_seq.len() {
1202 break;
1203 }
1204 start += seq_len;
1205 continue;
1206 }
1207
1208 windows.push(TrainingWindow {
1209 inputs,
1210 targets,
1211 loss_mask,
1212 });
1213
1214 if end == full_seq.len() {
1215 break;
1216 }
1217 start += seq_len;
1218 }
1219 }
1220
1221 Ok(Self { windows, seq_len })
1222 }
1223
1224 fn from_sequences(sequences: &[Vec<u32>], seq_len: usize, pad_token_id: u32) -> Result<Self> {
1225 let mut windows = Vec::new();
1226 for sequence in sequences {
1227 if sequence.len() < 2 {
1228 continue;
1229 }
1230
1231 let mut start = 0;
1232 while start + 1 < sequence.len() {
1233 let end = (start + seq_len + 1).min(sequence.len());
1234 let chunk = &sequence[start..end];
1235 let prediction_count = chunk.len() - 1;
1236
1237 let mut inputs = vec![pad_token_id; seq_len];
1238 let mut targets = vec![pad_token_id; seq_len];
1239 let mut loss_mask = vec![0.0; seq_len];
1240 inputs[..prediction_count].copy_from_slice(&chunk[..prediction_count]);
1241 targets[..prediction_count].copy_from_slice(&chunk[1..]);
1242 loss_mask[..prediction_count].fill(1.0);
1243
1244 windows.push(TrainingWindow {
1245 inputs,
1246 targets,
1247 loss_mask,
1248 });
1249
1250 if end == sequence.len() {
1251 break;
1252 }
1253 start += seq_len;
1254 }
1255 }
1256
1257 Ok(Self { windows, seq_len })
1258 }
1259
1260 fn is_empty(&self) -> bool {
1261 self.windows.is_empty()
1262 }
1263
1264 fn len(&self) -> usize {
1265 self.windows.len()
1266 }
1267
1268 fn batch<B: Backend>(
1269 &self,
1270 step: usize,
1271 batch_size: usize,
1272 device: &B::Device,
1273 ) -> Result<TokenBatch<B>> {
1274 let mut inputs = Vec::with_capacity(batch_size * self.seq_len);
1275 let mut targets = Vec::with_capacity(batch_size * self.seq_len);
1276 let mut loss_mask = Vec::with_capacity(batch_size * self.seq_len);
1277
1278 for batch_idx in 0..batch_size {
1279 let index = (step * batch_size + batch_idx) % self.windows.len();
1280 let window = &self.windows[index];
1281 inputs.extend_from_slice(&window.inputs);
1282 targets.extend_from_slice(&window.targets);
1283 loss_mask.extend_from_slice(&window.loss_mask);
1284 }
1285
1286 Ok(TokenBatch {
1287 inputs: tensor_from_u32(inputs, [batch_size, self.seq_len], device)?,
1288 targets: tensor_from_u32(targets, [batch_size, self.seq_len], device)?,
1289 loss_mask: Tensor::<B, 2>::from_data(
1290 TensorData::new(loss_mask, [batch_size, self.seq_len]),
1291 device,
1292 ),
1293 })
1294 }
1295}
1296
1297struct TokenBatch<B: Backend> {
1298 inputs: Tensor<B, 2, Int>,
1299 targets: Tensor<B, 2, Int>,
1300 loss_mask: Tensor<B, 2>,
1301}
1302
1303#[cfg(test)]
1304#[allow(dead_code)]
1305pub fn make_batch<B: Backend>(
1306 step: usize,
1307 batch_size: usize,
1308 seq_len: usize,
1309 vocab_size: usize,
1310 device: &B::Device,
1311) -> Result<(Tensor<B, 2, Int>, Tensor<B, 2, Int>)> {
1312 let mut inputs = Vec::with_capacity(batch_size * seq_len);
1313 let mut targets = Vec::with_capacity(batch_size * seq_len);
1314
1315 for batch in 0..batch_size {
1316 let offset = (step * 7 + batch * 13) % vocab_size;
1317 for pos in 0..seq_len {
1318 let token = ((offset + pos) % vocab_size) as u32;
1319 let next = ((offset + pos + 1) % vocab_size) as u32;
1320 inputs.push(token);
1321 targets.push(next);
1322 }
1323 }
1324
1325 Ok((
1326 tensor_from_u32(inputs, [batch_size, seq_len], device)?,
1327 tensor_from_u32(targets, [batch_size, seq_len], device)?,
1328 ))
1329}