llama-rs 0.17.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
//! SafeTensors model loader implementing the [`ModelSource`] trait.
//!
//! Loads HuggingFace SafeTensors model directories (config.json + *.safetensors)
//! and provides tensor data to the shared [`build_llama_model`] pipeline.
//!
//! Key differences from GGUF loading:
//!
//! - **Shape transposition**: SafeTensors stores 2D weight matrices as
//!   `[out_features, in_features]` (PyTorch convention). The internal GGML
//!   convention is `[in_features, out_features]`. The raw bytes are identical;
//!   only the shape metadata is swapped.
//!
//! - **Gemma norm weight offset**: HuggingFace Gemma models apply
//!   `(1 + weight) * rms_norm(x)` in RMSNorm, storing the raw weight.
//!   The GGUF converter pre-adds +1, so GGUF weights are final-form.
//!   For SafeTensors, we add +1 to norm weights at load time so the
//!   shared `RMSNorm::forward()` (which expects final-form weights) works
//!   correctly.

use std::path::Path;

use crate::model::{Architecture, ModelConfig, ModelError, ModelResult, ModelSource};
use crate::model::hf_config::HfConfig;
use crate::tensor::Tensor;

use super::{ShardedSafeTensors, TensorNameMapper};
use super::reader::{bf16_to_f32, f16_to_f32, SafeTensorsDtype};

/// Loads a HuggingFace SafeTensors model directory and implements [`ModelSource`].
///
/// Expected directory layout:
/// ```text
/// model_dir/
///   config.json
///   model.safetensors            (single-file)
///   model.safetensors.index.json (sharded, optional)
///   model-00001-of-00004.safetensors ...
///   tokenizer.json               (used by Engine, not by this loader)
/// ```
pub struct SafeTensorsLoader {
    files: ShardedSafeTensors,
    name_mapper: TensorNameMapper,
    config: ModelConfig,
    architecture: Architecture,
}

impl SafeTensorsLoader {
    /// Open a SafeTensors model directory and parse its configuration.
    ///
    /// `path` can be either a directory or any file inside the model directory.
    pub fn load(path: &Path) -> ModelResult<Self> {
        // 1. Determine model directory
        let dir = if path.is_dir() {
            path.to_path_buf()
        } else {
            path.parent().unwrap_or(Path::new(".")).to_path_buf()
        };

        // 2. Open SafeTensors files (handles both single and sharded)
        let files = ShardedSafeTensors::open(&dir)
            .map_err(|e| ModelError::ConfigError(format!("SafeTensors: {e}")))?;

        // 3. Load and parse config.json
        let config_path = dir.join("config.json");
        let config_str = std::fs::read_to_string(&config_path)
            .map_err(|e| ModelError::ConfigError(
                format!("Failed to read {}: {e}", config_path.display())
            ))?;
        let hf_config = HfConfig::from_json(&config_str)?;

        // 4. Detect architecture from model_type
        let architecture = hf_config.architecture();
        let architecture = if matches!(architecture, Architecture::Unknown) {
            // Fallback: try from_gguf_str for broader coverage
            let model_type = hf_config.model_type.as_deref().unwrap_or("llama");
            Architecture::from_gguf_str(model_type)
        } else {
            architecture
        };

        // 5. Build ModelConfig from HF config
        let config = hf_config.to_model_config()?;

        // 6. Build bidirectional name mapper from actual tensor names
        let tensor_names = files.tensor_names();
        let name_mapper = TensorNameMapper::from_tensor_names(&tensor_names, architecture);

        tracing::info!(
            "SafeTensors model: {} tensors, {} mapped, arch={:?}",
            files.num_tensors(),
            name_mapper.len(),
            architecture,
        );

        Ok(Self { files, name_mapper, config, architecture })
    }

    /// Whether this architecture requires Gemma-style norm weight offset (+1).
    ///
    /// HuggingFace Gemma models store raw norm weights and apply `(1 + w)` in
    /// the forward pass. GGUF bakes the +1 into the stored weights, so GGUF
    /// loading is a no-op. For SafeTensors we must add +1 ourselves.
    fn needs_gemma_norm_offset(&self) -> bool {
        // Disabled: empirically, both GGUF (unsloth) and SafeTensors (Google)
        // store identical raw norm weights. The GGUF converter does NOT add +1
        // for Gemma 4, and the downstream RMSNorm uses w * rms_norm(x) directly.
        false
    }

    /// Check whether an internal tensor name refers to a norm weight that
    /// needs the Gemma +1 offset.
    fn is_norm_weight(name: &str) -> bool {
        name.ends_with("_norm.weight") || name == "output_norm.weight"
    }

    /// Apply +1 to every element of a norm weight tensor (Gemma only).
    fn apply_norm_offset(data: &mut [f32]) {
        for v in data.iter_mut() {
            *v += 1.0;
        }
    }
}

