1use crate::{GemmaConfig, GemmaGenerator, gemma_cfg_from_gguf};
17use anyhow::{Context, Result, anyhow, bail};
18use rlx_cli::{LmRunner, WeightFormat};
19use rlx_core::gguf_support::{
20 GgufModelFamily, ResolveWeightsOptions, assert_gguf_family, gguf_f32_bytes_estimate,
21 resolve_weights_file_with_options,
22};
23use rlx_core::weight_loader::GgufLoader;
24use rlx_flow::CompileProfile;
25use rlx_qwen3::SampleOpts;
26use rlx_runtime::{Device, Session};
27use std::path::{Path, PathBuf};
28
29#[derive(Debug, Clone)]
34pub enum GemmaConfigSource {
35 Embedded,
36 JsonFile(PathBuf),
37 Explicit(GemmaConfig),
38}
39
40#[derive(Debug, Clone, Default)]
41pub struct GemmaRunnerBuilder {
42 weights: Option<PathBuf>,
43 config: Option<GemmaConfigSource>,
44 device: Option<Device>,
45 max_seq: Option<usize>,
46 max_memory_gb: Option<f32>,
47 stream: bool,
48 sample: Option<SampleOpts>,
49 format: Option<WeightFormat>,
50 packed_weights: bool,
51}
52
53impl GemmaRunnerBuilder {
54 pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
55 self.weights = Some(path.into());
56 self
57 }
58
59 pub fn format(mut self, fmt: WeightFormat) -> Self {
60 self.format = Some(fmt);
61 self
62 }
63
64 pub fn config(mut self, src: GemmaConfigSource) -> Self {
65 self.config = Some(src);
66 self
67 }
68
69 pub fn config_value(self, cfg: GemmaConfig) -> Self {
70 self.config(GemmaConfigSource::Explicit(cfg))
71 }
72
73 pub fn device(mut self, d: Device) -> Self {
74 self.device = Some(d);
75 self
76 }
77
78 pub fn max_seq(mut self, n: usize) -> Self {
79 self.max_seq = Some(n);
80 self
81 }
82
83 pub fn max_memory_gb(mut self, gb: f32) -> Self {
84 self.max_memory_gb = Some(gb);
85 self
86 }
87
88 pub fn stream(mut self, on: bool) -> Self {
89 self.stream = on;
90 self
91 }
92
93 pub fn sample(mut self, opts: SampleOpts) -> Self {
94 self.sample = Some(opts);
95 self
96 }
97
98 pub fn packed_weights(mut self, on: bool) -> Self {
101 self.packed_weights = on;
102 self
103 }
104
105 pub fn build(self) -> Result<GemmaRunner> {
106 let resolve = ResolveWeightsOptions {
107 prefer_gguf_substring: Some(rlx_core::DEFAULT_GGUF_PREFER_SUBSTR),
108 ..Default::default()
109 };
110 let weights_path = resolve_weights_file_with_options(
111 self.weights
112 .as_ref()
113 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?,
114 &resolve,
115 )?;
116 let format = WeightFormat::resolve(&weights_path, self.format)?;
117 let device = self.device.unwrap_or(Device::Cpu);
118 let max_seq = self.max_seq.unwrap_or(128);
119 let stream = self.stream;
120 let sample = self.sample.unwrap_or_else(SampleOpts::greedy);
121
122 let (cfg, total_bytes_estimate) = match format {
123 WeightFormat::Gguf => load_gemma_gguf_config(&weights_path, self.config.as_ref())?,
124 WeightFormat::Safetensors => {
125 load_gemma_safetensors_config(&weights_path, self.config.as_ref())?
126 }
127 };
128
129 if let Some(cap_gb) = self.max_memory_gb {
130 let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
131 if est_gb > cap_gb {
132 bail!(
133 "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB"
134 );
135 }
136 }
137
138 crate::capabilities::validate_device(&cfg, device, self.packed_weights)?;
139
140 let path_str = weights_path
141 .to_str()
142 .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
143 let generator = if self.packed_weights {
144 None
145 } else {
146 Some(
147 GemmaGenerator::from_path(cfg.clone(), path_str, device)?
148 .with_prefill_cache(2)
149 .with_decode_cache(max_seq + 64),
150 )
151 };
152
153 let packed = if self.packed_weights {
154 if !matches!(format, WeightFormat::Gguf) {
155 bail!(
156 "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
157 format,
158 weights_path
159 );
160 }
161 eprintln!(
162 "[gemma-runner] packed_weights=true — compiling prefill graph with \
163 Op::DequantMatMul on {device:?}"
164 );
165 Some(GemmaPackedForward::build(
166 &cfg,
167 &weights_path,
168 max_seq,
169 device,
170 )?)
171 } else {
172 None
173 };
174
175 Ok(GemmaRunner {
176 generator,
177 cfg,
178 sample,
179 stream,
180 device,
181 packed,
182 })
183 }
184}
185
186struct GemmaPackedForward {
187 compiled: rlx_runtime::CompiledGraph,
188 seq: usize,
189}
190
191impl GemmaPackedForward {
192 fn build(cfg: &GemmaConfig, weights_path: &Path, seq: usize, device: Device) -> Result<Self> {
193 use crate::build_gemma_graph_sized_packed;
194 let mut loader = GgufLoader::from_file(
195 weights_path
196 .to_str()
197 .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
198 )?;
199 let mut packed = std::collections::HashMap::new();
200 let (graph, params) =
204 build_gemma_graph_sized_packed(cfg, &mut loader, 1, seq, true, false, &mut packed)?;
205 let opts = rlx_core::flow_bridge::compile_options_for_profile(
206 &CompileProfile::gemma_prefill(),
207 device,
208 );
209 let mut compiled = Session::new(device).compile_with(graph, &opts);
210 for (name, data) in ¶ms {
211 compiled.set_param(name, data);
212 }
213 for (name, (bytes, _scheme, _shape)) in &packed {
214 compiled.set_param_typed(name, bytes, rlx_ir::DType::U8);
215 }
216 Ok(Self { compiled, seq })
217 }
218}
219
220pub struct GemmaRunner {
221 generator: Option<GemmaGenerator>,
222 cfg: GemmaConfig,
223 sample: SampleOpts,
224 stream: bool,
225 device: Device,
226 packed: Option<GemmaPackedForward>,
227}
228
229impl GemmaRunner {
230 pub fn builder() -> GemmaRunnerBuilder {
231 GemmaRunnerBuilder::default()
232 }
233
234 pub fn config(&self) -> &GemmaConfig {
235 &self.cfg
236 }
237
238 pub fn device(&self) -> Device {
239 self.device
240 }
241
242 pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
244 if let Some(p) = self.packed.as_mut() {
245 let n = prompt_ids.len().min(p.seq);
248 let last = n.saturating_sub(1);
249 let mut padded = vec![0u32; p.seq];
250 for (i, &t) in prompt_ids.iter().take(p.seq).enumerate() {
251 padded[i] = t;
252 }
253 let ids_f32: Vec<f32> = padded.iter().map(|&i| i as f32).collect();
254 let out = p.compiled.run(&[("input_ids", ids_f32.as_slice())]);
255 let logits = out
256 .into_iter()
257 .next()
258 .ok_or_else(|| anyhow!("packed forward returned no output"))?;
259 let vocab = self.cfg.vocab_size;
260 let expected = p.seq * vocab;
261 if logits.len() < expected {
262 bail!("logits short: {} < {expected}", logits.len());
263 }
264 let start = last * vocab;
265 return Ok(logits[start..start + vocab].to_vec());
266 }
267 let generator = self
268 .generator
269 .as_mut()
270 .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
271 generator.prefill_get_last_logits(prompt_ids)
272 }
273
274 pub fn generate_packed(
275 &mut self,
276 prompt_ids: &[u32],
277 n_new: usize,
278 mut on_token: impl FnMut(u32),
279 ) -> Result<Vec<u32>> {
280 if self.packed.is_none() {
281 bail!("generate_packed() only works in packed_weights(true) mode");
282 }
283 let mut history: Vec<u32> = prompt_ids.to_vec();
284 let mut out = Vec::with_capacity(n_new);
285 for _ in 0..n_new {
286 let logits = self.predict_logits(&history)?;
287 let next = rlx_qwen3::sample_token(&logits, self.sample) as u32;
288 on_token(next);
289 history.push(next);
290 out.push(next);
291 }
292 Ok(out)
293 }
294
295 pub fn generate(
296 &mut self,
297 prompt_ids: &[u32],
298 n_new: usize,
299 mut on_token: impl FnMut(u32),
300 ) -> Result<Vec<u32>> {
301 if self.packed.is_some() {
302 return self.generate_packed(prompt_ids, n_new, on_token);
303 }
304 let generator = self
305 .generator
306 .as_mut()
307 .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
308 generator.prefill(prompt_ids);
309 let tokens = if self.stream {
310 generator.generate_cached_with(n_new, self.sample, &mut on_token)?
311 } else {
312 let toks = generator.generate_cached(n_new, self.sample)?;
313 for &t in &toks {
314 on_token(t);
315 }
316 toks
317 };
318 Ok(tokens)
319 }
320}
321
322impl LmRunner for GemmaRunner {
323 fn family(&self) -> &'static str {
324 "gemma"
325 }
326 fn vocab_size(&self) -> usize {
327 self.config().vocab_size
328 }
329 fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
330 GemmaRunner::predict_logits(self, prompt_ids)
331 }
332 fn generate(
333 &mut self,
334 prompt_ids: &[u32],
335 n_new: usize,
336 on_token: &mut dyn FnMut(u32) -> bool,
337 ) -> Result<Vec<u32>> {
338 GemmaRunner::generate(self, prompt_ids, n_new, |tok| {
340 let _ = on_token(tok);
341 })
342 }
343}
344
345fn load_gemma_gguf_config(
346 path: &Path,
347 override_src: Option<&GemmaConfigSource>,
348) -> Result<(GemmaConfig, u64)> {
349 let raw = assert_gguf_family(path, GgufModelFamily::Gemma)?;
350 let cfg = match override_src {
351 Some(GemmaConfigSource::Explicit(c)) => c.clone(),
352 Some(GemmaConfigSource::JsonFile(p)) => {
353 GemmaConfig::from_file(p).with_context(|| format!("reading override config {p:?}"))?
354 }
355 Some(GemmaConfigSource::Embedded) | None => gemma_cfg_from_gguf(&raw)?,
356 };
357 Ok((cfg, gguf_f32_bytes_estimate(&raw)))
358}
359
360fn load_gemma_safetensors_config(
361 path: &Path,
362 override_src: Option<&GemmaConfigSource>,
363) -> Result<(GemmaConfig, u64)> {
364 let cfg_path = match override_src {
365 Some(GemmaConfigSource::Explicit(c)) => {
366 return Ok((c.clone(), default_st_size_estimate(path)));
367 }
368 Some(GemmaConfigSource::JsonFile(p)) => p.clone(),
369 Some(GemmaConfigSource::Embedded) => {
370 bail!("ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
371 }
372 None => path
373 .parent()
374 .ok_or_else(|| anyhow!("weights path has no parent dir"))?
375 .join("config.json"),
376 };
377 let cfg = GemmaConfig::from_file(&cfg_path)
378 .with_context(|| format!("reading config {cfg_path:?}"))?;
379 Ok((cfg, default_st_size_estimate(path)))
380}
381
382fn default_st_size_estimate(path: &Path) -> u64 {
383 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
384}