llama-rs 0.17.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
//! HuggingFace config.json parser
//!
//! Parses the `config.json` file that accompanies HuggingFace Optimum ONNX exports
//! and SafeTensors models into the internal `ModelConfig` type.

use std::path::Path;

use crate::model::{
    ActivationType, Architecture, ModelConfig, RopeConfig, RopeScalingType, RopeType,
};

use super::error::{ModelError, ModelResult};

/// Parsed HuggingFace config.json
#[derive(Debug, Clone, serde::Deserialize)]
pub struct HfConfig {
    /// Model type identifier (e.g., "llama", "mistral", "qwen2")
    pub model_type: Option<String>,

    /// Vocabulary size
    pub vocab_size: Option<usize>,

    /// Hidden size / embedding dimension
    pub hidden_size: Option<usize>,

    /// Intermediate (FFN) size
    pub intermediate_size: Option<usize>,

    /// Number of hidden layers
    pub num_hidden_layers: Option<usize>,

    /// Number of attention heads
    pub num_attention_heads: Option<usize>,

    /// Number of key-value heads (for GQA)
    pub num_key_value_heads: Option<usize>,

    /// Maximum position embeddings
    pub max_position_embeddings: Option<usize>,

    /// RMS normalization epsilon
    pub rms_norm_eps: Option<f32>,

    /// RoPE theta (frequency base)
    pub rope_theta: Option<f32>,

    /// RoPE scaling configuration
    pub rope_scaling: Option<RopeScalingConfig>,

    /// Whether to tie word embeddings (input and output)
    pub tie_word_embeddings: Option<bool>,

    /// Hidden activation function
    pub hidden_act: Option<String>,

    /// Whether attention uses bias
    #[serde(default)]
    pub attention_bias: bool,

    /// Whether MLP uses bias
    #[serde(default)]
    pub mlp_bias: bool,

    /// Head dimension (explicit override, rare)
    pub head_dim: Option<usize>,

    /// BOS token ID
    pub bos_token_id: Option<u32>,

    /// EOS token ID
    pub eos_token_id: Option<serde_json::Value>,

    // Gemma 4 architecture fields
    /// Per-layer sliding window attention pattern (Gemma 4)
    #[serde(default)]
    pub sliding_window_pattern: Option<Vec<bool>>,

    /// Number of KV shared layers (Gemma 4)
    #[serde(default)]
    pub num_key_value_shared_layers: Option<usize>,

    /// Per-layer input embedding dimension (Gemma 4 PLIE)
    #[serde(default)]
    pub n_embd_per_layer: Option<usize>,

    /// SWA head dimension (Gemma 4)
    #[serde(default)]
    pub head_dim_swa: Option<usize>,

    /// SWA RoPE frequency base (Gemma 4)
    #[serde(default)]
    pub rope_theta_swa: Option<f32>,

    /// Sliding window size
    #[serde(default)]
    pub sliding_window: Option<usize>,

    /// Final logit softcapping value (Gemma 4: 30.0)
    #[serde(default)]
    pub final_logit_softcapping: Option<f32>,

    /// Partial rotary factor for global attention (Gemma 4: 0.25)
    /// Parsed from rope_parameters.full_attention.partial_rotary_factor
    #[serde(default)]
    pub partial_rotary_factor: Option<f32>,

    /// Hidden activation function name (alternative to hidden_act, used by Gemma 4)
    #[serde(default)]
    pub hidden_activation: Option<String>,
}

/// RoPE scaling configuration from config.json
#[derive(Debug, Clone, serde::Deserialize)]
pub struct RopeScalingConfig {
    /// Scaling type (e.g., "linear", "dynamic")
    #[serde(rename = "type")]
    pub scaling_type: Option<String>,

    /// Scaling factor
    pub factor: Option<f32>,
}

impl HfConfig {
    /// Load config from a JSON file path
    pub fn from_file<P: AsRef<Path>>(path: P) -> ModelResult<Self> {
        let path = path.as_ref();
        let data = std::fs::read_to_string(path)
            .map_err(|e| ModelError::ConfigError(format!("{}: {}", path.display(), e)))?;
        Self::from_json(&data)
    }

