1use crate::{GemmaConfig, GemmaGenerator, gemma_cfg_from_gguf};
17use anyhow::{Context, Result, anyhow, bail};
18use rlx_cli::{LmRunner, WeightFormat};
19use rlx_core::gguf_support::{
20 GgufModelFamily, ResolveWeightsOptions, assert_gguf_family, gguf_f32_bytes_estimate,
21 resolve_weights_file_with_options,
22};
23use rlx_qwen3::SampleOpts;
24use rlx_runtime::Device;
25use std::path::{Path, PathBuf};
26
27pub type GemmaConfigSource = rlx_runtime::ConfigSource<GemmaConfig>;
35
36#[derive(Debug, Clone, Default)]
37pub struct GemmaRunnerBuilder {
38 weights: Option<PathBuf>,
39 config: Option<GemmaConfigSource>,
40 device: Option<Device>,
41 max_seq: Option<usize>,
42 max_memory_gb: Option<f32>,
43 stream: bool,
44 sample: Option<SampleOpts>,
45 format: Option<WeightFormat>,
46 packed_weights: bool,
47}
48
49impl GemmaRunnerBuilder {
50 pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
51 self.weights = Some(path.into());
52 self
53 }
54
55 pub fn format(mut self, fmt: WeightFormat) -> Self {
56 self.format = Some(fmt);
57 self
58 }
59
60 pub fn config(mut self, src: GemmaConfigSource) -> Self {
61 self.config = Some(src);
62 self
63 }
64
65 pub fn config_value(self, cfg: GemmaConfig) -> Self {
66 self.config(GemmaConfigSource::Explicit(cfg))
67 }
68
69 pub fn device(mut self, d: Device) -> Self {
70 self.device = Some(d);
71 self
72 }
73
74 pub fn max_seq(mut self, n: usize) -> Self {
75 self.max_seq = Some(n);
76 self
77 }
78
79 pub fn max_memory_gb(mut self, gb: f32) -> Self {
80 self.max_memory_gb = Some(gb);
81 self
82 }
83
84 pub fn stream(mut self, on: bool) -> Self {
85 self.stream = on;
86 self
87 }
88
89 pub fn sample(mut self, opts: SampleOpts) -> Self {
90 self.sample = Some(opts);
91 self
92 }
93
94 pub fn packed_weights(mut self, on: bool) -> Self {
97 self.packed_weights = on;
98 self
99 }
100
101 pub fn build(self) -> Result<GemmaRunner> {
102 let resolve = ResolveWeightsOptions {
103 prefer_gguf_substring: Some(rlx_core::DEFAULT_GGUF_PREFER_SUBSTR),
104 ..Default::default()
105 };
106 let weights_path = resolve_weights_file_with_options(
107 self.weights
108 .as_ref()
109 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?,
110 &resolve,
111 )?;
112 let format = WeightFormat::resolve(&weights_path, self.format)?;
113 let device = self.device.unwrap_or(Device::Cpu);
114 let max_seq = self.max_seq.unwrap_or(128);
115 let stream = self.stream;
116 let sample = self.sample.unwrap_or_else(SampleOpts::greedy);
117
118 let (cfg, total_bytes_estimate) = match format {
119 WeightFormat::Gguf => load_gemma_gguf_config(&weights_path, self.config.as_ref())?,
120 WeightFormat::Safetensors => {
121 load_gemma_safetensors_config(&weights_path, self.config.as_ref())?
122 }
123 };
124
125 if let Some(cap_gb) = self.max_memory_gb {
126 let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
127 if est_gb > cap_gb {
128 bail!(
129 "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB"
130 );
131 }
132 }
133
134 crate::capabilities::validate_device(&cfg, device, self.packed_weights)?;
135
136 let path_str = weights_path
137 .to_str()
138 .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
139 let generator = if self.packed_weights {
140 None
141 } else {
142 Some(
143 GemmaGenerator::from_path(cfg.clone(), path_str, device)?
144 .with_inference_caches(max_seq),
145 )
146 };
147
148 let packed = if self.packed_weights {
149 if !matches!(format, WeightFormat::Gguf) {
150 bail!(
151 "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
152 format,
153 weights_path
154 );
155 }
156 eprintln!(
157 "[gemma-runner] packed_weights=true — Q4 prefill + bucketed decode on {device:?}"
158 );
159 Some(crate::packed_session::GemmaPackedSession::build(
160 cfg.clone(),
161 &weights_path,
162 max_seq,
163 device,
164 )?)
165 } else {
166 None
167 };
168
169 Ok(GemmaRunner {
170 generator,
171 cfg,
172 sample,
173 stream,
174 device,
175 packed,
176 })
177 }
178}
179
180pub struct GemmaRunner {
181 generator: Option<GemmaGenerator>,
182 cfg: GemmaConfig,
183 sample: SampleOpts,
184 stream: bool,
185 device: Device,
186 packed: Option<crate::packed_session::GemmaPackedSession>,
187}
188
189impl GemmaRunner {
190 pub fn builder() -> GemmaRunnerBuilder {
191 GemmaRunnerBuilder::default()
192 }
193
194 pub fn config(&self) -> &GemmaConfig {
195 &self.cfg
196 }
197
198 pub fn device(&self) -> Device {
199 self.device
200 }
201
202 pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
204 if let Some(p) = self.packed.as_mut() {
205 return p.predict_logits(prompt_ids);
206 }
207 let generator = self
208 .generator
209 .as_mut()
210 .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
211 generator.prefill_get_last_logits(prompt_ids)
212 }
213
214 pub fn generate_packed(
215 &mut self,
216 prompt_ids: &[u32],
217 n_new: usize,
218 on_token: impl FnMut(u32),
219 ) -> Result<Vec<u32>> {
220 if self.packed.is_none() {
221 bail!("generate_packed() only works in packed_weights(true) mode");
222 }
223 let sample = self.sample;
224 self.packed
225 .as_mut()
226 .unwrap()
227 .generate(prompt_ids, n_new, sample, on_token)
228 }
229
230 pub fn generate(
231 &mut self,
232 prompt_ids: &[u32],
233 n_new: usize,
234 mut on_token: impl FnMut(u32),
235 ) -> Result<Vec<u32>> {
236 if self.packed.is_some() {
237 return self.generate_packed(prompt_ids, n_new, on_token);
238 }
239 let generator = self
240 .generator
241 .as_mut()
242 .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
243 generator.prefill(prompt_ids);
244 let tokens = if self.stream {
245 generator.generate_cached_with(n_new, self.sample, &mut on_token)?
246 } else {
247 let toks = generator.generate_cached(n_new, self.sample)?;
248 for &t in &toks {
249 on_token(t);
250 }
251 toks
252 };
253 Ok(tokens)
254 }
255
256 pub fn generate_from_embeds(
258 &mut self,
259 prompt_ids: &[u32],
260 inputs_embeds: &[f32],
261 n_new: usize,
262 mut on_token: impl FnMut(u32),
263 ) -> Result<Vec<u32>> {
264 if self.packed.is_some() {
265 bail!("generate_from_embeds is not supported with packed_weights(true)");
266 }
267 let generator = self
268 .generator
269 .as_mut()
270 .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
271 let tokens = if self.stream {
272 generator.generate_from_embeds_with(
273 prompt_ids,
274 inputs_embeds,
275 n_new,
276 self.sample,
277 &mut on_token,
278 )?
279 } else {
280 let toks =
281 generator.generate_from_embeds(prompt_ids, inputs_embeds, n_new, self.sample)?;
282 for &t in &toks {
283 on_token(t);
284 }
285 toks
286 };
287 Ok(tokens)
288 }
289
290 pub fn generate_multimodal(
292 &mut self,
293 mm_cfg: &crate::multimodal::GemmaMultimodalConfig,
294 token_ids: &[u32],
295 image_embeds: &[f32],
296 audio_embeds: &[f32],
297 video_embeds: &[f32],
298 n_new: usize,
299 mut on_token: impl FnMut(u32),
300 ) -> Result<Vec<u32>> {
301 let generator = self
302 .generator
303 .as_ref()
304 .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
305 let embeds = crate::multimodal_embed::build_multimodal_inputs_embeds(
306 generator.weights_cache(),
307 &self.cfg,
308 mm_cfg,
309 token_ids,
310 image_embeds,
311 audio_embeds,
312 video_embeds,
313 )?;
314 let attn_bias = crate::multimodal_mask::build_multimodal_prefill_attn_bias(
315 token_ids, &self.cfg, mm_cfg, 1,
316 );
317 let generator = self
318 .generator
319 .as_mut()
320 .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
321 let tokens = if self.stream {
322 generator.generate_from_embeds_with_bias_and_callback(
323 token_ids,
324 &embeds,
325 attn_bias,
326 n_new,
327 self.sample,
328 &mut on_token,
329 )?
330 } else {
331 let toks = generator.generate_from_embeds_with_bias(
332 token_ids,
333 &embeds,
334 attn_bias,
335 n_new,
336 self.sample,
337 )?;
338 for &t in &toks {
339 on_token(t);
340 }
341 toks
342 };
343 Ok(tokens)
344 }
345}
346
347impl LmRunner for GemmaRunner {
348 fn family(&self) -> &'static str {
349 "gemma"
350 }
351 fn vocab_size(&self) -> usize {
352 self.config().vocab_size
353 }
354 fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
355 GemmaRunner::predict_logits(self, prompt_ids)
356 }
357 fn generate(
358 &mut self,
359 prompt_ids: &[u32],
360 n_new: usize,
361 on_token: &mut dyn FnMut(u32) -> bool,
362 ) -> Result<Vec<u32>> {
363 GemmaRunner::generate(self, prompt_ids, n_new, |tok| {
365 let _ = on_token(tok);
366 })
367 }
368}
369
370fn load_gemma_gguf_config(
371 path: &Path,
372 override_src: Option<&GemmaConfigSource>,
373) -> Result<(GemmaConfig, u64)> {
374 let raw = assert_gguf_family(path, GgufModelFamily::Gemma)?;
375 let cfg = match override_src {
376 Some(GemmaConfigSource::Explicit(c)) => c.clone(),
377 Some(GemmaConfigSource::JsonFile(p)) => {
378 GemmaConfig::from_file(p).with_context(|| format!("reading override config {p:?}"))?
379 }
380 Some(GemmaConfigSource::Embedded) | None => gemma_cfg_from_gguf(&raw)?,
381 };
382 Ok((cfg, gguf_f32_bytes_estimate(&raw)))
383}
384
385fn load_gemma_safetensors_config(
386 path: &Path,
387 override_src: Option<&GemmaConfigSource>,
388) -> Result<(GemmaConfig, u64)> {
389 let cfg_path = match override_src {
390 Some(GemmaConfigSource::Explicit(c)) => {
391 return Ok((c.clone(), default_st_size_estimate(path)));
392 }
393 Some(GemmaConfigSource::JsonFile(p)) => p.clone(),
394 Some(GemmaConfigSource::Embedded) => {
395 bail!("ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
396 }
397 None => path
398 .parent()
399 .ok_or_else(|| anyhow!("weights path has no parent dir"))?
400 .join("config.json"),
401 };
402 let cfg = GemmaConfig::from_file(&cfg_path)
403 .with_context(|| format!("reading config {cfg_path:?}"))?;
404 Ok((cfg, default_st_size_estimate(path)))
405}
406
407fn default_st_size_estimate(path: &Path) -> u64 {
408 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
409}