1use crate::{
2 error::{Error, Result},
3 runtime::DefaultAutodiffBackend,
4};
5use burn::{
6 grad_clipping::GradientClippingConfig,
7 module::{Module, Param},
8 optim::{AdamWConfig, GradientsParams, Optimizer},
9 record::{FullPrecisionSettings, NamedMpkFileRecorder},
10 tensor::{
11 activation,
12 backend::{AutodiffBackend, Backend},
13 Int, Tensor, TensorData,
14 },
15};
16use serde::{Deserialize, Serialize};
17use std::f32::consts::PI;
18use std::path::Path;
19
20const EPS: f32 = 1e-6;
21
22#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
28pub enum MultiscreenParameterBudget {
29 Params1M,
30 Params5M,
31 Params10M,
32 Params50M,
33 Params100M,
34}
35
36impl MultiscreenParameterBudget {
37 pub const ALL: [Self; 5] = [
38 Self::Params1M,
39 Self::Params5M,
40 Self::Params10M,
41 Self::Params50M,
42 Self::Params100M,
43 ];
44
45 pub fn label(self) -> &'static str {
46 match self {
47 Self::Params1M => "1M",
48 Self::Params5M => "5M",
49 Self::Params10M => "10M",
50 Self::Params50M => "50M",
51 Self::Params100M => "100M",
52 }
53 }
54
55 pub fn target_parameter_count(self) -> usize {
56 match self {
57 Self::Params1M => 1_000_000,
58 Self::Params5M => 5_000_000,
59 Self::Params10M => 10_000_000,
60 Self::Params50M => 50_000_000,
61 Self::Params100M => 100_000_000,
62 }
63 }
64}
65
66#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
68pub struct MultiscreenModelConfig {
69 pub vocab_size: usize,
71 pub seq_len: usize,
73 pub layers: usize,
75 pub tiles: usize,
77 pub d_model: usize,
79 pub d_key: usize,
81 pub d_value: usize,
83 pub w_th: f32,
85}
86
87impl MultiscreenModelConfig {
88 pub fn tiny() -> Self {
89 Self {
90 vocab_size: 64,
91 seq_len: 64,
92 layers: 2,
93 tiles: 4,
94 d_model: 64,
95 d_key: 16,
96 d_value: 32,
97 w_th: 32.0,
98 }
99 }
100
101 pub fn tiny_for_tests() -> Self {
102 Self {
103 vocab_size: 32,
104 seq_len: 8,
105 layers: 1,
106 tiles: 2,
107 d_model: 16,
108 d_key: 4,
109 d_value: 8,
110 w_th: 8.0,
111 }
112 }
113
114 pub fn for_parameter_budget(
115 budget: MultiscreenParameterBudget,
116 vocab_size: usize,
117 seq_len: usize,
118 ) -> Self {
119 match budget {
120 MultiscreenParameterBudget::Params1M => Self::preset_1m(vocab_size, seq_len),
121 MultiscreenParameterBudget::Params5M => Self::preset_5m(vocab_size, seq_len),
122 MultiscreenParameterBudget::Params10M => Self::preset_10m(vocab_size, seq_len),
123 MultiscreenParameterBudget::Params50M => Self::preset_50m(vocab_size, seq_len),
124 MultiscreenParameterBudget::Params100M => Self::preset_100m(vocab_size, seq_len),
125 }
126 }
127
128 pub fn preset_1m(vocab_size: usize, seq_len: usize) -> Self {
129 Self::from_dimensions(vocab_size, seq_len, 2, 2, 128, 32, 64)
130 }
131
132 pub fn preset_5m(vocab_size: usize, seq_len: usize) -> Self {
133 Self::from_dimensions(vocab_size, seq_len, 2, 4, 384, 96, 192)
134 }
135
136 pub fn preset_10m(vocab_size: usize, seq_len: usize) -> Self {
137 Self::from_dimensions(vocab_size, seq_len, 3, 4, 512, 128, 256)
138 }
139
140 pub fn preset_50m(vocab_size: usize, seq_len: usize) -> Self {
141 Self::from_dimensions(vocab_size, seq_len, 6, 4, 960, 240, 480)
142 }
143
144 pub fn preset_100m(vocab_size: usize, seq_len: usize) -> Self {
145 Self::from_dimensions(vocab_size, seq_len, 8, 4, 1216, 304, 608)
146 }
147
148 pub fn paper_10m(vocab_size: usize, seq_len: usize) -> Self {
149 Self::preset_10m(vocab_size, seq_len)
150 }
151
152 pub fn estimated_parameter_count(&self) -> usize {
153 let embedding_params = self.vocab_size.saturating_mul(self.d_model);
154 let per_tile_params = self
155 .d_model
156 .saturating_mul(
157 2usize
158 .saturating_mul(self.d_key)
159 .saturating_add(3usize.saturating_mul(self.d_value)),
160 )
161 .saturating_add(3);
162 let tile_params = self
163 .layers
164 .saturating_mul(self.tiles)
165 .saturating_mul(per_tile_params);
166 embedding_params
167 .saturating_add(2)
168 .saturating_add(tile_params)
169 }
170
171 fn from_dimensions(
172 vocab_size: usize,
173 seq_len: usize,
174 layers: usize,
175 tiles: usize,
176 d_model: usize,
177 d_key: usize,
178 d_value: usize,
179 ) -> Self {
180 Self {
181 vocab_size,
182 seq_len,
183 layers,
184 tiles,
185 d_model,
186 d_key,
187 d_value,
188 w_th: 32.0,
189 }
190 }
191
192 pub fn validate(&self) -> Result<()> {
193 ensure(self.vocab_size > 0, "vocab_size must be greater than zero")?;
194 ensure(self.seq_len > 0, "seq_len must be greater than zero")?;
195 ensure(self.layers > 0, "layers must be greater than zero")?;
196 ensure(self.tiles > 0, "tiles must be greater than zero")?;
197 ensure(self.d_model > 0, "d_model must be greater than zero")?;
198 ensure(self.d_key >= 2, "d_key must be at least 2 for MiPE")?;
199 ensure(self.d_value > 0, "d_value must be greater than zero")?;
200 ensure(
201 self.w_th.is_finite() && self.w_th > 0.0,
202 "w_th must be positive and finite",
203 )?;
204 Ok(())
205 }
206}
207
208#[derive(Clone, Debug)]
210pub struct ModelTrainingConfig {
211 pub steps: usize,
212 pub batch_size: usize,
213 pub learning_rate: f64,
214 pub weight_decay: f64,
215 pub grad_clip_norm: Option<f64>,
216 pub pad_token_id: u32,
217}
218
219impl Default for ModelTrainingConfig {
220 fn default() -> Self {
221 Self {
222 steps: 100,
223 batch_size: 4,
224 learning_rate: 2e-4,
225 weight_decay: 0.01,
226 grad_clip_norm: Some(1.0),
227 pad_token_id: 0,
228 }
229 }
230}
231
232#[derive(Clone, Debug)]
234pub struct ModelInferenceConfig {
235 pub max_new_tokens: usize,
236 pub pad_token_id: u32,
237}
238
239impl Default for ModelInferenceConfig {
240 fn default() -> Self {
241 Self {
242 max_new_tokens: 16,
243 pad_token_id: 0,
244 }
245 }
246}
247
248#[derive(Clone, Debug, PartialEq)]
249pub struct ModelTrainingReport {
250 pub steps: usize,
251 pub final_loss: f32,
252 pub training_window_count: usize,
253 pub parameter_count: usize,
254}
255
256#[derive(Clone, Debug)]
258pub struct EvaluationResult {
259 pub loss: f32,
261 pub perplexity: f32,
263 pub accuracy: f64,
265 pub num_batches: usize,
267 pub total_tokens: usize,
269}
270
271pub struct MultiscreenModelOutput {
272 pub token_ids: Vec<u32>,
273}
274
275#[derive(Module, Debug)]
278pub struct MultiscreenModel<B: Backend = DefaultAutodiffBackend> {
279 #[module(skip)]
280 config: MultiscreenModelConfig,
281 token_embedding: Param<Tensor<B, 2>>,
282 s_e: Param<Tensor<B, 1>>,
283 s_f: Param<Tensor<B, 1>>,
284 layers: Vec<MultiscreenLayer<B>>,
285}
286
287pub type DefaultMultiscreenModel = MultiscreenModel<DefaultAutodiffBackend>;
289
290#[derive(Module, Debug)]
291struct MultiscreenLayer<B: Backend> {
292 tiles: Vec<GatedScreeningTile<B>>,
293}
294
295#[derive(Module, Debug)]
296struct GatedScreeningTile<B: Backend> {
297 w_q: Param<Tensor<B, 2>>,
298 w_k: Param<Tensor<B, 2>>,
299 w_v: Param<Tensor<B, 2>>,
300 w_g: Param<Tensor<B, 2>>,
301 w_o: Param<Tensor<B, 2>>,
302 s_w: Param<Tensor<B, 1>>,
303 s_r: Param<Tensor<B, 1>>,
304 s_o: Param<Tensor<B, 1>>,
305 #[module(skip)]
306 w_th: f32,
307}
308
309impl<B: Backend> MultiscreenModel<B> {
310 pub fn new(config: MultiscreenModelConfig, device: &B::Device) -> Result<Self> {
311 config.validate()?;
312
313 let mut seed = 0x4d55_4c54_4953_4352;
314 let token_embedding = init_matrix(
315 config.vocab_size,
316 config.d_model,
317 0.1 / (config.d_model as f32).sqrt(),
318 &mut seed,
319 device,
320 );
321 let s_e = init_scalar(0.0, device);
322 let s_f = init_scalar((config.d_model as f32).sqrt().ln(), device);
323
324 let mut layers = Vec::with_capacity(config.layers);
325 for _layer_idx in 0..config.layers {
326 let mut tiles = Vec::with_capacity(config.tiles);
327 for tile_idx in 0..config.tiles {
328 let w_q = init_matrix(
329 config.d_model,
330 config.d_key,
331 0.1 / (config.d_key as f32).sqrt(),
332 &mut seed,
333 device,
334 );
335 let w_k = init_matrix(
336 config.d_model,
337 config.d_key,
338 0.1 / (config.d_key as f32).sqrt(),
339 &mut seed,
340 device,
341 );
342 let w_v = init_matrix(
343 config.d_model,
344 config.d_value,
345 0.1 / (config.d_value as f32).sqrt(),
346 &mut seed,
347 device,
348 );
349 let w_g = init_matrix(config.d_model, config.d_value, 0.1, &mut seed, device);
350 let w_o = init_matrix(
351 config.d_value,
352 config.d_model,
353 0.1 / (config.d_model as f32).sqrt(),
354 &mut seed,
355 device,
356 );
357
358 let window_frac = if config.tiles == 1 {
359 0.0
360 } else {
361 tile_idx as f32 / (config.tiles - 1) as f32
362 };
363 let s_w = init_scalar(window_frac * config.w_th.ln(), device);
364 let s_r = init_scalar(0.0, device);
365 let s_o = init_scalar(-0.5 * ((config.layers * config.tiles) as f32).ln(), device);
366
367 tiles.push(GatedScreeningTile {
368 w_q,
369 w_k,
370 w_v,
371 w_g,
372 w_o,
373 s_w,
374 s_r,
375 s_o,
376 w_th: config.w_th,
377 });
378 }
379 layers.push(MultiscreenLayer { tiles });
380 }
381
382 Ok(Self {
383 config,
384 token_embedding,
385 s_e,
386 s_f,
387 layers,
388 })
389 }
390
391 pub fn config(&self) -> &MultiscreenModelConfig {
392 &self.config
393 }
394
395 pub fn parameter_count(&self) -> usize {
396 self.num_params()
397 }
398
399 pub fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
400 let [batch, seq_len] = tokens.dims();
401 let embedding = row_unit_normalize(self.token_embedding.val());
402 let one_hot = tokens.one_hot::<3>(self.config().vocab_size).float();
403 let mut x =
404 linear(one_hot, embedding.clone()).reshape([batch, seq_len, self.config().d_model]);
405 let x_dims = x.dims();
406 x = x * expand_scalar3(self.s_e.val().exp(), x_dims);
407
408 for layer in &self.layers {
409 let mut layer_update = Tensor::<B, 3>::zeros(x.dims(), &x.device());
410 for tile in &layer.tiles {
411 layer_update = layer_update + tile.forward(x.clone());
412 }
413 x = x + layer_update;
414 }
415
416 let logits_weight = embedding.swap_dims(0, 1);
417 let logits = linear(x, logits_weight);
418 logits.clone() * expand_scalar3(self.s_f.val().exp(), logits.dims())
419 }
420
421 pub fn save_parameters(&self, path: impl AsRef<Path>) -> Result<()> {
422 let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
423 self.clone()
424 .save_file(path.as_ref().to_path_buf(), &recorder)
425 .map_err(|err| Error::Serialization(err.to_string()))
426 }
427
428 pub fn load_parameters(&mut self, path: impl AsRef<Path>) -> Result<()> {
429 let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
430 let device =
431 self.devices().into_iter().next().ok_or_else(|| {
432 Error::Serialization("model has no device for parameter load".into())
433 })?;
434 let loaded = self
435 .clone()
436 .load_file(path.as_ref().to_path_buf(), &recorder, &device)
437 .map_err(|err| Error::Serialization(err.to_string()))?;
438 *self = loaded;
439 Ok(())
440 }
441
442 pub fn infer_tokens_stream(
453 &self,
454 prompt: &[u32],
455 inference: &ModelInferenceConfig,
456 device: &B::Device,
457 mut on_token: impl FnMut(u32, usize) -> bool,
458 ) -> Result<MultiscreenModelOutput> {
459 if prompt.is_empty() {
460 return Err(Error::Inference(
461 "prompt must contain at least one token".to_string(),
462 ));
463 }
464
465 let mut output = prompt.to_vec();
466 for i in 0..inference.max_new_tokens {
467 let next = self.predict_next_token(&output, inference.pad_token_id, device)?;
468 output.push(next);
469 if !on_token(next, i) {
470 break;
471 }
472 }
473
474 Ok(MultiscreenModelOutput { token_ids: output })
475 }
476
477 pub fn infer_tokens(
481 &self,
482 prompt: &[u32],
483 inference: &ModelInferenceConfig,
484 device: &B::Device,
485 ) -> Result<MultiscreenModelOutput> {
486 self.infer_tokens_stream(prompt, inference, device, |_, _| true)
487 }
488
489 pub fn predict_next_token(
490 &self,
491 context: &[u32],
492 pad_token_id: u32,
493 device: &B::Device,
494 ) -> Result<u32> {
495 if context.is_empty() {
496 return Err(Error::Inference(
497 "context must contain at least one token".to_string(),
498 ));
499 }
500
501 let input = context_window(context, self.config().seq_len, pad_token_id);
502 let input = tensor_from_u32::<B, 2>(input, [1, self.config().seq_len], device)?;
503 let logits = self.forward(input);
504 let last_logits = logits
505 .slice([
506 0..1,
507 self.config().seq_len - 1..self.config().seq_len,
508 0..self.config().vocab_size,
509 ])
510 .reshape([self.config().vocab_size]);
511 let values = tensor_to_vec(last_logits)?;
512 argmax(&values).map(|idx| idx as u32)
513 }
514
515 pub fn forward_logits(
523 &self,
524 context: &[u32],
525 pad_token_id: u32,
526 device: &B::Device,
527 ) -> Result<Tensor<B, 3>> {
528 if context.is_empty() {
529 return Err(Error::Inference(
530 "context must contain at least one token".to_string(),
531 ));
532 }
533
534 let input = context_window(context, self.config().seq_len, pad_token_id);
535 let input = tensor_from_u32::<B, 2>(input, [1, self.config().seq_len], device)?;
536 let logits = self.forward(input);
537 Ok(logits)
538 }
539
540 pub fn evaluate_on_sequences(
546 &self,
547 sequences: &[Vec<u32>],
548 seq_len: usize,
549 batch_size: usize,
550 pad_token_id: u32,
551 device: &B::Device,
552 ) -> Result<EvaluationResult> {
553 let windows = TrainingWindows::from_sequences(sequences, seq_len, pad_token_id)?;
554 if windows.is_empty() {
555 return Ok(EvaluationResult {
556 loss: f32::NAN,
557 perplexity: f32::NAN,
558 accuracy: 0.0,
559 num_batches: 0,
560 total_tokens: 0,
561 });
562 }
563
564 let num_batches = windows.len().div_ceil(batch_size);
565 let mut total_loss = 0.0_f64;
566 let mut total_correct = 0_usize;
567 let mut total_tokens = 0_usize;
568
569 for step in 0..num_batches {
570 let batch = windows.batch::<B>(step, batch_size, device)?;
571 let logits = self.forward(batch.inputs); let loss = cross_entropy_loss_with_mask(
573 logits.clone(),
574 batch.targets.clone(),
575 batch.loss_mask.clone(),
576 );
577 let loss_val = tensor_scalar(loss)? as f64;
578 total_loss += loss_val;
579
580 let [b, s, v] = logits.dims();
582 let mask_vec: Vec<f32> = batch
583 .loss_mask
584 .clone()
585 .reshape([b * s])
586 .into_data()
587 .into_vec::<f32>()
588 .map_err(|e| Error::Inference(e.to_string()))?;
589 let target_vec: Vec<i32> = batch
590 .targets
591 .clone()
592 .reshape([b * s])
593 .into_data()
594 .into_vec::<i32>()
595 .map_err(|e| Error::Inference(e.to_string()))?;
596 let logit_vec: Vec<f32> = logits
597 .reshape([b * s * v])
598 .into_data()
599 .into_vec::<f32>()
600 .map_err(|e| Error::Inference(e.to_string()))?;
601
602 for bi in 0..b {
603 for si in 0..s {
604 let mi = bi * s + si;
605 if mask_vec[mi] < 0.5 {
606 continue;
607 }
608 total_tokens += 1;
609 let base = bi * s * v + si * v;
610 let mut best_idx = 0;
611 let mut best_val = f32::NEG_INFINITY;
612 for vi in 0..v {
613 let val = logit_vec[base + vi];
614 if val > best_val {
615 best_val = val;
616 best_idx = vi;
617 }
618 }
619 if best_idx == target_vec[mi] as usize {
620 total_correct += 1;
621 }
622 }
623 }
624 }
625
626 let avg_loss = total_loss / num_batches as f64;
627 let perplexity = avg_loss.exp();
628 let accuracy = if total_tokens > 0 {
629 total_correct as f64 / total_tokens as f64
630 } else {
631 0.0
632 };
633
634 Ok(EvaluationResult {
635 loss: avg_loss as f32,
636 perplexity: perplexity as f32,
637 accuracy,
638 num_batches,
639 total_tokens,
640 })
641 }
642}
643
644impl<B> MultiscreenModel<B>
645where
646 B: AutodiffBackend,
647{
648 pub fn train_token_sequences(
653 &mut self,
654 sequences: &[Vec<u32>],
655 training: &ModelTrainingConfig,
656 device: &B::Device,
657 mut on_step: impl FnMut(usize, f32),
658 ) -> Result<ModelTrainingReport> {
659 if training.batch_size == 0 {
660 return Err(Error::Training(
661 "batch_size must be greater than zero".to_string(),
662 ));
663 }
664 let windows = TrainingWindows::from_sequences(
665 sequences,
666 self.config().seq_len,
667 training.pad_token_id,
668 )?;
669 if windows.is_empty() {
670 return Err(Error::Training(
671 "training requires at least one sequence with two or more tokens".to_string(),
672 ));
673 }
674
675 let mut optimizer_config =
676 AdamWConfig::new().with_weight_decay(training.weight_decay as f32);
677 if let Some(max_norm) = training.grad_clip_norm.filter(|value| *value > 0.0) {
678 optimizer_config = optimizer_config
679 .with_grad_clipping(Some(GradientClippingConfig::Norm(max_norm as f32)));
680 }
681 let mut optimizer = optimizer_config.init::<B, Self>();
682 let mut model = self.clone();
683 let mut final_loss = f32::NAN;
684
685 for step in 0..training.steps {
686 let batch = windows.batch::<B>(step, training.batch_size, device)?;
687 let logits = model.forward(batch.inputs);
688 let loss = cross_entropy_loss_with_mask(logits, batch.targets, batch.loss_mask);
689 final_loss = tensor_scalar(loss.clone())?;
690 let grads = loss.backward();
691 let grads = GradientsParams::from_grads(grads, &model);
692 model = optimizer.step(training.learning_rate, model, grads);
693 on_step(step, final_loss);
694 }
695
696 if training.steps == 0 {
697 let batch = windows.batch::<B>(0, training.batch_size, device)?;
698 final_loss = tensor_scalar(cross_entropy_loss_with_mask(
699 model.forward(batch.inputs),
700 batch.targets,
701 batch.loss_mask,
702 ))?;
703 }
704
705 *self = model;
706
707 Ok(ModelTrainingReport {
708 steps: training.steps,
709 final_loss,
710 training_window_count: windows.len(),
711 parameter_count: self.parameter_count(),
712 })
713 }
714
715 pub fn train_chat_sequences(
727 &mut self,
728 chat_pairs: &[(Vec<u32>, Vec<u32>)],
729 training: &ModelTrainingConfig,
730 device: &B::Device,
731 mut on_step: impl FnMut(usize, f32),
732 ) -> Result<ModelTrainingReport> {
733 if training.batch_size == 0 {
734 return Err(Error::Training(
735 "batch_size must be greater than zero".to_string(),
736 ));
737 }
738 let windows = TrainingWindows::from_chat_sequences(
739 chat_pairs,
740 self.config().seq_len,
741 training.pad_token_id,
742 )?;
743 if windows.is_empty() {
744 return Err(Error::Training(
745 "training requires at least one chat pair that produces two or more tokens"
746 .to_string(),
747 ));
748 }
749
750 let mut optimizer_config =
751 AdamWConfig::new().with_weight_decay(training.weight_decay as f32);
752 if let Some(max_norm) = training.grad_clip_norm.filter(|value| *value > 0.0) {
753 optimizer_config = optimizer_config
754 .with_grad_clipping(Some(GradientClippingConfig::Norm(max_norm as f32)));
755 }
756 let mut optimizer = optimizer_config.init::<B, Self>();
757 let mut model = self.clone();
758 let mut final_loss = f32::NAN;
759
760 for step in 0..training.steps {
761 let batch = windows.batch::<B>(step, training.batch_size, device)?;
762 let logits = model.forward(batch.inputs);
763 let loss = cross_entropy_loss_with_mask(logits, batch.targets, batch.loss_mask);
764 final_loss = tensor_scalar(loss.clone())?;
765 let grads = loss.backward();
766 let grads = GradientsParams::from_grads(grads, &model);
767 model = optimizer.step(training.learning_rate, model, grads);
768 on_step(step, final_loss);
769 }
770
771 if training.steps == 0 {
772 let batch = windows.batch::<B>(0, training.batch_size, device)?;
773 final_loss = tensor_scalar(cross_entropy_loss_with_mask(
774 model.forward(batch.inputs),
775 batch.targets,
776 batch.loss_mask,
777 ))?;
778 }
779
780 *self = model;
781
782 Ok(ModelTrainingReport {
783 steps: training.steps,
784 final_loss,
785 training_window_count: windows.len(),
786 parameter_count: self.parameter_count(),
787 })
788 }
789}
790
791impl<B: Backend> crate::lm::LanguageModel<B> for MultiscreenModel<B> {
792 fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
793 MultiscreenModel::forward(self, tokens)
794 }
795}
796
797impl<B: Backend> crate::lm::TrainableLanguageModel<B> for MultiscreenModel<B> {}
798
799impl<B: Backend> GatedScreeningTile<B> {
800 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
801 let q = row_unit_normalize(linear(x.clone(), self.w_q.val()));
802 let k = row_unit_normalize(linear(x.clone(), self.w_k.val()));
803 let v = row_unit_normalize(linear(x.clone(), self.w_v.val()));
804 let g = linear(x, self.w_g.val());
805
806 let w = self.s_w.val().clamp(-10.0, 8.0).exp() + 1.0;
807 let r = activation::sigmoid(self.s_r.val().clamp(-10.0, 8.0));
808
809 let q = apply_mipe(q, w.clone(), self.w_th);
810 let k = apply_mipe(k, w.clone(), self.w_th);
811
812 let similarity = q.matmul(k.swap_dims(1, 2));
813 let alpha = trim_and_square_tensor(similarity.clone(), r);
814 let softmask = causal_softmask_tensor::<B>(similarity.dims()[1], w, &similarity.device());
815 let relevance = alpha * softmask.unsqueeze();
816 let h = relevance.matmul(v);
817 let u = tanh_norm(h);
818 let gate = activation::silu(g).tanh();
819 let gated = u * gate;
820 let out = linear(gated, self.w_o.val());
821 out.clone() * expand_scalar3(self.s_o.val().exp(), out.dims())
822 }
823}
824
825#[allow(dead_code)]
826pub fn cross_entropy_loss<B: Backend>(
827 logits: Tensor<B, 3>,
828 targets: Tensor<B, 2, Int>,
829) -> Tensor<B, 1> {
830 let device = logits.device();
831 let [batch, seq_len, _] = logits.dims();
832 let loss_mask = Tensor::<B, 2>::ones([batch, seq_len], &device);
833 cross_entropy_loss_with_mask(logits, targets, loss_mask)
834}
835
836pub fn cross_entropy_loss_with_mask<B: Backend>(
837 logits: Tensor<B, 3>,
838 targets: Tensor<B, 2, Int>,
839 loss_mask: Tensor<B, 2>,
840) -> Tensor<B, 1> {
841 let [batch, seq_len, vocab_size] = logits.dims();
842 let token_count = batch * seq_len;
843 let flat_logits = logits.reshape([token_count, vocab_size]);
844 let flat_targets = targets.reshape([token_count]);
845 let flat_mask = loss_mask.reshape([token_count]);
846 let log_probs = activation::log_softmax(flat_logits, 1);
847 let target_probs = flat_targets.one_hot::<2>(vocab_size).float();
848 let picked = (log_probs * target_probs).sum_dim(1).reshape([token_count]);
849 let denom = flat_mask.clone().sum().add_scalar(EPS);
850 (picked.neg() * flat_mask).sum() / denom
851}
852
853pub fn row_unit_normalize<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
854 let denom = x.clone().square().sum_dim(D - 1).add_scalar(EPS).sqrt();
855 x / denom
856}
857
858fn trim_and_square_tensor<B: Backend>(similarity: Tensor<B, 3>, r: Tensor<B, 1>) -> Tensor<B, 3> {
859 let distance_from_one = similarity.mul_scalar(-1.0).add_scalar(1.0);
860 let scaled = distance_from_one.clone() / expand_scalar3(r, distance_from_one.clone().dims());
861 scaled
862 .mul_scalar(-1.0)
863 .add_scalar(1.0)
864 .clamp(0.0, 1.0)
865 .square()
866}
867
868fn causal_softmask_tensor<B: Backend>(
869 seq_len: usize,
870 w: Tensor<B, 1>,
871 device: &B::Device,
872) -> Tensor<B, 2> {
873 let mut distances = Vec::with_capacity(seq_len * seq_len);
874 for i in 0..seq_len {
875 for j in 0..seq_len {
876 distances.push(j as f32 - i as f32);
877 }
878 }
879
880 let dist = Tensor::<B, 2>::from_data(TensorData::new(distances, [seq_len, seq_len]), device);
881 let w = expand_scalar2(w, [seq_len, seq_len]);
882 let causal = dist.clone().lower_equal_elem(0.0);
883 let within_window = dist.clone().greater(w.clone().neg());
884 let active = causal.float() * within_window.float();
885 let tapered = (dist / w)
886 .mul_scalar(PI)
887 .cos()
888 .mul_scalar(0.5)
889 .add_scalar(0.5);
890 active * tapered
891}
892
893pub fn tanh_norm<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
894 let norm = x.clone().square().sum_dim(D - 1).add_scalar(EPS).sqrt();
895 let scale = norm.clone().tanh() / norm;
896 x * scale
897}
898
899fn apply_mipe<B: Backend>(z: Tensor<B, 3>, w: Tensor<B, 1>, w_th: f32) -> Tensor<B, 3> {
900 let [_batch, seq_len, d_key] = z.dims();
901 debug_assert!(d_key >= 2);
902
903 let positions = Tensor::<B, 3>::from_data(
904 TensorData::new(
905 (0..seq_len).map(|idx| idx as f32).collect::<Vec<_>>(),
906 [1, seq_len, 1],
907 ),
908 &z.device(),
909 );
910
911 let gamma_raw = w
912 .clone()
913 .mul_scalar(PI / w_th)
914 .cos()
915 .mul_scalar(0.5)
916 .add_scalar(0.5);
917 let gamma = gamma_raw.mask_fill(w.clone().greater_equal_elem(w_th), 0.0);
918 let gamma = expand_scalar3(gamma, [1, seq_len, 1]);
919 let w = expand_scalar3(w, [1, seq_len, 1]);
920 let phi = (positions * gamma / w).mul_scalar(PI);
921 let cos_phi = phi.clone().cos();
922 let sin_phi = phi.sin();
923
924 let x0 = z.clone().narrow(2, 0, 1);
925 let x1 = z.clone().narrow(2, 1, 1);
926 let rot0 = x0.clone() * cos_phi.clone() - x1.clone() * sin_phi.clone();
927 let rot1 = x0 * sin_phi + x1 * cos_phi;
928
929 if d_key == 2 {
930 Tensor::cat(vec![rot0, rot1], 2)
931 } else {
932 let rest = z.narrow(2, 2, d_key - 2);
933 Tensor::cat(vec![rot0, rot1, rest], 2)
934 }
935}
936
937fn linear<B: Backend>(x: Tensor<B, 3>, weight: Tensor<B, 2>) -> Tensor<B, 3> {
938 let [batch, _seq_len, _in_dim] = x.dims();
939 let [weight_in, out_dim] = weight.dims();
940 x.matmul(weight.unsqueeze::<3>().expand([batch, weight_in, out_dim]))
941}
942
943fn expand_scalar2<B: Backend>(value: Tensor<B, 1>, dims: [usize; 2]) -> Tensor<B, 2> {
944 value.unsqueeze::<2>().expand(dims)
945}
946
947fn expand_scalar3<B: Backend>(value: Tensor<B, 1>, dims: [usize; 3]) -> Tensor<B, 3> {
948 value.unsqueeze::<3>().expand(dims)
949}
950
951fn init_scalar<B: Backend>(value: f32, device: &B::Device) -> Param<Tensor<B, 1>> {
952 Param::from_tensor(Tensor::<B, 1>::from_data([value], device))
953}
954
955fn init_matrix<B: Backend>(
956 rows: usize,
957 cols: usize,
958 std: f32,
959 seed: &mut u64,
960 device: &B::Device,
961) -> Param<Tensor<B, 2>> {
962 let values = gaussian_values(rows * cols, std, seed);
963 Param::from_tensor(Tensor::<B, 2>::from_data(
964 TensorData::new(values, [rows, cols]),
965 device,
966 ))
967}
968
969fn gaussian_values(len: usize, std: f32, seed: &mut u64) -> Vec<f32> {
970 let mut values = Vec::with_capacity(len);
971 while values.len() < len {
972 let u1 = next_uniform(seed).max(1e-7);
973 let u2 = next_uniform(seed);
974 let radius = (-2.0 * u1.ln()).sqrt();
975 let theta = 2.0 * PI * u2;
976 values.push(radius * theta.cos() * std);
977 if values.len() < len {
978 values.push(radius * theta.sin() * std);
979 }
980 }
981 values
982}
983
984fn next_uniform(seed: &mut u64) -> f32 {
985 *seed = seed
986 .wrapping_mul(6364136223846793005)
987 .wrapping_add(1442695040888963407);
988 let bits = (*seed >> 40) as u32;
989 (bits as f32 + 1.0) / ((1u32 << 24) as f32 + 2.0)
990}
991
992fn context_window(context: &[u32], seq_len: usize, pad_token_id: u32) -> Vec<u32> {
993 let mut input = vec![pad_token_id; seq_len];
994 let suffix = if context.len() > seq_len {
995 &context[context.len() - seq_len..]
996 } else {
997 context
998 };
999 let start = seq_len - suffix.len();
1000 input[start..].copy_from_slice(suffix);
1001 input
1002}
1003
1004fn argmax(values: &[f32]) -> Result<usize> {
1005 values
1006 .iter()
1007 .enumerate()
1008 .max_by(|left, right| left.1.total_cmp(right.1))
1009 .map(|(index, _)| index)
1010 .ok_or_else(|| Error::Inference("cannot argmax empty logits".to_string()))
1011}
1012
1013fn tensor_scalar<B: Backend>(tensor: Tensor<B, 1>) -> Result<f32> {
1014 tensor
1015 .into_data()
1016 .into_vec::<f32>()
1017 .map_err(|err| Error::Training(err.to_string()))?
1018 .into_iter()
1019 .next()
1020 .ok_or_else(|| Error::Training("expected scalar tensor".to_string()))
1021}
1022
1023fn tensor_to_vec<B: Backend>(tensor: Tensor<B, 1>) -> Result<Vec<f32>> {
1024 tensor
1025 .into_data()
1026 .into_vec::<f32>()
1027 .map_err(|err| Error::Inference(err.to_string()))
1028}
1029
1030fn tensor_from_u32<B: Backend, const D: usize>(
1031 values: Vec<u32>,
1032 shape: [usize; D],
1033 device: &B::Device,
1034) -> Result<Tensor<B, D, Int>> {
1035 let values = values
1036 .into_iter()
1037 .map(|value| {
1038 i32::try_from(value)
1039 .map_err(|_| Error::Config(format!("token id {value} exceeds i32::MAX")))
1040 })
1041 .collect::<Result<Vec<_>>>()?;
1042 Ok(Tensor::<B, D, Int>::from_data(
1043 TensorData::new(values, shape),
1044 device,
1045 ))
1046}
1047
1048fn ensure(condition: bool, message: &str) -> Result<()> {
1049 if condition {
1050 Ok(())
1051 } else {
1052 Err(Error::Config(message.to_string()))
1053 }
1054}
1055
1056struct TrainingWindow {
1057 inputs: Vec<u32>,
1058 targets: Vec<u32>,
1059 loss_mask: Vec<f32>,
1060}
1061
1062struct TrainingWindows {
1063 windows: Vec<TrainingWindow>,
1064 seq_len: usize,
1065}
1066
1067impl TrainingWindows {
1068 fn from_chat_sequences(
1074 chat_pairs: &[(Vec<u32>, Vec<u32>)],
1075 seq_len: usize,
1076 pad_token_id: u32,
1077 ) -> Result<Self> {
1078 let mut windows = Vec::new();
1079 for (prompt_ids, response_ids) in chat_pairs {
1080 let mut full_seq = prompt_ids.clone();
1081 full_seq.extend_from_slice(response_ids);
1082
1083 if full_seq.len() < 2 {
1084 continue;
1085 }
1086
1087 let prompt_len = prompt_ids.len();
1088
1089 let mut start = 0;
1090 while start + 1 < full_seq.len() {
1091 let end = (start + seq_len + 1).min(full_seq.len());
1092 let chunk = &full_seq[start..end];
1093 let prediction_count = chunk.len() - 1;
1094
1095 let mut inputs = vec![pad_token_id; seq_len];
1096 let mut targets = vec![pad_token_id; seq_len];
1097 let mut loss_mask = vec![0.0; seq_len];
1098 inputs[..prediction_count].copy_from_slice(&chunk[..prediction_count]);
1099 targets[..prediction_count].copy_from_slice(&chunk[1..]);
1100
1101 for (i, mask) in loss_mask.iter_mut().enumerate().take(prediction_count) {
1104 let target_global_idx = start + i + 1;
1105 if target_global_idx >= prompt_len {
1106 *mask = 1.0;
1107 }
1108 }
1109
1110 windows.push(TrainingWindow {
1111 inputs,
1112 targets,
1113 loss_mask,
1114 });
1115
1116 if end == full_seq.len() {
1117 break;
1118 }
1119 start += seq_len;
1120 }
1121 }
1122
1123 Ok(Self { windows, seq_len })
1124 }
1125
1126 fn from_sequences(sequences: &[Vec<u32>], seq_len: usize, pad_token_id: u32) -> Result<Self> {
1127 let mut windows = Vec::new();
1128 for sequence in sequences {
1129 if sequence.len() < 2 {
1130 continue;
1131 }
1132
1133 let mut start = 0;
1134 while start + 1 < sequence.len() {
1135 let end = (start + seq_len + 1).min(sequence.len());
1136 let chunk = &sequence[start..end];
1137 let prediction_count = chunk.len() - 1;
1138
1139 let mut inputs = vec![pad_token_id; seq_len];
1140 let mut targets = vec![pad_token_id; seq_len];
1141 let mut loss_mask = vec![0.0; seq_len];
1142 inputs[..prediction_count].copy_from_slice(&chunk[..prediction_count]);
1143 targets[..prediction_count].copy_from_slice(&chunk[1..]);
1144 loss_mask[..prediction_count].fill(1.0);
1145
1146 windows.push(TrainingWindow {
1147 inputs,
1148 targets,
1149 loss_mask,
1150 });
1151
1152 if end == sequence.len() {
1153 break;
1154 }
1155 start += seq_len;
1156 }
1157 }
1158
1159 Ok(Self { windows, seq_len })
1160 }
1161
1162 fn is_empty(&self) -> bool {
1163 self.windows.is_empty()
1164 }
1165
1166 fn len(&self) -> usize {
1167 self.windows.len()
1168 }
1169
1170 fn batch<B: Backend>(
1171 &self,
1172 step: usize,
1173 batch_size: usize,
1174 device: &B::Device,
1175 ) -> Result<TokenBatch<B>> {
1176 let mut inputs = Vec::with_capacity(batch_size * self.seq_len);
1177 let mut targets = Vec::with_capacity(batch_size * self.seq_len);
1178 let mut loss_mask = Vec::with_capacity(batch_size * self.seq_len);
1179
1180 for batch_idx in 0..batch_size {
1181 let index = (step * batch_size + batch_idx) % self.windows.len();
1182 let window = &self.windows[index];
1183 inputs.extend_from_slice(&window.inputs);
1184 targets.extend_from_slice(&window.targets);
1185 loss_mask.extend_from_slice(&window.loss_mask);
1186 }
1187
1188 Ok(TokenBatch {
1189 inputs: tensor_from_u32(inputs, [batch_size, self.seq_len], device)?,
1190 targets: tensor_from_u32(targets, [batch_size, self.seq_len], device)?,
1191 loss_mask: Tensor::<B, 2>::from_data(
1192 TensorData::new(loss_mask, [batch_size, self.seq_len]),
1193 device,
1194 ),
1195 })
1196 }
1197}
1198
1199struct TokenBatch<B: Backend> {
1200 inputs: Tensor<B, 2, Int>,
1201 targets: Tensor<B, 2, Int>,
1202 loss_mask: Tensor<B, 2>,
1203}
1204
1205#[cfg(test)]
1206#[allow(dead_code)]
1207pub fn make_batch<B: Backend>(
1208 step: usize,
1209 batch_size: usize,
1210 seq_len: usize,
1211 vocab_size: usize,
1212 device: &B::Device,
1213) -> Result<(Tensor<B, 2, Int>, Tensor<B, 2, Int>)> {
1214 let mut inputs = Vec::with_capacity(batch_size * seq_len);
1215 let mut targets = Vec::with_capacity(batch_size * seq_len);
1216
1217 for batch in 0..batch_size {
1218 let offset = (step * 7 + batch * 13) % vocab_size;
1219 for pos in 0..seq_len {
1220 let token = ((offset + pos) % vocab_size) as u32;
1221 let next = ((offset + pos + 1) % vocab_size) as u32;
1222 inputs.push(token);
1223 targets.push(next);
1224 }
1225 }
1226
1227 Ok((
1228 tensor_from_u32(inputs, [batch_size, seq_len], device)?,
1229 tensor_from_u32(targets, [batch_size, seq_len], device)?,
1230 ))
1231}