1use std::{
2 io::IsTerminal,
3 num::NonZeroU32,
4 path::{Path, PathBuf},
5 str::FromStr,
6};
7
8use encoding_rs::UTF_8;
9use hf_hub::api::tokio::ApiBuilder;
10use llama_cpp_2::{
11 LlamaContextLoadError,
12 context::{LlamaContext, params::LlamaContextParams},
13 llama_batch::LlamaBatch,
14 model::{AddBos, LlamaChatTemplate, LlamaModel},
15 mtmd::{self, MtmdBitmap, MtmdContext, MtmdInputText},
16 sampling::LlamaSampler,
17};
18
19use crate::{
20 GenericTextLmRequest, GenericVisionLmRequest, ImageOrText, MessageRole,
21 RunnerWithRecommendedSampling, TextLmRunner, VisionLmRunner,
22 error::{CreateLlamaCppRunnerError, GenericRunnerError},
23 runner::{Gemma3Stream, LLAMA_BACKEND, PrepareRun, Runtime},
24 sample::SimpleSamplingParams,
25 template::ChatTemplate,
26};
27
28pub const QWEN_3D5_4B_GUFF_MODEL_ID: &str = "unsloth/Qwen3.5-4B-GGUF";
29pub const QWEN_3D5_4B_GUFF_MODDEL_FILENAME: &str = "Qwen3.5-4B-Q4_K_M.gguf";
30pub const QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
31
32pub const GEMMA_3_1B_GUFF_MODEL_ID: &str = "google/gemma-3-1b-it-qat-q4_0-gguf";
33pub const GEMMA_3_1B_GUFF_MODEL_FILENAME: &str = "gemma-3-1b-it-q4_0.gguf";
34
35pub struct Gemma3TextRunner {
36 model: LlamaModel,
37 llama_template: LlamaChatTemplate,
38 ctx_size: NonZeroU32,
39}
40
41impl Gemma3TextRunner {
42 pub async fn new(
43 model_id: impl ToString,
44 model_file: impl AsRef<str>,
45 ctx_size: NonZeroU32,
46 ) -> Result<Self, CreateLlamaCppRunnerError> {
47 let repo = build_hf_api()?.model(model_id.to_string());
48 Self::from_file(repo.get(model_file.as_ref()).await?, ctx_size)
49 }
50
51 pub fn recommend_sampling() -> SimpleSamplingParams {
52 SimpleSamplingParams {
53 top_p: Some(0.95f32),
54 top_k: Some(64),
55 temperature: Some(1f32),
56 ..Default::default()
57 }
58 }
59
60 pub fn from_file(
61 model_file: impl AsRef<Path>,
62 ctx_size: NonZeroU32,
63 ) -> Result<Self, CreateLlamaCppRunnerError> {
64 let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
65
66 let chat_template = model.chat_template(None)?;
67 Ok(Self {
68 model,
69 llama_template: chat_template,
70 ctx_size,
71 })
72 }
73
74 pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
75 {
76 let inner = Self::new(
77 GEMMA_3_1B_GUFF_MODEL_ID,
78 GEMMA_3_1B_GUFF_MODEL_FILENAME,
79 32_000.try_into().unwrap(),
80 )
81 .await?;
82 Ok(RunnerWithRecommendedSampling {
83 inner,
84 default_sampling: Self::recommend_sampling(),
85 })
86 }
87}
88
89impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req, Tmpl> for Gemma3TextRunner
90where
91 Tmpl: ChatTemplate,
92{
93 fn stream_lm_response(
94 &'s self,
95 request: GenericTextLmRequest<'req, Tmpl>,
96 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
97 let ctx = self
98 .model
99 .new_context(
100 &LLAMA_BACKEND,
101 LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
102 )
103 .map_err(|err| GenericRunnerError::from(err));
104 Gemma3Stream::new(ctx, request, self, &self.model)
105 }
106}
107
108pub struct Gemma3VisionRunner {
109 model: LlamaModel,
110 chat_template: LlamaChatTemplate,
111 mtmd_ctx: MtmdContext,
112 ctx_size: NonZeroU32,
113}
114
115impl Gemma3VisionRunner {
116 pub async fn new(
117 repo_id: impl ToString,
118 model_file: impl AsRef<str>,
119 multimodel_file: impl AsRef<str>,
120 ctx_size: NonZeroU32,
121 ) -> Result<Self, CreateLlamaCppRunnerError> {
122 let repo = build_hf_api()?.model(repo_id.to_string());
123 let model = LlamaModel::load_from_file(
124 &LLAMA_BACKEND,
125 repo.get(model_file.as_ref()).await?,
126 &Default::default(),
127 )?;
128
129 let mtmd_ctx = MtmdContext::init_from_file(
130 repo.get(multimodel_file.as_ref()).await?.to_str().unwrap(),
131 &model,
132 &Default::default(),
133 )?;
134
135 let chat_template = model.chat_template(None)?;
136
137 Ok(Self {
138 model,
139 mtmd_ctx,
140 chat_template,
141 ctx_size,
142 })
143 }
144
145 pub fn from_files(
146 model_file: impl AsRef<Path>,
147 multimodel_file: impl AsRef<Path>,
148 ctx_size: NonZeroU32,
149 ) -> Result<Self, CreateLlamaCppRunnerError> {
150 let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
151 let mtmd_ctx = MtmdContext::init_from_file(
152 multimodel_file.as_ref().as_os_str().to_str().unwrap(),
153 &model,
154 &Default::default(),
155 )?;
156
157 let chat_template = model.chat_template(None)?;
158
159 Ok(Self {
160 model,
161 mtmd_ctx,
162 chat_template,
163 ctx_size,
164 })
165 }
166
167 fn new_context_window(&self) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
168 self.model.new_context(
169 &LLAMA_BACKEND,
170 LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
171 )
172 }
173
174 pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
175 {
176 let inner = Self::new(
177 QWEN_3D5_4B_GUFF_MODEL_ID,
178 QWEN_3D5_4B_GUFF_MODDEL_FILENAME,
179 QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME,
180 16384u32.try_into().unwrap(),
181 )
182 .await?;
183 Ok(RunnerWithRecommendedSampling {
184 inner: inner,
185 default_sampling: SimpleSamplingParams {
186 top_p: Some(0.8f32),
187 top_k: Some(20),
188 temperature: Some(0.7f32),
189 presence_penalty: Some(1.5),
190 repetition_penalty: Some(1.0),
191 seed: None,
192 },
193 })
194 }
195}
196
197impl<'s, 'req, Tmpl> VisionLmRunner<'s, 'req, Tmpl> for Gemma3VisionRunner
198where
199 Tmpl: ChatTemplate,
200{
201 fn stream_vlm_response(
202 &'s self,
203 request: GenericVisionLmRequest<'req, Tmpl>,
204 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
205 let ctx = self
206 .new_context_window()
207 .map_err(|err| GenericRunnerError::from(err));
208 Gemma3Stream::new(ctx, request, self, &self.model)
209 }
210}
211
212impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req, Tmpl> for Gemma3VisionRunner
213where
214 Tmpl: ChatTemplate,
215{
216 fn stream_lm_response(
217 &'s self,
218 request: GenericTextLmRequest<'req, Tmpl>,
219 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
220 self.stream_vlm_response(request.into())
221 }
222}
223
224impl<'a, Tmpl> From<GenericTextLmRequest<'a, Tmpl>> for GenericVisionLmRequest<'a, Tmpl> {
225 fn from(value: GenericTextLmRequest<'a, Tmpl>) -> Self {
226 Self {
227 messages: value
228 .messages
229 .into_iter()
230 .map(|(role, text)| (role, ImageOrText::Text(text)))
231 .collect(),
232 sampling: value.sampling,
233 llguidance: value.llguidance,
234 max_seq: value.max_seq,
235 prefill: value.prefill,
236 tmpl: value.tmpl,
237 }
238 }
239}
240
241fn build_hf_api() -> Result<hf_hub::api::tokio::Api, hf_hub::api::tokio::ApiError> {
242 let mut api = ApiBuilder::new()
243 .with_progress(std::io::stdin().is_terminal())
244 .with_token(std::env::var("HF_TOKEN").ok())
245 .with_chunk_size(Some(2 << 28));
246 if let Ok(endpoint) = std::env::var("HF_ENDPOINT") {
247 api = api.with_endpoint(endpoint);
248 }
249 if let Ok(cache) = std::env::var("HF_HOME") {
250 api = api.with_cache_dir(
251 PathBuf::from_str(&cache).expect("HF_HOME env var is not a valid path"),
252 );
253 }
254 api.build()
255}
256
257impl<Tmpl> PrepareRun<Tmpl::Error> for Gemma3Stream<'_, ImageOrText<'_>, Gemma3VisionRunner, Tmpl>
258where
259 Tmpl: ChatTemplate,
260{
261 fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
262 let media_marker = mtmd::mtmd_default_marker();
264 let messages = self
265 .req
266 .messages
267 .iter()
268 .fold(
269 Vec::<(MessageRole, String)>::new(),
270 |mut acc, (role, message)| {
271 let text = match message {
272 ImageOrText::Text(text) => text,
273 ImageOrText::Image(_) => media_marker,
274 };
275 if let Some(last) = acc.last()
276 && last.0 == *role
277 {
278 let (_, adj) = acc.remove(acc.len() - 1);
280 acc.push((role.clone(), format!("{0}\n{text}", adj)));
281 acc
282 } else {
283 acc.push((role.clone(), text.to_string()));
284 acc
285 }
286 },
287 )
288 .into_iter()
289 .collect::<Vec<_>>();
290 log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
291
292 let formatted_prompt = self
294 .req
295 .tmpl
296 .apply_template(self.model, &self.runner.chat_template, &messages)
297 .map_err(GenericRunnerError::ApplyChatTemplate)?;
298
299 let bitmaps = self
301 .req
302 .messages
303 .iter()
304 .filter_map(|msg| match &msg.1 {
305 ImageOrText::Image(image) => Some(image),
306 _ => None,
307 })
308 .enumerate()
309 .map(|(idx, im)| {
310 MtmdBitmap::from_image_data(
311 im.width(),
312 im.height(),
313 im.to_rgb8().to_vec().as_slice(),
314 )
315 .expect(format!("image#{} has corrupted RGB data", idx).as_str())
316 })
317 .collect::<Vec<_>>();
318 let bitmap_refs = bitmaps.iter().collect::<Vec<_>>();
319 let chunks = self.runner.mtmd_ctx.tokenize(
320 MtmdInputText {
321 text: formatted_prompt,
322 add_special: true,
323 parse_special: true,
324 },
325 &bitmap_refs,
326 )?;
327 log::debug!(target: "gemma", "tokenization resulted in {} chunks", chunks.len());
328 let n_past = chunks.eval_chunks(
329 &self.runner.mtmd_ctx,
330 self.ctx.as_ref().unwrap(),
331 0,
332 0,
333 1,
334 true,
335 )?;
336
337 let mut preparation = Runtime {
339 sampler: self.req.sampling.to_llama(),
340 decoder: UTF_8.new_decoder(),
341 batch: LlamaBatch::new(self.runner.ctx_size.get() as usize, 1),
342 n_past,
343 step: 0,
344 };
345 if let Some(llguidance) = &self.req.llguidance {
346 let llg_sampler = llguidance.to_llama(&self.runner.model)?;
347 preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
348 }
349 self.runtime = Some(preparation);
350
351 Ok(())
352 }
353}
354
355impl<S: AsRef<str>, Tmpl> PrepareRun<Tmpl::Error> for Gemma3Stream<'_, S, Gemma3TextRunner, Tmpl>
356where
357 Tmpl: ChatTemplate,
358{
359 fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
360 let messages = self
362 .req
363 .messages
364 .iter()
365 .fold(
366 Vec::<(MessageRole, String)>::new(),
367 |mut acc, (role, message)| {
368 if let Some(last) = acc.last()
369 && last.0 == *role
370 {
371 let (_, adj) = acc.remove(acc.len() - 1);
373 acc.push((role.clone(), format!("{0}\n{1}", adj, message.as_ref())));
374 acc
375 } else {
376 acc.push((role.clone(), message.as_ref().to_string()));
377 acc
378 }
379 },
380 )
381 .into_iter()
382 .collect::<Vec<_>>();
383 log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
384
385 let formatted_prompt = self
387 .req
388 .tmpl
389 .apply_template(self.model, &self.runner.llama_template, &messages)
390 .map_err(GenericRunnerError::ApplyChatTemplate)?;
391
392 let token_list = self.model.str_to_token(&formatted_prompt, AddBos::Always)?;
394 let mut batch = LlamaBatch::new(self.runner.ctx_size.get() as usize, 1);
395 let token_list_len = token_list.len();
396 for (i, token) in token_list.into_iter().enumerate() {
397 batch.add(token, i as i32, &[0], i == token_list_len - 1)?;
398 }
399 self.ctx.as_mut().unwrap().decode(&mut batch)?;
400
401 let mut preparation = Runtime {
403 sampler: self.req.sampling.to_llama(),
404 decoder: UTF_8.new_decoder(),
405 batch,
406 n_past: token_list_len as i32,
407 step: 0,
408 };
409 if let Some(llguidance) = &self.req.llguidance {
410 let llg_sampler = llguidance.to_llama(&self.runner.model)?;
411 preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
412 }
413 self.runtime = Some(preparation);
414
415 Ok(())
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use crate::*;
422
423 #[tokio::test]
424 async fn test_lm() {
425 let runner = Gemma3TextRunner::default().await.unwrap();
426 let answer = runner
427 .get_lm_response(TextLmRequest {
428 messages: vec![(MessageRole::User, "What is the capital of France?")],
429 ..Default::default()
430 })
431 .unwrap();
432 assert!(answer.contains("Paris"));
433 }
434
435 #[tokio::test]
436 async fn test_vlm() {
437 let runner = Gemma3VisionRunner::default().await.unwrap();
438 let eiffel_tower_im =
439 image::load_from_memory(include_bytes!("../../assets/eiffel-tower.jpg")).unwrap();
440 let answer = runner
441 .get_vlm_response(VisionLmRequest {
442 messages: vec![
443 (
444 MessageRole::User,
445 ImageOrText::Text("Which city is this building in?"),
446 ),
447 (MessageRole::User, ImageOrText::Image(&eiffel_tower_im)),
448 ],
449 ..Default::default()
450 })
451 .unwrap();
452 assert!(answer.contains("Paris"));
453 }
454}