use anyhow::{Context, Result, bail};
use rlx_cli::WeightFormat;
use rlx_llama_base::LlamaBaseConfig;
use rlx_llama32::Llama32Config;
use serde::Deserialize;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Deserialize)]
struct HfConfigProbe {
#[serde(default)]
model_type: Option<String>,
#[serde(default)]
architectures: Option<Vec<String>>,
#[serde(default)]
_name_or_path: Option<String>,
}
pub fn config_json_path(weights_or_dir: &Path) -> PathBuf {
if weights_or_dir.is_dir() {
return weights_or_dir.join("config.json");
}
weights_or_dir
.parent()
.map(|p| p.join("config.json"))
.unwrap_or_else(|| PathBuf::from("config.json"))
}
pub fn llama_config_from_hf(weights_or_dir: &Path) -> Result<Llama32Config> {
let cfg_path = config_json_path(weights_or_dir);
Llama32Config::from_file(&cfg_path)
.with_context(|| format!("reading MiniCPM5 HF config {cfg_path:?}"))
}
pub fn validate_hf_config(weights_or_dir: &Path) -> Result<()> {
let cfg_path = config_json_path(weights_or_dir);
let raw =
std::fs::read_to_string(&cfg_path).with_context(|| format!("reading {cfg_path:?}"))?;
let probe: HfConfigProbe =
serde_json::from_str(&raw).with_context(|| format!("parsing {cfg_path:?}"))?;
match probe.model_type.as_deref() {
Some("llama") => {}
Some(other) => bail!(
"rlx-minicpm5: {cfg_path:?} has model_type={other:?}; expected `llama` \
(MiniCPM5-1B is LlamaForCausalLM-shaped)"
),
None => bail!("rlx-minicpm5: {cfg_path:?} missing model_type"),
}
if let Some(archs) = &probe.architectures {
let ok = archs.iter().any(|a| a == "LlamaForCausalLM");
if !ok {
bail!(
"rlx-minicpm5: {cfg_path:?} architectures={archs:?}; \
expected LlamaForCausalLM (MiniCPM5)"
);
}
}
Ok(())
}
pub fn validate_weights_kind(weights: &Path) -> Result<()> {
match WeightFormat::from_path(weights)? {
WeightFormat::Gguf => {
let cfg = LlamaBaseConfig::from_gguf_path(weights)
.with_context(|| format!("rlx-minicpm5: parse GGUF {weights:?}"))?;
if cfg.arch != "llama" {
bail!(
"rlx-minicpm5: expected `general.architecture = llama`; \
got `{}` at {weights:?}",
cfg.arch
);
}
}
WeightFormat::Safetensors => validate_hf_config(weights)?,
}
Ok(())
}
pub fn minicpm5_1b_preset() -> Llama32Config {
Llama32Config {
vocab_size: 130_560,
hidden_size: 1536,
intermediate_size: 4608,
num_hidden_layers: 24,
num_attention_heads: 16,
num_key_value_heads: 2,
max_position_embeddings: 131_072,
rms_norm_eps: 1e-6,
rope_theta: 5_000_000.0,
hidden_act: "silu".into(),
tie_word_embeddings: false,
attention_bias: false,
head_dim: Some(128),
rope_scaling: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn preset_matches_hf_card() {
let p = minicpm5_1b_preset();
assert_eq!(p.hidden_size, 1536);
assert_eq!(p.num_hidden_layers, 24);
assert_eq!(p.head_dim(), 128);
assert!((p.rope_theta - 5_000_000.0).abs() < 1.0);
}
#[test]
fn validates_llama_model_type() {
let dir = std::env::temp_dir().join(format!(
"rlx_minicpm5_cfg_{}_{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
std::fs::create_dir_all(&dir).unwrap();
std::fs::write(
dir.join("config.json"),
r#"{
"model_type": "llama",
"architectures": ["LlamaForCausalLM"],
"vocab_size": 130560,
"hidden_size": 1536,
"intermediate_size": 4608,
"num_hidden_layers": 24,
"num_attention_heads": 16,
"num_key_value_heads": 2,
"head_dim": 128,
"max_position_embeddings": 131072,
"rms_norm_eps": 1e-6,
"rope_theta": 5000000
}"#,
)
.unwrap();
validate_hf_config(&dir).expect("valid llama config");
let cfg = llama_config_from_hf(&dir).unwrap();
assert_eq!(cfg.hidden_size, 1536);
std::fs::remove_dir_all(&dir).ok();
}
}