Skip to main content

llama_runner/runner/
gemma3.rs

1use std::{
2    io::IsTerminal,
3    num::NonZeroU32,
4    path::{Path, PathBuf},
5    str::FromStr,
6};
7
8use encoding_rs::UTF_8;
9use hf_hub::api::tokio::ApiBuilder;
10use llama_cpp_2::{
11    LlamaContextLoadError,
12    context::{LlamaContext, params::LlamaContextParams},
13    llama_batch::LlamaBatch,
14    model::{AddBos, LlamaChatTemplate, LlamaModel},
15    mtmd::{self, MtmdBitmap, MtmdContext, MtmdInputText},
16    sampling::LlamaSampler,
17};
18
19use crate::{
20    GenericTextLmRequest, GenericVisionLmRequest, ImageOrText, MessageRole,
21    RunnerWithRecommendedSampling, TextLmRunner, VisionLmRunner,
22    error::{CreateLlamaCppRunnerError, GenericRunnerError},
23    runner::{Gemma3Stream, LLAMA_BACKEND, PrepareRun, Runtime},
24    sample::SimpleSamplingParams,
25    template::ChatTemplate,
26};
27
28pub const QWEN_3D5_4B_GUFF_MODEL_ID: &str = "unsloth/Qwen3.5-4B-GGUF";
29pub const QWEN_3D5_4B_GUFF_MODDEL_FILENAME: &str = "Qwen3.5-4B-Q4_K_M.gguf";
30pub const QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
31
32pub const GEMMA_3_1B_GUFF_MODEL_ID: &str = "google/gemma-3-1b-it-qat-q4_0-gguf";
33pub const GEMMA_3_1B_GUFF_MODEL_FILENAME: &str = "gemma-3-1b-it-q4_0.gguf";
34
35pub struct Gemma3TextRunner {
36    model: LlamaModel,
37    llama_template: LlamaChatTemplate,
38    ctx_size: NonZeroU32,
39}
40
41impl Gemma3TextRunner {
42    pub async fn new(
43        model_id: impl ToString,
44        model_file: impl AsRef<str>,
45        ctx_size: NonZeroU32,
46    ) -> Result<Self, CreateLlamaCppRunnerError> {
47        let repo = build_hf_api()?.model(model_id.to_string());
48        Self::from_file(repo.get(model_file.as_ref()).await?, ctx_size)
49    }
50
51    pub fn recommend_sampling() -> SimpleSamplingParams {
52        SimpleSamplingParams {
53            top_p: Some(0.95f32),
54            top_k: Some(64),
55            temperature: Some(1f32),
56            ..Default::default()
57        }
58    }
59
60    pub fn from_file(
61        model_file: impl AsRef<Path>,
62        ctx_size: NonZeroU32,
63    ) -> Result<Self, CreateLlamaCppRunnerError> {
64        let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
65
66        let chat_template = model.chat_template(None)?;
67        Ok(Self {
68            model,
69            llama_template: chat_template,
70            ctx_size,
71        })
72    }
73
74    pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
75    {
76        let inner = Self::new(
77            GEMMA_3_1B_GUFF_MODEL_ID,
78            GEMMA_3_1B_GUFF_MODEL_FILENAME,
79            32_000.try_into().unwrap(),
80        )
81        .await?;
82        Ok(RunnerWithRecommendedSampling {
83            inner,
84            default_sampling: Self::recommend_sampling(),
85        })
86    }
87}
88
89impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req, Tmpl> for Gemma3TextRunner
90where
91    Tmpl: ChatTemplate,
92{
93    fn stream_lm_response(
94        &'s self,
95        request: GenericTextLmRequest<'req, Tmpl>,
96    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
97        let ctx = self
98            .model
99            .new_context(
100                &LLAMA_BACKEND,
101                LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
102            )
103            .map_err(|err| GenericRunnerError::from(err));
104        Gemma3Stream::new(ctx, request, self, &self.model)
105    }
106}
107
108pub struct Gemma3VisionRunner {
109    model: LlamaModel,
110    chat_template: LlamaChatTemplate,
111    mtmd_ctx: MtmdContext,
112    ctx_size: NonZeroU32,
113}
114
115impl Gemma3VisionRunner {
116    pub async fn new(
117        repo_id: impl ToString,
118        model_file: impl AsRef<str>,
119        multimodel_file: impl AsRef<str>,
120        ctx_size: NonZeroU32,
121    ) -> Result<Self, CreateLlamaCppRunnerError> {
122        let repo = build_hf_api()?.model(repo_id.to_string());
123        let model = LlamaModel::load_from_file(
124            &LLAMA_BACKEND,
125            repo.get(model_file.as_ref()).await?,
126            &Default::default(),
127        )?;
128
129        let mtmd_ctx = MtmdContext::init_from_file(
130            repo.get(multimodel_file.as_ref()).await?.to_str().unwrap(),
131            &model,
132            &Default::default(),
133        )?;
134
135        let chat_template = model.chat_template(None)?;
136
137        Ok(Self {
138            model,
139            mtmd_ctx,
140            chat_template,
141            ctx_size,
142        })
143    }
144
145    pub fn from_files(
146        model_file: impl AsRef<Path>,
147        multimodel_file: impl AsRef<Path>,
148        ctx_size: NonZeroU32,
149    ) -> Result<Self, CreateLlamaCppRunnerError> {
150        let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
151        let mtmd_ctx = MtmdContext::init_from_file(
152            multimodel_file.as_ref().as_os_str().to_str().unwrap(),
153            &model,
154            &Default::default(),
155        )?;
156
157        let chat_template = model.chat_template(None)?;
158
159        Ok(Self {
160            model,
161            mtmd_ctx,
162            chat_template,
163            ctx_size,
164        })
165    }
166
167    fn new_context_window(&self) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
168        self.model.new_context(
169            &LLAMA_BACKEND,
170            LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
171        )
172    }
173
174    pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
175    {
176        let inner = Self::new(
177            QWEN_3D5_4B_GUFF_MODEL_ID,
178            QWEN_3D5_4B_GUFF_MODDEL_FILENAME,
179            QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME,
180            16384u32.try_into().unwrap(),
181        )
182        .await?;
183        Ok(RunnerWithRecommendedSampling {
184            inner: inner,
185            default_sampling: SimpleSamplingParams {
186                top_p: Some(0.8f32),
187                top_k: Some(20),
188                temperature: Some(0.7f32),
189                presence_penalty: Some(1.5),
190                repetition_penalty: Some(1.0),
191                seed: None,
192            },
193        })
194    }
195}
196
197impl<'s, 'req, Tmpl> VisionLmRunner<'s, 'req, Tmpl> for Gemma3VisionRunner
198where
199    Tmpl: ChatTemplate,
200{
201    fn stream_vlm_response(
202        &'s self,
203        request: GenericVisionLmRequest<'req, Tmpl>,
204    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
205        let ctx = self
206            .new_context_window()
207            .map_err(|err| GenericRunnerError::from(err));
208        Gemma3Stream::new(ctx, request, self, &self.model)
209    }
210}
211
212impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req, Tmpl> for Gemma3VisionRunner
213where
214    Tmpl: ChatTemplate,
215{
216    fn stream_lm_response(
217        &'s self,
218        request: GenericTextLmRequest<'req, Tmpl>,
219    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
220        self.stream_vlm_response(request.into())
221    }
222}
223
224impl<'a, Tmpl> From<GenericTextLmRequest<'a, Tmpl>> for GenericVisionLmRequest<'a, Tmpl> {
225    fn from(value: GenericTextLmRequest<'a, Tmpl>) -> Self {
226        Self {
227            messages: value
228                .messages
229                .into_iter()
230                .map(|(role, text)| (role, ImageOrText::Text(text)))
231                .collect(),
232            sampling: value.sampling,
233            llguidance: value.llguidance,
234            max_seq: value.max_seq,
235            prefill: value.prefill,
236            tmpl: value.tmpl,
237        }
238    }
239}
240
241fn build_hf_api() -> Result<hf_hub::api::tokio::Api, hf_hub::api::tokio::ApiError> {
242    let mut api = ApiBuilder::new()
243        .with_progress(std::io::stdin().is_terminal())
244        .with_token(std::env::var("HF_TOKEN").ok())
245        .with_chunk_size(Some(2 << 28));
246    if let Ok(endpoint) = std::env::var("HF_ENDPOINT") {
247        api = api.with_endpoint(endpoint);
248    }
249    if let Ok(cache) = std::env::var("HF_HOME") {
250        api = api.with_cache_dir(
251            PathBuf::from_str(&cache).expect("HF_HOME env var is not a valid path"),
252        );
253    }
254    api.build()
255}
256
257impl<Tmpl> PrepareRun<Tmpl::Error> for Gemma3Stream<'_, ImageOrText<'_>, Gemma3VisionRunner, Tmpl>
258where
259    Tmpl: ChatTemplate,
260{
261    fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
262        // Preprocess the message, flattening media
263        let media_marker = mtmd::mtmd_default_marker();
264        let messages = self
265            .req
266            .messages
267            .iter()
268            .fold(
269                Vec::<(MessageRole, String)>::new(),
270                |mut acc, (role, message)| {
271                    let text = match message {
272                        ImageOrText::Text(text) => text,
273                        ImageOrText::Image(_) => media_marker,
274                    };
275                    if let Some(last) = acc.last()
276                        && last.0 == *role
277                    {
278                        // merge adjacent
279                        let (_, adj) = acc.remove(acc.len() - 1);
280                        acc.push((role.clone(), format!("{0}\n{text}", adj)));
281                        acc
282                    } else {
283                        acc.push((role.clone(), text.to_string()));
284                        acc
285                    }
286                },
287            )
288            .into_iter()
289            .collect::<Vec<_>>();
290        log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
291
292        // apply custom template
293        let formatted_prompt = self
294            .req
295            .tmpl
296            .apply_template(self.model, &self.runner.chat_template, &messages)
297            .map_err(GenericRunnerError::ApplyChatTemplate)?;
298
299        // Aggregate images
300        let bitmaps = self
301            .req
302            .messages
303            .iter()
304            .filter_map(|msg| match &msg.1 {
305                ImageOrText::Image(image) => Some(image),
306                _ => None,
307            })
308            .enumerate()
309            .map(|(idx, im)| {
310                MtmdBitmap::from_image_data(
311                    im.width(),
312                    im.height(),
313                    im.to_rgb8().to_vec().as_slice(),
314                )
315                .expect(format!("image#{} has corrupted RGB data", idx).as_str())
316            })
317            .collect::<Vec<_>>();
318        let bitmap_refs = bitmaps.iter().collect::<Vec<_>>();
319        let chunks = self.runner.mtmd_ctx.tokenize(
320            MtmdInputText {
321                text: formatted_prompt,
322                add_special: true,
323                parse_special: true,
324            },
325            &bitmap_refs,
326        )?;
327        log::debug!(target: "gemma", "tokenization resulted in {} chunks", chunks.len());
328        let n_past = chunks.eval_chunks(
329            &self.runner.mtmd_ctx,
330            self.ctx.as_ref().unwrap(),
331            0,
332            0,
333            1,
334            true,
335        )?;
336
337        // Generate preparation
338        let mut preparation = Runtime {
339            sampler: self.req.sampling.to_llama(),
340            decoder: UTF_8.new_decoder(),
341            batch: LlamaBatch::new(self.runner.ctx_size.get() as usize, 1),
342            n_past,
343            step: 0,
344        };
345        if let Some(llguidance) = &self.req.llguidance {
346            let llg_sampler = llguidance.to_llama(&self.runner.model)?;
347            preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
348        }
349        self.runtime = Some(preparation);
350
351        Ok(())
352    }
353}
354
355impl<S: AsRef<str>, Tmpl> PrepareRun<Tmpl::Error> for Gemma3Stream<'_, S, Gemma3TextRunner, Tmpl>
356where
357    Tmpl: ChatTemplate,
358{
359    fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
360        // Preprocess the message
361        let messages = self
362            .req
363            .messages
364            .iter()
365            .fold(
366                Vec::<(MessageRole, String)>::new(),
367                |mut acc, (role, message)| {
368                    if let Some(last) = acc.last()
369                        && last.0 == *role
370                    {
371                        // merge adjacent
372                        let (_, adj) = acc.remove(acc.len() - 1);
373                        acc.push((role.clone(), format!("{0}\n{1}", adj, message.as_ref())));
374                        acc
375                    } else {
376                        acc.push((role.clone(), message.as_ref().to_string()));
377                        acc
378                    }
379                },
380            )
381            .into_iter()
382            .collect::<Vec<_>>();
383        log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
384
385        // apply custom template
386        let formatted_prompt = self
387            .req
388            .tmpl
389            .apply_template(self.model, &self.runner.llama_template, &messages)
390            .map_err(GenericRunnerError::ApplyChatTemplate)?;
391
392        // Aggregate images
393        let token_list = self.model.str_to_token(&formatted_prompt, AddBos::Always)?;
394        let mut batch = LlamaBatch::new(self.runner.ctx_size.get() as usize, 1);
395        let token_list_len = token_list.len();
396        for (i, token) in token_list.into_iter().enumerate() {
397            batch.add(token, i as i32, &[0], i == token_list_len - 1)?;
398        }
399        self.ctx.as_mut().unwrap().decode(&mut batch)?;
400
401        // Generate preparation
402        let mut preparation = Runtime {
403            sampler: self.req.sampling.to_llama(),
404            decoder: UTF_8.new_decoder(),
405            batch,
406            n_past: token_list_len as i32,
407            step: 0,
408        };
409        if let Some(llguidance) = &self.req.llguidance {
410            let llg_sampler = llguidance.to_llama(&self.runner.model)?;
411            preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
412        }
413        self.runtime = Some(preparation);
414
415        Ok(())
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use crate::*;
422
423    #[tokio::test]
424    async fn test_lm() {
425        let runner = Gemma3TextRunner::default().await.unwrap();
426        let answer = runner
427            .get_lm_response(TextLmRequest {
428                messages: vec![(MessageRole::User, "What is the capital of France?")],
429                ..Default::default()
430            })
431            .unwrap();
432        assert!(answer.contains("Paris"));
433    }
434
435    #[tokio::test]
436    async fn test_vlm() {
437        let runner = Gemma3VisionRunner::default().await.unwrap();
438        let eiffel_tower_im =
439            image::load_from_memory(include_bytes!("../../assets/eiffel-tower.jpg")).unwrap();
440        let answer = runner
441            .get_vlm_response(VisionLmRequest {
442                messages: vec![
443                    (
444                        MessageRole::User,
445                        ImageOrText::Text("Which city is this building in?"),
446                    ),
447                    (MessageRole::User, ImageOrText::Image(&eiffel_tower_im)),
448                ],
449                ..Default::default()
450            })
451            .unwrap();
452        assert!(answer.contains("Paris"));
453    }
454}