Skip to main content

llama_runner/runner/
stream.rs

1use encoding_rs::Decoder;
2use llama_cpp_2::{
3    context::LlamaContext,
4    llama_batch::LlamaBatch,
5    model::{AddBos, LlamaModel},
6    sampling::LlamaSampler,
7    token::LlamaToken,
8};
9
10use crate::{
11    GenericRunnerRequest, GenericTextLmRequest, GenericVisionLmRequest, TextLmRunner,
12    VisionLmRunner, error::GenericRunnerError, sample::SimpleSamplingParams,
13    template::ChatTemplate,
14};
15
16pub struct Gemma3Stream<'a, Message, Runner, Tmpl: ChatTemplate> {
17    pub(super) ctx_source: Option<Result<LlamaContext<'a>, GenericRunnerError<Tmpl::Error>>>,
18    pub(super) ctx: Option<LlamaContext<'a>>,
19    pub(super) req: GenericRunnerRequest<Message, Tmpl>,
20    pub(super) runner: &'a Runner,
21    pub(super) model: &'a LlamaModel,
22    pub(super) runtime: Option<Runtime<'a>>,
23    pub(super) done: bool,
24}
25
26pub(super) struct Runtime<'a> {
27    pub(super) sampler: LlamaSampler,
28    pub(super) decoder: Decoder,
29    pub(super) batch: LlamaBatch<'a>,
30    pub(super) n_past: i32,
31    pub(super) step: usize,
32}
33
34pub(super) trait PrepareRun<TmplErr> {
35    fn prepare(&mut self) -> Result<(), GenericRunnerError<TmplErr>>;
36}
37
38pub struct RunnerWithRecommendedSampling<Inner> {
39    pub inner: Inner,
40    pub default_sampling: SimpleSamplingParams,
41}
42
43impl<'a, Message, Runner, Tmpl> Iterator for Gemma3Stream<'a, Message, Runner, Tmpl>
44where
45    Tmpl: ChatTemplate,
46    Self: PrepareRun<Tmpl::Error>,
47{
48    type Item = Result<String, GenericRunnerError<Tmpl::Error>>;
49
50    fn next(&mut self) -> Option<Self::Item> {
51        if self.done {
52            return None;
53        }
54
55        if let Some(result) = self.ctx_source.take() {
56            match result {
57                Ok(ctx) => self.ctx = Some(ctx),
58                Err(err) => {
59                    self.done = true;
60                    return Some(Err(err));
61                }
62            }
63        }
64
65        if self.runtime.is_none()
66            && let Err(err) = self.prepare()
67        {
68            self.done = true;
69            return Some(Err(err));
70        }
71        let Runtime {
72            sampler,
73            decoder,
74            batch,
75            n_past,
76            step,
77        } = self.runtime.as_mut().unwrap();
78
79        if *step >= self.req.max_seq {
80            self.done = true;
81            return None;
82        }
83
84        // Sample response token
85        let ctx = self.ctx.as_mut().unwrap();
86        let model = self.model;
87        let sample_idx = batch.n_tokens() - 1;
88        let mut sample = |token: LlamaToken,
89                          sampler: &mut LlamaSampler,
90                          ctx: &mut LlamaContext<'a>,
91                          step: usize|
92         -> Result<Option<String>, GenericRunnerError<Tmpl::Error>> {
93            sampler.accept(token);
94            if model.is_eog_token(token) {
95                return Ok(None);
96            }
97            batch.clear();
98            batch.add(token, *n_past + (step as i32), &[0], true)?;
99
100            ctx.decode(batch)?;
101
102            let piece = model.token_to_piece(token, decoder, true, None)?;
103            Ok(Some(piece))
104        };
105        if let Some(prefill) = self.req.prefill.take() {
106            log::debug!(target: "gemma", "prefill: {}", prefill);
107            let tokens = match model.str_to_token(&prefill, AddBos::Never) {
108                Ok(tokens) => tokens,
109                Err(err) => {
110                    return Some(Err(err.into()));
111                }
112            };
113            log::debug!(target: "gemma", "prefill tokens: {:?}", tokens.iter().map(|t| t.0).collect::<Vec<_>>());
114            for token in tokens {
115                match sample(token, sampler, ctx, *step) {
116                    Ok(_) => {}
117                    Err(err) => return Some(Err(err.into())),
118                }
119                *step += 1;
120            }
121            Some(Ok(prefill))
122        } else {
123            let token = sampler.sample(ctx, sample_idx);
124            match sample(token, sampler, ctx, *step) {
125                Ok(Some(piece)) => {
126                    *step += 1;
127                    return Some(Ok(piece));
128                }
129                Ok(None) => {
130                    self.done = true;
131                    return None;
132                }
133                Err(err) => {
134                    self.done = true;
135                    return Some(Err(err));
136                }
137            }
138        }
139    }
140}
141
142impl<'s, Message, Runner, Tmpl> Gemma3Stream<'s, Message, Runner, Tmpl>
143where
144    Tmpl: ChatTemplate,
145{
146    pub(crate) fn new(
147        source: Result<LlamaContext<'s>, GenericRunnerError<Tmpl::Error>>,
148        req: GenericRunnerRequest<Message, Tmpl>,
149        runner: &'s Runner,
150        model: &'s LlamaModel,
151    ) -> Self {
152        Self {
153            ctx_source: Some(source),
154            ctx: None,
155            req,
156            runner,
157            model,
158            runtime: None,
159            done: false,
160        }
161    }
162}
163
164impl<'a, Inner> RunnerWithRecommendedSampling<Inner> {
165    fn get_preprocessed_simple_sampling(
166        &self,
167        sampling: SimpleSamplingParams,
168    ) -> SimpleSamplingParams {
169        let mut sampling = sampling;
170        if sampling.top_k.is_none() {
171            sampling.top_k = self.default_sampling.top_k;
172        }
173        if sampling.top_p.is_none() {
174            sampling.top_p = self.default_sampling.top_p;
175        }
176        if sampling.temperature.is_none() {
177            sampling.temperature = self.default_sampling.temperature;
178        }
179        sampling
180    }
181}
182
183impl<'s, 'req, Inner, Tmpl> VisionLmRunner<'s, 'req, Tmpl> for RunnerWithRecommendedSampling<Inner>
184where
185    Inner: VisionLmRunner<'s, 'req, Tmpl>,
186    Tmpl: ChatTemplate,
187{
188    fn stream_vlm_response(
189        &'s self,
190        mut request: GenericVisionLmRequest<'req, Tmpl>,
191    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
192        request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
193        self.inner.stream_vlm_response(request)
194    }
195}
196
197impl<'s, 'req, Inner, Tmpl> TextLmRunner<'s, 'req, Tmpl> for RunnerWithRecommendedSampling<Inner>
198where
199    Inner: TextLmRunner<'s, 'req, Tmpl>,
200    Tmpl: ChatTemplate,
201{
202    fn stream_lm_response(
203        &'s self,
204        mut request: GenericTextLmRequest<'req, Tmpl>,
205    ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
206    where
207        Tmpl: ChatTemplate,
208    {
209        request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
210        self.inner.stream_lm_response(request)
211    }
212}
213
214impl<Inner> From<Inner> for RunnerWithRecommendedSampling<Inner> {
215    fn from(value: Inner) -> Self {
216        Self {
217            inner: value,
218            default_sampling: SimpleSamplingParams::default(),
219        }
220    }
221}