1use std::{
2 env,
3 io::IsTerminal,
4 num::NonZeroU32,
5 path::{Path, PathBuf},
6 str::FromStr,
7 sync::LazyLock,
8};
9
10use encoding_rs::{Decoder, UTF_8};
11use hf_hub::api::tokio::ApiBuilder;
12use llama_cpp_2::{
13 LlamaContextLoadError,
14 context::{LlamaContext, params::LlamaContextParams},
15 llama_backend::LlamaBackend,
16 llama_batch::LlamaBatch,
17 model::{AddBos, LlamaChatMessage, LlamaChatTemplate, LlamaModel},
18 mtmd::{self, MtmdBitmap, MtmdContext, MtmdInputText},
19 sampling::LlamaSampler,
20 token::LlamaToken,
21};
22use strum::Display;
23
24use crate::{
25 error::{CreateLlamaCppRunnerError, RunnerError},
26 sample::{LlguidanceSamplingParams, SimpleSamplingParams},
27};
28
29pub const QWEN_3D5_4B_GUFF_MODEL_ID: &str = "unsloth/Qwen3.5-4B-GGUF";
30pub const QWEN_3D5_4B_GUFF_MODDEL_FILENAME: &str = "Qwen3.5-4B-Q4_K_M.gguf";
31pub const QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
32
33pub const GEMMA_3_1B_GUFF_MODEL_ID: &str = "google/gemma-3-1b-it-qat-q4_0-gguf";
34pub const GEMMA_3_1B_GUFF_MODEL_FILENAME: &str = "gemma-3-1b-it-q4_0.gguf";
35
36pub trait TextLmRunner<'s, 'req> {
37 type Response: Iterator<Item = Result<String, RunnerError>>;
38 fn stream_lm_response(&'s self, request: TextLmRequest<'req>) -> Self::Response;
39}
40
41pub trait VisionLmRunner<'s, 'req> {
42 type Response: Iterator<Item = Result<String, RunnerError>>;
43 fn stream_vlm_response(&'s self, request: VisionLmRequest<'req>) -> Self::Response;
44}
45
46#[derive(Debug, Clone)]
47pub struct RunnerRequest<M> {
48 pub messages: Vec<(MessageRole, M)>,
49 pub sampling: SimpleSamplingParams,
50 pub llguidance: Option<LlguidanceSamplingParams>,
51 pub max_seq: usize,
52 pub prefill: Option<String>,
53}
54
55impl<M> Default for RunnerRequest<M> {
56 fn default() -> Self {
57 Self {
58 messages: vec![],
59 sampling: Default::default(),
60 llguidance: None,
61 max_seq: usize::MAX,
62 prefill: None,
63 }
64 }
65}
66
67pub type TextLmRequest<'a> = RunnerRequest<&'a str>;
68pub type VisionLmRequest<'a> = RunnerRequest<ImageOrText<'a>>;
69
70pub trait TextLmRunnerExt<'s, 'req> {
71 fn get_lm_response(&'s self, request: TextLmRequest<'req>) -> Result<String, RunnerError>;
72}
73
74pub trait VisionLmRunnerExt<'s, 'req> {
75 fn get_vlm_response(&'s self, request: VisionLmRequest<'req>) -> Result<String, RunnerError>;
76}
77
78impl<'s, 'req, T> TextLmRunnerExt<'s, 'req> for T
79where
80 T: TextLmRunner<'s, 'req>,
81{
82 fn get_lm_response(&'s self, request: TextLmRequest<'req>) -> Result<String, RunnerError> {
83 self.stream_lm_response(request)
84 .collect::<Result<String, _>>()
85 }
86}
87
88impl<'s, 'req, T> VisionLmRunnerExt<'s, 'req> for T
89where
90 T: VisionLmRunner<'s, 'req>,
91{
92 fn get_vlm_response(&'s self, request: VisionLmRequest<'req>) -> Result<String, RunnerError> {
93 self.stream_vlm_response(request)
94 .collect::<Result<String, _>>()
95 }
96}
97
98#[derive(Debug, Clone, Display, PartialEq, Eq)]
99pub enum MessageRole {
100 #[strum(to_string = "assistant")]
101 Assistant,
102 #[strum(to_string = "user")]
103 User,
104 #[strum(to_string = "system")]
105 System,
106}
107
108#[derive(Debug, Clone)]
109pub enum ImageOrText<'a> {
110 Text(&'a str),
111 Image(&'a image::DynamicImage),
112}
113
114pub struct Gemma3TextRunner {
115 model: LlamaModel,
116 chat_template: LlamaChatTemplate,
117 ctx_size: NonZeroU32,
118}
119
120impl Gemma3TextRunner {
121 pub async fn new(
122 model_id: impl ToString,
123 model_file: impl AsRef<str>,
124 ctx_size: NonZeroU32,
125 ) -> Result<Self, CreateLlamaCppRunnerError> {
126 let repo = build_hf_api()?.model(model_id.to_string());
127 Self::from_file(repo.get(model_file.as_ref()).await?, ctx_size)
128 }
129
130 pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
131 {
132 let inner = Self::new(
133 GEMMA_3_1B_GUFF_MODEL_ID,
134 GEMMA_3_1B_GUFF_MODEL_FILENAME,
135 32_000.try_into().unwrap(),
136 )
137 .await?;
138 Ok(RunnerWithRecommendedSampling {
139 inner,
140 default_sampling: Self::recommend_sampling(),
141 })
142 }
143
144 pub fn recommend_sampling() -> SimpleSamplingParams {
145 SimpleSamplingParams {
146 top_p: Some(0.95f32),
147 top_k: Some(64),
148 temperature: Some(1f32),
149 ..Default::default()
150 }
151 }
152
153 pub fn from_file(
154 model_file: impl AsRef<Path>,
155 ctx_size: NonZeroU32,
156 ) -> Result<Self, CreateLlamaCppRunnerError> {
157 let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
158
159 let chat_template = model.chat_template(None)?;
160 Ok(Self {
161 model,
162 chat_template,
163 ctx_size,
164 })
165 }
166}
167
168impl<'s, 'req> TextLmRunner<'s, 'req> for Gemma3TextRunner {
169 type Response = Gemma3Stream<'s, &'req str, Gemma3TextRunner>;
170
171 fn stream_lm_response(&'s self, request: TextLmRequest<'req>) -> Self::Response {
172 let ctx = self
173 .model
174 .new_context(
175 &LLAMA_BACKEND,
176 LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
177 )
178 .map_err(|err| RunnerError::from(err));
179 Gemma3Stream::new(ctx, request, self, &self.model)
180 }
181}
182
183pub struct Gemma3VisionRunner {
184 model: LlamaModel,
185 chat_template: LlamaChatTemplate,
186 mtmd_ctx: MtmdContext,
187 ctx_size: NonZeroU32,
188}
189
190static LLAMA_BACKEND: LazyLock<LlamaBackend> = LazyLock::new(|| {
191 llama_cpp_2::send_logs_to_tracing(llama_cpp_2::LogOptions::default().with_logs_enabled(
192 env::var("RUST_LOG").map_or(false, |lvl| lvl.to_lowercase() == "debug"),
193 ));
194 LlamaBackend::init().unwrap()
195});
196
197impl Gemma3VisionRunner {
198 pub async fn new(
199 repo_id: impl ToString,
200 model_file: impl AsRef<str>,
201 multimodel_file: impl AsRef<str>,
202 ctx_size: NonZeroU32,
203 ) -> Result<Self, CreateLlamaCppRunnerError> {
204 let repo = build_hf_api()?.model(repo_id.to_string());
205 let model = LlamaModel::load_from_file(
206 &LLAMA_BACKEND,
207 repo.get(model_file.as_ref()).await?,
208 &Default::default(),
209 )?;
210
211 let mtmd_ctx = MtmdContext::init_from_file(
212 repo.get(multimodel_file.as_ref()).await?.to_str().unwrap(),
213 &model,
214 &Default::default(),
215 )?;
216
217 let chat_template = model.chat_template(None)?;
218
219 Ok(Self {
220 model,
221 mtmd_ctx,
222 chat_template,
223 ctx_size,
224 })
225 }
226
227 pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
228 {
229 let inner = Self::new(
230 QWEN_3D5_4B_GUFF_MODEL_ID,
231 QWEN_3D5_4B_GUFF_MODDEL_FILENAME,
232 QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME,
233 16384u32.try_into().unwrap(),
234 )
235 .await?;
236 Ok(RunnerWithRecommendedSampling {
237 inner: inner,
238 default_sampling: SimpleSamplingParams {
239 top_p: Some(0.8f32),
240 top_k: Some(20),
241 temperature: Some(0.7f32),
242 presence_penalty: Some(1.5),
243 repetition_penalty: Some(1.0),
244 seed: None,
245 },
246 })
247 }
248
249 pub fn from_files(
250 model_file: impl AsRef<Path>,
251 multimodel_file: impl AsRef<Path>,
252 ctx_size: NonZeroU32,
253 ) -> Result<Self, CreateLlamaCppRunnerError> {
254 let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
255 let mtmd_ctx = MtmdContext::init_from_file(
256 multimodel_file.as_ref().as_os_str().to_str().unwrap(),
257 &model,
258 &Default::default(),
259 )?;
260
261 let chat_template = model.chat_template(None)?;
262
263 Ok(Self {
264 model,
265 mtmd_ctx,
266 chat_template,
267 ctx_size,
268 })
269 }
270
271 fn new_context_window(&self) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
272 self.model.new_context(
273 &LLAMA_BACKEND,
274 LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
275 )
276 }
277}
278
279impl<'s, 'req> VisionLmRunner<'s, 'req> for Gemma3VisionRunner {
280 type Response = Gemma3Stream<'s, ImageOrText<'req>, Gemma3VisionRunner>;
281
282 fn stream_vlm_response(&'s self, request: VisionLmRequest<'req>) -> Self::Response {
283 let ctx = self
284 .new_context_window()
285 .map_err(|err| RunnerError::from(err));
286 Gemma3Stream::new(ctx, request, self, &self.model)
287 }
288}
289
290impl<'s, 'req> TextLmRunner<'s, 'req> for Gemma3VisionRunner {
291 type Response = <Self as VisionLmRunner<'s, 'req>>::Response;
292
293 fn stream_lm_response(&'s self, request: TextLmRequest<'req>) -> Self::Response {
294 self.stream_vlm_response(request.into())
295 }
296}
297
298impl<'a> From<TextLmRequest<'a>> for VisionLmRequest<'a> {
299 fn from(value: TextLmRequest<'a>) -> Self {
300 Self {
301 messages: value
302 .messages
303 .into_iter()
304 .map(|(role, text)| (role, ImageOrText::Text(text)))
305 .collect(),
306 sampling: value.sampling,
307 llguidance: value.llguidance,
308 max_seq: value.max_seq,
309 prefill: value.prefill,
310 }
311 }
312}
313
314pub struct Gemma3Stream<'a, Message, Runner> {
315 ctx_source: Option<Result<LlamaContext<'a>, RunnerError>>,
316 ctx: Option<LlamaContext<'a>>,
317 req: RunnerRequest<Message>,
318 runner: &'a Runner,
319 model: &'a LlamaModel,
320 runtime: Option<Runtime<'a>>,
321 done: bool,
322}
323
324struct Runtime<'a> {
325 sampler: LlamaSampler,
326 decoder: Decoder,
327 batch: LlamaBatch<'a>,
328 n_past: i32,
329 step: usize,
330}
331
332trait PrepareRun {
333 fn prepare(&mut self) -> Result<(), RunnerError>;
334}
335
336impl PrepareRun for Gemma3Stream<'_, ImageOrText<'_>, Gemma3VisionRunner> {
337 fn prepare(&mut self) -> Result<(), RunnerError> {
338 let media_marker = mtmd::mtmd_default_marker();
340 let messages = self
341 .req
342 .messages
343 .iter()
344 .fold(
345 Vec::<(MessageRole, String)>::new(),
346 |mut acc, (role, message)| {
347 let text = match message {
348 ImageOrText::Text(text) => text,
349 ImageOrText::Image(_) => media_marker,
350 };
351 if let Some(last) = acc.last()
352 && last.0 == *role
353 {
354 let (_, adj) = acc.remove(acc.len() - 1);
356 acc.push((role.clone(), format!("{0}\n{text}", adj)));
357 acc
358 } else {
359 acc.push((role.clone(), text.to_string()));
360 acc
361 }
362 },
363 )
364 .into_iter()
365 .map(|(role, content)| LlamaChatMessage::new(role.to_string(), content))
366 .collect::<Result<Vec<_>, _>>()
367 .expect("message preprocessing failed");
368 log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
369
370 let formatted_prompt =
372 self.runner
373 .model
374 .apply_chat_template(&self.runner.chat_template, &messages, true)?;
375 let bitmaps = self
376 .req
377 .messages
378 .iter()
379 .filter_map(|msg| match &msg.1 {
380 ImageOrText::Image(image) => Some(image),
381 _ => None,
382 })
383 .enumerate()
384 .map(|(idx, im)| {
385 MtmdBitmap::from_image_data(
386 im.width(),
387 im.height(),
388 im.to_rgb8().to_vec().as_slice(),
389 )
390 .expect(format!("image#{} has corrupted RGB data", idx).as_str())
391 })
392 .collect::<Vec<_>>();
393 let bitmap_refs = bitmaps.iter().collect::<Vec<_>>();
394 let chunks = self.runner.mtmd_ctx.tokenize(
395 MtmdInputText {
396 text: formatted_prompt,
397 add_special: true,
398 parse_special: true,
399 },
400 &bitmap_refs,
401 )?;
402 log::debug!(target: "gemma", "tokenization resulted in {} chunks", chunks.len());
403 let n_past = chunks.eval_chunks(
404 &self.runner.mtmd_ctx,
405 self.ctx.as_ref().unwrap(),
406 0,
407 0,
408 1,
409 true,
410 )?;
411
412 let mut preparation = Runtime {
414 sampler: self.req.sampling.to_llama(),
415 decoder: UTF_8.new_decoder(),
416 batch: LlamaBatch::new(self.runner.ctx_size.get() as usize, 1),
417 n_past,
418 step: 0,
419 };
420 if let Some(llguidance) = &self.req.llguidance {
421 let llg_sampler = llguidance.to_llama(&self.runner.model)?;
422 preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
423 }
424 self.runtime = Some(preparation);
425
426 Ok(())
427 }
428}
429
430impl<S: AsRef<str>> PrepareRun for Gemma3Stream<'_, S, Gemma3TextRunner> {
431 fn prepare(&mut self) -> Result<(), RunnerError> {
432 let messages = self
434 .req
435 .messages
436 .iter()
437 .fold(
438 Vec::<(MessageRole, String)>::new(),
439 |mut acc, (role, message)| {
440 if let Some(last) = acc.last()
441 && last.0 == *role
442 {
443 let (_, adj) = acc.remove(acc.len() - 1);
445 acc.push((role.clone(), format!("{0}\n{1}", adj, message.as_ref())));
446 acc
447 } else {
448 acc.push((role.clone(), message.as_ref().to_string()));
449 acc
450 }
451 },
452 )
453 .into_iter()
454 .map(|(role, content)| LlamaChatMessage::new(role.to_string(), content))
455 .collect::<Result<Vec<_>, _>>()
456 .expect("message preprocessing failed");
457 log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
458
459 let formatted_prompt =
461 self.runner
462 .model
463 .apply_chat_template(&self.runner.chat_template, &messages, true)?;
464 let token_list = self.model.str_to_token(&formatted_prompt, AddBos::Always)?;
465 let mut batch = LlamaBatch::new(self.runner.ctx_size.get() as usize, 1);
466 let token_list_len = token_list.len();
467 for (i, token) in token_list.into_iter().enumerate() {
468 batch.add(token, i as i32, &[0], i == token_list_len - 1)?;
469 }
470 self.ctx.as_mut().unwrap().decode(&mut batch)?;
471
472 let mut preparation = Runtime {
474 sampler: self.req.sampling.to_llama(),
475 decoder: UTF_8.new_decoder(),
476 batch,
477 n_past: token_list_len as i32,
478 step: 0,
479 };
480 if let Some(llguidance) = &self.req.llguidance {
481 let llg_sampler = llguidance.to_llama(&self.runner.model)?;
482 preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
483 }
484 self.runtime = Some(preparation);
485
486 Ok(())
487 }
488}
489
490impl<'a, M, R> Iterator for Gemma3Stream<'a, M, R>
491where
492 Self: PrepareRun,
493{
494 type Item = Result<String, RunnerError>;
495
496 fn next(&mut self) -> Option<Self::Item> {
497 if self.done {
498 return None;
499 }
500
501 if let Some(result) = self.ctx_source.take() {
502 match result {
503 Ok(ctx) => self.ctx = Some(ctx),
504 Err(err) => {
505 self.done = true;
506 return Some(Err(err));
507 }
508 }
509 }
510
511 if self.runtime.is_none()
512 && let Err(err) = self.prepare()
513 {
514 self.done = true;
515 return Some(Err(err));
516 }
517 let Runtime {
518 sampler,
519 decoder,
520 batch,
521 n_past,
522 step,
523 } = self.runtime.as_mut().unwrap();
524
525 if *step >= self.req.max_seq {
526 self.done = true;
527 return None;
528 }
529
530 let ctx = self.ctx.as_mut().unwrap();
532 let model = self.model;
533 let sample_idx = batch.n_tokens() - 1;
534 let mut sample = |token: LlamaToken,
535 sampler: &mut LlamaSampler,
536 ctx: &mut LlamaContext<'a>,
537 step: usize|
538 -> Result<Option<String>, RunnerError> {
539 sampler.accept(token);
540 if model.is_eog_token(token) {
541 return Ok(None);
542 }
543 batch.clear();
544 batch.add(token, *n_past + (step as i32), &[0], true)?;
545
546 ctx.decode(batch)?;
547
548 let piece = model.token_to_piece(token, decoder, true, None)?;
549 Ok(Some(piece))
550 };
551 if let Some(prefill) = self.req.prefill.take() {
552 log::debug!(target: "gemma", "prefill: {}", prefill);
553 let tokens = match model.str_to_token(&prefill, AddBos::Never) {
554 Ok(tokens) => tokens,
555 Err(err) => {
556 return Some(Err(err.into()));
557 }
558 };
559 log::debug!(target: "gemma", "prefill tokens: {:?}", tokens.iter().map(|t| t.0).collect::<Vec<_>>());
560 for token in tokens {
561 match sample(token, sampler, ctx, *step) {
562 Ok(_) => {}
563 Err(err) => return Some(Err(err.into())),
564 }
565 *step += 1;
566 }
567 Some(Ok(prefill))
568 } else {
569 let token = sampler.sample(ctx, sample_idx);
570 match sample(token, sampler, ctx, *step) {
571 Ok(Some(piece)) => {
572 *step += 1;
573 return Some(Ok(piece));
574 }
575 Ok(None) => {
576 self.done = true;
577 return None;
578 }
579 Err(err) => {
580 self.done = true;
581 return Some(Err(err));
582 }
583 }
584 }
585 }
586}
587
588impl<'s, M, R> Gemma3Stream<'s, M, R> {
589 fn new(
590 source: Result<LlamaContext<'s>, RunnerError>,
591 req: RunnerRequest<M>,
592 runner: &'s R,
593 model: &'s LlamaModel,
594 ) -> Self {
595 Self {
596 ctx_source: Some(source),
597 ctx: None,
598 req,
599 runner,
600 model,
601 runtime: None,
602 done: false,
603 }
604 }
605}
606
607pub struct RunnerWithRecommendedSampling<Inner> {
608 pub inner: Inner,
609 pub default_sampling: SimpleSamplingParams,
610}
611
612impl<'a, Inner> RunnerWithRecommendedSampling<Inner> {
613 fn get_preprocessed_simple_sampling(
614 &self,
615 sampling: SimpleSamplingParams,
616 ) -> SimpleSamplingParams {
617 let mut sampling = sampling;
618 if sampling.top_k.is_none() {
619 sampling.top_k = self.default_sampling.top_k;
620 }
621 if sampling.top_p.is_none() {
622 sampling.top_p = self.default_sampling.top_p;
623 }
624 if sampling.temperature.is_none() {
625 sampling.temperature = self.default_sampling.temperature;
626 }
627 sampling
628 }
629}
630
631impl<'s, 'req, Inner> VisionLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
632where
633 Inner: VisionLmRunner<'s, 'req>,
634{
635 type Response = <Inner as VisionLmRunner<'s, 'req>>::Response;
636
637 fn stream_vlm_response(&'s self, mut request: VisionLmRequest<'req>) -> Self::Response {
638 request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
639 self.inner.stream_vlm_response(request)
640 }
641}
642
643impl<'s, 'req, Inner> TextLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
644where
645 Inner: TextLmRunner<'s, 'req>,
646{
647 type Response = <Inner as TextLmRunner<'s, 'req>>::Response;
648
649 fn stream_lm_response(&'s self, mut request: TextLmRequest<'req>) -> Self::Response {
650 request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
651 self.inner.stream_lm_response(request)
652 }
653}
654
655impl<Inner> From<Inner> for RunnerWithRecommendedSampling<Inner> {
656 fn from(value: Inner) -> Self {
657 Self {
658 inner: value,
659 default_sampling: SimpleSamplingParams::default(),
660 }
661 }
662}
663
664fn build_hf_api() -> Result<hf_hub::api::tokio::Api, hf_hub::api::tokio::ApiError> {
665 let mut api = ApiBuilder::new()
666 .with_progress(std::io::stdin().is_terminal())
667 .with_token(std::env::var("HF_TOKEN").ok())
668 .with_chunk_size(Some(2 << 28));
669 if let Ok(endpoint) = std::env::var("HF_ENDPOINT") {
670 api = api.with_endpoint(endpoint);
671 }
672 if let Ok(cache) = std::env::var("HF_HOME") {
673 api = api.with_cache_dir(
674 PathBuf::from_str(&cache).expect("HF_HOME env var is not a valid path"),
675 );
676 }
677 api.build()
678}