Skip to main content

llama_runner/runner/
gemma4.rs

1use llama_cpp_2::model::{LlamaChatTemplate, LlamaModel};
2
3use crate::{
4    Gemma3Stream, Gemma3VisionRunner, GenericVisionLmRequest, MessageRole,
5    RunnerWithRecommendedSampling, VisionLmRunner,
6    error::{CreateLlamaCppRunnerError, GenericRunnerError},
7    sample::SimpleSamplingParams,
8    template::ChatTemplate,
9};
10
11pub const GEMMA_4_E2B_GUFF_MODEL_ID: &str = "unsloth/gemma-4-E2B-it-GGUF";
12pub const GEMMA_4_E2B_GUFF_MODEL_FILENAME: &str = "gemma-4-E2B-it-Q4_0.gguf";
13pub const GEMMA_4_E2B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
14
15#[repr(transparent)]
16pub struct Gemma4VisionRunner(Gemma3VisionRunner);
17
18impl Gemma4VisionRunner {
19    pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
20    {
21        let inner = Gemma3VisionRunner::new(
22            GEMMA_4_E2B_GUFF_MODEL_ID,
23            GEMMA_4_E2B_GUFF_MODEL_FILENAME,
24            GEMMA_4_E2B_GUFF_MULTIMODEL_FILENAME,
25            128_000u32.try_into().unwrap(),
26        )
27        .await?;
28        Ok(RunnerWithRecommendedSampling {
29            inner: Gemma4VisionRunner(inner),
30            default_sampling: SimpleSamplingParams {
31                top_p: Some(0.8f32),
32                top_k: Some(20),
33                temperature: Some(0.7f32),
34                presence_penalty: Some(1.5),
35                repetition_penalty: Some(1.0),
36                seed: None,
37            },
38        })
39    }
40}
41
42pub trait Gemma4ApplicableChatTemplate: ChatTemplate {}
43
44impl<'s, 'req, Tmpl: Gemma4ApplicableChatTemplate> VisionLmRunner<'s, 'req, Tmpl>
45    for Gemma4VisionRunner
46{
47    fn stream_vlm_response(
48        &'s self,
49        request: GenericVisionLmRequest<'req, Tmpl>,
50    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
51        self.0.stream_vlm_response(request)
52    }
53}
54
55#[cfg(test)]
56mod test {
57    use crate::{mcp::Gemma4ChatTemplate, *};
58
59    #[tokio::test]
60    #[cfg(feature = "mcp")]
61    async fn test_vlm() {
62        let runner = Gemma4VisionRunner::default().await.unwrap();
63        let eiffel_tower_im =
64            image::load_from_memory(include_bytes!("../../assets/eiffel-tower.jpg")).unwrap();
65        let answer = runner
66            .get_vlm_response(GenericRunnerRequest {
67                messages: vec![
68                    (
69                        MessageRole::User,
70                        ImageOrText::Text("Which city is this building in?"),
71                    ),
72                    (MessageRole::User, ImageOrText::Image(&eiffel_tower_im)),
73                ],
74                tmpl: Gemma4ChatTemplate::default(),
75                ..Default::default()
76            })
77            .unwrap();
78        assert!(answer.contains("Paris"));
79    }
80}