llama-rs 0.16.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
//! Bidirectional mapper between HuggingFace and llama.cpp-style tensor names.

use std::collections::HashMap;
use crate::model::Architecture;

/// Bidirectional mapper between HuggingFace and internal tensor names.
pub struct TensorNameMapper {
    /// HuggingFace name → internal name
    hf_to_internal: HashMap<String, String>,
    /// Internal name → HuggingFace name
    internal_to_hf: HashMap<String, String>,
}

impl TensorNameMapper {
    /// Build the mapper by scanning tensor names from SafeTensors files.
    ///
    /// `hf_names` is the list of all tensor names found in the SafeTensors file(s).
    /// The mapper matches each against known patterns and builds the bidirectional mapping.
    pub fn from_tensor_names(hf_names: &[String], architecture: Architecture) -> Self {
        let mut hf_to_internal = HashMap::new();
        let mut internal_to_hf = HashMap::new();

        let is_gemma = architecture.is_gemma();

        for hf_name in hf_names {
            if let Some(internal) = Self::map_hf_to_internal(hf_name, is_gemma) {
                hf_to_internal.insert(hf_name.clone(), internal.clone());
                internal_to_hf.insert(internal, hf_name.clone());
            }
        }

        Self { hf_to_internal, internal_to_hf }
    }

    /// Look up internal name from HuggingFace name.
    pub fn to_internal(&self, hf_name: &str) -> Option<&str> {
        self.hf_to_internal.get(hf_name).map(|s| s.as_str())
    }

    /// Look up HuggingFace name from internal name.
    pub fn to_hf(&self, internal_name: &str) -> Option<&str> {
        self.internal_to_hf.get(internal_name).map(|s| s.as_str())
    }

    /// Number of mapped tensors.
    pub fn len(&self) -> usize {
        self.hf_to_internal.len()
    }

    pub fn is_empty(&self) -> bool {
        self.hf_to_internal.is_empty()
    }

