1use std::{collections::HashSet, sync::Arc};
19
20use anyhow::Result;
21use futures::stream::{self, StreamExt};
22use tracing as log;
23
24use crate::model_card::ModelDeploymentCard;
25use dynamo_runtime::{
26 pipeline::{
27 AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, ServerStreamingEngine,
28 SingleIn, async_trait,
29 },
30 protocols::annotated::Annotated,
31};
32
33use crate::protocols::{
34 TokenIdType,
35 common::{
36 StopConditions,
37 llm_backend::{
38 BackendOutput, EmbeddingsEngineOutput, FinishReason, LLMEngineOutput,
39 PreprocessedRequest,
40 },
41 preprocessor::PreprocessedEmbeddingRequest,
42 },
43};
44use crate::tokenizers::{DecodeStream, HuggingFaceTokenizer, Tokenizer};
45use tokenizers::Tokenizer as HfTokenizer;
46
47pub type ExecutionOutputStream = Annotated<LLMEngineOutput>;
49
50pub type ExecutionContext = ServerStreamingEngine<PreprocessedRequest, ExecutionOutputStream>;
52
53#[allow(dead_code)]
55pub struct Backend {
56 pub tokenizer: Option<Tokenizer>, validate_engine_decode: bool, }
59
60#[allow(dead_code)]
62struct DecoderUnfoldState {
63 stream: ManyOut<ExecutionOutputStream>,
64 decoder: Decoder,
65 validate_engine_decode: bool,
66}
67
68impl Backend {
69 pub fn from_tokenizer(tokenizer: HfTokenizer) -> Arc<Self> {
70 let tokenizer = HuggingFaceTokenizer::from_tokenizer(tokenizer);
71 let tokenizer = Tokenizer::from(Arc::new(tokenizer));
72
73 Arc::new(Self {
74 tokenizer: Some(tokenizer),
75 validate_engine_decode: false,
76 })
77 }
78
79 pub fn from_mdc(mdc: &ModelDeploymentCard) -> Arc<Self> {
80 match mdc.tokenizer_hf() {
81 Ok(tokenizer) => Self::from_tokenizer(tokenizer),
82 Err(err) => {
83 tracing::warn!(%err, "tokenizer_hf error converting ModelDeploymentCard to HF tokenizer");
84 Arc::new(Self {
85 tokenizer: None,
86 validate_engine_decode: false,
87 })
88 }
89 }
90 }
91
92 fn decoder(
93 &self,
94 stream: ManyOut<ExecutionOutputStream>,
95 prompt_token_ids: &[TokenIdType],
96 stop_conditions: StopConditions,
97 ) -> anyhow::Result<DecoderUnfoldState> {
98 let Some(tokenizer) = self.tokenizer.as_ref() else {
99 anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
100 };
101 let decoder = Decoder::new(
102 tokenizer.decode_stream(prompt_token_ids, false),
103 stop_conditions,
104 );
105
106 Ok(DecoderUnfoldState {
107 stream,
108 decoder,
109 validate_engine_decode: self.validate_engine_decode,
110 })
111 }
112}
113
114#[async_trait]
115impl
116 Operator<
117 SingleIn<PreprocessedRequest>,
118 ManyOut<Annotated<BackendOutput>>,
119 SingleIn<PreprocessedRequest>,
120 ManyOut<Annotated<LLMEngineOutput>>,
121 > for Backend
122{
123 async fn generate(
124 &self,
125 request: SingleIn<PreprocessedRequest>,
126 next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
127 ) -> Result<ManyOut<Annotated<BackendOutput>>> {
128 let stop_conditions = request.stop_conditions.clone();
129
130 let prompt_token_ids = request.token_ids.clone();
131
132 let next_stream = next.generate(request).await?;
133
134 let context = next_stream.context();
135 let state = self.decoder(next_stream, &prompt_token_ids, stop_conditions)?;
136
137 let processed_stream = stream::unfold(state, |mut state| async move {
138 match state.stream.next().await {
139 Some(output) => {
140 if output.is_event() || output.data.is_none() {
145 return Some((output, state));
146 }
147
148 if let Some(data) = &output.data
150 && data.text.is_some()
151 && !state.validate_engine_decode
152 {
153 return Some((output, state));
154 }
155
156 let data = output.data.as_ref().unwrap();
157
158 let result = state.decoder.process_token_ids(&data.token_ids).unwrap();
159
160 let finish_reason = match &result.stop_trigger {
164 Some(StopTrigger::MaxTokensLimit) => Some(FinishReason::Length),
165 Some(StopTrigger::HiddenStopTokenDetected(_)) => Some(FinishReason::Stop),
166 Some(StopTrigger::HiddenStopSequenceDetected(_)) => {
167 Some(FinishReason::Stop)
168 }
169 None => None,
170 };
171
172 if data.finish_reason.is_none() && finish_reason.is_some() {
173 tracing::debug!(
174 ?result.stop_trigger,
175 "upstream did not provide a finish reason; issuing a stop_generation request to free resources",
176 );
177 state.stream.context().stop_generating();
178 }
179
180 let text = result.text;
181 let tokens = result.tokens;
182
183 if state.validate_engine_decode {
184 if data.finish_reason != finish_reason {
185 log::warn!(
186 "finish reason mismatch: expected {:?}, got {:?}",
187 data.finish_reason,
188 finish_reason
189 );
190 }
191
192 if data.text.is_some() && data.text != text {
193 log::warn!("text mismatch: expected {:?}, got {:?}", data.text, text);
194 }
195 }
196
197 let mut output = output;
199 let mut data = output.data.take().unwrap();
200
201 if finish_reason.is_some() {
208 data.finish_reason = finish_reason;
209 }
210 data.text = text;
211 data.tokens = Some(tokens);
212
213 output.data = Some(data);
214
215 Some((output, state))
216 }
217
218 None => None,
219 }
220 });
221
222 let stream = processed_stream.map(move |output| {
225 output.map_data(|data| {
226 Ok(BackendOutput {
227 token_ids: data.token_ids,
228 tokens: data.tokens.unwrap_or_default(),
229 text: data.text,
230 cum_log_probs: data.cum_log_probs,
231 log_probs: data.log_probs,
232 top_logprobs: data.top_logprobs,
233 finish_reason: data.finish_reason,
234 index: data.index,
236 })
237 })
238 });
239
240 Ok(ResponseStream::new(Box::pin(stream), context))
241 }
242}
243
244#[async_trait]
245impl
246 Operator<
247 SingleIn<PreprocessedEmbeddingRequest>,
248 ManyOut<Annotated<EmbeddingsEngineOutput>>,
249 SingleIn<PreprocessedEmbeddingRequest>,
250 ManyOut<Annotated<EmbeddingsEngineOutput>>,
251 > for Backend
252{
253 async fn generate(
254 &self,
255 request: SingleIn<PreprocessedEmbeddingRequest>,
256 next: ServerStreamingEngine<
257 PreprocessedEmbeddingRequest,
258 Annotated<EmbeddingsEngineOutput>,
259 >,
260 ) -> Result<ManyOut<Annotated<EmbeddingsEngineOutput>>> {
261 let response_stream = next.generate(request).await?;
264
265 Ok(response_stream)
271 }
272}
273
274#[allow(dead_code)]
282pub struct Decoder {
283 decode_stream: DecodeStream,
284
285 min_tokens: u32,
287
288 hidden_stop_ids: HashSet<TokenIdType>,
291
292 hidden_stop_sequences: Vec<String>,
295
296 generated_tokens: u32,
298
299 jail: String,
301
302 jail_max_bytes: usize,
304
305 jailed_bytes: usize,
307 }
310
311#[allow(dead_code)]
312#[derive(Debug)]
313pub enum StopTrigger {
314 MaxTokensLimit,
315 HiddenStopTokenDetected(TokenIdType),
316 HiddenStopSequenceDetected(String),
317}
318
319impl StopTrigger {
320 pub fn should_hide_text(&self) -> bool {
321 match self {
322 StopTrigger::MaxTokensLimit => false,
323 StopTrigger::HiddenStopTokenDetected(_) => true,
324 StopTrigger::HiddenStopSequenceDetected(_) => true,
325 }
326 }
327}
328
329pub struct StepResult {
330 pub token: Option<String>,
331 pub stop_trigger: Option<StopTrigger>,
332}
333
334impl StepResult {
335 fn ok(token: Option<String>) -> Self {
336 Self {
337 token,
338 stop_trigger: None,
339 }
340 }
341
342 fn with_stop_trigger(token: Option<String>, stop_trigger: StopTrigger) -> Self {
343 Self {
344 token,
345 stop_trigger: Some(stop_trigger),
346 }
347 }
348}
349
350pub struct SeqResult {
352 pub tokens: Vec<Option<String>>, pub text: Option<String>, pub stop_trigger: Option<StopTrigger>, }
356
357#[allow(dead_code)]
358impl Decoder {
359 pub fn new(
360 decode_stream: DecodeStream,
361 stop_condition: StopConditions,
362 ) -> Self {
364 let hidden_stop_ids: HashSet<TokenIdType> = stop_condition
365 .stop_token_ids_hidden
366 .unwrap_or_default()
367 .iter()
368 .copied()
369 .collect();
370
371 let hidden_stop_sequences: Vec<String> = stop_condition
372 .stop
373 .unwrap_or_default()
374 .iter()
375 .map(|x| x.to_string())
376 .collect();
377
378 let jail_max_bytes = hidden_stop_sequences
379 .iter()
380 .map(|x| x.len())
381 .max()
382 .unwrap_or(0);
383
384 Self {
385 decode_stream,
386 hidden_stop_ids,
387 hidden_stop_sequences,
388 min_tokens: stop_condition.min_tokens.unwrap_or(0),
391 generated_tokens: 0,
392 jail: String::new(),
393 jail_max_bytes,
394 jailed_bytes: 0,
395 }
396 }
397
398 pub fn step(&mut self, token_id: TokenIdType) -> Result<StepResult> {
405 self.generated_tokens += 1;
407
408 let token = self.decode_stream.step(token_id)?;
410
411 if self.generated_tokens < self.min_tokens {
413 return Ok(StepResult::ok(token));
414 }
415
416 if self.hidden_stop_ids.contains(&token_id) {
418 return Ok(StepResult::with_stop_trigger(
419 token,
420 StopTrigger::HiddenStopTokenDetected(token_id),
421 ));
422 }
423
424 if self.jail_max_bytes > 0
427 && let Some(token) = &token
428 {
429 let pre_append = self.jail.len();
430 log::debug!("pre_append: {}", pre_append);
431 log::debug!("jail: {}", self.jail);
432 self.jail.push_str(token);
433 log::debug!("post_append: {}", self.jail.len());
434 log::debug!("jail: {}", self.jail);
435
436 for seq in &self.hidden_stop_sequences {
437 log::debug!("stop seq: {}", seq);
438 if let Some(offset) = galil_seiferas::gs_find(self.jail.as_bytes(), seq.as_bytes())
439 {
440 log::debug!("offset: {}", offset);
441 let partial_token = if offset >= pre_append {
449 self.jail[pre_append..offset].to_string()
450 } else {
451 "".to_string()
452 };
453 return Ok(StepResult::with_stop_trigger(
454 Some(partial_token),
455 StopTrigger::HiddenStopSequenceDetected(seq.to_string()),
456 ));
457 }
458 }
459
460 if self.jail.len() > self.jail_max_bytes {
461 let drain_len = self.jail.len() - self.jail_max_bytes;
463 self.jail.drain(0..drain_len);
464 }
465 }
466
467 Ok(StepResult::ok(token))
468 }
469
470 pub fn process_token_ids(&mut self, token_ids: &[TokenIdType]) -> Result<SeqResult> {
471 let mut text: Option<String> = None;
472 let mut tokens = Vec::with_capacity(token_ids.len());
473
474 for token_id in token_ids {
475 let StepResult {
476 token,
477 stop_trigger,
478 } = self.step(*token_id)?;
479
480 let hide_text = stop_trigger
481 .as_ref()
482 .map(|x| x.should_hide_text())
483 .unwrap_or(false);
484
485 if !hide_text && let Some(token) = &token {
486 text.get_or_insert_with(|| String::with_capacity(token_ids.len()))
487 .push_str(token);
488 }
489 tokens.push(token);
490
491 if let Some(stop_trigger) = stop_trigger {
492 return Ok(SeqResult {
493 tokens,
494 text,
495 stop_trigger: Some(stop_trigger),
496 });
497 }
498 }
499
500 Ok(SeqResult {
501 tokens,
502 text,
503 stop_trigger: None,
504 })
505 }
506
507 fn return_token(&self, token: Option<String>) -> StepResult {
508 StepResult {
509 token,
510 stop_trigger: None,
511 }
512 }
513
514 fn return_with_stop_trigger(
515 &self,
516 token: Option<String>,
517 stop_trigger: StopTrigger,
518 ) -> StepResult {
519 StepResult {
520 token,
521 stop_trigger: Some(stop_trigger),
522 }
523 }
524
525 fn jailed_string(&self) -> Option<String> {
526 if self.jailed_bytes > 0 {
527 Some(self.jail[self.jail.len() - self.jailed_bytes..].to_string())
529 } else {
530 None
531 }
532 }
533}