dynamo_llm/protocols/common/
llm_backend.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use serde::{Deserialize, Serialize};
5
6pub use super::FinishReason;
7pub use super::preprocessor::PreprocessedRequest;
8use crate::protocols::TokenIdType;
9use dynamo_runtime::protocols::maybe_error::MaybeError;
10
11pub type TokenType = Option<String>;
12pub type LogProbs = Vec<f64>;
13
14#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
15pub struct TopLogprob {
16    pub rank: u32,
17    pub token_id: TokenIdType,
18    pub token: TokenType,
19    pub logprob: f64,
20}
21pub type TopLogprobs = Vec<Vec<TopLogprob>>; // num_tokens x top_logprobs
22
23#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
24pub struct BackendOutput {
25    /// New token_ids generated from the LLM Engine
26    pub token_ids: Vec<TokenIdType>,
27
28    /// Unlike [`LLMEngineOutput::tokens`], this is a vector of tokens, not an optional.
29    /// The size of this vector should be the same as the size of `token_ids`.
30    pub tokens: Vec<TokenType>,
31
32    /// Decoded text from the list tokens.
33    pub text: Option<String>,
34
35    /// Optional cumulative log probabilities
36    pub cum_log_probs: Option<f64>,
37
38    /// Optional log probabilities
39    pub log_probs: Option<LogProbs>,
40
41    pub top_logprobs: Option<TopLogprobs>,
42
43    // TODO: Enrich this with more information as can apply our first-level postprocessing
44    // logic and return more detailed information
45    pub finish_reason: Option<FinishReason>,
46    // Model Deployment Card checksum
47    //pub mdcsum: String,
48
49    // Index field for batch requests to match OpenAI format
50    pub index: Option<u32>,
51}
52
53/// The LLM engine and backnd with manage it's own state, specifically translating how a
54/// given request/slot is managed on that particular backend.
55///
56/// For nvLLM's purpose, it has a single tracable request_id as part of it's context that
57/// has propaged through the service pipeline to the backend.
58///
59/// This is the minimal raw output from the LLM engine. The Backend may then apply multiple
60/// levels of post-processing before the BackendOutput is returns
61#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
62pub struct LLMEngineOutput {
63    // new token_ids
64    pub token_ids: Vec<TokenIdType>,
65
66    /// If the LLM Engine performs the detokenization, then this will have a Some of the detokenized
67    /// text/tokens. If this value is None, then the Backend is responsible for detokenization.
68    pub tokens: Option<Vec<TokenType>>,
69
70    // decoded text -
71    pub text: Option<String>,
72
73    /// cumulative log probabilities
74    pub cum_log_probs: Option<f64>,
75
76    /// Optional log probabilities
77    pub log_probs: Option<LogProbs>,
78
79    pub top_logprobs: Option<TopLogprobs>,
80
81    // TODO: Enrich this with more information as can apply our first-level postprocessing
82    // logic and return more detailed information
83    pub finish_reason: Option<FinishReason>,
84
85    // Index field for batch requests to match OpenAI format
86    pub index: Option<u32>,
87}
88
89impl LLMEngineOutput {
90    pub fn cancelled() -> Self {
91        LLMEngineOutput {
92            token_ids: vec![],
93            tokens: None,
94            text: None,
95            cum_log_probs: None,
96            log_probs: None,
97            top_logprobs: None,
98            finish_reason: Some(FinishReason::Cancelled),
99            index: None,
100        }
101    }
102
103    pub fn stop() -> Self {
104        LLMEngineOutput {
105            token_ids: vec![],
106            tokens: None,
107            text: None,
108            cum_log_probs: None,
109            log_probs: None,
110            finish_reason: Some(FinishReason::Stop),
111            top_logprobs: None,
112            index: None,
113        }
114    }
115
116    pub fn length() -> Self {
117        LLMEngineOutput {
118            token_ids: vec![],
119            tokens: None,
120            text: None,
121            cum_log_probs: None,
122            log_probs: None,
123            top_logprobs: None,
124            finish_reason: Some(FinishReason::Length),
125            index: None,
126        }
127    }
128
129    pub fn error(err_msg: String) -> Self {
130        LLMEngineOutput {
131            token_ids: vec![],
132            tokens: None,
133            text: None,
134            cum_log_probs: None,
135            log_probs: None,
136            top_logprobs: None,
137            finish_reason: Some(FinishReason::Error(err_msg)),
138            index: None,
139        }
140    }
141}
142
143impl MaybeError for LLMEngineOutput {
144    fn from_err(err: Box<dyn std::error::Error + Send + Sync>) -> Self {
145        LLMEngineOutput::error(format!("{:?}", err))
146    }
147
148    fn err(&self) -> Option<anyhow::Error> {
149        if let Some(FinishReason::Error(err_msg)) = &self.finish_reason {
150            Some(anyhow::Error::msg(err_msg.clone()))
151        } else {
152            None
153        }
154    }
155}
156
157/// Raw output from embedding engines containing embedding vectors
158#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
159pub struct EmbeddingsEngineOutput {
160    /// Generated embedding vectors (one per input text)
161    pub embeddings: Vec<Vec<f64>>,
162
163    /// Token usage information
164    pub prompt_tokens: u32,
165    pub total_tokens: u32,
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_maybe_error() {
174        let output = LLMEngineOutput::stop();
175        assert!(output.err().is_none());
176        assert!(output.is_ok());
177        assert!(!output.is_err());
178
179        let output = LLMEngineOutput::error("Test error".to_string());
180        assert_eq!(format!("{}", output.err().unwrap()), "Test error");
181        assert!(!output.is_ok());
182        assert!(output.is_err());
183    }
184}