Skip to main content

sapient_models/
weights.rs

1//! HuggingFace safetensors weight loading and key resolution.
2
3use std::collections::HashMap;
4use std::path::PathBuf;
5
6use anyhow::{bail, Context, Result};
7use sapient_core::Tensor;
8use sapient_io::SafetensorsLoader;
9
10/// Load and merge safetensors shards from disk.
11pub fn load_hf_weights(paths: &[PathBuf]) -> Result<HashMap<String, Tensor>> {
12    let mut merged = HashMap::new();
13    for path in paths {
14        let shard = SafetensorsLoader::load(path)
15            .with_context(|| format!("failed to load weights from {}", path.display()))?;
16        for (k, v) in shard {
17            if merged.insert(k.clone(), v).is_some() {
18                bail!("duplicate weight key '{k}' in shard {}", path.display());
19            }
20        }
21    }
22    Ok(merged)
23}
24
25/// Detect the common prefix for transformer weight keys.
26pub fn detect_weight_prefix(weights: &HashMap<String, Tensor>) -> String {
27    const CANDIDATES: &[&str] = &[
28        "model.text_model.",
29        "model.language_model.",
30        "transformer.",
31        "model.",
32        "gpt_neox.",
33    ];
34
35    for prefix in CANDIDATES {
36        let embed_key = format!("{prefix}embed_tokens.weight");
37        if weights.contains_key(&embed_key) {
38            return prefix.to_string();
39        }
40    }
41
42    if weights.contains_key("embed_tokens.weight") {
43        return String::new();
44    }
45
46    // Fall back: find any embed_tokens key.
47    weights
48        .keys()
49        .find(|k| k.ends_with("embed_tokens.weight"))
50        .map(|k| {
51            k.strip_suffix("embed_tokens.weight")
52                .unwrap_or("")
53                .to_string()
54        })
55        .unwrap_or_else(|| "model.".to_string())
56}
57
58/// Resolve a weight tensor by logical suffix (e.g. `layers.0.self_attn.q_proj`).
59pub fn resolve_weight<'a>(
60    weights: &'a HashMap<String, Tensor>,
61    prefix: &str,
62    suffix: &str,
63) -> Result<&'a Tensor> {
64    let key = format!("{prefix}{suffix}.weight");
65    weights
66        .get(&key)
67        .or_else(|| weights.get(suffix))
68        .with_context(|| format!("missing weight '{key}'"))
69}
70
71/// Resolve an optional bias tensor by logical suffix (e.g. `layers.0.self_attn.q_proj`).
72/// Returns `None` when the model has no bias for that layer (e.g. Llama/Mistral).
73pub fn resolve_bias<'a>(
74    weights: &'a HashMap<String, Tensor>,
75    prefix: &str,
76    suffix: &str,
77) -> Option<&'a Tensor> {
78    let key = format!("{prefix}{suffix}.bias");
79    weights
80        .get(&key)
81        .or_else(|| weights.get(&format!("{suffix}.bias")))
82}
83
84/// Resolve lm_head — may live outside the model prefix.
85pub fn resolve_lm_head<'a>(
86    weights: &'a HashMap<String, Tensor>,
87    prefix: &str,
88    tie_word_embeddings: bool,
89    embed_key: &str,
90) -> Result<&'a Tensor> {
91    if tie_word_embeddings {
92        return weights
93            .get(embed_key)
94            .with_context(|| format!("missing tied embedding weight '{embed_key}'"));
95    }
96
97    weights
98        .get("lm_head.weight")
99        .or_else(|| weights.get(&format!("{prefix}lm_head.weight")))
100        .with_context(|| "missing lm_head.weight")
101}
102
103pub fn tie_word_embeddings_from_config(raw: &serde_json::Value) -> bool {
104    raw.get("tie_word_embeddings")
105        .and_then(|v| v.as_bool())
106        .or_else(|| {
107            raw.get("text_config")
108                .and_then(|tc| tc.get("tie_word_embeddings"))
109                .and_then(|v| v.as_bool())
110        })
111        .unwrap_or(false)
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn detect_text_model_prefix() {
120        let mut w = HashMap::new();
121        w.insert(
122            "model.text_model.embed_tokens.weight".into(),
123            Tensor::zeros(vec![1, 1], sapient_core::DType::F32).unwrap(),
124        );
125        assert_eq!(detect_weight_prefix(&w), "model.text_model.");
126    }
127}