1use crate::autograd::Tensor;
42use crate::demo::Qwen2Config;
43use crate::nn::{GroupedQueryAttention, Linear, Module, RMSNorm, RotaryPositionEmbedding};
44
45#[derive(Debug)]
53pub struct Embedding {
54 weight: Tensor,
56 vocab_size: usize,
57 hidden_size: usize,
58}
59
60impl Embedding {
61 #[must_use]
63 pub fn new(vocab_size: usize, hidden_size: usize) -> Self {
64 let data: Vec<f32> = (0..vocab_size * hidden_size)
66 .map(|i| {
67 (i as f32 * 0.0001).sin() * 0.02
69 })
70 .collect();
71
72 Self {
73 weight: Tensor::new(&data, &[vocab_size, hidden_size]),
74 vocab_size,
75 hidden_size,
76 }
77 }
78
79 #[must_use]
87 pub fn placeholder(vocab_size: usize, hidden_size: usize) -> Self {
88 Self {
89 weight: Tensor::new(&[0.0], &[1]),
90 vocab_size,
91 hidden_size,
92 }
93 }
94
95 pub fn forward_into(&self, input_ids: &[u32], output: &mut [f32]) {
97 for (s, &token_id) in input_ids.iter().enumerate() {
98 let token_idx = token_id as usize;
99 if token_idx >= self.vocab_size {
100 continue;
102 }
103
104 let src_offset = token_idx * self.hidden_size;
105 let dst_offset = s * self.hidden_size;
106
107 output[dst_offset..dst_offset + self.hidden_size]
108 .copy_from_slice(&self.weight.data()[src_offset..src_offset + self.hidden_size]);
109 }
110 }
111
112 #[must_use]
114 pub fn forward(&self, input_ids: &[u32]) -> Tensor {
115 let batch_size = 1;
116 let mut output = vec![0.0f32; batch_size * input_ids.len() * self.hidden_size];
117 self.forward_into(input_ids, &mut output);
118 Tensor::new(&output, &[batch_size, input_ids.len(), self.hidden_size])
119 }
120
121 pub fn set_weight(&mut self, weight: Tensor) {
123 self.weight = weight;
124 }
125
126 #[must_use]
128 pub fn weight(&self) -> &Tensor {
129 &self.weight
130 }
131}
132
133#[derive(Debug)]
143#[allow(clippy::struct_field_names)] pub struct Qwen2MLP {
145 gate_proj: Linear,
146 up_proj: Linear,
147 down_proj: Linear,
148}
149
150impl Qwen2MLP {
151 #[must_use]
153 pub fn new(hidden_size: usize, intermediate_size: usize) -> Self {
154 Self {
155 gate_proj: Linear::new(hidden_size, intermediate_size),
156 up_proj: Linear::new(hidden_size, intermediate_size),
157 down_proj: Linear::new(intermediate_size, hidden_size),
158 }
159 }
160
161 #[must_use]
165 pub fn placeholder(hidden_size: usize, intermediate_size: usize) -> Self {
166 Self {
167 gate_proj: Linear::placeholder(hidden_size, intermediate_size),
168 up_proj: Linear::placeholder(hidden_size, intermediate_size),
169 down_proj: Linear::placeholder(intermediate_size, hidden_size),
170 }
171 }
172
173 #[must_use]
175 pub fn forward(&self, x: &Tensor) -> Tensor {
176 let gate = self.gate_proj.forward(x);
177 let gate_activated = silu(&gate);
178 let up = self.up_proj.forward(x);
179 let hidden = elementwise_mul(&gate_activated, &up);
180 self.down_proj.forward(&hidden)
181 }
182
183 pub fn gate_proj_mut(&mut self) -> &mut Linear {
185 &mut self.gate_proj
186 }
187
188 pub fn up_proj_mut(&mut self) -> &mut Linear {
190 &mut self.up_proj
191 }
192
193 pub fn down_proj_mut(&mut self) -> &mut Linear {
195 &mut self.down_proj
196 }
197}
198
199#[derive(Debug)]
215pub struct Qwen2DecoderLayer {
216 self_attn: GroupedQueryAttention,
217 mlp: Qwen2MLP,
218 input_layernorm: RMSNorm,
219 post_attention_layernorm: RMSNorm,
220}
221
222impl Qwen2DecoderLayer {
223 #[must_use]
225 pub fn new(config: &Qwen2Config) -> Self {
226 Self {
227 self_attn: GroupedQueryAttention::new(
228 config.hidden_size,
229 config.num_attention_heads,
230 config.num_kv_heads,
231 ),
232 mlp: Qwen2MLP::new(config.hidden_size, config.intermediate_size),
233 input_layernorm: RMSNorm::new(&[config.hidden_size]),
234 post_attention_layernorm: RMSNorm::new(&[config.hidden_size]),
235 }
236 }
237
238 #[must_use]
242 pub fn placeholder(config: &Qwen2Config) -> Self {
243 Self {
244 self_attn: GroupedQueryAttention::placeholder(
245 config.hidden_size,
246 config.num_attention_heads,
247 config.num_kv_heads,
248 ),
249 mlp: Qwen2MLP::placeholder(config.hidden_size, config.intermediate_size),
250 input_layernorm: RMSNorm::placeholder(&[config.hidden_size]),
251 post_attention_layernorm: RMSNorm::placeholder(&[config.hidden_size]),
252 }
253 }
254
255 #[must_use]
257 pub fn forward(
258 &self,
259 hidden_states: &Tensor,
260 _position_ids: &[usize],
261 _rope: &RotaryPositionEmbedding,
262 _attention_mask: Option<&Tensor>,
263 ) -> Tensor {
264 let residual = hidden_states.clone();
268 let hidden = self.input_layernorm.forward(hidden_states);
269 let (attn_output, _attn_weights) = self.self_attn.forward_self(&hidden, None);
270 let hidden = add_tensors(&residual, &attn_output);
271
272 let residual = hidden.clone();
274 let hidden = self.post_attention_layernorm.forward(&hidden);
275 let mlp_output = self.mlp.forward(&hidden);
276 add_tensors(&residual, &mlp_output)
277 }
278
279 #[must_use]
281 pub fn forward_profiled(
282 &self,
283 hidden_states: &Tensor,
284 _position_ids: &[usize],
285 _rope: &RotaryPositionEmbedding,
286 _attention_mask: Option<&Tensor>,
287 ) -> (Tensor, std::time::Duration, std::time::Duration) {
288 use std::time::Instant;
289
290 let residual = hidden_states.clone();
292 let hidden = self.input_layernorm.forward(hidden_states);
293
294 let attn_start = Instant::now();
295 let (attn_output, _attn_weights) = self.self_attn.forward_self(&hidden, None);
296 let attn_time = attn_start.elapsed();
297
298 let hidden = add_tensors(&residual, &attn_output);
299
300 let residual = hidden.clone();
302 let hidden = self.post_attention_layernorm.forward(&hidden);
303
304 let mlp_start = Instant::now();
305 let mlp_output = self.mlp.forward(&hidden);
306 let mlp_time = mlp_start.elapsed();
307
308 (add_tensors(&residual, &mlp_output), attn_time, mlp_time)
309 }
310
311 pub fn self_attn_mut(&mut self) -> &mut GroupedQueryAttention {
313 &mut self.self_attn
314 }
315
316 pub fn mlp_mut(&mut self) -> &mut Qwen2MLP {
318 &mut self.mlp
319 }
320
321 pub fn input_layernorm_mut(&mut self) -> &mut RMSNorm {
323 &mut self.input_layernorm
324 }
325
326 pub fn post_attention_layernorm_mut(&mut self) -> &mut RMSNorm {
328 &mut self.post_attention_layernorm
329 }
330}
331
332#[derive(Debug)]
338pub struct KVCache {
339 pub keys: Vec<Option<Tensor>>,
341 pub values: Vec<Option<Tensor>>,
343 pub cached_len: usize,
345}
346
347impl KVCache {
348 #[must_use]
350 pub fn new(num_layers: usize) -> Self {
351 Self {
352 keys: vec![None; num_layers],
353 values: vec![None; num_layers],
354 cached_len: 0,
355 }
356 }
357
358 pub fn clear(&mut self) {
360 for k in &mut self.keys {
361 *k = None;
362 }
363 for v in &mut self.values {
364 *v = None;
365 }
366 self.cached_len = 0;
367 }
368}
369
370#[derive(Debug)]
378pub struct Qwen2Model {
379 embed_tokens: Embedding,
381 layers: Vec<Qwen2DecoderLayer>,
383 norm: RMSNorm,
385 lm_head: Linear,
387 rope: RotaryPositionEmbedding,
389 config: Qwen2Config,
391 kv_cache: Option<KVCache>,
393 training: bool,
395 cached_causal_mask: Option<Tensor>,
397 cached_mask_data: Vec<f32>,
399 cached_embed_data: Vec<f32>,
401}
402
403impl Qwen2Model {
404 #[must_use]
408 pub fn new(config: &Qwen2Config) -> Self {
409 let head_dim = config.hidden_size / config.num_attention_heads;
410
411 Self {
412 embed_tokens: Embedding::new(config.vocab_size, config.hidden_size),
413 layers: (0..config.num_layers)
414 .map(|_| Qwen2DecoderLayer::new(config))
415 .collect(),
416 norm: RMSNorm::new(&[config.hidden_size]),
417 lm_head: Linear::new(config.hidden_size, config.vocab_size),
418 rope: RotaryPositionEmbedding::with_base(
419 head_dim,
420 config.max_seq_len,
421 config.rope_theta as f32,
422 ),
423 config: config.clone(),
424 kv_cache: None,
425 training: false,
426 cached_causal_mask: None,
427 cached_mask_data: Vec::new(),
428 cached_embed_data: Vec::new(),
429 }
430 }
431
432 #[must_use]
436 pub fn new_uninitialized(config: &Qwen2Config) -> Self {
437 let head_dim = config.hidden_size / config.num_attention_heads;
438
439 Self {
440 embed_tokens: Embedding::placeholder(config.vocab_size, config.hidden_size),
441 layers: (0..config.num_layers)
442 .map(|_| Qwen2DecoderLayer::placeholder(config))
443 .collect(),
444 norm: RMSNorm::placeholder(&[config.hidden_size]),
445 lm_head: Linear::placeholder(config.hidden_size, config.vocab_size),
446 rope: RotaryPositionEmbedding::with_base(
447 head_dim,
448 config.max_seq_len,
449 config.rope_theta as f32,
450 ),
451 config: config.clone(),
452 kv_cache: None,
453 training: false,
454 cached_causal_mask: None,
455 cached_mask_data: Vec::new(),
456 cached_embed_data: Vec::new(),
457 }
458 }
459
460 pub fn forward(&mut self, input_ids: &[u32], position_ids: &[usize]) -> Tensor {
471 let seq_len = input_ids.len();
473 if self.cached_embed_data.len() < seq_len * self.config.hidden_size {
474 self.cached_embed_data = vec![0.0f32; seq_len * self.config.hidden_size];
475 }
476 self.embed_tokens
477 .forward_into(input_ids, &mut self.cached_embed_data);
478 let mut hidden = Tensor::new(
479 &self.cached_embed_data[..seq_len * self.config.hidden_size],
480 &[1, seq_len, self.config.hidden_size],
481 );
482
483 if self
485 .cached_causal_mask
486 .as_ref()
487 .map_or(true, |m| m.shape()[0] != seq_len)
488 {
489 if self.cached_mask_data.len() < seq_len * seq_len {
490 self.cached_mask_data = vec![0.0f32; seq_len * seq_len];
491 }
492 generate_causal_mask_into(seq_len, &mut self.cached_mask_data);
493 self.cached_causal_mask = Some(Tensor::new(
494 &self.cached_mask_data[..seq_len * seq_len],
495 &[seq_len, seq_len],
496 ));
497 }
498 let attention_mask = self
499 .cached_causal_mask
500 .as_ref()
501 .expect("causal mask must be initialized before forward pass");
502
503 for layer in &self.layers {
505 hidden = layer.forward(&hidden, position_ids, &self.rope, Some(attention_mask));
506 }
507
508 hidden = self.norm.forward(&hidden);
510
511 self.lm_head.forward(&hidden)
513 }
514
515 pub fn forward_profiled(&mut self, input_ids: &[u32], position_ids: &[usize]) -> Tensor {
518 use std::time::Instant;
519
520 let total_start = Instant::now();
521
522 let embed_start = Instant::now();
524 let seq_len = input_ids.len();
525 if self.cached_embed_data.len() < seq_len * self.config.hidden_size {
526 self.cached_embed_data = vec![0.0f32; seq_len * self.config.hidden_size];
527 }
528 self.embed_tokens
529 .forward_into(input_ids, &mut self.cached_embed_data);
530 let mut hidden = Tensor::new(
531 &self.cached_embed_data[..seq_len * self.config.hidden_size],
532 &[1, seq_len, self.config.hidden_size],
533 );
534 let embed_time = embed_start.elapsed();
535
536 if self
538 .cached_causal_mask
539 .as_ref()
540 .map_or(true, |m| m.shape()[0] != seq_len)
541 {
542 if self.cached_mask_data.len() < seq_len * seq_len {
543 self.cached_mask_data = vec![0.0f32; seq_len * seq_len];
544 }
545 generate_causal_mask_into(seq_len, &mut self.cached_mask_data);
546 self.cached_causal_mask = Some(Tensor::new(
547 &self.cached_mask_data[..seq_len * seq_len],
548 &[seq_len, seq_len],
549 ));
550 }
551 let attention_mask = self
552 .cached_causal_mask
553 .as_ref()
554 .expect("causal mask must be initialized before profiled forward pass");
555
556 let mut total_attn = std::time::Duration::ZERO;
558 let mut total_mlp = std::time::Duration::ZERO;
559
560 let layers_start = Instant::now();
561 for layer in &self.layers {
562 let (output, attn_time, mlp_time) =
563 layer.forward_profiled(&hidden, position_ids, &self.rope, Some(attention_mask));
564 hidden = output;
565 total_attn += attn_time;
566 total_mlp += mlp_time;
567 }
568 let layers_time = layers_start.elapsed();
569
570 let norm_start = Instant::now();
572 hidden = self.norm.forward(&hidden);
573 let norm_time = norm_start.elapsed();
574
575 let lm_head_start = Instant::now();
577 let output = self.lm_head.forward(&hidden);
578 let lm_head_time = lm_head_start.elapsed();
579
580 let total_time = total_start.elapsed();
581
582 eprintln!("\n=== Forward Pass Profile (seq_len={seq_len}) ===");
584 eprintln!(
585 " Embedding: {:>8.2}ms ({:>5.1}%)",
586 embed_time.as_secs_f64() * 1000.0,
587 embed_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
588 );
589 eprintln!(
590 " Layers total: {:>8.2}ms ({:>5.1}%)",
591 layers_time.as_secs_f64() * 1000.0,
592 layers_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
593 );
594 eprintln!(
595 " - Attention: {:>8.2}ms ({:>5.1}%)",
596 total_attn.as_secs_f64() * 1000.0,
597 total_attn.as_secs_f64() / total_time.as_secs_f64() * 100.0
598 );
599 eprintln!(
600 " - MLP: {:>8.2}ms ({:>5.1}%)",
601 total_mlp.as_secs_f64() * 1000.0,
602 total_mlp.as_secs_f64() / total_time.as_secs_f64() * 100.0
603 );
604 eprintln!(
605 " Final norm: {:>8.2}ms ({:>5.1}%)",
606 norm_time.as_secs_f64() * 1000.0,
607 norm_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
608 );
609 eprintln!(
610 " LM head: {:>8.2}ms ({:>5.1}%)",
611 lm_head_time.as_secs_f64() * 1000.0,
612 lm_head_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
613 );
614 eprintln!(
615 " TOTAL: {:>8.2}ms",
616 total_time.as_secs_f64() * 1000.0
617 );
618 eprintln!("==========================================\n");
619
620 output
621 }
622
623 pub fn generate(
636 &mut self,
637 prompt_ids: &[u32],
638 max_new_tokens: usize,
639 temperature: f32,
640 _top_p: f32,
641 ) -> Vec<u32> {
642 self.generate_internal(prompt_ids, max_new_tokens, temperature, false)
643 }
644
645 pub fn generate_profiled(
647 &mut self,
648 prompt_ids: &[u32],
649 max_new_tokens: usize,
650 temperature: f32,
651 ) -> Vec<u32> {
652 self.generate_internal(prompt_ids, max_new_tokens, temperature, true)
653 }
654
655 fn generate_internal(
656 &mut self,
657 prompt_ids: &[u32],
658 max_new_tokens: usize,
659 temperature: f32,
660 profile: bool,
661 ) -> Vec<u32> {
662 let mut output_ids = Vec::with_capacity(prompt_ids.len() + max_new_tokens);
663 output_ids.extend_from_slice(prompt_ids);
664
665 let mut position_ids = Vec::with_capacity(prompt_ids.len() + max_new_tokens);
667
668 for i in 0..max_new_tokens {
669 position_ids.clear();
671 for p in 0..output_ids.len() {
672 position_ids.push(p);
673 }
674
675 let logits = if profile && i == 0 {
677 self.forward_profiled(&output_ids, &position_ids)
678 } else {
679 self.forward(&output_ids, &position_ids)
680 };
681
682 let vocab_size = self.config.vocab_size;
684 let last_pos = output_ids.len() - 1;
685 let logits_slice = &logits.data()[last_pos * vocab_size..(last_pos + 1) * vocab_size];
686
687 let next_token = if temperature == 0.0 {
689 argmax(logits_slice) as u32
691 } else {
692 sample_with_temperature(logits_slice, temperature)
694 };
695
696 if next_token == 151645 || next_token == 151644 {
698 break;
699 }
700
701 output_ids.push(next_token);
702 }
703
704 output_ids
705 }
706
707 #[must_use]
709 pub fn config(&self) -> &Qwen2Config {
710 &self.config
711 }
712
713 pub fn eval(&mut self) {
715 self.training = false;
716 }
717
718 pub fn train(&mut self) {
720 self.training = true;
721 }
722
723 pub fn enable_cache(&mut self) {
725 self.kv_cache = Some(KVCache::new(self.config.num_layers));
726 }
727
728 pub fn disable_cache(&mut self) {
730 self.kv_cache = None;
731 }
732
733 pub fn clear_cache(&mut self) {
735 if let Some(ref mut cache) = self.kv_cache {
736 cache.clear();
737 }
738 }
739
740 #[must_use]
742 pub fn num_layers(&self) -> usize {
743 self.layers.len()
744 }
745
746 #[must_use]
758 pub fn weight_names(&self) -> Vec<String> {
759 let mut names = Vec::new();
760
761 names.push("model.embed_tokens.weight".to_string());
763
764 for i in 0..self.layers.len() {
766 let prefix = format!("model.layers.{i}");
767
768 names.push(format!("{prefix}.self_attn.q_proj.weight"));
770 names.push(format!("{prefix}.self_attn.k_proj.weight"));
771 names.push(format!("{prefix}.self_attn.v_proj.weight"));
772 names.push(format!("{prefix}.self_attn.o_proj.weight"));
773
774 names.push(format!("{prefix}.mlp.gate_proj.weight"));
776 names.push(format!("{prefix}.mlp.up_proj.weight"));
777 names.push(format!("{prefix}.mlp.down_proj.weight"));
778
779 names.push(format!("{prefix}.input_layernorm.weight"));
781 names.push(format!("{prefix}.post_attention_layernorm.weight"));
782 }
783
784 names.push("model.norm.weight".to_string());
786
787 names.push("lm_head.weight".to_string());
789
790 names
791 }
792
793 #[must_use]
795 pub fn weight_info(&self) -> std::collections::HashMap<String, Vec<usize>> {
796 use std::collections::HashMap;
797 let mut info = HashMap::new();
798
799 let h = self.config.hidden_size;
800 let v = self.config.vocab_size;
801 let i = self.config.intermediate_size;
802 let num_heads = self.config.num_attention_heads;
803 let num_kv_heads = self.config.num_kv_heads;
804 let head_dim = h / num_heads;
805 let kv_dim = num_kv_heads * head_dim;
806
807 info.insert("model.embed_tokens.weight".to_string(), vec![v, h]);
809
810 for layer_idx in 0..self.layers.len() {
812 let prefix = format!("model.layers.{layer_idx}");
813
814 info.insert(format!("{prefix}.self_attn.q_proj.weight"), vec![h, h]);
816 info.insert(format!("{prefix}.self_attn.k_proj.weight"), vec![kv_dim, h]);
817 info.insert(format!("{prefix}.self_attn.v_proj.weight"), vec![kv_dim, h]);
818 info.insert(format!("{prefix}.self_attn.o_proj.weight"), vec![h, h]);
819
820 info.insert(format!("{prefix}.mlp.gate_proj.weight"), vec![i, h]);
822 info.insert(format!("{prefix}.mlp.up_proj.weight"), vec![i, h]);
823 info.insert(format!("{prefix}.mlp.down_proj.weight"), vec![h, i]);
824
825 info.insert(format!("{prefix}.input_layernorm.weight"), vec![h]);
827 info.insert(format!("{prefix}.post_attention_layernorm.weight"), vec![h]);
828 }
829
830 info.insert("model.norm.weight".to_string(), vec![h]);
832
833 info.insert("lm_head.weight".to_string(), vec![v, h]);
835
836 info
837 }
838
839 #[must_use]
845 pub fn weights(&self) -> std::collections::HashMap<String, Vec<f32>> {
846 use std::collections::HashMap;
847 let mut weights = HashMap::new();
848
849 weights.insert(
851 "model.embed_tokens.weight".to_string(),
852 self.embed_tokens.weight.data().to_vec(),
853 );
854
855 weights
860 }
861
862 #[must_use]
864 pub fn num_parameters(&self) -> usize {
865 let info = self.weight_info();
866 info.values()
867 .map(|shape| shape.iter().product::<usize>())
868 .sum()
869 }
870
871 pub fn embed_tokens_mut(&mut self) -> &mut Embedding {
877 &mut self.embed_tokens
878 }
879
880 pub fn layer_mut(&mut self, idx: usize) -> Option<&mut Qwen2DecoderLayer> {
882 self.layers.get_mut(idx)
883 }
884
885 pub fn norm_mut(&mut self) -> &mut RMSNorm {
887 &mut self.norm
888 }
889
890 pub fn lm_head_mut(&mut self) -> &mut Linear {
892 &mut self.lm_head
893 }
894
895 #[must_use]
897 pub fn lm_head(&self) -> &Linear {
898 &self.lm_head
899 }
900
901 pub fn load_from_safetensors(&mut self, path: &std::path::Path) -> Result<usize, String> {
919 use crate::serialization::safetensors::MappedSafeTensors;
920
921 let mapped = MappedSafeTensors::open(path)?;
923 let mut loaded_count = 0;
924
925 let load_tensor = |name: &str| -> Result<Tensor, String> {
927 let meta = mapped
928 .get_metadata(name)
929 .ok_or_else(|| format!("Weight '{name}' not found in SafeTensors file"))?;
930 let data = mapped.get_tensor(name)?;
931 Ok(Tensor::new(&data, &meta.shape))
932 };
933
934 if let Ok(t) = load_tensor("model.embed_tokens.weight") {
936 self.embed_tokens.set_weight(t);
937 loaded_count += 1;
938 }
939
940 for i in 0..self.layers.len() {
942 let prefix = format!("model.layers.{i}");
943 let layer = self.layers.get_mut(i).ok_or("Layer index out of bounds")?;
944
945 if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.q_proj.weight")) {
947 layer.self_attn_mut().q_proj_mut().set_weight(t);
948 loaded_count += 1;
949 }
950 if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.k_proj.weight")) {
951 layer.self_attn_mut().k_proj_mut().set_weight(t);
952 loaded_count += 1;
953 }
954 if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.v_proj.weight")) {
955 layer.self_attn_mut().v_proj_mut().set_weight(t);
956 loaded_count += 1;
957 }
958 if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.o_proj.weight")) {
959 layer.self_attn_mut().out_proj_mut().set_weight(t);
960 loaded_count += 1;
961 }
962
963 if let Ok(t) = load_tensor(&format!("{prefix}.mlp.gate_proj.weight")) {
965 layer.mlp_mut().gate_proj_mut().set_weight(t);
966 loaded_count += 1;
967 }
968 if let Ok(t) = load_tensor(&format!("{prefix}.mlp.up_proj.weight")) {
969 layer.mlp_mut().up_proj_mut().set_weight(t);
970 loaded_count += 1;
971 }
972 if let Ok(t) = load_tensor(&format!("{prefix}.mlp.down_proj.weight")) {
973 layer.mlp_mut().down_proj_mut().set_weight(t);
974 loaded_count += 1;
975 }
976
977 if let Ok(t) = load_tensor(&format!("{prefix}.input_layernorm.weight")) {
979 layer.input_layernorm_mut().set_weight(t);
980 loaded_count += 1;
981 }
982 if let Ok(t) = load_tensor(&format!("{prefix}.post_attention_layernorm.weight")) {
983 layer.post_attention_layernorm_mut().set_weight(t);
984 loaded_count += 1;
985 }
986 }
987
988 if let Ok(t) = load_tensor("model.norm.weight") {
990 self.norm.set_weight(t);
991 loaded_count += 1;
992 }
993
994 if let Ok(t) = load_tensor("lm_head.weight") {
997 self.lm_head.set_weight(t);
998 loaded_count += 1;
999 } else if let Ok(t) = load_tensor("model.embed_tokens.weight") {
1000 self.lm_head.set_weight(t);
1003 loaded_count += 1;
1004 }
1005
1006 Ok(loaded_count)
1007 }
1008
1009 pub fn from_safetensors(config: &Qwen2Config, path: &std::path::Path) -> Result<Self, String> {
1013 let mut model = Self::new(config);
1014 model.load_from_safetensors(path)?;
1015 Ok(model)
1016 }
1017
1018 pub fn load_from_apr(&mut self, path: &std::path::Path) -> Result<usize, String> {
1034 use crate::bundle::MappedFile;
1035 use crate::format::v2::AprV2ReaderRef;
1036
1037 let mapped = MappedFile::open(path).map_err(|e| format!("mmap failed: {e}"))?;
1039 let reader = AprV2ReaderRef::from_bytes(mapped.as_slice())
1041 .map_err(|e| format!("APR parse failed: {e}"))?;
1042
1043 let mut loaded_count = 0;
1044
1045 let load_tensor = |name: &str| -> Result<Tensor, String> {
1048 let entry = reader
1049 .get_tensor(name)
1050 .ok_or_else(|| format!("Weight '{name}' not found in APR file"))?;
1051 let data = reader
1052 .get_f32_tensor(name)
1053 .ok_or_else(|| format!("Failed to read f32 data for '{name}'"))?;
1054 Ok(Tensor::new(&data, &entry.shape))
1055 };
1056
1057 if let Ok(t) = load_tensor("embed_tokens.weight") {
1059 self.embed_tokens.set_weight(t);
1060 loaded_count += 1;
1061 }
1062
1063 for i in 0..self.layers.len() {
1065 let prefix = format!("layers.{i}");
1066 let layer = self.layers.get_mut(i).ok_or("Layer index out of bounds")?;
1067
1068 if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.q_proj.weight")) {
1070 layer.self_attn_mut().q_proj_mut().set_weight(t);
1071 loaded_count += 1;
1072 }
1073 if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.k_proj.weight")) {
1074 layer.self_attn_mut().k_proj_mut().set_weight(t);
1075 loaded_count += 1;
1076 }
1077 if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.v_proj.weight")) {
1078 layer.self_attn_mut().v_proj_mut().set_weight(t);
1079 loaded_count += 1;
1080 }
1081 if let Ok(t) = load_tensor(&format!("{prefix}.self_attn.o_proj.weight")) {
1082 layer.self_attn_mut().out_proj_mut().set_weight(t);
1083 loaded_count += 1;
1084 }
1085
1086 if let Ok(t) = load_tensor(&format!("{prefix}.mlp.gate_proj.weight")) {
1088 layer.mlp_mut().gate_proj_mut().set_weight(t);
1089 loaded_count += 1;
1090 }
1091 if let Ok(t) = load_tensor(&format!("{prefix}.mlp.up_proj.weight")) {
1092 layer.mlp_mut().up_proj_mut().set_weight(t);
1093 loaded_count += 1;
1094 }
1095 if let Ok(t) = load_tensor(&format!("{prefix}.mlp.down_proj.weight")) {
1096 layer.mlp_mut().down_proj_mut().set_weight(t);
1097 loaded_count += 1;
1098 }
1099
1100 if let Ok(t) = load_tensor(&format!("{prefix}.input_layernorm.weight")) {
1102 layer.input_layernorm_mut().set_weight(t);
1103 loaded_count += 1;
1104 }
1105 if let Ok(t) = load_tensor(&format!("{prefix}.post_attention_layernorm.weight")) {
1106 layer.post_attention_layernorm_mut().set_weight(t);
1107 loaded_count += 1;
1108 }
1109 }
1110
1111 if let Ok(t) = load_tensor("norm.weight") {
1113 self.norm.set_weight(t);
1114 loaded_count += 1;
1115 }
1116
1117 if let Ok(t) = load_tensor("lm_head.weight") {
1120 self.lm_head.set_weight(t);
1121 loaded_count += 1;
1122 } else {
1123 if let Ok(t) = load_tensor("embed_tokens.weight") {
1126 self.lm_head.set_weight(t);
1127 loaded_count += 1;
1128 }
1129 }
1130
1131 Ok(loaded_count)
1132 }
1133
1134 pub fn from_apr(config: &Qwen2Config, path: &std::path::Path) -> Result<Self, String> {
1138 let mut model = Self::new(config);
1139 model.load_from_apr(path)?;
1140 Ok(model)
1141 }
1142}
1143
1144fn silu(x: &Tensor) -> Tensor {
1151 x.mul(&x.sigmoid())
1153}
1154
1155fn elementwise_mul(a: &Tensor, b: &Tensor) -> Tensor {
1157 a.mul(b)
1158}
1159
1160fn add_tensors(a: &Tensor, b: &Tensor) -> Tensor {
1162 a.add(b)
1163}
1164
1165fn generate_causal_mask_into(size: usize, data: &mut [f32]) {
1167 for i in 0..size {
1168 for j in 0..size {
1169 if j > i {
1170 data[i * size + j] = f32::NEG_INFINITY;
1171 } else {
1172 data[i * size + j] = 0.0;
1173 }
1174 }
1175 }
1176}
1177
1178fn argmax(slice: &[f32]) -> usize {
1180 slice
1181 .iter()
1182 .enumerate()
1183 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1184 .map_or(0, |(i, _)| i)
1185}
1186
1187fn sample_with_temperature(logits: &[f32], temperature: f32) -> u32 {
1189 use rand::Rng;
1190
1191 let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
1193
1194 let max_val = scaled.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
1196 let exp_vals: Vec<f32> = scaled.iter().map(|&v| (v - max_val).exp()).collect();
1197 let sum: f32 = exp_vals.iter().sum();
1198 let probs: Vec<f32> = exp_vals.iter().map(|&v| v / sum).collect();
1199
1200 let mut rng = rand::thread_rng();
1202 let r: f32 = rng.gen();
1203 let mut cumsum = 0.0;
1204
1205 for (i, &p) in probs.iter().enumerate() {
1206 cumsum += p;
1207 if r < cumsum {
1208 return i as u32;
1209 }
1210 }
1211
1212 (probs.len() - 1) as u32
1213}
1214
1215#[cfg(test)]
1216mod tests;