Skip to main content

llama_runner/runner/
ext.rs

1use crate::{
2    GenericTextLmRequest, GenericVisionLmRequest, TextLmRunner, VisionLmRunner,
3    error::GenericRunnerError, template::ChatTemplate,
4};
5
6pub trait TextLmRunnerExt<'s, 'req, Tmpl>
7where
8    Tmpl: ChatTemplate,
9{
10    fn get_lm_response(
11        &'s self,
12        request: GenericTextLmRequest<'req, Tmpl>,
13    ) -> Result<String, GenericRunnerError<Tmpl::Error>>;
14}
15
16pub trait VisionLmRunnerExt<'s, 'req, Tmpl>
17where
18    Tmpl: ChatTemplate,
19{
20    fn get_vlm_response(
21        &'s self,
22        request: GenericVisionLmRequest<'req, Tmpl>,
23    ) -> Result<String, GenericRunnerError<Tmpl::Error>>;
24}
25
26impl<'s, 'req, TextRunner, Tmpl> TextLmRunnerExt<'s, 'req, Tmpl> for TextRunner
27where
28    TextRunner: TextLmRunner<'s, 'req, Tmpl>,
29    Tmpl: ChatTemplate,
30{
31    fn get_lm_response(
32        &'s self,
33        request: GenericTextLmRequest<'req, Tmpl>,
34    ) -> Result<String, GenericRunnerError<Tmpl::Error>> {
35        self.stream_lm_response(request)
36            .collect::<Result<String, _>>()
37    }
38}
39
40impl<'s, 'req, VisionRunner, Tmpl> VisionLmRunnerExt<'s, 'req, Tmpl> for VisionRunner
41where
42    VisionRunner: VisionLmRunner<'s, 'req, Tmpl>,
43    Tmpl: ChatTemplate,
44{
45    fn get_vlm_response(
46        &'s self,
47        request: GenericVisionLmRequest<'req, Tmpl>,
48    ) -> Result<String, GenericRunnerError<Tmpl::Error>> {
49        self.stream_vlm_response(request)
50            .collect::<Result<String, _>>()
51    }
52}