1use std::{
2 io::IsTerminal,
3 num::NonZeroU32,
4 path::{Path, PathBuf},
5 str::FromStr,
6};
7
8use encoding_rs::UTF_8;
9use hf_hub::api::tokio::ApiBuilder;
10use llama_cpp_2::{
11 LlamaContextLoadError,
12 context::{LlamaContext, params::LlamaContextParams},
13 llama_batch::LlamaBatch,
14 model::{AddBos, LlamaChatTemplate, LlamaModel},
15 mtmd::{self, MtmdBitmap, MtmdContext, MtmdInputText},
16 sampling::LlamaSampler,
17};
18
19use crate::{
20 GenericTextLmRequest, GenericVisionLmRequest, ImageOrText, MessageRole,
21 RunnerWithRecommendedSampling, TextLmRunner, VisionLmRunner,
22 error::{CreateLlamaCppRunnerError, GenericRunnerError},
23 hf::build_hf_api,
24 runner::{Gemma3Stream, LLAMA_BACKEND, PrepareRun, Runtime},
25 sample::SimpleSamplingParams,
26 template::ChatTemplate,
27};
28
29pub const QWEN_3D5_4B_GUFF_MODEL_ID: &str = "unsloth/Qwen3.5-4B-GGUF";
30pub const QWEN_3D5_4B_GUFF_MODDEL_FILENAME: &str = "Qwen3.5-4B-Q4_K_M.gguf";
31pub const QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
32
33pub const GEMMA_3_1B_GUFF_MODEL_ID: &str = "google/gemma-3-1b-it-qat-q4_0-gguf";
34pub const GEMMA_3_1B_GUFF_MODEL_FILENAME: &str = "gemma-3-1b-it-q4_0.gguf";
35
36pub struct Gemma3TextRunner {
37 model: LlamaModel,
38 llama_template: LlamaChatTemplate,
39 ctx_size: NonZeroU32,
40}
41
42impl Gemma3TextRunner {
43 pub async fn new(
44 model_id: impl ToString,
45 model_file: impl AsRef<str>,
46 ctx_size: NonZeroU32,
47 ) -> Result<Self, CreateLlamaCppRunnerError> {
48 let repo = build_hf_api()?.model(model_id.to_string());
49 Self::from_file(repo.get(model_file.as_ref()).await?, ctx_size)
50 }
51
52 pub fn recommend_sampling() -> SimpleSamplingParams {
53 SimpleSamplingParams {
54 top_p: Some(0.95f32),
55 top_k: Some(64),
56 temperature: Some(1f32),
57 ..Default::default()
58 }
59 }
60
61 pub fn from_file(
62 model_file: impl AsRef<Path>,
63 ctx_size: NonZeroU32,
64 ) -> Result<Self, CreateLlamaCppRunnerError> {
65 let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
66
67 let chat_template = model.chat_template(None)?;
68 Ok(Self {
69 model,
70 llama_template: chat_template,
71 ctx_size,
72 })
73 }
74
75 pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
76 {
77 let inner = Self::new(
78 GEMMA_3_1B_GUFF_MODEL_ID,
79 GEMMA_3_1B_GUFF_MODEL_FILENAME,
80 32_000.try_into().unwrap(),
81 )
82 .await?;
83 Ok(RunnerWithRecommendedSampling {
84 inner,
85 default_sampling: Self::recommend_sampling(),
86 })
87 }
88}
89
90impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req, Tmpl> for Gemma3TextRunner
91where
92 Tmpl: ChatTemplate,
93{
94 fn stream_lm_response(
95 &'s self,
96 request: GenericTextLmRequest<'req, Tmpl>,
97 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
98 let ctx = self
99 .model
100 .new_context(
101 &LLAMA_BACKEND,
102 LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
103 )
104 .map_err(|err| GenericRunnerError::from(err));
105 Gemma3Stream::new(ctx, request, self, &self.model)
106 }
107}
108
109pub struct Gemma3VisionRunner {
110 model: LlamaModel,
111 chat_template: LlamaChatTemplate,
112 mtmd_ctx: MtmdContext,
113 ctx_size: NonZeroU32,
114}
115
116impl Gemma3VisionRunner {
117 pub async fn new(
118 repo_id: impl ToString,
119 model_file: impl AsRef<str>,
120 multimodel_file: impl AsRef<str>,
121 ctx_size: NonZeroU32,
122 ) -> Result<Self, CreateLlamaCppRunnerError> {
123 let repo = build_hf_api()?.model(repo_id.to_string());
124 let model = LlamaModel::load_from_file(
125 &LLAMA_BACKEND,
126 repo.get(model_file.as_ref()).await?,
127 &Default::default(),
128 )?;
129
130 let mtmd_ctx = MtmdContext::init_from_file(
131 repo.get(multimodel_file.as_ref()).await?.to_str().unwrap(),
132 &model,
133 &Default::default(),
134 )?;
135
136 let chat_template = model.chat_template(None)?;
137
138 Ok(Self {
139 model,
140 mtmd_ctx,
141 chat_template,
142 ctx_size,
143 })
144 }
145
146 pub fn from_files(
147 model_file: impl AsRef<Path>,
148 multimodel_file: impl AsRef<Path>,
149 ctx_size: NonZeroU32,
150 ) -> Result<Self, CreateLlamaCppRunnerError> {
151 let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
152 let mtmd_ctx = MtmdContext::init_from_file(
153 multimodel_file.as_ref().as_os_str().to_str().unwrap(),
154 &model,
155 &Default::default(),
156 )?;
157
158 let chat_template = model.chat_template(None)?;
159
160 Ok(Self {
161 model,
162 mtmd_ctx,
163 chat_template,
164 ctx_size,
165 })
166 }
167
168 fn new_context_window(&self) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
169 self.model.new_context(
170 &LLAMA_BACKEND,
171 LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
172 )
173 }
174
175 pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
176 {
177 let inner = Self::new(
178 QWEN_3D5_4B_GUFF_MODEL_ID,
179 QWEN_3D5_4B_GUFF_MODDEL_FILENAME,
180 QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME,
181 16384u32.try_into().unwrap(),
182 )
183 .await?;
184 Ok(RunnerWithRecommendedSampling {
185 inner: inner,
186 default_sampling: SimpleSamplingParams {
187 top_p: Some(0.8f32),
188 top_k: Some(20),
189 temperature: Some(0.7f32),
190 presence_penalty: Some(1.5),
191 repetition_penalty: Some(1.0),
192 seed: None,
193 },
194 })
195 }
196}
197
198impl<'s, 'req, Tmpl> VisionLmRunner<'s, 'req, Tmpl> for Gemma3VisionRunner
199where
200 Tmpl: ChatTemplate,
201{
202 fn stream_vlm_response(
203 &'s self,
204 request: GenericVisionLmRequest<'req, Tmpl>,
205 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
206 let ctx = self
207 .new_context_window()
208 .map_err(|err| GenericRunnerError::from(err));
209 Gemma3Stream::new(ctx, request, self, &self.model)
210 }
211}
212
213impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req, Tmpl> for Gemma3VisionRunner
214where
215 Tmpl: ChatTemplate,
216{
217 fn stream_lm_response(
218 &'s self,
219 request: GenericTextLmRequest<'req, Tmpl>,
220 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
221 self.stream_vlm_response(request.into())
222 }
223}
224
225impl<'a, Tmpl> From<GenericTextLmRequest<'a, Tmpl>> for GenericVisionLmRequest<'a, Tmpl> {
226 fn from(value: GenericTextLmRequest<'a, Tmpl>) -> Self {
227 Self {
228 messages: value
229 .messages
230 .into_iter()
231 .map(|(role, text)| (role, ImageOrText::Text(text)))
232 .collect(),
233 sampling: value.sampling,
234 llguidance: value.llguidance,
235 max_seq: value.max_seq,
236 prefill: value.prefill,
237 tmpl: value.tmpl,
238 }
239 }
240}
241
242impl<Tmpl> PrepareRun<Tmpl::Error> for Gemma3Stream<'_, ImageOrText<'_>, Gemma3VisionRunner, Tmpl>
243where
244 Tmpl: ChatTemplate,
245{
246 fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
247 let media_marker = mtmd::mtmd_default_marker();
249 let messages = self
250 .req
251 .messages
252 .iter()
253 .fold(
254 Vec::<(MessageRole, String)>::new(),
255 |mut acc, (role, message)| {
256 let text = match message {
257 ImageOrText::Text(text) => text,
258 ImageOrText::Image(_) => media_marker,
259 };
260 if let Some(last) = acc.last()
261 && last.0 == *role
262 {
263 let (_, adj) = acc.remove(acc.len() - 1);
265 acc.push((role.clone(), format!("{0}\n{text}", adj)));
266 acc
267 } else {
268 acc.push((role.clone(), text.to_string()));
269 acc
270 }
271 },
272 )
273 .into_iter()
274 .collect::<Vec<_>>();
275 log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
276
277 let formatted_prompt = self
279 .req
280 .tmpl
281 .apply_template(self.model, &self.runner.chat_template, &messages)
282 .map_err(GenericRunnerError::ApplyChatTemplate)?;
283
284 let bitmaps = self
286 .req
287 .messages
288 .iter()
289 .filter_map(|msg| match &msg.1 {
290 ImageOrText::Image(image) => Some(image),
291 _ => None,
292 })
293 .enumerate()
294 .map(|(idx, im)| {
295 MtmdBitmap::from_image_data(
296 im.width(),
297 im.height(),
298 im.to_rgb8().to_vec().as_slice(),
299 )
300 .expect(format!("image#{} has corrupted RGB data", idx).as_str())
301 })
302 .collect::<Vec<_>>();
303 let bitmap_refs = bitmaps.iter().collect::<Vec<_>>();
304 let chunks = self.runner.mtmd_ctx.tokenize(
305 MtmdInputText {
306 text: formatted_prompt,
307 add_special: true,
308 parse_special: true,
309 },
310 &bitmap_refs,
311 )?;
312 log::debug!(target: "gemma", "tokenization resulted in {} chunks", chunks.len());
313 let n_past = chunks.eval_chunks(
314 &self.runner.mtmd_ctx,
315 self.ctx.as_ref().unwrap(),
316 0,
317 0,
318 1,
319 true,
320 )?;
321
322 let mut preparation = Runtime {
324 sampler: self.req.sampling.to_llama(),
325 decoder: UTF_8.new_decoder(),
326 batch: LlamaBatch::new(self.runner.ctx_size.get() as usize, 1),
327 n_past,
328 step: 0,
329 };
330 if let Some(llguidance) = &self.req.llguidance {
331 let llg_sampler = llguidance.to_llama(&self.runner.model)?;
332 preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
333 }
334 self.runtime = Some(preparation);
335
336 Ok(())
337 }
338}
339
340impl<S: AsRef<str>, Tmpl> PrepareRun<Tmpl::Error> for Gemma3Stream<'_, S, Gemma3TextRunner, Tmpl>
341where
342 Tmpl: ChatTemplate,
343{
344 fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
345 let messages = self
347 .req
348 .messages
349 .iter()
350 .fold(
351 Vec::<(MessageRole, String)>::new(),
352 |mut acc, (role, message)| {
353 if let Some(last) = acc.last()
354 && last.0 == *role
355 {
356 let (_, adj) = acc.remove(acc.len() - 1);
358 acc.push((role.clone(), format!("{0}\n{1}", adj, message.as_ref())));
359 acc
360 } else {
361 acc.push((role.clone(), message.as_ref().to_string()));
362 acc
363 }
364 },
365 )
366 .into_iter()
367 .collect::<Vec<_>>();
368 log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
369
370 let formatted_prompt = self
372 .req
373 .tmpl
374 .apply_template(self.model, &self.runner.llama_template, &messages)
375 .map_err(GenericRunnerError::ApplyChatTemplate)?;
376
377 let token_list = self.model.str_to_token(&formatted_prompt, AddBos::Always)?;
379 let mut batch = LlamaBatch::new(self.runner.ctx_size.get() as usize, 1);
380 let token_list_len = token_list.len();
381 for (i, token) in token_list.into_iter().enumerate() {
382 batch.add(token, i as i32, &[0], i == token_list_len - 1)?;
383 }
384 self.ctx.as_mut().unwrap().decode(&mut batch)?;
385
386 let mut preparation = Runtime {
388 sampler: self.req.sampling.to_llama(),
389 decoder: UTF_8.new_decoder(),
390 batch,
391 n_past: token_list_len as i32,
392 step: 0,
393 };
394 if let Some(llguidance) = &self.req.llguidance {
395 let llg_sampler = llguidance.to_llama(&self.runner.model)?;
396 preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
397 }
398 self.runtime = Some(preparation);
399
400 Ok(())
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use crate::*;
407
408 #[tokio::test]
409 async fn test_lm() {
410 let runner = Gemma3TextRunner::default().await.unwrap();
411 let answer = runner
412 .get_lm_response(TextLmRequest {
413 messages: vec![(MessageRole::User, "What is the capital of France?")],
414 ..Default::default()
415 })
416 .unwrap();
417 assert!(answer.contains("Paris"));
418 }
419
420 #[tokio::test]
421 async fn test_vlm() {
422 let runner = Gemma3VisionRunner::default().await.unwrap();
423 let eiffel_tower_im =
424 image::load_from_memory(include_bytes!("../../assets/eiffel-tower.jpg")).unwrap();
425 let answer = runner
426 .get_vlm_response(VisionLmRequest {
427 messages: vec![
428 (
429 MessageRole::User,
430 ImageOrText::Text("Which city is this building in?"),
431 ),
432 (MessageRole::User, ImageOrText::Image(&eiffel_tower_im)),
433 ],
434 ..Default::default()
435 })
436 .unwrap();
437 assert!(answer.contains("Paris"));
438 }
439}