use std::collections::BTreeSet;
use std::path::PathBuf;
use anyhow::{Context, Result};
use hf_hub::{
Repo, RepoType,
api::sync::{Api, ApiRepo},
};
use mlx_rs::module::ModuleParametersExt;
use tokenizers::Tokenizer;
use crate::engine::config::{GenerationConfig, HfConfig, ModelArgs};
use crate::engine::model::Model;
use crate::engine::template::ChatTemplate;
pub struct Loaded {
pub model: Model,
pub tokenizer: Tokenizer,
pub args: ModelArgs,
pub eos: Vec<u32>,
pub model_type: String,
pub chat_template: ChatTemplate,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Quantize {
Q2,
Q3,
Q4,
Q8,
}
impl Quantize {
pub fn bits(self) -> i32 {
match self {
Quantize::Q2 => 2,
Quantize::Q3 => 3,
Quantize::Q4 => 4,
Quantize::Q8 => 8,
}
}
pub fn group_size(self) -> i32 {
64
}
pub fn from_bits(bits: i32) -> Option<Quantize> {
match bits {
2 => Some(Quantize::Q2),
3 => Some(Quantize::Q3),
4 => Some(Quantize::Q4),
8 => Some(Quantize::Q8),
_ => None,
}
}
}
pub fn load(model_id: &str, quant: Option<Quantize>) -> Result<Loaded> {
let api = Api::new()?;
let repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));
let config_path = repo.get("config.json").context("fetch config.json")?;
let tokenizer_path = repo.get("tokenizer.json").context("fetch tokenizer.json")?;
let weight_paths = resolve_weights(&repo).context("resolve model weight files")?;
let gen_cfg: GenerationConfig = match repo.get("generation_config.json") {
Ok(p) => serde_json::from_reader(std::fs::File::open(p)?)?,
Err(_) => GenerationConfig::default(),
};
let eos = gen_cfg.eos_or_default();
let hf: HfConfig = serde_json::from_reader(std::fs::File::open(config_path)?)?;
let model_type = hf.model_type.clone();
let args: ModelArgs = hf.into();
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| anyhow::anyhow!(e))?;
let chat_template = build_chat_template(&repo);
let mut model = Model::new(&args)?;
for p in &weight_paths {
model
.load_safetensors(p)
.context("load_safetensors (struct field names must match HF tensor keys)")?;
}
if args.tie_word_embeddings {
model.tie_lm_head();
}
if let Some(q) = quant {
model = mlx_rs::nn::quantize(model, Some(q.group_size()), Some(q.bits()))?;
}
Ok(Loaded {
model,
tokenizer,
args,
eos,
model_type,
chat_template,
})
}
fn build_chat_template(repo: &ApiRepo) -> ChatTemplate {
let cfg = repo
.get("tokenizer_config.json")
.ok()
.and_then(|p| std::fs::File::open(p).ok())
.and_then(|f| serde_json::from_reader::<_, serde_json::Value>(f).ok());
let Some(cfg) = cfg else {
return ChatTemplate::chatml_only();
};
let template = cfg
.get("chat_template")
.and_then(|v| v.as_str())
.map(String::from);
ChatTemplate::new(template, token_str(cfg.get("bos_token")), token_str(cfg.get("eos_token")))
}
fn token_str(v: Option<&serde_json::Value>) -> String {
match v {
Some(serde_json::Value::String(s)) => s.clone(),
Some(serde_json::Value::Object(o)) => o
.get("content")
.and_then(|c| c.as_str())
.unwrap_or_default()
.to_string(),
_ => String::new(),
}
}
pub fn resolve_weights(repo: &ApiRepo) -> Result<Vec<PathBuf>> {
if let Ok(index_path) = repo.get("model.safetensors.index.json") {
let v: serde_json::Value = serde_json::from_reader(std::fs::File::open(index_path)?)?;
let map = v
.get("weight_map")
.and_then(|m| m.as_object())
.context("index.json missing weight_map")?;
let shards: BTreeSet<String> = map
.values()
.filter_map(|x| x.as_str().map(String::from))
.collect();
let mut paths = Vec::with_capacity(shards.len());
for shard in shards {
paths.push(repo.get(&shard).with_context(|| format!("fetch {shard}"))?);
}
return Ok(paths);
}
Ok(vec![
repo.get("model.safetensors")
.context("fetch model.safetensors")?,
])
}