use std::collections::HashMap;
use std::fs::{self, File};
use std::io::Write;
use std::path::PathBuf;
use axonml_tensor::Tensor;
use indicatif::{ProgressBar, ProgressStyle};
use crate::error::{LLMError, LLMResult};
const HF_API_BASE: &str = "https://huggingface.co";
pub struct HFLoader {
model_id: String,
cache_dir: PathBuf,
tensors: HashMap<String, TensorInfo>,
config: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub shape: Vec<usize>,
pub data: Vec<f32>,
pub dtype: String,
}
impl HFLoader {
pub fn new(model_id: &str) -> LLMResult<Self> {
let cache_dir = Self::get_cache_dir(model_id);
fs::create_dir_all(&cache_dir).map_err(|e| LLMError::IoError(e.to_string()))?;
Ok(Self {
model_id: model_id.to_string(),
cache_dir,
tensors: HashMap::new(),
config: None,
})
}
pub fn from_local(path: &str) -> LLMResult<Self> {
let cache_dir = PathBuf::from(path);
if !cache_dir.exists() {
return Err(LLMError::ModelNotFound(path.to_string()));
}
Ok(Self {
model_id: path.to_string(),
cache_dir,
tensors: HashMap::new(),
config: None,
})
}
fn get_cache_dir(model_id: &str) -> PathBuf {
let base = dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("axonml")
.join("hub");
let safe_id = model_id.replace('/', "--");
base.join(safe_id)
}
pub fn download_file(&self, filename: &str) -> LLMResult<PathBuf> {
let local_path = self.cache_dir.join(filename);
if local_path.exists() {
println!("Using cached: {}", local_path.display());
return Ok(local_path);
}
let url = format!(
"{}/{}/resolve/main/{}",
HF_API_BASE, self.model_id, filename
);
println!("Downloading: {}", url);
let pb = ProgressBar::new(0);
pb.set_style(
ProgressStyle::default_bar()
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})")
.unwrap()
.progress_chars("#>-"),
);
pb.set_message(format!("Downloading {}", filename));
let client = reqwest::blocking::Client::new();
let response = client
.get(&url)
.send()
.map_err(|e| LLMError::NetworkError(e.to_string()))?;
if !response.status().is_success() {
return Err(LLMError::NetworkError(format!(
"Failed to download {}: HTTP {}",
filename,
response.status()
)));
}
if let Some(len) = response.content_length() {
pb.set_length(len);
}
let bytes = response
.bytes()
.map_err(|e| LLMError::NetworkError(e.to_string()))?;
pb.set_position(bytes.len() as u64);
let mut file = File::create(&local_path).map_err(|e| LLMError::IoError(e.to_string()))?;
file.write_all(&bytes)
.map_err(|e| LLMError::IoError(e.to_string()))?;
pb.finish_with_message(format!("Downloaded {}", filename));
Ok(local_path)
}
pub fn load_config(&mut self) -> LLMResult<serde_json::Value> {
if let Some(ref config) = self.config {
return Ok(config.clone());
}
let path = self.download_file("config.json")?;
let content = fs::read_to_string(&path).map_err(|e| LLMError::IoError(e.to_string()))?;
let config: serde_json::Value =
serde_json::from_str(&content).map_err(|e| LLMError::ParseError(e.to_string()))?;
self.config = Some(config.clone());
Ok(config)
}
pub fn load_tensors(&mut self) -> LLMResult<()> {
let single_file = self.cache_dir.join("model.safetensors");
if single_file.exists() || self.download_file("model.safetensors").is_ok() {
return self.load_safetensors_file("model.safetensors");
}
let index_path = self.download_file("model.safetensors.index.json")?;
let index_content =
fs::read_to_string(&index_path).map_err(|e| LLMError::IoError(e.to_string()))?;
let index: serde_json::Value = serde_json::from_str(&index_content)
.map_err(|e| LLMError::ParseError(e.to_string()))?;
let weight_map = index["weight_map"]
.as_object()
.ok_or_else(|| LLMError::ParseError("Invalid index file".to_string()))?;
let mut shard_files: Vec<String> = weight_map
.values()
.filter_map(|v| v.as_str().map(String::from))
.collect();
shard_files.sort();
shard_files.dedup();
for shard in &shard_files {
self.download_file(shard)?;
self.load_safetensors_file(shard)?;
}
Ok(())
}
fn load_safetensors_file(&mut self, filename: &str) -> LLMResult<()> {
let path = self.cache_dir.join(filename);
let data = fs::read(&path).map_err(|e| LLMError::IoError(e.to_string()))?;
let tensors = safetensors::SafeTensors::deserialize(&data)
.map_err(|e| LLMError::ParseError(e.to_string()))?;
for (name, tensor) in tensors.tensors() {
let shape: Vec<usize> = tensor.shape().to_vec();
let dtype = format!("{:?}", tensor.dtype());
let data = self.convert_tensor_to_f32(&tensor)?;
self.tensors
.insert(name.to_string(), TensorInfo { shape, data, dtype });
}
println!("Loaded {} tensors from {}", tensors.len(), filename);
Ok(())
}
fn convert_tensor_to_f32(
&self,
tensor: &safetensors::tensor::TensorView,
) -> LLMResult<Vec<f32>> {
use safetensors::Dtype;
let data = tensor.data();
match tensor.dtype() {
Dtype::F32 => {
Ok(data
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect())
}
Dtype::F16 => {
Ok(data
.chunks_exact(2)
.map(|b| {
let bits = u16::from_le_bytes([b[0], b[1]]);
half::f16::from_bits(bits).to_f32()
})
.collect())
}
Dtype::BF16 => {
Ok(data
.chunks_exact(2)
.map(|b| {
let bits = u16::from_le_bytes([b[0], b[1]]);
half::bf16::from_bits(bits).to_f32()
})
.collect())
}
dtype => Err(LLMError::UnsupportedFormat(format!(
"Unsupported tensor dtype: {:?}",
dtype
))),
}
}
pub fn get_tensor(&self, name: &str) -> Option<&TensorInfo> {
self.tensors.get(name)
}
pub fn get_as_tensor(&self, name: &str) -> LLMResult<Tensor<f32>> {
let info = self
.tensors
.get(name)
.ok_or_else(|| LLMError::WeightNotFound(name.to_string()))?;
Tensor::from_vec(info.data.clone(), &info.shape)
.map_err(|e| LLMError::TensorError(e.to_string()))
}
pub fn tensor_names(&self) -> Vec<&str> {
self.tensors.keys().map(|s| s.as_str()).collect()
}
pub fn print_tensor_info(&self) {
println!("\nLoaded tensors:");
let mut names: Vec<_> = self.tensors.keys().collect();
names.sort();
for name in names {
let info = &self.tensors[name];
println!(" {} {:?} ({})", name, info.shape, info.dtype);
}
}
pub fn cache_dir(&self) -> &std::path::Path {
&self.cache_dir
}
pub fn model_id(&self) -> &str {
&self.model_id
}
pub fn download_file_if_exists(&self, filename: &str) -> LLMResult<bool> {
match self.download_file(filename) {
Ok(_) => Ok(true),
Err(LLMError::NetworkError(msg)) if msg.contains("404") || msg.contains("HTTP 4") => {
Ok(false)
}
Err(e) => Err(e),
}
}
pub fn tensors(&self) -> &HashMap<String, TensorInfo> {
&self.tensors
}
}
pub trait WeightMapper {
fn map_name(&self, hf_name: &str) -> Option<String>;
fn expected_weights(&self) -> Vec<String>;
}
pub struct LLaMAWeightMapper {
num_layers: usize,
}
impl LLaMAWeightMapper {
pub fn new(num_layers: usize) -> Self {
Self { num_layers }
}
}
impl WeightMapper for LLaMAWeightMapper {
fn map_name(&self, hf_name: &str) -> Option<String> {
let name = hf_name.strip_prefix("model.").unwrap_or(hf_name);
Some(name.to_string())
}
fn expected_weights(&self) -> Vec<String> {
let mut weights = vec!["embed_tokens.weight".to_string(), "norm.weight".to_string()];
for i in 0..self.num_layers {
weights.extend([
format!("layers.{}.self_attn.q_proj.weight", i),
format!("layers.{}.self_attn.k_proj.weight", i),
format!("layers.{}.self_attn.v_proj.weight", i),
format!("layers.{}.self_attn.o_proj.weight", i),
format!("layers.{}.mlp.gate_proj.weight", i),
format!("layers.{}.mlp.up_proj.weight", i),
format!("layers.{}.mlp.down_proj.weight", i),
format!("layers.{}.input_layernorm.weight", i),
format!("layers.{}.post_attention_layernorm.weight", i),
]);
}
weights
}
}
pub struct MistralWeightMapper {
num_layers: usize,
}
impl MistralWeightMapper {
pub fn new(num_layers: usize) -> Self {
Self { num_layers }
}
}
impl WeightMapper for MistralWeightMapper {
fn map_name(&self, hf_name: &str) -> Option<String> {
let name = hf_name.strip_prefix("model.").unwrap_or(hf_name);
Some(name.to_string())
}
fn expected_weights(&self) -> Vec<String> {
LLaMAWeightMapper::new(self.num_layers).expected_weights()
}
}
pub struct PhiWeightMapper {
num_layers: usize,
}
impl PhiWeightMapper {
pub fn new(num_layers: usize) -> Self {
Self { num_layers }
}
}
impl WeightMapper for PhiWeightMapper {
fn map_name(&self, hf_name: &str) -> Option<String> {
let name = hf_name
.strip_prefix("model.")
.or_else(|| hf_name.strip_prefix("transformer."))
.unwrap_or(hf_name);
Some(name.to_string())
}
fn expected_weights(&self) -> Vec<String> {
let mut weights = vec![
"embed_tokens.weight".to_string(),
"final_layernorm.weight".to_string(),
];
for i in 0..self.num_layers {
weights.extend([
format!("layers.{}.self_attn.q_proj.weight", i),
format!("layers.{}.self_attn.k_proj.weight", i),
format!("layers.{}.self_attn.v_proj.weight", i),
format!("layers.{}.self_attn.dense.weight", i),
format!("layers.{}.mlp.fc1.weight", i),
format!("layers.{}.mlp.fc2.weight", i),
format!("layers.{}.input_layernorm.weight", i),
]);
}
weights
}
}
pub fn load_llama_from_hf(
model_id: &str,
) -> LLMResult<(crate::LLaMAConfig, HashMap<String, Tensor<f32>>)> {
let mut loader = HFLoader::new(model_id)?;
let config_json = loader.load_config()?;
let config = parse_llama_config_from_json(&config_json)?;
loader.load_tensors()?;
let mapper = LLaMAWeightMapper::new(config.num_hidden_layers);
let mut weights = HashMap::new();
for (hf_name, tensor_info) in &loader.tensors {
if let Some(mapped_name) = mapper.map_name(hf_name) {
let tensor = Tensor::from_vec(tensor_info.data.clone(), &tensor_info.shape)
.map_err(|e| LLMError::TensorError(e.to_string()))?;
weights.insert(mapped_name, tensor);
}
}
Ok((config, weights))
}
pub fn parse_llama_config_from_json(json: &serde_json::Value) -> LLMResult<crate::LLaMAConfig> {
Ok(crate::LLaMAConfig {
vocab_size: json["vocab_size"].as_u64().unwrap_or(32000) as usize,
hidden_size: json["hidden_size"].as_u64().unwrap_or(4096) as usize,
intermediate_size: json["intermediate_size"].as_u64().unwrap_or(11008) as usize,
num_hidden_layers: json["num_hidden_layers"].as_u64().unwrap_or(32) as usize,
num_attention_heads: json["num_attention_heads"].as_u64().unwrap_or(32) as usize,
num_key_value_heads: json["num_key_value_heads"]
.as_u64()
.unwrap_or(json["num_attention_heads"].as_u64().unwrap_or(32))
as usize,
max_position_embeddings: json["max_position_embeddings"].as_u64().unwrap_or(4096) as usize,
rms_norm_eps: json["rms_norm_eps"].as_f64().unwrap_or(1e-5) as f32,
rope_theta: json["rope_theta"].as_f64().unwrap_or(10000.0) as f32,
attention_dropout: 0.0,
hidden_dropout: 0.0,
})
}
pub fn load_mistral_from_hf(
model_id: &str,
) -> LLMResult<(crate::MistralConfig, HashMap<String, Tensor<f32>>)> {
let mut loader = HFLoader::new(model_id)?;
let config_json = loader.load_config()?;
let config = parse_mistral_config(&config_json)?;
loader.load_tensors()?;
let mapper = MistralWeightMapper::new(config.num_hidden_layers);
let mut weights = HashMap::new();
for (hf_name, tensor_info) in &loader.tensors {
if let Some(mapped_name) = mapper.map_name(hf_name) {
let tensor = Tensor::from_vec(tensor_info.data.clone(), &tensor_info.shape)
.map_err(|e| LLMError::TensorError(e.to_string()))?;
weights.insert(mapped_name, tensor);
}
}
Ok((config, weights))
}
fn parse_mistral_config(json: &serde_json::Value) -> LLMResult<crate::MistralConfig> {
Ok(crate::MistralConfig {
vocab_size: json["vocab_size"].as_u64().unwrap_or(32000) as usize,
hidden_size: json["hidden_size"].as_u64().unwrap_or(4096) as usize,
intermediate_size: json["intermediate_size"].as_u64().unwrap_or(14336) as usize,
num_hidden_layers: json["num_hidden_layers"].as_u64().unwrap_or(32) as usize,
num_attention_heads: json["num_attention_heads"].as_u64().unwrap_or(32) as usize,
num_key_value_heads: json["num_key_value_heads"].as_u64().unwrap_or(8) as usize,
max_position_embeddings: json["max_position_embeddings"].as_u64().unwrap_or(32768) as usize,
sliding_window: json["sliding_window"].as_u64().unwrap_or(4096) as usize,
rms_norm_eps: json["rms_norm_eps"].as_f64().unwrap_or(1e-5) as f32,
rope_theta: json["rope_theta"].as_f64().unwrap_or(10000.0) as f32,
attention_dropout: 0.0,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_dir() {
let dir = HFLoader::get_cache_dir("meta-llama/Llama-2-7b-hf");
assert!(dir.to_string_lossy().contains("meta-llama--Llama-2-7b-hf"));
}
#[test]
fn test_llama_weight_mapper() {
let mapper = LLaMAWeightMapper::new(2);
assert_eq!(
mapper.map_name("model.embed_tokens.weight"),
Some("embed_tokens.weight".to_string())
);
assert_eq!(
mapper.map_name("model.layers.0.self_attn.q_proj.weight"),
Some("layers.0.self_attn.q_proj.weight".to_string())
);
}
#[test]
fn test_expected_weights() {
let mapper = LLaMAWeightMapper::new(2);
let weights = mapper.expected_weights();
assert!(weights.contains(&"embed_tokens.weight".to_string()));
assert!(weights.contains(&"layers.0.self_attn.q_proj.weight".to_string()));
assert!(weights.contains(&"layers.1.mlp.down_proj.weight".to_string()));
}
}