    fn map_hf_to_internal(hf_name: &str, is_gemma: bool) -> Option<String> {
        // Strip common prefixes to normalize across model types.
        // Gemma 4 multimodal uses "model.language_model." prefix;
        // standard models use "model." prefix.
        let normalized = hf_name
            .strip_prefix("model.language_model.")
            .or_else(|| hf_name.strip_prefix("model."))
            .unwrap_or(hf_name);

        // Non-layer tensors
        match normalized {
            "embed_tokens.weight" => return Some("token_embd.weight".into()),
            "norm.weight" => return Some("output_norm.weight".into()),
            // Gemma 4 PLIE shared tensors
            "per_layer_model_projection.weight" | "per_layer_model_proj.weight" =>
                return Some("per_layer_model_proj.weight".into()),
            "embed_tokens_per_layer.weight" | "per_layer_token_embd.weight" =>
                return Some("per_layer_token_embd.weight".into()),
            "per_layer_projection_norm.weight" | "per_layer_proj_norm.weight" =>
                return Some("per_layer_proj_norm.weight".into()),
            _ => {}
        }

        // Top-level lm_head (no prefix stripping needed)
        if hf_name == "lm_head.weight" {
            return Some("output.weight".into());
        }

        // Layer tensors: layers.{i}.suffix
        if let Some(rest) = normalized.strip_prefix("layers.") {
            let dot_pos = rest.find('.')?;
            let layer_num: usize = rest[..dot_pos].parse().ok()?;
            let suffix = &rest[dot_pos + 1..];

            let internal_suffix = match suffix {
                // Attention projections
                "self_attn.q_proj.weight" => "attn_q.weight",
                "self_attn.k_proj.weight" => "attn_k.weight",
                "self_attn.v_proj.weight" => "attn_v.weight",
                "self_attn.o_proj.weight" => "attn_output.weight",
                "self_attn.q_proj.bias" => "attn_q.bias",
                "self_attn.k_proj.bias" => "attn_k.bias",
                "self_attn.v_proj.bias" => "attn_v.bias",
                "self_attn.o_proj.bias" => "attn_output.bias",

                // Attention normalization
                "self_attn.q_norm.weight" => "attn_q_norm.weight",
                "self_attn.k_norm.weight" => "attn_k_norm.weight",

                // MLP projections
                "mlp.gate_proj.weight" => "ffn_gate.weight",
                "mlp.up_proj.weight" => "ffn_up.weight",
                "mlp.down_proj.weight" => "ffn_down.weight",
                "mlp.gate_proj.bias" => "ffn_gate.bias",
                "mlp.up_proj.bias" => "ffn_up.bias",
                "mlp.down_proj.bias" => "ffn_down.bias",

                // Norm layers — Gemma uses different naming than LLaMA
                "input_layernorm.weight" => "attn_norm.weight",
                "post_attention_layernorm.weight" => {
                    if is_gemma { "post_attention_norm.weight" } else { "ffn_norm.weight" }
                }
                "pre_feedforward_layernorm.weight" => "ffn_norm.weight",
                "post_feedforward_layernorm.weight" => "post_ffw_norm.weight",

                // Gemma 4 PLIE per-layer tensors
                "per_layer_input_gate.weight" => "inp_gate.weight",
                "per_layer_projection.weight" | "per_layer_proj.weight" => "proj.weight",
                "post_per_layer_input_norm.weight" | "per_layer_post_norm.weight" =>
                    "post_norm.weight",
                "layer_scalar" | "layer_output_scale" => "layer_output_scale.weight",

                _ => return None,
            };

            return Some(format!("blk.{}.{}", layer_num, internal_suffix));
        }

        None
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make_mapper(names: &[&str], arch: Architecture) -> TensorNameMapper {
        let names: Vec<String> = names.iter().map(|s| s.to_string()).collect();
        TensorNameMapper::from_tensor_names(&names, arch)
    }

    #[test]
    fn test_llama_basic_mapping() {
        let names = vec![
            "model.embed_tokens.weight",
            "model.layers.0.self_attn.q_proj.weight",
            "model.layers.0.self_attn.k_proj.weight",
            "model.layers.0.self_attn.v_proj.weight",
            "model.layers.0.self_attn.o_proj.weight",
            "model.layers.0.mlp.gate_proj.weight",
            "model.layers.0.mlp.up_proj.weight",
            "model.layers.0.mlp.down_proj.weight",
            "model.layers.0.input_layernorm.weight",
            "model.layers.0.post_attention_layernorm.weight",
            "model.norm.weight",
            "lm_head.weight",
        ];
        let m = make_mapper(&names, Architecture::Llama3);

        assert_eq!(m.to_internal("model.embed_tokens.weight"), Some("token_embd.weight"));
        assert_eq!(m.to_internal("model.layers.0.self_attn.q_proj.weight"), Some("blk.0.attn_q.weight"));
        assert_eq!(m.to_internal("model.layers.0.mlp.gate_proj.weight"), Some("blk.0.ffn_gate.weight"));
        assert_eq!(m.to_internal("model.layers.0.post_attention_layernorm.weight"), Some("blk.0.ffn_norm.weight"));
        assert_eq!(m.to_internal("model.norm.weight"), Some("output_norm.weight"));
        assert_eq!(m.to_internal("lm_head.weight"), Some("output.weight"));

        // Reverse lookup
        assert_eq!(m.to_hf("token_embd.weight"), Some("model.embed_tokens.weight"));
        assert_eq!(m.to_hf("blk.0.attn_q.weight"), Some("model.layers.0.self_attn.q_proj.weight"));
    }

    #[test]
    fn test_gemma4_actual_safetensors_names() {
        // Actual tensor names from google/gemma-4-E2B-it SafeTensors file
        let names = vec![
            "model.language_model.embed_tokens.weight",
            "model.language_model.norm.weight",
            "model.language_model.layers.0.self_attn.q_proj.weight",
            "model.language_model.layers.0.self_attn.k_proj.weight",
            "model.language_model.layers.0.self_attn.v_proj.weight",
            "model.language_model.layers.0.self_attn.o_proj.weight",
            "model.language_model.layers.0.self_attn.q_norm.weight",
            "model.language_model.layers.0.self_attn.k_norm.weight",
            "model.language_model.layers.0.mlp.gate_proj.weight",
            "model.language_model.layers.0.mlp.up_proj.weight",
            "model.language_model.layers.0.mlp.down_proj.weight",
            "model.language_model.layers.0.input_layernorm.weight",
            "model.language_model.layers.0.post_attention_layernorm.weight",
            "model.language_model.layers.0.pre_feedforward_layernorm.weight",
            "model.language_model.layers.0.post_feedforward_layernorm.weight",
            "model.language_model.layers.0.per_layer_input_gate.weight",
            "model.language_model.layers.0.per_layer_projection.weight",
            "model.language_model.layers.0.post_per_layer_input_norm.weight",
            "model.language_model.layers.0.layer_scalar",
            "model.language_model.embed_tokens_per_layer.weight",
            "model.language_model.per_layer_model_projection.weight",
            "model.language_model.per_layer_projection_norm.weight",
        ];
        let m = make_mapper(&names, Architecture::Gemma4);

        // Embedding and output
        assert_eq!(m.to_internal("model.language_model.embed_tokens.weight"), Some("token_embd.weight"));
        assert_eq!(m.to_internal("model.language_model.norm.weight"), Some("output_norm.weight"));

        // Standard attention
        assert_eq!(m.to_internal("model.language_model.layers.0.self_attn.q_proj.weight"), Some("blk.0.attn_q.weight"));
        assert_eq!(m.to_internal("model.language_model.layers.0.self_attn.o_proj.weight"), Some("blk.0.attn_output.weight"));

        // QK norm
        assert_eq!(m.to_internal("model.language_model.layers.0.self_attn.q_norm.weight"), Some("blk.0.attn_q_norm.weight"));

        // FFN
        assert_eq!(m.to_internal("model.language_model.layers.0.mlp.gate_proj.weight"), Some("blk.0.ffn_gate.weight"));

        // Gemma norms
        assert_eq!(m.to_internal("model.language_model.layers.0.input_layernorm.weight"), Some("blk.0.attn_norm.weight"));
        assert_eq!(m.to_internal("model.language_model.layers.0.post_attention_layernorm.weight"), Some("blk.0.post_attention_norm.weight"));
        assert_eq!(m.to_internal("model.language_model.layers.0.pre_feedforward_layernorm.weight"), Some("blk.0.ffn_norm.weight"));
        assert_eq!(m.to_internal("model.language_model.layers.0.post_feedforward_layernorm.weight"), Some("blk.0.post_ffw_norm.weight"));

        // PLIE per-layer
        assert_eq!(m.to_internal("model.language_model.layers.0.per_layer_input_gate.weight"), Some("blk.0.inp_gate.weight"));
        assert_eq!(m.to_internal("model.language_model.layers.0.per_layer_projection.weight"), Some("blk.0.proj.weight"));
        assert_eq!(m.to_internal("model.language_model.layers.0.post_per_layer_input_norm.weight"), Some("blk.0.post_norm.weight"));
        assert_eq!(m.to_internal("model.language_model.layers.0.layer_scalar"), Some("blk.0.layer_output_scale.weight"));

        // PLIE shared
        assert_eq!(m.to_internal("model.language_model.embed_tokens_per_layer.weight"), Some("per_layer_token_embd.weight"));
        assert_eq!(m.to_internal("model.language_model.per_layer_model_projection.weight"), Some("per_layer_model_proj.weight"));
        assert_eq!(m.to_internal("model.language_model.per_layer_projection_norm.weight"), Some("per_layer_proj_norm.weight"));
    }

    #[test]
    fn test_multi_layer_mapping() {
        let names = vec![
            "model.layers.0.self_attn.q_proj.weight",
            "model.layers.15.self_attn.q_proj.weight",
            "model.layers.34.self_attn.q_proj.weight",
        ];
        let m = make_mapper(&names, Architecture::Llama3);

        assert_eq!(m.to_internal("model.layers.15.self_attn.q_proj.weight"), Some("blk.15.attn_q.weight"));
        assert_eq!(m.to_internal("model.layers.34.self_attn.q_proj.weight"), Some("blk.34.attn_q.weight"));
    }

    #[test]
    fn test_unknown_name_returns_none() {
        let names = vec!["some.random.tensor"];
        let m = make_mapper(&names, Architecture::Llama3);
        assert_eq!(m.to_internal("some.random.tensor"), None);
        assert_eq!(m.to_hf("nonexistent"), None);
    }

    #[test]
    fn test_gemma_vs_llama_post_attention_norm() {
        let names = vec!["model.layers.0.post_attention_layernorm.weight"];

        let llama = make_mapper(&names, Architecture::Llama3);
        assert_eq!(llama.to_internal("model.layers.0.post_attention_layernorm.weight"), Some("blk.0.ffn_norm.weight"));

        let gemma = make_mapper(&names, Architecture::Gemma4);
        assert_eq!(gemma.to_internal("model.layers.0.post_attention_layernorm.weight"), Some("blk.0.post_attention_norm.weight"));
    }

    #[test]
    fn test_bias_tensors() {
        let names = vec![
            "model.layers.0.self_attn.q_proj.bias",
            "model.layers.0.mlp.gate_proj.bias",
        ];
        let m = make_mapper(&names, Architecture::Llama3);

        assert_eq!(m.to_internal("model.layers.0.self_attn.q_proj.bias"), Some("blk.0.attn_q.bias"));
        assert_eq!(m.to_internal("model.layers.0.mlp.gate_proj.bias"), Some("blk.0.ffn_gate.bias"));
    }
}