1use crate::{Llama32Config, Llama32Generator, llama32_cfg_from_gguf};
5use anyhow::{Context, Result, anyhow, bail};
6use rlx_cli::{LmRunner, WeightFormat};
7use rlx_core::weight_loader::GgufLoader;
8use rlx_flow::CompileProfile;
9use rlx_gguf::{GgufFile, MetaValue};
10use rlx_qwen3::SampleOpts;
11use rlx_runtime::{Device, Session};
12use std::path::{Path, PathBuf};
13
14#[derive(Debug, Clone)]
19pub enum Llama32ConfigSource {
20 Embedded,
21 JsonFile(PathBuf),
22 Explicit(Llama32Config),
23}
24
25#[derive(Debug, Clone, Default)]
26pub struct Llama32RunnerBuilder {
27 weights: Option<PathBuf>,
28 config: Option<Llama32ConfigSource>,
29 device: Option<Device>,
30 max_seq: Option<usize>,
31 max_memory_gb: Option<f32>,
32 stream: bool,
33 sample: Option<SampleOpts>,
34 format: Option<WeightFormat>,
35 packed_weights: bool,
36}
37
38impl Llama32RunnerBuilder {
39 pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
40 self.weights = Some(path.into());
41 self
42 }
43
44 pub fn format(mut self, fmt: WeightFormat) -> Self {
45 self.format = Some(fmt);
46 self
47 }
48
49 pub fn config(mut self, src: Llama32ConfigSource) -> Self {
50 self.config = Some(src);
51 self
52 }
53
54 pub fn config_value(self, cfg: Llama32Config) -> Self {
55 self.config(Llama32ConfigSource::Explicit(cfg))
56 }
57
58 pub fn device(mut self, d: Device) -> Self {
59 self.device = Some(d);
60 self
61 }
62
63 pub fn max_seq(mut self, n: usize) -> Self {
64 self.max_seq = Some(n);
65 self
66 }
67
68 pub fn max_memory_gb(mut self, gb: f32) -> Self {
69 self.max_memory_gb = Some(gb);
70 self
71 }
72
73 pub fn stream(mut self, on: bool) -> Self {
74 self.stream = on;
75 self
76 }
77
78 pub fn sample(mut self, opts: SampleOpts) -> Self {
79 self.sample = Some(opts);
80 self
81 }
82
83 pub fn packed_weights(mut self, on: bool) -> Self {
86 self.packed_weights = on;
87 self
88 }
89
90 pub fn build(self) -> Result<Llama32Runner> {
91 let weights_path = self
92 .weights
93 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
94 let format = match self.format {
95 Some(f) => f,
96 None => WeightFormat::from_path(&weights_path)?,
97 };
98 let device = self.device.unwrap_or(Device::Cpu);
99 let max_seq = self.max_seq.unwrap_or(128);
100 let stream = self.stream;
101 let sample = self.sample.unwrap_or_else(SampleOpts::greedy);
102
103 let (cfg, total_bytes_estimate) = match format {
104 WeightFormat::Gguf => load_llama32_gguf_config(&weights_path, self.config.as_ref())?,
105 WeightFormat::Safetensors => {
106 load_llama32_safetensors_config(&weights_path, self.config.as_ref())?
107 }
108 };
109
110 if let Some(cap_gb) = self.max_memory_gb {
111 let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
112 if est_gb > cap_gb {
113 bail!(
114 "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB"
115 );
116 }
117 }
118
119 crate::validate_device(&cfg, device, self.packed_weights)?;
120
121 let path_str = weights_path
122 .to_str()
123 .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
124 let generator = if self.packed_weights {
125 None
126 } else {
127 Some(
128 Llama32Generator::from_path(cfg.clone(), path_str, device)?
129 .with_prefill_cache(2)
130 .with_decode_cache(max_seq + 64),
131 )
132 };
133
134 let packed = if self.packed_weights {
135 if !matches!(format, WeightFormat::Gguf) {
136 bail!(
137 "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
138 format,
139 weights_path
140 );
141 }
142 eprintln!(
143 "[llama32-runner] packed_weights=true — compiling prefill graph with \
144 Op::DequantMatMul on {device:?}"
145 );
146 Some(Llama32PackedForward::build(
147 &cfg,
148 &weights_path,
149 max_seq,
150 device,
151 )?)
152 } else {
153 None
154 };
155
156 Ok(Llama32Runner {
157 generator,
158 cfg,
159 sample,
160 stream,
161 device,
162 packed,
163 })
164 }
165}
166
167struct Llama32PackedForward {
168 compiled: rlx_runtime::CompiledGraph,
169 seq: usize,
170}
171
172impl Llama32PackedForward {
173 fn build(cfg: &Llama32Config, weights_path: &Path, seq: usize, device: Device) -> Result<Self> {
174 use crate::build_llama32_graph_sized_packed;
175 let mut loader = GgufLoader::from_file(
176 weights_path
177 .to_str()
178 .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
179 )?;
180 let mut packed = std::collections::HashMap::new();
181 let (graph, params) =
182 build_llama32_graph_sized_packed(cfg, &mut loader, 1, seq, true, true, &mut packed)?;
183 let opts = rlx_core::flow_bridge::compile_options_for_profile(
184 &CompileProfile::llama32_prefill(),
185 device,
186 );
187 let mut compiled = Session::new(device).compile_with(graph, &opts);
188 for (name, data) in ¶ms {
189 compiled.set_param(name, data);
190 }
191 for (name, (bytes, _scheme, _shape)) in &packed {
192 compiled.set_param_typed(name, bytes, rlx_ir::DType::U8);
193 }
194 Ok(Self { compiled, seq })
195 }
196}
197
198pub struct Llama32Runner {
199 generator: Option<Llama32Generator>,
200 cfg: Llama32Config,
201 sample: SampleOpts,
202 stream: bool,
203 device: Device,
204 packed: Option<Llama32PackedForward>,
205}
206
207impl Llama32Runner {
208 pub fn builder() -> Llama32RunnerBuilder {
209 Llama32RunnerBuilder::default()
210 }
211
212 pub fn config(&self) -> &Llama32Config {
213 &self.cfg
214 }
215
216 pub fn device(&self) -> Device {
217 self.device
218 }
219
220 pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
222 if let Some(p) = self.packed.as_mut() {
223 let mut padded = vec![*prompt_ids.first().unwrap_or(&0); p.seq];
224 for (i, &t) in prompt_ids.iter().take(p.seq).enumerate() {
225 padded[i] = t;
226 }
227 let ids_f32: Vec<f32> = padded.iter().map(|&i| i as f32).collect();
228 let out = p.compiled.run(&[("input_ids", ids_f32.as_slice())]);
229 let logits = out
230 .into_iter()
231 .next()
232 .ok_or_else(|| anyhow!("packed forward returned no output"))?;
233 let vocab = self.cfg.vocab_size;
234 if logits.len() < vocab {
235 bail!("logits short: {} < {vocab}", logits.len());
236 }
237 return Ok(logits[..vocab].to_vec());
238 }
239 let generator = self
240 .generator
241 .as_mut()
242 .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
243 generator.prefill_get_last_logits(prompt_ids)
244 }
245
246 pub fn generate_packed(
247 &mut self,
248 prompt_ids: &[u32],
249 n_new: usize,
250 mut on_token: impl FnMut(u32),
251 ) -> Result<Vec<u32>> {
252 if self.packed.is_none() {
253 bail!("generate_packed() only works in packed_weights(true) mode");
254 }
255 let mut history: Vec<u32> = prompt_ids.to_vec();
256 let mut out = Vec::with_capacity(n_new);
257 for _ in 0..n_new {
258 let logits = self.predict_logits(&history)?;
259 let next = rlx_qwen3::sample_token(&logits, self.sample) as u32;
260 on_token(next);
261 history.push(next);
262 out.push(next);
263 }
264 Ok(out)
265 }
266
267 pub fn generate(
268 &mut self,
269 prompt_ids: &[u32],
270 n_new: usize,
271 mut on_token: impl FnMut(u32),
272 ) -> Result<Vec<u32>> {
273 if self.packed.is_some() {
274 return self.generate_packed(prompt_ids, n_new, on_token);
275 }
276 let generator = self
277 .generator
278 .as_mut()
279 .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
280 generator.prefill(prompt_ids);
281 let tokens = if self.stream {
282 generator.generate_cached_with(n_new, self.sample, &mut on_token)?
283 } else {
284 let toks = generator.generate_cached(n_new, self.sample)?;
285 for &t in &toks {
286 on_token(t);
287 }
288 toks
289 };
290 Ok(tokens)
291 }
292}
293
294impl LmRunner for Llama32Runner {
295 fn family(&self) -> &'static str {
296 "llama32"
297 }
298 fn vocab_size(&self) -> usize {
299 self.config().vocab_size
300 }
301 fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
302 Llama32Runner::predict_logits(self, prompt_ids)
303 }
304 fn generate(
305 &mut self,
306 prompt_ids: &[u32],
307 n_new: usize,
308 on_token: &mut dyn FnMut(u32) -> bool,
309 ) -> Result<Vec<u32>> {
310 Llama32Runner::generate(self, prompt_ids, n_new, |tok| {
311 let _ = on_token(tok);
312 })
313 }
314}
315
316fn load_llama32_gguf_config(
317 path: &Path,
318 override_src: Option<&Llama32ConfigSource>,
319) -> Result<(Llama32Config, u64)> {
320 let raw = GgufFile::from_path(path).with_context(|| format!("opening {path:?}"))?;
321 let arch = raw
322 .metadata
323 .get("general.architecture")
324 .and_then(MetaValue::as_str)
325 .unwrap_or("llama");
326 if arch != "llama" {
327 bail!(
328 "{path:?} has architecture {arch:?}; Llama32Runner expects general.architecture=llama"
329 );
330 }
331 let cfg = match override_src {
332 Some(Llama32ConfigSource::Explicit(c)) => c.clone(),
333 Some(Llama32ConfigSource::JsonFile(p)) => {
334 Llama32Config::from_file(p).with_context(|| format!("reading override config {p:?}"))?
335 }
336 Some(Llama32ConfigSource::Embedded) | None => llama32_cfg_from_gguf(&raw)?,
337 };
338 let bytes_est: u64 = raw
339 .tensors
340 .values()
341 .map(|t| (t.n_elements() as u64) * 4)
342 .sum();
343 Ok((cfg, bytes_est))
344}
345
346fn load_llama32_safetensors_config(
347 path: &Path,
348 override_src: Option<&Llama32ConfigSource>,
349) -> Result<(Llama32Config, u64)> {
350 let cfg_path = match override_src {
351 Some(Llama32ConfigSource::Explicit(c)) => {
352 return Ok((c.clone(), default_st_size_estimate(path)));
353 }
354 Some(Llama32ConfigSource::JsonFile(p)) => p.clone(),
355 Some(Llama32ConfigSource::Embedded) => {
356 bail!("ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
357 }
358 None => path
359 .parent()
360 .ok_or_else(|| anyhow!("weights path has no parent dir"))?
361 .join("config.json"),
362 };
363 let cfg = Llama32Config::from_file(&cfg_path)
364 .with_context(|| format!("reading config {cfg_path:?}"))?;
365 Ok((cfg, default_st_size_estimate(path)))
366}
367
368fn default_st_size_estimate(path: &Path) -> u64 {
369 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
370}