    /// Parse config from a JSON string.
    ///
    /// Handles multimodal models (Gemma 4) where the text model config is
    /// nested under `text_config`. Maps Gemma 4-specific field names to
    /// the HfConfig field names.
    pub fn from_json(json: &str) -> ModelResult<Self> {
        let raw: serde_json::Value = serde_json::from_str(json)
            .map_err(|e| ModelError::ConfigError(format!("Failed to parse config.json: {e}")))?;

        // For multimodal models, flatten text_config into the top level
        let config_value = if let Some(text_config) = raw.get("text_config") {
            let mut merged = text_config.clone();
            // Keep top-level model_type if present (e.g., "gemma4")
            if let Some(mt) = raw.get("model_type") {
                merged.as_object_mut().unwrap().insert("model_type".into(), mt.clone());
            }
            if let Some(tie) = raw.get("tie_word_embeddings") {
                merged.as_object_mut().unwrap().insert("tie_word_embeddings".into(), tie.clone());
            }
            merged
        } else {
            raw.clone()
        };

        let mut config: HfConfig = serde_json::from_value(config_value.clone())
            .map_err(|e| ModelError::ConfigError(format!("Failed to parse config.json: {e}")))?;

        // Map Gemma 4-specific field names to our generic fields
        if let Some(obj) = config_value.as_object() {
            // layer_types: ["sliding_attention", "full_attention", ...] → sliding_window_pattern
            if config.sliding_window_pattern.is_none() {
                if let Some(layer_types) = obj.get("layer_types").and_then(|v| v.as_array()) {
                    config.sliding_window_pattern = Some(
                        layer_types.iter()
                            .map(|v| v.as_str() == Some("sliding_attention"))
                            .collect()
                    );
                }
            }

            // num_kv_shared_layers (Gemma 4 naming) → num_key_value_shared_layers
            if config.num_key_value_shared_layers.is_none() {
                if let Some(v) = obj.get("num_kv_shared_layers").and_then(|v| v.as_u64()) {
                    config.num_key_value_shared_layers = Some(v as usize);
                }
            }

            // hidden_size_per_layer_input → n_embd_per_layer
            if config.n_embd_per_layer.is_none() {
                if let Some(v) = obj.get("hidden_size_per_layer_input").and_then(|v| v.as_u64()) {
                    config.n_embd_per_layer = Some(v as usize);
                }
            }

            // global_head_dim (Gemma 4: 512 for global attention)
            // head_dim is already parsed as the SWA head dim (256)
            // We need head_dim_swa = head_dim and then set head_dim = global_head_dim
            if config.head_dim_swa.is_none() {
                if let Some(global_hd) = obj.get("global_head_dim").and_then(|v| v.as_u64()) {
                    // SWA head dim is the regular head_dim
                    config.head_dim_swa = config.head_dim;
                    // Override head_dim to global for the base config
                    config.head_dim = Some(global_hd as usize);
                }
            }

            // rope_parameters: { "sliding_attention": { "rope_theta": 10000 }, "full_attention": { "rope_theta": 1000000, "partial_rotary_factor": 0.25 } }
            if config.rope_theta_swa.is_none() {
                if let Some(rp) = obj.get("rope_parameters").and_then(|v| v.as_object()) {
                    if let Some(swa) = rp.get("sliding_attention").and_then(|v| v.as_object()) {
                        if let Some(theta) = swa.get("rope_theta").and_then(|v| v.as_f64()) {
                            config.rope_theta_swa = Some(theta as f32);
                        }
                    }
                    if let Some(full) = rp.get("full_attention").and_then(|v| v.as_object()) {
                        if let Some(theta) = full.get("rope_theta").and_then(|v| v.as_f64()) {
                            config.rope_theta = Some(theta as f32);
                        }
                        if config.partial_rotary_factor.is_none() {
                            if let Some(prf) = full.get("partial_rotary_factor").and_then(|v| v.as_f64()) {
                                config.partial_rotary_factor = Some(prf as f32);
                            }
                        }
                    }
                }
            }
        }

        Ok(config)
    }

    /// Detect the model architecture from model_type
    pub fn architecture(&self) -> Architecture {
        match self.model_type.as_deref() {
            Some("llama") => Architecture::Llama,
            Some("mistral") => Architecture::Mistral,
            Some("qwen2") => Architecture::Qwen2,
            Some("codellama") => Architecture::CodeLlama,
            Some("yi") => Architecture::Yi,
            Some("deepseek") | Some("deepseek_v2") => Architecture::DeepSeek,
            Some("mixtral") => Architecture::Mixtral,
            Some("gemma4") | Some("gemma4_text") => Architecture::Gemma4,
            Some("gemma2") | Some("gemma2_text") => Architecture::Gemma2,
            Some("gemma") => Architecture::Gemma,
            _ => Architecture::Unknown,
        }
    }

