dynamo_llm/
backend.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Backend
5//!
6//! An [`Backend`] is the final stage of the pipeline. It represents the execution of the LLM
7//! on some processing hardware.
8//!
9//! At minimum, the Backend is split into two components, the [`Backend`] itself and a downstream [`ExecutionContext`].
10//!
11//! The [`ExecutionContext`] can be thought of as the core driver of the forward pass, whereas the [`Backend`] is the
12//! manager of all resources and concurrent tasks surrounding the LLM execution context / forward pass.
13//!
14//! For almost every known scenario, detokenization and initial post processing must happen in the Backend.
15//! Further post-processing can happen in the response stream. One example is the jailing mechanism for partial
16//! hidden stop condition matches, which can be handled in the response stream rather than the backend.
17
18use std::{collections::HashSet, sync::Arc};
19
20use anyhow::Result;
21use futures::stream::{self, StreamExt};
22use tracing as log;
23
24use crate::model_card::ModelDeploymentCard;
25use dynamo_runtime::{
26    pipeline::{
27        AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, ServerStreamingEngine,
28        SingleIn, async_trait,
29    },
30    protocols::annotated::Annotated,
31};
32
33use crate::protocols::{
34    TokenIdType,
35    common::{
36        StopConditions,
37        llm_backend::{
38            BackendOutput, EmbeddingsEngineOutput, FinishReason, LLMEngineOutput,
39            PreprocessedRequest,
40        },
41        preprocessor::PreprocessedEmbeddingRequest,
42    },
43};
44use crate::tokenizers::{DecodeStream, HuggingFaceTokenizer, Tokenizer};
45use tokenizers::Tokenizer as HfTokenizer;
46
47/// Represents the output stream from the execution engine
48pub type ExecutionOutputStream = Annotated<LLMEngineOutput>;
49
50/// Context for executing LLM inference, engine consumes backend input and produces execution output stream
51pub type ExecutionContext = ServerStreamingEngine<PreprocessedRequest, ExecutionOutputStream>;
52
53/// Backend handles resource management and orchestrates LLM execution
54#[allow(dead_code)]
55pub struct Backend {
56    pub tokenizer: Option<Tokenizer>, // Handles token encoding/decoding
57    validate_engine_decode: bool,     // Enable validation of engine decoding
58}
59
60/// Internal state for managing token decoding and stream processing
61#[allow(dead_code)]
62struct DecoderUnfoldState {
63    stream: ManyOut<ExecutionOutputStream>,
64    decoder: Decoder,
65    validate_engine_decode: bool,
66}
67
68impl Backend {
69    pub fn from_tokenizer(tokenizer: HfTokenizer) -> Arc<Self> {
70        let tokenizer = HuggingFaceTokenizer::from_tokenizer(tokenizer);
71        let tokenizer = Tokenizer::from(Arc::new(tokenizer));
72
73        Arc::new(Self {
74            tokenizer: Some(tokenizer),
75            validate_engine_decode: false,
76        })
77    }
78
79    pub fn from_mdc(mdc: &ModelDeploymentCard) -> Arc<Self> {
80        match mdc.tokenizer_hf() {
81            Ok(tokenizer) => Self::from_tokenizer(tokenizer),
82            Err(err) => {
83                tracing::warn!(%err, "tokenizer_hf error converting ModelDeploymentCard to HF tokenizer");
84                Arc::new(Self {
85                    tokenizer: None,
86                    validate_engine_decode: false,
87                })
88            }
89        }
90    }
91
92    fn decoder(
93        &self,
94        stream: ManyOut<ExecutionOutputStream>,
95        prompt_token_ids: &[TokenIdType],
96        stop_conditions: StopConditions,
97    ) -> anyhow::Result<DecoderUnfoldState> {
98        let Some(tokenizer) = self.tokenizer.as_ref() else {
99            anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
100        };
101        let decoder = Decoder::new(
102            tokenizer.decode_stream(prompt_token_ids, false),
103            stop_conditions,
104        );
105
106        Ok(DecoderUnfoldState {
107            stream,
108            decoder,
109            validate_engine_decode: self.validate_engine_decode,
110        })
111    }
112}
113
114#[async_trait]
115impl
116    Operator<
117        SingleIn<PreprocessedRequest>,
118        ManyOut<Annotated<BackendOutput>>,
119        SingleIn<PreprocessedRequest>,
120        ManyOut<Annotated<LLMEngineOutput>>,
121    > for Backend
122{
123    async fn generate(
124        &self,
125        request: SingleIn<PreprocessedRequest>,
126        next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
127    ) -> Result<ManyOut<Annotated<BackendOutput>>> {
128        let stop_conditions = request.stop_conditions.clone();
129
130        let prompt_token_ids = request.token_ids.clone();
131
132        let next_stream = next.generate(request).await?;
133
134        let context = next_stream.context();
135        let state = self.decoder(next_stream, &prompt_token_ids, stop_conditions)?;
136
137        let processed_stream = stream::unfold(state, |mut state| async move {
138            match state.stream.next().await {
139                Some(output) => {
140                    // move to state.process_output
141                    // handle any error conditions / unwraps here
142
143                    // events are pass thru
144                    if output.is_event() || output.data.is_none() {
145                        return Some((output, state));
146                    }
147
148                    // if we have a data field without an event, then we might need to update the data
149                    if let Some(data) = &output.data
150                        && data.text.is_some()
151                        && !state.validate_engine_decode
152                    {
153                        return Some((output, state));
154                    }
155
156                    let data = output.data.as_ref().unwrap();
157
158                    let result = state.decoder.process_token_ids(&data.token_ids).unwrap();
159
160                    // NOTE: the `finish_reason` is computed from the generated `token_ids` alone.
161                    // The `data` field can have a `finish_reason` set, coming from the underlying
162                    // LLM inference `Engine`, and empty `token_ids`. See comment below for more details.
163                    let finish_reason = match &result.stop_trigger {
164                        Some(StopTrigger::MaxTokensLimit) => Some(FinishReason::Length),
165                        Some(StopTrigger::HiddenStopTokenDetected(_)) => Some(FinishReason::Stop),
166                        Some(StopTrigger::HiddenStopSequenceDetected(_)) => {
167                            Some(FinishReason::Stop)
168                        }
169                        None => None,
170                    };
171
172                    if data.finish_reason.is_none() && finish_reason.is_some() {
173                        tracing::debug!(
174                            ?result.stop_trigger,
175                            "upstream did not provide a finish reason; issuing a stop_generation request to free resources",
176                        );
177                        state.stream.context().stop_generating();
178                    }
179
180                    let text = result.text;
181                    let tokens = result.tokens;
182
183                    if state.validate_engine_decode {
184                        if data.finish_reason != finish_reason {
185                            log::warn!(
186                                "finish reason mismatch: expected {:?}, got {:?}",
187                                data.finish_reason,
188                                finish_reason
189                            );
190                        }
191
192                        if data.text.is_some() && data.text != text {
193                            log::warn!("text mismatch: expected {:?}, got {:?}", data.text, text);
194                        }
195                    }
196
197                    // update output in-place
198                    let mut output = output;
199                    let mut data = output.data.take().unwrap();
200
201                    // NOTE: If `finish_reason.is_some()`, then one of the stop conditions was triggered
202                    // by the token generation. We should update the `data.finish_reason` in that case.
203                    // However, if `finish_reason.is_none()`, it is possible that we are in the case where
204                    // `data.token_ids` is empty, and `data.finish_reason` is already correctly set.
205                    // In that case, `process_token_ids` above will rewrite `finish_reason` to `None`,
206                    // which we don't want to propagate to `data.finish_reason`.
207                    if finish_reason.is_some() {
208                        data.finish_reason = finish_reason;
209                    }
210                    data.text = text;
211                    data.tokens = Some(tokens);
212
213                    output.data = Some(data);
214
215                    Some((output, state))
216                }
217
218                None => None,
219            }
220        });
221
222        // convert stream of processed Annotated<LLMEngineOutput> to Annotated<BackendOutput>
223        //let mdcsum = self.mdcsum.clone();
224        let stream = processed_stream.map(move |output| {
225            output.map_data(|data| {
226                Ok(BackendOutput {
227                    token_ids: data.token_ids,
228                    tokens: data.tokens.unwrap_or_default(),
229                    text: data.text,
230                    cum_log_probs: data.cum_log_probs,
231                    log_probs: data.log_probs,
232                    top_logprobs: data.top_logprobs,
233                    finish_reason: data.finish_reason,
234                    //mdcsum: mdcsum.clone(),
235                    index: data.index,
236                })
237            })
238        });
239
240        Ok(ResponseStream::new(Box::pin(stream), context))
241    }
242}
243
244#[async_trait]
245impl
246    Operator<
247        SingleIn<PreprocessedEmbeddingRequest>,
248        ManyOut<Annotated<EmbeddingsEngineOutput>>,
249        SingleIn<PreprocessedEmbeddingRequest>,
250        ManyOut<Annotated<EmbeddingsEngineOutput>>,
251    > for Backend
252{
253    async fn generate(
254        &self,
255        request: SingleIn<PreprocessedEmbeddingRequest>,
256        next: ServerStreamingEngine<
257            PreprocessedEmbeddingRequest,
258            Annotated<EmbeddingsEngineOutput>,
259        >,
260    ) -> Result<ManyOut<Annotated<EmbeddingsEngineOutput>>> {
261        // For embeddings, we mostly pass through since no detokenization is needed
262        // But we could add validation, logging, or other post-processing here
263        let response_stream = next.generate(request).await?;
264
265        // Could add embedding-specific post-processing here:
266        // - Validation of embedding dimensions
267        // - Normalization if requested
268        // - Usage statistics validation
269
270        Ok(response_stream)
271    }
272}
273
274// todo - add visible stop conditions
275// visible_stop_ids: HashSet<TokenIdType>,
276// visible_stop_sequences: Vec<String>,
277
278/// The [`Decoder`] object could be a member of either the internal LLM engine or part of the
279/// postprocessor. If in the postprocessor, should be minimally in the same process or at very minimum
280/// on the same physical machine connected by an IPC.
281#[allow(dead_code)]
282pub struct Decoder {
283    decode_stream: DecodeStream,
284
285    // do not trigger stop conditions until at least this many tokens have been generated
286    min_tokens: u32,
287
288    // single tokens that if found in the response will trigger a stop condition after the
289    // minimum number of tokens have been generated
290    hidden_stop_ids: HashSet<TokenIdType>,
291
292    // text sequences that if found in the response will trigger a stop condition after the
293    // minimum number of tokens have been generated
294    hidden_stop_sequences: Vec<String>,
295
296    // number of generated tokens
297    generated_tokens: u32,
298
299    // content jailed by partial hidden stop matches
300    jail: String,
301
302    // maximum number of bytes for the largest stop sequence
303    jail_max_bytes: usize,
304
305    // the number of bytes currently jailed
306    jailed_bytes: usize,
307    // mdcsum
308    //mdcsum: String,
309}
310
311#[allow(dead_code)]
312#[derive(Debug)]
313pub enum StopTrigger {
314    MaxTokensLimit,
315    HiddenStopTokenDetected(TokenIdType),
316    HiddenStopSequenceDetected(String),
317}
318
319impl StopTrigger {
320    pub fn should_hide_text(&self) -> bool {
321        match self {
322            StopTrigger::MaxTokensLimit => false,
323            StopTrigger::HiddenStopTokenDetected(_) => true,
324            StopTrigger::HiddenStopSequenceDetected(_) => true,
325        }
326    }
327}
328
329pub struct StepResult {
330    pub token: Option<String>,
331    pub stop_trigger: Option<StopTrigger>,
332}
333
334impl StepResult {
335    fn ok(token: Option<String>) -> Self {
336        Self {
337            token,
338            stop_trigger: None,
339        }
340    }
341
342    fn with_stop_trigger(token: Option<String>, stop_trigger: StopTrigger) -> Self {
343        Self {
344            token,
345            stop_trigger: Some(stop_trigger),
346        }
347    }
348}
349
350/// Result of processing a sequence of tokens
351pub struct SeqResult {
352    pub tokens: Vec<Option<String>>,       // Individual decoded tokens
353    pub text: Option<String>,              // Combined decoded text
354    pub stop_trigger: Option<StopTrigger>, // Reason for stopping generation, if any
355}
356
357#[allow(dead_code)]
358impl Decoder {
359    pub fn new(
360        decode_stream: DecodeStream,
361        stop_condition: StopConditions,
362        //mdcsum: String,
363    ) -> Self {
364        let hidden_stop_ids: HashSet<TokenIdType> = stop_condition
365            .stop_token_ids_hidden
366            .unwrap_or_default()
367            .iter()
368            .copied()
369            .collect();
370
371        let hidden_stop_sequences: Vec<String> = stop_condition
372            .stop
373            .unwrap_or_default()
374            .iter()
375            .map(|x| x.to_string())
376            .collect();
377
378        let jail_max_bytes = hidden_stop_sequences
379            .iter()
380            .map(|x| x.len())
381            .max()
382            .unwrap_or(0);
383
384        Self {
385            decode_stream,
386            hidden_stop_ids,
387            hidden_stop_sequences,
388            //visible_stop_ids: HashSet::new(),
389            //visible_stop_sequences: Vec::new(),
390            min_tokens: stop_condition.min_tokens.unwrap_or(0),
391            generated_tokens: 0,
392            jail: String::new(),
393            jail_max_bytes,
394            jailed_bytes: 0,
395        }
396    }
397
398    /// Minimum amount of work to determine if a given generated/decoded sequence should be stopped
399    /// This method can be called by the inner most loop of the LLM engine or minimally in the same
400    /// process as the LLM engine.
401    ///
402    /// In the future, this method may kick off async cpu/tokio tasks and or async cuda tasks to
403    /// handle logits post-processing and/or other tasks.
404    pub fn step(&mut self, token_id: TokenIdType) -> Result<StepResult> {
405        // increment the generated tokens
406        self.generated_tokens += 1;
407
408        // decode the token
409        let token = self.decode_stream.step(token_id)?;
410
411        // stop conditions to not apply until the minimum number of tokens have been generated
412        if self.generated_tokens < self.min_tokens {
413            return Ok(StepResult::ok(token));
414        }
415
416        // check for hidden stop tokens - eos takes precedence
417        if self.hidden_stop_ids.contains(&token_id) {
418            return Ok(StepResult::with_stop_trigger(
419                token,
420                StopTrigger::HiddenStopTokenDetected(token_id),
421            ));
422        }
423
424        // check stop sequences - the jail will always hold at least the largest stop sequence
425        // if jail_max_bytes is 0, then there are no stop sequences
426        if self.jail_max_bytes > 0
427            && let Some(token) = &token
428        {
429            let pre_append = self.jail.len();
430            log::debug!("pre_append: {}", pre_append);
431            log::debug!("jail: {}", self.jail);
432            self.jail.push_str(token);
433            log::debug!("post_append: {}", self.jail.len());
434            log::debug!("jail: {}", self.jail);
435
436            for seq in &self.hidden_stop_sequences {
437                log::debug!("stop seq: {}", seq);
438                if let Some(offset) = galil_seiferas::gs_find(self.jail.as_bytes(), seq.as_bytes())
439                {
440                    log::debug!("offset: {}", offset);
441                    // return only new bytes after pre_append .. offset+seq.len()
442                    // example: seq = "ox", token = "boxes", return "b"
443                    // note: this changes when we start jailing tokens for partial matches
444                    // on the suffix of the jail with prefixes of the stop sequences
445                    //
446                    // we might have returned a partial match, if so, then offset < pre_append
447                    // in that case, we return the empty string
448                    let partial_token = if offset >= pre_append {
449                        self.jail[pre_append..offset].to_string()
450                    } else {
451                        "".to_string()
452                    };
453                    return Ok(StepResult::with_stop_trigger(
454                        Some(partial_token),
455                        StopTrigger::HiddenStopSequenceDetected(seq.to_string()),
456                    ));
457                }
458            }
459
460            if self.jail.len() > self.jail_max_bytes {
461                // truncate the jail
462                let drain_len = self.jail.len() - self.jail_max_bytes;
463                self.jail.drain(0..drain_len);
464            }
465        }
466
467        Ok(StepResult::ok(token))
468    }
469
470    pub fn process_token_ids(&mut self, token_ids: &[TokenIdType]) -> Result<SeqResult> {
471        let mut text: Option<String> = None;
472        let mut tokens = Vec::with_capacity(token_ids.len());
473
474        for token_id in token_ids {
475            let StepResult {
476                token,
477                stop_trigger,
478            } = self.step(*token_id)?;
479
480            let hide_text = stop_trigger
481                .as_ref()
482                .map(|x| x.should_hide_text())
483                .unwrap_or(false);
484
485            if !hide_text && let Some(token) = &token {
486                text.get_or_insert_with(|| String::with_capacity(token_ids.len()))
487                    .push_str(token);
488            }
489            tokens.push(token);
490
491            if let Some(stop_trigger) = stop_trigger {
492                return Ok(SeqResult {
493                    tokens,
494                    text,
495                    stop_trigger: Some(stop_trigger),
496                });
497            }
498        }
499
500        Ok(SeqResult {
501            tokens,
502            text,
503            stop_trigger: None,
504        })
505    }
506
507    fn return_token(&self, token: Option<String>) -> StepResult {
508        StepResult {
509            token,
510            stop_trigger: None,
511        }
512    }
513
514    fn return_with_stop_trigger(
515        &self,
516        token: Option<String>,
517        stop_trigger: StopTrigger,
518    ) -> StepResult {
519        StepResult {
520            token,
521            stop_trigger: Some(stop_trigger),
522        }
523    }
524
525    fn jailed_string(&self) -> Option<String> {
526        if self.jailed_bytes > 0 {
527            // get the last jailed_bytes from the jail
528            Some(self.jail[self.jail.len() - self.jailed_bytes..].to_string())
529        } else {
530            None
531        }
532    }
533}