1use std::{collections::HashSet, sync::Arc};
31
32use anyhow::{Error, Result};
33use futures::stream::{self, StreamExt};
34use tracing as log;
35
36use crate::model_card::model::{ModelDeploymentCard, TokenizerKind};
37use dynamo_runtime::{
38 pipeline::{
39 async_trait, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream,
40 ServerStreamingEngine, SingleIn,
41 },
42 protocols::annotated::Annotated,
43};
44
45use crate::protocols::{
46 common::{
47 llm_backend::{BackendInput, BackendOutput, FinishReason, LLMEngineOutput},
48 StopConditions,
49 },
50 TokenIdType,
51};
52use crate::tokenizers::{DecodeStream, HuggingFaceTokenizer, Tokenizer};
53use tokenizers::Tokenizer as HfTokenizer;
54
55pub type ExecutionOutputStream = Annotated<LLMEngineOutput>;
57
58pub type ExecutionContext = ServerStreamingEngine<BackendInput, ExecutionOutputStream>;
60
61#[allow(dead_code)]
63pub struct Backend {
64 pub tokenizer: Option<Tokenizer>, validate_engine_decode: bool, }
67
68#[allow(dead_code)]
70struct DecoderUnfoldState {
71 stream: ManyOut<ExecutionOutputStream>,
72 decoder: Decoder,
73 validate_engine_decode: bool,
74}
75
76impl Backend {
77 pub async fn from_tokenizer(tokenizer: HfTokenizer) -> Result<Arc<Self>> {
78 let tokenizer = HuggingFaceTokenizer::from_tokenizer(tokenizer);
79 let tokenizer = Tokenizer::from(Arc::new(tokenizer));
80
81 Ok(Arc::new(Self {
82 tokenizer: Some(tokenizer),
83 validate_engine_decode: false,
84 }))
85 }
86
87 pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
88 let tokenizer = match &mdc.tokenizer {
89 Some(TokenizerKind::HfTokenizerJson(file)) => {
90 HfTokenizer::from_file(file).map_err(Error::msg)?
91 }
92 Some(TokenizerKind::GGUF(t)) => *t.clone(),
93 None => {
94 return Ok(Arc::new(Self {
95 tokenizer: None,
96 validate_engine_decode: false,
97 }));
98 }
99 };
100 Self::from_tokenizer(tokenizer).await
101 }
102
103 fn decoder(
104 &self,
105 stream: ManyOut<ExecutionOutputStream>,
106 stop_conditions: StopConditions,
107 ) -> anyhow::Result<DecoderUnfoldState> {
108 let Some(tokenizer) = self.tokenizer.as_ref() else {
109 anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
110 };
111 let decoder = Decoder::new(tokenizer.decode_stream(false), stop_conditions);
112
113 Ok(DecoderUnfoldState {
114 stream,
115 decoder,
116 validate_engine_decode: self.validate_engine_decode,
117 })
118 }
119}
120
121#[async_trait]
122impl
123 Operator<
124 SingleIn<BackendInput>,
125 ManyOut<Annotated<BackendOutput>>,
126 SingleIn<BackendInput>,
127 ManyOut<Annotated<LLMEngineOutput>>,
128 > for Backend
129{
130 async fn generate(
131 &self,
132 request: SingleIn<BackendInput>,
133 next: ServerStreamingEngine<BackendInput, Annotated<LLMEngineOutput>>,
134 ) -> Result<ManyOut<Annotated<BackendOutput>>> {
135 let stop_conditions = request.stop_conditions.clone();
136 let next_stream = next.generate(request).await?;
137
138 let context = next_stream.context();
139 let state = self.decoder(next_stream, stop_conditions)?;
140
141 let processed_stream = stream::unfold(state, |mut state| async move {
142 match state.stream.next().await {
143 Some(output) => {
144 if output.is_event() || output.data.is_none() {
149 return Some((output, state));
150 }
151
152 if let Some(data) = &output.data {
154 if data.text.is_some() && !state.validate_engine_decode {
155 return Some((output, state));
156 }
157 }
158
159 let data = output.data.as_ref().unwrap();
160
161 let result = state.decoder.process_token_ids(&data.token_ids).unwrap();
162
163 let finish_reason = match &result.stop_trigger {
165 Some(StopTrigger::MaxTokensLimit) => Some(FinishReason::Length),
166 Some(StopTrigger::HiddenStopTokenDetected(_)) => Some(FinishReason::Stop),
167 Some(StopTrigger::HiddenStopSequenceDetected(_)) => {
168 Some(FinishReason::Stop)
169 }
170 None => None,
171 };
172
173 if data.finish_reason.is_none() && finish_reason.is_some() {
174 tracing::debug!(
175 ?result.stop_trigger,
176 "upstream did not provide a finish reason; issuing a stop_generation request to free resources",
177 );
178 state.stream.context().stop_generating();
179 }
180
181 let text = result.text;
182 let tokens = result.tokens;
183
184 if state.validate_engine_decode {
185 if data.finish_reason != finish_reason {
186 log::warn!(
187 "finish reason mismatch: expected {:?}, got {:?}",
188 data.finish_reason,
189 finish_reason
190 );
191 }
192
193 if data.text.is_some() && data.text != text {
194 log::warn!("text mismatch: expected {:?}, got {:?}", data.text, text);
195 }
196 }
197
198 let mut output = output;
200 let mut data = output.data.take().unwrap();
201
202 data.finish_reason = finish_reason;
203 data.text = text;
204 data.tokens = Some(tokens);
205
206 output.data = Some(data);
207
208 Some((output, state))
209 }
210
211 None => None,
212 }
213 });
214
215 let stream = processed_stream.map(move |output| {
218 output.map_data(|data| {
219 Ok(BackendOutput {
220 token_ids: data.token_ids,
221 tokens: data.tokens.unwrap_or_default(),
222 text: data.text,
223 cum_log_probs: data.cum_log_probs,
224 log_probs: data.log_probs,
225 finish_reason: data.finish_reason,
226 })
228 })
229 });
230
231 Ok(ResponseStream::new(Box::pin(stream), context))
232 }
233}
234
235#[allow(dead_code)]
243pub struct Decoder {
244 decode_stream: DecodeStream,
245
246 min_tokens: u32,
248
249 hidden_stop_ids: HashSet<TokenIdType>,
252
253 hidden_stop_sequences: Vec<String>,
256
257 generated_tokens: u32,
259
260 jail: String,
262
263 jail_max_bytes: usize,
265
266 jailed_bytes: usize,
268 }
271
272#[allow(dead_code)]
273#[derive(Debug)]
274pub enum StopTrigger {
275 MaxTokensLimit,
276 HiddenStopTokenDetected(TokenIdType),
277 HiddenStopSequenceDetected(String),
278}
279
280impl StopTrigger {
281 pub fn should_hide_text(&self) -> bool {
282 match self {
283 StopTrigger::MaxTokensLimit => false,
284 StopTrigger::HiddenStopTokenDetected(_) => true,
285 StopTrigger::HiddenStopSequenceDetected(_) => true,
286 }
287 }
288}
289
290pub struct StepResult {
291 pub token: Option<String>,
292 pub stop_trigger: Option<StopTrigger>,
293}
294
295impl StepResult {
296 fn ok(token: Option<String>) -> Self {
297 Self {
298 token,
299 stop_trigger: None,
300 }
301 }
302
303 fn with_stop_trigger(token: Option<String>, stop_trigger: StopTrigger) -> Self {
304 Self {
305 token,
306 stop_trigger: Some(stop_trigger),
307 }
308 }
309}
310
311pub struct SeqResult {
313 pub tokens: Vec<Option<String>>, pub text: Option<String>, pub stop_trigger: Option<StopTrigger>, }
317
318#[allow(dead_code)]
319impl Decoder {
320 pub fn new(
321 decode_stream: DecodeStream,
322 stop_condition: StopConditions,
323 ) -> Self {
325 let hidden_stop_ids: HashSet<TokenIdType> = stop_condition
326 .stop_token_ids_hidden
327 .unwrap_or_default()
328 .iter()
329 .copied()
330 .collect();
331
332 let hidden_stop_sequences: Vec<String> = stop_condition
333 .stop
334 .unwrap_or_default()
335 .iter()
336 .map(|x| x.to_string())
337 .collect();
338
339 let jail_max_bytes = hidden_stop_sequences
340 .iter()
341 .map(|x| x.len())
342 .max()
343 .unwrap_or(0);
344
345 Self {
346 decode_stream,
347 hidden_stop_ids,
348 hidden_stop_sequences,
349 min_tokens: stop_condition.min_tokens.unwrap_or(0),
352 generated_tokens: 0,
353 jail: String::new(),
354 jail_max_bytes,
355 jailed_bytes: 0,
356 }
357 }
358
359 pub fn step(&mut self, token_id: TokenIdType) -> Result<StepResult> {
366 self.generated_tokens += 1;
368
369 let token = self.decode_stream.step(token_id)?;
371
372 if self.generated_tokens < self.min_tokens {
374 return Ok(StepResult::ok(token));
375 }
376
377 if self.hidden_stop_ids.contains(&token_id) {
379 return Ok(StepResult::with_stop_trigger(
380 token,
381 StopTrigger::HiddenStopTokenDetected(token_id),
382 ));
383 }
384
385 if self.jail_max_bytes > 0 {
388 if let Some(token) = &token {
389 let pre_append = self.jail.len();
390 log::debug!("pre_append: {}", pre_append);
391 log::debug!("jail: {}", self.jail);
392 self.jail.push_str(token);
393 log::debug!("post_append: {}", self.jail.len());
394 log::debug!("jail: {}", self.jail);
395
396 for seq in &self.hidden_stop_sequences {
397 log::debug!("stop seq: {}", seq);
398 if let Some(offset) =
399 galil_seiferas::gs_find(self.jail.as_bytes(), seq.as_bytes())
400 {
401 log::debug!("offset: {}", offset);
402 let partial_token = if offset >= pre_append {
410 self.jail[pre_append..offset].to_string()
411 } else {
412 "".to_string()
413 };
414 return Ok(StepResult::with_stop_trigger(
415 Some(partial_token),
416 StopTrigger::HiddenStopSequenceDetected(seq.to_string()),
417 ));
418 }
419 }
420
421 if self.jail.len() > self.jail_max_bytes {
422 let drain_len = self.jail.len() - self.jail_max_bytes;
424 self.jail.drain(0..drain_len);
425 }
426 }
427 }
428
429 Ok(StepResult::ok(token))
430 }
431
432 pub fn process_token_ids(&mut self, token_ids: &[TokenIdType]) -> Result<SeqResult> {
433 let mut text: Option<String> = None;
434 let mut tokens = Vec::new();
435
436 for token_id in token_ids {
437 let StepResult {
438 token,
439 stop_trigger,
440 } = self.step(*token_id)?;
441
442 let hide_text = stop_trigger
443 .as_ref()
444 .map(|x| x.should_hide_text())
445 .unwrap_or(false);
446
447 if !hide_text {
448 if let Some(token) = &token {
449 text.get_or_insert_with(String::new).push_str(token);
450 }
451 }
452 tokens.push(token);
453
454 if let Some(stop_trigger) = stop_trigger {
455 return Ok(SeqResult {
456 tokens,
457 text,
458 stop_trigger: Some(stop_trigger),
459 });
460 }
461 }
462
463 Ok(SeqResult {
464 tokens,
465 text,
466 stop_trigger: None,
467 })
468 }
469
470 fn return_token(&self, token: Option<String>) -> StepResult {
471 StepResult {
472 token,
473 stop_trigger: None,
474 }
475 }
476
477 fn return_with_stop_trigger(
478 &self,
479 token: Option<String>,
480 stop_trigger: StopTrigger,
481 ) -> StepResult {
482 StepResult {
483 token,
484 stop_trigger: Some(stop_trigger),
485 }
486 }
487
488 fn jailed_string(&self) -> Option<String> {
489 if self.jailed_bytes > 0 {
490 Some(self.jail[self.jail.len() - self.jailed_bytes..].to_string())
492 } else {
493 None
494 }
495 }
496}