Skip to main content

forgellm_runtime/
interpreter.rs

1//! Interpreter — executes IR graphs directly on CPU.
2//!
3//! Uses optimized kernels from the `kernels` module for compute-heavy
4//! operations (matmul, rms_norm). Validates correctness and serves as
5//! the primary inference path until the AOT codegen is ready.
6
7use forgellm_frontend::ir::*;
8use forgellm_frontend::weight_loader::ModelWeights;
9
10use crate::kernels;
11use crate::kv_cache::KVCache;
12
13/// Run a single forward pass for one token through the model.
14///
15/// Returns logits of shape `[vocab_size]`.
16pub fn forward(
17    token_id: u32,
18    pos: usize,
19    graph: &Graph,
20    weights: &ModelWeights,
21    cache: &mut KVCache,
22) -> Vec<f32> {
23    let config = graph.config.as_ref().expect("graph must have config");
24
25    let hidden = config.hidden_size;
26    let intermediate = config.intermediate_size;
27    let num_heads = config.num_attention_heads;
28    let num_kv_heads = config.num_kv_heads;
29    let head_dim = config.head_dim;
30    let vocab = config.vocab_size;
31
32    // Embedding lookup
33    let embed_w = weights.tensor("model.embed_tokens.weight");
34    let mut hidden_state = vec![0.0f32; hidden];
35    let offset = token_id as usize * hidden;
36    hidden_state.copy_from_slice(&embed_w[offset..offset + hidden]);
37
38    // Gemma-1 scales the embedding by sqrt(hidden_size) after lookup.
39    // Doing it here (not to the stored weight) preserves tied embeddings for
40    // the final logit projection.
41    if matches!(
42        config.architecture,
43        forgellm_frontend::ir::Architecture::Gemma
44    ) {
45        let scale = (hidden as f32).sqrt();
46        for v in hidden_state.iter_mut() {
47            *v *= scale;
48        }
49    }
50
51    // Pre-allocate buffers
52    let mut normed = vec![0.0f32; hidden];
53    let mut q = vec![0.0f32; num_heads * head_dim];
54    let mut k = vec![0.0f32; num_kv_heads * head_dim];
55    let mut v = vec![0.0f32; num_kv_heads * head_dim];
56    let mut attn_out = vec![0.0f32; num_heads * head_dim];
57    let mut attn_proj = vec![0.0f32; hidden];
58    let mut residual = vec![0.0f32; hidden];
59    let mut gate = vec![0.0f32; intermediate];
60    let mut gate_act = vec![0.0f32; intermediate];
61    let mut up = vec![0.0f32; intermediate];
62    let mut ffn_hidden = vec![0.0f32; intermediate];
63    let mut ffn_out = vec![0.0f32; hidden];
64
65    for layer_idx in 0..config.num_layers {
66        let prefix = format!("model.layers.{layer_idx}");
67
68        // Attention norm
69        let norm_w = weights.tensor(&format!("{prefix}.input_layernorm.weight"));
70        rms_norm(&mut normed, &hidden_state, norm_w, config.rms_norm_eps);
71
72        // QKV projections
73        let q_w = weights.tensor(&format!("{prefix}.self_attn.q_proj.weight"));
74        let k_w = weights.tensor(&format!("{prefix}.self_attn.k_proj.weight"));
75        let v_w = weights.tensor(&format!("{prefix}.self_attn.v_proj.weight"));
76        matmul(&mut q, &normed, q_w, 1, hidden, num_heads * head_dim);
77        matmul(&mut k, &normed, k_w, 1, hidden, num_kv_heads * head_dim);
78        matmul(&mut v, &normed, v_w, 1, hidden, num_kv_heads * head_dim);
79
80        // Add QKV biases if present (Qwen2 uses biases on QKV)
81        if let Some(q_bias) = weights.get(&format!("{prefix}.self_attn.q_proj.bias")) {
82            elementwise_add_inplace(&mut q, q_bias);
83        }
84        if let Some(k_bias) = weights.get(&format!("{prefix}.self_attn.k_proj.bias")) {
85            elementwise_add_inplace(&mut k, k_bias);
86        }
87        if let Some(v_bias) = weights.get(&format!("{prefix}.self_attn.v_proj.bias")) {
88            elementwise_add_inplace(&mut v, v_bias);
89        }
90
91        // RoPE
92        rope(&mut q, pos, head_dim, num_heads, config.rope_theta);
93        rope(&mut k, pos, head_dim, num_kv_heads, config.rope_theta);
94
95        // Update KV cache
96        cache.append(layer_idx, &k, &v);
97
98        // Attention
99        attention(
100            &mut attn_out,
101            &q,
102            cache.k(layer_idx),
103            cache.v(layer_idx),
104            &AttentionParams {
105                seq_len: pos + 1,
106                num_heads,
107                num_kv_heads,
108                head_dim,
109            },
110        );
111
112        // Output projection
113        let o_w = weights.tensor(&format!("{prefix}.self_attn.o_proj.weight"));
114        matmul(
115            &mut attn_proj,
116            &attn_out,
117            o_w,
118            1,
119            num_heads * head_dim,
120            hidden,
121        );
122
123        // Residual
124        elementwise_add(&mut residual, &hidden_state, &attn_proj);
125
126        // FFN norm
127        let ffn_norm_w = weights.tensor(&format!("{prefix}.post_attention_layernorm.weight"));
128        rms_norm(&mut normed, &residual, ffn_norm_w, config.rms_norm_eps);
129
130        // FFN
131        let gate_w = weights.tensor(&format!("{prefix}.mlp.gate_proj.weight"));
132        let up_w = weights.tensor(&format!("{prefix}.mlp.up_proj.weight"));
133        let down_w = weights.tensor(&format!("{prefix}.mlp.down_proj.weight"));
134
135        matmul(&mut gate, &normed, gate_w, 1, hidden, intermediate);
136        match config.hidden_activation {
137            forgellm_frontend::ir::HiddenActivation::SiLU => silu(&mut gate_act, &gate),
138            forgellm_frontend::ir::HiddenActivation::GeluApprox => {
139                kernels::gelu(&mut gate_act, &gate)
140            }
141        }
142        matmul(&mut up, &normed, up_w, 1, hidden, intermediate);
143        elementwise_mul(&mut ffn_hidden, &gate_act, &up);
144        matmul(&mut ffn_out, &ffn_hidden, down_w, 1, intermediate, hidden);
145
146        // Residual
147        elementwise_add(&mut hidden_state, &residual, &ffn_out);
148    }
149
150    // Final norm
151    let final_norm_w = weights.tensor("model.norm.weight");
152    rms_norm(
153        &mut normed,
154        &hidden_state,
155        final_norm_w,
156        config.rms_norm_eps,
157    );
158
159    // Logits projection (may use tied embeddings)
160    let lm_head_w = weights
161        .get("lm_head.weight")
162        .unwrap_or_else(|| weights.tensor("model.embed_tokens.weight"));
163    let mut logits = vec![0.0f32; vocab];
164    matmul(&mut logits, &normed, lm_head_w, 1, hidden, vocab);
165
166    logits
167}
168
169// --- Kernel wrappers (delegate to optimized kernels module) ---
170
171fn rms_norm(output: &mut [f32], input: &[f32], weight: &[f32], eps: f32) {
172    kernels::rms_norm(output, input, weight, eps);
173}
174
175fn matmul(output: &mut [f32], input: &[f32], weight: &[f32], m: usize, k: usize, n: usize) {
176    kernels::matmul(output, input, weight, m, k, n);
177}
178
179fn silu(output: &mut [f32], input: &[f32]) {
180    kernels::silu(output, input);
181}
182
183fn elementwise_mul(output: &mut [f32], a: &[f32], b: &[f32]) {
184    kernels::elementwise_mul(output, a, b);
185}
186
187fn elementwise_add(output: &mut [f32], a: &[f32], b: &[f32]) {
188    kernels::elementwise_add(output, a, b);
189}
190
191fn elementwise_add_inplace(a: &mut [f32], b: &[f32]) {
192    for i in 0..a.len() {
193        a[i] += b[i];
194    }
195}
196
197fn rope(data: &mut [f32], pos: usize, head_dim: usize, num_heads: usize, theta: f32) {
198    for h in 0..num_heads {
199        let head_offset = h * head_dim;
200        for i in (0..head_dim).step_by(2) {
201            let freq = 1.0 / theta.powf(i as f32 / head_dim as f32);
202            let angle = pos as f32 * freq;
203            let cos_val = angle.cos();
204            let sin_val = angle.sin();
205            let x0 = data[head_offset + i];
206            let x1 = data[head_offset + i + 1];
207            data[head_offset + i] = x0 * cos_val - x1 * sin_val;
208            data[head_offset + i + 1] = x0 * sin_val + x1 * cos_val;
209        }
210    }
211}
212
213struct AttentionParams {
214    seq_len: usize,
215    num_heads: usize,
216    num_kv_heads: usize,
217    head_dim: usize,
218}
219
220fn attention(
221    output: &mut [f32],
222    q: &[f32],
223    k_cache: &[f32],
224    v_cache: &[f32],
225    params: &AttentionParams,
226) {
227    kernels::attention(
228        output,
229        q,
230        k_cache,
231        v_cache,
232        params.seq_len,
233        params.num_heads,
234        params.num_kv_heads,
235        params.head_dim,
236    );
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use std::collections::HashMap;
243
244    #[test]
245    fn rms_norm_basic() {
246        let input = vec![1.0, 2.0, 3.0, 4.0];
247        let weight = vec![1.0; 4];
248        let mut output = vec![0.0; 4];
249        rms_norm(&mut output, &input, &weight, 1e-5);
250
251        // RMS = sqrt((1+4+9+16)/4) = sqrt(7.5) ≈ 2.7386
252        let rms = (30.0f32 / 4.0 + 1e-5).sqrt();
253        let expected: Vec<f32> = input.iter().map(|x| x / rms).collect();
254        for (a, b) in output.iter().zip(expected.iter()) {
255            assert!((a - b).abs() < 1e-5, "got {a}, expected {b}");
256        }
257    }
258
259    #[test]
260    fn matmul_basic() {
261        // [1, 2] x [[1, 3], [2, 4]]^T = [1*1+2*2, 1*3+2*4] = [5, 11]
262        // weight stored as [n, k] = [[1, 2], [3, 4]]
263        let input = vec![1.0, 2.0];
264        let weight = vec![1.0, 2.0, 3.0, 4.0]; // row 0: [1,2], row 1: [3,4]
265        let mut output = vec![0.0; 2];
266        matmul(&mut output, &input, &weight, 1, 2, 2);
267        assert!((output[0] - 5.0).abs() < 1e-6);
268        assert!((output[1] - 11.0).abs() < 1e-6);
269    }
270
271    #[test]
272    fn silu_basic() {
273        let input = vec![0.0, 1.0, -1.0];
274        let mut output = vec![0.0; 3];
275        silu(&mut output, &input);
276        // silu(0) = 0, silu(1) = 1/(1+e^-1) ≈ 0.7311
277        assert!((output[0] - 0.0).abs() < 1e-6);
278        assert!((output[1] - 0.7311).abs() < 1e-3);
279        assert!((output[2] - (-0.2689)).abs() < 1e-3);
280    }
281
282    #[test]
283    fn softmax_basic() {
284        let mut values = vec![1.0, 2.0, 3.0];
285        kernels::softmax(&mut values);
286        let sum: f32 = values.iter().sum();
287        assert!((sum - 1.0).abs() < 1e-6);
288        assert!(values[2] > values[1]);
289        assert!(values[1] > values[0]);
290    }
291
292    #[test]
293    fn rope_preserves_magnitude() {
294        // RoPE is a rotation, so it should preserve vector magnitude
295        let mut data = vec![1.0, 0.0, 0.0, 1.0]; // 1 head, dim=4
296        let mag_before: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
297        rope(&mut data, 5, 4, 1, 10000.0);
298        let mag_after: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
299        assert!(
300            (mag_before - mag_after).abs() < 1e-5,
301            "RoPE changed magnitude: {mag_before} → {mag_after}"
302        );
303    }
304
305    #[test]
306    fn forward_with_tiny_model() {
307        // Build a minimal model to verify the interpreter runs without panicking
308        let config = ModelConfig {
309            architecture: Architecture::Llama,
310            hidden_size: 8,
311            intermediate_size: 16,
312            num_layers: 1,
313            num_attention_heads: 2,
314            num_kv_heads: 1,
315            head_dim: 4,
316            vocab_size: 16,
317            max_seq_len: 32,
318            rms_norm_eps: 1e-5,
319            rope_theta: 10000.0,
320            dtype: DType::F32,
321            sliding_window_size: None,
322            qkv_bias: false,
323            hidden_activation: HiddenActivation::SiLU,
324        };
325
326        let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
327
328        // Create random-ish weights
329        let mut tensors = HashMap::new();
330        let h = 8;
331        let inter = 16;
332        let vocab = 16;
333        let num_heads = 2;
334        let num_kv_heads = 1;
335        let head_dim = 4;
336
337        tensors.insert("model.embed_tokens.weight".into(), vec![0.1f32; vocab * h]);
338        tensors.insert(
339            "model.layers.0.input_layernorm.weight".into(),
340            vec![1.0f32; h],
341        );
342        tensors.insert(
343            "model.layers.0.self_attn.q_proj.weight".into(),
344            vec![0.01f32; num_heads * head_dim * h],
345        );
346        tensors.insert(
347            "model.layers.0.self_attn.k_proj.weight".into(),
348            vec![0.01f32; num_kv_heads * head_dim * h],
349        );
350        tensors.insert(
351            "model.layers.0.self_attn.v_proj.weight".into(),
352            vec![0.01f32; num_kv_heads * head_dim * h],
353        );
354        tensors.insert(
355            "model.layers.0.self_attn.o_proj.weight".into(),
356            vec![0.01f32; h * num_heads * head_dim],
357        );
358        tensors.insert(
359            "model.layers.0.post_attention_layernorm.weight".into(),
360            vec![1.0f32; h],
361        );
362        tensors.insert(
363            "model.layers.0.mlp.gate_proj.weight".into(),
364            vec![0.01f32; inter * h],
365        );
366        tensors.insert(
367            "model.layers.0.mlp.up_proj.weight".into(),
368            vec![0.01f32; inter * h],
369        );
370        tensors.insert(
371            "model.layers.0.mlp.down_proj.weight".into(),
372            vec![0.01f32; h * inter],
373        );
374        tensors.insert("model.norm.weight".into(), vec![1.0f32; h]);
375        tensors.insert("lm_head.weight".into(), vec![0.01f32; vocab * h]);
376
377        let weights = ModelWeights { tensors };
378        let mut kv_cache = KVCache::new(1, num_kv_heads, head_dim);
379
380        // Run forward pass
381        let logits = forward(0, 0, &graph, &weights, &mut kv_cache);
382        assert_eq!(logits.len(), vocab);
383        assert_eq!(kv_cache.len(), 0); // advance not called by forward
384
385        // Logits should be finite
386        for &l in &logits {
387            assert!(l.is_finite(), "logit is not finite: {l}");
388        }
389    }
390
391    #[test]
392    fn forward_multi_token() {
393        let config = ModelConfig {
394            architecture: Architecture::Llama,
395            hidden_size: 8,
396            intermediate_size: 16,
397            num_layers: 1,
398            num_attention_heads: 2,
399            num_kv_heads: 1,
400            head_dim: 4,
401            vocab_size: 16,
402            max_seq_len: 32,
403            rms_norm_eps: 1e-5,
404            rope_theta: 10000.0,
405            dtype: DType::F32,
406            sliding_window_size: None,
407            qkv_bias: false,
408            hidden_activation: HiddenActivation::SiLU,
409        };
410
411        let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
412
413        let mut tensors = HashMap::new();
414        let h = 8;
415        let inter = 16;
416        let vocab = 16;
417
418        tensors.insert("model.embed_tokens.weight".into(), vec![0.1f32; vocab * h]);
419        tensors.insert("model.layers.0.input_layernorm.weight".into(), vec![1.0; h]);
420        tensors.insert(
421            "model.layers.0.self_attn.q_proj.weight".into(),
422            vec![0.01; 8 * h],
423        );
424        tensors.insert(
425            "model.layers.0.self_attn.k_proj.weight".into(),
426            vec![0.01; 4 * h],
427        );
428        tensors.insert(
429            "model.layers.0.self_attn.v_proj.weight".into(),
430            vec![0.01; 4 * h],
431        );
432        tensors.insert(
433            "model.layers.0.self_attn.o_proj.weight".into(),
434            vec![0.01; h * 8],
435        );
436        tensors.insert(
437            "model.layers.0.post_attention_layernorm.weight".into(),
438            vec![1.0; h],
439        );
440        tensors.insert(
441            "model.layers.0.mlp.gate_proj.weight".into(),
442            vec![0.01; inter * h],
443        );
444        tensors.insert(
445            "model.layers.0.mlp.up_proj.weight".into(),
446            vec![0.01; inter * h],
447        );
448        tensors.insert(
449            "model.layers.0.mlp.down_proj.weight".into(),
450            vec![0.01; h * inter],
451        );
452        tensors.insert("model.norm.weight".into(), vec![1.0; h]);
453        tensors.insert("lm_head.weight".into(), vec![0.01; vocab * h]);
454
455        let weights = ModelWeights { tensors };
456        let mut cache = KVCache::new(1, 1, 4);
457
458        // Generate 3 tokens
459        for pos in 0..3 {
460            let logits = forward(1, pos, &graph, &weights, &mut cache);
461            assert_eq!(logits.len(), vocab);
462            cache.advance();
463        }
464
465        assert_eq!(cache.len(), 3);
466    }
467
468    // ── Real-world validation tests ──────────────────────────────────────
469
470    /// Build a tiny model with distinguishable per-token embeddings so that
471    /// different token IDs produce different logit distributions.
472    fn tiny_model_with_varied_weights() -> (ModelConfig, Graph, ModelWeights) {
473        let config = ModelConfig {
474            architecture: Architecture::Llama,
475            hidden_size: 8,
476            intermediate_size: 16,
477            num_layers: 1,
478            num_attention_heads: 2,
479            num_kv_heads: 1,
480            head_dim: 4,
481            vocab_size: 16,
482            max_seq_len: 32,
483            rms_norm_eps: 1e-5,
484            rope_theta: 10000.0,
485            dtype: DType::F32,
486            sliding_window_size: None,
487            qkv_bias: false,
488            hidden_activation: HiddenActivation::SiLU,
489        };
490
491        let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
492
493        let h = 8;
494        let inter = 16;
495        let vocab = 16;
496        let num_heads = 2;
497        let num_kv_heads = 1;
498        let head_dim = 4;
499
500        let mut tensors = HashMap::new();
501
502        // Varied embeddings: each token gets a distinct embedding vector
503        let mut embed = vec![0.0f32; vocab * h];
504        for tok in 0..vocab {
505            for d in 0..h {
506                embed[tok * h + d] = ((tok * h + d) as f32 + 1.0) * 0.05;
507            }
508        }
509        tensors.insert("model.embed_tokens.weight".into(), embed);
510
511        tensors.insert(
512            "model.layers.0.input_layernorm.weight".into(),
513            vec![1.0f32; h],
514        );
515        // Use varied projection weights so the model isn't degenerate
516        let q_w: Vec<f32> = (0..num_heads * head_dim * h)
517            .map(|i| ((i % 7) as f32 + 1.0) * 0.01)
518            .collect();
519        let k_w: Vec<f32> = (0..num_kv_heads * head_dim * h)
520            .map(|i| ((i % 5) as f32 + 1.0) * 0.01)
521            .collect();
522        let v_w: Vec<f32> = (0..num_kv_heads * head_dim * h)
523            .map(|i| ((i % 3) as f32 + 1.0) * 0.01)
524            .collect();
525        let o_w: Vec<f32> = (0..h * num_heads * head_dim)
526            .map(|i| ((i % 11) as f32 + 1.0) * 0.01)
527            .collect();
528        tensors.insert("model.layers.0.self_attn.q_proj.weight".into(), q_w);
529        tensors.insert("model.layers.0.self_attn.k_proj.weight".into(), k_w);
530        tensors.insert("model.layers.0.self_attn.v_proj.weight".into(), v_w);
531        tensors.insert("model.layers.0.self_attn.o_proj.weight".into(), o_w);
532        tensors.insert(
533            "model.layers.0.post_attention_layernorm.weight".into(),
534            vec![1.0f32; h],
535        );
536
537        let gate_w: Vec<f32> = (0..inter * h)
538            .map(|i| ((i % 13) as f32 + 1.0) * 0.01)
539            .collect();
540        let up_w: Vec<f32> = (0..inter * h)
541            .map(|i| ((i % 9) as f32 + 1.0) * 0.01)
542            .collect();
543        let down_w: Vec<f32> = (0..h * inter)
544            .map(|i| ((i % 7) as f32 + 1.0) * 0.01)
545            .collect();
546        tensors.insert("model.layers.0.mlp.gate_proj.weight".into(), gate_w);
547        tensors.insert("model.layers.0.mlp.up_proj.weight".into(), up_w);
548        tensors.insert("model.layers.0.mlp.down_proj.weight".into(), down_w);
549
550        tensors.insert("model.norm.weight".into(), vec![1.0f32; h]);
551        // Varied lm_head so different hidden states map to different logits
552        let lm_head: Vec<f32> = (0..vocab * h)
553            .map(|i| ((i % 17) as f32 - 8.0) * 0.02)
554            .collect();
555        tensors.insert("lm_head.weight".into(), lm_head);
556
557        let weights = ModelWeights { tensors };
558        (config, graph, weights)
559    }
560
561    #[test]
562    fn different_prompts_produce_different_logits() {
563        // Two different token IDs at pos=0 should produce different logit vectors.
564        // This validates that the model distinguishes inputs (not degenerate).
565        let (_config, graph, weights) = tiny_model_with_varied_weights();
566
567        let mut cache1 = KVCache::new(1, 1, 4);
568        let logits1 = forward(0, 0, &graph, &weights, &mut cache1);
569
570        let mut cache2 = KVCache::new(1, 1, 4);
571        let logits2 = forward(5, 0, &graph, &weights, &mut cache2);
572
573        // Both should be finite
574        for &l in &logits1 {
575            assert!(l.is_finite(), "logits1 contains non-finite value: {l}");
576        }
577        for &l in &logits2 {
578            assert!(l.is_finite(), "logits2 contains non-finite value: {l}");
579        }
580
581        // They should differ (the model is not degenerate)
582        let differs = logits1
583            .iter()
584            .zip(logits2.iter())
585            .any(|(a, b)| (a - b).abs() > 1e-6);
586        assert!(
587            differs,
588            "different input tokens should produce different logit distributions"
589        );
590    }
591
592    #[test]
593    fn cache_reset_produces_same_logits() {
594        // After clearing the cache, running the same token at pos=0 should
595        // produce identical logits as a fresh run. This validates that clear()
596        // truly resets all state for independent multi-request serving.
597        let (_config, graph, weights) = tiny_model_with_varied_weights();
598
599        // First run: fresh cache
600        let mut cache = KVCache::new(1, 1, 4);
601        let logits_fresh = forward(3, 0, &graph, &weights, &mut cache);
602
603        // Advance the cache with more tokens to build up state
604        cache.advance();
605        let _ = forward(7, 1, &graph, &weights, &mut cache);
606        cache.advance();
607        assert_eq!(cache.len(), 2);
608
609        // Clear and re-run
610        cache.clear();
611        assert_eq!(cache.len(), 0);
612        let logits_after_reset = forward(3, 0, &graph, &weights, &mut cache);
613
614        // Should be identical
615        for (i, (a, b)) in logits_fresh
616            .iter()
617            .zip(logits_after_reset.iter())
618            .enumerate()
619        {
620            assert!(
621                (a - b).abs() < 1e-6,
622                "logit[{i}] differs after reset: fresh={a}, after_reset={b}"
623            );
624        }
625    }
626
627    #[test]
628    fn forward_at_pos_zero_no_nan() {
629        // pos=0 is the first token where seq_len=1 in attention.
630        // This is a common edge case: softmax over a single element,
631        // RoPE with angle=0, and KV cache with one entry.
632        let (_config, graph, weights) = tiny_model_with_varied_weights();
633        let mut cache = KVCache::new(1, 1, 4);
634
635        let logits = forward(0, 0, &graph, &weights, &mut cache);
636        assert_eq!(logits.len(), 16);
637
638        for (i, &l) in logits.iter().enumerate() {
639            assert!(
640                !l.is_nan(),
641                "logit[{i}] is NaN at pos=0 — likely a softmax or attention bug"
642            );
643            assert!(!l.is_infinite(), "logit[{i}] is infinite at pos=0");
644        }
645    }
646}