impl ModelSource for SafeTensorsLoader {
    fn config(&self) -> &ModelConfig {
        &self.config
    }

    fn config_mut(&mut self) -> &mut ModelConfig {
        &mut self.config
    }

    fn architecture(&self) -> Architecture {
        self.architecture
    }

    fn load_tensor(&self, name: &str) -> ModelResult<Tensor> {
        // Map internal name (e.g. "blk.0.attn_q.weight") to HF name
        let hf_name = self.name_mapper.to_hf(name)
            .ok_or_else(|| ModelError::MissingTensor(format!(
                "{name} (no HuggingFace mapping found)"
            )))?;

        let info = self.files.tensor_info(hf_name)
            .ok_or_else(|| ModelError::MissingTensor(hf_name.to_string()))?;
        let data = self.files.tensor_data(hf_name)
            .ok_or_else(|| ModelError::MissingTensor(hf_name.to_string()))?;

        // Convert to F32
        let mut f32_data = match info.dtype {
            SafeTensorsDtype::F32 => {
                // Parse F32 values from raw bytes. We avoid bytemuck::cast_slice
                // because the data offset within the mmap may not be 4-byte aligned,
                // causing an alignment panic.
                data.chunks_exact(4)
                    .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
                    .collect()
            }
            SafeTensorsDtype::BF16 => bf16_to_f32(data),
            SafeTensorsDtype::F16 => f16_to_f32(data),
            other => return Err(ModelError::ConfigError(
                format!("Unsupported SafeTensors dtype {:?} for tensor {name}", other)
            )),
        };

        // Gemma norm weight offset: add +1 to raw HF norm weights so the
        // shared RMSNorm forward pass (which expects GGUF-style final-form
        // weights) produces correct results.
        if self.needs_gemma_norm_offset() && Self::is_norm_weight(name) {
            Self::apply_norm_offset(&mut f32_data);
        }

        // Swap 2D shape from PyTorch [out, in] to GGML [in, out].
        // GGML uses column-major storage: W[i,j] at index i + j*k.
        // PyTorch uses row-major [out, in]: W[o,i] at index o*k + i.
        // These produce identical memory layouts, so only the shape metadata
        // needs to change — no data reordering required.
        let shape = if info.shape.len() == 2 {
            vec![info.shape[1], info.shape[0]]
        } else {
            info.shape.clone()
        };

        let mut tensor = Tensor::from_f32(&f32_data, shape)?;
        tensor.set_name(name);
        Ok(tensor)
    }

