dynamo_llm/
backend.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Backend
17//!
18//! An [`Backend`] is the final stage of the pipeline. It represents the execution of the LLM
19//! on some processing hardware.
20//!
21//! At minimum, the Backend is split into two components, the [`Backend`] itself and a downstream [`ExecutionContext`].
22//!
23//! The [`ExecutionContext`] can be thought of as the core driver of the forward pass, whereas the [`Backend`] is the
24//! manager of all resources and concurrent tasks surrounding the LLM execution context / forward pass.
25//!
26//! For almost every known scenario, detokenization and initial post processing must happen in the Backend.
27//! Further post-processing can happen in the response stream. One example is the jailing mechanism for partial
28//! hidden stop condition matches, which can be handled in the response stream rather than the backend.
29
30use std::{collections::HashSet, sync::Arc};
31
32use anyhow::{Error, Result};
33use futures::stream::{self, StreamExt};
34use tracing as log;
35
36use crate::model_card::model::{ModelDeploymentCard, TokenizerKind};
37use dynamo_runtime::{
38    pipeline::{
39        async_trait, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream,
40        ServerStreamingEngine, SingleIn,
41    },
42    protocols::annotated::Annotated,
43};
44
45use crate::protocols::{
46    common::{
47        llm_backend::{BackendOutput, FinishReason, LLMEngineOutput, PreprocessedRequest},
48        StopConditions,
49    },
50    TokenIdType,
51};
52use crate::tokenizers::{DecodeStream, HuggingFaceTokenizer, Tokenizer};
53use tokenizers::Tokenizer as HfTokenizer;
54
55/// Represents the output stream from the execution engine
56pub type ExecutionOutputStream = Annotated<LLMEngineOutput>;
57
58/// Context for executing LLM inference, engine consumes backend input and produces execution output stream
59pub type ExecutionContext = ServerStreamingEngine<PreprocessedRequest, ExecutionOutputStream>;
60
61/// Backend handles resource management and orchestrates LLM execution
62#[allow(dead_code)]
63pub struct Backend {
64    pub tokenizer: Option<Tokenizer>, // Handles token encoding/decoding
65    validate_engine_decode: bool,     // Enable validation of engine decoding
66}
67
68/// Internal state for managing token decoding and stream processing
69#[allow(dead_code)]
70struct DecoderUnfoldState {
71    stream: ManyOut<ExecutionOutputStream>,
72    decoder: Decoder,
73    validate_engine_decode: bool,
74}
75
76impl Backend {
77    pub async fn from_tokenizer(tokenizer: HfTokenizer) -> Result<Arc<Self>> {
78        let tokenizer = HuggingFaceTokenizer::from_tokenizer(tokenizer);
79        let tokenizer = Tokenizer::from(Arc::new(tokenizer));
80
81        Ok(Arc::new(Self {
82            tokenizer: Some(tokenizer),
83            validate_engine_decode: false,
84        }))
85    }
86
87    pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
88        let tokenizer = match &mdc.tokenizer {
89            Some(TokenizerKind::HfTokenizerJson(file)) => {
90                HfTokenizer::from_file(file).map_err(Error::msg)?
91            }
92            Some(TokenizerKind::GGUF(t)) => *t.clone(),
93            None => {
94                return Ok(Arc::new(Self {
95                    tokenizer: None,
96                    validate_engine_decode: false,
97                }));
98            }
99        };
100        Self::from_tokenizer(tokenizer).await
101    }
102
103    fn decoder(
104        &self,
105        stream: ManyOut<ExecutionOutputStream>,
106        stop_conditions: StopConditions,
107    ) -> anyhow::Result<DecoderUnfoldState> {
108        let Some(tokenizer) = self.tokenizer.as_ref() else {
109            anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
110        };
111        let decoder = Decoder::new(tokenizer.decode_stream(false), stop_conditions);
112
113        Ok(DecoderUnfoldState {
114            stream,
115            decoder,
116            validate_engine_decode: self.validate_engine_decode,
117        })
118    }
119}
120
121#[async_trait]
122impl
123    Operator<
124        SingleIn<PreprocessedRequest>,
125        ManyOut<Annotated<BackendOutput>>,
126        SingleIn<PreprocessedRequest>,
127        ManyOut<Annotated<LLMEngineOutput>>,
128    > for Backend
129{
130    async fn generate(
131        &self,
132        request: SingleIn<PreprocessedRequest>,
133        next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
134    ) -> Result<ManyOut<Annotated<BackendOutput>>> {
135        let stop_conditions = request.stop_conditions.clone();
136        let next_stream = next.generate(request).await?;
137
138        let context = next_stream.context();
139        let state = self.decoder(next_stream, stop_conditions)?;
140
141        let processed_stream = stream::unfold(state, |mut state| async move {
142            match state.stream.next().await {
143                Some(output) => {
144                    // move to state.process_output
145                    // handle any error conditions / unwraps here
146
147                    // events are pass thru
148                    if output.is_event() || output.data.is_none() {
149                        return Some((output, state));
150                    }
151
152                    // if we have a data field without an event, then we might need to update the data
153                    if let Some(data) = &output.data {
154                        if data.text.is_some() && !state.validate_engine_decode {
155                            return Some((output, state));
156                        }
157                    }
158
159                    let data = output.data.as_ref().unwrap();
160
161                    let result = state.decoder.process_token_ids(&data.token_ids).unwrap();
162
163                    // todo - propagate finish reason details - possibly an annotation
164                    let finish_reason = match &result.stop_trigger {
165                        Some(StopTrigger::MaxTokensLimit) => Some(FinishReason::Length),
166                        Some(StopTrigger::HiddenStopTokenDetected(_)) => Some(FinishReason::Stop),
167                        Some(StopTrigger::HiddenStopSequenceDetected(_)) => {
168                            Some(FinishReason::Stop)
169                        }
170                        None => None,
171                    };
172
173                    if data.finish_reason.is_none() && finish_reason.is_some() {
174                        tracing::debug!(
175                            ?result.stop_trigger,
176                            "upstream did not provide a finish reason; issuing a stop_generation request to free resources",
177                        );
178                        state.stream.context().stop_generating();
179                    }
180
181                    let text = result.text;
182                    let tokens = result.tokens;
183
184                    if state.validate_engine_decode {
185                        if data.finish_reason != finish_reason {
186                            log::warn!(
187                                "finish reason mismatch: expected {:?}, got {:?}",
188                                data.finish_reason,
189                                finish_reason
190                            );
191                        }
192
193                        if data.text.is_some() && data.text != text {
194                            log::warn!("text mismatch: expected {:?}, got {:?}", data.text, text);
195                        }
196                    }
197
198                    // update output in-place
199                    let mut output = output;
200                    let mut data = output.data.take().unwrap();
201
202                    data.finish_reason = finish_reason;
203                    data.text = text;
204                    data.tokens = Some(tokens);
205
206                    output.data = Some(data);
207
208                    Some((output, state))
209                }
210
211                None => None,
212            }
213        });
214
215        // convert stream of processed Annotated<LLMEngineOutput> to Annotated<BackendOutput>
216        //let mdcsum = self.mdcsum.clone();
217        let stream = processed_stream.map(move |output| {
218            output.map_data(|data| {
219                Ok(BackendOutput {
220                    token_ids: data.token_ids,
221                    tokens: data.tokens.unwrap_or_default(),
222                    text: data.text,
223                    cum_log_probs: data.cum_log_probs,
224                    log_probs: data.log_probs,
225                    finish_reason: data.finish_reason,
226                    //mdcsum: mdcsum.clone(),
227                })
228            })
229        });
230
231        Ok(ResponseStream::new(Box::pin(stream), context))
232    }
233}
234
235// todo - add visible stop conditions
236// visible_stop_ids: HashSet<TokenIdType>,
237// visible_stop_sequences: Vec<String>,
238
239/// The [`Decoder`] object could be a member of either the internal LLM engine or part of the
240/// postprocessor. If in the postprocessor, should be minimally in the same process or at very minimum
241/// on the same physical machine connected by an IPC.
242#[allow(dead_code)]
243pub struct Decoder {
244    decode_stream: DecodeStream,
245
246    // do not trigger stop conditions until at least this many tokens have been generated
247    min_tokens: u32,
248
249    // single tokens that if found in the response will trigger a stop condition after the
250    // minimum number of tokens have been generated
251    hidden_stop_ids: HashSet<TokenIdType>,
252
253    // text sequences that if found in the response will trigger a stop condition after the
254    // minimum number of tokens have been generated
255    hidden_stop_sequences: Vec<String>,
256
257    // number of generated tokens
258    generated_tokens: u32,
259
260    // content jailed by partial hidden stop matches
261    jail: String,
262
263    // maximum number of bytes for the largest stop sequence
264    jail_max_bytes: usize,
265
266    // the number of bytes currently jailed
267    jailed_bytes: usize,
268    // mdcsum
269    //mdcsum: String,
270}
271
272#[allow(dead_code)]
273#[derive(Debug)]
274pub enum StopTrigger {
275    MaxTokensLimit,
276    HiddenStopTokenDetected(TokenIdType),
277    HiddenStopSequenceDetected(String),
278}
279
280impl StopTrigger {
281    pub fn should_hide_text(&self) -> bool {
282        match self {
283            StopTrigger::MaxTokensLimit => false,
284            StopTrigger::HiddenStopTokenDetected(_) => true,
285            StopTrigger::HiddenStopSequenceDetected(_) => true,
286        }
287    }
288}
289
290pub struct StepResult {
291    pub token: Option<String>,
292    pub stop_trigger: Option<StopTrigger>,
293}
294
295impl StepResult {
296    fn ok(token: Option<String>) -> Self {
297        Self {
298            token,
299            stop_trigger: None,
300        }
301    }
302
303    fn with_stop_trigger(token: Option<String>, stop_trigger: StopTrigger) -> Self {
304        Self {
305            token,
306            stop_trigger: Some(stop_trigger),
307        }
308    }
309}
310
311/// Result of processing a sequence of tokens
312pub struct SeqResult {
313    pub tokens: Vec<Option<String>>,       // Individual decoded tokens
314    pub text: Option<String>,              // Combined decoded text
315    pub stop_trigger: Option<StopTrigger>, // Reason for stopping generation, if any
316}
317
318#[allow(dead_code)]
319impl Decoder {
320    pub fn new(
321        decode_stream: DecodeStream,
322        stop_condition: StopConditions,
323        //mdcsum: String,
324    ) -> Self {
325        let hidden_stop_ids: HashSet<TokenIdType> = stop_condition
326            .stop_token_ids_hidden
327            .unwrap_or_default()
328            .iter()
329            .copied()
330            .collect();
331
332        let hidden_stop_sequences: Vec<String> = stop_condition
333            .stop
334            .unwrap_or_default()
335            .iter()
336            .map(|x| x.to_string())
337            .collect();
338
339        let jail_max_bytes = hidden_stop_sequences
340            .iter()
341            .map(|x| x.len())
342            .max()
343            .unwrap_or(0);
344
345        Self {
346            decode_stream,
347            hidden_stop_ids,
348            hidden_stop_sequences,
349            //visible_stop_ids: HashSet::new(),
350            //visible_stop_sequences: Vec::new(),
351            min_tokens: stop_condition.min_tokens.unwrap_or(0),
352            generated_tokens: 0,
353            jail: String::new(),
354            jail_max_bytes,
355            jailed_bytes: 0,
356        }
357    }
358
359    /// Minimum amount of work to determine if a given generated/decoded sequence should be stopped
360    /// This method can be called by the inner most loop of the LLM engine or minimally in the same
361    /// process as the LLM engine.
362    ///
363    /// In the future, this method may kick off async cpu/tokio tasks and or async cuda tasks to
364    /// handle logits post-processing and/or other tasks.
365    pub fn step(&mut self, token_id: TokenIdType) -> Result<StepResult> {
366        // increment the generated tokens
367        self.generated_tokens += 1;
368
369        // decode the token
370        let token = self.decode_stream.step(token_id)?;
371
372        // stop conditions to not apply until the minimum number of tokens have been generated
373        if self.generated_tokens < self.min_tokens {
374            return Ok(StepResult::ok(token));
375        }
376
377        // check for hidden stop tokens - eos takes precedence
378        if self.hidden_stop_ids.contains(&token_id) {
379            return Ok(StepResult::with_stop_trigger(
380                token,
381                StopTrigger::HiddenStopTokenDetected(token_id),
382            ));
383        }
384
385        // check stop sequences - the jail will always hold at least the largest stop sequence
386        // if jail_max_bytes is 0, then there are no stop sequences
387        if self.jail_max_bytes > 0 {
388            if let Some(token) = &token {
389                let pre_append = self.jail.len();
390                log::debug!("pre_append: {}", pre_append);
391                log::debug!("jail: {}", self.jail);
392                self.jail.push_str(token);
393                log::debug!("post_append: {}", self.jail.len());
394                log::debug!("jail: {}", self.jail);
395
396                for seq in &self.hidden_stop_sequences {
397                    log::debug!("stop seq: {}", seq);
398                    if let Some(offset) =
399                        galil_seiferas::gs_find(self.jail.as_bytes(), seq.as_bytes())
400                    {
401                        log::debug!("offset: {}", offset);
402                        // return only new bytes after pre_append .. offset+seq.len()
403                        // example: seq = "ox", token = "boxes", return "b"
404                        // note: this changes when we start jailing tokens for partial matches
405                        // on the suffix of the jail with prefixes of the stop sequences
406                        //
407                        // we might have returned a partial match, if so, then offset < pre_append
408                        // in that case, we return the empty string
409                        let partial_token = if offset >= pre_append {
410                            self.jail[pre_append..offset].to_string()
411                        } else {
412                            "".to_string()
413                        };
414                        return Ok(StepResult::with_stop_trigger(
415                            Some(partial_token),
416                            StopTrigger::HiddenStopSequenceDetected(seq.to_string()),
417                        ));
418                    }
419                }
420
421                if self.jail.len() > self.jail_max_bytes {
422                    // truncate the jail
423                    let drain_len = self.jail.len() - self.jail_max_bytes;
424                    self.jail.drain(0..drain_len);
425                }
426            }
427        }
428
429        Ok(StepResult::ok(token))
430    }
431
432    pub fn process_token_ids(&mut self, token_ids: &[TokenIdType]) -> Result<SeqResult> {
433        let mut text: Option<String> = None;
434        let mut tokens = Vec::new();
435
436        for token_id in token_ids {
437            let StepResult {
438                token,
439                stop_trigger,
440            } = self.step(*token_id)?;
441
442            let hide_text = stop_trigger
443                .as_ref()
444                .map(|x| x.should_hide_text())
445                .unwrap_or(false);
446
447            if !hide_text {
448                if let Some(token) = &token {
449                    text.get_or_insert_with(String::new).push_str(token);
450                }
451            }
452            tokens.push(token);
453
454            if let Some(stop_trigger) = stop_trigger {
455                return Ok(SeqResult {
456                    tokens,
457                    text,
458                    stop_trigger: Some(stop_trigger),
459                });
460            }
461        }
462
463        Ok(SeqResult {
464            tokens,
465            text,
466            stop_trigger: None,
467        })
468    }
469
470    fn return_token(&self, token: Option<String>) -> StepResult {
471        StepResult {
472            token,
473            stop_trigger: None,
474        }
475    }
476
477    fn return_with_stop_trigger(
478        &self,
479        token: Option<String>,
480        stop_trigger: StopTrigger,
481    ) -> StepResult {
482        StepResult {
483            token,
484            stop_trigger: Some(stop_trigger),
485        }
486    }
487
488    fn jailed_string(&self) -> Option<String> {
489        if self.jailed_bytes > 0 {
490            // get the last jailed_bytes from the jail
491            Some(self.jail[self.jail.len() - self.jailed_bytes..].to_string())
492        } else {
493            None
494        }
495    }
496}