Skip to main content

llama_runner/
runner.rs

1use std::{
2    env,
3    io::IsTerminal,
4    num::NonZeroU32,
5    path::{Path, PathBuf},
6    str::FromStr,
7    sync::LazyLock,
8};
9
10use encoding_rs::{Decoder, UTF_8};
11use hf_hub::api::tokio::ApiBuilder;
12use llama_cpp_2::{
13    LlamaContextLoadError,
14    context::{LlamaContext, params::LlamaContextParams},
15    llama_backend::LlamaBackend,
16    llama_batch::LlamaBatch,
17    model::{AddBos, LlamaChatMessage, LlamaChatTemplate, LlamaModel},
18    mtmd::{self, MtmdBitmap, MtmdContext, MtmdInputText},
19    sampling::LlamaSampler,
20    token::LlamaToken,
21};
22use strum::Display;
23
24use crate::{
25    error::{CreateLlamaCppRunnerError, RunnerError},
26    sample::{LlguidanceSamplingParams, SimpleSamplingParams},
27};
28
29pub const QWEN_3D5_4B_GUFF_MODEL_ID: &str = "unsloth/Qwen3.5-4B-GGUF";
30pub const QWEN_3D5_4B_GUFF_MODDEL_FILENAME: &str = "Qwen3.5-4B-Q4_K_M.gguf";
31pub const QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
32
33pub const GEMMA_3_1B_GUFF_MODEL_ID: &str = "google/gemma-3-1b-it-qat-q4_0-gguf";
34pub const GEMMA_3_1B_GUFF_MODEL_FILENAME: &str = "gemma-3-1b-it-q4_0.gguf";
35
36pub trait TextLmRunner<'s, 'req> {
37    type Response: Iterator<Item = Result<String, RunnerError>>;
38    fn stream_lm_response(&'s self, request: TextLmRequest<'req>) -> Self::Response;
39}
40
41pub trait VisionLmRunner<'s, 'req> {
42    type Response: Iterator<Item = Result<String, RunnerError>>;
43    fn stream_vlm_response(&'s self, request: VisionLmRequest<'req>) -> Self::Response;
44}
45
46#[derive(Debug, Clone)]
47pub struct RunnerRequest<M> {
48    pub messages: Vec<(MessageRole, M)>,
49    pub sampling: SimpleSamplingParams,
50    pub llguidance: Option<LlguidanceSamplingParams>,
51    pub max_seq: usize,
52    pub prefill: Option<String>,
53}
54
55impl<M> Default for RunnerRequest<M> {
56    fn default() -> Self {
57        Self {
58            messages: vec![],
59            sampling: Default::default(),
60            llguidance: None,
61            max_seq: usize::MAX,
62            prefill: None,
63        }
64    }
65}
66
67pub type TextLmRequest<'a> = RunnerRequest<&'a str>;
68pub type VisionLmRequest<'a> = RunnerRequest<ImageOrText<'a>>;
69
70pub trait TextLmRunnerExt<'s, 'req> {
71    fn get_lm_response(&'s self, request: TextLmRequest<'req>) -> Result<String, RunnerError>;
72}
73
74pub trait VisionLmRunnerExt<'s, 'req> {
75    fn get_vlm_response(&'s self, request: VisionLmRequest<'req>) -> Result<String, RunnerError>;
76}
77
78impl<'s, 'req, T> TextLmRunnerExt<'s, 'req> for T
79where
80    T: TextLmRunner<'s, 'req>,
81{
82    fn get_lm_response(&'s self, request: TextLmRequest<'req>) -> Result<String, RunnerError> {
83        self.stream_lm_response(request)
84            .collect::<Result<String, _>>()
85    }
86}
87
88impl<'s, 'req, T> VisionLmRunnerExt<'s, 'req> for T
89where
90    T: VisionLmRunner<'s, 'req>,
91{
92    fn get_vlm_response(&'s self, request: VisionLmRequest<'req>) -> Result<String, RunnerError> {
93        self.stream_vlm_response(request)
94            .collect::<Result<String, _>>()
95    }
96}
97
98#[derive(Debug, Clone, Display, PartialEq, Eq)]
99pub enum MessageRole {
100    #[strum(to_string = "assistant")]
101    Assistant,
102    #[strum(to_string = "user")]
103    User,
104    #[strum(to_string = "system")]
105    System,
106}
107
108#[derive(Debug, Clone)]
109pub enum ImageOrText<'a> {
110    Text(&'a str),
111    Image(&'a image::DynamicImage),
112}
113
114pub struct Gemma3TextRunner {
115    model: LlamaModel,
116    chat_template: LlamaChatTemplate,
117    ctx_size: NonZeroU32,
118}
119
120impl Gemma3TextRunner {
121    pub async fn new(
122        model_id: impl ToString,
123        model_file: impl AsRef<str>,
124        ctx_size: NonZeroU32,
125    ) -> Result<Self, CreateLlamaCppRunnerError> {
126        let repo = build_hf_api()?.model(model_id.to_string());
127        Self::from_file(repo.get(model_file.as_ref()).await?, ctx_size)
128    }
129
130    pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
131    {
132        let inner = Self::new(
133            GEMMA_3_1B_GUFF_MODEL_ID,
134            GEMMA_3_1B_GUFF_MODEL_FILENAME,
135            32_000.try_into().unwrap(),
136        )
137        .await?;
138        Ok(RunnerWithRecommendedSampling {
139            inner,
140            default_sampling: Self::recommend_sampling(),
141        })
142    }
143
144    pub fn recommend_sampling() -> SimpleSamplingParams {
145        SimpleSamplingParams {
146            top_p: Some(0.95f32),
147            top_k: Some(64),
148            temperature: Some(1f32),
149            ..Default::default()
150        }
151    }
152
153    pub fn from_file(
154        model_file: impl AsRef<Path>,
155        ctx_size: NonZeroU32,
156    ) -> Result<Self, CreateLlamaCppRunnerError> {
157        let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
158
159        let chat_template = model.chat_template(None)?;
160        Ok(Self {
161            model,
162            chat_template,
163            ctx_size,
164        })
165    }
166}
167
168impl<'s, 'req> TextLmRunner<'s, 'req> for Gemma3TextRunner {
169    type Response = Gemma3Stream<'s, &'req str, Gemma3TextRunner>;
170
171    fn stream_lm_response(&'s self, request: TextLmRequest<'req>) -> Self::Response {
172        let ctx = self
173            .model
174            .new_context(
175                &LLAMA_BACKEND,
176                LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
177            )
178            .map_err(|err| RunnerError::from(err));
179        Gemma3Stream::new(ctx, request, self, &self.model)
180    }
181}
182
183pub struct Gemma3VisionRunner {
184    model: LlamaModel,
185    chat_template: LlamaChatTemplate,
186    mtmd_ctx: MtmdContext,
187    ctx_size: NonZeroU32,
188}
189
190static LLAMA_BACKEND: LazyLock<LlamaBackend> = LazyLock::new(|| {
191    llama_cpp_2::send_logs_to_tracing(llama_cpp_2::LogOptions::default().with_logs_enabled(
192        env::var("RUST_LOG").map_or(false, |lvl| lvl.to_lowercase() == "debug"),
193    ));
194    LlamaBackend::init().unwrap()
195});
196
197impl Gemma3VisionRunner {
198    pub async fn new(
199        repo_id: impl ToString,
200        model_file: impl AsRef<str>,
201        multimodel_file: impl AsRef<str>,
202        ctx_size: NonZeroU32,
203    ) -> Result<Self, CreateLlamaCppRunnerError> {
204        let repo = build_hf_api()?.model(repo_id.to_string());
205        let model = LlamaModel::load_from_file(
206            &LLAMA_BACKEND,
207            repo.get(model_file.as_ref()).await?,
208            &Default::default(),
209        )?;
210
211        let mtmd_ctx = MtmdContext::init_from_file(
212            repo.get(multimodel_file.as_ref()).await?.to_str().unwrap(),
213            &model,
214            &Default::default(),
215        )?;
216
217        let chat_template = model.chat_template(None)?;
218
219        Ok(Self {
220            model,
221            mtmd_ctx,
222            chat_template,
223            ctx_size,
224        })
225    }
226
227    pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
228    {
229        let inner = Self::new(
230            QWEN_3D5_4B_GUFF_MODEL_ID,
231            QWEN_3D5_4B_GUFF_MODDEL_FILENAME,
232            QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME,
233            16384u32.try_into().unwrap(),
234        )
235        .await?;
236        Ok(RunnerWithRecommendedSampling {
237            inner: inner,
238            default_sampling: SimpleSamplingParams {
239                top_p: Some(0.8f32),
240                top_k: Some(20),
241                temperature: Some(0.7f32),
242                presence_penalty: Some(1.5),
243                repetition_penalty: Some(1.0),
244                seed: None,
245            },
246        })
247    }
248
249    pub fn from_files(
250        model_file: impl AsRef<Path>,
251        multimodel_file: impl AsRef<Path>,
252        ctx_size: NonZeroU32,
253    ) -> Result<Self, CreateLlamaCppRunnerError> {
254        let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
255        let mtmd_ctx = MtmdContext::init_from_file(
256            multimodel_file.as_ref().as_os_str().to_str().unwrap(),
257            &model,
258            &Default::default(),
259        )?;
260
261        let chat_template = model.chat_template(None)?;
262
263        Ok(Self {
264            model,
265            mtmd_ctx,
266            chat_template,
267            ctx_size,
268        })
269    }
270
271    fn new_context_window(&self) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
272        self.model.new_context(
273            &LLAMA_BACKEND,
274            LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
275        )
276    }
277}
278
279impl<'s, 'req> VisionLmRunner<'s, 'req> for Gemma3VisionRunner {
280    type Response = Gemma3Stream<'s, ImageOrText<'req>, Gemma3VisionRunner>;
281
282    fn stream_vlm_response(&'s self, request: VisionLmRequest<'req>) -> Self::Response {
283        let ctx = self
284            .new_context_window()
285            .map_err(|err| RunnerError::from(err));
286        Gemma3Stream::new(ctx, request, self, &self.model)
287    }
288}
289
290impl<'s, 'req> TextLmRunner<'s, 'req> for Gemma3VisionRunner {
291    type Response = <Self as VisionLmRunner<'s, 'req>>::Response;
292
293    fn stream_lm_response(&'s self, request: TextLmRequest<'req>) -> Self::Response {
294        self.stream_vlm_response(request.into())
295    }
296}
297
298impl<'a> From<TextLmRequest<'a>> for VisionLmRequest<'a> {
299    fn from(value: TextLmRequest<'a>) -> Self {
300        Self {
301            messages: value
302                .messages
303                .into_iter()
304                .map(|(role, text)| (role, ImageOrText::Text(text)))
305                .collect(),
306            sampling: value.sampling,
307            llguidance: value.llguidance,
308            max_seq: value.max_seq,
309            prefill: value.prefill,
310        }
311    }
312}
313
314pub struct Gemma3Stream<'a, Message, Runner> {
315    ctx_source: Option<Result<LlamaContext<'a>, RunnerError>>,
316    ctx: Option<LlamaContext<'a>>,
317    req: RunnerRequest<Message>,
318    runner: &'a Runner,
319    model: &'a LlamaModel,
320    runtime: Option<Runtime<'a>>,
321    done: bool,
322}
323
324struct Runtime<'a> {
325    sampler: LlamaSampler,
326    decoder: Decoder,
327    batch: LlamaBatch<'a>,
328    n_past: i32,
329    step: usize,
330}
331
332trait PrepareRun {
333    fn prepare(&mut self) -> Result<(), RunnerError>;
334}
335
336impl PrepareRun for Gemma3Stream<'_, ImageOrText<'_>, Gemma3VisionRunner> {
337    fn prepare(&mut self) -> Result<(), RunnerError> {
338        // Preprocess the message, flattening media
339        let media_marker = mtmd::mtmd_default_marker();
340        let messages = self
341            .req
342            .messages
343            .iter()
344            .fold(
345                Vec::<(MessageRole, String)>::new(),
346                |mut acc, (role, message)| {
347                    let text = match message {
348                        ImageOrText::Text(text) => text,
349                        ImageOrText::Image(_) => media_marker,
350                    };
351                    if let Some(last) = acc.last()
352                        && last.0 == *role
353                    {
354                        // merge adjacent
355                        let (_, adj) = acc.remove(acc.len() - 1);
356                        acc.push((role.clone(), format!("{0}\n{text}", adj)));
357                        acc
358                    } else {
359                        acc.push((role.clone(), text.to_string()));
360                        acc
361                    }
362                },
363            )
364            .into_iter()
365            .map(|(role, content)| LlamaChatMessage::new(role.to_string(), content))
366            .collect::<Result<Vec<_>, _>>()
367            .expect("message preprocessing failed");
368        log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
369
370        // Aggregate images
371        let formatted_prompt =
372            self.runner
373                .model
374                .apply_chat_template(&self.runner.chat_template, &messages, true)?;
375        let bitmaps = self
376            .req
377            .messages
378            .iter()
379            .filter_map(|msg| match &msg.1 {
380                ImageOrText::Image(image) => Some(image),
381                _ => None,
382            })
383            .enumerate()
384            .map(|(idx, im)| {
385                MtmdBitmap::from_image_data(
386                    im.width(),
387                    im.height(),
388                    im.to_rgb8().to_vec().as_slice(),
389                )
390                .expect(format!("image#{} has corrupted RGB data", idx).as_str())
391            })
392            .collect::<Vec<_>>();
393        let bitmap_refs = bitmaps.iter().collect::<Vec<_>>();
394        let chunks = self.runner.mtmd_ctx.tokenize(
395            MtmdInputText {
396                text: formatted_prompt,
397                add_special: true,
398                parse_special: true,
399            },
400            &bitmap_refs,
401        )?;
402        log::debug!(target: "gemma", "tokenization resulted in {} chunks", chunks.len());
403        let n_past = chunks.eval_chunks(
404            &self.runner.mtmd_ctx,
405            self.ctx.as_ref().unwrap(),
406            0,
407            0,
408            1,
409            true,
410        )?;
411
412        // Generate preparation
413        let mut preparation = Runtime {
414            sampler: self.req.sampling.to_llama(),
415            decoder: UTF_8.new_decoder(),
416            batch: LlamaBatch::new(self.runner.ctx_size.get() as usize, 1),
417            n_past,
418            step: 0,
419        };
420        if let Some(llguidance) = &self.req.llguidance {
421            let llg_sampler = llguidance.to_llama(&self.runner.model)?;
422            preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
423        }
424        self.runtime = Some(preparation);
425
426        Ok(())
427    }
428}
429
430impl<S: AsRef<str>> PrepareRun for Gemma3Stream<'_, S, Gemma3TextRunner> {
431    fn prepare(&mut self) -> Result<(), RunnerError> {
432        // Preprocess the message
433        let messages = self
434            .req
435            .messages
436            .iter()
437            .fold(
438                Vec::<(MessageRole, String)>::new(),
439                |mut acc, (role, message)| {
440                    if let Some(last) = acc.last()
441                        && last.0 == *role
442                    {
443                        // merge adjacent
444                        let (_, adj) = acc.remove(acc.len() - 1);
445                        acc.push((role.clone(), format!("{0}\n{1}", adj, message.as_ref())));
446                        acc
447                    } else {
448                        acc.push((role.clone(), message.as_ref().to_string()));
449                        acc
450                    }
451                },
452            )
453            .into_iter()
454            .map(|(role, content)| LlamaChatMessage::new(role.to_string(), content))
455            .collect::<Result<Vec<_>, _>>()
456            .expect("message preprocessing failed");
457        log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
458
459        // Aggregate images
460        let formatted_prompt =
461            self.runner
462                .model
463                .apply_chat_template(&self.runner.chat_template, &messages, true)?;
464        let token_list = self.model.str_to_token(&formatted_prompt, AddBos::Always)?;
465        let mut batch = LlamaBatch::new(self.runner.ctx_size.get() as usize, 1);
466        let token_list_len = token_list.len();
467        for (i, token) in token_list.into_iter().enumerate() {
468            batch.add(token, i as i32, &[0], i == token_list_len - 1)?;
469        }
470        self.ctx.as_mut().unwrap().decode(&mut batch)?;
471
472        // Generate preparation
473        let mut preparation = Runtime {
474            sampler: self.req.sampling.to_llama(),
475            decoder: UTF_8.new_decoder(),
476            batch,
477            n_past: token_list_len as i32,
478            step: 0,
479        };
480        if let Some(llguidance) = &self.req.llguidance {
481            let llg_sampler = llguidance.to_llama(&self.runner.model)?;
482            preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
483        }
484        self.runtime = Some(preparation);
485
486        Ok(())
487    }
488}
489
490impl<'a, M, R> Iterator for Gemma3Stream<'a, M, R>
491where
492    Self: PrepareRun,
493{
494    type Item = Result<String, RunnerError>;
495
496    fn next(&mut self) -> Option<Self::Item> {
497        if self.done {
498            return None;
499        }
500
501        if let Some(result) = self.ctx_source.take() {
502            match result {
503                Ok(ctx) => self.ctx = Some(ctx),
504                Err(err) => {
505                    self.done = true;
506                    return Some(Err(err));
507                }
508            }
509        }
510
511        if self.runtime.is_none()
512            && let Err(err) = self.prepare()
513        {
514            self.done = true;
515            return Some(Err(err));
516        }
517        let Runtime {
518            sampler,
519            decoder,
520            batch,
521            n_past,
522            step,
523        } = self.runtime.as_mut().unwrap();
524
525        if *step >= self.req.max_seq {
526            self.done = true;
527            return None;
528        }
529
530        // Sample response token
531        let ctx = self.ctx.as_mut().unwrap();
532        let model = self.model;
533        let sample_idx = batch.n_tokens() - 1;
534        let mut sample = |token: LlamaToken,
535                          sampler: &mut LlamaSampler,
536                          ctx: &mut LlamaContext<'a>,
537                          step: usize|
538         -> Result<Option<String>, RunnerError> {
539            sampler.accept(token);
540            if model.is_eog_token(token) {
541                return Ok(None);
542            }
543            batch.clear();
544            batch.add(token, *n_past + (step as i32), &[0], true)?;
545
546            ctx.decode(batch)?;
547
548            let piece = model.token_to_piece(token, decoder, true, None)?;
549            Ok(Some(piece))
550        };
551        if let Some(prefill) = self.req.prefill.take() {
552            log::debug!(target: "gemma", "prefill: {}", prefill);
553            let tokens = match model.str_to_token(&prefill, AddBos::Never) {
554                Ok(tokens) => tokens,
555                Err(err) => {
556                    return Some(Err(err.into()));
557                }
558            };
559            log::debug!(target: "gemma", "prefill tokens: {:?}", tokens.iter().map(|t| t.0).collect::<Vec<_>>());
560            for token in tokens {
561                match sample(token, sampler, ctx, *step) {
562                    Ok(_) => {}
563                    Err(err) => return Some(Err(err.into())),
564                }
565                *step += 1;
566            }
567            Some(Ok(prefill))
568        } else {
569            let token = sampler.sample(ctx, sample_idx);
570            match sample(token, sampler, ctx, *step) {
571                Ok(Some(piece)) => {
572                    *step += 1;
573                    return Some(Ok(piece));
574                }
575                Ok(None) => {
576                    self.done = true;
577                    return None;
578                }
579                Err(err) => {
580                    self.done = true;
581                    return Some(Err(err));
582                }
583            }
584        }
585    }
586}
587
588impl<'s, M, R> Gemma3Stream<'s, M, R> {
589    fn new(
590        source: Result<LlamaContext<'s>, RunnerError>,
591        req: RunnerRequest<M>,
592        runner: &'s R,
593        model: &'s LlamaModel,
594    ) -> Self {
595        Self {
596            ctx_source: Some(source),
597            ctx: None,
598            req,
599            runner,
600            model,
601            runtime: None,
602            done: false,
603        }
604    }
605}
606
607pub struct RunnerWithRecommendedSampling<Inner> {
608    pub inner: Inner,
609    pub default_sampling: SimpleSamplingParams,
610}
611
612impl<'a, Inner> RunnerWithRecommendedSampling<Inner> {
613    fn get_preprocessed_simple_sampling(
614        &self,
615        sampling: SimpleSamplingParams,
616    ) -> SimpleSamplingParams {
617        let mut sampling = sampling;
618        if sampling.top_k.is_none() {
619            sampling.top_k = self.default_sampling.top_k;
620        }
621        if sampling.top_p.is_none() {
622            sampling.top_p = self.default_sampling.top_p;
623        }
624        if sampling.temperature.is_none() {
625            sampling.temperature = self.default_sampling.temperature;
626        }
627        sampling
628    }
629}
630
631impl<'s, 'req, Inner> VisionLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
632where
633    Inner: VisionLmRunner<'s, 'req>,
634{
635    type Response = <Inner as VisionLmRunner<'s, 'req>>::Response;
636
637    fn stream_vlm_response(&'s self, mut request: VisionLmRequest<'req>) -> Self::Response {
638        request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
639        self.inner.stream_vlm_response(request)
640    }
641}
642
643impl<'s, 'req, Inner> TextLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
644where
645    Inner: TextLmRunner<'s, 'req>,
646{
647    type Response = <Inner as TextLmRunner<'s, 'req>>::Response;
648
649    fn stream_lm_response(&'s self, mut request: TextLmRequest<'req>) -> Self::Response {
650        request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
651        self.inner.stream_lm_response(request)
652    }
653}
654
655impl<Inner> From<Inner> for RunnerWithRecommendedSampling<Inner> {
656    fn from(value: Inner) -> Self {
657        Self {
658            inner: value,
659            default_sampling: SimpleSamplingParams::default(),
660        }
661    }
662}
663
664fn build_hf_api() -> Result<hf_hub::api::tokio::Api, hf_hub::api::tokio::ApiError> {
665    let mut api = ApiBuilder::new()
666        .with_progress(std::io::stdin().is_terminal())
667        .with_token(std::env::var("HF_TOKEN").ok())
668        .with_chunk_size(Some(2 << 28));
669    if let Ok(endpoint) = std::env::var("HF_ENDPOINT") {
670        api = api.with_endpoint(endpoint);
671    }
672    if let Ok(cache) = std::env::var("HF_HOME") {
673        api = api.with_cache_dir(
674            PathBuf::from_str(&cache).expect("HF_HOME env var is not a valid path"),
675        );
676    }
677    api.build()
678}