Skip to main content

llama_runner/runner/
gemma3.rs

1use std::{
2    io::IsTerminal,
3    num::NonZeroU32,
4    path::{Path, PathBuf},
5    str::FromStr,
6};
7
8use encoding_rs::UTF_8;
9use hf_hub::api::tokio::ApiBuilder;
10use llama_cpp_2::{
11    LlamaContextLoadError,
12    context::{LlamaContext, params::LlamaContextParams},
13    llama_batch::LlamaBatch,
14    model::{AddBos, LlamaChatTemplate, LlamaModel},
15    mtmd::{self, MtmdBitmap, MtmdContext, MtmdInputText},
16    sampling::LlamaSampler,
17};
18
19use crate::{
20    GenericTextLmRequest, GenericVisionLmRequest, ImageOrText, MessageRole,
21    RunnerWithRecommendedSampling, TextLmRunner, VisionLmRunner,
22    error::{CreateLlamaCppRunnerError, GenericRunnerError},
23    hf::build_hf_api,
24    runner::{Gemma3Stream, LLAMA_BACKEND, PrepareRun, Runtime},
25    sample::SimpleSamplingParams,
26    template::ChatTemplate,
27};
28
29pub const QWEN_3D5_4B_GUFF_MODEL_ID: &str = "unsloth/Qwen3.5-4B-GGUF";
30pub const QWEN_3D5_4B_GUFF_MODDEL_FILENAME: &str = "Qwen3.5-4B-Q4_K_M.gguf";
31pub const QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
32
33pub const GEMMA_3_1B_GUFF_MODEL_ID: &str = "google/gemma-3-1b-it-qat-q4_0-gguf";
34pub const GEMMA_3_1B_GUFF_MODEL_FILENAME: &str = "gemma-3-1b-it-q4_0.gguf";
35
36pub struct Gemma3TextRunner {
37    model: LlamaModel,
38    llama_template: LlamaChatTemplate,
39    ctx_size: NonZeroU32,
40}
41
42impl Gemma3TextRunner {
43    pub async fn new(
44        model_id: impl ToString,
45        model_file: impl AsRef<str>,
46        ctx_size: NonZeroU32,
47    ) -> Result<Self, CreateLlamaCppRunnerError> {
48        let repo = build_hf_api()?.model(model_id.to_string());
49        Self::from_file(repo.get(model_file.as_ref()).await?, ctx_size)
50    }
51
52    pub fn recommend_sampling() -> SimpleSamplingParams {
53        SimpleSamplingParams {
54            top_p: Some(0.95f32),
55            top_k: Some(64),
56            temperature: Some(1f32),
57            ..Default::default()
58        }
59    }
60
61    pub fn from_file(
62        model_file: impl AsRef<Path>,
63        ctx_size: NonZeroU32,
64    ) -> Result<Self, CreateLlamaCppRunnerError> {
65        let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
66
67        let chat_template = model.chat_template(None)?;
68        Ok(Self {
69            model,
70            llama_template: chat_template,
71            ctx_size,
72        })
73    }
74
75    pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
76    {
77        let inner = Self::new(
78            GEMMA_3_1B_GUFF_MODEL_ID,
79            GEMMA_3_1B_GUFF_MODEL_FILENAME,
80            32_000.try_into().unwrap(),
81        )
82        .await?;
83        Ok(RunnerWithRecommendedSampling {
84            inner,
85            default_sampling: Self::recommend_sampling(),
86        })
87    }
88}
89
90impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req, Tmpl> for Gemma3TextRunner
91where
92    Tmpl: ChatTemplate,
93{
94    fn stream_lm_response(
95        &'s self,
96        request: GenericTextLmRequest<'req, Tmpl>,
97    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
98        let ctx = self
99            .model
100            .new_context(
101                &LLAMA_BACKEND,
102                LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
103            )
104            .map_err(|err| GenericRunnerError::from(err));
105        Gemma3Stream::new(ctx, request, self, &self.model)
106    }
107}
108
109pub struct Gemma3VisionRunner {
110    model: LlamaModel,
111    chat_template: LlamaChatTemplate,
112    mtmd_ctx: MtmdContext,
113    ctx_size: NonZeroU32,
114}
115
116impl Gemma3VisionRunner {
117    pub async fn new(
118        repo_id: impl ToString,
119        model_file: impl AsRef<str>,
120        multimodel_file: impl AsRef<str>,
121        ctx_size: NonZeroU32,
122    ) -> Result<Self, CreateLlamaCppRunnerError> {
123        let repo = build_hf_api()?.model(repo_id.to_string());
124        let model = LlamaModel::load_from_file(
125            &LLAMA_BACKEND,
126            repo.get(model_file.as_ref()).await?,
127            &Default::default(),
128        )?;
129
130        let mtmd_ctx = MtmdContext::init_from_file(
131            repo.get(multimodel_file.as_ref()).await?.to_str().unwrap(),
132            &model,
133            &Default::default(),
134        )?;
135
136        let chat_template = model.chat_template(None)?;
137
138        Ok(Self {
139            model,
140            mtmd_ctx,
141            chat_template,
142            ctx_size,
143        })
144    }
145
146    pub fn from_files(
147        model_file: impl AsRef<Path>,
148        multimodel_file: impl AsRef<Path>,
149        ctx_size: NonZeroU32,
150    ) -> Result<Self, CreateLlamaCppRunnerError> {
151        let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
152        let mtmd_ctx = MtmdContext::init_from_file(
153            multimodel_file.as_ref().as_os_str().to_str().unwrap(),
154            &model,
155            &Default::default(),
156        )?;
157
158        let chat_template = model.chat_template(None)?;
159
160        Ok(Self {
161            model,
162            mtmd_ctx,
163            chat_template,
164            ctx_size,
165        })
166    }
167
168    fn new_context_window(&self) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
169        self.model.new_context(
170            &LLAMA_BACKEND,
171            LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
172        )
173    }
174
175    pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
176    {
177        let inner = Self::new(
178            QWEN_3D5_4B_GUFF_MODEL_ID,
179            QWEN_3D5_4B_GUFF_MODDEL_FILENAME,
180            QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME,
181            16384u32.try_into().unwrap(),
182        )
183        .await?;
184        Ok(RunnerWithRecommendedSampling {
185            inner: inner,
186            default_sampling: SimpleSamplingParams {
187                top_p: Some(0.8f32),
188                top_k: Some(20),
189                temperature: Some(0.7f32),
190                presence_penalty: Some(1.5),
191                repetition_penalty: Some(1.0),
192                seed: None,
193            },
194        })
195    }
196}
197
198impl<'s, 'req, Tmpl> VisionLmRunner<'s, 'req, Tmpl> for Gemma3VisionRunner
199where
200    Tmpl: ChatTemplate,
201{
202    fn stream_vlm_response(
203        &'s self,
204        request: GenericVisionLmRequest<'req, Tmpl>,
205    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
206        let ctx = self
207            .new_context_window()
208            .map_err(|err| GenericRunnerError::from(err));
209        Gemma3Stream::new(ctx, request, self, &self.model)
210    }
211}
212
213impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req, Tmpl> for Gemma3VisionRunner
214where
215    Tmpl: ChatTemplate,
216{
217    fn stream_lm_response(
218        &'s self,
219        request: GenericTextLmRequest<'req, Tmpl>,
220    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
221        self.stream_vlm_response(request.into())
222    }
223}
224
225impl<'a, Tmpl> From<GenericTextLmRequest<'a, Tmpl>> for GenericVisionLmRequest<'a, Tmpl> {
226    fn from(value: GenericTextLmRequest<'a, Tmpl>) -> Self {
227        Self {
228            messages: value
229                .messages
230                .into_iter()
231                .map(|(role, text)| (role, ImageOrText::Text(text)))
232                .collect(),
233            sampling: value.sampling,
234            llguidance: value.llguidance,
235            max_seq: value.max_seq,
236            prefill: value.prefill,
237            tmpl: value.tmpl,
238        }
239    }
240}
241
242impl<Tmpl> PrepareRun<Tmpl::Error> for Gemma3Stream<'_, ImageOrText<'_>, Gemma3VisionRunner, Tmpl>
243where
244    Tmpl: ChatTemplate,
245{
246    fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
247        // Preprocess the message, flattening media
248        let media_marker = mtmd::mtmd_default_marker();
249        let messages = self
250            .req
251            .messages
252            .iter()
253            .fold(
254                Vec::<(MessageRole, String)>::new(),
255                |mut acc, (role, message)| {
256                    let text = match message {
257                        ImageOrText::Text(text) => text,
258                        ImageOrText::Image(_) => media_marker,
259                    };
260                    if let Some(last) = acc.last()
261                        && last.0 == *role
262                    {
263                        // merge adjacent
264                        let (_, adj) = acc.remove(acc.len() - 1);
265                        acc.push((role.clone(), format!("{0}\n{text}", adj)));
266                        acc
267                    } else {
268                        acc.push((role.clone(), text.to_string()));
269                        acc
270                    }
271                },
272            )
273            .into_iter()
274            .collect::<Vec<_>>();
275        log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
276
277        // apply custom template
278        let formatted_prompt = self
279            .req
280            .tmpl
281            .apply_template(self.model, &self.runner.chat_template, &messages)
282            .map_err(GenericRunnerError::ApplyChatTemplate)?;
283
284        // Aggregate images
285        let bitmaps = self
286            .req
287            .messages
288            .iter()
289            .filter_map(|msg| match &msg.1 {
290                ImageOrText::Image(image) => Some(image),
291                _ => None,
292            })
293            .enumerate()
294            .map(|(idx, im)| {
295                MtmdBitmap::from_image_data(
296                    im.width(),
297                    im.height(),
298                    im.to_rgb8().to_vec().as_slice(),
299                )
300                .expect(format!("image#{} has corrupted RGB data", idx).as_str())
301            })
302            .collect::<Vec<_>>();
303        let bitmap_refs = bitmaps.iter().collect::<Vec<_>>();
304        let chunks = self.runner.mtmd_ctx.tokenize(
305            MtmdInputText {
306                text: formatted_prompt,
307                add_special: true,
308                parse_special: true,
309            },
310            &bitmap_refs,
311        )?;
312        log::debug!(target: "gemma", "tokenization resulted in {} chunks", chunks.len());
313        let n_past = chunks.eval_chunks(
314            &self.runner.mtmd_ctx,
315            self.ctx.as_ref().unwrap(),
316            0,
317            0,
318            1,
319            true,
320        )?;
321
322        // Generate preparation
323        let mut preparation = Runtime {
324            sampler: self.req.sampling.to_llama(),
325            decoder: UTF_8.new_decoder(),
326            batch: LlamaBatch::new(self.runner.ctx_size.get() as usize, 1),
327            n_past,
328            step: 0,
329        };
330        if let Some(llguidance) = &self.req.llguidance {
331            let llg_sampler = llguidance.to_llama(&self.runner.model)?;
332            preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
333        }
334        self.runtime = Some(preparation);
335
336        Ok(())
337    }
338}
339
340impl<S: AsRef<str>, Tmpl> PrepareRun<Tmpl::Error> for Gemma3Stream<'_, S, Gemma3TextRunner, Tmpl>
341where
342    Tmpl: ChatTemplate,
343{
344    fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
345        // Preprocess the message
346        let messages = self
347            .req
348            .messages
349            .iter()
350            .fold(
351                Vec::<(MessageRole, String)>::new(),
352                |mut acc, (role, message)| {
353                    if let Some(last) = acc.last()
354                        && last.0 == *role
355                    {
356                        // merge adjacent
357                        let (_, adj) = acc.remove(acc.len() - 1);
358                        acc.push((role.clone(), format!("{0}\n{1}", adj, message.as_ref())));
359                        acc
360                    } else {
361                        acc.push((role.clone(), message.as_ref().to_string()));
362                        acc
363                    }
364                },
365            )
366            .into_iter()
367            .collect::<Vec<_>>();
368        log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
369
370        // apply custom template
371        let formatted_prompt = self
372            .req
373            .tmpl
374            .apply_template(self.model, &self.runner.llama_template, &messages)
375            .map_err(GenericRunnerError::ApplyChatTemplate)?;
376
377        // Aggregate images
378        let token_list = self.model.str_to_token(&formatted_prompt, AddBos::Always)?;
379        let mut batch = LlamaBatch::new(self.runner.ctx_size.get() as usize, 1);
380        let token_list_len = token_list.len();
381        for (i, token) in token_list.into_iter().enumerate() {
382            batch.add(token, i as i32, &[0], i == token_list_len - 1)?;
383        }
384        self.ctx.as_mut().unwrap().decode(&mut batch)?;
385
386        // Generate preparation
387        let mut preparation = Runtime {
388            sampler: self.req.sampling.to_llama(),
389            decoder: UTF_8.new_decoder(),
390            batch,
391            n_past: token_list_len as i32,
392            step: 0,
393        };
394        if let Some(llguidance) = &self.req.llguidance {
395            let llg_sampler = llguidance.to_llama(&self.runner.model)?;
396            preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
397        }
398        self.runtime = Some(preparation);
399
400        Ok(())
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use crate::*;
407
408    #[tokio::test]
409    async fn test_lm() {
410        let runner = Gemma3TextRunner::default().await.unwrap();
411        let answer = runner
412            .get_lm_response(TextLmRequest {
413                messages: vec![(MessageRole::User, "What is the capital of France?")],
414                ..Default::default()
415            })
416            .unwrap();
417        assert!(answer.contains("Paris"));
418    }
419
420    #[tokio::test]
421    async fn test_vlm() {
422        let runner = Gemma3VisionRunner::default().await.unwrap();
423        let eiffel_tower_im =
424            image::load_from_memory(include_bytes!("../../assets/eiffel-tower.jpg")).unwrap();
425        let answer = runner
426            .get_vlm_response(VisionLmRequest {
427                messages: vec![
428                    (
429                        MessageRole::User,
430                        ImageOrText::Text("Which city is this building in?"),
431                    ),
432                    (MessageRole::User, ImageOrText::Image(&eiffel_tower_im)),
433                ],
434                ..Default::default()
435            })
436            .unwrap();
437        assert!(answer.contains("Paris"));
438    }
439}