1use crate::{
2 error::{Error, Result},
3 runtime::DefaultAutodiffBackend,
4};
5use burn::{
6 grad_clipping::GradientClippingConfig,
7 module::{Module, Param},
8 optim::{AdamWConfig, GradientsParams, Optimizer},
9 record::{FullPrecisionSettings, NamedMpkFileRecorder},
10 tensor::{
11 activation,
12 backend::{AutodiffBackend, Backend},
13 Int, Tensor, TensorData,
14 },
15};
16use serde::{Deserialize, Serialize};
17use std::f32::consts::PI;
18use std::path::Path;
19
20const EPS: f32 = 1e-6;
21
22#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
28pub enum MultiscreenParameterBudget {
29 Params1M,
30 Params5M,
31 Params10M,
32 Params50M,
33 Params100M,
34}
35
36impl MultiscreenParameterBudget {
37 pub const ALL: [Self; 5] = [
38 Self::Params1M,
39 Self::Params5M,
40 Self::Params10M,
41 Self::Params50M,
42 Self::Params100M,
43 ];
44
45 pub fn label(self) -> &'static str {
46 match self {
47 Self::Params1M => "1M",
48 Self::Params5M => "5M",
49 Self::Params10M => "10M",
50 Self::Params50M => "50M",
51 Self::Params100M => "100M",
52 }
53 }
54
55 pub fn target_parameter_count(self) -> usize {
56 match self {
57 Self::Params1M => 1_000_000,
58 Self::Params5M => 5_000_000,
59 Self::Params10M => 10_000_000,
60 Self::Params50M => 50_000_000,
61 Self::Params100M => 100_000_000,
62 }
63 }
64}
65
66#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
68pub struct MultiscreenModelConfig {
69 pub vocab_size: usize,
71 pub seq_len: usize,
73 pub layers: usize,
75 pub tiles: usize,
77 pub d_model: usize,
79 pub d_key: usize,
81 pub d_value: usize,
83 pub w_th: f32,
85}
86
87impl MultiscreenModelConfig {
88 pub fn tiny() -> Self {
89 Self {
90 vocab_size: 64,
91 seq_len: 64,
92 layers: 2,
93 tiles: 4,
94 d_model: 64,
95 d_key: 16,
96 d_value: 32,
97 w_th: 32.0,
98 }
99 }
100
101 pub fn tiny_for_tests() -> Self {
102 Self {
103 vocab_size: 32,
104 seq_len: 8,
105 layers: 1,
106 tiles: 2,
107 d_model: 16,
108 d_key: 4,
109 d_value: 8,
110 w_th: 8.0,
111 }
112 }
113
114 pub fn for_parameter_budget(
115 budget: MultiscreenParameterBudget,
116 vocab_size: usize,
117 seq_len: usize,
118 ) -> Self {
119 match budget {
120 MultiscreenParameterBudget::Params1M => Self::preset_1m(vocab_size, seq_len),
121 MultiscreenParameterBudget::Params5M => Self::preset_5m(vocab_size, seq_len),
122 MultiscreenParameterBudget::Params10M => Self::preset_10m(vocab_size, seq_len),
123 MultiscreenParameterBudget::Params50M => Self::preset_50m(vocab_size, seq_len),
124 MultiscreenParameterBudget::Params100M => Self::preset_100m(vocab_size, seq_len),
125 }
126 }
127
128 pub fn preset_1m(vocab_size: usize, seq_len: usize) -> Self {
129 Self::from_dimensions(vocab_size, seq_len, 2, 2, 128, 32, 64)
130 }
131
132 pub fn preset_5m(vocab_size: usize, seq_len: usize) -> Self {
133 Self::from_dimensions(vocab_size, seq_len, 2, 4, 384, 96, 192)
134 }
135
136 pub fn preset_10m(vocab_size: usize, seq_len: usize) -> Self {
137 Self::from_dimensions(vocab_size, seq_len, 3, 4, 512, 128, 256)
138 }
139
140 pub fn preset_50m(vocab_size: usize, seq_len: usize) -> Self {
141 Self::from_dimensions(vocab_size, seq_len, 6, 4, 960, 240, 480)
142 }
143
144 pub fn preset_100m(vocab_size: usize, seq_len: usize) -> Self {
145 Self::from_dimensions(vocab_size, seq_len, 8, 4, 1216, 304, 608)
146 }
147
148 pub fn paper_10m(vocab_size: usize, seq_len: usize) -> Self {
149 Self::preset_10m(vocab_size, seq_len)
150 }
151
152 pub fn estimated_parameter_count(&self) -> usize {
153 let embedding_params = self.vocab_size.saturating_mul(self.d_model);
154 let per_tile_params = self
155 .d_model
156 .saturating_mul(
157 2usize
158 .saturating_mul(self.d_key)
159 .saturating_add(3usize.saturating_mul(self.d_value)),
160 )
161 .saturating_add(3);
162 let tile_params = self
163 .layers
164 .saturating_mul(self.tiles)
165 .saturating_mul(per_tile_params);
166 embedding_params
167 .saturating_add(2)
168 .saturating_add(tile_params)
169 }
170
171 fn from_dimensions(
172 vocab_size: usize,
173 seq_len: usize,
174 layers: usize,
175 tiles: usize,
176 d_model: usize,
177 d_key: usize,
178 d_value: usize,
179 ) -> Self {
180 Self {
181 vocab_size,
182 seq_len,
183 layers,
184 tiles,
185 d_model,
186 d_key,
187 d_value,
188 w_th: 32.0,
189 }
190 }
191
192 pub fn validate(&self) -> Result<()> {
193 ensure(self.vocab_size > 0, "vocab_size must be greater than zero")?;
194 ensure(self.seq_len > 0, "seq_len must be greater than zero")?;
195 ensure(self.layers > 0, "layers must be greater than zero")?;
196 ensure(self.tiles > 0, "tiles must be greater than zero")?;
197 ensure(self.d_model > 0, "d_model must be greater than zero")?;
198 ensure(self.d_key >= 2, "d_key must be at least 2 for MiPE")?;
199 ensure(self.d_value > 0, "d_value must be greater than zero")?;
200 ensure(
201 self.w_th.is_finite() && self.w_th > 0.0,
202 "w_th must be positive and finite",
203 )?;
204 Ok(())
205 }
206}
207
208#[derive(Clone, Debug)]
210pub struct ModelTrainingConfig {
211 pub steps: usize,
212 pub batch_size: usize,
213 pub learning_rate: f64,
214 pub weight_decay: f64,
215 pub grad_clip_norm: Option<f64>,
216 pub pad_token_id: u32,
217}
218
219impl Default for ModelTrainingConfig {
220 fn default() -> Self {
221 Self {
222 steps: 100,
223 batch_size: 4,
224 learning_rate: 2e-4,
225 weight_decay: 0.01,
226 grad_clip_norm: Some(1.0),
227 pad_token_id: 0,
228 }
229 }
230}
231
232#[derive(Clone, Debug)]
234pub struct ModelInferenceConfig {
235 pub max_new_tokens: usize,
236 pub pad_token_id: u32,
237}
238
239impl Default for ModelInferenceConfig {
240 fn default() -> Self {
241 Self {
242 max_new_tokens: 16,
243 pad_token_id: 0,
244 }
245 }
246}
247
248#[derive(Clone, Debug, PartialEq)]
249pub struct ModelTrainingReport {
250 pub steps: usize,
251 pub final_loss: f32,
252 pub training_window_count: usize,
253 pub parameter_count: usize,
254}
255
256#[derive(Clone, Debug)]
258pub struct EvaluationResult {
259 pub loss: f32,
261 pub perplexity: f32,
263 pub accuracy: f64,
265 pub num_batches: usize,
267 pub total_tokens: usize,
269}
270
271pub struct MultiscreenModelOutput {
272 pub token_ids: Vec<u32>,
273}
274
275#[derive(Module, Debug)]
278pub struct MultiscreenModel<B: Backend = DefaultAutodiffBackend> {
279 #[module(skip)]
280 config: MultiscreenModelConfig,
281 token_embedding: Param<Tensor<B, 2>>,
282 s_e: Param<Tensor<B, 1>>,
283 s_f: Param<Tensor<B, 1>>,
284 layers: Vec<MultiscreenLayer<B>>,
285}
286
287pub type DefaultMultiscreenModel = MultiscreenModel<DefaultAutodiffBackend>;
289
290#[derive(Module, Debug)]
291struct MultiscreenLayer<B: Backend> {
292 tiles: Vec<GatedScreeningTile<B>>,
293}
294
295#[derive(Module, Debug)]
296struct GatedScreeningTile<B: Backend> {
297 w_q: Param<Tensor<B, 2>>,
298 w_k: Param<Tensor<B, 2>>,
299 w_v: Param<Tensor<B, 2>>,
300 w_g: Param<Tensor<B, 2>>,
301 w_o: Param<Tensor<B, 2>>,
302 s_w: Param<Tensor<B, 1>>,
303 s_r: Param<Tensor<B, 1>>,
304 s_o: Param<Tensor<B, 1>>,
305 #[module(skip)]
306 w_th: f32,
307}
308
309impl<B: Backend> MultiscreenModel<B> {
310 pub fn new(config: MultiscreenModelConfig, device: &B::Device) -> Result<Self> {
311 config.validate()?;
312
313 let mut seed = 0x4d55_4c54_4953_4352;
314 let token_embedding = init_matrix(
315 config.vocab_size,
316 config.d_model,
317 0.1 / (config.d_model as f32).sqrt(),
318 &mut seed,
319 device,
320 );
321 let s_e = init_scalar(0.0, device);
322 let s_f = init_scalar((config.d_model as f32).sqrt().ln(), device);
323
324 let mut layers = Vec::with_capacity(config.layers);
325 for _layer_idx in 0..config.layers {
326 let mut tiles = Vec::with_capacity(config.tiles);
327 for tile_idx in 0..config.tiles {
328 let w_q = init_matrix(
329 config.d_model,
330 config.d_key,
331 0.1 / (config.d_key as f32).sqrt(),
332 &mut seed,
333 device,
334 );
335 let w_k = init_matrix(
336 config.d_model,
337 config.d_key,
338 0.1 / (config.d_key as f32).sqrt(),
339 &mut seed,
340 device,
341 );
342 let w_v = init_matrix(
343 config.d_model,
344 config.d_value,
345 0.1 / (config.d_value as f32).sqrt(),
346 &mut seed,
347 device,
348 );
349 let w_g = init_matrix(config.d_model, config.d_value, 0.1, &mut seed, device);
350 let w_o = init_matrix(
351 config.d_value,
352 config.d_model,
353 0.1 / (config.d_model as f32).sqrt(),
354 &mut seed,
355 device,
356 );
357
358 let window_frac = if config.tiles == 1 {
359 0.0
360 } else {
361 tile_idx as f32 / (config.tiles - 1) as f32
362 };
363 let s_w = init_scalar(window_frac * config.w_th.ln(), device);
364 let s_r = init_scalar(0.0, device);
365 let s_o = init_scalar(-0.5 * ((config.layers * config.tiles) as f32).ln(), device);
366
367 tiles.push(GatedScreeningTile {
368 w_q,
369 w_k,
370 w_v,
371 w_g,
372 w_o,
373 s_w,
374 s_r,
375 s_o,
376 w_th: config.w_th,
377 });
378 }
379 layers.push(MultiscreenLayer { tiles });
380 }
381
382 Ok(Self {
383 config,
384 token_embedding,
385 s_e,
386 s_f,
387 layers,
388 })
389 }
390
391 pub fn config(&self) -> &MultiscreenModelConfig {
392 &self.config
393 }
394
395 pub fn parameter_count(&self) -> usize {
396 self.num_params()
397 }
398
399 pub fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
400 let [batch, seq_len] = tokens.dims();
401 let embedding = row_unit_normalize(self.token_embedding.val());
402 let one_hot = tokens.one_hot::<3>(self.config().vocab_size).float();
403 let mut x =
404 linear(one_hot, embedding.clone()).reshape([batch, seq_len, self.config().d_model]);
405 let x_dims = x.dims();
406 x = x * expand_scalar3(self.s_e.val().exp(), x_dims);
407
408 for layer in &self.layers {
409 let mut layer_update = Tensor::<B, 3>::zeros(x.dims(), &x.device());
410 for tile in &layer.tiles {
411 layer_update = layer_update + tile.forward(x.clone());
412 }
413 x = x + layer_update;
414 }
415
416 let logits_weight = embedding.swap_dims(0, 1);
417 let logits = linear(x, logits_weight);
418 logits.clone() * expand_scalar3(self.s_f.val().exp(), logits.dims())
419 }
420
421 pub fn save_parameters(&self, path: impl AsRef<Path>) -> Result<()> {
422 let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
423 self.clone()
424 .save_file(path.as_ref().to_path_buf(), &recorder)
425 .map_err(|err| Error::Serialization(err.to_string()))
426 }
427
428 pub fn load_parameters(&mut self, path: impl AsRef<Path>) -> Result<()> {
429 let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
430 let device =
431 self.devices().into_iter().next().ok_or_else(|| {
432 Error::Serialization("model has no device for parameter load".into())
433 })?;
434 let loaded = self
435 .clone()
436 .load_file(path.as_ref().to_path_buf(), &recorder, &device)
437 .map_err(|err| Error::Serialization(err.to_string()))?;
438 *self = loaded;
439 Ok(())
440 }
441
442 pub fn infer_tokens_stream(
453 &self,
454 prompt: &[u32],
455 inference: &ModelInferenceConfig,
456 device: &B::Device,
457 mut on_token: impl FnMut(u32, usize) -> bool,
458 ) -> Result<MultiscreenModelOutput> {
459 if prompt.is_empty() {
460 return Err(Error::Inference(
461 "prompt must contain at least one token".to_string(),
462 ));
463 }
464
465 let mut output = prompt.to_vec();
466 for i in 0..inference.max_new_tokens {
467 let next = self.predict_next_token(&output, inference.pad_token_id, device)?;
468 output.push(next);
469 if !on_token(next, i) {
470 break;
471 }
472 }
473
474 Ok(MultiscreenModelOutput { token_ids: output })
475 }
476
477 pub fn infer_tokens(
481 &self,
482 prompt: &[u32],
483 inference: &ModelInferenceConfig,
484 device: &B::Device,
485 ) -> Result<MultiscreenModelOutput> {
486 self.infer_tokens_stream(prompt, inference, device, |_, _| true)
487 }
488
489 pub fn predict_next_token(
490 &self,
491 context: &[u32],
492 pad_token_id: u32,
493 device: &B::Device,
494 ) -> Result<u32> {
495 if context.is_empty() {
496 return Err(Error::Inference(
497 "context must contain at least one token".to_string(),
498 ));
499 }
500
501 let input = context_window(context, self.config().seq_len, pad_token_id);
502 let input = tensor_from_u32::<B, 2>(input, [1, self.config().seq_len], device)?;
503 let logits = self.forward(input);
504 let last_logits = logits
505 .slice([
506 0..1,
507 self.config().seq_len - 1..self.config().seq_len,
508 0..self.config().vocab_size,
509 ])
510 .reshape([self.config().vocab_size]);
511 let values = tensor_to_vec(last_logits)?;
512 argmax(&values).map(|idx| idx as u32)
513 }
514}
515
516impl<B> MultiscreenModel<B>
517where
518 B: AutodiffBackend,
519{
520 pub fn train_token_sequences(
525 &mut self,
526 sequences: &[Vec<u32>],
527 training: &ModelTrainingConfig,
528 device: &B::Device,
529 mut on_step: impl FnMut(usize, f32),
530 ) -> Result<ModelTrainingReport> {
531 if training.batch_size == 0 {
532 return Err(Error::Training(
533 "batch_size must be greater than zero".to_string(),
534 ));
535 }
536 let windows = TrainingWindows::from_sequences(
537 sequences,
538 self.config().seq_len,
539 training.pad_token_id,
540 )?;
541 if windows.is_empty() {
542 return Err(Error::Training(
543 "training requires at least one sequence with two or more tokens".to_string(),
544 ));
545 }
546
547 let mut optimizer_config =
548 AdamWConfig::new().with_weight_decay(training.weight_decay as f32);
549 if let Some(max_norm) = training.grad_clip_norm.filter(|value| *value > 0.0) {
550 optimizer_config = optimizer_config
551 .with_grad_clipping(Some(GradientClippingConfig::Norm(max_norm as f32)));
552 }
553 let mut optimizer = optimizer_config.init::<B, Self>();
554 let mut model = self.clone();
555 let mut final_loss = f32::NAN;
556
557 for step in 0..training.steps {
558 let batch = windows.batch::<B>(step, training.batch_size, device)?;
559 let logits = model.forward(batch.inputs);
560 let loss = cross_entropy_loss_with_mask(logits, batch.targets, batch.loss_mask);
561 final_loss = tensor_scalar(loss.clone())?;
562 let grads = loss.backward();
563 let grads = GradientsParams::from_grads(grads, &model);
564 model = optimizer.step(training.learning_rate, model, grads);
565 on_step(step, final_loss);
566 }
567
568 if training.steps == 0 {
569 let batch = windows.batch::<B>(0, training.batch_size, device)?;
570 final_loss = tensor_scalar(cross_entropy_loss_with_mask(
571 model.forward(batch.inputs),
572 batch.targets,
573 batch.loss_mask,
574 ))?;
575 }
576
577 *self = model;
578
579 Ok(ModelTrainingReport {
580 steps: training.steps,
581 final_loss,
582 training_window_count: windows.len(),
583 parameter_count: self.parameter_count(),
584 })
585 }
586
587 pub fn evaluate_on_sequences(
590 &self,
591 sequences: &[Vec<u32>],
592 seq_len: usize,
593 batch_size: usize,
594 pad_token_id: u32,
595 device: &B::Device,
596 ) -> Result<EvaluationResult> {
597 let windows = TrainingWindows::from_sequences(sequences, seq_len, pad_token_id)?;
598 if windows.is_empty() {
599 return Ok(EvaluationResult {
600 loss: f32::NAN,
601 perplexity: f32::NAN,
602 accuracy: 0.0,
603 num_batches: 0,
604 total_tokens: 0,
605 });
606 }
607
608 let num_batches = windows.len().div_ceil(batch_size);
609 let mut total_loss = 0.0_f64;
610 let mut total_correct = 0_usize;
611 let mut total_tokens = 0_usize;
612
613 for step in 0..num_batches {
614 let batch = windows.batch::<B>(step, batch_size, device)?;
615 let logits = self.forward(batch.inputs); let loss = cross_entropy_loss_with_mask(
617 logits.clone(),
618 batch.targets.clone(),
619 batch.loss_mask.clone(),
620 );
621 let loss_val = tensor_scalar(loss)? as f64;
622 total_loss += loss_val;
623
624 let [b, s, v] = logits.dims();
626 let mask_vec: Vec<f32> = batch
627 .loss_mask
628 .clone()
629 .reshape([b * s])
630 .into_data()
631 .into_vec::<f32>()
632 .map_err(|e| Error::Inference(e.to_string()))?;
633 let target_vec: Vec<i32> = batch
634 .targets
635 .clone()
636 .reshape([b * s])
637 .into_data()
638 .into_vec::<i32>()
639 .map_err(|e| Error::Inference(e.to_string()))?;
640 let logit_vec: Vec<f32> = logits
641 .reshape([b * s * v])
642 .into_data()
643 .into_vec::<f32>()
644 .map_err(|e| Error::Inference(e.to_string()))?;
645
646 for bi in 0..b {
647 for si in 0..s {
648 let mi = bi * s + si;
649 if mask_vec[mi] < 0.5 {
650 continue;
651 }
652 total_tokens += 1;
653 let base = bi * s * v + si * v;
654 let mut best_idx = 0;
655 let mut best_val = f32::NEG_INFINITY;
656 for vi in 0..v {
657 let val = logit_vec[base + vi];
658 if val > best_val {
659 best_val = val;
660 best_idx = vi;
661 }
662 }
663 if best_idx == target_vec[mi] as usize {
664 total_correct += 1;
665 }
666 }
667 }
668 }
669
670 let avg_loss = total_loss / num_batches as f64;
671 let perplexity = avg_loss.exp();
672 let accuracy = if total_tokens > 0 {
673 total_correct as f64 / total_tokens as f64
674 } else {
675 0.0
676 };
677
678 Ok(EvaluationResult {
679 loss: avg_loss as f32,
680 perplexity: perplexity as f32,
681 accuracy,
682 num_batches,
683 total_tokens,
684 })
685 }
686}
687
688impl<B: Backend> crate::lm::LanguageModel<B> for MultiscreenModel<B> {
689 fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
690 MultiscreenModel::forward(self, tokens)
691 }
692}
693
694impl<B: Backend> crate::lm::TrainableLanguageModel<B> for MultiscreenModel<B> {}
695
696impl<B: Backend> GatedScreeningTile<B> {
697 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
698 let q = row_unit_normalize(linear(x.clone(), self.w_q.val()));
699 let k = row_unit_normalize(linear(x.clone(), self.w_k.val()));
700 let v = row_unit_normalize(linear(x.clone(), self.w_v.val()));
701 let g = linear(x, self.w_g.val());
702
703 let w = self.s_w.val().clamp(-10.0, 8.0).exp() + 1.0;
704 let r = activation::sigmoid(self.s_r.val().clamp(-10.0, 8.0));
705
706 let q = apply_mipe(q, w.clone(), self.w_th);
707 let k = apply_mipe(k, w.clone(), self.w_th);
708
709 let similarity = q.matmul(k.swap_dims(1, 2));
710 let alpha = trim_and_square_tensor(similarity.clone(), r);
711 let softmask = causal_softmask_tensor::<B>(similarity.dims()[1], w, &similarity.device());
712 let relevance = alpha * softmask.unsqueeze();
713 let h = relevance.matmul(v);
714 let u = tanh_norm(h);
715 let gate = activation::silu(g).tanh();
716 let gated = u * gate;
717 let out = linear(gated, self.w_o.val());
718 out.clone() * expand_scalar3(self.s_o.val().exp(), out.dims())
719 }
720}
721
722#[allow(dead_code)]
723pub fn cross_entropy_loss<B: Backend>(
724 logits: Tensor<B, 3>,
725 targets: Tensor<B, 2, Int>,
726) -> Tensor<B, 1> {
727 let device = logits.device();
728 let [batch, seq_len, _] = logits.dims();
729 let loss_mask = Tensor::<B, 2>::ones([batch, seq_len], &device);
730 cross_entropy_loss_with_mask(logits, targets, loss_mask)
731}
732
733pub fn cross_entropy_loss_with_mask<B: Backend>(
734 logits: Tensor<B, 3>,
735 targets: Tensor<B, 2, Int>,
736 loss_mask: Tensor<B, 2>,
737) -> Tensor<B, 1> {
738 let [batch, seq_len, vocab_size] = logits.dims();
739 let token_count = batch * seq_len;
740 let flat_logits = logits.reshape([token_count, vocab_size]);
741 let flat_targets = targets.reshape([token_count]);
742 let flat_mask = loss_mask.reshape([token_count]);
743 let log_probs = activation::log_softmax(flat_logits, 1);
744 let target_probs = flat_targets.one_hot::<2>(vocab_size).float();
745 let picked = (log_probs * target_probs).sum_dim(1).reshape([token_count]);
746 let denom = flat_mask.clone().sum().add_scalar(EPS);
747 (picked.neg() * flat_mask).sum() / denom
748}
749
750pub fn row_unit_normalize<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
751 let denom = x.clone().square().sum_dim(D - 1).add_scalar(EPS).sqrt();
752 x / denom
753}
754
755fn trim_and_square_tensor<B: Backend>(similarity: Tensor<B, 3>, r: Tensor<B, 1>) -> Tensor<B, 3> {
756 let distance_from_one = similarity.mul_scalar(-1.0).add_scalar(1.0);
757 let scaled = distance_from_one.clone() / expand_scalar3(r, distance_from_one.clone().dims());
758 scaled
759 .mul_scalar(-1.0)
760 .add_scalar(1.0)
761 .clamp(0.0, 1.0)
762 .square()
763}
764
765fn causal_softmask_tensor<B: Backend>(
766 seq_len: usize,
767 w: Tensor<B, 1>,
768 device: &B::Device,
769) -> Tensor<B, 2> {
770 let mut distances = Vec::with_capacity(seq_len * seq_len);
771 for i in 0..seq_len {
772 for j in 0..seq_len {
773 distances.push(j as f32 - i as f32);
774 }
775 }
776
777 let dist = Tensor::<B, 2>::from_data(TensorData::new(distances, [seq_len, seq_len]), device);
778 let w = expand_scalar2(w, [seq_len, seq_len]);
779 let causal = dist.clone().lower_equal_elem(0.0);
780 let within_window = dist.clone().greater(w.clone().neg());
781 let active = causal.float() * within_window.float();
782 let tapered = (dist / w)
783 .mul_scalar(PI)
784 .cos()
785 .mul_scalar(0.5)
786 .add_scalar(0.5);
787 active * tapered
788}
789
790pub fn tanh_norm<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
791 let norm = x.clone().square().sum_dim(D - 1).add_scalar(EPS).sqrt();
792 let scale = norm.clone().tanh() / norm;
793 x * scale
794}
795
796fn apply_mipe<B: Backend>(z: Tensor<B, 3>, w: Tensor<B, 1>, w_th: f32) -> Tensor<B, 3> {
797 let [_batch, seq_len, d_key] = z.dims();
798 debug_assert!(d_key >= 2);
799
800 let positions = Tensor::<B, 3>::from_data(
801 TensorData::new(
802 (0..seq_len).map(|idx| idx as f32).collect::<Vec<_>>(),
803 [1, seq_len, 1],
804 ),
805 &z.device(),
806 );
807
808 let gamma_raw = w
809 .clone()
810 .mul_scalar(PI / w_th)
811 .cos()
812 .mul_scalar(0.5)
813 .add_scalar(0.5);
814 let gamma = gamma_raw.mask_fill(w.clone().greater_equal_elem(w_th), 0.0);
815 let gamma = expand_scalar3(gamma, [1, seq_len, 1]);
816 let w = expand_scalar3(w, [1, seq_len, 1]);
817 let phi = (positions * gamma / w).mul_scalar(PI);
818 let cos_phi = phi.clone().cos();
819 let sin_phi = phi.sin();
820
821 let x0 = z.clone().narrow(2, 0, 1);
822 let x1 = z.clone().narrow(2, 1, 1);
823 let rot0 = x0.clone() * cos_phi.clone() - x1.clone() * sin_phi.clone();
824 let rot1 = x0 * sin_phi + x1 * cos_phi;
825
826 if d_key == 2 {
827 Tensor::cat(vec![rot0, rot1], 2)
828 } else {
829 let rest = z.narrow(2, 2, d_key - 2);
830 Tensor::cat(vec![rot0, rot1, rest], 2)
831 }
832}
833
834fn linear<B: Backend>(x: Tensor<B, 3>, weight: Tensor<B, 2>) -> Tensor<B, 3> {
835 let [batch, _seq_len, _in_dim] = x.dims();
836 let [weight_in, out_dim] = weight.dims();
837 x.matmul(weight.unsqueeze::<3>().expand([batch, weight_in, out_dim]))
838}
839
840fn expand_scalar2<B: Backend>(value: Tensor<B, 1>, dims: [usize; 2]) -> Tensor<B, 2> {
841 value.unsqueeze::<2>().expand(dims)
842}
843
844fn expand_scalar3<B: Backend>(value: Tensor<B, 1>, dims: [usize; 3]) -> Tensor<B, 3> {
845 value.unsqueeze::<3>().expand(dims)
846}
847
848fn init_scalar<B: Backend>(value: f32, device: &B::Device) -> Param<Tensor<B, 1>> {
849 Param::from_tensor(Tensor::<B, 1>::from_data([value], device))
850}
851
852fn init_matrix<B: Backend>(
853 rows: usize,
854 cols: usize,
855 std: f32,
856 seed: &mut u64,
857 device: &B::Device,
858) -> Param<Tensor<B, 2>> {
859 let values = gaussian_values(rows * cols, std, seed);
860 Param::from_tensor(Tensor::<B, 2>::from_data(
861 TensorData::new(values, [rows, cols]),
862 device,
863 ))
864}
865
866fn gaussian_values(len: usize, std: f32, seed: &mut u64) -> Vec<f32> {
867 let mut values = Vec::with_capacity(len);
868 while values.len() < len {
869 let u1 = next_uniform(seed).max(1e-7);
870 let u2 = next_uniform(seed);
871 let radius = (-2.0 * u1.ln()).sqrt();
872 let theta = 2.0 * PI * u2;
873 values.push(radius * theta.cos() * std);
874 if values.len() < len {
875 values.push(radius * theta.sin() * std);
876 }
877 }
878 values
879}
880
881fn next_uniform(seed: &mut u64) -> f32 {
882 *seed = seed
883 .wrapping_mul(6364136223846793005)
884 .wrapping_add(1442695040888963407);
885 let bits = (*seed >> 40) as u32;
886 (bits as f32 + 1.0) / ((1u32 << 24) as f32 + 2.0)
887}
888
889fn context_window(context: &[u32], seq_len: usize, pad_token_id: u32) -> Vec<u32> {
890 let mut input = vec![pad_token_id; seq_len];
891 let suffix = if context.len() > seq_len {
892 &context[context.len() - seq_len..]
893 } else {
894 context
895 };
896 let start = seq_len - suffix.len();
897 input[start..].copy_from_slice(suffix);
898 input
899}
900
901fn argmax(values: &[f32]) -> Result<usize> {
902 values
903 .iter()
904 .enumerate()
905 .max_by(|left, right| left.1.total_cmp(right.1))
906 .map(|(index, _)| index)
907 .ok_or_else(|| Error::Inference("cannot argmax empty logits".to_string()))
908}
909
910fn tensor_scalar<B: Backend>(tensor: Tensor<B, 1>) -> Result<f32> {
911 tensor
912 .into_data()
913 .into_vec::<f32>()
914 .map_err(|err| Error::Training(err.to_string()))?
915 .into_iter()
916 .next()
917 .ok_or_else(|| Error::Training("expected scalar tensor".to_string()))
918}
919
920fn tensor_to_vec<B: Backend>(tensor: Tensor<B, 1>) -> Result<Vec<f32>> {
921 tensor
922 .into_data()
923 .into_vec::<f32>()
924 .map_err(|err| Error::Inference(err.to_string()))
925}
926
927fn tensor_from_u32<B: Backend, const D: usize>(
928 values: Vec<u32>,
929 shape: [usize; D],
930 device: &B::Device,
931) -> Result<Tensor<B, D, Int>> {
932 let values = values
933 .into_iter()
934 .map(|value| {
935 i32::try_from(value)
936 .map_err(|_| Error::Config(format!("token id {value} exceeds i32::MAX")))
937 })
938 .collect::<Result<Vec<_>>>()?;
939 Ok(Tensor::<B, D, Int>::from_data(
940 TensorData::new(values, shape),
941 device,
942 ))
943}
944
945fn ensure(condition: bool, message: &str) -> Result<()> {
946 if condition {
947 Ok(())
948 } else {
949 Err(Error::Config(message.to_string()))
950 }
951}
952
953struct TrainingWindow {
954 inputs: Vec<u32>,
955 targets: Vec<u32>,
956 loss_mask: Vec<f32>,
957}
958
959struct TrainingWindows {
960 windows: Vec<TrainingWindow>,
961 seq_len: usize,
962}
963
964impl TrainingWindows {
965 fn from_sequences(sequences: &[Vec<u32>], seq_len: usize, pad_token_id: u32) -> Result<Self> {
966 let mut windows = Vec::new();
967 for sequence in sequences {
968 if sequence.len() < 2 {
969 continue;
970 }
971
972 let mut start = 0;
973 while start + 1 < sequence.len() {
974 let end = (start + seq_len + 1).min(sequence.len());
975 let chunk = &sequence[start..end];
976 let prediction_count = chunk.len() - 1;
977
978 let mut inputs = vec![pad_token_id; seq_len];
979 let mut targets = vec![pad_token_id; seq_len];
980 let mut loss_mask = vec![0.0; seq_len];
981 inputs[..prediction_count].copy_from_slice(&chunk[..prediction_count]);
982 targets[..prediction_count].copy_from_slice(&chunk[1..]);
983 loss_mask[..prediction_count].fill(1.0);
984
985 windows.push(TrainingWindow {
986 inputs,
987 targets,
988 loss_mask,
989 });
990
991 if end == sequence.len() {
992 break;
993 }
994 start += seq_len;
995 }
996 }
997
998 Ok(Self { windows, seq_len })
999 }
1000
1001 fn is_empty(&self) -> bool {
1002 self.windows.is_empty()
1003 }
1004
1005 fn len(&self) -> usize {
1006 self.windows.len()
1007 }
1008
1009 fn batch<B: Backend>(
1010 &self,
1011 step: usize,
1012 batch_size: usize,
1013 device: &B::Device,
1014 ) -> Result<TokenBatch<B>> {
1015 let mut inputs = Vec::with_capacity(batch_size * self.seq_len);
1016 let mut targets = Vec::with_capacity(batch_size * self.seq_len);
1017 let mut loss_mask = Vec::with_capacity(batch_size * self.seq_len);
1018
1019 for batch_idx in 0..batch_size {
1020 let index = (step * batch_size + batch_idx) % self.windows.len();
1021 let window = &self.windows[index];
1022 inputs.extend_from_slice(&window.inputs);
1023 targets.extend_from_slice(&window.targets);
1024 loss_mask.extend_from_slice(&window.loss_mask);
1025 }
1026
1027 Ok(TokenBatch {
1028 inputs: tensor_from_u32(inputs, [batch_size, self.seq_len], device)?,
1029 targets: tensor_from_u32(targets, [batch_size, self.seq_len], device)?,
1030 loss_mask: Tensor::<B, 2>::from_data(
1031 TensorData::new(loss_mask, [batch_size, self.seq_len]),
1032 device,
1033 ),
1034 })
1035 }
1036}
1037
1038struct TokenBatch<B: Backend> {
1039 inputs: Tensor<B, 2, Int>,
1040 targets: Tensor<B, 2, Int>,
1041 loss_mask: Tensor<B, 2>,
1042}
1043
1044#[cfg(test)]
1045#[allow(dead_code)]
1046pub fn make_batch<B: Backend>(
1047 step: usize,
1048 batch_size: usize,
1049 seq_len: usize,
1050 vocab_size: usize,
1051 device: &B::Device,
1052) -> Result<(Tensor<B, 2, Int>, Tensor<B, 2, Int>)> {
1053 let mut inputs = Vec::with_capacity(batch_size * seq_len);
1054 let mut targets = Vec::with_capacity(batch_size * seq_len);
1055
1056 for batch in 0..batch_size {
1057 let offset = (step * 7 + batch * 13) % vocab_size;
1058 for pos in 0..seq_len {
1059 let token = ((offset + pos) % vocab_size) as u32;
1060 let next = ((offset + pos + 1) % vocab_size) as u32;
1061 inputs.push(token);
1062 targets.push(next);
1063 }
1064 }
1065
1066 Ok((
1067 tensor_from_u32(inputs, [batch_size, seq_len], device)?,
1068 tensor_from_u32(targets, [batch_size, seq_len], device)?,
1069 ))
1070}