Skip to main content

candle_mi/
config.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Transformer configuration and `HuggingFace` `config.json` parsing.
4//!
5//! [`TransformerConfig`] captures the ~12 configuration axes that distinguish
6//! modern decoder-only transformer architectures (`LLaMA`, `Qwen2`, Gemma 2,
7//! `Phi-3`, `StarCoder2`, Mistral, etc.).  One forward pass implementation
8//! covers all of them; adding a new model family requires only a new
9//! `parse_*` function (~30 lines).
10//!
11//! # Usage
12//!
13//! ```
14//! use candle_mi::TransformerConfig;
15//!
16//! let config_str = r#"{"model_type": "llama", "hidden_size": 2048,
17//!     "num_hidden_layers": 16, "num_attention_heads": 32,
18//!     "num_key_value_heads": 8, "intermediate_size": 8192,
19//!     "vocab_size": 32000, "rms_norm_eps": 1e-5,
20//!     "rope_theta": 500000.0, "max_position_embeddings": 131072}"#;
21//! let json: serde_json::Value = serde_json::from_str(config_str).unwrap();
22//! let config = TransformerConfig::from_hf_config(&json).unwrap();
23//! assert_eq!(config.num_layers, 16);
24//! ```
25
26use std::fmt;
27use std::io::Read as _;
28use std::path::Path;
29
30use serde_json::Value;
31
32use crate::error::{MIError, Result};
33
34// ---------------------------------------------------------------------------
35// Supported model types
36// ---------------------------------------------------------------------------
37
38/// `model_type` strings accepted by
39/// [`TransformerConfig::from_hf_config`].
40///
41/// Use this for cache discovery, UI filtering, or anywhere you need to know
42/// which `HuggingFace` model families the generic transformer backend handles.
43pub const SUPPORTED_MODEL_TYPES: &[&str] = &[
44    "gemma",
45    "gemma2",
46    "llama",
47    "mistral",
48    "phi3",
49    "qwen2",
50    "starcoder2",
51];
52
53// ---------------------------------------------------------------------------
54// Configuration enums
55// ---------------------------------------------------------------------------
56
57/// Layer normalization variant.
58#[non_exhaustive]
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum NormType {
61    /// Standard RMS normalization: `x * weight / sqrt(mean(x^2) + eps)`.
62    RmsNorm,
63    /// Standard layer normalization (weight + bias).
64    LayerNorm,
65    /// Gemma-style RMS norm that adds `1.0` to the learned weight:
66    /// `x * (weight + 1) / sqrt(mean(x^2) + eps)`.
67    GemmaRmsNorm,
68}
69
70impl fmt::Display for NormType {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        match self {
73            Self::RmsNorm => write!(f, "RmsNorm"),
74            Self::LayerNorm => write!(f, "LayerNorm"),
75            Self::GemmaRmsNorm => write!(f, "GemmaRmsNorm"),
76        }
77    }
78}
79
80/// Activation function used in the MLP.
81#[non_exhaustive]
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum Activation {
84    /// Sigmoid Linear Unit (used in `SwiGLU` gating).
85    Silu,
86    /// Gaussian Error Linear Unit — exact (erf) variant.
87    Gelu,
88    /// Gaussian Error Linear Unit — `PyTorch` tanh approximation.
89    ///
90    /// Used by Gemma 2, `StarCoder2`, and other models that specify
91    /// `hidden_act: "gelu_pytorch_tanh"` in their `HuggingFace` config.
92    GeluApprox,
93}
94
95impl fmt::Display for Activation {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        match self {
98            Self::Silu => write!(f, "SiLU"),
99            Self::Gelu => write!(f, "GELU"),
100            Self::GeluApprox => write!(f, "GELU (tanh approx)"),
101        }
102    }
103}
104
105/// Layout of the Q, K, V projections in the attention block.
106#[non_exhaustive]
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum QkvLayout {
109    /// Three separate linear layers: `q_proj`, `k_proj`, `v_proj`.
110    Separate,
111    /// Single fused linear layer `qkv_proj`, split via `narrow()`.
112    Fused,
113}
114
115impl fmt::Display for QkvLayout {
116    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117        match self {
118            Self::Separate => write!(f, "Separate"),
119            Self::Fused => write!(f, "Fused"),
120        }
121    }
122}
123
124/// Layout of the MLP (feed-forward network).
125#[non_exhaustive]
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum MlpLayout {
128    /// Gated MLP with separate gate and up projections:
129    /// `down(act(gate(x)) * up(x))`.
130    GatedSeparate,
131    /// Gated MLP with fused gate+up projection:
132    /// `gate_up = fused(x)`, split, then `down(act(gate) * up)`.
133    GatedFused,
134    /// Plain (non-gated) MLP: `proj(act(fc(x)))`.
135    Plain,
136}
137
138impl fmt::Display for MlpLayout {
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        match self {
141            Self::GatedSeparate => write!(f, "GatedSeparate"),
142            Self::GatedFused => write!(f, "GatedFused"),
143            Self::Plain => write!(f, "Plain"),
144        }
145    }
146}
147
148// ---------------------------------------------------------------------------
149// TransformerConfig
150// ---------------------------------------------------------------------------
151
152/// Configuration for a generic decoder-only transformer.
153///
154/// Captures ~12 configuration axes that distinguish modern transformer
155/// architectures.  Parsed from `HuggingFace` `config.json` via
156/// [`from_hf_config`](Self::from_hf_config).
157///
158/// # Supported model families
159///
160/// | Family | Key config traits |
161/// |--------|------------------|
162/// | `LLaMA` 1/2/3 | Baseline: GQA, `SiLU`, `RmsNorm` |
163/// | `Qwen` 2/2.5 | + QKV bias, conditional tied embeddings |
164/// | Gemma / Gemma 2 | + `GemmaRmsNorm`, embedding scale, soft-capping, 4-norm |
165/// | `Phi-3` / `Phi-4` | + Fused QKV, fused MLP |
166/// | `StarCoder2` | + Plain MLP, GELU, bias everywhere |
167/// | Mistral | + Sliding window attention |
168///
169/// # `config.json` field reference
170///
171/// ## Required fields (all families)
172///
173/// | Field | `config.json` key |
174/// |-------|-------------------|
175/// | — | `model_type` |
176/// | `hidden_size` | `hidden_size` |
177/// | `num_layers` | `num_hidden_layers` |
178/// | `num_attention_heads` | `num_attention_heads` |
179/// | `intermediate_size` | `intermediate_size` |
180/// | `vocab_size` | `vocab_size` |
181///
182/// ## Optional fields (all families)
183///
184/// | Field | `config.json` key | Default |
185/// |-------|-------------------|---------|
186/// | `num_kv_heads` | `num_key_value_heads` | `num_attention_heads` |
187/// | `head_dim` | `head_dim` | `hidden_size / num_attention_heads` |
188/// | `norm_eps` | `rms_norm_eps` ¹ | 1e-5 ² |
189/// | `rope_theta` | `rope_theta` | 10 000 ³ |
190/// | `max_position_embeddings` | `max_position_embeddings` | 4 096 ⁴ |
191/// | `tie_word_embeddings` | `tie_word_embeddings` | `false` ⁵ |
192///
193/// ¹ `StarCoder2` reads `norm_epsilon` instead.\
194/// ² 1e-6 for Qwen2, Gemma, Gemma 2.\
195/// ³ 1 000 000 for Qwen2.\
196/// ⁴ 32 768 for Qwen2/Mistral; 16 384 for `StarCoder2`; 8 192 for
197///   Gemma/Gemma 2; 4 096 for `LLaMA`/`Phi-3`.\
198/// ⁵ `true` for Gemma, Gemma 2, `StarCoder2`.
199///
200/// ## Hardcoded architecture axes
201///
202/// The following fields are **set by the family-specific parser**, not
203/// read from `config.json` (except where noted):
204///
205/// | Field | Description |
206/// |-------|-------------|
207/// | `norm_type` | [`RmsNorm`](NormType::RmsNorm) for most; [`GemmaRmsNorm`](NormType::GemmaRmsNorm) for Gemma/Gemma 2; read from `norm_type` key for `StarCoder2` (default [`RmsNorm`](NormType::RmsNorm), `"layer_norm"` → [`LayerNorm`](NormType::LayerNorm)) |
208/// | `activation` | [`Silu`](Activation::Silu) for `LLaMA`/Qwen2/`Phi-3`/Mistral; [`GeluApprox`](Activation::GeluApprox) for Gemma/Gemma 2/`StarCoder2` |
209/// | `qkv_layout` | [`Fused`](QkvLayout::Fused) for `Phi-3`; [`Separate`](QkvLayout::Separate) for all others |
210/// | `mlp_layout` | [`GatedFused`](MlpLayout::GatedFused) for `Phi-3`; [`Plain`](MlpLayout::Plain) for `StarCoder2`; [`GatedSeparate`](MlpLayout::GatedSeparate) for all others |
211/// | `embedding_scale` | `Some(sqrt(hidden_size))` for Gemma/Gemma 2; `None` for all others |
212/// | `use_post_norms` | `true` for Gemma 2 (4 norms per layer); `false` for all others |
213/// | `alternating_sliding_window` | `true` for Gemma 2; `false` for all others |
214///
215/// ## Per-family `config.json` extensions
216///
217/// **Qwen2** — reads `attention_bias` (default `true`) → `qkv_bias`.
218///
219/// **Gemma / Gemma 2** — hardcodes `embedding_scale` to `sqrt(hidden_size)`,
220/// `tie_word_embeddings` defaults to `true`, and `norm_eps` defaults to 1e-6.
221/// Gemma 2 additionally reads:
222///
223/// | `config.json` key | Field | Default |
224/// |-------------------|-------|---------|
225/// | `attn_logit_softcapping` | `attn_logit_softcapping` | `None` |
226/// | `final_logit_softcapping` | `final_logit_softcapping` | `None` |
227/// | `query_pre_attn_scalar` | `query_pre_attn_scalar` | `Some(256.0)` |
228/// | `sliding_window` | `sliding_window` | `None` |
229///
230/// **`Phi-3`** — no extra `config.json` keys; fused QKV and fused gated MLP
231/// are hardcoded.
232///
233/// **`StarCoder2`** — reads `use_bias` (default `true`) → `qkv_bias`,
234/// `o_proj_bias`, and `mlp_bias`.  Reads `norm_type` (default `RmsNorm`,
235/// `"layer_norm"` → `LayerNorm`).  Uses `norm_epsilon` key (not
236/// `rms_norm_eps`).  Hardcodes [`Plain`](MlpLayout::Plain) MLP and
237/// [`GeluApprox`](Activation::GeluApprox) activation.
238///
239/// **Mistral** — reads `sliding_window` (default `None`).  Otherwise
240/// identical to `LLaMA`; `max_position_embeddings` defaults to 32 768.
241#[derive(Debug, Clone, PartialEq)]
242#[allow(clippy::struct_excessive_bools)] // Config structs legitimately have many boolean axes
243pub struct TransformerConfig {
244    // --- Dimensions ----------------------------------------------------------
245    /// Hidden dimension (`d_model`).
246    pub hidden_size: usize,
247    /// Number of transformer layers (decoder blocks).
248    pub num_layers: usize,
249    /// Number of query attention heads.
250    pub num_attention_heads: usize,
251    /// Number of key/value heads (GQA when < `num_attention_heads`).
252    pub num_kv_heads: usize,
253    /// Dimension per head (usually `hidden_size / num_attention_heads`).
254    pub head_dim: usize,
255    /// MLP intermediate dimension.
256    pub intermediate_size: usize,
257    /// Vocabulary size.
258    pub vocab_size: usize,
259
260    // --- Architecture axes ---------------------------------------------------
261    /// Normalization variant.
262    pub norm_type: NormType,
263    /// Epsilon for normalization layers.
264    pub norm_eps: f64,
265    /// MLP activation function.
266    pub activation: Activation,
267    /// QKV projection layout (separate or fused).
268    pub qkv_layout: QkvLayout,
269    /// MLP layout (gated separate, gated fused, or plain).
270    pub mlp_layout: MlpLayout,
271    /// Whether Q, K, V projections have bias terms.
272    pub qkv_bias: bool,
273    /// Whether the output projection (`o_proj`) has a bias term.
274    pub o_proj_bias: bool,
275    /// Whether MLP projections have bias terms.
276    pub mlp_bias: bool,
277    /// Embedding scale factor (`Some(sqrt(hidden_size))` for Gemma models).
278    pub embedding_scale: Option<f64>,
279    /// Whether the LM head shares weights with the token embedding.
280    pub tie_word_embeddings: bool,
281
282    // --- Positional encoding -------------------------------------------------
283    /// Base frequency for rotary position embeddings.
284    pub rope_theta: f64,
285    /// Maximum sequence length for position embeddings.
286    pub max_position_embeddings: usize,
287
288    // --- Gemma 2 extensions --------------------------------------------------
289    /// Attention logit soft-capping: `tanh(scores / cap) * cap` before softmax.
290    /// `Some(50.0)` for Gemma 2; `None` for most models.
291    pub attn_logit_softcapping: Option<f64>,
292    /// Final logit soft-capping: `tanh(logits / cap) * cap` after LM head.
293    /// `Some(30.0)` for Gemma 2; `None` for most models.
294    pub final_logit_softcapping: Option<f64>,
295    /// Custom attention scaling factor.  When set, scale = `1/sqrt(scalar)`
296    /// instead of the default `1/sqrt(head_dim)`.
297    /// `Some(256.0)` for Gemma 2; `None` for most models.
298    pub query_pre_attn_scalar: Option<f64>,
299    /// Whether each layer has post-attention and post-feedforward norms
300    /// (4 norms per layer instead of 2).  `true` for Gemma 2.
301    pub use_post_norms: bool,
302
303    // --- Sliding window attention --------------------------------------------
304    /// Sliding window size.  `None` for global attention.
305    pub sliding_window: Option<usize>,
306    /// Whether sliding window alternates with global attention per layer.
307    /// When `true`, even layers (0, 2, 4, ...) use sliding window and
308    /// odd layers use global causal.  `true` for Gemma 2.
309    pub alternating_sliding_window: bool,
310}
311
312// ---------------------------------------------------------------------------
313// Config parsing — entry point
314// ---------------------------------------------------------------------------
315
316impl TransformerConfig {
317    /// Parse a [`TransformerConfig`] from a `HuggingFace` `config.json` value.
318    ///
319    /// Dispatches on the `model_type` field to a family-specific parser.
320    /// See the [`TransformerConfig`] struct-level documentation for the
321    /// full field reference (required/optional keys, defaults, and
322    /// per-family extensions).
323    ///
324    /// # Errors
325    ///
326    /// Returns [`MIError::Config`] if `model_type` is missing, unsupported,
327    /// or if required fields are absent.
328    pub fn from_hf_config(config: &Value) -> Result<Self> {
329        let model_type = config
330            .get("model_type")
331            .and_then(Value::as_str)
332            .ok_or_else(|| MIError::Config("missing 'model_type' field".into()))?;
333
334        // Keep in sync with SUPPORTED_MODEL_TYPES.
335        match model_type {
336            "llama" => Self::parse_llama(config),
337            "qwen2" => Self::parse_qwen2(config),
338            "gemma" => Self::parse_gemma(config),
339            "gemma2" => Self::parse_gemma2(config),
340            "phi3" => Self::parse_phi3(config),
341            "starcoder2" => Self::parse_starcoder2(config),
342            "mistral" => Self::parse_mistral(config),
343            other => Err(MIError::Config(format!(
344                "unsupported model_type: '{other}'"
345            ))),
346        }
347    }
348}
349
350// ---------------------------------------------------------------------------
351// Per-family config parsers
352// ---------------------------------------------------------------------------
353
354impl TransformerConfig {
355    /// Parse a `LLaMA`-family config (`LLaMA` 1/2/3, `Code-LLaMA`).
356    ///
357    /// Simplest baseline: no bias, no embedding scale, no sliding window,
358    /// separate LM head (unless `tie_word_embeddings` is set).
359    ///
360    /// # Errors
361    ///
362    /// Returns [`MIError::Config`] if required dimension fields are missing.
363    fn parse_llama(config: &Value) -> Result<Self> {
364        let hidden_size = get_usize(config, "hidden_size")?;
365        let num_attention_heads = get_usize(config, "num_attention_heads")?;
366
367        Ok(Self {
368            hidden_size,
369            num_layers: get_usize(config, "num_hidden_layers")?,
370            num_attention_heads,
371            num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
372            head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
373            intermediate_size: get_usize(config, "intermediate_size")?,
374            vocab_size: get_usize(config, "vocab_size")?,
375
376            norm_type: NormType::RmsNorm,
377            norm_eps: get_f64_or(config, "rms_norm_eps", 1e-5),
378            activation: Activation::Silu,
379            qkv_layout: QkvLayout::Separate,
380            mlp_layout: MlpLayout::GatedSeparate,
381            qkv_bias: false,
382            o_proj_bias: false,
383            mlp_bias: false,
384            embedding_scale: None,
385            tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
386
387            rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
388            max_position_embeddings: get_usize_or(config, "max_position_embeddings", 4096),
389
390            attn_logit_softcapping: None,
391            final_logit_softcapping: None,
392            query_pre_attn_scalar: None,
393            use_post_norms: false,
394            sliding_window: None,
395            alternating_sliding_window: false,
396        })
397    }
398
399    /// Parse a Qwen2/Qwen2.5 config.
400    ///
401    /// Adds QKV bias and conditional tied embeddings on top of the
402    /// `LLaMA` baseline.
403    ///
404    /// # Errors
405    ///
406    /// Returns [`MIError::Config`] if required dimension fields are missing.
407    fn parse_qwen2(config: &Value) -> Result<Self> {
408        let hidden_size = get_usize(config, "hidden_size")?;
409        let num_attention_heads = get_usize(config, "num_attention_heads")?;
410
411        Ok(Self {
412            hidden_size,
413            num_layers: get_usize(config, "num_hidden_layers")?,
414            num_attention_heads,
415            num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
416            head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
417            intermediate_size: get_usize(config, "intermediate_size")?,
418            vocab_size: get_usize(config, "vocab_size")?,
419
420            norm_type: NormType::RmsNorm,
421            norm_eps: get_f64_or(config, "rms_norm_eps", 1e-6),
422            activation: Activation::Silu,
423            qkv_layout: QkvLayout::Separate,
424            mlp_layout: MlpLayout::GatedSeparate,
425            qkv_bias: get_bool_or(config, "attention_bias", true),
426            o_proj_bias: false,
427            mlp_bias: false,
428            embedding_scale: None,
429            tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
430
431            rope_theta: get_f64_or(config, "rope_theta", 1_000_000.0),
432            max_position_embeddings: get_usize_or(config, "max_position_embeddings", 32_768),
433
434            attn_logit_softcapping: None,
435            final_logit_softcapping: None,
436            query_pre_attn_scalar: None,
437            use_post_norms: false,
438            sliding_window: None,
439            alternating_sliding_window: false,
440        })
441    }
442
443    /// Parse a Gemma config (Gemma 1, `CodeGemma`).
444    ///
445    /// Adds `GemmaRmsNorm` (weight + 1), sqrt embedding scale, and GELU.
446    ///
447    /// # Errors
448    ///
449    /// Returns [`MIError::Config`] if required dimension fields are missing.
450    fn parse_gemma(config: &Value) -> Result<Self> {
451        let hidden_size = get_usize(config, "hidden_size")?;
452        let num_attention_heads = get_usize(config, "num_attention_heads")?;
453
454        Ok(Self {
455            hidden_size,
456            num_layers: get_usize(config, "num_hidden_layers")?,
457            num_attention_heads,
458            num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
459            head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
460            intermediate_size: get_usize(config, "intermediate_size")?,
461            vocab_size: get_usize(config, "vocab_size")?,
462
463            norm_type: NormType::GemmaRmsNorm,
464            norm_eps: get_f64_or(config, "rms_norm_eps", 1e-6),
465            activation: Activation::GeluApprox,
466            qkv_layout: QkvLayout::Separate,
467            mlp_layout: MlpLayout::GatedSeparate,
468            qkv_bias: false,
469            o_proj_bias: false,
470            mlp_bias: false,
471            // CAST: usize → f64, hidden_size fits in f64 mantissa (d_model <= 2^52)
472            #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
473            // PROMOTE: embedding scale is sqrt(hidden_size); precision loss negligible for d_model <= 2^52
474            embedding_scale: Some((hidden_size as f64).sqrt()),
475            tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", true),
476
477            rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
478            max_position_embeddings: get_usize_or(
479                config,
480                "max_position_embeddings",
481                8192,
482            ),
483
484            attn_logit_softcapping: None,
485            final_logit_softcapping: None,
486            query_pre_attn_scalar: None,
487            use_post_norms: false,
488            sliding_window: None,
489            alternating_sliding_window: false,
490        })
491    }
492
493    /// Parse a Gemma 2 config.
494    ///
495    /// Adds attention/final logit soft-capping, 4-norm layers,
496    /// `query_pre_attn_scalar`, and alternating sliding window attention.
497    ///
498    /// # Errors
499    ///
500    /// Returns [`MIError::Config`] if required dimension fields are missing.
501    fn parse_gemma2(config: &Value) -> Result<Self> {
502        let hidden_size = get_usize(config, "hidden_size")?;
503        let num_attention_heads = get_usize(config, "num_attention_heads")?;
504
505        Ok(Self {
506            hidden_size,
507            num_layers: get_usize(config, "num_hidden_layers")?,
508            num_attention_heads,
509            num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
510            head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
511            intermediate_size: get_usize(config, "intermediate_size")?,
512            vocab_size: get_usize(config, "vocab_size")?,
513
514            norm_type: NormType::GemmaRmsNorm,
515            norm_eps: get_f64_or(config, "rms_norm_eps", 1e-6),
516            activation: Activation::GeluApprox,
517            qkv_layout: QkvLayout::Separate,
518            mlp_layout: MlpLayout::GatedSeparate,
519            qkv_bias: false,
520            o_proj_bias: false,
521            mlp_bias: false,
522            // CAST: usize → f64, hidden_size fits in f64 mantissa (d_model <= 2^52)
523            #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
524            // PROMOTE: embedding scale is sqrt(hidden_size); precision loss negligible for d_model <= 2^52
525            embedding_scale: Some((hidden_size as f64).sqrt()),
526            tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", true),
527
528            rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
529            max_position_embeddings: get_usize_or(
530                config,
531                "max_position_embeddings",
532                8192,
533            ),
534
535            attn_logit_softcapping: get_optional_f64(config, "attn_logit_softcapping"),
536            final_logit_softcapping: get_optional_f64(config, "final_logit_softcapping"),
537            query_pre_attn_scalar: get_optional_f64(config, "query_pre_attn_scalar")
538                .or(Some(256.0)),
539            use_post_norms: true,
540            sliding_window: get_optional_usize(config, "sliding_window"),
541            alternating_sliding_window: true,
542        })
543    }
544
545    /// Parse a Phi-3 config.
546    ///
547    /// Adds fused QKV projection and fused gate+up MLP projection.
548    ///
549    /// # Errors
550    ///
551    /// Returns [`MIError::Config`] if required dimension fields are missing.
552    fn parse_phi3(config: &Value) -> Result<Self> {
553        let hidden_size = get_usize(config, "hidden_size")?;
554        let num_attention_heads = get_usize(config, "num_attention_heads")?;
555
556        Ok(Self {
557            hidden_size,
558            num_layers: get_usize(config, "num_hidden_layers")?,
559            num_attention_heads,
560            num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
561            head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
562            intermediate_size: get_usize(config, "intermediate_size")?,
563            vocab_size: get_usize(config, "vocab_size")?,
564
565            norm_type: NormType::RmsNorm,
566            norm_eps: get_f64_or(config, "rms_norm_eps", 1e-5),
567            activation: Activation::Silu,
568            qkv_layout: QkvLayout::Fused,
569            mlp_layout: MlpLayout::GatedFused,
570            qkv_bias: false,
571            o_proj_bias: false,
572            mlp_bias: false,
573            embedding_scale: None,
574            tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
575
576            rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
577            max_position_embeddings: get_usize_or(config, "max_position_embeddings", 4096),
578
579            attn_logit_softcapping: None,
580            final_logit_softcapping: None,
581            query_pre_attn_scalar: None,
582            use_post_norms: false,
583            sliding_window: None,
584            alternating_sliding_window: false,
585        })
586    }
587
588    /// Parse a `StarCoder2` config.
589    ///
590    /// Adds plain (non-gated) MLP, GELU activation, and bias on all
591    /// projections.
592    ///
593    /// # Errors
594    ///
595    /// Returns [`MIError::Config`] if required dimension fields are missing.
596    fn parse_starcoder2(config: &Value) -> Result<Self> {
597        let hidden_size = get_usize(config, "hidden_size")?;
598        let num_attention_heads = get_usize(config, "num_attention_heads")?;
599        let use_bias = get_bool_or(config, "use_bias", true);
600
601        // StarCoder2 specifies norm_type in config (usually "layer_norm").
602        let norm_type = match config.get("norm_type").and_then(Value::as_str) {
603            Some("layer_norm") => NormType::LayerNorm,
604            _ => NormType::RmsNorm,
605        };
606
607        Ok(Self {
608            hidden_size,
609            num_layers: get_usize(config, "num_hidden_layers")?,
610            num_attention_heads,
611            num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
612            head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
613            intermediate_size: get_usize(config, "intermediate_size")?,
614            vocab_size: get_usize(config, "vocab_size")?,
615
616            norm_type,
617            norm_eps: get_f64_or(config, "norm_epsilon", 1e-5),
618            activation: Activation::GeluApprox,
619            qkv_layout: QkvLayout::Separate,
620            mlp_layout: MlpLayout::Plain,
621            qkv_bias: use_bias,
622            o_proj_bias: use_bias,
623            mlp_bias: use_bias,
624            embedding_scale: None,
625            tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", true),
626
627            rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
628            max_position_embeddings: get_usize_or(config, "max_position_embeddings", 16_384),
629
630            attn_logit_softcapping: None,
631            final_logit_softcapping: None,
632            query_pre_attn_scalar: None,
633            use_post_norms: false,
634            sliding_window: get_optional_usize(config, "sliding_window"),
635            alternating_sliding_window: false,
636        })
637    }
638
639    /// Parse a Mistral config.
640    ///
641    /// LLaMA-like with sliding window attention on all layers.
642    ///
643    /// # Errors
644    ///
645    /// Returns [`MIError::Config`] if required dimension fields are missing.
646    fn parse_mistral(config: &Value) -> Result<Self> {
647        let hidden_size = get_usize(config, "hidden_size")?;
648        let num_attention_heads = get_usize(config, "num_attention_heads")?;
649
650        Ok(Self {
651            hidden_size,
652            num_layers: get_usize(config, "num_hidden_layers")?,
653            num_attention_heads,
654            num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
655            head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
656            intermediate_size: get_usize(config, "intermediate_size")?,
657            vocab_size: get_usize(config, "vocab_size")?,
658
659            norm_type: NormType::RmsNorm,
660            norm_eps: get_f64_or(config, "rms_norm_eps", 1e-5),
661            activation: Activation::Silu,
662            qkv_layout: QkvLayout::Separate,
663            mlp_layout: MlpLayout::GatedSeparate,
664            qkv_bias: false,
665            o_proj_bias: false,
666            mlp_bias: false,
667            embedding_scale: None,
668            tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
669
670            rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
671            max_position_embeddings: get_usize_or(config, "max_position_embeddings", 32_768),
672
673            attn_logit_softcapping: None,
674            final_logit_softcapping: None,
675            query_pre_attn_scalar: None,
676            use_post_norms: false,
677            sliding_window: get_optional_usize(config, "sliding_window"),
678            alternating_sliding_window: false,
679        })
680    }
681}
682
683// ---------------------------------------------------------------------------
684// JSON extraction helpers
685// ---------------------------------------------------------------------------
686
687/// Extract a required `usize` field from a JSON object.
688pub(crate) fn get_usize(config: &Value, key: &str) -> Result<usize> {
689    let val = config
690        .get(key)
691        .and_then(Value::as_u64)
692        .ok_or_else(|| MIError::Config(format!("missing or invalid field '{key}'")))?;
693    usize::try_from(val)
694        .map_err(|_| MIError::Config(format!("field '{key}' value {val} overflows usize")))
695}
696
697/// Extract an optional `usize` field, returning a default if absent.
698pub(crate) fn get_usize_or(config: &Value, key: &str, default: usize) -> usize {
699    config
700        .get(key)
701        .and_then(Value::as_u64)
702        .and_then(|v| usize::try_from(v).ok())
703        .unwrap_or(default)
704}
705
706/// Extract an optional `usize` field, returning `None` if absent.
707pub(crate) fn get_optional_usize(config: &Value, key: &str) -> Option<usize> {
708    config
709        .get(key)
710        .and_then(Value::as_u64)
711        .and_then(|v| usize::try_from(v).ok())
712}
713
714/// Extract an `f64` field, returning a default if absent.
715pub(crate) fn get_f64_or(config: &Value, key: &str, default: f64) -> f64 {
716    config.get(key).and_then(Value::as_f64).unwrap_or(default)
717}
718
719/// Extract an optional `f64` field, returning `None` if absent.
720pub(crate) fn get_optional_f64(config: &Value, key: &str) -> Option<f64> {
721    config.get(key).and_then(Value::as_f64)
722}
723
724/// Extract a `bool` field, returning a default if absent.
725pub(crate) fn get_bool_or(config: &Value, key: &str, default: bool) -> bool {
726    config.get(key).and_then(Value::as_bool).unwrap_or(default)
727}
728
729/// Extract `head_dim`, falling back to `hidden_size / num_attention_heads`.
730pub(crate) fn get_head_dim(
731    config: &Value,
732    hidden_size: usize,
733    num_attention_heads: usize,
734) -> Result<usize> {
735    // Explicit head_dim in config takes precedence.
736    let explicit = config.get("head_dim").and_then(Value::as_u64).map(|hd| {
737        usize::try_from(hd).map_err(|_| MIError::Config("head_dim overflows usize".into()))
738    });
739
740    match explicit {
741        Some(result) => result,
742        None if num_attention_heads == 0 => Err(MIError::Config(
743            "num_attention_heads is 0, cannot compute head_dim".into(),
744        )),
745        None => Ok(hidden_size / num_attention_heads),
746    }
747}
748
749// ---------------------------------------------------------------------------
750// Activation string parsing
751// ---------------------------------------------------------------------------
752
753/// Infer [`Activation`] from `hidden_activation` or `hidden_act` config fields.
754///
755/// Prefers `hidden_activation` (used by Gemma 2) over `hidden_act`.
756/// Defaults to [`Activation::Silu`] when neither field is present.
757fn parse_activation_str(config: &Value) -> Activation {
758    let act_str = config
759        .get("hidden_activation")
760        .or_else(|| config.get("hidden_act"))
761        .and_then(Value::as_str);
762    match act_str {
763        Some("gelu_pytorch_tanh") => Activation::GeluApprox,
764        Some("gelu") => Activation::Gelu,
765        _ => Activation::Silu,
766    }
767}
768
769// ---------------------------------------------------------------------------
770// Tensor name utilities
771// ---------------------------------------------------------------------------
772
773/// Extract tensor names from a single `.safetensors` file header.
774///
775/// Reads only the JSON header (first 8 bytes = length, then header bytes);
776/// no weight data is loaded.
777///
778/// # Errors
779///
780/// Returns [`MIError::Io`] on read failure, [`MIError::Config`] if the
781/// header is malformed.
782pub fn tensor_names_from_safetensors(path: &Path) -> Result<Vec<String>> {
783    let mut file = std::fs::File::open(path)?;
784    let mut len_buf = [0u8; 8];
785    file.read_exact(&mut len_buf)?;
786    let header_len = u64::from_le_bytes(len_buf);
787    let header_len = usize::try_from(header_len)
788        .map_err(|_| MIError::Config("safetensors header length overflows usize".into()))?;
789    let mut header_buf = vec![0u8; header_len];
790    file.read_exact(&mut header_buf)?;
791    let header: Value = serde_json::from_slice(&header_buf)
792        .map_err(|e| MIError::Config(format!("failed to parse safetensors header: {e}")))?;
793    let obj = header
794        .as_object()
795        .ok_or_else(|| MIError::Config("safetensors header is not a JSON object".into()))?;
796    Ok(obj
797        .keys()
798        .filter(|k| *k != "__metadata__")
799        .cloned()
800        .collect())
801}
802
803/// Extract tensor names from a `model.safetensors.index.json` index file.
804///
805/// Reads the `weight_map` keys from the sharded model index.
806///
807/// # Errors
808///
809/// Returns [`MIError::Io`] on read failure, [`MIError::Config`] if the
810/// index is malformed or missing `weight_map`.
811pub fn tensor_names_from_index(path: &Path) -> Result<Vec<String>> {
812    let content = std::fs::read_to_string(path)?;
813    let index: Value = serde_json::from_str(&content)
814        .map_err(|e| MIError::Config(format!("failed to parse safetensors index: {e}")))?;
815    let weight_map = index
816        .get("weight_map")
817        .and_then(Value::as_object)
818        .ok_or_else(|| MIError::Config("missing 'weight_map' in safetensors index".into()))?;
819    Ok(weight_map.keys().cloned().collect())
820}
821
822// ---------------------------------------------------------------------------
823// Auto-config: generic parser for unknown model families
824// ---------------------------------------------------------------------------
825
826impl TransformerConfig {
827    /// Parse a [`TransformerConfig`] from a `HuggingFace` `config.json` value
828    /// and safetensors tensor names.
829    ///
830    /// Two-tier dispatch:
831    /// - **Known families** (listed in [`SUPPORTED_MODEL_TYPES`]): delegates to
832    ///   the existing manually-validated parser via [`from_hf_config`](Self::from_hf_config).
833    /// - **Unknown families**: auto-detects architecture axes from `config.json`
834    ///   scalars and safetensors tensor names (QKV/MLP layout, bias flags, norm
835    ///   type, post-norms), with `model_type`-based fixups for Gemma-family
836    ///   traits.
837    ///
838    /// `tensor_names` should contain all tensor names from the model's
839    /// safetensors file(s).  Use [`tensor_names_from_safetensors`] or
840    /// [`tensor_names_from_index`] to obtain them without loading weights.
841    ///
842    /// # Errors
843    ///
844    /// Returns [`MIError::Config`] if `model_type` is missing or if required
845    /// dimension fields are absent.
846    pub fn from_hf_config_auto(config: &Value, tensor_names: &[String]) -> Result<Self> {
847        let model_type = config
848            .get("model_type")
849            .and_then(Value::as_str)
850            .ok_or_else(|| MIError::Config("missing 'model_type' field".into()))?;
851
852        // Known families: use existing manually-validated parsers
853        if SUPPORTED_MODEL_TYPES.contains(&model_type) {
854            return Self::from_hf_config(config);
855        }
856
857        // Unknown families: auto-detect from config.json + tensor names
858        Self::parse_auto(config, tensor_names, model_type)
859    }
860
861    /// Auto-detect a [`TransformerConfig`] from `config.json` scalars and
862    /// safetensors tensor names.
863    ///
864    /// Uses a four-tier inference strategy:
865    /// 1. Required scalars from `config.json`
866    /// 2. Optional scalars from `config.json` with sensible defaults
867    /// 3. Architecture axes inferred from layer-0 tensor names
868    /// 4. `model_type`-based fixups (Gemma `RmsNorm`, embedding scale)
869    ///
870    /// # Errors
871    ///
872    /// Returns [`MIError::Config`] if required dimension fields are missing.
873    #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
874    fn parse_auto(config: &Value, tensor_names: &[String], model_type: &str) -> Result<Self> {
875        // Helper: check if a tensor matching `layers.0.<suffix>` exists
876        let has_layer0 = |suffix: &str| {
877            tensor_names
878                .iter()
879                .any(|n| n.contains("layers.0.") && n.ends_with(suffix))
880        };
881
882        // --- Tier 1: Required scalars ---
883        let hidden_size = get_usize(config, "hidden_size")?;
884        let num_attention_heads = get_usize(config, "num_attention_heads")?;
885
886        // --- Tier 2: Optional scalars ---
887        let norm_eps = config
888            .get("rms_norm_eps")
889            .and_then(Value::as_f64)
890            .or_else(|| config.get("norm_epsilon").and_then(Value::as_f64))
891            .unwrap_or(1e-5);
892
893        let activation = parse_activation_str(config);
894
895        // Sliding window: respect `use_sliding_window: false` (Qwen2)
896        let sliding_window =
897            if config.get("use_sliding_window").and_then(Value::as_bool) == Some(false) {
898                None
899            } else {
900                get_optional_usize(config, "sliding_window")
901            };
902
903        // tie_word_embeddings: config.json field, fallback to tensor name check
904        let tie_word_embeddings = config
905            .get("tie_word_embeddings")
906            .and_then(Value::as_bool)
907            .unwrap_or_else(|| !tensor_names.iter().any(|n| n == "lm_head.weight"));
908
909        // Gemma 2 extensions (Tier 2 — read from config.json if present)
910        let attn_logit_softcapping = get_optional_f64(config, "attn_logit_softcapping");
911        let final_logit_softcapping = get_optional_f64(config, "final_logit_softcapping");
912        let query_pre_attn_scalar = get_optional_f64(config, "query_pre_attn_scalar");
913
914        // --- Tier 3: Tensor name inference ---
915
916        // QKV layout
917        let qkv_layout = if has_layer0("self_attn.qkv_proj.weight") {
918            QkvLayout::Fused
919        } else {
920            QkvLayout::Separate
921        };
922
923        // MLP layout
924        let mlp_layout = if has_layer0("mlp.gate_up_proj.weight") {
925            MlpLayout::GatedFused
926        } else if has_layer0("mlp.gate_proj.weight") {
927            MlpLayout::GatedSeparate
928        } else if has_layer0("mlp.c_fc.weight") {
929            MlpLayout::Plain
930        } else {
931            MlpLayout::GatedSeparate // safest default for decoder-only transformers
932        };
933
934        // Bias flags
935        let qkv_bias = has_layer0("self_attn.q_proj.bias") || has_layer0("self_attn.qkv_proj.bias");
936        let o_proj_bias = has_layer0("self_attn.o_proj.bias");
937        let mlp_bias = has_layer0("mlp.down_proj.bias")
938            || has_layer0("mlp.c_fc.bias")
939            || has_layer0("mlp.gate_proj.bias")
940            || has_layer0("mlp.gate_up_proj.bias");
941
942        // Norm type: LayerNorm if norm layers have bias tensors
943        let has_norm_bias = has_layer0("input_layernorm.bias");
944        let base_norm_type = if has_norm_bias {
945            NormType::LayerNorm
946        } else {
947            NormType::RmsNorm
948        };
949
950        // Post-norms (4-norm layers, Gemma 2 style)
951        let use_post_norms = has_layer0("post_feedforward_layernorm.weight")
952            || has_layer0("pre_feedforward_layernorm.weight");
953
954        // --- Tier 4: model_type fixups ---
955        let is_gemma = model_type.contains("gemma");
956
957        let norm_type = if is_gemma {
958            NormType::GemmaRmsNorm
959        } else {
960            base_norm_type
961        };
962
963        // CAST: usize → f64, hidden_size fits in f64 mantissa (d_model <= 2^52)
964        // PROMOTE: embedding scale is sqrt(hidden_size); precision loss negligible for d_model <= 2^52
965        let embedding_scale = if is_gemma {
966            Some((hidden_size as f64).sqrt())
967        } else {
968            None
969        };
970
971        let alternating_sliding_window = is_gemma && use_post_norms;
972
973        // Gemma 2-like models default query_pre_attn_scalar to 256
974        let query_pre_attn_scalar = if is_gemma && use_post_norms {
975            query_pre_attn_scalar.or(Some(256.0))
976        } else {
977            query_pre_attn_scalar
978        };
979
980        Ok(Self {
981            hidden_size,
982            num_layers: get_usize(config, "num_hidden_layers")?,
983            num_attention_heads,
984            num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
985            head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
986            intermediate_size: get_usize(config, "intermediate_size")?,
987            vocab_size: get_usize(config, "vocab_size")?,
988
989            norm_type,
990            norm_eps,
991            activation,
992            qkv_layout,
993            mlp_layout,
994            qkv_bias,
995            o_proj_bias,
996            mlp_bias,
997            embedding_scale,
998            tie_word_embeddings,
999
1000            rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
1001            max_position_embeddings: get_usize_or(config, "max_position_embeddings", 4096),
1002
1003            attn_logit_softcapping,
1004            final_logit_softcapping,
1005            query_pre_attn_scalar,
1006            use_post_norms,
1007            sliding_window,
1008            alternating_sliding_window,
1009        })
1010    }
1011}
1012
1013// ---------------------------------------------------------------------------
1014// Auto-config compatibility check
1015// ---------------------------------------------------------------------------
1016
1017/// Result of a compatibility check for auto-config loading.
1018///
1019/// Returned by [`TransformerConfig::check_auto_compatibility`].
1020#[derive(Debug, Clone)]
1021pub struct CompatibilityReport {
1022    /// Whether the model is loadable by `GenericTransformer`.
1023    pub compatible: bool,
1024    /// Human-readable issues found (empty if compatible).
1025    pub issues: Vec<String>,
1026}
1027
1028impl CompatibilityReport {
1029    /// Returns `Ok(())` if compatible, or [`MIError::Config`] with a
1030    /// diagnostic summary of all issues.
1031    ///
1032    /// # Errors
1033    ///
1034    /// Returns [`MIError::Config`] listing all detected incompatibilities.
1035    pub fn into_result(self) -> Result<()> {
1036        if self.compatible {
1037            Ok(())
1038        } else {
1039            Err(MIError::Config(format!(
1040                "model is not compatible with GenericTransformer:\n  - {}",
1041                self.issues.join("\n  - ")
1042            )))
1043        }
1044    }
1045}
1046
1047impl TransformerConfig {
1048    /// Check whether `config.json` contains the required fields for auto-config.
1049    ///
1050    /// This is a lightweight check that does not require tensor names or
1051    /// downloading weights.  It validates that the five required scalar
1052    /// fields (`hidden_size`, `num_hidden_layers`, `num_attention_heads`,
1053    /// `intermediate_size`, `vocab_size`) are present.
1054    ///
1055    /// A passing check does **not** guarantee full compatibility — use
1056    /// [`check_auto_compatibility`](Self::check_auto_compatibility) with
1057    /// tensor names for a definitive answer.
1058    #[must_use]
1059    pub fn check_config_fields(config: &Value) -> CompatibilityReport {
1060        let required = [
1061            "hidden_size",
1062            "num_hidden_layers",
1063            "num_attention_heads",
1064            "intermediate_size",
1065            "vocab_size",
1066        ];
1067        let mut issues = Vec::new();
1068        for key in &required {
1069            if config.get(*key).and_then(Value::as_u64).is_none() {
1070                issues.push(format!("missing or invalid required field '{key}'"));
1071            }
1072        }
1073        CompatibilityReport {
1074            compatible: issues.is_empty(),
1075            issues,
1076        }
1077    }
1078
1079    /// Check whether a model is fully compatible with `GenericTransformer`
1080    /// auto-config loading.
1081    ///
1082    /// Validates both `config.json` fields and safetensors tensor names
1083    /// against the patterns `GenericTransformer::load()` expects.  Call
1084    /// this after downloading but before loading to get a clear diagnostic
1085    /// instead of a cryptic "tensor not found" error.
1086    ///
1087    /// Checks performed:
1088    /// - Required `config.json` scalars are present
1089    /// - Embedding tensor (`model.embed_tokens.weight`) exists
1090    /// - Layer-0 normalization tensors exist (`input_layernorm.weight`,
1091    ///   `post_attention_layernorm.weight`)
1092    /// - Final norm tensor (`model.norm.weight`) exists
1093    /// - At least one recognized attention projection pattern
1094    /// - At least one recognized MLP projection pattern
1095    /// - `lm_head.weight` exists when `tie_word_embeddings` is false
1096    #[must_use]
1097    pub fn check_auto_compatibility(
1098        config: &Value,
1099        tensor_names: &[String],
1100    ) -> CompatibilityReport {
1101        let mut issues = Vec::new();
1102
1103        // --- Config field checks ---
1104        let field_report = Self::check_config_fields(config);
1105        issues.extend(field_report.issues);
1106
1107        // --- Tensor name checks (with "did you mean?" hints) ---
1108        let has_tensor_issues = check_tensor_names(config, tensor_names, &mut issues);
1109
1110        // --- Summary of actual naming convention (when tensor checks fail) ---
1111        if has_tensor_issues && !tensor_names.is_empty() {
1112            if let Some(hint) = detect_naming_convention(tensor_names) {
1113                issues.push(hint);
1114            }
1115        }
1116
1117        CompatibilityReport {
1118            compatible: issues.is_empty(),
1119            issues,
1120        }
1121    }
1122}
1123
1124/// Check safetensors tensor names against the patterns `GenericTransformer`
1125/// expects, appending actionable diagnostics (with "did you mean?" hints)
1126/// to `issues`.
1127///
1128/// Returns `true` if any tensor-name issue was found.
1129#[allow(clippy::too_many_lines)]
1130fn check_tensor_names(config: &Value, tensor_names: &[String], issues: &mut Vec<String>) -> bool {
1131    // Helper: check if a tensor name exists
1132    let has = |name: &str| tensor_names.iter().any(|n| n == name);
1133    let has_layer0 = |suffix: &str| {
1134        tensor_names
1135            .iter()
1136            .any(|n| n.contains("layers.0.") && n.ends_with(suffix))
1137    };
1138
1139    // Helper: find tensors matching a keyword (for "did you mean?" hints)
1140    let find_matching = |keyword: &str, limit: usize| -> Vec<&str> {
1141        tensor_names
1142            .iter()
1143            .filter(|n| n.to_lowercase().contains(keyword))
1144            .take(limit)
1145            .map(String::as_str)
1146            .collect::<Vec<_>>()
1147    };
1148
1149    let mut has_issues = false;
1150
1151    // --- Embedding ---
1152    if !has("model.embed_tokens.weight") {
1153        has_issues = true;
1154        let found: Vec<&str> = tensor_names
1155            .iter()
1156            .filter(|n| n.contains("embed") || n.contains("wte") || n.contains("word_embeddings"))
1157            .take(3)
1158            .map(String::as_str)
1159            .collect();
1160        let hint = if found.is_empty() {
1161            String::new()
1162        } else {
1163            format!("; found embedding-like tensors: {}", found.join(", "))
1164        };
1165        issues.push(format!(
1166            "missing embedding tensor 'model.embed_tokens.weight'{hint}"
1167        ));
1168    }
1169
1170    // --- Layer-0 normalization ---
1171    if !has_layer0("input_layernorm.weight") {
1172        has_issues = true;
1173        let found = find_matching("norm", 4);
1174        let hint = if found.is_empty() {
1175            String::new()
1176        } else {
1177            format!("; found norm-like tensors: {}", found.join(", "))
1178        };
1179        issues.push(format!(
1180            "missing normalization tensor \
1181             'model.layers.0.input_layernorm.weight'{hint}"
1182        ));
1183    }
1184    if !has_layer0("post_attention_layernorm.weight")
1185        && !has_layer0("pre_feedforward_layernorm.weight")
1186    {
1187        has_issues = true;
1188        issues.push(
1189            "missing normalization tensor \
1190             'model.layers.0.post_attention_layernorm.weight'"
1191                .into(),
1192        );
1193    }
1194
1195    // --- Final norm ---
1196    if !has("model.norm.weight") {
1197        has_issues = true;
1198        let found: Vec<&str> = tensor_names
1199            .iter()
1200            .filter(|n| {
1201                (n.contains("ln_f") || n.contains("final_layer_norm") || n.contains("ln_out"))
1202                    && n.ends_with(".weight")
1203            })
1204            .take(2)
1205            .map(String::as_str)
1206            .collect();
1207        let hint = if found.is_empty() {
1208            String::new()
1209        } else {
1210            format!("; found final-norm-like tensors: {}", found.join(", "))
1211        };
1212        issues.push(format!(
1213            "missing final norm tensor 'model.norm.weight'{hint}"
1214        ));
1215    }
1216
1217    // --- Attention projections ---
1218    let has_separate_attn = has_layer0("self_attn.q_proj.weight");
1219    let has_fused_attn = has_layer0("self_attn.qkv_proj.weight");
1220    if !has_separate_attn && !has_fused_attn {
1221        has_issues = true;
1222        let found = find_matching("attn", 4);
1223        let hint = if found.is_empty() {
1224            String::new()
1225        } else {
1226            format!("; found attention-like tensors: {}", found.join(", "))
1227        };
1228        issues.push(format!(
1229            "missing attention projections: expected \
1230             'self_attn.q_proj.weight' or 'self_attn.qkv_proj.weight'{hint}"
1231        ));
1232    }
1233
1234    // --- MLP projections ---
1235    let has_gated_separate = has_layer0("mlp.gate_proj.weight");
1236    let has_gated_fused = has_layer0("mlp.gate_up_proj.weight");
1237    let has_plain = has_layer0("mlp.c_fc.weight");
1238    // Also accept down_proj as evidence of a recognized MLP
1239    let has_down = has_layer0("mlp.down_proj.weight");
1240    if !has_gated_separate && !has_gated_fused && !has_plain && !has_down {
1241        has_issues = true;
1242        let found: Vec<&str> = tensor_names
1243            .iter()
1244            .filter(|n| n.contains("mlp") || n.contains("ffn") || n.contains("fc"))
1245            .take(4)
1246            .map(String::as_str)
1247            .collect();
1248        let hint = if found.is_empty() {
1249            String::new()
1250        } else {
1251            format!("; found MLP-like tensors: {}", found.join(", "))
1252        };
1253        issues.push(format!(
1254            "missing MLP projections: expected 'mlp.gate_proj.weight', \
1255             'mlp.gate_up_proj.weight', or 'mlp.c_fc.weight'{hint}"
1256        ));
1257    }
1258
1259    // --- LM head ---
1260    let tie = config
1261        .get("tie_word_embeddings")
1262        .and_then(Value::as_bool)
1263        .unwrap_or_else(|| !tensor_names.iter().any(|n| n == "lm_head.weight"));
1264    if !tie && !has("lm_head.weight") {
1265        issues.push("tie_word_embeddings is false but 'lm_head.weight' tensor is missing".into());
1266    }
1267
1268    has_issues
1269}
1270
1271/// Detect known non-standard weight naming conventions and produce a
1272/// human-readable hint explaining why the model is incompatible.
1273///
1274/// Returns `None` if the naming convention is unrecognized.
1275fn detect_naming_convention(tensor_names: &[String]) -> Option<String> {
1276    // Known non-standard prefix patterns
1277    let patterns: &[(&str, &str)] = &[
1278        (
1279            "transformer.h.",
1280            "GPT-2 / GPT-J / GPT-NeoX (uses 'transformer.h.{i}' prefix)",
1281        ),
1282        (
1283            "transformer.blocks.",
1284            "Falcon / MPT (uses 'transformer.blocks.{i}' prefix)",
1285        ),
1286        (
1287            "gpt_neox.layers.",
1288            "GPT-NeoX / Pythia (uses 'gpt_neox.layers.{i}' prefix)",
1289        ),
1290        (
1291            "transformer.layer.",
1292            "BLOOM (uses 'transformer.layer.{i}' prefix)",
1293        ),
1294    ];
1295
1296    for &(prefix, description) in patterns {
1297        if tensor_names.iter().any(|n| n.starts_with(prefix)) {
1298            return Some(format!(
1299                "this model uses {description} — candle-mi currently requires \
1300                 HF-standard 'model.layers.{{i}}' weight naming. \
1301                 Support for this architecture is planned in Phase 9 \
1302                 (tensor name remapping)"
1303            ));
1304        }
1305    }
1306
1307    // If no known pattern matched, show the first few tensor names as a
1308    // diagnostic aid
1309    if !tensor_names.iter().any(|n| n.starts_with("model.layers.")) {
1310        let sample: Vec<&str> = tensor_names.iter().take(5).map(String::as_str).collect();
1311        return Some(format!(
1312            "weight tensors use an unrecognized naming convention \
1313             (first 5: {}). candle-mi expects 'model.layers.{{i}}.self_attn.*' / \
1314             'model.layers.{{i}}.mlp.*' naming",
1315            sample.join(", ")
1316        ));
1317    }
1318
1319    None
1320}
1321
1322// ---------------------------------------------------------------------------
1323// Tests
1324// ---------------------------------------------------------------------------
1325
1326#[cfg(test)]
1327#[allow(clippy::unwrap_used)]
1328mod tests {
1329    use super::*;
1330
1331    /// Helper to create a minimal LLaMA-style config JSON.
1332    fn llama_config_json() -> Value {
1333        serde_json::json!({
1334            "model_type": "llama",
1335            "hidden_size": 2048,
1336            "num_hidden_layers": 16,
1337            "num_attention_heads": 32,
1338            "num_key_value_heads": 8,
1339            "intermediate_size": 8192,
1340            "vocab_size": 128256,
1341            "rms_norm_eps": 1e-5,
1342            "rope_theta": 500000.0,
1343            "max_position_embeddings": 131072
1344        })
1345    }
1346
1347    #[test]
1348    fn parse_llama_basic() {
1349        let config = TransformerConfig::from_hf_config(&llama_config_json()).unwrap();
1350        assert_eq!(config.hidden_size, 2048);
1351        assert_eq!(config.num_layers, 16);
1352        assert_eq!(config.num_attention_heads, 32);
1353        assert_eq!(config.num_kv_heads, 8);
1354        assert_eq!(config.head_dim, 64);
1355        assert_eq!(config.intermediate_size, 8192);
1356        assert_eq!(config.vocab_size, 128256);
1357        assert_eq!(config.norm_type, NormType::RmsNorm);
1358        assert_eq!(config.activation, Activation::Silu);
1359        assert_eq!(config.qkv_layout, QkvLayout::Separate);
1360        assert_eq!(config.mlp_layout, MlpLayout::GatedSeparate);
1361        assert!(!config.qkv_bias);
1362        assert!(!config.o_proj_bias);
1363        assert!(!config.mlp_bias);
1364        assert!(config.embedding_scale.is_none());
1365        assert!(!config.tie_word_embeddings);
1366        assert!((config.rope_theta - 500_000.0).abs() < f64::EPSILON);
1367        assert!(config.attn_logit_softcapping.is_none());
1368        assert!(config.sliding_window.is_none());
1369    }
1370
1371    #[test]
1372    fn parse_qwen2_bias() {
1373        let json = serde_json::json!({
1374            "model_type": "qwen2",
1375            "hidden_size": 896,
1376            "num_hidden_layers": 24,
1377            "num_attention_heads": 14,
1378            "num_key_value_heads": 2,
1379            "intermediate_size": 4864,
1380            "vocab_size": 151936,
1381            "attention_bias": true,
1382            "tie_word_embeddings": true
1383        });
1384        let config = TransformerConfig::from_hf_config(&json).unwrap();
1385        assert!(config.qkv_bias);
1386        assert!(!config.o_proj_bias);
1387        assert!(config.tie_word_embeddings);
1388    }
1389
1390    #[test]
1391    fn parse_gemma2_extensions() {
1392        let json = serde_json::json!({
1393            "model_type": "gemma2",
1394            "hidden_size": 2304,
1395            "num_hidden_layers": 26,
1396            "num_attention_heads": 8,
1397            "num_key_value_heads": 4,
1398            "head_dim": 256,
1399            "intermediate_size": 9216,
1400            "vocab_size": 256000,
1401            "attn_logit_softcapping": 50.0,
1402            "final_logit_softcapping": 30.0,
1403            "query_pre_attn_scalar": 256,
1404            "sliding_window": 4096
1405        });
1406        let config = TransformerConfig::from_hf_config(&json).unwrap();
1407        assert_eq!(config.norm_type, NormType::GemmaRmsNorm);
1408        assert_eq!(config.head_dim, 256);
1409        assert!(config.embedding_scale.is_some());
1410        assert!((config.attn_logit_softcapping.unwrap() - 50.0).abs() < f64::EPSILON);
1411        assert!((config.final_logit_softcapping.unwrap() - 30.0).abs() < f64::EPSILON);
1412        assert!((config.query_pre_attn_scalar.unwrap() - 256.0).abs() < f64::EPSILON);
1413        assert!(config.use_post_norms);
1414        assert_eq!(config.sliding_window, Some(4096));
1415        assert!(config.alternating_sliding_window);
1416    }
1417
1418    #[test]
1419    fn parse_phi3_fused() {
1420        let json = serde_json::json!({
1421            "model_type": "phi3",
1422            "hidden_size": 3072,
1423            "num_hidden_layers": 32,
1424            "num_attention_heads": 32,
1425            "num_key_value_heads": 32,
1426            "intermediate_size": 8192,
1427            "vocab_size": 32064
1428        });
1429        let config = TransformerConfig::from_hf_config(&json).unwrap();
1430        assert_eq!(config.qkv_layout, QkvLayout::Fused);
1431        assert_eq!(config.mlp_layout, MlpLayout::GatedFused);
1432    }
1433
1434    #[test]
1435    fn parse_starcoder2_bias_and_plain_mlp() {
1436        let json = serde_json::json!({
1437            "model_type": "starcoder2",
1438            "hidden_size": 3072,
1439            "num_hidden_layers": 30,
1440            "num_attention_heads": 24,
1441            "num_key_value_heads": 2,
1442            "intermediate_size": 12288,
1443            "vocab_size": 49152,
1444            "use_bias": true,
1445            "norm_type": "layer_norm"
1446        });
1447        let config = TransformerConfig::from_hf_config(&json).unwrap();
1448        assert_eq!(config.mlp_layout, MlpLayout::Plain);
1449        assert_eq!(config.activation, Activation::GeluApprox);
1450        assert_eq!(config.norm_type, NormType::LayerNorm);
1451        assert!(config.qkv_bias);
1452        assert!(config.o_proj_bias);
1453        assert!(config.mlp_bias);
1454    }
1455
1456    #[test]
1457    fn parse_mistral_sliding_window() {
1458        let json = serde_json::json!({
1459            "model_type": "mistral",
1460            "hidden_size": 4096,
1461            "num_hidden_layers": 32,
1462            "num_attention_heads": 32,
1463            "num_key_value_heads": 8,
1464            "intermediate_size": 14336,
1465            "vocab_size": 32000,
1466            "sliding_window": 4096
1467        });
1468        let config = TransformerConfig::from_hf_config(&json).unwrap();
1469        assert_eq!(config.sliding_window, Some(4096));
1470        assert!(!config.alternating_sliding_window);
1471    }
1472
1473    #[test]
1474    fn unsupported_model_type_errors() {
1475        let json = serde_json::json!({ "model_type": "bert" });
1476        let result = TransformerConfig::from_hf_config(&json);
1477        assert!(result.is_err());
1478    }
1479
1480    #[test]
1481    fn missing_model_type_errors() {
1482        let json = serde_json::json!({ "hidden_size": 768 });
1483        let result = TransformerConfig::from_hf_config(&json);
1484        assert!(result.is_err());
1485    }
1486
1487    // -----------------------------------------------------------------------
1488    // Auto-config validation: parse_auto() must match manual parsers
1489    // -----------------------------------------------------------------------
1490    //
1491    // For each of the 7 known transformer families, we verify that
1492    // parse_auto() produces the SAME TransformerConfig as the manual
1493    // parser.  Config JSON and tensor names are taken from real cached
1494    // models.
1495    //
1496    // Known exception — Phi-3 `sliding_window`: The Phi-3 config.json
1497    // contains "sliding_window": 2047 but the HuggingFace implementation
1498    // ignores it.  The manual parser sets None; the auto-parser reads
1499    // Some(2047).  We test all other fields and assert the sliding_window
1500    // difference explicitly.
1501
1502    /// Helper: convert `&[&str]` to `Vec<String>` for tensor names.
1503    fn tensor_names(names: &[&str]) -> Vec<String> {
1504        names.iter().map(|s| (*s).to_owned()).collect()
1505    }
1506
1507    #[test]
1508    fn auto_config_matches_llama() {
1509        // LLaMA 3.2 1B — actual config.json + tensor names
1510        let json = serde_json::json!({
1511            "model_type": "llama",
1512            "hidden_size": 2048,
1513            "num_hidden_layers": 16,
1514            "num_attention_heads": 32,
1515            "num_key_value_heads": 8,
1516            "head_dim": 64,
1517            "intermediate_size": 8192,
1518            "vocab_size": 128256,
1519            "rms_norm_eps": 1e-5,
1520            "rope_theta": 500000.0,
1521            "max_position_embeddings": 131072,
1522            "hidden_act": "silu",
1523            "attention_bias": false,
1524            "mlp_bias": false,
1525            "tie_word_embeddings": true
1526        });
1527        let names = tensor_names(&[
1528            "model.embed_tokens.weight",
1529            "model.layers.0.input_layernorm.weight",
1530            "model.layers.0.mlp.down_proj.weight",
1531            "model.layers.0.mlp.gate_proj.weight",
1532            "model.layers.0.mlp.up_proj.weight",
1533            "model.layers.0.post_attention_layernorm.weight",
1534            "model.layers.0.self_attn.k_proj.weight",
1535            "model.layers.0.self_attn.o_proj.weight",
1536            "model.layers.0.self_attn.q_proj.weight",
1537            "model.layers.0.self_attn.v_proj.weight",
1538            "model.norm.weight",
1539        ]);
1540
1541        let manual = TransformerConfig::from_hf_config(&json).unwrap();
1542        let auto = TransformerConfig::parse_auto(&json, &names, "llama").unwrap();
1543        assert_eq!(auto, manual);
1544    }
1545
1546    #[test]
1547    fn auto_config_matches_qwen2() {
1548        // Qwen2.5-Coder-3B-Instruct — actual config.json + tensor names
1549        let json = serde_json::json!({
1550            "model_type": "qwen2",
1551            "hidden_size": 2048,
1552            "num_hidden_layers": 36,
1553            "num_attention_heads": 16,
1554            "num_key_value_heads": 2,
1555            "intermediate_size": 11008,
1556            "vocab_size": 151936,
1557            "rms_norm_eps": 1e-6,
1558            "rope_theta": 1000000.0,
1559            "max_position_embeddings": 32768,
1560            "hidden_act": "silu",
1561            "tie_word_embeddings": true,
1562            "sliding_window": 32768,
1563            "use_sliding_window": false
1564        });
1565        let names = tensor_names(&[
1566            "model.embed_tokens.weight",
1567            "model.layers.0.input_layernorm.weight",
1568            "model.layers.0.mlp.down_proj.weight",
1569            "model.layers.0.mlp.gate_proj.weight",
1570            "model.layers.0.mlp.up_proj.weight",
1571            "model.layers.0.post_attention_layernorm.weight",
1572            "model.layers.0.self_attn.k_proj.bias",
1573            "model.layers.0.self_attn.k_proj.weight",
1574            "model.layers.0.self_attn.o_proj.weight",
1575            "model.layers.0.self_attn.q_proj.bias",
1576            "model.layers.0.self_attn.q_proj.weight",
1577            "model.layers.0.self_attn.v_proj.bias",
1578            "model.layers.0.self_attn.v_proj.weight",
1579            "model.norm.weight",
1580        ]);
1581
1582        let manual = TransformerConfig::from_hf_config(&json).unwrap();
1583        let auto = TransformerConfig::parse_auto(&json, &names, "qwen2").unwrap();
1584        assert_eq!(auto, manual);
1585    }
1586
1587    #[test]
1588    fn auto_config_matches_gemma() {
1589        // CodeGemma 7B IT — actual config.json + tensor names
1590        let json = serde_json::json!({
1591            "model_type": "gemma",
1592            "hidden_size": 3072,
1593            "num_hidden_layers": 28,
1594            "num_attention_heads": 16,
1595            "num_key_value_heads": 16,
1596            "head_dim": 256,
1597            "intermediate_size": 24576,
1598            "vocab_size": 256000,
1599            "rms_norm_eps": 1e-6,
1600            "rope_theta": 10000.0,
1601            "max_position_embeddings": 8192,
1602            "hidden_activation": "gelu_pytorch_tanh"
1603        });
1604        let names = tensor_names(&[
1605            "model.embed_tokens.weight",
1606            "model.layers.0.input_layernorm.weight",
1607            "model.layers.0.mlp.down_proj.weight",
1608            "model.layers.0.mlp.gate_proj.weight",
1609            "model.layers.0.mlp.up_proj.weight",
1610            "model.layers.0.post_attention_layernorm.weight",
1611            "model.layers.0.self_attn.k_proj.weight",
1612            "model.layers.0.self_attn.o_proj.weight",
1613            "model.layers.0.self_attn.q_proj.weight",
1614            "model.layers.0.self_attn.v_proj.weight",
1615            "model.norm.weight",
1616        ]);
1617
1618        let manual = TransformerConfig::from_hf_config(&json).unwrap();
1619        let auto = TransformerConfig::parse_auto(&json, &names, "gemma").unwrap();
1620        assert_eq!(auto, manual);
1621    }
1622
1623    #[test]
1624    fn auto_config_matches_gemma2() {
1625        // Gemma 2 2B — actual config.json + tensor names
1626        let json = serde_json::json!({
1627            "model_type": "gemma2",
1628            "hidden_size": 2304,
1629            "num_hidden_layers": 26,
1630            "num_attention_heads": 8,
1631            "num_key_value_heads": 4,
1632            "head_dim": 256,
1633            "intermediate_size": 9216,
1634            "vocab_size": 256000,
1635            "rms_norm_eps": 1e-6,
1636            "rope_theta": 10000.0,
1637            "max_position_embeddings": 8192,
1638            "hidden_act": "gelu_pytorch_tanh",
1639            "hidden_activation": "gelu_pytorch_tanh",
1640            "attn_logit_softcapping": 50.0,
1641            "final_logit_softcapping": 30.0,
1642            "query_pre_attn_scalar": 256,
1643            "sliding_window": 4096
1644        });
1645        let names = tensor_names(&[
1646            "model.embed_tokens.weight",
1647            "model.layers.0.input_layernorm.weight",
1648            "model.layers.0.mlp.down_proj.weight",
1649            "model.layers.0.mlp.gate_proj.weight",
1650            "model.layers.0.mlp.up_proj.weight",
1651            "model.layers.0.post_attention_layernorm.weight",
1652            "model.layers.0.post_feedforward_layernorm.weight",
1653            "model.layers.0.pre_feedforward_layernorm.weight",
1654            "model.layers.0.self_attn.k_proj.weight",
1655            "model.layers.0.self_attn.o_proj.weight",
1656            "model.layers.0.self_attn.q_proj.weight",
1657            "model.layers.0.self_attn.v_proj.weight",
1658            "model.norm.weight",
1659        ]);
1660
1661        let manual = TransformerConfig::from_hf_config(&json).unwrap();
1662        let auto = TransformerConfig::parse_auto(&json, &names, "gemma2").unwrap();
1663        assert_eq!(auto, manual);
1664    }
1665
1666    #[test]
1667    fn auto_config_matches_phi3() {
1668        // Phi-3-mini-4k-instruct — actual config.json + tensor names
1669        //
1670        // Known exception: Phi-3 config.json contains "sliding_window": 2047
1671        // but the manual parser ignores it (sets None).  The auto-parser
1672        // reads it as Some(2047).  We verify all other fields match and
1673        // assert the sliding_window difference explicitly.
1674        let json = serde_json::json!({
1675            "model_type": "phi3",
1676            "hidden_size": 3072,
1677            "num_hidden_layers": 32,
1678            "num_attention_heads": 32,
1679            "num_key_value_heads": 32,
1680            "intermediate_size": 8192,
1681            "vocab_size": 32064,
1682            "rms_norm_eps": 1e-5,
1683            "rope_theta": 10000.0,
1684            "max_position_embeddings": 4096,
1685            "hidden_act": "silu",
1686            "tie_word_embeddings": false,
1687            "sliding_window": 2047,
1688            "attention_bias": false
1689        });
1690        let names = tensor_names(&[
1691            "lm_head.weight",
1692            "model.embed_tokens.weight",
1693            "model.layers.0.input_layernorm.weight",
1694            "model.layers.0.mlp.down_proj.weight",
1695            "model.layers.0.mlp.gate_up_proj.weight",
1696            "model.layers.0.post_attention_layernorm.weight",
1697            "model.layers.0.self_attn.o_proj.weight",
1698            "model.layers.0.self_attn.qkv_proj.weight",
1699            "model.norm.weight",
1700        ]);
1701
1702        let manual = TransformerConfig::from_hf_config(&json).unwrap();
1703        let auto = TransformerConfig::parse_auto(&json, &names, "phi3").unwrap();
1704
1705        // Known exception: sliding_window
1706        assert_eq!(manual.sliding_window, None);
1707        assert_eq!(auto.sliding_window, Some(2047));
1708
1709        // All other fields must match — compare field by field excluding
1710        // sliding_window by creating copies with the same value.
1711        let mut auto_adjusted = auto;
1712        auto_adjusted.sliding_window = None;
1713        assert_eq!(auto_adjusted, manual);
1714    }
1715
1716    #[test]
1717    fn auto_config_matches_starcoder2() {
1718        // StarCoder2-3B — actual config.json + tensor names
1719        let json = serde_json::json!({
1720            "model_type": "starcoder2",
1721            "hidden_size": 3072,
1722            "num_hidden_layers": 30,
1723            "num_attention_heads": 24,
1724            "num_key_value_heads": 2,
1725            "intermediate_size": 12288,
1726            "vocab_size": 49152,
1727            "norm_epsilon": 1e-5,
1728            "norm_type": "layer_norm",
1729            "rope_theta": 999999.4420358813,
1730            "max_position_embeddings": 16384,
1731            "hidden_act": "gelu_pytorch_tanh",
1732            "use_bias": true,
1733            "sliding_window": 4096
1734        });
1735        let names = tensor_names(&[
1736            "model.embed_tokens.weight",
1737            "model.layers.0.input_layernorm.bias",
1738            "model.layers.0.input_layernorm.weight",
1739            "model.layers.0.mlp.c_fc.bias",
1740            "model.layers.0.mlp.c_fc.weight",
1741            "model.layers.0.mlp.c_proj.bias",
1742            "model.layers.0.mlp.c_proj.weight",
1743            "model.layers.0.post_attention_layernorm.bias",
1744            "model.layers.0.post_attention_layernorm.weight",
1745            "model.layers.0.self_attn.k_proj.bias",
1746            "model.layers.0.self_attn.k_proj.weight",
1747            "model.layers.0.self_attn.o_proj.bias",
1748            "model.layers.0.self_attn.o_proj.weight",
1749            "model.layers.0.self_attn.q_proj.bias",
1750            "model.layers.0.self_attn.q_proj.weight",
1751            "model.layers.0.self_attn.v_proj.bias",
1752            "model.layers.0.self_attn.v_proj.weight",
1753            "model.norm.bias",
1754            "model.norm.weight",
1755        ]);
1756
1757        let manual = TransformerConfig::from_hf_config(&json).unwrap();
1758        let auto = TransformerConfig::parse_auto(&json, &names, "starcoder2").unwrap();
1759        assert_eq!(auto, manual);
1760    }
1761
1762    #[test]
1763    fn auto_config_matches_mistral() {
1764        // Mistral 7B v0.1 — actual config.json + tensor names
1765        let json = serde_json::json!({
1766            "model_type": "mistral",
1767            "hidden_size": 4096,
1768            "num_hidden_layers": 32,
1769            "num_attention_heads": 32,
1770            "num_key_value_heads": 8,
1771            "intermediate_size": 14336,
1772            "vocab_size": 32000,
1773            "rms_norm_eps": 1e-5,
1774            "rope_theta": 10000.0,
1775            "max_position_embeddings": 32768,
1776            "hidden_act": "silu",
1777            "tie_word_embeddings": false,
1778            "sliding_window": 4096
1779        });
1780        let names = tensor_names(&[
1781            "lm_head.weight",
1782            "model.embed_tokens.weight",
1783            "model.layers.0.input_layernorm.weight",
1784            "model.layers.0.mlp.down_proj.weight",
1785            "model.layers.0.mlp.gate_proj.weight",
1786            "model.layers.0.mlp.up_proj.weight",
1787            "model.layers.0.post_attention_layernorm.weight",
1788            "model.layers.0.self_attn.k_proj.weight",
1789            "model.layers.0.self_attn.o_proj.weight",
1790            "model.layers.0.self_attn.q_proj.weight",
1791            "model.layers.0.self_attn.v_proj.weight",
1792            "model.norm.weight",
1793        ]);
1794
1795        let manual = TransformerConfig::from_hf_config(&json).unwrap();
1796        let auto = TransformerConfig::parse_auto(&json, &names, "mistral").unwrap();
1797        assert_eq!(auto, manual);
1798    }
1799
1800    #[test]
1801    fn auto_config_unknown_model_type() {
1802        // Verify auto-config works for an unknown model_type using
1803        // LLaMA-like config.json + tensor names.
1804        let json = serde_json::json!({
1805            "model_type": "my_custom_llama",
1806            "hidden_size": 2048,
1807            "num_hidden_layers": 16,
1808            "num_attention_heads": 32,
1809            "num_key_value_heads": 8,
1810            "intermediate_size": 8192,
1811            "vocab_size": 32000,
1812            "rms_norm_eps": 1e-5,
1813            "rope_theta": 10000.0,
1814            "max_position_embeddings": 4096,
1815            "hidden_act": "silu"
1816        });
1817        let names = tensor_names(&[
1818            "lm_head.weight",
1819            "model.embed_tokens.weight",
1820            "model.layers.0.input_layernorm.weight",
1821            "model.layers.0.mlp.down_proj.weight",
1822            "model.layers.0.mlp.gate_proj.weight",
1823            "model.layers.0.mlp.up_proj.weight",
1824            "model.layers.0.post_attention_layernorm.weight",
1825            "model.layers.0.self_attn.k_proj.weight",
1826            "model.layers.0.self_attn.o_proj.weight",
1827            "model.layers.0.self_attn.q_proj.weight",
1828            "model.layers.0.self_attn.v_proj.weight",
1829            "model.norm.weight",
1830        ]);
1831
1832        // from_hf_config_auto should use auto-parser (not error)
1833        let config = TransformerConfig::from_hf_config_auto(&json, &names).unwrap();
1834        assert_eq!(config.hidden_size, 2048);
1835        assert_eq!(config.num_layers, 16);
1836        assert_eq!(config.num_attention_heads, 32);
1837        assert_eq!(config.num_kv_heads, 8);
1838        assert_eq!(config.head_dim, 64);
1839        assert_eq!(config.norm_type, NormType::RmsNorm);
1840        assert_eq!(config.activation, Activation::Silu);
1841        assert_eq!(config.qkv_layout, QkvLayout::Separate);
1842        assert_eq!(config.mlp_layout, MlpLayout::GatedSeparate);
1843        assert!(!config.qkv_bias);
1844        assert!(!config.o_proj_bias);
1845        assert!(!config.mlp_bias);
1846        assert!(config.embedding_scale.is_none());
1847        assert!(!config.tie_word_embeddings);
1848        assert!(config.sliding_window.is_none());
1849    }
1850
1851    #[test]
1852    fn auto_config_dispatches_known_families() {
1853        // Verify from_hf_config_auto delegates known families to manual parsers
1854        let json = llama_config_json();
1855        let names = tensor_names(&["model.embed_tokens.weight"]);
1856
1857        let auto = TransformerConfig::from_hf_config_auto(&json, &names).unwrap();
1858        let manual = TransformerConfig::from_hf_config(&json).unwrap();
1859        assert_eq!(auto, manual);
1860    }
1861
1862    // -----------------------------------------------------------------------
1863    // Compatibility check tests
1864    // -----------------------------------------------------------------------
1865
1866    #[test]
1867    fn compatibility_check_passes_standard_model() {
1868        let json = serde_json::json!({
1869            "model_type": "my_custom",
1870            "hidden_size": 2048,
1871            "num_hidden_layers": 16,
1872            "num_attention_heads": 32,
1873            "intermediate_size": 8192,
1874            "vocab_size": 32000,
1875            "tie_word_embeddings": true
1876        });
1877        let names = tensor_names(&[
1878            "model.embed_tokens.weight",
1879            "model.layers.0.input_layernorm.weight",
1880            "model.layers.0.post_attention_layernorm.weight",
1881            "model.layers.0.self_attn.q_proj.weight",
1882            "model.layers.0.mlp.gate_proj.weight",
1883            "model.norm.weight",
1884        ]);
1885        let report = TransformerConfig::check_auto_compatibility(&json, &names);
1886        assert!(report.compatible, "issues: {:?}", report.issues);
1887    }
1888
1889    #[test]
1890    fn compatibility_check_detects_missing_norms() {
1891        // OLMo-like: no norm weights at all
1892        let json = serde_json::json!({
1893            "model_type": "olmo",
1894            "hidden_size": 2048,
1895            "num_hidden_layers": 16,
1896            "num_attention_heads": 16,
1897            "intermediate_size": 8192,
1898            "vocab_size": 50304
1899        });
1900        let names = tensor_names(&[
1901            "model.embed_tokens.weight",
1902            "model.layers.0.self_attn.q_proj.weight",
1903            "model.layers.0.mlp.gate_proj.weight",
1904            "model.layers.0.mlp.down_proj.weight",
1905        ]);
1906        let report = TransformerConfig::check_auto_compatibility(&json, &names);
1907        assert!(!report.compatible);
1908        // Should detect missing input_layernorm, post_attention_layernorm, and model.norm
1909        assert!(report.issues.len() >= 3, "issues: {:?}", report.issues);
1910        assert!(
1911            report.issues.iter().any(|i| i.contains("input_layernorm")),
1912            "should mention input_layernorm"
1913        );
1914        assert!(
1915            report.issues.iter().any(|i| i.contains("model.norm")),
1916            "should mention model.norm"
1917        );
1918    }
1919
1920    #[test]
1921    fn compatibility_check_detects_missing_config_fields() {
1922        let json = serde_json::json!({
1923            "model_type": "mystery",
1924            "hidden_size": 768
1925        });
1926        let names = tensor_names(&[]);
1927        let report = TransformerConfig::check_auto_compatibility(&json, &names);
1928        assert!(!report.compatible);
1929        // Missing: num_hidden_layers, num_attention_heads, intermediate_size, vocab_size
1930        assert!(
1931            report
1932                .issues
1933                .iter()
1934                .any(|i| i.contains("num_hidden_layers")),
1935            "should mention num_hidden_layers"
1936        );
1937    }
1938
1939    #[test]
1940    fn compatibility_check_detects_missing_lm_head() {
1941        let json = serde_json::json!({
1942            "model_type": "custom",
1943            "hidden_size": 2048,
1944            "num_hidden_layers": 16,
1945            "num_attention_heads": 32,
1946            "intermediate_size": 8192,
1947            "vocab_size": 32000,
1948            "tie_word_embeddings": false
1949        });
1950        let names = tensor_names(&[
1951            "model.embed_tokens.weight",
1952            "model.layers.0.input_layernorm.weight",
1953            "model.layers.0.post_attention_layernorm.weight",
1954            "model.layers.0.self_attn.q_proj.weight",
1955            "model.layers.0.mlp.gate_proj.weight",
1956            "model.norm.weight",
1957            // Missing: lm_head.weight
1958        ]);
1959        let report = TransformerConfig::check_auto_compatibility(&json, &names);
1960        assert!(!report.compatible);
1961        assert!(
1962            report.issues.iter().any(|i| i.contains("lm_head")),
1963            "should mention lm_head"
1964        );
1965    }
1966
1967    #[test]
1968    fn compatibility_check_config_only() {
1969        let good = serde_json::json!({
1970            "hidden_size": 2048,
1971            "num_hidden_layers": 16,
1972            "num_attention_heads": 32,
1973            "intermediate_size": 8192,
1974            "vocab_size": 32000
1975        });
1976        assert!(TransformerConfig::check_config_fields(&good).compatible);
1977
1978        let bad = serde_json::json!({
1979            "hidden_size": 2048
1980        });
1981        let report = TransformerConfig::check_config_fields(&bad);
1982        assert!(!report.compatible);
1983        assert_eq!(report.issues.len(), 4); // missing 4 of 5 required fields
1984    }
1985
1986    #[test]
1987    fn compatibility_into_result_error_message() {
1988        let json = serde_json::json!({
1989            "model_type": "olmo",
1990            "hidden_size": 2048,
1991            "num_hidden_layers": 16,
1992            "num_attention_heads": 16,
1993            "intermediate_size": 8192,
1994            "vocab_size": 50304
1995        });
1996        let names = tensor_names(&[
1997            "model.embed_tokens.weight",
1998            "model.layers.0.self_attn.q_proj.weight",
1999            "model.layers.0.mlp.gate_proj.weight",
2000        ]);
2001        let result = TransformerConfig::check_auto_compatibility(&json, &names).into_result();
2002        assert!(result.is_err());
2003        let msg = result.unwrap_err().to_string();
2004        assert!(
2005            msg.contains("not compatible with GenericTransformer"),
2006            "error should explain incompatibility: {msg}"
2007        );
2008    }
2009
2010    #[test]
2011    fn compatibility_check_shows_gpt2_naming_hint() {
2012        let json = serde_json::json!({
2013            "model_type": "gpt2",
2014            "hidden_size": 768,
2015            "num_hidden_layers": 12,
2016            "num_attention_heads": 12,
2017            "intermediate_size": 3072,
2018            "vocab_size": 50257
2019        });
2020        let names = tensor_names(&[
2021            "transformer.wte.weight",
2022            "transformer.wpe.weight",
2023            "transformer.h.0.ln_1.weight",
2024            "transformer.h.0.attn.c_attn.weight",
2025            "transformer.h.0.mlp.c_fc.weight",
2026            "transformer.ln_f.weight",
2027        ]);
2028        let report = TransformerConfig::check_auto_compatibility(&json, &names);
2029        assert!(!report.compatible);
2030        // Should detect GPT-2 naming
2031        assert!(
2032            report.issues.iter().any(|i| i.contains("GPT-2")),
2033            "should detect GPT-2 naming convention: {:?}",
2034            report.issues
2035        );
2036        // Should show found embedding-like tensors
2037        assert!(
2038            report
2039                .issues
2040                .iter()
2041                .any(|i| i.contains("transformer.wte.weight")),
2042            "should show found embedding tensor: {:?}",
2043            report.issues
2044        );
2045        // Should show found attention-like tensors
2046        assert!(
2047            report.issues.iter().any(|i| i.contains("c_attn")),
2048            "should show found attention tensor: {:?}",
2049            report.issues
2050        );
2051    }
2052
2053    #[test]
2054    fn compatibility_check_shows_found_tensors_for_unknown_naming() {
2055        let json = serde_json::json!({
2056            "model_type": "custom_arch",
2057            "hidden_size": 512,
2058            "num_hidden_layers": 6,
2059            "num_attention_heads": 8,
2060            "intermediate_size": 2048,
2061            "vocab_size": 30000
2062        });
2063        let names = tensor_names(&[
2064            "encoder.layer.0.attention.query.weight",
2065            "encoder.layer.0.attention.key.weight",
2066            "encoder.layer.0.ffn.dense.weight",
2067            "encoder.embeddings.weight",
2068        ]);
2069        let report = TransformerConfig::check_auto_compatibility(&json, &names);
2070        assert!(!report.compatible);
2071        // Should show the unrecognized naming hint with sample tensors
2072        assert!(
2073            report
2074                .issues
2075                .iter()
2076                .any(|i| i.contains("unrecognized naming convention")),
2077            "should flag unrecognized naming: {:?}",
2078            report.issues
2079        );
2080        // Should show found embedding-like tensor
2081        assert!(
2082            report
2083                .issues
2084                .iter()
2085                .any(|i| i.contains("encoder.embeddings.weight")),
2086            "should show found embedding: {:?}",
2087            report.issues
2088        );
2089    }
2090
2091    #[test]
2092    fn compatibility_check_shows_found_norm_tensors() {
2093        // A model with HF-standard layer prefix but non-standard norm names
2094        let json = serde_json::json!({
2095            "model_type": "custom",
2096            "hidden_size": 2048,
2097            "num_hidden_layers": 16,
2098            "num_attention_heads": 32,
2099            "intermediate_size": 8192,
2100            "vocab_size": 32000,
2101            "tie_word_embeddings": true
2102        });
2103        let names = tensor_names(&[
2104            "model.embed_tokens.weight",
2105            "model.layers.0.self_attn.q_proj.weight",
2106            "model.layers.0.mlp.gate_proj.weight",
2107            "model.layers.0.attention_norm.weight",
2108            "model.layers.0.ffn_norm.weight",
2109            "model.final_norm.weight",
2110        ]);
2111        let report = TransformerConfig::check_auto_compatibility(&json, &names);
2112        assert!(!report.compatible);
2113        // Should show the alternative norm tensors that were found
2114        assert!(
2115            report.issues.iter().any(|i| i.contains("attention_norm")),
2116            "should show found norm tensors: {:?}",
2117            report.issues
2118        );
2119    }
2120}