1use std::{
2 io::IsTerminal,
3 marker::PhantomData,
4 num::NonZeroU32,
5 path::{Path, PathBuf},
6 str::FromStr,
7 sync::LazyLock,
8};
9
10use encoding_rs::{Decoder, UTF_8};
11use hf_hub::api::tokio::ApiBuilder;
12use llama_cpp_2::{
13 LlamaContextLoadError,
14 context::{LlamaContext, params::LlamaContextParams},
15 llama_backend::LlamaBackend,
16 llama_batch::LlamaBatch,
17 model::{AddBos, LlamaChatTemplate, LlamaModel},
18 mtmd::{self, MtmdBitmap, MtmdContext, MtmdInputText},
19 sampling::LlamaSampler,
20 token::LlamaToken,
21};
22use strum::Display;
23
24use crate::{
25 error::{CreateLlamaCppRunnerError, GenericRunnerError},
26 sample::{LlguidanceSamplingParams, SimpleSamplingParams},
27 template::{ChatTemplate, ModelChatTemplate},
28};
29
30pub const QWEN_3D5_4B_GUFF_MODEL_ID: &str = "unsloth/Qwen3.5-4B-GGUF";
31pub const QWEN_3D5_4B_GUFF_MODDEL_FILENAME: &str = "Qwen3.5-4B-Q4_K_M.gguf";
32pub const QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
33
34pub const GEMMA_3_1B_GUFF_MODEL_ID: &str = "google/gemma-3-1b-it-qat-q4_0-gguf";
35pub const GEMMA_3_1B_GUFF_MODEL_FILENAME: &str = "gemma-3-1b-it-q4_0.gguf";
36
37pub trait TextLmRunner<'s, 'req> {
38 type Response: Iterator<
39 Item = Result<String, GenericRunnerError<<Self::Template as ChatTemplate>::Error>>,
40 >;
41 type Template: ChatTemplate;
42 fn stream_lm_response(
43 &'s self,
44 request: GenericTextLmRequest<'req, Self::Template>,
45 ) -> Self::Response;
46}
47
48pub trait VisionLmRunner<'s, 'req> {
49 type Response: Iterator<
50 Item = Result<String, GenericRunnerError<<Self::Template as ChatTemplate>::Error>>,
51 >;
52 type Template: ChatTemplate;
53 fn stream_vlm_response(
54 &'s self,
55 request: GenericVisionLmRequest<'req, Self::Template>,
56 ) -> Self::Response;
57}
58
59#[derive(Debug, Clone)]
60pub struct GenericRunnerRequest<MsgCt, Tmpl> {
61 pub messages: Vec<(MessageRole, MsgCt)>,
62 pub sampling: SimpleSamplingParams,
63 pub llguidance: Option<LlguidanceSamplingParams>,
64 pub max_seq: usize,
65 pub prefill: Option<String>,
66 pub tmpl: Tmpl,
67}
68
69impl<M, T> Default for GenericRunnerRequest<M, T>
70where
71 T: Default,
72{
73 fn default() -> Self {
74 Self {
75 messages: vec![],
76 sampling: Default::default(),
77 llguidance: None,
78 max_seq: usize::MAX,
79 prefill: None,
80 tmpl: Default::default(),
81 }
82 }
83}
84
85pub type GenericTextLmRequest<'a, Tmpl> = GenericRunnerRequest<&'a str, Tmpl>;
86pub type GenericVisionLmRequest<'a, Tmpl> = GenericRunnerRequest<ImageOrText<'a>, Tmpl>;
87
88pub type RunnerRequest<'a, MsgCnt> = GenericRunnerRequest<MsgCnt, ModelChatTemplate>;
89pub type TextLmRequest<'a> = RunnerRequest<'a, &'a str>;
90pub type VisionLmRequest<'a> = RunnerRequest<'a, ImageOrText<'a>>;
91
92pub trait TextLmRunnerExt<'s, 'req, Tmpl, TmplErr> {
93 fn get_lm_response(
94 &'s self,
95 request: GenericTextLmRequest<'req, Tmpl>,
96 ) -> Result<String, GenericRunnerError<TmplErr>>;
97}
98
99pub trait VisionLmRunnerExt<'s, 'req, Tmpl, TmplErr> {
100 fn get_vlm_response(
101 &'s self,
102 request: GenericVisionLmRequest<'req, Tmpl>,
103 ) -> Result<String, GenericRunnerError<TmplErr>>;
104}
105
106impl<'s, 'req, TextRunner, Tmpl> TextLmRunnerExt<'s, 'req, Tmpl, Tmpl::Error> for TextRunner
107where
108 Tmpl: ChatTemplate,
109 TextRunner: TextLmRunner<'s, 'req, Template = Tmpl>,
110{
111 fn get_lm_response(
112 &'s self,
113 request: GenericTextLmRequest<'req, Tmpl>,
114 ) -> Result<String, GenericRunnerError<Tmpl::Error>> {
115 self.stream_lm_response(request)
116 .collect::<Result<String, _>>()
117 }
118}
119
120impl<'s, 'req, VisionRunner, Tmpl> VisionLmRunnerExt<'s, 'req, Tmpl, Tmpl::Error> for VisionRunner
121where
122 Tmpl: ChatTemplate,
123 VisionRunner: VisionLmRunner<'s, 'req, Template = Tmpl>,
124{
125 fn get_vlm_response(
126 &'s self,
127 request: GenericVisionLmRequest<'req, Tmpl>,
128 ) -> Result<String, GenericRunnerError<Tmpl::Error>> {
129 self.stream_vlm_response(request)
130 .collect::<Result<String, _>>()
131 }
132}
133
134#[derive(Debug, Clone, Display, PartialEq, Eq)]
135pub enum MessageRole {
136 #[strum(to_string = "assistant")]
137 Assistant,
138 #[strum(to_string = "user")]
139 User,
140 #[strum(to_string = "system")]
141 System,
142 #[strum(to_string = "{0}")]
143 Custom(&'static str),
144}
145
146#[derive(Debug, Clone)]
147pub enum ImageOrText<'a> {
148 Text(&'a str),
149 Image(&'a image::DynamicImage),
150}
151
152pub struct Gemma3TextRunner<Tmpl> {
153 model: LlamaModel,
154 llama_template: LlamaChatTemplate,
155 ctx_size: NonZeroU32,
156 _tmpl: PhantomData<Tmpl>,
157}
158
159impl<Tmpl> Gemma3TextRunner<Tmpl> {
160 pub async fn new(
161 model_id: impl ToString,
162 model_file: impl AsRef<str>,
163 ctx_size: NonZeroU32,
164 ) -> Result<Self, CreateLlamaCppRunnerError> {
165 let repo = build_hf_api()?.model(model_id.to_string());
166 Self::from_file(repo.get(model_file.as_ref()).await?, ctx_size)
167 }
168
169 pub fn recommend_sampling() -> SimpleSamplingParams {
170 SimpleSamplingParams {
171 top_p: Some(0.95f32),
172 top_k: Some(64),
173 temperature: Some(1f32),
174 ..Default::default()
175 }
176 }
177
178 pub fn from_file(
179 model_file: impl AsRef<Path>,
180 ctx_size: NonZeroU32,
181 ) -> Result<Self, CreateLlamaCppRunnerError> {
182 let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
183
184 let chat_template = model.chat_template(None)?;
185 Ok(Self {
186 model,
187 llama_template: chat_template,
188 ctx_size,
189 _tmpl: PhantomData,
190 })
191 }
192}
193
194impl Gemma3TextRunner<ModelChatTemplate> {
195 pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
196 {
197 let inner = Self::new(
198 GEMMA_3_1B_GUFF_MODEL_ID,
199 GEMMA_3_1B_GUFF_MODEL_FILENAME,
200 32_000.try_into().unwrap(),
201 )
202 .await?;
203 Ok(RunnerWithRecommendedSampling {
204 inner,
205 default_sampling: Self::recommend_sampling(),
206 })
207 }
208}
209
210impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req> for Gemma3TextRunner<Tmpl>
211where
212 Tmpl: ChatTemplate + 's,
213{
214 type Response = Gemma3Stream<'s, &'req str, Self, Tmpl>;
215 type Template = Tmpl;
216
217 fn stream_lm_response(&'s self, request: GenericTextLmRequest<'req, Tmpl>) -> Self::Response {
218 let ctx = self
219 .model
220 .new_context(
221 &LLAMA_BACKEND,
222 LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
223 )
224 .map_err(|err| GenericRunnerError::from(err));
225 Gemma3Stream::new(ctx, request, self, &self.model)
226 }
227}
228
229pub struct Gemma3VisionRunner<Tmpl> {
230 model: LlamaModel,
231 chat_template: LlamaChatTemplate,
232 mtmd_ctx: MtmdContext,
233 ctx_size: NonZeroU32,
234 _tmpl: PhantomData<Tmpl>,
235}
236
237static LLAMA_BACKEND: LazyLock<LlamaBackend> = LazyLock::new(|| {
238 llama_cpp_2::send_logs_to_tracing(llama_cpp_2::LogOptions::default());
239 LlamaBackend::init().unwrap()
240});
241
242impl<Tmpl> Gemma3VisionRunner<Tmpl> {
243 pub async fn new(
244 repo_id: impl ToString,
245 model_file: impl AsRef<str>,
246 multimodel_file: impl AsRef<str>,
247 ctx_size: NonZeroU32,
248 ) -> Result<Self, CreateLlamaCppRunnerError> {
249 let repo = build_hf_api()?.model(repo_id.to_string());
250 let model = LlamaModel::load_from_file(
251 &LLAMA_BACKEND,
252 repo.get(model_file.as_ref()).await?,
253 &Default::default(),
254 )?;
255
256 let mtmd_ctx = MtmdContext::init_from_file(
257 repo.get(multimodel_file.as_ref()).await?.to_str().unwrap(),
258 &model,
259 &Default::default(),
260 )?;
261
262 let chat_template = model.chat_template(None)?;
263
264 Ok(Self {
265 model,
266 mtmd_ctx,
267 chat_template,
268 ctx_size,
269 _tmpl: PhantomData,
270 })
271 }
272
273 pub fn from_files(
274 model_file: impl AsRef<Path>,
275 multimodel_file: impl AsRef<Path>,
276 ctx_size: NonZeroU32,
277 ) -> Result<Self, CreateLlamaCppRunnerError> {
278 let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
279 let mtmd_ctx = MtmdContext::init_from_file(
280 multimodel_file.as_ref().as_os_str().to_str().unwrap(),
281 &model,
282 &Default::default(),
283 )?;
284
285 let chat_template = model.chat_template(None)?;
286
287 Ok(Self {
288 model,
289 mtmd_ctx,
290 chat_template,
291 ctx_size,
292 _tmpl: PhantomData,
293 })
294 }
295
296 fn new_context_window(&self) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
297 self.model.new_context(
298 &LLAMA_BACKEND,
299 LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
300 )
301 }
302}
303
304impl Gemma3VisionRunner<ModelChatTemplate> {
305 pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
306 {
307 let inner = Self::new(
308 QWEN_3D5_4B_GUFF_MODEL_ID,
309 QWEN_3D5_4B_GUFF_MODDEL_FILENAME,
310 QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME,
311 16384u32.try_into().unwrap(),
312 )
313 .await?;
314 Ok(RunnerWithRecommendedSampling {
315 inner: inner,
316 default_sampling: SimpleSamplingParams {
317 top_p: Some(0.8f32),
318 top_k: Some(20),
319 temperature: Some(0.7f32),
320 presence_penalty: Some(1.5),
321 repetition_penalty: Some(1.0),
322 seed: None,
323 },
324 })
325 }
326}
327
328impl<'s, 'req, Tmpl> VisionLmRunner<'s, 'req> for Gemma3VisionRunner<Tmpl>
329where
330 Tmpl: ChatTemplate + 's,
331{
332 type Response = Gemma3Stream<'s, ImageOrText<'req>, Self, Tmpl>;
333 type Template = Tmpl;
334
335 fn stream_vlm_response(
336 &'s self,
337 request: GenericVisionLmRequest<'req, Tmpl>,
338 ) -> Self::Response {
339 let ctx = self
340 .new_context_window()
341 .map_err(|err| GenericRunnerError::from(err));
342 Gemma3Stream::new(ctx, request, self, &self.model)
343 }
344}
345
346impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req> for Gemma3VisionRunner<Tmpl>
347where
348 Tmpl: ChatTemplate + 's,
349{
350 type Response = <Self as VisionLmRunner<'s, 'req>>::Response;
351 type Template = <Self as VisionLmRunner<'s, 'req>>::Template;
352
353 fn stream_lm_response(&'s self, request: GenericTextLmRequest<'req, Tmpl>) -> Self::Response {
354 self.stream_vlm_response(request.into())
355 }
356}
357
358impl<'a, Tmpl> From<GenericTextLmRequest<'a, Tmpl>> for GenericVisionLmRequest<'a, Tmpl> {
359 fn from(value: GenericTextLmRequest<'a, Tmpl>) -> Self {
360 Self {
361 messages: value
362 .messages
363 .into_iter()
364 .map(|(role, text)| (role, ImageOrText::Text(text)))
365 .collect(),
366 sampling: value.sampling,
367 llguidance: value.llguidance,
368 max_seq: value.max_seq,
369 prefill: value.prefill,
370 tmpl: value.tmpl,
371 }
372 }
373}
374
375pub struct Gemma3Stream<'a, Message, Runner, Tmpl: ChatTemplate> {
376 ctx_source: Option<Result<LlamaContext<'a>, GenericRunnerError<Tmpl::Error>>>,
377 ctx: Option<LlamaContext<'a>>,
378 req: GenericRunnerRequest<Message, Tmpl>,
379 runner: &'a Runner,
380 model: &'a LlamaModel,
381 runtime: Option<Runtime<'a>>,
382 done: bool,
383}
384
385struct Runtime<'a> {
386 sampler: LlamaSampler,
387 decoder: Decoder,
388 batch: LlamaBatch<'a>,
389 n_past: i32,
390 step: usize,
391}
392
393trait PrepareRun<TmplErr> {
394 fn prepare(&mut self) -> Result<(), GenericRunnerError<TmplErr>>;
395}
396
397impl<Tmpl> PrepareRun<Tmpl::Error>
398 for Gemma3Stream<'_, ImageOrText<'_>, Gemma3VisionRunner<Tmpl>, Tmpl>
399where
400 Tmpl: ChatTemplate,
401{
402 fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
403 let media_marker = mtmd::mtmd_default_marker();
405 let messages = self
406 .req
407 .messages
408 .iter()
409 .fold(
410 Vec::<(MessageRole, String)>::new(),
411 |mut acc, (role, message)| {
412 let text = match message {
413 ImageOrText::Text(text) => text,
414 ImageOrText::Image(_) => media_marker,
415 };
416 if let Some(last) = acc.last()
417 && last.0 == *role
418 {
419 let (_, adj) = acc.remove(acc.len() - 1);
421 acc.push((role.clone(), format!("{0}\n{text}", adj)));
422 acc
423 } else {
424 acc.push((role.clone(), text.to_string()));
425 acc
426 }
427 },
428 )
429 .into_iter()
430 .collect::<Vec<_>>();
431 log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
432
433 let formatted_prompt = self
435 .req
436 .tmpl
437 .apply_template(self.model, &self.runner.chat_template, &messages)
438 .map_err(GenericRunnerError::ApplyChatTemplate)?;
439
440 let bitmaps = self
442 .req
443 .messages
444 .iter()
445 .filter_map(|msg| match &msg.1 {
446 ImageOrText::Image(image) => Some(image),
447 _ => None,
448 })
449 .enumerate()
450 .map(|(idx, im)| {
451 MtmdBitmap::from_image_data(
452 im.width(),
453 im.height(),
454 im.to_rgb8().to_vec().as_slice(),
455 )
456 .expect(format!("image#{} has corrupted RGB data", idx).as_str())
457 })
458 .collect::<Vec<_>>();
459 let bitmap_refs = bitmaps.iter().collect::<Vec<_>>();
460 let chunks = self.runner.mtmd_ctx.tokenize(
461 MtmdInputText {
462 text: formatted_prompt,
463 add_special: true,
464 parse_special: true,
465 },
466 &bitmap_refs,
467 )?;
468 log::debug!(target: "gemma", "tokenization resulted in {} chunks", chunks.len());
469 let n_past = chunks.eval_chunks(
470 &self.runner.mtmd_ctx,
471 self.ctx.as_ref().unwrap(),
472 0,
473 0,
474 1,
475 true,
476 )?;
477
478 let mut preparation = Runtime {
480 sampler: self.req.sampling.to_llama(),
481 decoder: UTF_8.new_decoder(),
482 batch: LlamaBatch::new(self.runner.ctx_size.get() as usize, 1),
483 n_past,
484 step: 0,
485 };
486 if let Some(llguidance) = &self.req.llguidance {
487 let llg_sampler = llguidance.to_llama(&self.runner.model)?;
488 preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
489 }
490 self.runtime = Some(preparation);
491
492 Ok(())
493 }
494}
495
496impl<S: AsRef<str>, Tmpl> PrepareRun<Tmpl::Error>
497 for Gemma3Stream<'_, S, Gemma3TextRunner<Tmpl>, Tmpl>
498where
499 Tmpl: ChatTemplate,
500{
501 fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
502 let messages = self
504 .req
505 .messages
506 .iter()
507 .fold(
508 Vec::<(MessageRole, String)>::new(),
509 |mut acc, (role, message)| {
510 if let Some(last) = acc.last()
511 && last.0 == *role
512 {
513 let (_, adj) = acc.remove(acc.len() - 1);
515 acc.push((role.clone(), format!("{0}\n{1}", adj, message.as_ref())));
516 acc
517 } else {
518 acc.push((role.clone(), message.as_ref().to_string()));
519 acc
520 }
521 },
522 )
523 .into_iter()
524 .collect::<Vec<_>>();
525 log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
526
527 let formatted_prompt = self
529 .req
530 .tmpl
531 .apply_template(self.model, &self.runner.llama_template, &messages)
532 .map_err(GenericRunnerError::ApplyChatTemplate)?;
533
534 let token_list = self.model.str_to_token(&formatted_prompt, AddBos::Always)?;
536 let mut batch = LlamaBatch::new(self.runner.ctx_size.get() as usize, 1);
537 let token_list_len = token_list.len();
538 for (i, token) in token_list.into_iter().enumerate() {
539 batch.add(token, i as i32, &[0], i == token_list_len - 1)?;
540 }
541 self.ctx.as_mut().unwrap().decode(&mut batch)?;
542
543 let mut preparation = Runtime {
545 sampler: self.req.sampling.to_llama(),
546 decoder: UTF_8.new_decoder(),
547 batch,
548 n_past: token_list_len as i32,
549 step: 0,
550 };
551 if let Some(llguidance) = &self.req.llguidance {
552 let llg_sampler = llguidance.to_llama(&self.runner.model)?;
553 preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
554 }
555 self.runtime = Some(preparation);
556
557 Ok(())
558 }
559}
560
561impl<'a, Message, Runner, Tmpl> Iterator for Gemma3Stream<'a, Message, Runner, Tmpl>
562where
563 Tmpl: ChatTemplate,
564 Self: PrepareRun<Tmpl::Error>,
565{
566 type Item = Result<String, GenericRunnerError<Tmpl::Error>>;
567
568 fn next(&mut self) -> Option<Self::Item> {
569 if self.done {
570 return None;
571 }
572
573 if let Some(result) = self.ctx_source.take() {
574 match result {
575 Ok(ctx) => self.ctx = Some(ctx),
576 Err(err) => {
577 self.done = true;
578 return Some(Err(err));
579 }
580 }
581 }
582
583 if self.runtime.is_none()
584 && let Err(err) = self.prepare()
585 {
586 self.done = true;
587 return Some(Err(err));
588 }
589 let Runtime {
590 sampler,
591 decoder,
592 batch,
593 n_past,
594 step,
595 } = self.runtime.as_mut().unwrap();
596
597 if *step >= self.req.max_seq {
598 self.done = true;
599 return None;
600 }
601
602 let ctx = self.ctx.as_mut().unwrap();
604 let model = self.model;
605 let sample_idx = batch.n_tokens() - 1;
606 let mut sample = |token: LlamaToken,
607 sampler: &mut LlamaSampler,
608 ctx: &mut LlamaContext<'a>,
609 step: usize|
610 -> Result<Option<String>, GenericRunnerError<Tmpl::Error>> {
611 sampler.accept(token);
612 if model.is_eog_token(token) {
613 return Ok(None);
614 }
615 batch.clear();
616 batch.add(token, *n_past + (step as i32), &[0], true)?;
617
618 ctx.decode(batch)?;
619
620 let piece = model.token_to_piece(token, decoder, true, None)?;
621 Ok(Some(piece))
622 };
623 if let Some(prefill) = self.req.prefill.take() {
624 log::debug!(target: "gemma", "prefill: {}", prefill);
625 let tokens = match model.str_to_token(&prefill, AddBos::Never) {
626 Ok(tokens) => tokens,
627 Err(err) => {
628 return Some(Err(err.into()));
629 }
630 };
631 log::debug!(target: "gemma", "prefill tokens: {:?}", tokens.iter().map(|t| t.0).collect::<Vec<_>>());
632 for token in tokens {
633 match sample(token, sampler, ctx, *step) {
634 Ok(_) => {}
635 Err(err) => return Some(Err(err.into())),
636 }
637 *step += 1;
638 }
639 Some(Ok(prefill))
640 } else {
641 let token = sampler.sample(ctx, sample_idx);
642 match sample(token, sampler, ctx, *step) {
643 Ok(Some(piece)) => {
644 *step += 1;
645 return Some(Ok(piece));
646 }
647 Ok(None) => {
648 self.done = true;
649 return None;
650 }
651 Err(err) => {
652 self.done = true;
653 return Some(Err(err));
654 }
655 }
656 }
657 }
658}
659
660impl<'s, Message, Runner, Tmpl> Gemma3Stream<'s, Message, Runner, Tmpl>
661where
662 Tmpl: ChatTemplate,
663{
664 fn new(
665 source: Result<LlamaContext<'s>, GenericRunnerError<Tmpl::Error>>,
666 req: GenericRunnerRequest<Message, Tmpl>,
667 runner: &'s Runner,
668 model: &'s LlamaModel,
669 ) -> Self {
670 Self {
671 ctx_source: Some(source),
672 ctx: None,
673 req,
674 runner,
675 model,
676 runtime: None,
677 done: false,
678 }
679 }
680}
681
682pub struct RunnerWithRecommendedSampling<Inner> {
683 pub inner: Inner,
684 pub default_sampling: SimpleSamplingParams,
685}
686
687impl<'a, Inner> RunnerWithRecommendedSampling<Inner> {
688 fn get_preprocessed_simple_sampling(
689 &self,
690 sampling: SimpleSamplingParams,
691 ) -> SimpleSamplingParams {
692 let mut sampling = sampling;
693 if sampling.top_k.is_none() {
694 sampling.top_k = self.default_sampling.top_k;
695 }
696 if sampling.top_p.is_none() {
697 sampling.top_p = self.default_sampling.top_p;
698 }
699 if sampling.temperature.is_none() {
700 sampling.temperature = self.default_sampling.temperature;
701 }
702 sampling
703 }
704}
705
706impl<'s, 'req, Inner, Tmpl> VisionLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
707where
708 Tmpl: ChatTemplate,
709 Inner: VisionLmRunner<'s, 'req, Template = Tmpl>,
710{
711 type Response = <Inner as VisionLmRunner<'s, 'req>>::Response;
712 type Template = Tmpl;
713
714 fn stream_vlm_response(
715 &'s self,
716 mut request: GenericVisionLmRequest<'req, Tmpl>,
717 ) -> Self::Response {
718 request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
719 self.inner.stream_vlm_response(request)
720 }
721}
722
723impl<'s, 'req, Inner, Tmpl> TextLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
724where
725 Tmpl: ChatTemplate,
726 Inner: TextLmRunner<'s, 'req, Template = Tmpl>,
727{
728 type Response = <Inner as TextLmRunner<'s, 'req>>::Response;
729 type Template = Tmpl;
730
731 fn stream_lm_response(
732 &'s self,
733 mut request: GenericTextLmRequest<'req, Tmpl>,
734 ) -> Self::Response {
735 request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
736 self.inner.stream_lm_response(request)
737 }
738}
739
740impl<Inner> From<Inner> for RunnerWithRecommendedSampling<Inner> {
741 fn from(value: Inner) -> Self {
742 Self {
743 inner: value,
744 default_sampling: SimpleSamplingParams::default(),
745 }
746 }
747}
748
749fn build_hf_api() -> Result<hf_hub::api::tokio::Api, hf_hub::api::tokio::ApiError> {
750 let mut api = ApiBuilder::new()
751 .with_progress(std::io::stdin().is_terminal())
752 .with_token(std::env::var("HF_TOKEN").ok())
753 .with_chunk_size(Some(2 << 28));
754 if let Ok(endpoint) = std::env::var("HF_ENDPOINT") {
755 api = api.with_endpoint(endpoint);
756 }
757 if let Ok(cache) = std::env::var("HF_HOME") {
758 api = api.with_cache_dir(
759 PathBuf::from_str(&cache).expect("HF_HOME env var is not a valid path"),
760 );
761 }
762 api.build()
763}