1use encoding_rs::Decoder;
2use llama_cpp_2::{
3 context::LlamaContext,
4 llama_batch::LlamaBatch,
5 model::{AddBos, LlamaModel},
6 sampling::LlamaSampler,
7 token::LlamaToken,
8};
9
10use crate::{
11 GenericRunnerRequest, GenericTextLmRequest, GenericVisionLmRequest, TextLmRunner,
12 VisionLmRunner, error::GenericRunnerError, sample::SimpleSamplingParams,
13 template::ChatTemplate,
14};
15
16pub struct Gemma3Stream<'a, Message, Runner, Tmpl: ChatTemplate> {
17 pub(super) ctx_source: Option<Result<LlamaContext<'a>, GenericRunnerError<Tmpl::Error>>>,
18 pub(super) ctx: Option<LlamaContext<'a>>,
19 pub(super) req: GenericRunnerRequest<Message, Tmpl>,
20 pub(super) runner: &'a Runner,
21 pub(super) model: &'a LlamaModel,
22 pub(super) runtime: Option<Runtime<'a>>,
23 pub(super) done: bool,
24}
25
26pub(super) struct Runtime<'a> {
27 pub(super) sampler: LlamaSampler,
28 pub(super) decoder: Decoder,
29 pub(super) batch: LlamaBatch<'a>,
30 pub(super) n_past: i32,
31 pub(super) step: usize,
32}
33
34pub(super) trait PrepareRun<TmplErr> {
35 fn prepare(&mut self) -> Result<(), GenericRunnerError<TmplErr>>;
36}
37
38pub struct RunnerWithRecommendedSampling<Inner> {
39 pub inner: Inner,
40 pub default_sampling: SimpleSamplingParams,
41}
42
43impl<'a, Message, Runner, Tmpl> Iterator for Gemma3Stream<'a, Message, Runner, Tmpl>
44where
45 Tmpl: ChatTemplate,
46 Self: PrepareRun<Tmpl::Error>,
47{
48 type Item = Result<String, GenericRunnerError<Tmpl::Error>>;
49
50 fn next(&mut self) -> Option<Self::Item> {
51 if self.done {
52 return None;
53 }
54
55 if let Some(result) = self.ctx_source.take() {
56 match result {
57 Ok(ctx) => self.ctx = Some(ctx),
58 Err(err) => {
59 self.done = true;
60 return Some(Err(err));
61 }
62 }
63 }
64
65 if self.runtime.is_none()
66 && let Err(err) = self.prepare()
67 {
68 self.done = true;
69 return Some(Err(err));
70 }
71 let Runtime {
72 sampler,
73 decoder,
74 batch,
75 n_past,
76 step,
77 } = self.runtime.as_mut().unwrap();
78
79 if *step >= self.req.max_seq {
80 self.done = true;
81 return None;
82 }
83
84 let ctx = self.ctx.as_mut().unwrap();
86 let model = self.model;
87 let sample_idx = batch.n_tokens() - 1;
88 let mut sample = |token: LlamaToken,
89 sampler: &mut LlamaSampler,
90 ctx: &mut LlamaContext<'a>,
91 step: usize|
92 -> Result<Option<String>, GenericRunnerError<Tmpl::Error>> {
93 sampler.accept(token);
94 if model.is_eog_token(token) {
95 return Ok(None);
96 }
97 batch.clear();
98 batch.add(token, *n_past + (step as i32), &[0], true)?;
99
100 ctx.decode(batch)?;
101
102 let piece = model.token_to_piece(token, decoder, true, None)?;
103 Ok(Some(piece))
104 };
105 if let Some(prefill) = self.req.prefill.take() {
106 log::debug!(target: "gemma", "prefill: {}", prefill);
107 let tokens = match model.str_to_token(&prefill, AddBos::Never) {
108 Ok(tokens) => tokens,
109 Err(err) => {
110 return Some(Err(err.into()));
111 }
112 };
113 log::debug!(target: "gemma", "prefill tokens: {:?}", tokens.iter().map(|t| t.0).collect::<Vec<_>>());
114 for token in tokens {
115 match sample(token, sampler, ctx, *step) {
116 Ok(_) => {}
117 Err(err) => return Some(Err(err.into())),
118 }
119 *step += 1;
120 }
121 Some(Ok(prefill))
122 } else {
123 let token = sampler.sample(ctx, sample_idx);
124 match sample(token, sampler, ctx, *step) {
125 Ok(Some(piece)) => {
126 *step += 1;
127 return Some(Ok(piece));
128 }
129 Ok(None) => {
130 self.done = true;
131 return None;
132 }
133 Err(err) => {
134 self.done = true;
135 return Some(Err(err));
136 }
137 }
138 }
139 }
140}
141
142impl<'s, Message, Runner, Tmpl> Gemma3Stream<'s, Message, Runner, Tmpl>
143where
144 Tmpl: ChatTemplate,
145{
146 pub(crate) fn new(
147 source: Result<LlamaContext<'s>, GenericRunnerError<Tmpl::Error>>,
148 req: GenericRunnerRequest<Message, Tmpl>,
149 runner: &'s Runner,
150 model: &'s LlamaModel,
151 ) -> Self {
152 Self {
153 ctx_source: Some(source),
154 ctx: None,
155 req,
156 runner,
157 model,
158 runtime: None,
159 done: false,
160 }
161 }
162}
163
164impl<'a, Inner> RunnerWithRecommendedSampling<Inner> {
165 fn get_preprocessed_simple_sampling(
166 &self,
167 sampling: SimpleSamplingParams,
168 ) -> SimpleSamplingParams {
169 let mut sampling = sampling;
170 if sampling.top_k.is_none() {
171 sampling.top_k = self.default_sampling.top_k;
172 }
173 if sampling.top_p.is_none() {
174 sampling.top_p = self.default_sampling.top_p;
175 }
176 if sampling.temperature.is_none() {
177 sampling.temperature = self.default_sampling.temperature;
178 }
179 sampling
180 }
181}
182
183impl<'s, 'req, Inner, Tmpl> VisionLmRunner<'s, 'req, Tmpl> for RunnerWithRecommendedSampling<Inner>
184where
185 Inner: VisionLmRunner<'s, 'req, Tmpl>,
186 Tmpl: ChatTemplate,
187{
188 fn stream_vlm_response(
189 &'s self,
190 mut request: GenericVisionLmRequest<'req, Tmpl>,
191 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>> {
192 request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
193 self.inner.stream_vlm_response(request)
194 }
195}
196
197impl<'s, 'req, Inner, Tmpl> TextLmRunner<'s, 'req, Tmpl> for RunnerWithRecommendedSampling<Inner>
198where
199 Inner: TextLmRunner<'s, 'req, Tmpl>,
200 Tmpl: ChatTemplate,
201{
202 fn stream_lm_response(
203 &'s self,
204 mut request: GenericTextLmRequest<'req, Tmpl>,
205 ) -> impl Iterator<Item = Result<String, GenericRunnerError<Tmpl::Error>>>
206 where
207 Tmpl: ChatTemplate,
208 {
209 request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
210 self.inner.stream_lm_response(request)
211 }
212}
213
214impl<Inner> From<Inner> for RunnerWithRecommendedSampling<Inner> {
215 fn from(value: Inner) -> Self {
216 Self {
217 inner: value,
218 default_sampling: SimpleSamplingParams::default(),
219 }
220 }
221}