Skip to main content

rlx_minicpm5/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// MiniCPM5 — OpenBMB edge LMs (e.g. [MiniCPM5-1B](https://huggingface.co/openbmb/MiniCPM5-1B)).
5//
6// Standard Llama decoder: GQA + RoPE + SwiGLU + RMSNorm, `LlamaForCausalLM`
7// weight layout. This crate wraps [`rlx_llama32::Llama32Runner`] with:
8//
9// * GGUF `general.architecture = llama` validation;
10// * HF `config.json` checks (`model_type = llama`, `LlamaForCausalLM`) for safetensors;
11// * a typed [`MiniCpm5Runner`] surface and `rlx-minicpm5` CLI binary.
12//
13// **How to run:** see [README.md](README.md) (download, CLI, chat, GGUF, examples).
14
15pub mod config;
16
17#[cfg(feature = "hf-download")]
18pub mod download;
19
20use anyhow::{Context, Result};
21use config::validate_weights_kind;
22use rlx_cli::WeightFormat;
23use rlx_llama_base::LlamaBaseConfig;
24use rlx_runtime::Device;
25use std::path::{Path, PathBuf};
26
27pub use config::{config_json_path, llama_config_from_hf, minicpm5_1b_preset};
28#[cfg(feature = "hf-download")]
29pub use download::{
30    default_hf_cache_dir, download_minicpm5_1b, download_minicpm5_gguf, fetch_minicpm5_1b,
31    fetch_minicpm5_gguf, materialize_minicpm5_1b, materialize_minicpm5_gguf,
32};
33pub use rlx_llama32::{Llama32Config, Llama32ConfigSource, Llama32Runner, Llama32RunnerBuilder};
34
35pub const FAMILY: &str = "MiniCPM5";
36/// HF model id for the 1B reference checkpoint.
37pub const HF_MODEL_ID_1B: &str = "openbmb/MiniCPM5-1B";
38/// GGUF quants (Q4_K_M, Q8_0, F16) on Hugging Face.
39pub const HF_MODEL_ID_GGUF: &str = "openbmb/MiniCPM5-1B-GGUF";
40
41/// Published GGUF filenames on Hugging Face (`openbmb/MiniCPM5-1B-GGUF`).
42pub const MINICPM5_GGUF_FILES: &[(&str, &str)] = &[
43    ("Q4_K_M", "MiniCPM5-1B-Q4_K_M.gguf"),
44    ("Q8_0", "MiniCPM5-1B-Q8_0.gguf"),
45    ("F16", "MiniCPM5-1B-F16.gguf"),
46];
47
48pub struct MiniCpm5Runner {
49    inner: Llama32Runner,
50    /// Parsed GGUF metadata when weights are GGUF; otherwise derived from HF config.
51    base: LlamaBaseConfig,
52}
53
54impl MiniCpm5Runner {
55    pub fn builder() -> MiniCpm5RunnerBuilder {
56        MiniCpm5RunnerBuilder::default()
57    }
58
59    pub fn base_config(&self) -> &LlamaBaseConfig {
60        &self.base
61    }
62
63    pub fn llama_config(&self) -> &Llama32Config {
64        self.inner.config()
65    }
66
67    pub fn inner(&self) -> &Llama32Runner {
68        &self.inner
69    }
70
71    pub fn inner_mut(&mut self) -> &mut Llama32Runner {
72        &mut self.inner
73    }
74
75    pub fn generate_packed(
76        &mut self,
77        prompt_ids: &[u32],
78        n_new: usize,
79        on_token: impl FnMut(u32),
80    ) -> Result<Vec<u32>> {
81        self.inner.generate_packed(prompt_ids, n_new, on_token)
82    }
83
84    /// KV-cached greedy generation (F32 weights; safetensors or GGUF dequant).
85    pub fn generate(
86        &mut self,
87        prompt_ids: &[u32],
88        n_new: usize,
89        on_token: impl FnMut(u32),
90    ) -> Result<Vec<u32>> {
91        self.inner.generate(prompt_ids, n_new, on_token)
92    }
93
94    /// Last-position logits after prefill.
95    pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
96        self.inner.predict_logits(prompt_ids)
97    }
98}
99
100#[derive(Debug, Clone, Default)]
101pub struct MiniCpm5RunnerBuilder {
102    weights: Option<PathBuf>,
103    inner: Llama32RunnerBuilder,
104}
105
106impl MiniCpm5RunnerBuilder {
107    pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
108        let p: PathBuf = path.into();
109        self.weights = Some(p.clone());
110        self.inner = self.inner.weights(p);
111        self
112    }
113
114    pub fn max_seq(mut self, n: usize) -> Self {
115        self.inner = self.inner.max_seq(n);
116        self
117    }
118
119    pub fn packed_weights(mut self, on: bool) -> Self {
120        self.inner = self.inner.packed_weights(on);
121        self
122    }
123
124    pub fn device(mut self, d: Device) -> Self {
125        self.inner = self.inner.device(d);
126        self
127    }
128
129    pub fn build(self) -> Result<MiniCpm5Runner> {
130        let weights = self
131            .weights
132            .as_ref()
133            .ok_or_else(|| anyhow::anyhow!("weights path required (call .weights(...))"))?
134            .clone();
135
136        validate_weights_kind(&weights)?;
137
138        let base = match WeightFormat::from_path(&weights)? {
139            WeightFormat::Gguf => LlamaBaseConfig::from_gguf_path(&weights)
140                .with_context(|| format!("rlx-minicpm5: parse GGUF {weights:?}"))?,
141            WeightFormat::Safetensors => llama_base_from_hf(&weights)?,
142        };
143
144        let inner = self
145            .inner
146            .build()
147            .context("rlx-minicpm5: building underlying Llama32Runner")?;
148
149        Ok(MiniCpm5Runner { inner, base })
150    }
151}
152
153fn llama_base_from_hf(weights_or_dir: &Path) -> Result<LlamaBaseConfig> {
154    let cfg = config::llama_config_from_hf(weights_or_dir)?;
155    Ok(LlamaBaseConfig {
156        arch: "llama".into(),
157        vocab_size: cfg.vocab_size,
158        hidden_size: cfg.hidden_size,
159        intermediate_size: cfg.intermediate_size,
160        num_hidden_layers: cfg.num_hidden_layers,
161        num_attention_heads: cfg.num_attention_heads,
162        num_key_value_heads: cfg.num_key_value_heads,
163        head_dim: cfg.head_dim,
164        rms_norm_eps: cfg.rms_norm_eps,
165        rope_theta: cfg.rope_theta,
166        rope_scaling: None,
167        sliding_window: None,
168        max_position_embeddings: cfg.max_position_embeddings,
169    })
170}
171
172/// CLI entry — delegates to `rlx_llama32::cli::run` after weight-kind checks.
173pub fn cli_run(args: &[String]) -> Result<()> {
174    if let Some(first) = args.iter().position(|a| a == "--weights") {
175        if let Some(path) = args.get(first + 1) {
176            validate_weights_kind(Path::new(path))?;
177        }
178    }
179    rlx_llama32::cli::run(args)
180}