Skip to main content

llama_runner/
runner.rs

1use std::{
2    io::IsTerminal,
3    num::NonZeroU32,
4    path::{Path, PathBuf},
5    str::FromStr,
6    sync::LazyLock,
7};
8
9use encoding_rs::{Decoder, UTF_8};
10use hf_hub::api::tokio::ApiBuilder;
11use llama_cpp_2::{
12    LlamaContextLoadError,
13    context::{LlamaContext, params::LlamaContextParams},
14    llama_backend::LlamaBackend,
15    llama_batch::LlamaBatch,
16    model::{AddBos, LlamaChatTemplate, LlamaModel},
17    mtmd::{self, MtmdBitmap, MtmdContext, MtmdInputText},
18    sampling::LlamaSampler,
19    token::LlamaToken,
20};
21use strum::Display;
22
23use crate::{
24    error::{CreateLlamaCppRunnerError, GenericRunnerError},
25    sample::{LlguidanceSamplingParams, SimpleSamplingParams},
26    template::{ChatTemplate, ModelChatTemplate},
27};
28
29pub const QWEN_3D5_4B_GUFF_MODEL_ID: &str = "unsloth/Qwen3.5-4B-GGUF";
30pub const QWEN_3D5_4B_GUFF_MODDEL_FILENAME: &str = "Qwen3.5-4B-Q4_K_M.gguf";
31pub const QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
32
33pub const GEMMA_3_1B_GUFF_MODEL_ID: &str = "google/gemma-3-1b-it-qat-q4_0-gguf";
34pub const GEMMA_3_1B_GUFF_MODEL_FILENAME: &str = "gemma-3-1b-it-q4_0.gguf";
35
36pub trait TextLmRunner<'s, 'req> {
37    fn stream_lm_response<Tmpl>(
38        &'s self,
39        request: GenericTextLmRequest<'req, Tmpl>,
40    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
41    where
42        Tmpl: ChatTemplate;
43}
44
45pub trait VisionLmRunner<'s, 'req> {
46    fn stream_vlm_response<Tmpl>(
47        &'s self,
48        request: GenericVisionLmRequest<'req, Tmpl>,
49    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
50    where
51        Tmpl: ChatTemplate;
52}
53
54#[derive(Debug, Clone)]
55pub struct GenericRunnerRequest<MsgCt, Tmpl> {
56    pub messages: Vec<(MessageRole, MsgCt)>,
57    pub sampling: SimpleSamplingParams,
58    pub llguidance: Option<LlguidanceSamplingParams>,
59    pub max_seq: usize,
60    pub prefill: Option<String>,
61    pub tmpl: Tmpl,
62}
63
64impl<M, T> Default for GenericRunnerRequest<M, T>
65where
66    T: Default,
67{
68    fn default() -> Self {
69        Self {
70            messages: vec![],
71            sampling: Default::default(),
72            llguidance: None,
73            max_seq: usize::MAX,
74            prefill: None,
75            tmpl: Default::default(),
76        }
77    }
78}
79
80pub type GenericTextLmRequest<'a, Tmpl> = GenericRunnerRequest<&'a str, Tmpl>;
81pub type GenericVisionLmRequest<'a, Tmpl> = GenericRunnerRequest<ImageOrText<'a>, Tmpl>;
82
83pub type RunnerRequest<'a, MsgCnt> = GenericRunnerRequest<MsgCnt, ModelChatTemplate>;
84pub type TextLmRequest<'a> = RunnerRequest<'a, &'a str>;
85pub type VisionLmRequest<'a> = RunnerRequest<'a, ImageOrText<'a>>;
86
87pub trait TextLmRunnerExt<'s, 'req> {
88    fn get_lm_response<Tmpl>(
89        &'s self,
90        request: GenericTextLmRequest<'req, Tmpl>,
91    ) -> Result<String, GenericRunnerError<Tmpl::Error>>
92    where
93        Tmpl: ChatTemplate;
94}
95
96pub trait VisionLmRunnerExt<'s, 'req> {
97    fn get_vlm_response<Tmpl>(
98        &'s self,
99        request: GenericVisionLmRequest<'req, Tmpl>,
100    ) -> Result<String, GenericRunnerError<Tmpl::Error>>
101    where
102        Tmpl: ChatTemplate;
103}
104
105impl<'s, 'req, TextRunner> TextLmRunnerExt<'s, 'req> for TextRunner
106where
107    TextRunner: TextLmRunner<'s, 'req>,
108{
109    fn get_lm_response<Tmpl>(
110        &'s self,
111        request: GenericTextLmRequest<'req, Tmpl>,
112    ) -> Result<String, GenericRunnerError<Tmpl::Error>>
113    where
114        Tmpl: ChatTemplate,
115    {
116        self.stream_lm_response(request)
117            .collect::<Result<String, _>>()
118    }
119}
120
121impl<'s, 'req, VisionRunner> VisionLmRunnerExt<'s, 'req> for VisionRunner
122where
123    VisionRunner: VisionLmRunner<'s, 'req>,
124{
125    fn get_vlm_response<Tmpl>(
126        &'s self,
127        request: GenericVisionLmRequest<'req, Tmpl>,
128    ) -> Result<String, GenericRunnerError<Tmpl::Error>>
129    where
130        Tmpl: ChatTemplate,
131    {
132        self.stream_vlm_response(request)
133            .collect::<Result<String, _>>()
134    }
135}
136
137#[derive(Debug, Clone, Display, PartialEq, Eq)]
138pub enum MessageRole {
139    #[strum(to_string = "assistant")]
140    Assistant,
141    #[strum(to_string = "user")]
142    User,
143    #[strum(to_string = "system")]
144    System,
145    #[strum(to_string = "{0}")]
146    Custom(&'static str),
147}
148
149#[derive(Debug, Clone)]
150pub enum ImageOrText<'a> {
151    Text(&'a str),
152    Image(&'a image::DynamicImage),
153}
154
155pub struct Gemma3TextRunner {
156    model: LlamaModel,
157    llama_template: LlamaChatTemplate,
158    ctx_size: NonZeroU32,
159}
160
161impl Gemma3TextRunner {
162    pub async fn new(
163        model_id: impl ToString,
164        model_file: impl AsRef<str>,
165        ctx_size: NonZeroU32,
166    ) -> Result<Self, CreateLlamaCppRunnerError> {
167        let repo = build_hf_api()?.model(model_id.to_string());
168        Self::from_file(repo.get(model_file.as_ref()).await?, ctx_size)
169    }
170
171    pub fn recommend_sampling() -> SimpleSamplingParams {
172        SimpleSamplingParams {
173            top_p: Some(0.95f32),
174            top_k: Some(64),
175            temperature: Some(1f32),
176            ..Default::default()
177        }
178    }
179
180    pub fn from_file(
181        model_file: impl AsRef<Path>,
182        ctx_size: NonZeroU32,
183    ) -> Result<Self, CreateLlamaCppRunnerError> {
184        let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
185
186        let chat_template = model.chat_template(None)?;
187        Ok(Self {
188            model,
189            llama_template: chat_template,
190            ctx_size,
191        })
192    }
193
194    pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
195    {
196        let inner = Self::new(
197            GEMMA_3_1B_GUFF_MODEL_ID,
198            GEMMA_3_1B_GUFF_MODEL_FILENAME,
199            32_000.try_into().unwrap(),
200        )
201        .await?;
202        Ok(RunnerWithRecommendedSampling {
203            inner,
204            default_sampling: Self::recommend_sampling(),
205        })
206    }
207}
208
209impl<'s, 'req> TextLmRunner<'s, 'req> for Gemma3TextRunner {
210    fn stream_lm_response<Tmpl>(
211        &'s self,
212        request: GenericTextLmRequest<'req, Tmpl>,
213    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
214    where
215        Tmpl: ChatTemplate,
216    {
217        let ctx = self
218            .model
219            .new_context(
220                &LLAMA_BACKEND,
221                LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
222            )
223            .map_err(|err| GenericRunnerError::from(err));
224        Gemma3Stream::new(ctx, request, self, &self.model)
225    }
226}
227
228pub struct Gemma3VisionRunner {
229    model: LlamaModel,
230    chat_template: LlamaChatTemplate,
231    mtmd_ctx: MtmdContext,
232    ctx_size: NonZeroU32,
233}
234
235static LLAMA_BACKEND: LazyLock<LlamaBackend> = LazyLock::new(|| {
236    llama_cpp_2::send_logs_to_tracing(llama_cpp_2::LogOptions::default());
237    LlamaBackend::init().unwrap()
238});
239
240impl Gemma3VisionRunner {
241    pub async fn new(
242        repo_id: impl ToString,
243        model_file: impl AsRef<str>,
244        multimodel_file: impl AsRef<str>,
245        ctx_size: NonZeroU32,
246    ) -> Result<Self, CreateLlamaCppRunnerError> {
247        let repo = build_hf_api()?.model(repo_id.to_string());
248        let model = LlamaModel::load_from_file(
249            &LLAMA_BACKEND,
250            repo.get(model_file.as_ref()).await?,
251            &Default::default(),
252        )?;
253
254        let mtmd_ctx = MtmdContext::init_from_file(
255            repo.get(multimodel_file.as_ref()).await?.to_str().unwrap(),
256            &model,
257            &Default::default(),
258        )?;
259
260        let chat_template = model.chat_template(None)?;
261
262        Ok(Self {
263            model,
264            mtmd_ctx,
265            chat_template,
266            ctx_size,
267        })
268    }
269
270    pub fn from_files(
271        model_file: impl AsRef<Path>,
272        multimodel_file: impl AsRef<Path>,
273        ctx_size: NonZeroU32,
274    ) -> Result<Self, CreateLlamaCppRunnerError> {
275        let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
276        let mtmd_ctx = MtmdContext::init_from_file(
277            multimodel_file.as_ref().as_os_str().to_str().unwrap(),
278            &model,
279            &Default::default(),
280        )?;
281
282        let chat_template = model.chat_template(None)?;
283
284        Ok(Self {
285            model,
286            mtmd_ctx,
287            chat_template,
288            ctx_size,
289        })
290    }
291
292    fn new_context_window(&self) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
293        self.model.new_context(
294            &LLAMA_BACKEND,
295            LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
296        )
297    }
298
299    pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
300    {
301        let inner = Self::new(
302            QWEN_3D5_4B_GUFF_MODEL_ID,
303            QWEN_3D5_4B_GUFF_MODDEL_FILENAME,
304            QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME,
305            16384u32.try_into().unwrap(),
306        )
307        .await?;
308        Ok(RunnerWithRecommendedSampling {
309            inner: inner,
310            default_sampling: SimpleSamplingParams {
311                top_p: Some(0.8f32),
312                top_k: Some(20),
313                temperature: Some(0.7f32),
314                presence_penalty: Some(1.5),
315                repetition_penalty: Some(1.0),
316                seed: None,
317            },
318        })
319    }
320}
321
322impl<'s, 'req> VisionLmRunner<'s, 'req> for Gemma3VisionRunner {
323    fn stream_vlm_response<Tmpl>(
324        &'s self,
325        request: GenericVisionLmRequest<'req, Tmpl>,
326    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
327    where
328        Tmpl: ChatTemplate,
329    {
330        let ctx = self
331            .new_context_window()
332            .map_err(|err| GenericRunnerError::from(err));
333        Gemma3Stream::new(ctx, request, self, &self.model)
334    }
335}
336
337impl<'s, 'req> TextLmRunner<'s, 'req> for Gemma3VisionRunner {
338    fn stream_lm_response<Tmpl>(
339        &'s self,
340        request: GenericTextLmRequest<'req, Tmpl>,
341    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
342    where
343        Tmpl: ChatTemplate,
344    {
345        self.stream_vlm_response(request.into())
346    }
347}
348
349impl<'a, Tmpl> From<GenericTextLmRequest<'a, Tmpl>> for GenericVisionLmRequest<'a, Tmpl> {
350    fn from(value: GenericTextLmRequest<'a, Tmpl>) -> Self {
351        Self {
352            messages: value
353                .messages
354                .into_iter()
355                .map(|(role, text)| (role, ImageOrText::Text(text)))
356                .collect(),
357            sampling: value.sampling,
358            llguidance: value.llguidance,
359            max_seq: value.max_seq,
360            prefill: value.prefill,
361            tmpl: value.tmpl,
362        }
363    }
364}
365
366pub struct Gemma3Stream<'a, Message, Runner, Tmpl: ChatTemplate> {
367    ctx_source: Option<Result<LlamaContext<'a>, GenericRunnerError<Tmpl::Error>>>,
368    ctx: Option<LlamaContext<'a>>,
369    req: GenericRunnerRequest<Message, Tmpl>,
370    runner: &'a Runner,
371    model: &'a LlamaModel,
372    runtime: Option<Runtime<'a>>,
373    done: bool,
374}
375
376struct Runtime<'a> {
377    sampler: LlamaSampler,
378    decoder: Decoder,
379    batch: LlamaBatch<'a>,
380    n_past: i32,
381    step: usize,
382}
383
384trait PrepareRun<TmplErr> {
385    fn prepare(&mut self) -> Result<(), GenericRunnerError<TmplErr>>;
386}
387
388impl<Tmpl> PrepareRun<Tmpl::Error> for Gemma3Stream<'_, ImageOrText<'_>, Gemma3VisionRunner, Tmpl>
389where
390    Tmpl: ChatTemplate,
391{
392    fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
393        // Preprocess the message, flattening media
394        let media_marker = mtmd::mtmd_default_marker();
395        let messages = self
396            .req
397            .messages
398            .iter()
399            .fold(
400                Vec::<(MessageRole, String)>::new(),
401                |mut acc, (role, message)| {
402                    let text = match message {
403                        ImageOrText::Text(text) => text,
404                        ImageOrText::Image(_) => media_marker,
405                    };
406                    if let Some(last) = acc.last()
407                        && last.0 == *role
408                    {
409                        // merge adjacent
410                        let (_, adj) = acc.remove(acc.len() - 1);
411                        acc.push((role.clone(), format!("{0}\n{text}", adj)));
412                        acc
413                    } else {
414                        acc.push((role.clone(), text.to_string()));
415                        acc
416                    }
417                },
418            )
419            .into_iter()
420            .collect::<Vec<_>>();
421        log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
422
423        // apply custom template
424        let formatted_prompt = self
425            .req
426            .tmpl
427            .apply_template(self.model, &self.runner.chat_template, &messages)
428            .map_err(GenericRunnerError::ApplyChatTemplate)?;
429
430        // Aggregate images
431        let bitmaps = self
432            .req
433            .messages
434            .iter()
435            .filter_map(|msg| match &msg.1 {
436                ImageOrText::Image(image) => Some(image),
437                _ => None,
438            })
439            .enumerate()
440            .map(|(idx, im)| {
441                MtmdBitmap::from_image_data(
442                    im.width(),
443                    im.height(),
444                    im.to_rgb8().to_vec().as_slice(),
445                )
446                .expect(format!("image#{} has corrupted RGB data", idx).as_str())
447            })
448            .collect::<Vec<_>>();
449        let bitmap_refs = bitmaps.iter().collect::<Vec<_>>();
450        let chunks = self.runner.mtmd_ctx.tokenize(
451            MtmdInputText {
452                text: formatted_prompt,
453                add_special: true,
454                parse_special: true,
455            },
456            &bitmap_refs,
457        )?;
458        log::debug!(target: "gemma", "tokenization resulted in {} chunks", chunks.len());
459        let n_past = chunks.eval_chunks(
460            &self.runner.mtmd_ctx,
461            self.ctx.as_ref().unwrap(),
462            0,
463            0,
464            1,
465            true,
466        )?;
467
468        // Generate preparation
469        let mut preparation = Runtime {
470            sampler: self.req.sampling.to_llama(),
471            decoder: UTF_8.new_decoder(),
472            batch: LlamaBatch::new(self.runner.ctx_size.get() as usize, 1),
473            n_past,
474            step: 0,
475        };
476        if let Some(llguidance) = &self.req.llguidance {
477            let llg_sampler = llguidance.to_llama(&self.runner.model)?;
478            preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
479        }
480        self.runtime = Some(preparation);
481
482        Ok(())
483    }
484}
485
486impl<S: AsRef<str>, Tmpl> PrepareRun<Tmpl::Error> for Gemma3Stream<'_, S, Gemma3TextRunner, Tmpl>
487where
488    Tmpl: ChatTemplate,
489{
490    fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
491        // Preprocess the message
492        let messages = self
493            .req
494            .messages
495            .iter()
496            .fold(
497                Vec::<(MessageRole, String)>::new(),
498                |mut acc, (role, message)| {
499                    if let Some(last) = acc.last()
500                        && last.0 == *role
501                    {
502                        // merge adjacent
503                        let (_, adj) = acc.remove(acc.len() - 1);
504                        acc.push((role.clone(), format!("{0}\n{1}", adj, message.as_ref())));
505                        acc
506                    } else {
507                        acc.push((role.clone(), message.as_ref().to_string()));
508                        acc
509                    }
510                },
511            )
512            .into_iter()
513            .collect::<Vec<_>>();
514        log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
515
516        // apply custom template
517        let formatted_prompt = self
518            .req
519            .tmpl
520            .apply_template(self.model, &self.runner.llama_template, &messages)
521            .map_err(GenericRunnerError::ApplyChatTemplate)?;
522
523        // Aggregate images
524        let token_list = self.model.str_to_token(&formatted_prompt, AddBos::Always)?;
525        let mut batch = LlamaBatch::new(self.runner.ctx_size.get() as usize, 1);
526        let token_list_len = token_list.len();
527        for (i, token) in token_list.into_iter().enumerate() {
528            batch.add(token, i as i32, &[0], i == token_list_len - 1)?;
529        }
530        self.ctx.as_mut().unwrap().decode(&mut batch)?;
531
532        // Generate preparation
533        let mut preparation = Runtime {
534            sampler: self.req.sampling.to_llama(),
535            decoder: UTF_8.new_decoder(),
536            batch,
537            n_past: token_list_len as i32,
538            step: 0,
539        };
540        if let Some(llguidance) = &self.req.llguidance {
541            let llg_sampler = llguidance.to_llama(&self.runner.model)?;
542            preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
543        }
544        self.runtime = Some(preparation);
545
546        Ok(())
547    }
548}
549
550impl<'a, Message, Runner, Tmpl> Iterator for Gemma3Stream<'a, Message, Runner, Tmpl>
551where
552    Tmpl: ChatTemplate,
553    Self: PrepareRun<Tmpl::Error>,
554{
555    type Item = Result<String, GenericRunnerError<Tmpl::Error>>;
556
557    fn next(&mut self) -> Option<Self::Item> {
558        if self.done {
559            return None;
560        }
561
562        if let Some(result) = self.ctx_source.take() {
563            match result {
564                Ok(ctx) => self.ctx = Some(ctx),
565                Err(err) => {
566                    self.done = true;
567                    return Some(Err(err));
568                }
569            }
570        }
571
572        if self.runtime.is_none()
573            && let Err(err) = self.prepare()
574        {
575            self.done = true;
576            return Some(Err(err));
577        }
578        let Runtime {
579            sampler,
580            decoder,
581            batch,
582            n_past,
583            step,
584        } = self.runtime.as_mut().unwrap();
585
586        if *step >= self.req.max_seq {
587            self.done = true;
588            return None;
589        }
590
591        // Sample response token
592        let ctx = self.ctx.as_mut().unwrap();
593        let model = self.model;
594        let sample_idx = batch.n_tokens() - 1;
595        let mut sample = |token: LlamaToken,
596                          sampler: &mut LlamaSampler,
597                          ctx: &mut LlamaContext<'a>,
598                          step: usize|
599         -> Result<Option<String>, GenericRunnerError<Tmpl::Error>> {
600            sampler.accept(token);
601            if model.is_eog_token(token) {
602                return Ok(None);
603            }
604            batch.clear();
605            batch.add(token, *n_past + (step as i32), &[0], true)?;
606
607            ctx.decode(batch)?;
608
609            let piece = model.token_to_piece(token, decoder, true, None)?;
610            Ok(Some(piece))
611        };
612        if let Some(prefill) = self.req.prefill.take() {
613            log::debug!(target: "gemma", "prefill: {}", prefill);
614            let tokens = match model.str_to_token(&prefill, AddBos::Never) {
615                Ok(tokens) => tokens,
616                Err(err) => {
617                    return Some(Err(err.into()));
618                }
619            };
620            log::debug!(target: "gemma", "prefill tokens: {:?}", tokens.iter().map(|t| t.0).collect::<Vec<_>>());
621            for token in tokens {
622                match sample(token, sampler, ctx, *step) {
623                    Ok(_) => {}
624                    Err(err) => return Some(Err(err.into())),
625                }
626                *step += 1;
627            }
628            Some(Ok(prefill))
629        } else {
630            let token = sampler.sample(ctx, sample_idx);
631            match sample(token, sampler, ctx, *step) {
632                Ok(Some(piece)) => {
633                    *step += 1;
634                    return Some(Ok(piece));
635                }
636                Ok(None) => {
637                    self.done = true;
638                    return None;
639                }
640                Err(err) => {
641                    self.done = true;
642                    return Some(Err(err));
643                }
644            }
645        }
646    }
647}
648
649impl<'s, Message, Runner, Tmpl> Gemma3Stream<'s, Message, Runner, Tmpl>
650where
651    Tmpl: ChatTemplate,
652{
653    fn new(
654        source: Result<LlamaContext<'s>, GenericRunnerError<Tmpl::Error>>,
655        req: GenericRunnerRequest<Message, Tmpl>,
656        runner: &'s Runner,
657        model: &'s LlamaModel,
658    ) -> Self {
659        Self {
660            ctx_source: Some(source),
661            ctx: None,
662            req,
663            runner,
664            model,
665            runtime: None,
666            done: false,
667        }
668    }
669}
670
671pub struct RunnerWithRecommendedSampling<Inner> {
672    pub inner: Inner,
673    pub default_sampling: SimpleSamplingParams,
674}
675
676impl<'a, Inner> RunnerWithRecommendedSampling<Inner> {
677    fn get_preprocessed_simple_sampling(
678        &self,
679        sampling: SimpleSamplingParams,
680    ) -> SimpleSamplingParams {
681        let mut sampling = sampling;
682        if sampling.top_k.is_none() {
683            sampling.top_k = self.default_sampling.top_k;
684        }
685        if sampling.top_p.is_none() {
686            sampling.top_p = self.default_sampling.top_p;
687        }
688        if sampling.temperature.is_none() {
689            sampling.temperature = self.default_sampling.temperature;
690        }
691        sampling
692    }
693}
694
695impl<'s, 'req, Inner> VisionLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
696where
697    Inner: VisionLmRunner<'s, 'req>,
698{
699    fn stream_vlm_response<Tmpl>(
700        &'s self,
701        mut request: GenericVisionLmRequest<'req, Tmpl>,
702    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
703    where
704        Tmpl: ChatTemplate,
705    {
706        request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
707        self.inner.stream_vlm_response(request)
708    }
709}
710
711impl<'s, 'req, Inner> TextLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
712where
713    Inner: TextLmRunner<'s, 'req>,
714{
715    fn stream_lm_response<Tmpl>(
716        &'s self,
717        mut request: GenericTextLmRequest<'req, Tmpl>,
718    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
719    where
720        Tmpl: ChatTemplate,
721    {
722        request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
723        self.inner.stream_lm_response(request)
724    }
725}
726
727impl<Inner> From<Inner> for RunnerWithRecommendedSampling<Inner> {
728    fn from(value: Inner) -> Self {
729        Self {
730            inner: value,
731            default_sampling: SimpleSamplingParams::default(),
732        }
733    }
734}
735
736fn build_hf_api() -> Result<hf_hub::api::tokio::Api, hf_hub::api::tokio::ApiError> {
737    let mut api = ApiBuilder::new()
738        .with_progress(std::io::stdin().is_terminal())
739        .with_token(std::env::var("HF_TOKEN").ok())
740        .with_chunk_size(Some(2 << 28));
741    if let Ok(endpoint) = std::env::var("HF_ENDPOINT") {
742        api = api.with_endpoint(endpoint);
743    }
744    if let Ok(cache) = std::env::var("HF_HOME") {
745        api = api.with_cache_dir(
746            PathBuf::from_str(&cache).expect("HF_HOME env var is not a valid path"),
747        );
748    }
749    api.build()
750}