Skip to main content

llama_runner/runner/
gemma4.rs

1use std::num::NonZeroU32;
2
3use crate::{
4    Gemma3VisionRunner, GenericVisionLmRequest, RunnerWithRecommendedSampling, VisionLmRunner,
5    error::{CreateLlamaCppRunnerError, GenericRunnerError},
6    sample::SimpleSamplingParams,
7    template::ChatTemplate,
8};
9
10pub const GEMMA_4_E2B_GUFF_MODEL_ID: &str = "unsloth/gemma-4-E2B-it-GGUF";
11pub const GEMMA_4_E2B_GUFF_MODEL_FILENAME: &str = "gemma-4-E2B-it-Q4_0.gguf";
12pub const GEMMA_4_E2B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
13
14#[repr(transparent)]
15pub struct Gemma4VisionRunner(Gemma3VisionRunner);
16
17impl Gemma4VisionRunner {
18    pub async fn new(
19        repo_id: impl ToString,
20        model_file: impl AsRef<str>,
21        multimodel_file: impl AsRef<str>,
22        ctx_size: NonZeroU32,
23    ) -> Result<Self, CreateLlamaCppRunnerError> {
24        Ok(Self(
25            Gemma3VisionRunner::new(repo_id, model_file, multimodel_file, ctx_size).await?,
26        ))
27    }
28
29    pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
30    {
31        let inner = Self::new(
32            GEMMA_4_E2B_GUFF_MODEL_ID,
33            GEMMA_4_E2B_GUFF_MODEL_FILENAME,
34            GEMMA_4_E2B_GUFF_MULTIMODEL_FILENAME,
35            128_000u32.try_into().unwrap(),
36        )
37        .await?;
38        Ok(RunnerWithRecommendedSampling {
39            inner,
40            default_sampling: SimpleSamplingParams {
41                top_p: Some(0.95f32),
42                top_k: Some(64),
43                temperature: Some(1.0f32),
44                presence_penalty: None,
45                repetition_penalty: None,
46                seed: None,
47            },
48        })
49    }
50}
51
52pub trait Gemma4ApplicableChatTemplate: ChatTemplate {}
53
54impl<'s, 'req, Tmpl: Gemma4ApplicableChatTemplate> VisionLmRunner<'s, 'req, Tmpl>
55    for Gemma4VisionRunner
56{
57    fn stream_vlm_response(
58        &'s self,
59        request: GenericVisionLmRequest<'req, Tmpl>,
60    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
61        self.0.stream_vlm_response(request)
62    }
63}
64
65#[cfg(test)]
66mod test {
67    use crate::{mcp::Gemma4ChatTemplate, *};
68
69    #[tokio::test]
70    #[cfg(feature = "mcp")]
71    async fn test_vlm() {
72        let runner = Gemma4VisionRunner::default().await.unwrap();
73        let eiffel_tower_im =
74            image::load_from_memory(include_bytes!("../../assets/eiffel-tower.jpg")).unwrap();
75        let answer = runner
76            .get_vlm_response(GenericRunnerRequest {
77                messages: vec![
78                    (
79                        MessageRole::User,
80                        ImageOrText::Text("Which city is this building in?"),
81                    ),
82                    (MessageRole::User, ImageOrText::Image(&eiffel_tower_im)),
83                ],
84                tmpl: Gemma4ChatTemplate::default(),
85                ..Default::default()
86            })
87            .unwrap();
88        assert!(answer.contains("Paris"));
89    }
90}