    /// Convert to internal ModelConfig
    pub fn to_model_config(&self) -> ModelResult<ModelConfig> {
        let hidden_size = self
            .hidden_size
            .ok_or_else(|| ModelError::ConfigError("missing hidden_size in config.json".into()))?;

        let num_heads = self.num_attention_heads.ok_or_else(|| {
            ModelError::ConfigError("missing num_attention_heads in config.json".into())
        })?;

        let num_layers = self.num_hidden_layers.ok_or_else(|| {
            ModelError::ConfigError("missing num_hidden_layers in config.json".into())
        })?;

        let num_kv_heads = self.num_key_value_heads.unwrap_or(num_heads);
        let head_dim = self.head_dim.unwrap_or(hidden_size / num_heads);
        let intermediate_size = self.intermediate_size.unwrap_or(hidden_size * 4 * 2 / 3);
        let max_seq_len = self.max_position_embeddings.unwrap_or(2048);
        let norm_eps = self.rms_norm_eps.unwrap_or(1e-5);
        let vocab_size = self.vocab_size.unwrap_or(32000);

        // Determine RoPE type from architecture
        let architecture = self.architecture();
        let rope_type = match architecture {
            Architecture::Qwen2 => RopeType::NeoX,
            _ => RopeType::Normal,
        };

        let freq_base = self.rope_theta.unwrap_or(10000.0);
        let freq_scale = self
            .rope_scaling
            .as_ref()
            .and_then(|s| s.factor)
            .unwrap_or(1.0);

        let rope_config = RopeConfig {
            freq_base,
            freq_scale,
            n_dims: head_dim,
            scaling_type: RopeScalingType::None,
            original_max_position_embeddings: max_seq_len,
            rope_type,
            mrope_sections: None,
        };

        let act_str = self.hidden_activation.as_deref()
            .or(self.hidden_act.as_deref());
        let hidden_act = match act_str {
            Some("gelu") | Some("gelu_new") | Some("gelu_pytorch_tanh") => ActivationType::GELU,
            _ => ActivationType::SiLU,
        };
        let uses_gelu = matches!(hidden_act, ActivationType::GELU);

        let sliding_window = self.sliding_window.unwrap_or(0);

        let mut config = ModelConfig {
            vocab_size,
            hidden_size,
            intermediate_size,
            num_layers,
            num_heads,
            num_kv_heads,
            head_dim,
            max_seq_len,
            norm_eps,
            rope_config,
            use_parallel_residual: false,
            hidden_act,
            attention_bias: self.attention_bias,
            mlp_bias: self.mlp_bias,
            tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
            num_experts: 0,
            num_experts_per_token: 0,
            expert_intermediate_size: 0,
            key_length: head_dim,
            value_length: head_dim,
            ssm_d_inner: 0,
            ssm_d_state: 0,
            ssm_n_group: 0,
            ssm_dt_rank: 0,
            ssm_conv_kernel: 0,
            attn_logit_softcap: 0.0,
            final_logit_softcap: self.final_logit_softcapping.unwrap_or(0.0),
            sliding_window,
            has_combined_qkv: false,
            uses_layer_norm: false,
            uses_gelu,
            has_ffn_gate: true,
            attention_layer_configs: None,
            kv_source_layer: None,
        };

        // After building the base ModelConfig, add Gemma 4 heterogeneous attention
        if let Some(ref pattern) = self.sliding_window_pattern {
            // Only proceed if we have the sliding window pattern
            if architecture.has_heterogeneous_attention() {
                // SWA-specific fields
                let swa_head_dim = self.head_dim_swa.unwrap_or(config.head_dim);
                let swa_kv_heads = config.num_kv_heads;
                let swa_rope_freq_base = self.rope_theta_swa.unwrap_or(config.rope_config.freq_base);
                let swa_rope_dims = swa_head_dim; // SWA: full rotation

                // Global attention params
                let global_head_dim = config.head_dim;
                let global_kv_heads = config.num_kv_heads;
                let global_rope_freq_base = config.rope_config.freq_base;
                // Global rope_dims: head_dim * partial_rotary_factor (0.25 for Gemma 4)
                let prf = self.partial_rotary_factor.unwrap_or(1.0);
                let global_rope_dims = (global_head_dim as f32 * prf) as usize;

                // Build per-layer attention configs
                config.attention_layer_configs =
                    Some(ModelConfig::build_attention_layer_configs_from_pattern(
                        pattern,
                        swa_head_dim,
                        swa_kv_heads,
                        swa_rope_freq_base,
                        swa_rope_dims,
                        sliding_window,
                        global_head_dim,
                        global_kv_heads,
                        global_rope_freq_base,
                        global_rope_dims,
                    ));

                // Handle shared KV layers if specified
                if let Some(shared_layers) = self.num_key_value_shared_layers {
                    if shared_layers > 0 {
                        config.kv_source_layer = Some(ModelConfig::build_kv_source_mapping(
                            config.num_layers,
                            shared_layers,
                            config.attention_layer_configs.as_ref().unwrap(),
                        ));
                    }
                }
            }
        }

        Ok(config)
    }
}