1pub mod config;
16
17#[cfg(feature = "hf-download")]
18pub mod download;
19
20use anyhow::{Context, Result};
21use config::validate_weights_kind;
22use rlx_cli::WeightFormat;
23use rlx_llama_base::LlamaBaseConfig;
24use rlx_runtime::Device;
25use std::path::{Path, PathBuf};
26
27pub use config::{config_json_path, llama_config_from_hf, minicpm5_1b_preset};
28#[cfg(feature = "hf-download")]
29pub use download::{
30 default_hf_cache_dir, download_minicpm5_1b, download_minicpm5_gguf, fetch_minicpm5_1b,
31 fetch_minicpm5_gguf, materialize_minicpm5_1b, materialize_minicpm5_gguf,
32};
33pub use rlx_llama32::{Llama32Config, Llama32ConfigSource, Llama32Runner, Llama32RunnerBuilder};
34
35pub const FAMILY: &str = "MiniCPM5";
36pub const HF_MODEL_ID_1B: &str = "openbmb/MiniCPM5-1B";
38pub const HF_MODEL_ID_GGUF: &str = "openbmb/MiniCPM5-1B-GGUF";
40
41pub const MINICPM5_GGUF_FILES: &[(&str, &str)] = &[
43 ("Q4_K_M", "MiniCPM5-1B-Q4_K_M.gguf"),
44 ("Q8_0", "MiniCPM5-1B-Q8_0.gguf"),
45 ("F16", "MiniCPM5-1B-F16.gguf"),
46];
47
48pub struct MiniCpm5Runner {
49 inner: Llama32Runner,
50 base: LlamaBaseConfig,
52}
53
54impl MiniCpm5Runner {
55 pub fn builder() -> MiniCpm5RunnerBuilder {
56 MiniCpm5RunnerBuilder::default()
57 }
58
59 pub fn base_config(&self) -> &LlamaBaseConfig {
60 &self.base
61 }
62
63 pub fn llama_config(&self) -> &Llama32Config {
64 self.inner.config()
65 }
66
67 pub fn inner(&self) -> &Llama32Runner {
68 &self.inner
69 }
70
71 pub fn inner_mut(&mut self) -> &mut Llama32Runner {
72 &mut self.inner
73 }
74
75 pub fn generate_packed(
76 &mut self,
77 prompt_ids: &[u32],
78 n_new: usize,
79 on_token: impl FnMut(u32),
80 ) -> Result<Vec<u32>> {
81 self.inner.generate_packed(prompt_ids, n_new, on_token)
82 }
83
84 pub fn generate(
86 &mut self,
87 prompt_ids: &[u32],
88 n_new: usize,
89 on_token: impl FnMut(u32),
90 ) -> Result<Vec<u32>> {
91 self.inner.generate(prompt_ids, n_new, on_token)
92 }
93
94 pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
96 self.inner.predict_logits(prompt_ids)
97 }
98}
99
100#[derive(Debug, Clone, Default)]
101pub struct MiniCpm5RunnerBuilder {
102 weights: Option<PathBuf>,
103 inner: Llama32RunnerBuilder,
104}
105
106impl MiniCpm5RunnerBuilder {
107 pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
108 let p: PathBuf = path.into();
109 self.weights = Some(p.clone());
110 self.inner = self.inner.weights(p);
111 self
112 }
113
114 pub fn max_seq(mut self, n: usize) -> Self {
115 self.inner = self.inner.max_seq(n);
116 self
117 }
118
119 pub fn packed_weights(mut self, on: bool) -> Self {
120 self.inner = self.inner.packed_weights(on);
121 self
122 }
123
124 pub fn device(mut self, d: Device) -> Self {
125 self.inner = self.inner.device(d);
126 self
127 }
128
129 pub fn build(self) -> Result<MiniCpm5Runner> {
130 let weights = self
131 .weights
132 .as_ref()
133 .ok_or_else(|| anyhow::anyhow!("weights path required (call .weights(...))"))?
134 .clone();
135
136 validate_weights_kind(&weights)?;
137
138 let base = match WeightFormat::from_path(&weights)? {
139 WeightFormat::Gguf => LlamaBaseConfig::from_gguf_path(&weights)
140 .with_context(|| format!("rlx-minicpm5: parse GGUF {weights:?}"))?,
141 WeightFormat::Safetensors => llama_base_from_hf(&weights)?,
142 };
143
144 let inner = self
145 .inner
146 .build()
147 .context("rlx-minicpm5: building underlying Llama32Runner")?;
148
149 Ok(MiniCpm5Runner { inner, base })
150 }
151}
152
153fn llama_base_from_hf(weights_or_dir: &Path) -> Result<LlamaBaseConfig> {
154 let cfg = config::llama_config_from_hf(weights_or_dir)?;
155 Ok(LlamaBaseConfig {
156 arch: "llama".into(),
157 vocab_size: cfg.vocab_size,
158 hidden_size: cfg.hidden_size,
159 intermediate_size: cfg.intermediate_size,
160 num_hidden_layers: cfg.num_hidden_layers,
161 num_attention_heads: cfg.num_attention_heads,
162 num_key_value_heads: cfg.num_key_value_heads,
163 head_dim: cfg.head_dim,
164 rms_norm_eps: cfg.rms_norm_eps,
165 rope_theta: cfg.rope_theta,
166 rope_scaling: None,
167 sliding_window: None,
168 max_position_embeddings: cfg.max_position_embeddings,
169 })
170}
171
172pub fn cli_run(args: &[String]) -> Result<()> {
174 if let Some(first) = args.iter().position(|a| a == "--weights") {
175 if let Some(path) = args.get(first + 1) {
176 validate_weights_kind(Path::new(path))?;
177 }
178 }
179 rlx_llama32::cli::run(args)
180}