Skip to main content

llama_runner/
runner.rs

1use std::{
2    io::IsTerminal,
3    marker::PhantomData,
4    num::NonZeroU32,
5    path::{Path, PathBuf},
6    str::FromStr,
7    sync::LazyLock,
8};
9
10use encoding_rs::{Decoder, UTF_8};
11use hf_hub::api::tokio::ApiBuilder;
12use llama_cpp_2::{
13    LlamaContextLoadError,
14    context::{LlamaContext, params::LlamaContextParams},
15    llama_backend::LlamaBackend,
16    llama_batch::LlamaBatch,
17    model::{AddBos, LlamaChatTemplate, LlamaModel},
18    mtmd::{self, MtmdBitmap, MtmdContext, MtmdInputText},
19    sampling::LlamaSampler,
20    token::LlamaToken,
21};
22use strum::Display;
23
24use crate::{
25    error::{CreateLlamaCppRunnerError, GenericRunnerError},
26    sample::{LlguidanceSamplingParams, SimpleSamplingParams},
27    template::{ChatTemplate, ModelChatTemplate},
28};
29
30pub const QWEN_3D5_4B_GUFF_MODEL_ID: &str = "unsloth/Qwen3.5-4B-GGUF";
31pub const QWEN_3D5_4B_GUFF_MODDEL_FILENAME: &str = "Qwen3.5-4B-Q4_K_M.gguf";
32pub const QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
33
34pub const GEMMA_3_1B_GUFF_MODEL_ID: &str = "google/gemma-3-1b-it-qat-q4_0-gguf";
35pub const GEMMA_3_1B_GUFF_MODEL_FILENAME: &str = "gemma-3-1b-it-q4_0.gguf";
36
37pub trait TextLmRunner<'s, 'req> {
38    type Response: Iterator<
39        Item = Result<String, GenericRunnerError<<Self::Template as ChatTemplate>::Error>>,
40    >;
41    type Template: ChatTemplate;
42    fn stream_lm_response(
43        &'s self,
44        request: GenericTextLmRequest<'req, Self::Template>,
45    ) -> Self::Response;
46}
47
48pub trait VisionLmRunner<'s, 'req> {
49    type Response: Iterator<
50        Item = Result<String, GenericRunnerError<<Self::Template as ChatTemplate>::Error>>,
51    >;
52    type Template: ChatTemplate;
53    fn stream_vlm_response(
54        &'s self,
55        request: GenericVisionLmRequest<'req, Self::Template>,
56    ) -> Self::Response;
57}
58
59#[derive(Debug, Clone)]
60pub struct GenericRunnerRequest<MsgCt, Tmpl> {
61    pub messages: Vec<(MessageRole, MsgCt)>,
62    pub sampling: SimpleSamplingParams,
63    pub llguidance: Option<LlguidanceSamplingParams>,
64    pub max_seq: usize,
65    pub prefill: Option<String>,
66    pub tmpl: Tmpl,
67}
68
69impl<M, T> Default for GenericRunnerRequest<M, T>
70where
71    T: Default,
72{
73    fn default() -> Self {
74        Self {
75            messages: vec![],
76            sampling: Default::default(),
77            llguidance: None,
78            max_seq: usize::MAX,
79            prefill: None,
80            tmpl: Default::default(),
81        }
82    }
83}
84
85pub type GenericTextLmRequest<'a, Tmpl> = GenericRunnerRequest<&'a str, Tmpl>;
86pub type GenericVisionLmRequest<'a, Tmpl> = GenericRunnerRequest<ImageOrText<'a>, Tmpl>;
87
88pub type RunnerRequest<'a, MsgCnt> = GenericRunnerRequest<MsgCnt, ModelChatTemplate>;
89pub type TextLmRequest<'a> = RunnerRequest<'a, &'a str>;
90pub type VisionLmRequest<'a> = RunnerRequest<'a, ImageOrText<'a>>;
91
92pub trait TextLmRunnerExt<'s, 'req, Tmpl, TmplErr> {
93    fn get_lm_response(
94        &'s self,
95        request: GenericTextLmRequest<'req, Tmpl>,
96    ) -> Result<String, GenericRunnerError<TmplErr>>;
97}
98
99pub trait VisionLmRunnerExt<'s, 'req, Tmpl, TmplErr> {
100    fn get_vlm_response(
101        &'s self,
102        request: GenericVisionLmRequest<'req, Tmpl>,
103    ) -> Result<String, GenericRunnerError<TmplErr>>;
104}
105
106impl<'s, 'req, TextRunner, Tmpl> TextLmRunnerExt<'s, 'req, Tmpl, Tmpl::Error> for TextRunner
107where
108    Tmpl: ChatTemplate,
109    TextRunner: TextLmRunner<'s, 'req, Template = Tmpl>,
110{
111    fn get_lm_response(
112        &'s self,
113        request: GenericTextLmRequest<'req, Tmpl>,
114    ) -> Result<String, GenericRunnerError<Tmpl::Error>> {
115        self.stream_lm_response(request)
116            .collect::<Result<String, _>>()
117    }
118}
119
120impl<'s, 'req, VisionRunner, Tmpl> VisionLmRunnerExt<'s, 'req, Tmpl, Tmpl::Error> for VisionRunner
121where
122    Tmpl: ChatTemplate,
123    VisionRunner: VisionLmRunner<'s, 'req, Template = Tmpl>,
124{
125    fn get_vlm_response(
126        &'s self,
127        request: GenericVisionLmRequest<'req, Tmpl>,
128    ) -> Result<String, GenericRunnerError<Tmpl::Error>> {
129        self.stream_vlm_response(request)
130            .collect::<Result<String, _>>()
131    }
132}
133
134#[derive(Debug, Clone, Display, PartialEq, Eq)]
135pub enum MessageRole {
136    #[strum(to_string = "assistant")]
137    Assistant,
138    #[strum(to_string = "user")]
139    User,
140    #[strum(to_string = "system")]
141    System,
142    #[strum(to_string = "{0}")]
143    Custom(&'static str),
144}
145
146#[derive(Debug, Clone)]
147pub enum ImageOrText<'a> {
148    Text(&'a str),
149    Image(&'a image::DynamicImage),
150}
151
152pub struct Gemma3TextRunner<Tmpl> {
153    model: LlamaModel,
154    llama_template: LlamaChatTemplate,
155    ctx_size: NonZeroU32,
156    _tmpl: PhantomData<Tmpl>,
157}
158
159impl<Tmpl> Gemma3TextRunner<Tmpl> {
160    pub async fn new(
161        model_id: impl ToString,
162        model_file: impl AsRef<str>,
163        ctx_size: NonZeroU32,
164    ) -> Result<Self, CreateLlamaCppRunnerError> {
165        let repo = build_hf_api()?.model(model_id.to_string());
166        Self::from_file(repo.get(model_file.as_ref()).await?, ctx_size)
167    }
168
169    pub fn recommend_sampling() -> SimpleSamplingParams {
170        SimpleSamplingParams {
171            top_p: Some(0.95f32),
172            top_k: Some(64),
173            temperature: Some(1f32),
174            ..Default::default()
175        }
176    }
177
178    pub fn from_file(
179        model_file: impl AsRef<Path>,
180        ctx_size: NonZeroU32,
181    ) -> Result<Self, CreateLlamaCppRunnerError> {
182        let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
183
184        let chat_template = model.chat_template(None)?;
185        Ok(Self {
186            model,
187            llama_template: chat_template,
188            ctx_size,
189            _tmpl: PhantomData,
190        })
191    }
192}
193
194impl Gemma3TextRunner<ModelChatTemplate> {
195    pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
196    {
197        let inner = Self::new(
198            GEMMA_3_1B_GUFF_MODEL_ID,
199            GEMMA_3_1B_GUFF_MODEL_FILENAME,
200            32_000.try_into().unwrap(),
201        )
202        .await?;
203        Ok(RunnerWithRecommendedSampling {
204            inner,
205            default_sampling: Self::recommend_sampling(),
206        })
207    }
208}
209
210impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req> for Gemma3TextRunner<Tmpl>
211where
212    Tmpl: ChatTemplate + 's,
213{
214    type Response = Gemma3Stream<'s, &'req str, Self, Tmpl>;
215    type Template = Tmpl;
216
217    fn stream_lm_response(&'s self, request: GenericTextLmRequest<'req, Tmpl>) -> Self::Response {
218        let ctx = self
219            .model
220            .new_context(
221                &LLAMA_BACKEND,
222                LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
223            )
224            .map_err(|err| GenericRunnerError::from(err));
225        Gemma3Stream::new(ctx, request, self, &self.model)
226    }
227}
228
229pub struct Gemma3VisionRunner<Tmpl> {
230    model: LlamaModel,
231    chat_template: LlamaChatTemplate,
232    mtmd_ctx: MtmdContext,
233    ctx_size: NonZeroU32,
234    _tmpl: PhantomData<Tmpl>,
235}
236
237static LLAMA_BACKEND: LazyLock<LlamaBackend> = LazyLock::new(|| {
238    llama_cpp_2::send_logs_to_tracing(llama_cpp_2::LogOptions::default());
239    LlamaBackend::init().unwrap()
240});
241
242impl<Tmpl> Gemma3VisionRunner<Tmpl> {
243    pub async fn new(
244        repo_id: impl ToString,
245        model_file: impl AsRef<str>,
246        multimodel_file: impl AsRef<str>,
247        ctx_size: NonZeroU32,
248    ) -> Result<Self, CreateLlamaCppRunnerError> {
249        let repo = build_hf_api()?.model(repo_id.to_string());
250        let model = LlamaModel::load_from_file(
251            &LLAMA_BACKEND,
252            repo.get(model_file.as_ref()).await?,
253            &Default::default(),
254        )?;
255
256        let mtmd_ctx = MtmdContext::init_from_file(
257            repo.get(multimodel_file.as_ref()).await?.to_str().unwrap(),
258            &model,
259            &Default::default(),
260        )?;
261
262        let chat_template = model.chat_template(None)?;
263
264        Ok(Self {
265            model,
266            mtmd_ctx,
267            chat_template,
268            ctx_size,
269            _tmpl: PhantomData,
270        })
271    }
272
273    pub fn from_files(
274        model_file: impl AsRef<Path>,
275        multimodel_file: impl AsRef<Path>,
276        ctx_size: NonZeroU32,
277    ) -> Result<Self, CreateLlamaCppRunnerError> {
278        let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
279        let mtmd_ctx = MtmdContext::init_from_file(
280            multimodel_file.as_ref().as_os_str().to_str().unwrap(),
281            &model,
282            &Default::default(),
283        )?;
284
285        let chat_template = model.chat_template(None)?;
286
287        Ok(Self {
288            model,
289            mtmd_ctx,
290            chat_template,
291            ctx_size,
292            _tmpl: PhantomData,
293        })
294    }
295
296    fn new_context_window(&self) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
297        self.model.new_context(
298            &LLAMA_BACKEND,
299            LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
300        )
301    }
302}
303
304impl Gemma3VisionRunner<ModelChatTemplate> {
305    pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
306    {
307        let inner = Self::new(
308            QWEN_3D5_4B_GUFF_MODEL_ID,
309            QWEN_3D5_4B_GUFF_MODDEL_FILENAME,
310            QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME,
311            16384u32.try_into().unwrap(),
312        )
313        .await?;
314        Ok(RunnerWithRecommendedSampling {
315            inner: inner,
316            default_sampling: SimpleSamplingParams {
317                top_p: Some(0.8f32),
318                top_k: Some(20),
319                temperature: Some(0.7f32),
320                presence_penalty: Some(1.5),
321                repetition_penalty: Some(1.0),
322                seed: None,
323            },
324        })
325    }
326}
327
328impl<'s, 'req, Tmpl> VisionLmRunner<'s, 'req> for Gemma3VisionRunner<Tmpl>
329where
330    Tmpl: ChatTemplate + 's,
331{
332    type Response = Gemma3Stream<'s, ImageOrText<'req>, Self, Tmpl>;
333    type Template = Tmpl;
334
335    fn stream_vlm_response(
336        &'s self,
337        request: GenericVisionLmRequest<'req, Tmpl>,
338    ) -> Self::Response {
339        let ctx = self
340            .new_context_window()
341            .map_err(|err| GenericRunnerError::from(err));
342        Gemma3Stream::new(ctx, request, self, &self.model)
343    }
344}
345
346impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req> for Gemma3VisionRunner<Tmpl>
347where
348    Tmpl: ChatTemplate + 's,
349{
350    type Response = <Self as VisionLmRunner<'s, 'req>>::Response;
351    type Template = <Self as VisionLmRunner<'s, 'req>>::Template;
352
353    fn stream_lm_response(&'s self, request: GenericTextLmRequest<'req, Tmpl>) -> Self::Response {
354        self.stream_vlm_response(request.into())
355    }
356}
357
358impl<'a, Tmpl> From<GenericTextLmRequest<'a, Tmpl>> for GenericVisionLmRequest<'a, Tmpl> {
359    fn from(value: GenericTextLmRequest<'a, Tmpl>) -> Self {
360        Self {
361            messages: value
362                .messages
363                .into_iter()
364                .map(|(role, text)| (role, ImageOrText::Text(text)))
365                .collect(),
366            sampling: value.sampling,
367            llguidance: value.llguidance,
368            max_seq: value.max_seq,
369            prefill: value.prefill,
370            tmpl: value.tmpl,
371        }
372    }
373}
374
375pub struct Gemma3Stream<'a, Message, Runner, Tmpl: ChatTemplate> {
376    ctx_source: Option<Result<LlamaContext<'a>, GenericRunnerError<Tmpl::Error>>>,
377    ctx: Option<LlamaContext<'a>>,
378    req: GenericRunnerRequest<Message, Tmpl>,
379    runner: &'a Runner,
380    model: &'a LlamaModel,
381    runtime: Option<Runtime<'a>>,
382    done: bool,
383}
384
385struct Runtime<'a> {
386    sampler: LlamaSampler,
387    decoder: Decoder,
388    batch: LlamaBatch<'a>,
389    n_past: i32,
390    step: usize,
391}
392
393trait PrepareRun<TmplErr> {
394    fn prepare(&mut self) -> Result<(), GenericRunnerError<TmplErr>>;
395}
396
397impl<Tmpl> PrepareRun<Tmpl::Error>
398    for Gemma3Stream<'_, ImageOrText<'_>, Gemma3VisionRunner<Tmpl>, Tmpl>
399where
400    Tmpl: ChatTemplate,
401{
402    fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
403        // Preprocess the message, flattening media
404        let media_marker = mtmd::mtmd_default_marker();
405        let messages = self
406            .req
407            .messages
408            .iter()
409            .fold(
410                Vec::<(MessageRole, String)>::new(),
411                |mut acc, (role, message)| {
412                    let text = match message {
413                        ImageOrText::Text(text) => text,
414                        ImageOrText::Image(_) => media_marker,
415                    };
416                    if let Some(last) = acc.last()
417                        && last.0 == *role
418                    {
419                        // merge adjacent
420                        let (_, adj) = acc.remove(acc.len() - 1);
421                        acc.push((role.clone(), format!("{0}\n{text}", adj)));
422                        acc
423                    } else {
424                        acc.push((role.clone(), text.to_string()));
425                        acc
426                    }
427                },
428            )
429            .into_iter()
430            .collect::<Vec<_>>();
431        log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
432
433        // apply custom template
434        let formatted_prompt = self
435            .req
436            .tmpl
437            .apply_template(self.model, &self.runner.chat_template, &messages)
438            .map_err(GenericRunnerError::ApplyChatTemplate)?;
439
440        // Aggregate images
441        let bitmaps = self
442            .req
443            .messages
444            .iter()
445            .filter_map(|msg| match &msg.1 {
446                ImageOrText::Image(image) => Some(image),
447                _ => None,
448            })
449            .enumerate()
450            .map(|(idx, im)| {
451                MtmdBitmap::from_image_data(
452                    im.width(),
453                    im.height(),
454                    im.to_rgb8().to_vec().as_slice(),
455                )
456                .expect(format!("image#{} has corrupted RGB data", idx).as_str())
457            })
458            .collect::<Vec<_>>();
459        let bitmap_refs = bitmaps.iter().collect::<Vec<_>>();
460        let chunks = self.runner.mtmd_ctx.tokenize(
461            MtmdInputText {
462                text: formatted_prompt,
463                add_special: true,
464                parse_special: true,
465            },
466            &bitmap_refs,
467        )?;
468        log::debug!(target: "gemma", "tokenization resulted in {} chunks", chunks.len());
469        let n_past = chunks.eval_chunks(
470            &self.runner.mtmd_ctx,
471            self.ctx.as_ref().unwrap(),
472            0,
473            0,
474            1,
475            true,
476        )?;
477
478        // Generate preparation
479        let mut preparation = Runtime {
480            sampler: self.req.sampling.to_llama(),
481            decoder: UTF_8.new_decoder(),
482            batch: LlamaBatch::new(self.runner.ctx_size.get() as usize, 1),
483            n_past,
484            step: 0,
485        };
486        if let Some(llguidance) = &self.req.llguidance {
487            let llg_sampler = llguidance.to_llama(&self.runner.model)?;
488            preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
489        }
490        self.runtime = Some(preparation);
491
492        Ok(())
493    }
494}
495
496impl<S: AsRef<str>, Tmpl> PrepareRun<Tmpl::Error>
497    for Gemma3Stream<'_, S, Gemma3TextRunner<Tmpl>, Tmpl>
498where
499    Tmpl: ChatTemplate,
500{
501    fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
502        // Preprocess the message
503        let messages = self
504            .req
505            .messages
506            .iter()
507            .fold(
508                Vec::<(MessageRole, String)>::new(),
509                |mut acc, (role, message)| {
510                    if let Some(last) = acc.last()
511                        && last.0 == *role
512                    {
513                        // merge adjacent
514                        let (_, adj) = acc.remove(acc.len() - 1);
515                        acc.push((role.clone(), format!("{0}\n{1}", adj, message.as_ref())));
516                        acc
517                    } else {
518                        acc.push((role.clone(), message.as_ref().to_string()));
519                        acc
520                    }
521                },
522            )
523            .into_iter()
524            .collect::<Vec<_>>();
525        log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
526
527        // apply custom template
528        let formatted_prompt = self
529            .req
530            .tmpl
531            .apply_template(self.model, &self.runner.llama_template, &messages)
532            .map_err(GenericRunnerError::ApplyChatTemplate)?;
533
534        // Aggregate images
535        let token_list = self.model.str_to_token(&formatted_prompt, AddBos::Always)?;
536        let mut batch = LlamaBatch::new(self.runner.ctx_size.get() as usize, 1);
537        let token_list_len = token_list.len();
538        for (i, token) in token_list.into_iter().enumerate() {
539            batch.add(token, i as i32, &[0], i == token_list_len - 1)?;
540        }
541        self.ctx.as_mut().unwrap().decode(&mut batch)?;
542
543        // Generate preparation
544        let mut preparation = Runtime {
545            sampler: self.req.sampling.to_llama(),
546            decoder: UTF_8.new_decoder(),
547            batch,
548            n_past: token_list_len as i32,
549            step: 0,
550        };
551        if let Some(llguidance) = &self.req.llguidance {
552            let llg_sampler = llguidance.to_llama(&self.runner.model)?;
553            preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
554        }
555        self.runtime = Some(preparation);
556
557        Ok(())
558    }
559}
560
561impl<'a, Message, Runner, Tmpl> Iterator for Gemma3Stream<'a, Message, Runner, Tmpl>
562where
563    Tmpl: ChatTemplate,
564    Self: PrepareRun<Tmpl::Error>,
565{
566    type Item = Result<String, GenericRunnerError<Tmpl::Error>>;
567
568    fn next(&mut self) -> Option<Self::Item> {
569        if self.done {
570            return None;
571        }
572
573        if let Some(result) = self.ctx_source.take() {
574            match result {
575                Ok(ctx) => self.ctx = Some(ctx),
576                Err(err) => {
577                    self.done = true;
578                    return Some(Err(err));
579                }
580            }
581        }
582
583        if self.runtime.is_none()
584            && let Err(err) = self.prepare()
585        {
586            self.done = true;
587            return Some(Err(err));
588        }
589        let Runtime {
590            sampler,
591            decoder,
592            batch,
593            n_past,
594            step,
595        } = self.runtime.as_mut().unwrap();
596
597        if *step >= self.req.max_seq {
598            self.done = true;
599            return None;
600        }
601
602        // Sample response token
603        let ctx = self.ctx.as_mut().unwrap();
604        let model = self.model;
605        let sample_idx = batch.n_tokens() - 1;
606        let mut sample = |token: LlamaToken,
607                          sampler: &mut LlamaSampler,
608                          ctx: &mut LlamaContext<'a>,
609                          step: usize|
610         -> Result<Option<String>, GenericRunnerError<Tmpl::Error>> {
611            sampler.accept(token);
612            if model.is_eog_token(token) {
613                return Ok(None);
614            }
615            batch.clear();
616            batch.add(token, *n_past + (step as i32), &[0], true)?;
617
618            ctx.decode(batch)?;
619
620            let piece = model.token_to_piece(token, decoder, true, None)?;
621            Ok(Some(piece))
622        };
623        if let Some(prefill) = self.req.prefill.take() {
624            log::debug!(target: "gemma", "prefill: {}", prefill);
625            let tokens = match model.str_to_token(&prefill, AddBos::Never) {
626                Ok(tokens) => tokens,
627                Err(err) => {
628                    return Some(Err(err.into()));
629                }
630            };
631            log::debug!(target: "gemma", "prefill tokens: {:?}", tokens.iter().map(|t| t.0).collect::<Vec<_>>());
632            for token in tokens {
633                match sample(token, sampler, ctx, *step) {
634                    Ok(_) => {}
635                    Err(err) => return Some(Err(err.into())),
636                }
637                *step += 1;
638            }
639            Some(Ok(prefill))
640        } else {
641            let token = sampler.sample(ctx, sample_idx);
642            match sample(token, sampler, ctx, *step) {
643                Ok(Some(piece)) => {
644                    *step += 1;
645                    return Some(Ok(piece));
646                }
647                Ok(None) => {
648                    self.done = true;
649                    return None;
650                }
651                Err(err) => {
652                    self.done = true;
653                    return Some(Err(err));
654                }
655            }
656        }
657    }
658}
659
660impl<'s, Message, Runner, Tmpl> Gemma3Stream<'s, Message, Runner, Tmpl>
661where
662    Tmpl: ChatTemplate,
663{
664    fn new(
665        source: Result<LlamaContext<'s>, GenericRunnerError<Tmpl::Error>>,
666        req: GenericRunnerRequest<Message, Tmpl>,
667        runner: &'s Runner,
668        model: &'s LlamaModel,
669    ) -> Self {
670        Self {
671            ctx_source: Some(source),
672            ctx: None,
673            req,
674            runner,
675            model,
676            runtime: None,
677            done: false,
678        }
679    }
680}
681
682pub struct RunnerWithRecommendedSampling<Inner> {
683    pub inner: Inner,
684    pub default_sampling: SimpleSamplingParams,
685}
686
687impl<'a, Inner> RunnerWithRecommendedSampling<Inner> {
688    fn get_preprocessed_simple_sampling(
689        &self,
690        sampling: SimpleSamplingParams,
691    ) -> SimpleSamplingParams {
692        let mut sampling = sampling;
693        if sampling.top_k.is_none() {
694            sampling.top_k = self.default_sampling.top_k;
695        }
696        if sampling.top_p.is_none() {
697            sampling.top_p = self.default_sampling.top_p;
698        }
699        if sampling.temperature.is_none() {
700            sampling.temperature = self.default_sampling.temperature;
701        }
702        sampling
703    }
704}
705
706impl<'s, 'req, Inner, Tmpl> VisionLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
707where
708    Tmpl: ChatTemplate,
709    Inner: VisionLmRunner<'s, 'req, Template = Tmpl>,
710{
711    type Response = <Inner as VisionLmRunner<'s, 'req>>::Response;
712    type Template = Tmpl;
713
714    fn stream_vlm_response(
715        &'s self,
716        mut request: GenericVisionLmRequest<'req, Tmpl>,
717    ) -> Self::Response {
718        request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
719        self.inner.stream_vlm_response(request)
720    }
721}
722
723impl<'s, 'req, Inner, Tmpl> TextLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
724where
725    Tmpl: ChatTemplate,
726    Inner: TextLmRunner<'s, 'req, Template = Tmpl>,
727{
728    type Response = <Inner as TextLmRunner<'s, 'req>>::Response;
729    type Template = Tmpl;
730
731    fn stream_lm_response(
732        &'s self,
733        mut request: GenericTextLmRequest<'req, Tmpl>,
734    ) -> Self::Response {
735        request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
736        self.inner.stream_lm_response(request)
737    }
738}
739
740impl<Inner> From<Inner> for RunnerWithRecommendedSampling<Inner> {
741    fn from(value: Inner) -> Self {
742        Self {
743            inner: value,
744            default_sampling: SimpleSamplingParams::default(),
745        }
746    }
747}
748
749fn build_hf_api() -> Result<hf_hub::api::tokio::Api, hf_hub::api::tokio::ApiError> {
750    let mut api = ApiBuilder::new()
751        .with_progress(std::io::stdin().is_terminal())
752        .with_token(std::env::var("HF_TOKEN").ok())
753        .with_chunk_size(Some(2 << 28));
754    if let Ok(endpoint) = std::env::var("HF_ENDPOINT") {
755        api = api.with_endpoint(endpoint);
756    }
757    if let Ok(cache) = std::env::var("HF_HOME") {
758        api = api.with_cache_dir(
759            PathBuf::from_str(&cache).expect("HF_HOME env var is not a valid path"),
760        );
761    }
762    api.build()
763}