Skip to main content

llama_rs/model/
loader.rs

1//! Model loader for GGUF files
2//!
3//! This module provides functionality to load model weights from GGUF files
4//! and construct model instances.
5
6use std::path::Path;
7
8use crate::gguf::{GgufFile, MetadataValue};
9use crate::tensor::{DType, Tensor};
10
11use super::Architecture;
12use super::config::{ActivationType, ModelConfig, RopeConfig, RopeScalingType, RopeType};
13use super::deltanet::{BetaAlphaProjection, DeltaNetConfig, DeltaNetLayer};
14use super::mamba::{MambaConfig, MambaLayer};
15use super::error::{ModelError, ModelResult};
16use super::layers::{
17    Attention, AttentionLayer, FeedForward, FfnLayer, LayerNorm, Linear, NormLayer,
18    NoGateFeedForward, RMSNorm, TransformerLayer,
19};
20use super::bert::{BertLayer, BertModel};
21use super::llama::LlamaModel;
22use super::moe::{MoeConfig, MoeExpert, MoeLayer, MoeRouter};
23
24/// Trait for model weight sources (GGUF, ONNX, SafeTensors).
25///
26/// Implementors provide tensor loading and config access.
27/// The shared model assembly functions use this trait to load weights
28/// from any supported format.
29pub trait ModelSource {
30    /// Get parsed model configuration.
31    fn config(&self) -> &ModelConfig;
32
33    /// Get mutable reference to model configuration.
34    fn config_mut(&mut self) -> &mut ModelConfig;
35
36    /// Get detected architecture.
37    fn architecture(&self) -> Architecture;
38
39    /// Load a tensor by internal name (e.g., "blk.0.attn_q.weight").
40    fn load_tensor(&self, name: &str) -> ModelResult<Tensor>;
41
42    /// Try to load a tensor (None if not found).
43    fn try_load_tensor(&self, name: &str) -> Option<Tensor>;
44}
45
46/// Model loader for GGUF files
47pub struct ModelLoader {
48    /// Loaded GGUF file
49    gguf: GgufFile,
50    /// Detected architecture
51    architecture: Architecture,
52    /// Parsed model configuration
53    config: ModelConfig,
54}
55
56impl ModelLoader {
57    /// Load a model from a GGUF file path
58    pub fn load<P: AsRef<Path>>(path: P) -> ModelResult<Self> {
59        let gguf = GgufFile::open(path)?;
60
61        // Detect architecture
62        let arch_str = gguf
63            .data
64            .get_string("general.architecture")
65            .ok_or_else(|| ModelError::MissingMetadata("general.architecture".into()))?;
66
67        let architecture = Architecture::from_gguf_str(arch_str);
68
69        if matches!(architecture, Architecture::Unknown) {
70            return Err(ModelError::UnsupportedArchitecture(arch_str.to_string()));
71        }
72
73        // Parse configuration from metadata
74        let config = Self::parse_config(&gguf, &architecture)?;
75
76        Ok(Self {
77            gguf,
78            architecture,
79            config,
80        })
81    }
82
83    /// Parse model configuration from GGUF metadata
84    fn parse_config(gguf: &GgufFile, architecture: &Architecture) -> ModelResult<ModelConfig> {
85        let arch = architecture.as_str();
86
87        // Helper to get u32 metadata
88        let get_u32 = |key: &str| -> ModelResult<u32> {
89            gguf.data
90                .get_u32(key)
91                .ok_or_else(|| ModelError::MissingMetadata(key.into()))
92        };
93
94        // Helper to get f32 metadata with default
95        let get_f32_or =
96            |key: &str, default: f32| -> f32 { gguf.data.get_f32(key).unwrap_or(default) };
97
98        // Get core configuration
99        // Try multiple methods to determine vocab size
100        let vocab_size = get_u32(&format!("{}.vocab_size", arch))
101            .or_else(|_| get_u32("tokenizer.ggml.vocab_size"))
102            .map(|v| v as usize)
103            .unwrap_or_else(|_| {
104                // Fallback: get vocab size from tokenizer tokens array length
105                if let Some(tokens) = gguf.data.metadata.get("tokenizer.ggml.tokens")
106                    && let MetadataValue::Array(arr) = tokens
107                {
108                    return arr.values.len();
109                }
110                // Last resort: infer from embedding tensor shape
111                if let Some(emb_info) = gguf.data.get_tensor("token_embd.weight") {
112                    // Shape is [hidden_size, vocab_size] in llama.cpp convention
113                    if emb_info.dims.len() == 2 {
114                        return emb_info.dims[1] as usize;
115                    }
116                }
117                // Default
118                32000
119            });
120
121        let hidden_size = get_u32(&format!("{}.embedding_length", arch))? as usize;
122
123        let num_layers = get_u32(&format!("{}.block_count", arch))? as usize;
124
125        // Mamba/Mamba2 have no attention heads; use SSM params or defaults
126        let (num_heads, num_kv_heads, head_dim) =
127            if matches!(architecture, Architecture::Mamba | Architecture::Mamba2) {
128                let nh = get_u32(&format!("{}.attention.head_count", arch)).unwrap_or(1) as usize;
129                let nkv = get_u32(&format!("{}.attention.head_count_kv", arch))
130                    .unwrap_or(nh as u32) as usize;
131                let hd = get_u32(&format!("{}.attention.key_length", arch))
132                    .unwrap_or_else(|_| (hidden_size / nh.max(1)) as u32) as usize;
133                (nh, nkv, hd)
134            } else {
135                let nh = get_u32(&format!("{}.attention.head_count", arch))? as usize;
136                let nkv = get_u32(&format!("{}.attention.head_count_kv", arch))
137                    .unwrap_or(nh as u32) as usize;
138                let hd = get_u32(&format!("{}.attention.key_length", arch))
139                    .map(|v| v as usize)
140                    .unwrap_or(hidden_size / nh);
141                (nh, nkv, hd)
142            };
143
144        let intermediate_size = get_u32(&format!("{}.feed_forward_length", arch))
145            .unwrap_or_else(|_| {
146                if matches!(architecture, Architecture::Mamba | Architecture::Mamba2) {
147                    hidden_size as u32 // Pure Mamba may have no FFN
148                } else {
149                    (hidden_size * 4 * 2 / 3) as u32
150                }
151            }) as usize;
152
153        let max_seq_len = get_u32(&format!("{}.context_length", arch)).unwrap_or(2048) as usize;
154
155        let norm_eps = gguf
156            .data
157            .get_f32(&format!("{}.attention.layer_norm_rms_epsilon", arch))
158            .or_else(|| gguf.data.get_f32(&format!("{}.attention.layer_norm_epsilon", arch)))
159            .unwrap_or(1e-5);
160
161        // Parse RoPE configuration
162        let freq_base = get_f32_or(&format!("{}.rope.freq_base", arch), 10000.0);
163        let freq_scale = get_f32_or(&format!("{}.rope.scale_linear", arch), 1.0);
164
165        // Determine RoPE type based on architecture
166        // NeoX-style RoPE pairs (x[i], x[i+d/2]) — most modern architectures.
167        // Normal/LLaMA-style pairs consecutive elements (x[2i], x[2i+1]).
168        let rope_type = match architecture {
169            Architecture::Qwen2
170            | Architecture::Qwen2Moe
171            | Architecture::Qwen3
172            | Architecture::Qwen35
173            | Architecture::Qwen35Moe
174            | Architecture::Qwen3Moe
175            | Architecture::Qwen3Next
176            | Architecture::GPTNeoX
177            | Architecture::Falcon
178            | Architecture::Phi
179            | Architecture::Phi2
180            | Architecture::Phi3
181            | Architecture::PhiMoe
182            | Architecture::GPTJ
183            | Architecture::StableLM
184            | Architecture::Gemma
185            | Architecture::Gemma2
186            | Architecture::Gemma3
187            | Architecture::Gemma3N
188            | Architecture::Gemma4
189            | Architecture::GemmaEmbedding => RopeType::NeoX,
190            _ => RopeType::Normal,
191        };
192
193        // MoE configuration
194        let num_experts = get_u32(&format!("{}.expert_count", arch)).unwrap_or(0) as usize;
195        let num_experts_per_token =
196            get_u32(&format!("{}.expert_used_count", arch)).unwrap_or(0) as usize;
197        let expert_intermediate_size =
198            get_u32(&format!("{}.expert_feed_forward_length", arch)).unwrap_or(0) as usize;
199
200        // Attention head dimensions (may differ from hidden_size / num_heads)
201        let key_length =
202            get_u32(&format!("{}.attention.key_length", arch)).unwrap_or(head_dim as u32) as usize;
203        let value_length = get_u32(&format!("{}.attention.value_length", arch))
204            .unwrap_or(head_dim as u32) as usize;
205
206        let rope_n_dims = get_u32(&format!("{}.rope.dimension_count", arch))
207            .unwrap_or(head_dim as u32) as usize;
208
209        // MRoPE dimension sections (Qwen 3.6: [11, 11, 10, 0])
210        let mrope_sections = if let Some(MetadataValue::Array(arr)) =
211            gguf.data.metadata.get(&format!("{}.rope.dimension_sections", arch))
212        {
213            let sections: Vec<usize> = arr.values.iter().filter_map(|v| match v {
214                MetadataValue::Int32(n) if *n > 0 => Some(*n as usize),
215                _ => None,
216            }).collect();
217            if sections.is_empty() { None } else { Some(sections) }
218        } else {
219            None
220        };
221
222        let rope_config = RopeConfig {
223            freq_base,
224            freq_scale,
225            n_dims: rope_n_dims,
226            scaling_type: RopeScalingType::None,
227            original_max_position_embeddings: max_seq_len,
228            rope_type,
229            mrope_sections,
230        };
231
232        // Architecture-specific configuration
233        let has_combined_qkv = architecture.has_combined_qkv();
234        let uses_layer_norm = architecture.uses_layer_norm();
235        let uses_gelu = architecture.uses_gelu();
236        let has_ffn_gate = !architecture.has_no_gate_ffn();
237
238        // Gemma2 logit softcapping
239        let attn_logit_softcap =
240            get_f32_or(&format!("{}.attn_logit_softcapping", arch), 0.0);
241        let final_logit_softcap =
242            get_f32_or(&format!("{}.final_logit_softcapping", arch), 0.0);
243        let sliding_window =
244            get_u32(&format!("{}.attention.sliding_window", arch)).unwrap_or(0) as usize;
245
246        // Some architectures default to attention bias
247        let attention_bias = matches!(
248            architecture,
249            Architecture::Qwen
250                | Architecture::Qwen2
251                | Architecture::Qwen2Moe
252                | Architecture::Phi2
253                | Architecture::Phi3
254                | Architecture::PhiMoe
255                | Architecture::GPTNeoX
256                | Architecture::GPTJ
257                | Architecture::Falcon
258                | Architecture::BLOOM
259                | Architecture::MPT
260                | Architecture::OPT
261                | Architecture::GPT2
262                | Architecture::StableLM
263                | Architecture::Baichuan
264        );
265
266        let mlp_bias = matches!(
267            architecture,
268            Architecture::GPT2
269                | Architecture::GPTJ
270                | Architecture::GPTNeoX
271                | Architecture::BLOOM
272                | Architecture::OPT
273                | Architecture::StableLM
274                | Architecture::Phi2
275                | Architecture::Phi3
276        );
277
278        // Parallel residual: attention and FFN both computed from norm(x), added to residual.
279        // Phi-3/Phi-4/PhiMoe use sequential residual (separate attn_norm + ffn_norm).
280        let use_parallel_residual = matches!(
281            architecture,
282            Architecture::GPTNeoX
283                | Architecture::GPTJ
284                | Architecture::StableLM
285                | Architecture::Phi
286                | Architecture::Phi2
287                | Architecture::CodeShell
288        );
289
290        // Activation type
291        let hidden_act = if architecture.uses_gelu() {
292            ActivationType::GELU
293        } else {
294            ActivationType::SiLU
295        };
296
297        let mut config = ModelConfig {
298            vocab_size,
299            hidden_size,
300            intermediate_size,
301            num_layers,
302            num_heads,
303            num_kv_heads,
304            head_dim,
305            max_seq_len,
306            norm_eps,
307            rope_config,
308            use_parallel_residual,
309            hidden_act,
310            attention_bias,
311            mlp_bias,
312            tie_word_embeddings: gguf
313                .data
314                .get_string("general.tie_word_embeddings")
315                .map(|s| s == "true")
316                .unwrap_or(false),
317            num_experts,
318            num_experts_per_token,
319            expert_intermediate_size,
320            key_length,
321            value_length,
322            ssm_d_inner: get_u32(&format!("{}.ssm.inner_size", arch)).unwrap_or(0) as usize,
323            ssm_d_state: get_u32(&format!("{}.ssm.state_size", arch)).unwrap_or(0) as usize,
324            ssm_n_group: {
325                let g = get_u32(&format!("{}.ssm.group_count", arch)).unwrap_or(0) as usize;
326                // Mamba1 has no group_count; default to 1
327                if g == 0 && matches!(architecture, Architecture::Mamba | Architecture::Mamba2) {
328                    1
329                } else {
330                    g
331                }
332            },
333            ssm_dt_rank: get_u32(&format!("{}.ssm.time_step_rank", arch)).unwrap_or(0) as usize,
334            ssm_conv_kernel: get_u32(&format!("{}.ssm.conv_kernel", arch)).unwrap_or(0) as usize,
335            attn_logit_softcap,
336            final_logit_softcap,
337            sliding_window,
338            has_combined_qkv,
339            uses_layer_norm,
340            uses_gelu,
341            has_ffn_gate,
342            attention_layer_configs: None,
343            kv_source_layer: None,
344        };
345
346        // Gemma 4: heterogeneous attention layers.
347        //
348        // GGUF convention: unsuffixed keys = global attention, `_swa` = sliding window.
349        // The per-layer pattern is an explicit bool array (sliding_window_pattern),
350        // not a fixed period.
351        if architecture.has_heterogeneous_attention() {
352            // Global attention params (the GGUF "default" keys)
353            let global_head_dim = config.head_dim; // key_length = 512
354            let global_kv_heads = config.num_kv_heads; // head_count_kv = 1
355            let global_rope_freq_base = config.rope_config.freq_base; // rope.freq_base = 1e6
356            // Global RoPE dims: the GGUF dimension_count may report the full head_dim
357            // (e.g., 512), but Gemma 4 uses rope_freqs.weight to mask out dimensions.
358            // Entries with value ~1.0 are active; entries with 1e30 are frozen.
359            // Count active pairs to get the actual rotated dimension count.
360            // Global RoPE dims: the GGUF dimension_count may report the full head_dim
361            // (e.g., 512), but Gemma 4 uses rope_freqs.weight to mask out dimensions.
362            // Entries with value ~1.0 are active; entries with 1e30 are frozen.
363            // Count active pairs to get the actual rotated dimension count.
364            let global_rope_dims = if let Some(data) = gguf.tensor_data("rope_freqs.weight") {
365                let floats: &[f32] = bytemuck::cast_slice(data);
366                let active_pairs = floats.iter().filter(|&&v| v < 1e10).count();
367                active_pairs * 2
368            } else {
369                get_u32(&format!("{}.rope.dimension_count", arch))
370                    .unwrap_or(global_head_dim as u32) as usize
371            };
372
373            // Sliding window attention params (_swa suffix)
374            let swa_head_dim =
375                get_u32(&format!("{}.attention.key_length_swa", arch))
376                    .unwrap_or(global_head_dim as u32) as usize;
377            let swa_kv_heads =
378                get_u32(&format!("{}.attention.head_count_kv_swa", arch))
379                    .unwrap_or(global_kv_heads as u32) as usize;
380            let swa_rope_freq_base =
381                get_f32_or(&format!("{}.rope.freq_base_swa", arch), global_rope_freq_base);
382            let swa_rope_dims =
383                get_u32(&format!("{}.rope.dimension_count_swa", arch))
384                    .unwrap_or(swa_head_dim as u32) as usize;
385            let sliding_window = config.sliding_window;
386
387            // Per-layer pattern: bool array where true = SWA, false = global
388            let swa_pattern: Vec<bool> =
389                if let Some(MetadataValue::Array(arr)) =
390                    gguf.data.metadata.get(&format!("{}.attention.sliding_window_pattern", arch))
391                {
392                    arr.values
393                        .iter()
394                        .map(|v| matches!(v, MetadataValue::Bool(true)))
395                        .collect()
396                } else {
397                    // Fallback: 5 SWA + 1 global repeating pattern
398                    (0..config.num_layers)
399                        .map(|i| i % 6 != 5)
400                        .collect()
401                };
402
403            // `head_count_kv` may be a per-layer array (Gemma 4: 8 KV heads for
404            // sliding layers, 1 for global). The scalar `get_u32` returns None for
405            // an array, so `num_kv_heads` above fell back to `num_heads` and the
406            // `_swa` scalar keys are absent. Recover the per-type counts from the
407            // array, indexed by the sliding-window pattern.
408            let (swa_kv_heads, global_kv_heads) = match gguf
409                .data
410                .get_u32_array(&format!("{}.attention.head_count_kv", arch))
411                .filter(|v| v.len() == config.num_layers)
412            {
413                Some(per_layer) => {
414                    let swa = swa_pattern
415                        .iter()
416                        .position(|&s| s)
417                        .map(|i| per_layer[i] as usize)
418                        .unwrap_or(swa_kv_heads);
419                    let global = swa_pattern
420                        .iter()
421                        .position(|&s| !s)
422                        .map(|i| per_layer[i] as usize)
423                        .unwrap_or(global_kv_heads);
424                    (swa, global)
425                }
426                None => (swa_kv_heads, global_kv_heads),
427            };
428            // Keep the scalar config consistent with the global ("default") layer.
429            config.num_kv_heads = global_kv_heads;
430
431            config.attention_layer_configs =
432                Some(ModelConfig::build_attention_layer_configs_from_pattern(
433                    &swa_pattern,
434                    swa_head_dim,
435                    swa_kv_heads,
436                    swa_rope_freq_base,
437                    swa_rope_dims,
438                    sliding_window,
439                    global_head_dim,
440                    global_kv_heads,
441                    global_rope_freq_base,
442                    global_rope_dims,
443                ));
444
445            // Shared KV layers: type-specific mapping.
446            // Shared SWA layers reuse the last KV-owning SWA layer's cache.
447            // Shared global layers reuse the last KV-owning global layer's cache.
448            let shared_layers =
449                get_u32(&format!("{}.attention.shared_kv_layers", arch)).unwrap_or(0) as usize;
450            if shared_layers > 0 {
451                config.kv_source_layer = Some(ModelConfig::build_kv_source_mapping(
452                    config.num_layers,
453                    shared_layers,
454                    config.attention_layer_configs.as_ref().unwrap(),
455                ));
456            }
457        }
458
459        Ok(config)
460    }
461
462    /// Get the model configuration
463    pub fn config(&self) -> &ModelConfig {
464        &self.config
465    }
466
467    /// Get mutable reference to model configuration (e.g., to clamp context length).
468    pub fn config_mut(&mut self) -> &mut ModelConfig {
469        &mut self.config
470    }
471
472    /// Get the detected architecture
473    pub fn architecture(&self) -> Architecture {
474        self.architecture
475    }
476
477    /// Build the model from loaded weights.
478    ///
479    /// Delegates to the format-independent [`build_llama_model`] free function
480    /// via the [`ModelSource`] trait, so that SafeTensors and other loaders can
481    /// reuse the same layer assembly logic.
482    pub fn build_model(self) -> ModelResult<LlamaModel> {
483        build_llama_model(&self)
484    }
485
486    /// Build a BERT encoder-only model from loaded weights
487    pub fn build_bert_model(self) -> ModelResult<BertModel> {
488        let token_embedding = self.load_tensor("token_embd.weight")?;
489
490        let position_embedding = self.try_load_tensor("position_embd.weight");
491        let token_type_embedding = self.try_load_tensor("token_types.weight");
492
493        // Embedding normalization
494        let embed_norm = if let Some(w) = self.try_load_tensor("token_embd_norm.weight") {
495            if let Some(b) = self.try_load_tensor("token_embd_norm.bias") {
496                Some(NormLayer::Layer(LayerNorm::new(w, b, self.config.norm_eps)?))
497            } else {
498                Some(NormLayer::RMS(RMSNorm::new(w, self.config.norm_eps)?))
499            }
500        } else {
501            None
502        };
503
504        let mut layers = Vec::with_capacity(self.config.num_layers);
505        for i in 0..self.config.num_layers {
506            let prefix = format!("blk.{}", i);
507
508            // Attention normalization: try attn_output_norm (BERT) then attn_norm
509            let attn_norm_w = self
510                .try_load_tensor(&format!("{}.attn_output_norm.weight", prefix))
511                .or_else(|| self.try_load_tensor(&format!("{}.attn_norm.weight", prefix)))
512                .ok_or_else(|| {
513                    ModelError::MissingTensor(format!("{}.attn_norm.weight", prefix))
514                })?;
515            let attn_norm_b = self
516                .try_load_tensor(&format!("{}.attn_output_norm.bias", prefix))
517                .or_else(|| self.try_load_tensor(&format!("{}.attn_norm.bias", prefix)));
518            let attn_norm = if let Some(b) = attn_norm_b {
519                NormLayer::Layer(LayerNorm::new(attn_norm_w, b, self.config.norm_eps)?)
520            } else {
521                NormLayer::RMS(RMSNorm::new(attn_norm_w, self.config.norm_eps)?)
522            };
523
524            // Load Q, K, V (combined or separate)
525            let (wq, wk, wv) =
526                if let Some(qkv) = self.try_load_tensor(&format!("{}.attn_qkv.weight", prefix)) {
527                    // Split combined QKV for BERT
528                    let num_heads = self.config.num_heads;
529                    let head_dim = self.config.head_dim;
530                    let hidden = self.config.hidden_size;
531                    let q_size = num_heads * head_dim;
532                    let k_size = num_heads * head_dim;
533                    let v_size = num_heads * head_dim;
534
535                    let qkv_f32 = if qkv.dtype() == DType::F32 {
536                        qkv.as_f32()?.to_vec()
537                    } else {
538                        let backend = crate::backend::default_backend();
539                        let mut deq = Tensor::zeros(vec![qkv.numel()], DType::F32);
540                        backend
541                            .dequantize(&qkv, &mut deq)
542                            .map_err(|e| ModelError::ConfigError(format!("Dequant QKV: {}", e)))?;
543                        deq.as_f32()?.to_vec()
544                    };
545
546                    // GGUF layout: ne[0]=hidden (innermost), ne[1]=total (outer).
547                    // Q/K/V occupy contiguous rows: first q_size rows, then k_size, then v_size.
548                    let q_start = 0;
549                    let k_start_off = q_size * hidden;
550                    let v_start_off = (q_size + k_size) * hidden;
551
552                    let qkv_bias = self.try_load_tensor(&format!("{}.attn_qkv.bias", prefix));
553                    let (qb, kb, vb) = if let Some(ref b) = qkv_bias {
554                        let bd = b.as_f32()?;
555                        (
556                            Some(Tensor::from_f32(&bd[..q_size], vec![q_size])?),
557                            Some(Tensor::from_f32(
558                                &bd[q_size..q_size + k_size],
559                                vec![k_size],
560                            )?),
561                            Some(Tensor::from_f32(&bd[q_size + k_size..], vec![v_size])?),
562                        )
563                    } else {
564                        (None, None, None)
565                    };
566
567                    (
568                        Linear::new(
569                            Tensor::from_f32(&qkv_f32[q_start..q_start + q_size * hidden], vec![hidden, q_size])?,
570                            qb,
571                        )?,
572                        Linear::new(
573                            Tensor::from_f32(&qkv_f32[k_start_off..k_start_off + k_size * hidden], vec![hidden, k_size])?,
574                            kb,
575                        )?,
576                        Linear::new(
577                            Tensor::from_f32(&qkv_f32[v_start_off..v_start_off + v_size * hidden], vec![hidden, v_size])?,
578                            vb,
579                        )?,
580                    )
581                } else {
582                    let qb = self.try_load_tensor(&format!("{}.attn_q.bias", prefix));
583                    let kb = self.try_load_tensor(&format!("{}.attn_k.bias", prefix));
584                    let vb = self.try_load_tensor(&format!("{}.attn_v.bias", prefix));
585                    (
586                        Linear::new(
587                            self.load_tensor(&format!("{}.attn_q.weight", prefix))?,
588                            qb,
589                        )?,
590                        Linear::new(
591                            self.load_tensor(&format!("{}.attn_k.weight", prefix))?,
592                            kb,
593                        )?,
594                        Linear::new(
595                            self.load_tensor(&format!("{}.attn_v.weight", prefix))?,
596                            vb,
597                        )?,
598                    )
599                };
600
601            let wo_bias = self.try_load_tensor(&format!("{}.attn_output.bias", prefix));
602            let wo = Linear::new(
603                self.load_tensor(&format!("{}.attn_output.weight", prefix))?,
604                wo_bias,
605            )?;
606
607            // FFN normalization: try layer_output_norm (BERT) then ffn_norm
608            let ffn_norm_w = self
609                .try_load_tensor(&format!("{}.layer_output_norm.weight", prefix))
610                .or_else(|| self.try_load_tensor(&format!("{}.ffn_norm.weight", prefix)))
611                .ok_or_else(|| {
612                    ModelError::MissingTensor(format!("{}.ffn_norm.weight", prefix))
613                })?;
614            let ffn_norm_b = self
615                .try_load_tensor(&format!("{}.layer_output_norm.bias", prefix))
616                .or_else(|| self.try_load_tensor(&format!("{}.ffn_norm.bias", prefix)));
617            let ffn_norm = if let Some(b) = ffn_norm_b {
618                NormLayer::Layer(LayerNorm::new(ffn_norm_w, b, self.config.norm_eps)?)
619            } else {
620                NormLayer::RMS(RMSNorm::new(ffn_norm_w, self.config.norm_eps)?)
621            };
622
623            let ffn_up_bias = self.try_load_tensor(&format!("{}.ffn_up.bias", prefix));
624            let ffn_up = Linear::new(
625                self.load_tensor(&format!("{}.ffn_up.weight", prefix))?,
626                ffn_up_bias,
627            )?;
628            let ffn_down_bias = self.try_load_tensor(&format!("{}.ffn_down.bias", prefix));
629            let ffn_down = Linear::new(
630                self.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
631                ffn_down_bias,
632            )?;
633
634            layers.push(BertLayer {
635                attn_norm,
636                wq,
637                wk,
638                wv,
639                wo,
640                num_heads: self.config.num_heads,
641                head_dim: self.config.head_dim,
642                ffn_norm,
643                ffn_up,
644                ffn_down,
645            });
646        }
647
648        BertModel::new(
649            self.config,
650            token_embedding,
651            position_embedding,
652            token_type_embedding,
653            embed_norm,
654            layers,
655            self.architecture,
656        )
657    }
658
659    /// Get the DeltaNet config for creating recurrent state (Qwen3Next).
660    /// Returns None if the model has no SSM layers or is Mamba.
661    pub fn deltanet_config(&self) -> Option<DeltaNetConfig> {
662        deltanet_config_from_source(self)
663    }
664
665    /// Get the recurrent config (DeltaNet or Mamba) for creating inference context.
666    pub fn recurrent_config(&self) -> Option<super::deltanet::RecurrentConfig> {
667        if !self.config.has_ssm() {
668            return None;
669        }
670        if matches!(self.architecture, Architecture::Mamba | Architecture::Mamba2) {
671            Some(super::deltanet::RecurrentConfig::Mamba(MambaConfig {
672                d_inner: self.config.ssm_d_inner,
673                d_state: self.config.ssm_d_state,
674                dt_rank: self.config.ssm_dt_rank,
675                conv_kernel: self.config.ssm_conv_kernel.max(1),
676            }))
677        } else if let Some(dn) = self.deltanet_config() {
678            Some(super::deltanet::RecurrentConfig::DeltaNet(dn))
679        } else {
680            None
681        }
682    }
683
684    /// Try to load a tensor from the GGUF file, returning None if not found.
685    /// This is the GGUF-specific implementation used by the [`ModelSource`] trait.
686    fn gguf_try_load_tensor(&self, name: &str) -> Option<Tensor> {
687        let tensor_info = self.gguf.data.get_tensor(name)?;
688        let tensor_data = self.gguf.tensor_data(name)?;
689
690        let shape: Vec<usize> = tensor_info.dims.iter().map(|&d| d as usize).collect();
691        let dtype = DType::from(tensor_info.dtype);
692
693        Tensor::new(tensor_data.to_vec(), shape, dtype)
694            .ok()
695            .map(|mut t| {
696                t.set_name(name);
697                t
698            })
699    }
700
701    /// Load a tensor from the GGUF file.
702    /// This is the GGUF-specific implementation used by the [`ModelSource`] trait.
703    fn gguf_load_tensor(&self, name: &str) -> ModelResult<Tensor> {
704        let tensor_info = self
705            .gguf
706            .data
707            .get_tensor(name)
708            .ok_or_else(|| ModelError::MissingTensor(name.into()))?;
709
710        let tensor_data = self
711            .gguf
712            .tensor_data(name)
713            .ok_or_else(|| ModelError::MissingTensor(name.into()))?;
714
715        let shape: Vec<usize> = tensor_info.dims.iter().map(|&d| d as usize).collect();
716        let dtype = DType::from(tensor_info.dtype);
717
718        // Copy the tensor data to owned storage
719        // This is necessary because the GGUF file is dropped after build_model() returns
720        // and the memory-mapped data would become invalid
721        let mut tensor = Tensor::new(tensor_data.to_vec(), shape, dtype)?;
722
723        // Store the GGUF tensor name for GPU weight lookup
724        tensor.set_name(name);
725
726        Ok(tensor)
727    }
728}
729
730impl ModelSource for ModelLoader {
731    fn config(&self) -> &ModelConfig {
732        &self.config
733    }
734
735    fn config_mut(&mut self) -> &mut ModelConfig {
736        &mut self.config
737    }
738
739    fn architecture(&self) -> Architecture {
740        self.architecture
741    }
742
743    fn load_tensor(&self, name: &str) -> ModelResult<Tensor> {
744        self.gguf_load_tensor(name)
745    }
746
747    fn try_load_tensor(&self, name: &str) -> Option<Tensor> {
748        self.gguf_try_load_tensor(name)
749    }
750}
751
752// ---------------------------------------------------------------------------
753// Format-independent model assembly functions
754//
755// These free functions accept `&dyn ModelSource` so that any weight format
756// (GGUF, SafeTensors, ONNX) can reuse the same layer construction logic.
757// ---------------------------------------------------------------------------
758
759/// Build a LlamaModel from any [`ModelSource`].
760///
761/// This is the format-independent entry point for model assembly. It loads
762/// token embeddings, transformer layers, output normalization, output
763/// projection, and PLIE tensors from the given source.
764pub fn build_llama_model(source: &dyn ModelSource) -> ModelResult<LlamaModel> {
765    // Load token embeddings
766    let token_embedding = source.load_tensor("token_embd.weight")?;
767
768    // Load transformer layers
769    let config = source.config();
770    let mut layers = Vec::with_capacity(config.num_layers);
771    for i in 0..config.num_layers {
772        let layer = load_transformer_layer(source, i)?;
773        layers.push(layer);
774    }
775
776    // Log recurrent layer summary
777    let recurrent_count = layers.iter().filter(|l| l.is_recurrent()).count();
778    if recurrent_count > 0 {
779        tracing::info!(
780            "Model has {}/{} DeltaNet recurrent layers",
781            recurrent_count,
782            layers.len()
783        );
784    }
785
786    // Load final normalization (LayerNorm if bias exists, else RMSNorm)
787    let norm_weight =
788        apply_gemma_norm_weight_offset(source.load_tensor("output_norm.weight")?)?;
789    let norm = if let Some(bias) = source.try_load_tensor("output_norm.bias") {
790        NormLayer::Layer(LayerNorm::new(norm_weight, bias, config.norm_eps)?)
791    } else {
792        NormLayer::RMS(RMSNorm::new(norm_weight, config.norm_eps)?)
793    };
794
795    // Load output projection (may be tied to embeddings)
796    let output_bias = source.try_load_tensor("output.bias");
797    let output =
798        if config.tie_word_embeddings || source.try_load_tensor("output.weight").is_none() {
799            Linear::new(token_embedding.clone(), output_bias)?
800        } else {
801            let output_weight = source.load_tensor("output.weight")?;
802            Linear::new(output_weight, output_bias)?
803        };
804
805    // Gemma 4 PLIE: load shared per-layer embedding tensors
806    let per_layer_token_embd = source.try_load_tensor("per_layer_token_embd.weight");
807    let per_layer_model_proj = source
808        .try_load_tensor("per_layer_model_proj.weight")
809        .map(|w| {
810            // This tensor is often BF16; dequantize to F32 at load time
811            if w.dtype() != DType::F32 {
812                let backend = crate::backend::default_backend();
813                let mut deq = Tensor::zeros(vec![w.numel()], DType::F32);
814                backend
815                    .dequantize(&w, &mut deq)
816                    .map_err(|e| {
817                        ModelError::ConfigError(format!(
818                            "Failed to dequantize per_layer_model_proj: {e}"
819                        ))
820                    })?;
821                let shape = w.shape().to_vec();
822                let deq = deq.reshape(shape)?;
823                Linear::new(deq, None)
824            } else {
825                Linear::new(w, None)
826            }
827        })
828        .transpose()?;
829    let per_layer_proj_norm = source
830        .try_load_tensor("per_layer_proj_norm.weight")
831        .map(|w| RMSNorm::new(w, config.norm_eps))
832        .transpose()?;
833
834    // Determine n_epl from per_layer_proj_norm dimension (256 for Gemma 4)
835    let n_epl = per_layer_proj_norm
836        .as_ref()
837        .map(|n| n.hidden_size)
838        .unwrap_or(0);
839
840    if n_epl > 0 {
841        tracing::info!(
842            "Gemma 4 PLIE active: n_epl={}, n_layers={}, total_pl_dim={}",
843            n_epl,
844            config.num_layers,
845            n_epl * config.num_layers
846        );
847    }
848
849    LlamaModel::new(
850        config.clone(),
851        token_embedding,
852        layers,
853        norm,
854        output,
855        source.architecture(),
856        per_layer_token_embd,
857        per_layer_model_proj,
858        per_layer_proj_norm,
859        n_epl,
860    )
861}
862
863/// Load a single transformer layer from any [`ModelSource`].
864fn load_transformer_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<TransformerLayer> {
865    let prefix = format!("blk.{}", layer_idx);
866    let config = source.config();
867    let arch = source.architecture();
868    let is_mamba = matches!(arch, Architecture::Mamba | Architecture::Mamba2);
869
870    // Attention normalization (Mamba may use norm.weight or attn_norm.weight)
871    let attn_norm_weight = source
872        .try_load_tensor(&format!("{}.attn_norm.weight", prefix))
873        .or_else(|| source.try_load_tensor(&format!("{}.norm.weight", prefix)))
874        .ok_or_else(|| ModelError::MissingTensor(format!("{}.attn_norm.weight", prefix)))?;
875    let attn_norm_weight = apply_gemma_norm_weight_offset(attn_norm_weight)?;
876    let attn_norm_bias = source
877        .try_load_tensor(&format!("{}.attn_norm.bias", prefix))
878        .or_else(|| source.try_load_tensor(&format!("{}.norm.bias", prefix)));
879    let attn_norm = if let Some(bias) = attn_norm_bias {
880        NormLayer::Layer(LayerNorm::new(attn_norm_weight, bias, config.norm_eps)?)
881    } else {
882        NormLayer::RMS(RMSNorm::new(attn_norm_weight, config.norm_eps)?)
883    };
884
885    // Load attention based on available tensors
886    let attn_layer = load_attention_layer(source, layer_idx)?;
887
888    // Normalization between attention and FFN. Two conventions coexist:
889    //
890    //  * Gemma2/Cohere2: the `post_attention_norm` is applied to the attention
891    //    output BEFORE the residual add (`h = x + post_attn_norm(attn_out)`),
892    //    and a separate `ffn_norm` normalizes the residual before the FFN.
893    //  * Qwen3.5/3.6 MoE: no `ffn_norm` tensor exists. The sole post-attention
894    //    tensor (`post_attention_norm`) *is* the FFN normalization and is
895    //    applied AFTER the residual (`ffn(post_attention_norm(h))`).
896    //
897    // We disambiguate by checking whether a dedicated `ffn_norm` tensor is
898    // present: if both exist we follow the Gemma convention; if only
899    // `post_attention_norm` exists we remap it to `ffn_norm`.
900    let ffn_norm_weight = source.try_load_tensor(&format!("{}.ffn_norm.weight", prefix));
901    let ffn_norm_bias = source.try_load_tensor(&format!("{}.ffn_norm.bias", prefix));
902    let post_attn_w =
903        source.try_load_tensor(&format!("{}.post_attention_norm.weight", prefix));
904    let post_attn_b = source.try_load_tensor(&format!("{}.post_attention_norm.bias", prefix));
905
906    let (ffn_norm, post_attn_norm) = if let Some(w) = ffn_norm_weight {
907        let w = apply_gemma_norm_weight_offset(w)?;
908        let ffn = if let Some(bias) = ffn_norm_bias {
909            NormLayer::Layer(LayerNorm::new(w, bias, config.norm_eps)?)
910        } else {
911            NormLayer::RMS(RMSNorm::new(w, config.norm_eps)?)
912        };
913        let pan = post_attn_w
914            .map(|w| -> ModelResult<NormLayer> {
915                let w = apply_gemma_norm_weight_offset(w)?;
916                Ok(if let Some(bias) = post_attn_b {
917                    NormLayer::Layer(LayerNorm::new(w, bias, config.norm_eps)?)
918                } else {
919                    NormLayer::RMS(RMSNorm::new(w, config.norm_eps)?)
920                })
921            })
922            .transpose()?;
923        (ffn, pan)
924    } else if let Some(w) = post_attn_w {
925        // Qwen3.5/3.6 MoE: `post_attention_norm` is the FFN norm. Don't wire
926        // it into `post_attn_norm` — that path normalizes pre-residual.
927        let w = apply_gemma_norm_weight_offset(w)?;
928        let ffn = if let Some(bias) = post_attn_b {
929            NormLayer::Layer(LayerNorm::new(w, bias, config.norm_eps)?)
930        } else {
931            NormLayer::RMS(RMSNorm::new(w, config.norm_eps)?)
932        };
933        (ffn, None)
934    } else if is_mamba || config.use_parallel_residual {
935        // Parallel residual models (Phi-2, GPT-NeoX, GPT-J) share the
936        // attention norm for both branches, so ffn_norm doesn't exist.
937        // Mamba may lack a separate ffn_norm entirely. Use an identity norm.
938        let hidden = config.hidden_size;
939        let ffn = NormLayer::RMS(RMSNorm::new(
940            Tensor::from_f32(&vec![1.0f32; hidden], vec![hidden])?,
941            config.norm_eps,
942        )?);
943        (ffn, None)
944    } else {
945        return Err(ModelError::MissingTensor(format!(
946            "{}.ffn_norm.weight",
947            prefix
948        )));
949    };
950
951    // Load FFN: MoE, dense, or dummy for pure Mamba without FFN
952    let ffn_layer = if config.is_moe() {
953        load_moe_layer(source, layer_idx)?
954    } else if is_mamba
955        && source.try_load_tensor(&format!("{}.ffn_up.weight", prefix)).is_none()
956    {
957        FfnLayer::Identity
958    } else if !config.has_ffn_gate {
959        let up_tensor = source.load_tensor(&format!("{}.ffn_up.weight", prefix))?;
960        let up_out_dim = up_tensor.shape()[up_tensor.ndim() - 1];
961        let intermediate = config.intermediate_size;
962
963        if up_out_dim == 2 * intermediate {
964            // Fused gate+up projection (Phi-3, Phi-4): split into separate gate and up
965            let hidden = config.hidden_size;
966            let up_f32 = if up_tensor.dtype() == DType::F32 {
967                up_tensor.as_f32()?.to_vec()
968            } else {
969                let backend = crate::backend::default_backend();
970                let mut deq = Tensor::zeros(vec![up_tensor.numel()], DType::F32);
971                backend
972                    .dequantize(&up_tensor, &mut deq)
973                    .map_err(|e| ModelError::ConfigError(format!("Dequant ffn_up: {}", e)))?;
974                deq.as_f32()?.to_vec()
975            };
976
977            let gate_data = &up_f32[..hidden * intermediate];
978            let up_data = &up_f32[hidden * intermediate..];
979            let w_gate = Linear::new(
980                Tensor::from_f32(gate_data, vec![hidden, intermediate])?,
981                None,
982            )?;
983            let w_up = Linear::new(
984                Tensor::from_f32(up_data, vec![hidden, intermediate])?,
985                None,
986            )?;
987            let w_down = Linear::new(
988                source.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
989                None,
990            )?;
991            let mut ffn = FeedForward::new(w_gate, w_up, w_down);
992            ffn.use_gelu = config.uses_gelu;
993            FfnLayer::Dense(ffn)
994        } else {
995            let w_up = Linear::new(
996                up_tensor,
997                source.try_load_tensor(&format!("{}.ffn_up.bias", prefix)),
998            )?;
999            let w_down = Linear::new(
1000                source.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
1001                source.try_load_tensor(&format!("{}.ffn_down.bias", prefix)),
1002            )?;
1003            FfnLayer::NoGate(NoGateFeedForward::new(
1004                w_up,
1005                w_down,
1006                config.uses_gelu,
1007            ))
1008        }
1009    } else {
1010        let w_gate = Linear::new(
1011            source.load_tensor(&format!("{}.ffn_gate.weight", prefix))?,
1012            None,
1013        )?;
1014        let w_up = Linear::new(
1015            source.load_tensor(&format!("{}.ffn_up.weight", prefix))?,
1016            None,
1017        )?;
1018        let w_down = Linear::new(
1019            source.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
1020            None,
1021        )?;
1022        let mut ffn = FeedForward::new(w_gate, w_up, w_down);
1023        ffn.use_gelu = config.uses_gelu;
1024        FfnLayer::Dense(ffn)
1025    };
1026
1027    // Post-FFN normalization (Gemma2, Cohere2)
1028    let post_ffn_norm =
1029        if let Some(w) = source.try_load_tensor(&format!("{}.post_ffw_norm.weight", prefix)) {
1030            let w = apply_gemma_norm_weight_offset(w)?;
1031            let b = source.try_load_tensor(&format!("{}.post_ffw_norm.bias", prefix));
1032            Some(if let Some(bias) = b {
1033                NormLayer::Layer(LayerNorm::new(w, bias, config.norm_eps)?)
1034            } else {
1035                NormLayer::RMS(RMSNorm::new(w, config.norm_eps)?)
1036            })
1037        } else {
1038            None
1039        };
1040
1041    let rope_freq_base_override = config
1042        .attention_layer_configs
1043        .as_ref()
1044        .map(|cfgs| cfgs[layer_idx].rope_freq_base)
1045        .unwrap_or(0.0);
1046
1047    // Gemma 4 PLIE per-layer tensors
1048    let plie_inp_gate = source
1049        .try_load_tensor(&format!("{}.inp_gate.weight", prefix))
1050        .map(|w| Linear::new(w, None))
1051        .transpose()?;
1052    let plie_proj = source
1053        .try_load_tensor(&format!("{}.proj.weight", prefix))
1054        .map(|w| Linear::new(w, None))
1055        .transpose()?;
1056    let plie_post_norm = source
1057        .try_load_tensor(&format!("{}.post_norm.weight", prefix))
1058        .map(|w| RMSNorm::new(w, config.norm_eps))
1059        .transpose()?;
1060    let layer_output_scale = source
1061        .try_load_tensor(&format!("{}.layer_output_scale.weight", prefix))
1062        .and_then(|t| {
1063            if t.dtype() == crate::tensor::DType::F32 {
1064                t.as_f32().ok().map(|d| d[0])
1065            } else {
1066                // BF16 scalar — convert manually
1067                let raw = t.data();
1068                if raw.len() >= 2 {
1069                    let bits = u16::from_le_bytes([raw[0], raw[1]]);
1070                    Some(f32::from_bits((bits as u32) << 16))
1071                } else {
1072                    None
1073                }
1074            }
1075        });
1076
1077    Ok(TransformerLayer {
1078        attn_norm,
1079        attn_layer,
1080        post_attn_norm,
1081        ffn_norm,
1082        ffn_layer,
1083        post_ffn_norm,
1084        layer_idx,
1085        use_parallel_residual: config.use_parallel_residual,
1086        rope_freq_base_override,
1087        plie_inp_gate,
1088        plie_proj,
1089        plie_post_norm,
1090        layer_output_scale,
1091    })
1092}
1093
1094/// Load attention for a layer: either full softmax or delta-net recurrent.
1095fn load_attention_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<AttentionLayer> {
1096    let prefix = format!("blk.{}", layer_idx);
1097    let config = source.config();
1098
1099    if let Some(wq_weight) = source.try_load_tensor(&format!("{}.attn_q.weight", prefix)) {
1100        // Separate Q/K/V projections - standard LLaMA-like
1101        let attn = load_full_attention(source, layer_idx, wq_weight)?;
1102        Ok(AttentionLayer::FullAttention(attn))
1103    } else if let Some(qkv_weight) =
1104        source.try_load_tensor(&format!("{}.attn_qkv.weight", prefix))
1105    {
1106        if config.has_ssm() {
1107            // DeltaNet recurrent layer (Qwen3Next hybrid)
1108            let dn = load_deltanet_layer(source, layer_idx)?;
1109            Ok(AttentionLayer::DeltaNet(Box::new(dn)))
1110        } else {
1111            // Combined QKV for regular attention (Phi2, GPT-NeoX, GPT-J, Falcon, etc.)
1112            let attn = load_combined_qkv_attention(source, layer_idx, qkv_weight)?;
1113            Ok(AttentionLayer::FullAttention(attn))
1114        }
1115    } else if config.has_ssm()
1116        && source.try_load_tensor(&format!("{}.ssm_in.weight", prefix)).is_some()
1117    {
1118        // Pure Mamba/Mamba2 SSM layer (no attention tensors at all)
1119        let mamba = load_mamba_layer(source, layer_idx)?;
1120        Ok(AttentionLayer::Mamba(Box::new(mamba)))
1121    } else {
1122        Err(ModelError::MissingTensor(format!(
1123            "{}.attn_q.weight or {}.attn_qkv.weight or {}.ssm_in.weight",
1124            prefix, prefix, prefix
1125        )))
1126    }
1127}
1128
1129/// Load a full softmax attention layer from separate Q/K/V/O tensors.
1130fn load_full_attention(
1131    source: &dyn ModelSource,
1132    layer_idx: usize,
1133    wq_weight: Tensor,
1134) -> ModelResult<Attention> {
1135    let prefix = format!("blk.{}", layer_idx);
1136    let config = source.config();
1137    let arch = source.architecture();
1138    let use_neox_rope = matches!(config.rope_config.rope_type, RopeType::NeoX);
1139
1140    // Per-layer attention config overrides (Gemma 4 heterogeneous attention)
1141    let (num_kv_heads, head_dim, kl, vl, rope_dims) =
1142        if let Some(ref layer_configs) = config.attention_layer_configs {
1143            let lc = &layer_configs[layer_idx];
1144            (lc.num_kv_heads, lc.head_dim, lc.head_dim, lc.head_dim, lc.rope_dims)
1145        } else {
1146            let kl = config.key_length;
1147            let vl = config.value_length;
1148            let rope_dims = config.rope_config.n_dims;
1149            (config.num_kv_heads, config.head_dim, kl, vl, rope_dims)
1150        };
1151
1152    let wq_bias = source.try_load_tensor(&format!("{}.attn_q.bias", prefix));
1153    let actual_q_out = wq_weight.shape()[1];
1154    let has_attention_gate = actual_q_out == config.num_heads * (kl + vl);
1155
1156    let wq = Linear::new(wq_weight, wq_bias)?;
1157
1158    let wk_bias = source.try_load_tensor(&format!("{}.attn_k.bias", prefix));
1159    let wk = Linear::new(
1160        source.load_tensor(&format!("{}.attn_k.weight", prefix))?,
1161        wk_bias,
1162    )?;
1163    // Gemma 4 global layers set `attention_k_eq_v`: the value projection is tied
1164    // to the key projection and the GGUF omits `attn_v.weight`. When V is absent,
1165    // alias it to the key projection. The forward pass applies k-norm and RoPE only
1166    // to the K buffer and `normalize_v` only to the V buffer, so reusing `attn_k`
1167    // here reproduces `value = v_norm(k_proj(x))` with no k-norm or RoPE on V.
1168    let wv_bias = source.try_load_tensor(&format!("{}.attn_v.bias", prefix));
1169    let wv_weight = match source.try_load_tensor(&format!("{}.attn_v.weight", prefix)) {
1170        Some(w) => w,
1171        None => source.load_tensor(&format!("{}.attn_k.weight", prefix))?,
1172    };
1173    let wv = Linear::new(wv_weight, wv_bias)?;
1174    let wo_bias = source.try_load_tensor(&format!("{}.attn_output.bias", prefix));
1175    let wo = Linear::new(
1176        source.load_tensor(&format!("{}.attn_output.weight", prefix))?,
1177        wo_bias,
1178    )?;
1179
1180    let mut attention = Attention::with_kv_dims(
1181        wq, wk, wv, wo,
1182        config.num_heads,
1183        num_kv_heads,
1184        head_dim,
1185        kl, vl, rope_dims,
1186        use_neox_rope,
1187        has_attention_gate,
1188    );
1189
1190    if arch.uses_qk_norm()
1191        && let (Some(q_norm_w), Some(k_norm_w)) = (
1192            source.try_load_tensor(&format!("{}.attn_q_norm.weight", prefix)),
1193            source.try_load_tensor(&format!("{}.attn_k_norm.weight", prefix)),
1194        )
1195    {
1196        let q_norm = RMSNorm::new(q_norm_w, config.norm_eps)?;
1197        let k_norm = RMSNorm::new(k_norm_w, config.norm_eps)?;
1198        attention.set_qk_norms(q_norm, k_norm);
1199    }
1200
1201    if config.attn_logit_softcap > 0.0 {
1202        attention.set_attn_logit_softcap(config.attn_logit_softcap);
1203    }
1204
1205    // Qwen3Next/Qwen35Moe use [nope | rope] layout; all others use [rope | nope]
1206    if matches!(arch, Architecture::Qwen3Next | Architecture::Qwen35Moe) {
1207        attention.set_rope_partial_at_end(true);
1208    }
1209
1210    // MRoPE sections (Qwen 3.6)
1211    if let Some(ref sections) = config.rope_config.mrope_sections {
1212        attention.mrope_sections = Some(sections.clone());
1213    }
1214
1215    // Per-layer sliding window, RoPE freq dim, V norm, attn scale (Gemma 4)
1216    if let Some(ref layer_configs) = config.attention_layer_configs {
1217        let lc = &layer_configs[layer_idx];
1218        if lc.sliding_window > 0 {
1219            attention.set_sliding_window(lc.sliding_window);
1220        }
1221        if lc.rope_dims < lc.head_dim {
1222            attention.set_rope_freq_dim(lc.head_dim);
1223        }
1224        // Gemma 4: raw RMS norm on V values (no learned weights)
1225        attention.normalize_v = true;
1226        // Gemma 4 softmax scale is 1.0, NOT 1/sqrt(head_dim). The per-head QK
1227        // RMSNorm (with its learned weights) sets the dot-product magnitude, so
1228        // no additional 1/sqrt(head_dim) factor is applied. Verified against a
1229        // first-principles reconstruction from the GGUF weights: with scale=1.0
1230        // the layer-0 attention output matches; any 1/sqrt(d)-style scale does
1231        // not. (Position 0 hides this because a 1-element softmax is
1232        // scale-invariant.)
1233        attention.scale = 1.0;
1234    }
1235
1236    Ok(attention)
1237}
1238
1239/// Load attention from a combined QKV tensor by splitting it into separate Q, K, V.
1240fn load_combined_qkv_attention(
1241    source: &dyn ModelSource,
1242    layer_idx: usize,
1243    qkv_weight: Tensor,
1244) -> ModelResult<Attention> {
1245    let prefix = format!("blk.{}", layer_idx);
1246    let config = source.config();
1247    let use_neox_rope = matches!(config.rope_config.rope_type, RopeType::NeoX);
1248    let kl = config.key_length;
1249    let vl = config.value_length;
1250    let rope_dims = config.rope_config.n_dims;
1251    let num_heads = config.num_heads;
1252    let num_kv_heads = config.num_kv_heads;
1253    let head_dim = config.head_dim;
1254
1255    // Combined QKV shape: [hidden_size, (num_heads + 2 * num_kv_heads) * head_dim]
1256    let qkv_shape = qkv_weight.shape();
1257    let in_features = qkv_shape[0];
1258    let q_size = num_heads * head_dim;
1259    let k_size = num_kv_heads * head_dim;
1260    let v_size = num_kv_heads * head_dim;
1261
1262    // QKV bias
1263    let qkv_bias = source.try_load_tensor(&format!("{}.attn_qkv.bias", prefix));
1264
1265    if qkv_weight.dtype() == DType::F32 {
1266        let qkv_f32 = qkv_weight.as_f32()?;
1267
1268        // GGUF layout: ne[0]=in_features (innermost), ne[1]=total_out (outer).
1269        // Data has total_out rows of in_features elements each.
1270        // Q occupies the first q_size rows, K next k_size rows, V last v_size rows.
1271        let q_start = 0;
1272        let k_start = q_size * in_features;
1273        let v_start = (q_size + k_size) * in_features;
1274
1275        let q_tensor = Tensor::from_f32(
1276            &qkv_f32[q_start..q_start + q_size * in_features],
1277            vec![in_features, q_size],
1278        )?;
1279        let k_tensor = Tensor::from_f32(
1280            &qkv_f32[k_start..k_start + k_size * in_features],
1281            vec![in_features, k_size],
1282        )?;
1283        let v_tensor = Tensor::from_f32(
1284            &qkv_f32[v_start..v_start + v_size * in_features],
1285            vec![in_features, v_size],
1286        )?;
1287
1288        // Split bias if present
1289        let (q_bias, k_bias, v_bias) = if let Some(ref bias) = qkv_bias {
1290            let b = bias.as_f32()?;
1291            let qb = Tensor::from_f32(&b[..q_size], vec![q_size])?;
1292            let kb = Tensor::from_f32(&b[q_size..q_size + k_size], vec![k_size])?;
1293            let vb = Tensor::from_f32(&b[q_size + k_size..], vec![v_size])?;
1294            (Some(qb), Some(kb), Some(vb))
1295        } else {
1296            (None, None, None)
1297        };
1298
1299        let wq = Linear::new(q_tensor, q_bias)?;
1300        let wk = Linear::new(k_tensor, k_bias)?;
1301        let wv = Linear::new(v_tensor, v_bias)?;
1302
1303        let wo_bias = source.try_load_tensor(&format!("{}.attn_output.bias", prefix));
1304        let wo = Linear::new(
1305            source.load_tensor(&format!("{}.attn_output.weight", prefix))?,
1306            wo_bias,
1307        )?;
1308
1309        Ok(Attention::with_kv_dims(
1310            wq, wk, wv, wo,
1311            num_heads, num_kv_heads, head_dim,
1312            kl, vl, rope_dims,
1313            use_neox_rope, false,
1314        ))
1315    } else {
1316        // For quantized combined QKV, we need to dequantize first, split, then use F32
1317        // This is less memory efficient but necessary for correctness
1318        let backend = crate::backend::default_backend();
1319        let numel = qkv_weight.numel();
1320        let mut dequant = Tensor::zeros(vec![numel], DType::F32);
1321        backend
1322            .dequantize(&qkv_weight, &mut dequant)
1323            .map_err(|e| ModelError::ConfigError(format!("Failed to dequantize QKV: {}", e)))?;
1324        let qkv_f32 = dequant.as_f32()?;
1325
1326        // Same contiguous split as the F32 path
1327        let q_start = 0;
1328        let k_start = q_size * in_features;
1329        let v_start = (q_size + k_size) * in_features;
1330
1331        let q_tensor = Tensor::from_f32(
1332            &qkv_f32[q_start..q_start + q_size * in_features],
1333            vec![in_features, q_size],
1334        )?;
1335        let k_tensor = Tensor::from_f32(
1336            &qkv_f32[k_start..k_start + k_size * in_features],
1337            vec![in_features, k_size],
1338        )?;
1339        let v_tensor = Tensor::from_f32(
1340            &qkv_f32[v_start..v_start + v_size * in_features],
1341            vec![in_features, v_size],
1342        )?;
1343
1344        let (q_bias, k_bias, v_bias) = if let Some(ref bias) = qkv_bias {
1345            let b = bias.as_f32()?;
1346            let qb = Tensor::from_f32(&b[..q_size], vec![q_size])?;
1347            let kb = Tensor::from_f32(&b[q_size..q_size + k_size], vec![k_size])?;
1348            let vb = Tensor::from_f32(&b[q_size + k_size..], vec![v_size])?;
1349            (Some(qb), Some(kb), Some(vb))
1350        } else {
1351            (None, None, None)
1352        };
1353
1354        let wq = Linear::new(q_tensor, q_bias)?;
1355        let wk = Linear::new(k_tensor, k_bias)?;
1356        let wv = Linear::new(v_tensor, v_bias)?;
1357
1358        let wo_bias = source.try_load_tensor(&format!("{}.attn_output.bias", prefix));
1359        let wo = Linear::new(
1360            source.load_tensor(&format!("{}.attn_output.weight", prefix))?,
1361            wo_bias,
1362        )?;
1363
1364        Ok(Attention::with_kv_dims(
1365            wq, wk, wv, wo,
1366            num_heads, num_kv_heads, head_dim,
1367            kl, vl, rope_dims,
1368            use_neox_rope, false,
1369        ))
1370    }
1371}
1372
1373/// Load a DeltaNet (recurrent) layer from SSM tensors.
1374fn load_deltanet_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<DeltaNetLayer> {
1375    let prefix = format!("blk.{}", layer_idx);
1376    let cfg = source.config();
1377
1378    let d_inner = cfg.ssm_d_inner;
1379    let d_state = cfg.ssm_d_state;
1380    let num_v_heads = cfg.ssm_dt_rank;
1381    let num_k_heads = cfg.ssm_n_group;
1382    let head_v_dim = d_inner / num_v_heads;
1383    let head_k_dim = d_state;
1384    let conv_kernel = cfg.ssm_conv_kernel;
1385    let q_dim = num_k_heads * head_k_dim;
1386    let k_dim = num_k_heads * head_k_dim;
1387    let qkv_dim = q_dim + k_dim + d_inner;
1388
1389    let dn_config = DeltaNetConfig {
1390        d_inner,
1391        d_state,
1392        num_v_heads,
1393        num_k_heads,
1394        head_v_dim,
1395        head_k_dim,
1396        conv_kernel,
1397        qkv_dim,
1398    };
1399
1400    let attn_qkv = Linear::new(
1401        source.load_tensor(&format!("{}.attn_qkv.weight", prefix))?,
1402        None,
1403    )?;
1404
1405    let attn_gate = Linear::new(
1406        source.load_tensor(&format!("{}.attn_gate.weight", prefix))?,
1407        None,
1408    )?;
1409
1410    let ssm_ba = if let Some(ba_weight) =
1411        source.try_load_tensor(&format!("{}.ssm_ba.weight", prefix))
1412    {
1413        BetaAlphaProjection::Combined(Linear::new(ba_weight, None)?)
1414    } else {
1415        let beta_w = source.load_tensor(&format!("{}.ssm_beta.weight", prefix))?;
1416        let alpha_w = source.load_tensor(&format!("{}.ssm_alpha.weight", prefix))?;
1417        BetaAlphaProjection::Separate {
1418            beta: Linear::new(beta_w, None)?,
1419            alpha: Linear::new(alpha_w, None)?,
1420        }
1421    };
1422
1423    let ssm_conv1d_weight = source.load_tensor(&format!("{}.ssm_conv1d.weight", prefix))?;
1424    let ssm_a = source.load_tensor(&format!("{}.ssm_a", prefix))?;
1425    let ssm_dt_bias = source.load_tensor(&format!("{}.ssm_dt.bias", prefix))?;
1426
1427    let ssm_norm_weight = source.load_tensor(&format!("{}.ssm_norm.weight", prefix))?;
1428    let ssm_norm = RMSNorm::new(ssm_norm_weight, cfg.norm_eps)?;
1429
1430    let ssm_out = Linear::new(
1431        source.load_tensor(&format!("{}.ssm_out.weight", prefix))?,
1432        None,
1433    )?;
1434
1435    tracing::info!("Layer {}: loaded DeltaNet (d_inner={}, d_state={}, v_heads={}, k_heads={}, conv={})",
1436        layer_idx, d_inner, d_state, num_v_heads, num_k_heads, conv_kernel);
1437
1438    Ok(DeltaNetLayer {
1439        config: dn_config,
1440        attn_qkv,
1441        attn_gate,
1442        ssm_ba,
1443        ssm_conv1d_weight,
1444        ssm_a,
1445        ssm_dt_bias,
1446        ssm_norm,
1447        ssm_out,
1448    })
1449}
1450
1451/// Load a pure Mamba/Mamba2 SSM layer from Mamba-specific tensor names.
1452///
1453/// Mamba v1 uses: ssm_in, ssm_conv1d, ssm_x, ssm_dt, ssm_a, ssm_d, ssm_out.
1454fn load_mamba_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<MambaLayer> {
1455    let prefix = format!("blk.{}", layer_idx);
1456    let cfg = source.config();
1457
1458    let d_inner = cfg.ssm_d_inner;
1459    let d_state = cfg.ssm_d_state;
1460    let dt_rank = cfg.ssm_dt_rank;
1461    let conv_kernel = cfg.ssm_conv_kernel.max(1);
1462
1463    let mamba_config = MambaConfig {
1464        d_inner,
1465        d_state,
1466        dt_rank,
1467        conv_kernel,
1468    };
1469
1470    let ssm_in = Linear::new(
1471        source.load_tensor(&format!("{}.ssm_in.weight", prefix))?,
1472        None,
1473    )?;
1474
1475    let ssm_conv1d_weight = source.load_tensor(&format!("{}.ssm_conv1d.weight", prefix))?;
1476    let ssm_conv1d_bias = source.try_load_tensor(&format!("{}.ssm_conv1d.bias", prefix));
1477
1478    let ssm_x = Linear::new(
1479        source.load_tensor(&format!("{}.ssm_x.weight", prefix))?,
1480        None,
1481    )?;
1482
1483    let ssm_dt = Linear::new(
1484        source.load_tensor(&format!("{}.ssm_dt.weight", prefix))?,
1485        None,
1486    )?;
1487
1488    let ssm_dt_bias = source.load_tensor(&format!("{}.ssm_dt.bias", prefix))?;
1489    let ssm_a = source.load_tensor(&format!("{}.ssm_a", prefix))?;
1490    let ssm_d = source.try_load_tensor(&format!("{}.ssm_d", prefix));
1491
1492    let ssm_norm = match source.try_load_tensor(&format!("{}.ssm_norm.weight", prefix)) {
1493        Some(w) => Some(RMSNorm::new(w, cfg.norm_eps)?),
1494        None => None,
1495    };
1496
1497    let ssm_out = Linear::new(
1498        source.load_tensor(&format!("{}.ssm_out.weight", prefix))?,
1499        None,
1500    )?;
1501
1502    tracing::info!(
1503        "Layer {}: loaded Mamba SSM (d_inner={}, d_state={}, dt_rank={}, conv={})",
1504        layer_idx, d_inner, d_state, dt_rank, conv_kernel
1505    );
1506
1507    Ok(MambaLayer {
1508        ssm_in,
1509        ssm_conv1d_weight,
1510        ssm_conv1d_bias,
1511        ssm_x,
1512        ssm_dt,
1513        ssm_dt_bias,
1514        ssm_a,
1515        ssm_d,
1516        ssm_norm,
1517        ssm_out,
1518        config: mamba_config,
1519    })
1520}
1521
1522/// Load MoE layer tensors for a given layer index.
1523fn load_moe_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<FfnLayer> {
1524    let prefix = format!("blk.{}", layer_idx);
1525    let config = source.config();
1526    let num_experts = config.num_experts;
1527    let hidden_dim = config.hidden_size;
1528
1529    // Expert FFN dimension: use expert_intermediate_size if set,
1530    // otherwise fall back to intermediate_size / num_experts_per_token
1531    let expert_ffn_dim = if config.expert_intermediate_size > 0 {
1532        config.expert_intermediate_size
1533    } else {
1534        config.intermediate_size / config.num_experts_per_token
1535    };
1536
1537    // Router/gate weights: [hidden_dim, num_experts]
1538    let router_weight = source.load_tensor(&format!("{}.ffn_gate_inp.weight", prefix))?;
1539    let router = MoeRouter::from_weight(
1540        router_weight,
1541        config.num_experts_per_token,
1542        false, // Qwen3 MoE uses softmax, not log-softmax normalization
1543    );
1544
1545    // Load batched expert weights and split into individual experts
1546    // GGUF stores these as 3D tensors: [n_expert, ffn_dim, hidden_dim] or similar
1547    let gate_exps = source.load_tensor(&format!("{}.ffn_gate_exps.weight", prefix))?;
1548    let up_exps = source.load_tensor(&format!("{}.ffn_up_exps.weight", prefix))?;
1549    let down_exps = source.load_tensor(&format!("{}.ffn_down_exps.weight", prefix))?;
1550
1551    let mut experts = Vec::with_capacity(num_experts);
1552    for e in 0..num_experts {
1553        let mut gate_proj = extract_expert_tensor(&gate_exps, e)?;
1554        let mut up_proj = extract_expert_tensor(&up_exps, e)?;
1555        let mut down_proj = extract_expert_tensor(&down_exps, e)?;
1556
1557        gate_proj.set_name(format!("blk.{}.ffn_gate.{}.weight", layer_idx, e));
1558        up_proj.set_name(format!("blk.{}.ffn_up.{}.weight", layer_idx, e));
1559        down_proj.set_name(format!("blk.{}.ffn_down.{}.weight", layer_idx, e));
1560
1561        experts.push(MoeExpert {
1562            gate_proj,
1563            up_proj,
1564            down_proj,
1565            use_gelu: config.uses_gelu,
1566        });
1567    }
1568
1569    // Load shared experts if present (Qwen3Next)
1570    let mut shared_experts = Vec::new();
1571    if let (Some(mut gate_shexp), Some(mut up_shexp), Some(mut down_shexp)) = (
1572        source.try_load_tensor(&format!("{}.ffn_gate_shexp.weight", prefix)),
1573        source.try_load_tensor(&format!("{}.ffn_up_shexp.weight", prefix)),
1574        source.try_load_tensor(&format!("{}.ffn_down_shexp.weight", prefix)),
1575    ) {
1576        gate_shexp.set_name(format!("blk.{}.ffn_gate_shexp.0.weight", layer_idx));
1577        up_shexp.set_name(format!("blk.{}.ffn_up_shexp.0.weight", layer_idx));
1578        down_shexp.set_name(format!("blk.{}.ffn_down_shexp.0.weight", layer_idx));
1579        shared_experts.push(MoeExpert {
1580            gate_proj: gate_shexp,
1581            up_proj: up_shexp,
1582            down_proj: down_shexp,
1583            use_gelu: config.uses_gelu,
1584        });
1585    }
1586
1587    // Load shared expert gate weight if present (Qwen3Next).
1588    // This tensor may be BF16 — convert to F32 for inference.
1589    let shared_expert_gate =
1590        source.try_load_tensor(&format!("{}.ffn_gate_inp_shexp.weight", prefix))
1591            .map(|t| {
1592                if t.dtype() == DType::F32 {
1593                    t
1594                } else {
1595                    let raw = t.data();
1596                    let f32_vals: Vec<f32> = match t.dtype() {
1597                        DType::BF16 => {
1598                            raw.chunks_exact(2)
1599                                .map(|c| {
1600                                    let bits = u16::from_le_bytes([c[0], c[1]]);
1601                                    f32::from_bits((bits as u32) << 16)
1602                                })
1603                                .collect()
1604                        }
1605                        _ => {
1606                            tracing::warn!("Unsupported dtype for shared expert gate, zeroing");
1607                            vec![0.0f32; t.numel()]
1608                        }
1609                    };
1610                    let shape = t.shape().to_vec();
1611                    Tensor::from_f32(&f32_vals, shape).unwrap()
1612                }
1613            });
1614    if shared_expert_gate.is_some() {
1615        tracing::debug!("Layer {}: loaded shared expert gate", layer_idx);
1616    }
1617
1618    let num_shared = shared_experts.len();
1619    let moe_config = MoeConfig {
1620        num_experts,
1621        num_experts_per_token: config.num_experts_per_token,
1622        expert_hidden_dim: expert_ffn_dim,
1623        num_shared_experts: num_shared,
1624        aux_loss_coef: 0.0,
1625        normalize_router_logits: false,
1626    };
1627
1628    let mut moe_layer = MoeLayer::new(hidden_dim, moe_config);
1629    moe_layer.router = router;
1630    moe_layer.experts = experts;
1631    moe_layer.shared_experts = shared_experts;
1632    moe_layer.shared_expert_gate = shared_expert_gate;
1633
1634    Ok(FfnLayer::Moe(moe_layer))
1635}
1636
1637/// Extract a single expert's weight tensor from a batched 3D expert tensor.
1638///
1639/// Batched expert weights are stored as `[ne0, ne1, n_expert]` where the
1640/// expert dimension is outermost (slowest). Each expert's 2D weight has
1641/// shape `[ne0, ne1]` preserving column-major convention.
1642fn extract_expert_tensor(
1643    batched: &Tensor,
1644    expert_idx: usize,
1645) -> ModelResult<Tensor> {
1646    let shape = batched.shape();
1647    if shape.len() != 3 {
1648        return Err(ModelError::ConfigError(format!(
1649            "Expected 3D batched expert tensor, got shape {:?}",
1650            shape
1651        )));
1652    }
1653    let ne0 = shape[0];
1654    let ne1 = shape[1];
1655    let num_experts = shape[2];
1656    let expert_numel = ne0 * ne1;
1657
1658    if expert_idx >= num_experts {
1659        return Err(ModelError::ConfigError(format!(
1660            "Expert index {} out of bounds ({})",
1661            expert_idx, num_experts
1662        )));
1663    }
1664
1665    let per_expert_shape = vec![ne0, ne1];
1666
1667    if batched.dtype().is_quantized() {
1668        let block_size = batched.dtype().block_size();
1669        let block_bytes = batched.dtype().block_bytes();
1670
1671        if !expert_numel.is_multiple_of(block_size) {
1672            return Err(ModelError::ConfigError(format!(
1673                "Expert tensor elements ({}) not aligned to block size ({})",
1674                expert_numel, block_size
1675            )));
1676        }
1677
1678        let blocks_per_expert = expert_numel / block_size;
1679        let bytes_per_expert = blocks_per_expert * block_bytes;
1680        let byte_offset = expert_idx * bytes_per_expert;
1681
1682        let raw_data = batched.data();
1683        let expert_bytes = &raw_data[byte_offset..byte_offset + bytes_per_expert];
1684
1685        let mut tensor =
1686            Tensor::new(expert_bytes.to_vec(), per_expert_shape, batched.dtype())?;
1687        tensor.set_name(format!("expert.{}", expert_idx));
1688        Ok(tensor)
1689    } else {
1690        let f32_data = batched.as_f32()?;
1691        let offset = expert_idx * expert_numel;
1692        let expert_slice = &f32_data[offset..offset + expert_numel];
1693
1694        let mut tensor = Tensor::from_f32(expert_slice, per_expert_shape)?;
1695        tensor.set_name(format!("expert.{}", expert_idx));
1696        Ok(tensor)
1697    }
1698}
1699
1700/// Gemma's HuggingFace implementation uses `(1 + weight)` in RMS norm, but
1701/// the GGUF converter (`convert_hf_to_gguf.py`) already adds +1 to norm
1702/// weights during conversion. The GGUF file contains final-form weights,
1703/// so no adjustment is needed at load time. This function is kept as a no-op
1704/// identity for documentation.
1705fn apply_gemma_norm_weight_offset(weight: Tensor) -> ModelResult<Tensor> {
1706    Ok(weight)
1707}
1708
1709/// Get the DeltaNet config from any [`ModelSource`].
1710/// Returns None if the model has no SSM layers or is Mamba.
1711pub fn deltanet_config_from_source(source: &dyn ModelSource) -> Option<DeltaNetConfig> {
1712    let config = source.config();
1713    let arch = source.architecture();
1714    if !config.has_ssm()
1715        || matches!(arch, Architecture::Mamba | Architecture::Mamba2)
1716    {
1717        return None;
1718    }
1719    let d_inner = config.ssm_d_inner;
1720    let d_state = config.ssm_d_state;
1721    let num_v_heads = config.ssm_dt_rank;
1722    let num_k_heads = config.ssm_n_group.max(1);
1723    let head_v_dim = d_inner / num_v_heads.max(1);
1724    let head_k_dim = d_state;
1725    let conv_kernel = config.ssm_conv_kernel;
1726    let q_dim = num_k_heads * head_k_dim;
1727    let k_dim = num_k_heads * head_k_dim;
1728    let qkv_dim = q_dim + k_dim + d_inner;
1729
1730    Some(DeltaNetConfig {
1731        d_inner,
1732        d_state,
1733        num_v_heads,
1734        num_k_heads,
1735        head_v_dim,
1736        head_k_dim,
1737        conv_kernel,
1738        qkv_dim,
1739    })
1740}
1741
1742/// Convenience function to load a LLaMA-like model from a GGUF file
1743///
1744/// Supports all LLaMA-compatible architectures including Qwen3 MoE.
1745pub fn load_llama_model<P: AsRef<Path>>(path: P) -> ModelResult<LlamaModel> {
1746    let loader = ModelLoader::load(path)?;
1747
1748    if !loader.architecture().is_llama_like() {
1749        return Err(ModelError::UnsupportedArchitecture(
1750            loader.architecture().to_string(),
1751        ));
1752    }
1753
1754    loader.build_model()
1755}
1756
1757#[cfg(test)]
1758mod tests {
1759    use super::*;
1760    use crate::model::config::{AttentionLayerConfig, AttentionLayerType};
1761    use std::collections::HashMap;
1762
1763    /// Minimal in-memory ModelSource for exercising the assembly free functions
1764    /// without a real GGUF file.
1765    struct MockSource {
1766        tensors: HashMap<String, Tensor>,
1767        config: ModelConfig,
1768        arch: Architecture,
1769    }
1770
1771    impl ModelSource for MockSource {
1772        fn config(&self) -> &ModelConfig {
1773            &self.config
1774        }
1775        fn config_mut(&mut self) -> &mut ModelConfig {
1776            &mut self.config
1777        }
1778        fn architecture(&self) -> Architecture {
1779            self.arch
1780        }
1781        fn load_tensor(&self, name: &str) -> ModelResult<Tensor> {
1782            self.tensors
1783                .get(name)
1784                .cloned()
1785                .ok_or_else(|| ModelError::MissingTensor(name.to_string()))
1786        }
1787        fn try_load_tensor(&self, name: &str) -> Option<Tensor> {
1788            self.tensors.get(name).cloned()
1789        }
1790    }
1791
1792    /// Gemma 4 global (full-attention) layers set `attention_k_eq_v`: the value
1793    /// projection is tied to the key projection, so the GGUF omits `attn_v.weight`
1794    /// entirely. The loader must alias V to K for those layers instead of failing
1795    /// with `MissingTensor`.
1796    #[test]
1797    fn test_load_full_attention_aliases_v_to_k_when_v_absent() {
1798        let hidden = 3840usize;
1799        let num_heads = 16usize;
1800        let kv_heads = 1usize; // global layer uses a single KV head
1801        let head_dim = 512usize; // global head_dim
1802        let kl = head_dim;
1803        let q_out = num_heads * kl; // no attention gate
1804
1805        let mut tensors = HashMap::new();
1806        // GGUF convention: weight shape is [in_features, out_features].
1807        tensors.insert(
1808            "blk.1.attn_k.weight".to_string(),
1809            Tensor::zeros(vec![hidden, kv_heads * kl], DType::F32),
1810        );
1811        tensors.insert(
1812            "blk.1.attn_output.weight".to_string(),
1813            Tensor::zeros(vec![num_heads * head_dim, hidden], DType::F32),
1814        );
1815        // Deliberately NO "blk.1.attn_v.weight" — this is the global layer.
1816
1817        let mut config = ModelConfig::default();
1818        config.num_heads = num_heads;
1819        config.hidden_size = hidden;
1820        config.attn_logit_softcap = 0.0;
1821        config.attention_layer_configs = Some(vec![
1822            AttentionLayerConfig {
1823                layer_type: AttentionLayerType::Sliding,
1824                head_dim: 256,
1825                num_kv_heads: 8,
1826                rope_freq_base: 10_000.0,
1827                rope_dims: 256,
1828                sliding_window: 1024,
1829            },
1830            AttentionLayerConfig {
1831                layer_type: AttentionLayerType::Global,
1832                head_dim: 512,
1833                num_kv_heads: 1,
1834                rope_freq_base: 1_000_000.0,
1835                rope_dims: 128,
1836                sliding_window: 0,
1837            },
1838        ]);
1839
1840        let src = MockSource {
1841            tensors,
1842            config,
1843            arch: Architecture::Gemma4,
1844        };
1845        let wq = Tensor::zeros(vec![hidden, q_out], DType::F32);
1846
1847        let attn = load_full_attention(&src, 1, wq)
1848            .expect("global layer with tied K/V should load without attn_v");
1849
1850        assert_eq!(
1851            attn.wv.out_features, attn.wk.out_features,
1852            "V projection must alias K projection when attn_v is absent"
1853        );
1854        assert_eq!(attn.wv.out_features, kv_heads * kl);
1855    }
1856
1857    #[test]
1858    fn test_architecture_detection() {
1859        assert!(Architecture::Llama.is_llama_like());
1860        assert!(Architecture::Mistral.is_llama_like());
1861        assert!(Architecture::GPT2.is_llama_like());
1862        assert!(!Architecture::Bert.is_llama_like());
1863        assert!(!Architecture::Mamba.is_llama_like());
1864    }
1865}