    fn try_load_tensor(&self, name: &str) -> Option<Tensor> {
        self.load_tensor(name).ok()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io::Write;
    use tempfile::TempDir;

    /// Create a minimal SafeTensors file with the given tensors.
    fn write_safetensors_file(
        path: &std::path::Path,
        tensors: &[(&str, Vec<usize>, &[f32])],
    ) {
        let mut data_buf = Vec::new();
        let mut header = serde_json::Map::new();

        for &(name, ref shape, values) in tensors {
            let start = data_buf.len();
            for &v in values {
                data_buf.extend_from_slice(&v.to_le_bytes());
            }
            let end = data_buf.len();

            let shape_json: Vec<serde_json::Value> = shape.iter()
                .map(|&d| serde_json::Value::Number(serde_json::Number::from(d)))
                .collect();

            let mut obj = serde_json::Map::new();
            obj.insert("dtype".into(), "F32".into());
            obj.insert("shape".into(), serde_json::Value::Array(shape_json));
            obj.insert("data_offsets".into(), serde_json::json!([start, end]));
            header.insert(name.to_string(), serde_json::Value::Object(obj));
        }

        let header_str = serde_json::to_string(&serde_json::Value::Object(header)).unwrap();
        let header_bytes = header_str.as_bytes();

        let mut f = std::fs::File::create(path).unwrap();
        f.write_all(&(header_bytes.len() as u64).to_le_bytes()).unwrap();
        f.write_all(header_bytes).unwrap();
        f.write_all(&data_buf).unwrap();
    }

    /// Write a minimal config.json for testing.
    fn write_config_json(dir: &std::path::Path, model_type: &str) {
        let config = serde_json::json!({
            "model_type": model_type,
            "hidden_size": 64,
            "intermediate_size": 128,
            "num_hidden_layers": 1,
            "num_attention_heads": 2,
            "num_key_value_heads": 2,
            "max_position_embeddings": 128,
            "rms_norm_eps": 1e-5,
            "vocab_size": 8,
            "head_dim": 32,
        });
        let path = dir.join("config.json");
        std::fs::write(path, serde_json::to_string_pretty(&config).unwrap()).unwrap();
    }

    #[test]
    fn test_load_safetensors_basic() {
        let dir = TempDir::new().unwrap();
        let dir_path = dir.path();

        // Write config.json
        write_config_json(dir_path, "llama");

        // Write a minimal model.safetensors with an embedding tensor
        let embd_data: Vec<f32> = (0..512).map(|i| i as f32 * 0.01).collect(); // 8 x 64
        let norm_data: Vec<f32> = vec![1.0; 64];
        write_safetensors_file(
            &dir_path.join("model.safetensors"),
            &[
                ("model.embed_tokens.weight", vec![8, 64], &embd_data),
                ("model.norm.weight", vec![64], &norm_data),
            ],
        );

        let loader = SafeTensorsLoader::load(dir_path).unwrap();
        assert_eq!(loader.architecture(), Architecture::Llama);
        assert_eq!(loader.config().hidden_size, 64);
        assert_eq!(loader.config().vocab_size, 8);

        // Load the embedding tensor
        let embd = loader.load_tensor("token_embd.weight").unwrap();
        // SafeTensors shape [8, 64] -> GGML shape [64, 8] (transposed)
        assert_eq!(embd.shape(), &[64, 8]);

        // 1D norm tensor shape should not be transposed
        let norm = loader.load_tensor("output_norm.weight").unwrap();
        assert_eq!(norm.shape(), &[64]);
    }

    #[test]
    fn test_missing_tensor_returns_error() {
        let dir = TempDir::new().unwrap();
        let dir_path = dir.path();

        write_config_json(dir_path, "llama");
        write_safetensors_file(
            &dir_path.join("model.safetensors"),
            &[("model.embed_tokens.weight", vec![8, 64], &vec![0.0; 512])],
        );

        let loader = SafeTensorsLoader::load(dir_path).unwrap();
        assert!(loader.load_tensor("nonexistent.weight").is_err());
    }

    #[test]
    fn test_try_load_tensor_returns_none() {
        let dir = TempDir::new().unwrap();
        let dir_path = dir.path();

        write_config_json(dir_path, "llama");
        write_safetensors_file(
            &dir_path.join("model.safetensors"),
            &[("model.embed_tokens.weight", vec![8, 64], &vec![0.0; 512])],
        );

        let loader = SafeTensorsLoader::load(dir_path).unwrap();
        assert!(loader.try_load_tensor("nonexistent.weight").is_none());
        assert!(loader.try_load_tensor("token_embd.weight").is_some());
    }

    #[test]
    fn test_gemma_norm_no_offset() {
        // Gemma 4 norm weights are stored identically in GGUF and SafeTensors
        // (both use raw values). No +1 offset is needed.
        let dir = TempDir::new().unwrap();
        let dir_path = dir.path();

        let config = serde_json::json!({
            "model_type": "gemma4",
            "hidden_size": 4,
            "intermediate_size": 8,
            "num_hidden_layers": 1,
            "num_attention_heads": 1,
            "num_key_value_heads": 1,
            "max_position_embeddings": 16,
            "rms_norm_eps": 1e-5,
            "vocab_size": 4,
            "head_dim": 4,
        });
        std::fs::write(dir_path.join("config.json"),
            serde_json::to_string_pretty(&config).unwrap()
        ).unwrap();

        let norm_data = vec![9.375f32; 4];
        write_safetensors_file(
            &dir_path.join("model.safetensors"),
            &[
                ("model.norm.weight", vec![4], &norm_data),
                ("model.embed_tokens.weight", vec![4, 4], &vec![0.5; 16]),
            ],
        );

        let loader = SafeTensorsLoader::load(dir_path).unwrap();
        let norm = loader.load_tensor("output_norm.weight").unwrap();
        let norm_vals = norm.as_f32().unwrap();
        // Values should be unchanged (no +1 offset)
        for &v in norm_vals.iter() {
            assert!((v - 9.375).abs() < 1e-6, "Expected 9.375, got {v}");
        }
    }

    #[test]
    fn test_non_gemma_no_norm_offset() {
        let dir = TempDir::new().unwrap();
        let dir_path = dir.path();

        write_config_json(dir_path, "llama");

        let norm_data = vec![0.5f32; 64];
        write_safetensors_file(
            &dir_path.join("model.safetensors"),
            &[
                ("model.norm.weight", vec![64], &norm_data),
                ("model.embed_tokens.weight", vec![8, 64], &vec![0.0; 512]),
            ],
        );

        let loader = SafeTensorsLoader::load(dir_path).unwrap();
        assert!(!loader.architecture().is_gemma());

        let norm = loader.load_tensor("output_norm.weight").unwrap();
        let norm_vals = norm.as_f32().unwrap();
        // No offset for non-Gemma: values should remain 0.5
        for &v in norm_vals.iter() {
            assert!((v - 0.5).abs() < 1e-6, "Expected 0.5, got {v}");
        }
    }
}