1use std::{
2 io::IsTerminal,
3 num::NonZeroU32,
4 path::{Path, PathBuf},
5 str::FromStr,
6 sync::LazyLock,
7};
8
9use encoding_rs::{Decoder, UTF_8};
10use hf_hub::api::tokio::ApiBuilder;
11use llama_cpp_2::{
12 LlamaContextLoadError,
13 context::{LlamaContext, params::LlamaContextParams},
14 llama_backend::LlamaBackend,
15 llama_batch::LlamaBatch,
16 model::{AddBos, LlamaChatTemplate, LlamaModel},
17 mtmd::{self, MtmdBitmap, MtmdContext, MtmdInputText},
18 sampling::LlamaSampler,
19 token::LlamaToken,
20};
21use strum::Display;
22
23use crate::{
24 error::{CreateLlamaCppRunnerError, GenericRunnerError},
25 sample::{LlguidanceSamplingParams, SimpleSamplingParams},
26 template::{ChatTemplate, ModelChatTemplate},
27};
28
29pub const QWEN_3D5_4B_GUFF_MODEL_ID: &str = "unsloth/Qwen3.5-4B-GGUF";
30pub const QWEN_3D5_4B_GUFF_MODDEL_FILENAME: &str = "Qwen3.5-4B-Q4_K_M.gguf";
31pub const QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
32
33pub const GEMMA_3_1B_GUFF_MODEL_ID: &str = "google/gemma-3-1b-it-qat-q4_0-gguf";
34pub const GEMMA_3_1B_GUFF_MODEL_FILENAME: &str = "gemma-3-1b-it-q4_0.gguf";
35
36pub trait TextLmRunner<'s, 'req> {
37 fn stream_lm_response<Tmpl>(
38 &'s self,
39 request: GenericTextLmRequest<'req, Tmpl>,
40 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
41 where
42 Tmpl: ChatTemplate;
43}
44
45pub trait VisionLmRunner<'s, 'req> {
46 fn stream_vlm_response<Tmpl>(
47 &'s self,
48 request: GenericVisionLmRequest<'req, Tmpl>,
49 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
50 where
51 Tmpl: ChatTemplate;
52}
53
54#[derive(Debug, Clone)]
55pub struct GenericRunnerRequest<MsgCt, Tmpl> {
56 pub messages: Vec<(MessageRole, MsgCt)>,
57 pub sampling: SimpleSamplingParams,
58 pub llguidance: Option<LlguidanceSamplingParams>,
59 pub max_seq: usize,
60 pub prefill: Option<String>,
61 pub tmpl: Tmpl,
62}
63
64impl<M, T> Default for GenericRunnerRequest<M, T>
65where
66 T: Default,
67{
68 fn default() -> Self {
69 Self {
70 messages: vec![],
71 sampling: Default::default(),
72 llguidance: None,
73 max_seq: usize::MAX,
74 prefill: None,
75 tmpl: Default::default(),
76 }
77 }
78}
79
80pub type GenericTextLmRequest<'a, Tmpl> = GenericRunnerRequest<&'a str, Tmpl>;
81pub type GenericVisionLmRequest<'a, Tmpl> = GenericRunnerRequest<ImageOrText<'a>, Tmpl>;
82
83pub type RunnerRequest<'a, MsgCnt> = GenericRunnerRequest<MsgCnt, ModelChatTemplate>;
84pub type TextLmRequest<'a> = RunnerRequest<'a, &'a str>;
85pub type VisionLmRequest<'a> = RunnerRequest<'a, ImageOrText<'a>>;
86
87pub trait TextLmRunnerExt<'s, 'req> {
88 fn get_lm_response<Tmpl>(
89 &'s self,
90 request: GenericTextLmRequest<'req, Tmpl>,
91 ) -> Result<String, GenericRunnerError<Tmpl::Error>>
92 where
93 Tmpl: ChatTemplate;
94}
95
96pub trait VisionLmRunnerExt<'s, 'req> {
97 fn get_vlm_response<Tmpl>(
98 &'s self,
99 request: GenericVisionLmRequest<'req, Tmpl>,
100 ) -> Result<String, GenericRunnerError<Tmpl::Error>>
101 where
102 Tmpl: ChatTemplate;
103}
104
105impl<'s, 'req, TextRunner> TextLmRunnerExt<'s, 'req> for TextRunner
106where
107 TextRunner: TextLmRunner<'s, 'req>,
108{
109 fn get_lm_response<Tmpl>(
110 &'s self,
111 request: GenericTextLmRequest<'req, Tmpl>,
112 ) -> Result<String, GenericRunnerError<Tmpl::Error>>
113 where
114 Tmpl: ChatTemplate,
115 {
116 self.stream_lm_response(request)
117 .collect::<Result<String, _>>()
118 }
119}
120
121impl<'s, 'req, VisionRunner> VisionLmRunnerExt<'s, 'req> for VisionRunner
122where
123 VisionRunner: VisionLmRunner<'s, 'req>,
124{
125 fn get_vlm_response<Tmpl>(
126 &'s self,
127 request: GenericVisionLmRequest<'req, Tmpl>,
128 ) -> Result<String, GenericRunnerError<Tmpl::Error>>
129 where
130 Tmpl: ChatTemplate,
131 {
132 self.stream_vlm_response(request)
133 .collect::<Result<String, _>>()
134 }
135}
136
137#[derive(Debug, Clone, Display, PartialEq, Eq)]
138pub enum MessageRole {
139 #[strum(to_string = "assistant")]
140 Assistant,
141 #[strum(to_string = "user")]
142 User,
143 #[strum(to_string = "system")]
144 System,
145 #[strum(to_string = "{0}")]
146 Custom(&'static str),
147}
148
149#[derive(Debug, Clone)]
150pub enum ImageOrText<'a> {
151 Text(&'a str),
152 Image(&'a image::DynamicImage),
153}
154
155pub struct Gemma3TextRunner {
156 model: LlamaModel,
157 llama_template: LlamaChatTemplate,
158 ctx_size: NonZeroU32,
159}
160
161impl Gemma3TextRunner {
162 pub async fn new(
163 model_id: impl ToString,
164 model_file: impl AsRef<str>,
165 ctx_size: NonZeroU32,
166 ) -> Result<Self, CreateLlamaCppRunnerError> {
167 let repo = build_hf_api()?.model(model_id.to_string());
168 Self::from_file(repo.get(model_file.as_ref()).await?, ctx_size)
169 }
170
171 pub fn recommend_sampling() -> SimpleSamplingParams {
172 SimpleSamplingParams {
173 top_p: Some(0.95f32),
174 top_k: Some(64),
175 temperature: Some(1f32),
176 ..Default::default()
177 }
178 }
179
180 pub fn from_file(
181 model_file: impl AsRef<Path>,
182 ctx_size: NonZeroU32,
183 ) -> Result<Self, CreateLlamaCppRunnerError> {
184 let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
185
186 let chat_template = model.chat_template(None)?;
187 Ok(Self {
188 model,
189 llama_template: chat_template,
190 ctx_size,
191 })
192 }
193
194 pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
195 {
196 let inner = Self::new(
197 GEMMA_3_1B_GUFF_MODEL_ID,
198 GEMMA_3_1B_GUFF_MODEL_FILENAME,
199 32_000.try_into().unwrap(),
200 )
201 .await?;
202 Ok(RunnerWithRecommendedSampling {
203 inner,
204 default_sampling: Self::recommend_sampling(),
205 })
206 }
207}
208
209impl<'s, 'req> TextLmRunner<'s, 'req> for Gemma3TextRunner {
210 fn stream_lm_response<Tmpl>(
211 &'s self,
212 request: GenericTextLmRequest<'req, Tmpl>,
213 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
214 where
215 Tmpl: ChatTemplate,
216 {
217 let ctx = self
218 .model
219 .new_context(
220 &LLAMA_BACKEND,
221 LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
222 )
223 .map_err(|err| GenericRunnerError::from(err));
224 Gemma3Stream::new(ctx, request, self, &self.model)
225 }
226}
227
228pub struct Gemma3VisionRunner {
229 model: LlamaModel,
230 chat_template: LlamaChatTemplate,
231 mtmd_ctx: MtmdContext,
232 ctx_size: NonZeroU32,
233}
234
235static LLAMA_BACKEND: LazyLock<LlamaBackend> = LazyLock::new(|| {
236 llama_cpp_2::send_logs_to_tracing(llama_cpp_2::LogOptions::default());
237 LlamaBackend::init().unwrap()
238});
239
240impl Gemma3VisionRunner {
241 pub async fn new(
242 repo_id: impl ToString,
243 model_file: impl AsRef<str>,
244 multimodel_file: impl AsRef<str>,
245 ctx_size: NonZeroU32,
246 ) -> Result<Self, CreateLlamaCppRunnerError> {
247 let repo = build_hf_api()?.model(repo_id.to_string());
248 let model = LlamaModel::load_from_file(
249 &LLAMA_BACKEND,
250 repo.get(model_file.as_ref()).await?,
251 &Default::default(),
252 )?;
253
254 let mtmd_ctx = MtmdContext::init_from_file(
255 repo.get(multimodel_file.as_ref()).await?.to_str().unwrap(),
256 &model,
257 &Default::default(),
258 )?;
259
260 let chat_template = model.chat_template(None)?;
261
262 Ok(Self {
263 model,
264 mtmd_ctx,
265 chat_template,
266 ctx_size,
267 })
268 }
269
270 pub fn from_files(
271 model_file: impl AsRef<Path>,
272 multimodel_file: impl AsRef<Path>,
273 ctx_size: NonZeroU32,
274 ) -> Result<Self, CreateLlamaCppRunnerError> {
275 let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
276 let mtmd_ctx = MtmdContext::init_from_file(
277 multimodel_file.as_ref().as_os_str().to_str().unwrap(),
278 &model,
279 &Default::default(),
280 )?;
281
282 let chat_template = model.chat_template(None)?;
283
284 Ok(Self {
285 model,
286 mtmd_ctx,
287 chat_template,
288 ctx_size,
289 })
290 }
291
292 fn new_context_window(&self) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
293 self.model.new_context(
294 &LLAMA_BACKEND,
295 LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
296 )
297 }
298
299 pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
300 {
301 let inner = Self::new(
302 QWEN_3D5_4B_GUFF_MODEL_ID,
303 QWEN_3D5_4B_GUFF_MODDEL_FILENAME,
304 QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME,
305 16384u32.try_into().unwrap(),
306 )
307 .await?;
308 Ok(RunnerWithRecommendedSampling {
309 inner: inner,
310 default_sampling: SimpleSamplingParams {
311 top_p: Some(0.8f32),
312 top_k: Some(20),
313 temperature: Some(0.7f32),
314 presence_penalty: Some(1.5),
315 repetition_penalty: Some(1.0),
316 seed: None,
317 },
318 })
319 }
320}
321
322impl<'s, 'req> VisionLmRunner<'s, 'req> for Gemma3VisionRunner {
323 fn stream_vlm_response<Tmpl>(
324 &'s self,
325 request: GenericVisionLmRequest<'req, Tmpl>,
326 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
327 where
328 Tmpl: ChatTemplate,
329 {
330 let ctx = self
331 .new_context_window()
332 .map_err(|err| GenericRunnerError::from(err));
333 Gemma3Stream::new(ctx, request, self, &self.model)
334 }
335}
336
337impl<'s, 'req> TextLmRunner<'s, 'req> for Gemma3VisionRunner {
338 fn stream_lm_response<Tmpl>(
339 &'s self,
340 request: GenericTextLmRequest<'req, Tmpl>,
341 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
342 where
343 Tmpl: ChatTemplate,
344 {
345 self.stream_vlm_response(request.into())
346 }
347}
348
349impl<'a, Tmpl> From<GenericTextLmRequest<'a, Tmpl>> for GenericVisionLmRequest<'a, Tmpl> {
350 fn from(value: GenericTextLmRequest<'a, Tmpl>) -> Self {
351 Self {
352 messages: value
353 .messages
354 .into_iter()
355 .map(|(role, text)| (role, ImageOrText::Text(text)))
356 .collect(),
357 sampling: value.sampling,
358 llguidance: value.llguidance,
359 max_seq: value.max_seq,
360 prefill: value.prefill,
361 tmpl: value.tmpl,
362 }
363 }
364}
365
366pub struct Gemma3Stream<'a, Message, Runner, Tmpl: ChatTemplate> {
367 ctx_source: Option<Result<LlamaContext<'a>, GenericRunnerError<Tmpl::Error>>>,
368 ctx: Option<LlamaContext<'a>>,
369 req: GenericRunnerRequest<Message, Tmpl>,
370 runner: &'a Runner,
371 model: &'a LlamaModel,
372 runtime: Option<Runtime<'a>>,
373 done: bool,
374}
375
376struct Runtime<'a> {
377 sampler: LlamaSampler,
378 decoder: Decoder,
379 batch: LlamaBatch<'a>,
380 n_past: i32,
381 step: usize,
382}
383
384trait PrepareRun<TmplErr> {
385 fn prepare(&mut self) -> Result<(), GenericRunnerError<TmplErr>>;
386}
387
388impl<Tmpl> PrepareRun<Tmpl::Error> for Gemma3Stream<'_, ImageOrText<'_>, Gemma3VisionRunner, Tmpl>
389where
390 Tmpl: ChatTemplate,
391{
392 fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
393 let media_marker = mtmd::mtmd_default_marker();
395 let messages = self
396 .req
397 .messages
398 .iter()
399 .fold(
400 Vec::<(MessageRole, String)>::new(),
401 |mut acc, (role, message)| {
402 let text = match message {
403 ImageOrText::Text(text) => text,
404 ImageOrText::Image(_) => media_marker,
405 };
406 if let Some(last) = acc.last()
407 && last.0 == *role
408 {
409 let (_, adj) = acc.remove(acc.len() - 1);
411 acc.push((role.clone(), format!("{0}\n{text}", adj)));
412 acc
413 } else {
414 acc.push((role.clone(), text.to_string()));
415 acc
416 }
417 },
418 )
419 .into_iter()
420 .collect::<Vec<_>>();
421 log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
422
423 let formatted_prompt = self
425 .req
426 .tmpl
427 .apply_template(self.model, &self.runner.chat_template, &messages)
428 .map_err(GenericRunnerError::ApplyChatTemplate)?;
429
430 let bitmaps = self
432 .req
433 .messages
434 .iter()
435 .filter_map(|msg| match &msg.1 {
436 ImageOrText::Image(image) => Some(image),
437 _ => None,
438 })
439 .enumerate()
440 .map(|(idx, im)| {
441 MtmdBitmap::from_image_data(
442 im.width(),
443 im.height(),
444 im.to_rgb8().to_vec().as_slice(),
445 )
446 .expect(format!("image#{} has corrupted RGB data", idx).as_str())
447 })
448 .collect::<Vec<_>>();
449 let bitmap_refs = bitmaps.iter().collect::<Vec<_>>();
450 let chunks = self.runner.mtmd_ctx.tokenize(
451 MtmdInputText {
452 text: formatted_prompt,
453 add_special: true,
454 parse_special: true,
455 },
456 &bitmap_refs,
457 )?;
458 log::debug!(target: "gemma", "tokenization resulted in {} chunks", chunks.len());
459 let n_past = chunks.eval_chunks(
460 &self.runner.mtmd_ctx,
461 self.ctx.as_ref().unwrap(),
462 0,
463 0,
464 1,
465 true,
466 )?;
467
468 let mut preparation = Runtime {
470 sampler: self.req.sampling.to_llama(),
471 decoder: UTF_8.new_decoder(),
472 batch: LlamaBatch::new(self.runner.ctx_size.get() as usize, 1),
473 n_past,
474 step: 0,
475 };
476 if let Some(llguidance) = &self.req.llguidance {
477 let llg_sampler = llguidance.to_llama(&self.runner.model)?;
478 preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
479 }
480 self.runtime = Some(preparation);
481
482 Ok(())
483 }
484}
485
486impl<S: AsRef<str>, Tmpl> PrepareRun<Tmpl::Error> for Gemma3Stream<'_, S, Gemma3TextRunner, Tmpl>
487where
488 Tmpl: ChatTemplate,
489{
490 fn prepare(&mut self) -> Result<(), GenericRunnerError<Tmpl::Error>> {
491 let messages = self
493 .req
494 .messages
495 .iter()
496 .fold(
497 Vec::<(MessageRole, String)>::new(),
498 |mut acc, (role, message)| {
499 if let Some(last) = acc.last()
500 && last.0 == *role
501 {
502 let (_, adj) = acc.remove(acc.len() - 1);
504 acc.push((role.clone(), format!("{0}\n{1}", adj, message.as_ref())));
505 acc
506 } else {
507 acc.push((role.clone(), message.as_ref().to_string()));
508 acc
509 }
510 },
511 )
512 .into_iter()
513 .collect::<Vec<_>>();
514 log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
515
516 let formatted_prompt = self
518 .req
519 .tmpl
520 .apply_template(self.model, &self.runner.llama_template, &messages)
521 .map_err(GenericRunnerError::ApplyChatTemplate)?;
522
523 let token_list = self.model.str_to_token(&formatted_prompt, AddBos::Always)?;
525 let mut batch = LlamaBatch::new(self.runner.ctx_size.get() as usize, 1);
526 let token_list_len = token_list.len();
527 for (i, token) in token_list.into_iter().enumerate() {
528 batch.add(token, i as i32, &[0], i == token_list_len - 1)?;
529 }
530 self.ctx.as_mut().unwrap().decode(&mut batch)?;
531
532 let mut preparation = Runtime {
534 sampler: self.req.sampling.to_llama(),
535 decoder: UTF_8.new_decoder(),
536 batch,
537 n_past: token_list_len as i32,
538 step: 0,
539 };
540 if let Some(llguidance) = &self.req.llguidance {
541 let llg_sampler = llguidance.to_llama(&self.runner.model)?;
542 preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
543 }
544 self.runtime = Some(preparation);
545
546 Ok(())
547 }
548}
549
550impl<'a, Message, Runner, Tmpl> Iterator for Gemma3Stream<'a, Message, Runner, Tmpl>
551where
552 Tmpl: ChatTemplate,
553 Self: PrepareRun<Tmpl::Error>,
554{
555 type Item = Result<String, GenericRunnerError<Tmpl::Error>>;
556
557 fn next(&mut self) -> Option<Self::Item> {
558 if self.done {
559 return None;
560 }
561
562 if let Some(result) = self.ctx_source.take() {
563 match result {
564 Ok(ctx) => self.ctx = Some(ctx),
565 Err(err) => {
566 self.done = true;
567 return Some(Err(err));
568 }
569 }
570 }
571
572 if self.runtime.is_none()
573 && let Err(err) = self.prepare()
574 {
575 self.done = true;
576 return Some(Err(err));
577 }
578 let Runtime {
579 sampler,
580 decoder,
581 batch,
582 n_past,
583 step,
584 } = self.runtime.as_mut().unwrap();
585
586 if *step >= self.req.max_seq {
587 self.done = true;
588 return None;
589 }
590
591 let ctx = self.ctx.as_mut().unwrap();
593 let model = self.model;
594 let sample_idx = batch.n_tokens() - 1;
595 let mut sample = |token: LlamaToken,
596 sampler: &mut LlamaSampler,
597 ctx: &mut LlamaContext<'a>,
598 step: usize|
599 -> Result<Option<String>, GenericRunnerError<Tmpl::Error>> {
600 sampler.accept(token);
601 if model.is_eog_token(token) {
602 return Ok(None);
603 }
604 batch.clear();
605 batch.add(token, *n_past + (step as i32), &[0], true)?;
606
607 ctx.decode(batch)?;
608
609 let piece = model.token_to_piece(token, decoder, true, None)?;
610 Ok(Some(piece))
611 };
612 if let Some(prefill) = self.req.prefill.take() {
613 log::debug!(target: "gemma", "prefill: {}", prefill);
614 let tokens = match model.str_to_token(&prefill, AddBos::Never) {
615 Ok(tokens) => tokens,
616 Err(err) => {
617 return Some(Err(err.into()));
618 }
619 };
620 log::debug!(target: "gemma", "prefill tokens: {:?}", tokens.iter().map(|t| t.0).collect::<Vec<_>>());
621 for token in tokens {
622 match sample(token, sampler, ctx, *step) {
623 Ok(_) => {}
624 Err(err) => return Some(Err(err.into())),
625 }
626 *step += 1;
627 }
628 Some(Ok(prefill))
629 } else {
630 let token = sampler.sample(ctx, sample_idx);
631 match sample(token, sampler, ctx, *step) {
632 Ok(Some(piece)) => {
633 *step += 1;
634 return Some(Ok(piece));
635 }
636 Ok(None) => {
637 self.done = true;
638 return None;
639 }
640 Err(err) => {
641 self.done = true;
642 return Some(Err(err));
643 }
644 }
645 }
646 }
647}
648
649impl<'s, Message, Runner, Tmpl> Gemma3Stream<'s, Message, Runner, Tmpl>
650where
651 Tmpl: ChatTemplate,
652{
653 fn new(
654 source: Result<LlamaContext<'s>, GenericRunnerError<Tmpl::Error>>,
655 req: GenericRunnerRequest<Message, Tmpl>,
656 runner: &'s Runner,
657 model: &'s LlamaModel,
658 ) -> Self {
659 Self {
660 ctx_source: Some(source),
661 ctx: None,
662 req,
663 runner,
664 model,
665 runtime: None,
666 done: false,
667 }
668 }
669}
670
671pub struct RunnerWithRecommendedSampling<Inner> {
672 pub inner: Inner,
673 pub default_sampling: SimpleSamplingParams,
674}
675
676impl<'a, Inner> RunnerWithRecommendedSampling<Inner> {
677 fn get_preprocessed_simple_sampling(
678 &self,
679 sampling: SimpleSamplingParams,
680 ) -> SimpleSamplingParams {
681 let mut sampling = sampling;
682 if sampling.top_k.is_none() {
683 sampling.top_k = self.default_sampling.top_k;
684 }
685 if sampling.top_p.is_none() {
686 sampling.top_p = self.default_sampling.top_p;
687 }
688 if sampling.temperature.is_none() {
689 sampling.temperature = self.default_sampling.temperature;
690 }
691 sampling
692 }
693}
694
695impl<'s, 'req, Inner> VisionLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
696where
697 Inner: VisionLmRunner<'s, 'req>,
698{
699 fn stream_vlm_response<Tmpl>(
700 &'s self,
701 mut request: GenericVisionLmRequest<'req, Tmpl>,
702 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
703 where
704 Tmpl: ChatTemplate,
705 {
706 request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
707 self.inner.stream_vlm_response(request)
708 }
709}
710
711impl<'s, 'req, Inner> TextLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
712where
713 Inner: TextLmRunner<'s, 'req>,
714{
715 fn stream_lm_response<Tmpl>(
716 &'s self,
717 mut request: GenericTextLmRequest<'req, Tmpl>,
718 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
719 where
720 Tmpl: ChatTemplate,
721 {
722 request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
723 self.inner.stream_lm_response(request)
724 }
725}
726
727impl<Inner> From<Inner> for RunnerWithRecommendedSampling<Inner> {
728 fn from(value: Inner) -> Self {
729 Self {
730 inner: value,
731 default_sampling: SimpleSamplingParams::default(),
732 }
733 }
734}
735
736fn build_hf_api() -> Result<hf_hub::api::tokio::Api, hf_hub::api::tokio::ApiError> {
737 let mut api = ApiBuilder::new()
738 .with_progress(std::io::stdin().is_terminal())
739 .with_token(std::env::var("HF_TOKEN").ok())
740 .with_chunk_size(Some(2 << 28));
741 if let Ok(endpoint) = std::env::var("HF_ENDPOINT") {
742 api = api.with_endpoint(endpoint);
743 }
744 if let Ok(cache) = std::env::var("HF_HOME") {
745 api = api.with_cache_dir(
746 PathBuf::from_str(&cache).expect("HF_HOME env var is not a valid path"),
747 );
748 }
749 api.build()
750}