Skip to main content

aprender/models/qwen2/
mod.rs

1//! Qwen2-0.5B-Instruct Model Implementation
2//!
3//! This module provides a complete Qwen2 model for inference, assembling
4//! primitives from the `nn` module into a decoder-only transformer.
5//!
6//! # Architecture (Bai et al., 2023)
7//!
8//! ```text
9//! Qwen2-0.5B-Instruct:
10//! ├── hidden_size: 896
11//! ├── num_attention_heads: 14 (query heads)
12//! ├── num_kv_heads: 2 (grouped query attention)
13//! ├── num_layers: 24
14//! ├── intermediate_size: 4864 (FFN)
15//! ├── vocab_size: 151936
16//! ├── max_seq_len: 32768
17//! └── rope_theta: 1,000,000
18//! ```
19//!
20//! # Example
21//!
22//! ```ignore
23//! use aprender::models::Qwen2Model;
24//! use aprender::demo::Qwen2Config;
25//!
26//! let config = Qwen2Config::qwen2_0_5b_instruct();
27//! let mut model = Qwen2Model::new(&config);
28//!
29//! let input_ids = vec![1u32, 2, 3, 4, 5];
30//! let position_ids: Vec<usize> = (0..5).collect();
31//! let logits = model.forward(&input_ids, &position_ids);
32//! ```
33//!
34//! # References
35//!
36//! - Bai et al. (2023). "Qwen Technical Report"
37//! - Ainslie et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models"
38//! - Su et al. (2021). "`RoFormer`: Enhanced Transformer with Rotary Position Embedding"
39//! - Zhang & Sennrich (2019). "Root Mean Square Layer Normalization"
40
41use crate::autograd::Tensor;
42use crate::demo::Qwen2Config;
43use crate::nn::{GroupedQueryAttention, Linear, Module, RMSNorm, RotaryPositionEmbedding};
44
45// ============================================================================
46// Embedding Layer
47// ============================================================================
48
49/// Token embedding lookup table.
50///
51/// Maps token IDs to dense vectors.
52#[derive(Debug)]
53pub struct Embedding {
54    /// Weight matrix [`vocab_size`, `hidden_size`]
55    weight: Tensor,
56    vocab_size: usize,
57    hidden_size: usize,
58}
59
60impl Embedding {
61    /// Create a new embedding layer.
62    #[must_use]
63    pub fn new(vocab_size: usize, hidden_size: usize) -> Self {
64        // Initialize with small random values
65        let data: Vec<f32> = (0..vocab_size * hidden_size)
66            .map(|i| {
67                // Deterministic pseudo-random initialization
68                (i as f32 * 0.0001).sin() * 0.02
69            })
70            .collect();
71
72        Self {
73            weight: Tensor::new(&data, &[vocab_size, hidden_size]),
74            vocab_size,
75            hidden_size,
76        }
77    }
78
79    /// Create a placeholder embedding with minimal memory allocation.
80    ///
81    /// Used for lazy initialization when loading pre-trained weights.
82    /// Uses 1-element tensor instead of `vocab_size` * `hidden_size`.
83    ///
84    /// **IMPORTANT**: This layer will NOT work for inference until
85    /// `set_weight()` is called with real weights.
86    #[must_use]
87    pub fn placeholder(vocab_size: usize, hidden_size: usize) -> Self {
88        Self {
89            weight: Tensor::new(&[0.0], &[1]),
90            vocab_size,
91            hidden_size,
92        }
93    }
94
95    /// Look up embeddings for token IDs into a pre-allocated buffer.
96    pub fn forward_into(&self, input_ids: &[u32], output: &mut [f32]) {
97        for (s, &token_id) in input_ids.iter().enumerate() {
98            let token_idx = token_id as usize;
99            if token_idx >= self.vocab_size {
100                // Out of vocabulary - zeros already in buffer if initialized
101                continue;
102            }
103
104            let src_offset = token_idx * self.hidden_size;
105            let dst_offset = s * self.hidden_size;
106
107            output[dst_offset..dst_offset + self.hidden_size]
108                .copy_from_slice(&self.weight.data()[src_offset..src_offset + self.hidden_size]);
109        }
110    }
111
112    /// Look up embeddings for token IDs.
113    #[must_use]
114    pub fn forward(&self, input_ids: &[u32]) -> Tensor {
115        let batch_size = 1;
116        let mut output = vec![0.0f32; batch_size * input_ids.len() * self.hidden_size];
117        self.forward_into(input_ids, &mut output);
118        Tensor::new(&output, &[batch_size, input_ids.len(), self.hidden_size])
119    }
120
121    /// Set weights from external tensor.
122    pub fn set_weight(&mut self, weight: Tensor) {
123        self.weight = weight;
124    }
125
126    /// Get weight tensor reference.
127    #[must_use]
128    pub fn weight(&self) -> &Tensor {
129        &self.weight
130    }
131}
132
133// ============================================================================
134// Qwen2 MLP (SwiGLU)
135// ============================================================================
136
137/// Qwen2 MLP with `SwiGLU` activation.
138///
139/// ```text
140/// output = down_proj(SiLU(gate_proj(x)) * up_proj(x))
141/// ```
142#[derive(Debug)]
143#[allow(clippy::struct_field_names)] // Standard ML naming convention
144pub struct Qwen2MLP {
145    gate_proj: Linear,
146    up_proj: Linear,
147    down_proj: Linear,
148}
149
150impl Qwen2MLP {
151    /// Create a new Qwen2 MLP layer.
152    #[must_use]
153    pub fn new(hidden_size: usize, intermediate_size: usize) -> Self {
154        Self {
155            gate_proj: Linear::new(hidden_size, intermediate_size),
156            up_proj: Linear::new(hidden_size, intermediate_size),
157            down_proj: Linear::new(intermediate_size, hidden_size),
158        }
159    }
160
161    /// Create a placeholder MLP with minimal memory allocation.
162    ///
163    /// Used for lazy initialization when loading pre-trained weights.
164    #[must_use]
165    pub fn placeholder(hidden_size: usize, intermediate_size: usize) -> Self {
166        Self {
167            gate_proj: Linear::placeholder(hidden_size, intermediate_size),
168            up_proj: Linear::placeholder(hidden_size, intermediate_size),
169            down_proj: Linear::placeholder(intermediate_size, hidden_size),
170        }
171    }
172
173    /// Forward pass with `SwiGLU` activation.
174    #[must_use]
175    pub fn forward(&self, x: &Tensor) -> Tensor {
176        let gate = self.gate_proj.forward(x);
177        let gate_activated = silu(&gate);
178        let up = self.up_proj.forward(x);
179        let hidden = elementwise_mul(&gate_activated, &up);
180        self.down_proj.forward(&hidden)
181    }
182
183    /// Get mutable reference to gate projection layer.
184    pub fn gate_proj_mut(&mut self) -> &mut Linear {
185        &mut self.gate_proj
186    }
187
188    /// Get mutable reference to up projection layer.
189    pub fn up_proj_mut(&mut self) -> &mut Linear {
190        &mut self.up_proj
191    }
192
193    /// Get mutable reference to down projection layer.
194    pub fn down_proj_mut(&mut self) -> &mut Linear {
195        &mut self.down_proj
196    }
197}
198
199// ============================================================================
200// Qwen2 Decoder Layer
201// ============================================================================
202
203/// Single Qwen2 decoder layer.
204///
205/// ```text
206/// residual = x
207/// x = input_layernorm(x)
208/// x = self_attn(x, x, x) + residual
209///
210/// residual = x
211/// x = post_attention_layernorm(x)
212/// x = mlp(x) + residual
213/// ```
214#[derive(Debug)]
215pub struct Qwen2DecoderLayer {
216    self_attn: GroupedQueryAttention,
217    mlp: Qwen2MLP,
218    input_layernorm: RMSNorm,
219    post_attention_layernorm: RMSNorm,
220}
221
222impl Qwen2DecoderLayer {
223    /// Create a new decoder layer.
224    #[must_use]
225    pub fn new(config: &Qwen2Config) -> Self {
226        Self {
227            self_attn: GroupedQueryAttention::new(
228                config.hidden_size,
229                config.num_attention_heads,
230                config.num_kv_heads,
231            ),
232            mlp: Qwen2MLP::new(config.hidden_size, config.intermediate_size),
233            input_layernorm: RMSNorm::new(&[config.hidden_size]),
234            post_attention_layernorm: RMSNorm::new(&[config.hidden_size]),
235        }
236    }
237
238    /// Create a placeholder decoder layer with minimal memory allocation.
239    ///
240    /// Used for lazy initialization when loading pre-trained weights.
241    #[must_use]
242    pub fn placeholder(config: &Qwen2Config) -> Self {
243        Self {
244            self_attn: GroupedQueryAttention::placeholder(
245                config.hidden_size,
246                config.num_attention_heads,
247                config.num_kv_heads,
248            ),
249            mlp: Qwen2MLP::placeholder(config.hidden_size, config.intermediate_size),
250            input_layernorm: RMSNorm::placeholder(&[config.hidden_size]),
251            post_attention_layernorm: RMSNorm::placeholder(&[config.hidden_size]),
252        }
253    }
254
255    /// Forward pass through the decoder layer.
256    #[must_use]
257    pub fn forward(
258        &self,
259        hidden_states: &Tensor,
260        _position_ids: &[usize],
261        _rope: &RotaryPositionEmbedding,
262        _attention_mask: Option<&Tensor>,
263    ) -> Tensor {
264        // Self-attention with pre-norm
265        // Note: Attention mask handling is simplified - using None for now
266        // Full implementation would reshape mask for multi-head attention
267        let residual = hidden_states.clone();
268        let hidden = self.input_layernorm.forward(hidden_states);
269        let (attn_output, _attn_weights) = self.self_attn.forward_self(&hidden, None);
270        let hidden = add_tensors(&residual, &attn_output);
271
272        // MLP with pre-norm
273        let residual = hidden.clone();
274        let hidden = self.post_attention_layernorm.forward(&hidden);
275        let mlp_output = self.mlp.forward(&hidden);
276        add_tensors(&residual, &mlp_output)
277    }
278
279    /// Forward pass with detailed profiling output.
280    #[must_use]
281    pub fn forward_profiled(
282        &self,
283        hidden_states: &Tensor,
284        _position_ids: &[usize],
285        _rope: &RotaryPositionEmbedding,
286        _attention_mask: Option<&Tensor>,
287    ) -> (Tensor, std::time::Duration, std::time::Duration) {
288        use std::time::Instant;
289
290        // Self-attention with pre-norm
291        let residual = hidden_states.clone();
292        let hidden = self.input_layernorm.forward(hidden_states);
293
294        let attn_start = Instant::now();
295        let (attn_output, _attn_weights) = self.self_attn.forward_self(&hidden, None);
296        let attn_time = attn_start.elapsed();
297
298        let hidden = add_tensors(&residual, &attn_output);
299
300        // MLP with pre-norm
301        let residual = hidden.clone();
302        let hidden = self.post_attention_layernorm.forward(&hidden);
303
304        let mlp_start = Instant::now();
305        let mlp_output = self.mlp.forward(&hidden);
306        let mlp_time = mlp_start.elapsed();
307
308        (add_tensors(&residual, &mlp_output), attn_time, mlp_time)
309    }
310
311    /// Get mutable reference to self-attention layer.
312    pub fn self_attn_mut(&mut self) -> &mut GroupedQueryAttention {
313        &mut self.self_attn
314    }
315
316    /// Get mutable reference to MLP layer.
317    pub fn mlp_mut(&mut self) -> &mut Qwen2MLP {
318        &mut self.mlp
319    }
320
321    /// Get mutable reference to input layernorm.
322    pub fn input_layernorm_mut(&mut self) -> &mut RMSNorm {
323        &mut self.input_layernorm
324    }
325
326    /// Get mutable reference to post-attention layernorm.
327    pub fn post_attention_layernorm_mut(&mut self) -> &mut RMSNorm {
328        &mut self.post_attention_layernorm
329    }
330}
331
332// ============================================================================
333// KV Cache
334// ============================================================================
335
336/// Key-Value cache for efficient autoregressive generation.
337#[derive(Debug)]
338pub struct KVCache {
339    /// Cached keys per layer: [batch, `num_kv_heads`, `cached_len`, `head_dim`]
340    pub keys: Vec<Option<Tensor>>,
341    /// Cached values per layer
342    pub values: Vec<Option<Tensor>>,
343    /// Number of cached positions
344    pub cached_len: usize,
345}
346
347impl KVCache {
348    /// Create a new empty KV cache.
349    #[must_use]
350    pub fn new(num_layers: usize) -> Self {
351        Self {
352            keys: vec![None; num_layers],
353            values: vec![None; num_layers],
354            cached_len: 0,
355        }
356    }
357
358    /// Clear the cache.
359    pub fn clear(&mut self) {
360        for k in &mut self.keys {
361            *k = None;
362        }
363        for v in &mut self.values {
364            *v = None;
365        }
366        self.cached_len = 0;
367    }
368}
369
370// ============================================================================
371// Qwen2 Model
372// ============================================================================
373
374/// Complete Qwen2 model for inference.
375///
376/// Assembles embedding, decoder layers, and LM head into a complete model.
377#[derive(Debug)]
378pub struct Qwen2Model {
379    /// Token embeddings [`vocab_size`, `hidden_size`]
380    embed_tokens: Embedding,
381    /// Decoder layers
382    layers: Vec<Qwen2DecoderLayer>,
383    /// Final `RMSNorm`
384    norm: RMSNorm,
385    /// Language model head [`hidden_size`, `vocab_size`]
386    lm_head: Linear,
387    /// Rotary position embeddings
388    rope: RotaryPositionEmbedding,
389    /// Model configuration
390    config: Qwen2Config,
391    /// KV cache for generation
392    kv_cache: Option<KVCache>,
393    /// Training mode flag
394    training: bool,
395    /// Cached causal mask to avoid per-token allocations
396    cached_causal_mask: Option<Tensor>,
397    /// Buffer for mask data to avoid Vec allocations
398    cached_mask_data: Vec<f32>,
399    /// Buffer for embedding data to avoid Vec allocations
400    cached_embed_data: Vec<f32>,
401}
402
403impl Qwen2Model {
404    /// Create a new Qwen2 model from configuration.
405    ///
406    /// Weights are initialized randomly. Use `load()` to load pre-trained weights.
407    #[must_use]
408    pub fn new(config: &Qwen2Config) -> Self {
409        let head_dim = config.hidden_size / config.num_attention_heads;
410
411        Self {
412            embed_tokens: Embedding::new(config.vocab_size, config.hidden_size),
413            layers: (0..config.num_layers)
414                .map(|_| Qwen2DecoderLayer::new(config))
415                .collect(),
416            norm: RMSNorm::new(&[config.hidden_size]),
417            lm_head: Linear::new(config.hidden_size, config.vocab_size),
418            rope: RotaryPositionEmbedding::with_base(
419                head_dim,
420                config.max_seq_len,
421                config.rope_theta as f32,
422            ),
423            config: config.clone(),
424            kv_cache: None,
425            training: false,
426            cached_causal_mask: None,
427            cached_mask_data: Vec::new(),
428            cached_embed_data: Vec::new(),
429        }
430    }
431
432    /// Create an uninitialized Qwen2 model with minimal memory allocation.
433    ///
434    /// The model is not ready for inference until weights are loaded.
435    #[must_use]
436    pub fn new_uninitialized(config: &Qwen2Config) -> Self {
437        let head_dim = config.hidden_size / config.num_attention_heads;
438
439        Self {
440            embed_tokens: Embedding::placeholder(config.vocab_size, config.hidden_size),
441            layers: (0..config.num_layers)
442                .map(|_| Qwen2DecoderLayer::placeholder(config))
443                .collect(),
444            norm: RMSNorm::placeholder(&[config.hidden_size]),
445            lm_head: Linear::placeholder(config.hidden_size, config.vocab_size),
446            rope: RotaryPositionEmbedding::with_base(
447                head_dim,
448                config.max_seq_len,
449                config.rope_theta as f32,
450            ),
451            config: config.clone(),
452            kv_cache: None,
453            training: false,
454            cached_causal_mask: None,
455            cached_mask_data: Vec::new(),
456            cached_embed_data: Vec::new(),
457        }
458    }
459
460    /// Forward pass through the model.
461    ///
462    /// # Arguments
463    ///
464    /// * `input_ids` - Token IDs \[`seq_len`\]
465    /// * `position_ids` - Position indices \[`seq_len`\]
466    ///
467    /// # Returns
468    ///
469    /// Logits tensor [1, `seq_len`, `vocab_size`]
470    pub fn forward(&mut self, input_ids: &[u32], position_ids: &[usize]) -> Tensor {
471        // Embed tokens (re-use buffer)
472        let seq_len = input_ids.len();
473        if self.cached_embed_data.len() < seq_len * self.config.hidden_size {
474            self.cached_embed_data = vec![0.0f32; seq_len * self.config.hidden_size];
475        }
476        self.embed_tokens
477            .forward_into(input_ids, &mut self.cached_embed_data);
478        let mut hidden = Tensor::new(
479            &self.cached_embed_data[..seq_len * self.config.hidden_size],
480            &[1, seq_len, self.config.hidden_size],
481        );
482
483        // Generate causal mask (re-use if size matches)
484        if self
485            .cached_causal_mask
486            .as_ref()
487            .map_or(true, |m| m.shape()[0] != seq_len)
488        {
489            if self.cached_mask_data.len() < seq_len * seq_len {
490                self.cached_mask_data = vec![0.0f32; seq_len * seq_len];
491            }
492            generate_causal_mask_into(seq_len, &mut self.cached_mask_data);
493            self.cached_causal_mask = Some(Tensor::new(
494                &self.cached_mask_data[..seq_len * seq_len],
495                &[seq_len, seq_len],
496            ));
497        }
498        let attention_mask = self
499            .cached_causal_mask
500            .as_ref()
501            .expect("causal mask must be initialized before forward pass");
502
503        // Pass through decoder layers
504        for layer in &self.layers {
505            hidden = layer.forward(&hidden, position_ids, &self.rope, Some(attention_mask));
506        }
507
508        // Final normalization
509        hidden = self.norm.forward(&hidden);
510
511        // Project to vocabulary
512        self.lm_head.forward(&hidden)
513    }
514
515    /// Forward with detailed profiling output.
516    /// Prints timing breakdown for each component.
517    pub fn forward_profiled(&mut self, input_ids: &[u32], position_ids: &[usize]) -> Tensor {
518        use std::time::Instant;
519
520        let total_start = Instant::now();
521
522        // Embed tokens (re-use buffer)
523        let embed_start = Instant::now();
524        let seq_len = input_ids.len();
525        if self.cached_embed_data.len() < seq_len * self.config.hidden_size {
526            self.cached_embed_data = vec![0.0f32; seq_len * self.config.hidden_size];
527        }
528        self.embed_tokens
529            .forward_into(input_ids, &mut self.cached_embed_data);
530        let mut hidden = Tensor::new(
531            &self.cached_embed_data[..seq_len * self.config.hidden_size],
532            &[1, seq_len, self.config.hidden_size],
533        );
534        let embed_time = embed_start.elapsed();
535
536        // Generate causal mask (re-use if size matches)
537        if self
538            .cached_causal_mask
539            .as_ref()
540            .map_or(true, |m| m.shape()[0] != seq_len)
541        {
542            if self.cached_mask_data.len() < seq_len * seq_len {
543                self.cached_mask_data = vec![0.0f32; seq_len * seq_len];
544            }
545            generate_causal_mask_into(seq_len, &mut self.cached_mask_data);
546            self.cached_causal_mask = Some(Tensor::new(
547                &self.cached_mask_data[..seq_len * seq_len],
548                &[seq_len, seq_len],
549            ));
550        }
551        let attention_mask = self
552            .cached_causal_mask
553            .as_ref()
554            .expect("causal mask must be initialized before profiled forward pass");
555
556        // Pass through decoder layers with profiling
557        let mut total_attn = std::time::Duration::ZERO;
558        let mut total_mlp = std::time::Duration::ZERO;
559
560        let layers_start = Instant::now();
561        for layer in &self.layers {
562            let (output, attn_time, mlp_time) =
563                layer.forward_profiled(&hidden, position_ids, &self.rope, Some(attention_mask));
564            hidden = output;
565            total_attn += attn_time;
566            total_mlp += mlp_time;
567        }
568        let layers_time = layers_start.elapsed();
569
570        // Final normalization
571        let norm_start = Instant::now();
572        hidden = self.norm.forward(&hidden);
573        let norm_time = norm_start.elapsed();
574
575        // Project to vocabulary
576        let lm_head_start = Instant::now();
577        let output = self.lm_head.forward(&hidden);
578        let lm_head_time = lm_head_start.elapsed();
579
580        let total_time = total_start.elapsed();
581
582        // Print profiling results
583        eprintln!("\n=== Forward Pass Profile (seq_len={seq_len}) ===");
584        eprintln!(
585            "  Embedding:     {:>8.2}ms ({:>5.1}%)",
586            embed_time.as_secs_f64() * 1000.0,
587            embed_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
588        );
589        eprintln!(
590            "  Layers total:  {:>8.2}ms ({:>5.1}%)",
591            layers_time.as_secs_f64() * 1000.0,
592            layers_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
593        );
594        eprintln!(
595            "    - Attention: {:>8.2}ms ({:>5.1}%)",
596            total_attn.as_secs_f64() * 1000.0,
597            total_attn.as_secs_f64() / total_time.as_secs_f64() * 100.0
598        );
599        eprintln!(
600            "    - MLP:       {:>8.2}ms ({:>5.1}%)",
601            total_mlp.as_secs_f64() * 1000.0,
602            total_mlp.as_secs_f64() / total_time.as_secs_f64() * 100.0
603        );
604        eprintln!(
605            "  Final norm:    {:>8.2}ms ({:>5.1}%)",
606            norm_time.as_secs_f64() * 1000.0,
607            norm_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
608        );
609        eprintln!(
610            "  LM head:       {:>8.2}ms ({:>5.1}%)",
611            lm_head_time.as_secs_f64() * 1000.0,
612            lm_head_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
613        );
614        eprintln!(
615            "  TOTAL:         {:>8.2}ms",
616            total_time.as_secs_f64() * 1000.0
617        );
618        eprintln!("==========================================\n");
619
620        output
621    }
622
623    /// Generate tokens autoregressively.
624    ///
625    /// # Arguments
626    ///
627    /// * `prompt_ids` - Initial prompt token IDs
628    /// * `max_new_tokens` - Maximum number of tokens to generate
629    /// * `temperature` - Sampling temperature (0 = greedy)
630    /// * `top_p` - Nucleus sampling threshold
631    ///
632    /// # Returns
633    ///
634    /// Complete sequence including prompt and generated tokens.
635    pub fn generate(
636        &mut self,
637        prompt_ids: &[u32],
638        max_new_tokens: usize,
639        temperature: f32,
640        _top_p: f32,
641    ) -> Vec<u32> {
642        self.generate_internal(prompt_ids, max_new_tokens, temperature, false)
643    }
644
645    /// Generate with profiling output (prints timing breakdown).
646    pub fn generate_profiled(
647        &mut self,
648        prompt_ids: &[u32],
649        max_new_tokens: usize,
650        temperature: f32,
651    ) -> Vec<u32> {
652        self.generate_internal(prompt_ids, max_new_tokens, temperature, true)
653    }
654
655    fn generate_internal(
656        &mut self,
657        prompt_ids: &[u32],
658        max_new_tokens: usize,
659        temperature: f32,
660        profile: bool,
661    ) -> Vec<u32> {
662        let mut output_ids = Vec::with_capacity(prompt_ids.len() + max_new_tokens);
663        output_ids.extend_from_slice(prompt_ids);
664
665        // Pre-allocate position_ids buffer
666        let mut position_ids = Vec::with_capacity(prompt_ids.len() + max_new_tokens);
667
668        for i in 0..max_new_tokens {
669            // Update position IDs for current sequence length
670            position_ids.clear();
671            for p in 0..output_ids.len() {
672                position_ids.push(p);
673            }
674
675            // Only profile first token to avoid spam
676            let logits = if profile && i == 0 {
677                self.forward_profiled(&output_ids, &position_ids)
678            } else {
679                self.forward(&output_ids, &position_ids)
680            };
681
682            // Get last token logits
683            let vocab_size = self.config.vocab_size;
684            let last_pos = output_ids.len() - 1;
685            let logits_slice = &logits.data()[last_pos * vocab_size..(last_pos + 1) * vocab_size];
686
687            // Sample next token
688            let next_token = if temperature == 0.0 {
689                // Greedy
690                argmax(logits_slice) as u32
691            } else {
692                // Temperature sampling
693                sample_with_temperature(logits_slice, temperature)
694            };
695
696            // Check for EOS
697            if next_token == 151645 || next_token == 151644 {
698                break;
699            }
700
701            output_ids.push(next_token);
702        }
703
704        output_ids
705    }
706
707    /// Get model configuration.
708    #[must_use]
709    pub fn config(&self) -> &Qwen2Config {
710        &self.config
711    }
712
713    /// Set model to evaluation mode (no dropout).
714    pub fn eval(&mut self) {
715        self.training = false;
716    }
717
718    /// Set model to training mode.
719    pub fn train(&mut self) {
720        self.training = true;
721    }
722
723    /// Enable KV cache for efficient generation.
724    pub fn enable_cache(&mut self) {
725        self.kv_cache = Some(KVCache::new(self.config.num_layers));
726    }
727
728    /// Disable KV cache.
729    pub fn disable_cache(&mut self) {
730        self.kv_cache = None;
731    }
732
733    /// Clear KV cache.
734    pub fn clear_cache(&mut self) {
735        if let Some(ref mut cache) = self.kv_cache {
736            cache.clear();
737        }
738    }
739
740    /// Get number of layers.
741    #[must_use]
742    pub fn num_layers(&self) -> usize {
743        self.layers.len()
744    }
745
746    // ========================================================================
747    // Weight Introspection Methods (Section A: Model Loading)
748    // ========================================================================
749
750    /// Get list of weight names following `HuggingFace` convention.
751    ///
752    /// Returns names like:
753    /// - `model.embed_tokens.weight`
754    /// - `model.layers.0.self_attn.q_proj.weight`
755    /// - `model.norm.weight`
756    /// - `lm_head.weight`
757    #[must_use]
758    pub fn weight_names(&self) -> Vec<String> {
759        let mut names = Vec::new();
760
761        // Embedding
762        names.push("model.embed_tokens.weight".to_string());
763
764        // Decoder layers
765        for i in 0..self.layers.len() {
766            let prefix = format!("model.layers.{i}");
767
768            // Self-attention projections
769            names.push(format!("{prefix}.self_attn.q_proj.weight"));
770            names.push(format!("{prefix}.self_attn.k_proj.weight"));
771            names.push(format!("{prefix}.self_attn.v_proj.weight"));
772            names.push(format!("{prefix}.self_attn.o_proj.weight"));
773
774            // MLP
775            names.push(format!("{prefix}.mlp.gate_proj.weight"));
776            names.push(format!("{prefix}.mlp.up_proj.weight"));
777            names.push(format!("{prefix}.mlp.down_proj.weight"));
778
779            // Layer norms
780            names.push(format!("{prefix}.input_layernorm.weight"));
781            names.push(format!("{prefix}.post_attention_layernorm.weight"));
782        }
783
784        // Final norm
785        names.push("model.norm.weight".to_string());
786
787        // LM head
788        names.push("lm_head.weight".to_string());
789
790        names
791    }
792
793    /// Get weight shapes as a map from name to shape.
794    #[must_use]
795    pub fn weight_info(&self) -> std::collections::HashMap<String, Vec<usize>> {
796        use std::collections::HashMap;
797        let mut info = HashMap::new();
798
799        let h = self.config.hidden_size;
800        let v = self.config.vocab_size;
801        let i = self.config.intermediate_size;
802        let num_heads = self.config.num_attention_heads;
803        let num_kv_heads = self.config.num_kv_heads;
804        let head_dim = h / num_heads;
805        let kv_dim = num_kv_heads * head_dim;
806
807        // Embedding: [vocab_size, hidden_size]
808        info.insert("model.embed_tokens.weight".to_string(), vec![v, h]);
809
810        // Per-layer weights
811        for layer_idx in 0..self.layers.len() {
812            let prefix = format!("model.layers.{layer_idx}");
813
814            // Attention projections
815            info.insert(format!("{prefix}.self_attn.q_proj.weight"), vec![h, h]);
816            info.insert(format!("{prefix}.self_attn.k_proj.weight"), vec![kv_dim, h]);
817            info.insert(format!("{prefix}.self_attn.v_proj.weight"), vec![kv_dim, h]);
818            info.insert(format!("{prefix}.self_attn.o_proj.weight"), vec![h, h]);
819
820            // MLP
821            info.insert(format!("{prefix}.mlp.gate_proj.weight"), vec![i, h]);
822            info.insert(format!("{prefix}.mlp.up_proj.weight"), vec![i, h]);
823            info.insert(format!("{prefix}.mlp.down_proj.weight"), vec![h, i]);
824
825            // Norms
826            info.insert(format!("{prefix}.input_layernorm.weight"), vec![h]);
827            info.insert(format!("{prefix}.post_attention_layernorm.weight"), vec![h]);
828        }
829
830        // Final norm
831        info.insert("model.norm.weight".to_string(), vec![h]);
832
833        // LM head
834        info.insert("lm_head.weight".to_string(), vec![v, h]);
835
836        info
837    }
838
839    /// Extract accessible weights as a map from name to f32 data.
840    ///
841    /// Returns a map suitable for serialization to `SafeTensors` format.
842    /// Note: Currently returns weights from components with public accessors.
843    /// Full weight export will be enabled when nn modules expose weight accessors.
844    #[must_use]
845    pub fn weights(&self) -> std::collections::HashMap<String, Vec<f32>> {
846        use std::collections::HashMap;
847        let mut weights = HashMap::new();
848
849        // Embedding weights (direct access via our Embedding struct)
850        weights.insert(
851            "model.embed_tokens.weight".to_string(),
852            self.embed_tokens.weight.data().to_vec(),
853        );
854
855        // Note: lm_head and norm weights require nn::Linear and nn::RMSNorm
856        // to expose weight() accessors. For now, return embedding only.
857        // This is sufficient for weight loading tests.
858
859        weights
860    }
861
862    /// Get total number of parameters in the model.
863    #[must_use]
864    pub fn num_parameters(&self) -> usize {
865        let info = self.weight_info();
866        info.values()
867            .map(|shape| shape.iter().product::<usize>())
868            .sum()
869    }
870
871    // ========================================================================
872    // Mutable Accessors for Weight Loading
873    // ========================================================================
874
875    /// Get mutable reference to embedding layer.
876    pub fn embed_tokens_mut(&mut self) -> &mut Embedding {
877        &mut self.embed_tokens
878    }
879
880    /// Get mutable reference to decoder layer at index.
881    pub fn layer_mut(&mut self, idx: usize) -> Option<&mut Qwen2DecoderLayer> {
882        self.layers.get_mut(idx)
883    }
884
885    /// Get mutable reference to final norm layer.
886    pub fn norm_mut(&mut self) -> &mut RMSNorm {
887        &mut self.norm
888    }
889
890    /// Get mutable reference to language model head.
891    pub fn lm_head_mut(&mut self) -> &mut Linear {
892        &mut self.lm_head
893    }
894
895    /// Get reference to language model head (for testing/inspection).
896    #[must_use]
897    pub fn lm_head(&self) -> &Linear {
898        &self.lm_head
899    }
900
901    // ========================================================================
902    // SafeTensors Loading (Section A: Model Loading)
903    // ========================================================================
904
905    /// Load weights from `SafeTensors` format.
906    ///
907    /// # Arguments
908    ///
909    /// * `path` - Path to .safetensors file
910    ///
911    /// # Returns
912    ///
913    /// Number of weights loaded
914    ///
915    /// # Errors
916    ///
917    /// Returns error if file cannot be read or weights don't match.
918    pub fn load_from_safetensors(&mut self, path: &std::path::Path) -> Result<usize, String> {
919        use crate::serialization::safetensors::MappedSafeTensors;
920
921        // Use mmap for zero-copy loading (per Native Library Mandate)
922        let mapped = MappedSafeTensors::open(path)?;
923        let mut loaded_count = 0;
924
925        // Helper to load a tensor by name
926        let load_tensor = |name: &str| -> Result<Tensor, String> {
927            let meta = mapped
928                .get_metadata(name)
929                .ok_or_else(|| format!("Weight '{name}' not found in SafeTensors file"))?;
930            let data = mapped.get_tensor(name)?;
931            Ok(Tensor::new(&data, &meta.shape))
932        };
933
934        // Load embedding weights
935        if let Ok(t) = load_tensor("model.embed_tokens.weight") {
936            self.embed_tokens.set_weight(t);
937            loaded_count += 1;
938        }
939
940        // Load decoder layer weights
941        for i in 0..self.layers.len() {
942            let prefix = format!("model.layers.{i}");
943            let layer = self.layers.get_mut(i).ok_or("Layer index out of bounds")?;
944
945            // Attention projections
946            if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.q_proj.weight")) {
947                layer.self_attn_mut().q_proj_mut().set_weight(t);
948                loaded_count += 1;
949            }
950            if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.k_proj.weight")) {
951                layer.self_attn_mut().k_proj_mut().set_weight(t);
952                loaded_count += 1;
953            }
954            if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.v_proj.weight")) {
955                layer.self_attn_mut().v_proj_mut().set_weight(t);
956                loaded_count += 1;
957            }
958            if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.o_proj.weight")) {
959                layer.self_attn_mut().out_proj_mut().set_weight(t);
960                loaded_count += 1;
961            }
962
963            // MLP projections
964            if let Ok(t) = load_tensor(&format!("{prefix}.mlp.gate_proj.weight")) {
965                layer.mlp_mut().gate_proj_mut().set_weight(t);
966                loaded_count += 1;
967            }
968            if let Ok(t) = load_tensor(&format!("{prefix}.mlp.up_proj.weight")) {
969                layer.mlp_mut().up_proj_mut().set_weight(t);
970                loaded_count += 1;
971            }
972            if let Ok(t) = load_tensor(&format!("{prefix}.mlp.down_proj.weight")) {
973                layer.mlp_mut().down_proj_mut().set_weight(t);
974                loaded_count += 1;
975            }
976
977            // Layer norms
978            if let Ok(t) = load_tensor(&format!("{prefix}.input_layernorm.weight")) {
979                layer.input_layernorm_mut().set_weight(t);
980                loaded_count += 1;
981            }
982            if let Ok(t) = load_tensor(&format!("{prefix}.post_attention_layernorm.weight")) {
983                layer.post_attention_layernorm_mut().set_weight(t);
984                loaded_count += 1;
985            }
986        }
987
988        // Final norm
989        if let Ok(t) = load_tensor("model.norm.weight") {
990            self.norm.set_weight(t);
991            loaded_count += 1;
992        }
993
994        // LM head
995        // Note: Qwen2 uses weight tying - lm_head shares weights with embed_tokens
996        if let Ok(t) = load_tensor("lm_head.weight") {
997            self.lm_head.set_weight(t);
998            loaded_count += 1;
999        } else if let Ok(t) = load_tensor("model.embed_tokens.weight") {
1000            // Weight tying fallback: use embed_tokens.weight for lm_head
1001            // This is common in Qwen2 and many transformer models
1002            self.lm_head.set_weight(t);
1003            loaded_count += 1;
1004        }
1005
1006        Ok(loaded_count)
1007    }
1008
1009    /// Load model from `SafeTensors` file.
1010    ///
1011    /// Creates a new model with the given config and loads weights from file.
1012    pub fn from_safetensors(config: &Qwen2Config, path: &std::path::Path) -> Result<Self, String> {
1013        let mut model = Self::new(config);
1014        model.load_from_safetensors(path)?;
1015        Ok(model)
1016    }
1017
1018    /// Load weights from APR v2 format file.
1019    ///
1020    /// Per Native Library Mandate (Spec §2.4): Uses mmap via `bundle::MappedFile`
1021    /// for zero-copy tensor access. This is the REQUIRED approach for APR files.
1022    ///
1023    /// Note: APR canonical names don't have the "model." prefix (it's stripped
1024    /// during import per format/converter.rs). We look for names without prefix.
1025    ///
1026    /// # Returns
1027    ///
1028    /// Number of weights loaded
1029    ///
1030    /// # Errors
1031    ///
1032    /// Returns error if file cannot be read or weights don't match.
1033    pub fn load_from_apr(&mut self, path: &std::path::Path) -> Result<usize, String> {
1034        use crate::bundle::MappedFile;
1035        use crate::format::v2::AprV2ReaderRef;
1036
1037        // Use mmap for zero-copy loading (per Native Library Mandate)
1038        let mapped = MappedFile::open(path).map_err(|e| format!("mmap failed: {e}"))?;
1039        // Use AprV2ReaderRef for zero-copy - does NOT copy the mmap data!
1040        let reader = AprV2ReaderRef::from_bytes(mapped.as_slice())
1041            .map_err(|e| format!("APR parse failed: {e}"))?;
1042
1043        let mut loaded_count = 0;
1044
1045        // Helper to load a tensor by name
1046        // APR uses canonical names without "model." prefix
1047        let load_tensor = |name: &str| -> Result<Tensor, String> {
1048            let entry = reader
1049                .get_tensor(name)
1050                .ok_or_else(|| format!("Weight '{name}' not found in APR file"))?;
1051            let data = reader
1052                .get_f32_tensor(name)
1053                .ok_or_else(|| format!("Failed to read f32 data for '{name}'"))?;
1054            Ok(Tensor::new(&data, &entry.shape))
1055        };
1056
1057        // Load embedding weights (APR uses "embed_tokens.weight" not "model.embed_tokens.weight")
1058        if let Ok(t) = load_tensor("embed_tokens.weight") {
1059            self.embed_tokens.set_weight(t);
1060            loaded_count += 1;
1061        }
1062
1063        // Load decoder layer weights (APR uses "layers.N" not "model.layers.N")
1064        for i in 0..self.layers.len() {
1065            let prefix = format!("layers.{i}");
1066            let layer = self.layers.get_mut(i).ok_or("Layer index out of bounds")?;
1067
1068            // Attention projections
1069            if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.q_proj.weight")) {
1070                layer.self_attn_mut().q_proj_mut().set_weight(t);
1071                loaded_count += 1;
1072            }
1073            if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.k_proj.weight")) {
1074                layer.self_attn_mut().k_proj_mut().set_weight(t);
1075                loaded_count += 1;
1076            }
1077            if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.v_proj.weight")) {
1078                layer.self_attn_mut().v_proj_mut().set_weight(t);
1079                loaded_count += 1;
1080            }
1081            if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.o_proj.weight")) {
1082                layer.self_attn_mut().out_proj_mut().set_weight(t);
1083                loaded_count += 1;
1084            }
1085
1086            // MLP projections
1087            if let Ok(t) = load_tensor(&format!("{prefix}.mlp.gate_proj.weight")) {
1088                layer.mlp_mut().gate_proj_mut().set_weight(t);
1089                loaded_count += 1;
1090            }
1091            if let Ok(t) = load_tensor(&format!("{prefix}.mlp.up_proj.weight")) {
1092                layer.mlp_mut().up_proj_mut().set_weight(t);
1093                loaded_count += 1;
1094            }
1095            if let Ok(t) = load_tensor(&format!("{prefix}.mlp.down_proj.weight")) {
1096                layer.mlp_mut().down_proj_mut().set_weight(t);
1097                loaded_count += 1;
1098            }
1099
1100            // Layer norms
1101            if let Ok(t) = load_tensor(&format!("{prefix}.input_layernorm.weight")) {
1102                layer.input_layernorm_mut().set_weight(t);
1103                loaded_count += 1;
1104            }
1105            if let Ok(t) = load_tensor(&format!("{prefix}.post_attention_layernorm.weight")) {
1106                layer.post_attention_layernorm_mut().set_weight(t);
1107                loaded_count += 1;
1108            }
1109        }
1110
1111        // Final norm (APR uses "norm.weight" not "model.norm.weight")
1112        if let Ok(t) = load_tensor("norm.weight") {
1113            self.norm.set_weight(t);
1114            loaded_count += 1;
1115        }
1116
1117        // LM head (this one doesn't have "model." prefix even in SafeTensors)
1118        // Note: Qwen2 uses weight tying - lm_head shares weights with embed_tokens
1119        if let Ok(t) = load_tensor("lm_head.weight") {
1120            self.lm_head.set_weight(t);
1121            loaded_count += 1;
1122        } else {
1123            // Weight tying: use embed_tokens.weight for lm_head
1124            // This is common in Qwen2 and many transformer models
1125            if let Ok(t) = load_tensor("embed_tokens.weight") {
1126                self.lm_head.set_weight(t);
1127                loaded_count += 1;
1128            }
1129        }
1130
1131        Ok(loaded_count)
1132    }
1133
1134    /// Load model from APR v2 format file.
1135    ///
1136    /// Creates a new model with the given config and loads weights from file.
1137    pub fn from_apr(config: &Qwen2Config, path: &std::path::Path) -> Result<Self, String> {
1138        let mut model = Self::new(config);
1139        model.load_from_apr(path)?;
1140        Ok(model)
1141    }
1142}
1143
1144// ============================================================================
1145// Helper Functions
1146// ============================================================================
1147
1148/// `SiLU` (Swish) activation: x * sigmoid(x)
1149/// Uses SIMD-accelerated Tensor ops instead of naive iterators.
1150fn silu(x: &Tensor) -> Tensor {
1151    // SiLU(x) = x * sigmoid(x)
1152    x.mul(&x.sigmoid())
1153}
1154
1155/// Element-wise multiplication (SIMD-accelerated).
1156fn elementwise_mul(a: &Tensor, b: &Tensor) -> Tensor {
1157    a.mul(b)
1158}
1159
1160/// Element-wise addition (SIMD-accelerated).
1161fn add_tensors(a: &Tensor, b: &Tensor) -> Tensor {
1162    a.add(b)
1163}
1164
1165/// Generate causal attention mask into a pre-allocated buffer.
1166fn generate_causal_mask_into(size: usize, data: &mut [f32]) {
1167    for i in 0..size {
1168        for j in 0..size {
1169            if j > i {
1170                data[i * size + j] = f32::NEG_INFINITY;
1171            } else {
1172                data[i * size + j] = 0.0;
1173            }
1174        }
1175    }
1176}
1177
1178/// Find index of maximum value.
1179fn argmax(slice: &[f32]) -> usize {
1180    slice
1181        .iter()
1182        .enumerate()
1183        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1184        .map_or(0, |(i, _)| i)
1185}
1186
1187/// Sample from logits with temperature.
1188fn sample_with_temperature(logits: &[f32], temperature: f32) -> u32 {
1189    use rand::Rng;
1190
1191    // Apply temperature
1192    let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
1193
1194    // Softmax
1195    let max_val = scaled.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
1196    let exp_vals: Vec<f32> = scaled.iter().map(|&v| (v - max_val).exp()).collect();
1197    let sum: f32 = exp_vals.iter().sum();
1198    let probs: Vec<f32> = exp_vals.iter().map(|&v| v / sum).collect();
1199
1200    // Sample
1201    let mut rng = rand::thread_rng();
1202    let r: f32 = rng.gen();
1203    let mut cumsum = 0.0;
1204
1205    for (i, &p) in probs.iter().enumerate() {
1206        cumsum += p;
1207        if r < cumsum {
1208            return i as u32;
1209        }
1210    }
1211
1212    (probs.len() - 1) as u32
1213}
1214
1215#[cfg(test)]
1216mod tests;