dynamo_llm/protocols/
common.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Engine Protocols
17//! ================
18//!
19//! This module contains the protocols in public API for the LLM Engine and AsyncEngine facades.
20//!
21//! The core components are the `CompletionRequest` and `StreamingCompletionResponse` objects.
22//!
23//! The `StreamingCompletionResponse` objects are the outputs of the LLM Engine; however, we
24//! need some additional information to propagate intermediate results for improved observability.
25//! The metadata is transferred via the other arms of the `StreamingResponse` enum.
26//!
27
28use anyhow::Result;
29use derive_builder::Builder;
30use serde::ser::SerializeStruct;
31use serde::{Deserialize, Deserializer, Serialize, Serializer};
32use std::collections::HashMap;
33use std::time::SystemTime;
34
35use super::TokenIdType;
36
37pub mod llm_backend;
38pub mod postprocessor;
39pub mod preprocessor;
40
41/// SamplingOptionsProvider is a trait that allows the caller to extract the sampling options from
42/// the object that implements it. This will mutate the object.
43pub trait SamplingOptionsProvider {
44    fn extract_sampling_options(&self) -> Result<SamplingOptions>;
45}
46
47pub trait StopConditionsProvider {
48    fn extract_stop_conditions(&self) -> Result<StopConditions>;
49}
50
51#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
52pub enum FinishReason {
53    #[serde(rename = "eos")]
54    EoS,
55
56    #[serde(rename = "length")]
57    Length,
58
59    #[serde(rename = "stop")]
60    Stop,
61
62    #[serde(rename = "error")]
63    Error(String),
64
65    #[serde(rename = "cancelled")]
66    Cancelled,
67}
68
69impl std::fmt::Display for FinishReason {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        match self {
72            FinishReason::EoS => write!(f, "eos"),
73            FinishReason::Length => write!(f, "length"),
74            FinishReason::Stop => write!(f, "stop"),
75            FinishReason::Error(msg) => write!(f, "error: {}", msg),
76            FinishReason::Cancelled => write!(f, "cancelled"),
77        }
78    }
79}
80
81impl std::str::FromStr for FinishReason {
82    type Err = anyhow::Error;
83
84    fn from_str(s: &str) -> Result<Self, Self::Err> {
85        match s {
86            "eos" => Ok(FinishReason::EoS),
87            "length" => Ok(FinishReason::Length),
88            "stop" => Ok(FinishReason::Stop),
89            "cancelled" => Ok(FinishReason::Cancelled),
90            s if s.starts_with("error: ") => Ok(FinishReason::Error(s[7..].to_string())),
91            _ => Err(anyhow::anyhow!("Invalid FinishReason variant: '{}'", s)),
92        }
93    }
94}
95
96/// LLM Inference Engines can accept a variety of input types. Not all Engines will support all
97/// input types. For example, the trtllm::AsyncEngine only supports `PromptType::Tokens` as an
98/// input type. The higher-level `Backend` class is a general wrapper around Engines that will
99/// enable many of the input options that require pre/postprocessing.
100#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
101pub enum PromptType {
102    /// If allowed, this input type allowed the caller to pass a list of token_ids directly to the
103    /// inference engine. This is an advanced feature that requires the caller to handle all of the
104    /// necessary prompt formatting and tokenization.
105    #[serde(rename = "token_ids")]
106    TokenIds(Vec<TokenIdType>),
107
108    /// If allowed, the raw text will be tokenized and converted to token_ids without any additional
109    /// preprocessing. This is an advanced features that requires the caller to correctly format the
110    /// prompt as defined by the model.
111    #[serde(rename = "raw")]
112    Raw(String),
113
114    /// If allowed, the `CompletionContext` will be preprocessed server-side. If the `Model` trait
115    /// `requires_prompt_template` returns true then the `CompletionContext` will be used to
116    /// to render the formatted prompt from the template. `Completion` is the preferred `PromptType`
117    /// for single turn completions.
118    #[serde(rename = "completion")]
119    Completion(CompletionContext),
120
121    /// If allowed, the `ChatContext` will be preprocessed server-side. Most chat models will have
122    /// a predefined prompt format/structure. If the `Model` trait `requires_prompt_template` returns
123    /// true then the `ChatContext` will be used to to render the formatted prompt from the template.
124    /// `ChatCompletion` is the preferred `PromptType` for multi-turn completions.
125    #[serde(rename = "chat_completion")]
126    ChatCompletion(ChatContext),
127
128    /// If allowed, then `Model::requires_prompt_template()` must also return true. The `serde_json::Value`
129    /// will be passed directly the prompt template. This allows for a complete generic data model and
130    /// prompt template to be passed to be defined and used by the server.
131    #[serde(rename = "custom_json")]
132    CustomJson(serde_json::Value),
133}
134
135/// TensorRT LLM does not perform preprocessing or postprocessing. The input_ids / token_ids
136/// are expected to be preprocessed by the client. The client is responsible for constructing
137/// the model specific prompt template and applying the tokenizer.
138///
139/// TensorRT LLM will perform some server side postprocessing to ensure that generation is
140/// efficiently stopped. See `StopConditions` below.
141#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
142pub struct CompletionRequest {
143    /// Type of prompt
144    pub prompt: PromptType,
145
146    /// StopConditions are conditions that the inference engine will use to stop generation.
147    pub stop_conditions: StopConditions,
148
149    /// SamplingOptions directs the inference engine to use sampling instead of greedy decoding.
150    /// More documentation on how and on the order in which sampling options are applied
151    /// are needed.
152    pub sampling_options: SamplingOptions,
153
154    /// The computed checksum of the Model Deployment Card (MDC).
155    #[builder(default)]
156    pub mdc_sum: Option<String>,
157
158    /// User requested annotations for the request
159    #[builder(default)]
160    pub annotations: Option<Vec<String>>,
161}
162
163impl CompletionRequest {
164    pub fn builder() -> CompletionRequestBuilder {
165        CompletionRequestBuilder::default()
166    }
167}
168
169#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
170/// Defines the prompt template and system prompt for a completion request.
171/// If the model does not support prompt templates, the system_prompt will be ignored.
172pub struct CompletionContext {
173    /// Prompt sent by the user
174    pub prompt: String,
175
176    /// Optional system_prompt for models that support prompt templates with system_prompts.
177    pub system_prompt: Option<String>,
178}
179
180/// ChatTurn is a struct that contains the user and assistant messages in a chat.
181#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
182pub struct ChatTurn {
183    /// The user message
184    pub user: String,
185
186    /// The assistant response
187    pub assistant: String,
188}
189
190/// ChatContext is a struct that contains the role and context of a chat message
191/// along with a flattened CompletionContext.
192#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
193pub struct ChatContext {
194    /// CompletionContext for this chat turn
195    #[serde(flatten)]
196    pub completion: CompletionContext,
197
198    /// The history/context of the user and assistant messages in the chat context
199    pub context: Vec<ChatTurn>,
200}
201
202/// TensorRT LLM server-side stop conditions. These options allow for the server to evaluate
203/// the generated sequence and stop generation if the sequence meets a stop condition.
204#[derive(Serialize, Deserialize, Debug, Clone, Default)]
205pub struct StopConditions {
206    /// The maximum number of tokens to generate
207    pub max_tokens: Option<u32>,
208
209    /// List of strings that stop the generation when they are generated.
210    /// The returned output will not contain the stop strings.
211    pub stop: Option<Vec<String>>,
212
213    /// List of tokens that stop the generation when they are
214    /// generated. The returned output will NOT contain the stop tokens.
215    pub stop_token_ids_hidden: Option<Vec<TokenIdType>>,
216
217    /// The minimum number of tokens to generate
218    /// To ignore_eos, set min_tokens to max_tokens
219    pub min_tokens: Option<u32>,
220
221    /// Whether to ignore the EOS token and continue generating
222    /// tokens after the EOS token is generated.
223    // TODO(ignore_eos) - improve this my masking the EOS token with logit bias
224    pub ignore_eos: Option<bool>,
225}
226
227impl StopConditions {
228    pub fn apply_ignore_eos(&mut self) {
229        if self.ignore_eos.unwrap_or(false) {
230            self.min_tokens = self.max_tokens;
231            self.stop = None;
232            self.stop_token_ids_hidden = None;
233        }
234    }
235}
236
237/// Temperature range for sampling.
238pub const TEMPERATURE_RANGE: (f32, f32) = (0.0, 1.0);
239
240/// Top P range for sampling.
241pub const TOP_P_RANGE: (f32, f32) = (0.0, 1.0);
242
243/// Frequency Penalty range for sampling.
244pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (-1.0, 1.0);
245
246/// Collection of options that control the sampling behavior of the inference engine.
247#[derive(Serialize, Deserialize, Debug, Clone, Default)]
248pub struct SamplingOptions {
249    /// Number of output sequences to return for the given prompt
250    pub n: Option<i32>,
251
252    /// Number of output sequences that are generated from the prompt.
253    /// From these `best_of` sequences, the top `n` sequences are returned.
254    /// `best_of` must be greater than or equal to `n`. This is treated as
255    /// the beam width when `use_beam_search` is True. By default, `best_of`
256    /// is set to `n`.
257    pub best_of: Option<i32>,
258
259    /// Float that penalizes new tokens based on whether they
260    /// appear in the generated text so far. Values > 0 encourage the model
261    /// to use new tokens, while values < 0 encourage the model to repeat
262    /// tokens.
263    pub presence_penalty: Option<f32>,
264
265    /// Float that penalizes new tokens based on their
266    /// frequency in the generated text so far. Values > 0 encourage the
267    /// model to use new tokens, while values < 0 encourage the model to
268    /// repeat tokens.
269    pub frequency_penalty: Option<f32>,
270
271    /// Float that penalizes new tokens based on whether
272    /// they appear in the prompt and the generated text so far. Values > 1
273    /// encourage the model to use new tokens, while values < 1 encourage
274    /// the model to repeat tokens.
275    pub repetition_penalty: Option<f32>,
276
277    /// Float that controls the randomness of the sampling. Lower
278    /// values make the model more deterministic, while higher values make
279    /// the model more random. Zero means greedy sampling.
280    pub temperature: Option<f32>,
281
282    /// Float that controls the cumulative probability of the top tokens
283    /// to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
284    pub top_p: Option<f32>,
285
286    /// Integer that controls the number of top tokens to consider. Set
287    /// to -1 to consider all tokens.
288    pub top_k: Option<i32>,
289
290    /// Float that represents the minimum probability for a token to be
291    /// considered, relative to the probability of the most likely token.
292    /// Must be in [0, 1]. Set to 0 to disable this.
293    pub min_p: Option<f32>,
294
295    /// Whether to use beam search instead of sampling.
296    pub use_beam_search: Option<bool>,
297
298    /// Float that penalizes sequences based on their length.
299    /// Used in beam search.
300    pub length_penalty: Option<f32>,
301
302    /// The seed to use when sampling
303    pub seed: Option<i64>,
304}
305
306impl SamplingOptions {
307    pub fn force_greedy(&mut self) {
308        self.presence_penalty = None;
309        self.frequency_penalty = None;
310        self.repetition_penalty = None;
311        self.temperature = None;
312        self.top_p = None;
313        self.top_k = None;
314        self.min_p = None;
315    }
316}
317
318/// Collection of options that control what information the inference engine returns in the response.
319#[derive(Serialize, Deserialize, Debug, Clone, Default)]
320pub struct OutputOptions {
321    /// Number of log probabilities to return per output token.
322    /// Note that the implementation follows the OpenAI API: The return
323    /// result includes the log probabilities on the `logprobs` most likely
324    /// tokens, as well the chosen tokens. The API will always return the
325    /// log probability of the sampled token, so there  may be up to
326    /// `logprobs+1` elements in the response
327    pub logprobs: Option<u32>,
328
329    /// Number of log probabilities to return per prompt token.
330    pub prompt_logprobs: Option<u32>,
331
332    /// Whether to skip special tokens in the output.
333    /// spaces_between_special_tokens: Whether to add spaces between special
334    /// tokens in the output.  Defaults to True.
335    pub skip_special_tokens: Option<bool>,
336
337    /// If true, the Context object will contain the prompt that was pass to
338    /// the tokenizer. This is useful for inspecting the behavior of prompt
339    /// templates that are applied during the backend preprocessing.
340    pub formatted_prompt: Option<bool>,
341}
342
343// Struct for log probability information
344#[derive(Debug, Serialize, Deserialize, Clone)]
345pub struct ChatCompletionLogprobs {
346    /// A list of message content tokens with log probability information.
347    #[serde(skip_serializing_if = "Option::is_none")]
348    pub content: Option<Vec<ChatCompletionTokenLogprob>>,
349
350    /// A list of message refusal tokens with log probability information.
351    #[serde(skip_serializing_if = "Option::is_none")]
352    pub refusal: Option<Vec<ChatCompletionTokenLogprob>>,
353}
354
355#[derive(Debug, Serialize, Deserialize, Clone)]
356pub struct ChatCompletionTokenLogprob {
357    /// The token.
358    pub token: String,
359
360    /// The log probability of this token, if it is within the top 20 most likely tokens.
361    /// Otherwise, the value `-9999.0` signifies that the token is very unlikely.
362    pub logprob: f64,
363
364    /// A list of integers representing the UTF-8 bytes representation of the token.
365    /// Useful in instances where characters are represented by multiple tokens and their
366    /// byte representations must be combined to generate the correct text representation.
367    /// Can be `None` if there is no bytes representation for the token.
368    pub bytes: Option<Vec<u8>>,
369
370    /// List of the most likely tokens and their log probability, at this token position.
371    /// In rare cases, there may be fewer than the requested number of `top_logprobs` returned.
372    pub top_logprobs: Vec<TopLogprob>,
373}
374
375#[derive(Debug, Serialize, Deserialize, Clone)]
376pub struct TopLogprob {
377    /// The token.
378    pub token: String,
379
380    /// The log probability of this token.
381    pub logprob: f64,
382
383    /// A list of integers representing the UTF-8 bytes representation of the token.
384    /// Can be `None` if there is no bytes representation for the token.
385    pub bytes: Option<Vec<u8>>,
386}
387
388// /// UserData is a struct that contains user-defined data that can be passed to the inference engine.
389// /// This information will be use to annotate the distributed traces for improved observability.
390// #[derive(Serialize, Deserialize, Debug, Clone, Default)]
391// pub struct UserData {
392//     /// Apply server-side prompt template to the request
393//     pub request_uuid: Option<uuid::Uuid>,
394// }
395
396/// StreamingResponse is the primary response object for the LLM Engine. The response stream
397/// can emit three different types of messages. The Initialize and Finalize messages are optional
398/// and primarily used over disaggreated transports to move states from the server to the client.
399#[derive(Serialize, Deserialize, Debug)]
400pub enum StreamingResponse {
401    /// Initialize transports a Prologue object which communication the LLM Engine Context
402    Initialize(Option<Prologue>),
403
404    /// Step is the primary data in the response stream. It contains the StreamingCompletionResponse
405    Step(Box<StreamingCompletionResponse>),
406
407    /// Finalize is an optional final message in the response stream. It contains the Epilogue object which
408    /// is used to communicate extra information about the completion and the engine statistics.
409    Finalize(Option<Epilogue>),
410}
411
412// TODO(ryan) - this should be part of the internal api as it is not deserializble
413//              the public API should drop the Option<Arc<Stats>> in favor of Option<Stats>
414//              the two variants both serialize to the same json; however, the internal version
415//              can not be deserialized directly.
416//              we use the internal one on the server side to avoid the cost of cloning the Stats
417//              object; however, client side, we should always fully materialize the Stats object.
418//
419// TODO(ryan) - update this object to use an enum where we have the current definition be the
420//              StepResponse arm; then we will add the following arms:
421//              - Initialize(Prologue)
422//              - Step()
423//              - Finalize(Epilogue)
424
425/// This is the first message that will be emitted by an Engine Response Stream
426/// It indicates that the request has been preprocessed and queued for execution on the backend.
427#[derive(Serialize, Deserialize, Debug)]
428pub struct Prologue {
429    /// If the request was preprocessed with a prompt template, this will contain the formatted prompt
430    pub formatted_prompt: Option<String>,
431
432    /// If the request did not contain TokenIds, this will contain the token_ids that were generated
433    /// from tokenizing the prompt.
434    pub input_token_ids: Option<Vec<TokenIdType>>,
435}
436
437/// This is the final message that will be emitted by a Engine Response Stream when it
438/// finishes without error. In some cases, the engine may emit an error which will indicate
439/// the end of the steam. Another case in which an Finalize(Epilogue) will not be emitted is
440/// if the response handler has stalled and too many responses
441#[derive(Serialize, Deserialize, Debug)]
442pub struct Epilogue {}
443
444#[derive(Debug)]
445pub struct StreamingCompletionResponse {
446    pub delta: Delta,
447    pub logprobs: Option<ChatCompletionLogprobs>,
448}
449
450#[derive(Serialize, Deserialize, Debug, Clone)]
451pub enum StreamState {
452    Active,
453    Finished(FinishReason),
454}
455
456#[derive(Serialize, Deserialize, Debug, Clone)]
457#[serde(rename_all = "snake_case")]
458pub enum Logits {
459    All(Vec<f32>),
460    Sparse(Vec<(u32, f32)>),
461}
462
463#[derive(Serialize, Deserialize, Debug, Clone)]
464#[serde(rename_all = "snake_case")]
465pub enum LogProbs {
466    Normalized(Logits),
467    Raw(Logits),
468}
469
470/// At each SequencePosition we hold position specific data
471pub struct SequencePositionData {
472    pub token_id: TokenIdType,
473
474    /// The log probability of the token
475    pub logprobs: Option<LogProbs>,
476}
477
478// todo(ryan) - we need to create a DeltaBuilder which is a mutable object that can be passed
479// around from the low-level compute engine to the high-level api. The DeltaBuilder will allow
480// us to construct the Delta object at multiple layers in the streaming response path.
481#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
482pub struct Delta {
483    pub is_complete: bool,
484
485    pub finish_reason: Option<FinishReason>,
486
487    // new token_ids
488    pub token_ids: Option<Vec<u32>>,
489
490    // tokens
491    pub tokens: Option<Vec<String>>,
492
493    // decoded text
494    pub text: Option<String>,
495
496    // current sequence length
497    // when stream, we expect this to increase by 1 on each response
498    pub sequence_length: Option<usize>,
499
500    // if the number of slots for a given request is greater than 1
501    // this indicates the index of the slot for the response
502    pub index: Option<usize>,
503
504    /// cumulative log probabilities
505    pub cum_log_probs: Option<f64>,
506
507    /// error message from engine
508    /// if this is set, is_complete should also be true
509    pub err_msg: Option<String>,
510
511    /// usage info
512    pub usage: Option<Usage>,
513}
514
515#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
516pub struct Usage {
517    pub input_tokens_count: usize,
518    pub output_tokens_count: usize,
519}
520
521// todo(ryan) - we need to update this object to make it more generic
522// we need to define a set of generic stats traits that allow those stats to be None
523// then back them by a concrete implementation like a TrtllmStats object
524#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
525pub struct Stats {
526    /// Time since the last Epoch/Forward Pass in microseconds (us).
527    /// This is measured and recorded by the Response Router rather then the
528    /// Inference Engine. Note, when evaluating the responses, if the this
529    /// values is greater then the stream's measured value, then there was a gap
530    /// between forward passes. In normal operation, the value of this field should
531    /// be less than the recorded value on the response stream.
532    pub time_since_last_forward_pass_us: Option<u64>,
533
534    pub request_active_count: u32,
535
536    pub request_context_count: u32,
537
538    pub request_generation_count: u32,
539
540    pub request_scheduled_count: u32,
541
542    pub request_max_count: u32,
543
544    pub kv_free_cache_blocks: u64,
545
546    pub kv_max_cache_blocks: u64,
547
548    pub kv_used_cache_blocks: u64,
549
550    pub kv_tokens_per_cache_block: u64,
551
552    pub runtime_cpu_memory_usage: u64,
553
554    pub runtime_gpu_memory_usage: u64,
555
556    pub runtime_pinned_memory_usage: u64,
557
558    pub iteration_counter: u64,
559
560    pub microbatch_id: u64,
561
562    pub total_context_tokens: u32,
563
564    pub timestamp: String,
565}
566
567impl Serialize for StreamingCompletionResponse {
568    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
569    where
570        S: Serializer,
571    {
572        let mut state = serializer.serialize_struct("StreamingCompletionResponse", 2)?;
573
574        // Serialize `delta` field
575        state.serialize_field("delta", &self.delta)?;
576
577        state.end()
578    }
579}
580
581impl<'de> Deserialize<'de> for StreamingCompletionResponse {
582    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
583    where
584        D: Deserializer<'de>,
585    {
586        // Create a temporary struct for deserialization
587        #[derive(Deserialize)]
588        struct TempResponse {
589            delta: Delta,
590            logprobs: Option<ChatCompletionLogprobs>,
591        }
592
593        let TempResponse { delta, logprobs } = TempResponse::deserialize(deserializer)?;
594
595        Ok(StreamingCompletionResponse { delta, logprobs })
596    }
597}
598
599#[derive(Serialize, Deserialize, Debug)]
600pub struct ScatterData<T> {
601    pub x: Vec<T>,
602    pub y: Vec<T>,
603}
604
605#[derive(Serialize, Deserialize, Debug)]
606pub struct Trace {
607    pub time_to_first_token: u64,
608    pub token_to_token: Vec<u64>,
609    pub start: SystemTime,
610    pub complete: SystemTime,
611    pub initial_tokens: u32,
612    pub max_tokens: u32,
613    pub t2ft_iteration_count: u64,
614    pub t2t_iteration_count: Vec<u64>,
615}
616
617#[derive(Serialize, Deserialize, Debug)]
618pub struct PerformanceModel {
619    // linear regression parameters fitting t2ft vs. initial tokens
620    pub t2ft_intercept: f64,
621    pub t2ft_slope: f64,
622
623    // linear regression parameters fitting t2tl vs. initial tokens
624    pub t2tl_intercept: f64,
625    pub t2tl_slope: f64,
626
627    // r2 values from the regression
628    pub t2ft_fit_r2: f64,
629    pub t2tl_fit_r2: f64,
630}
631
632#[derive(Serialize, Deserialize, Debug)]
633pub struct CalibrationResults {
634    pub effective_flops: f64,
635    pub effective_memory_bandwidth: f64,
636    pub max_q: u32,
637    pub performance_model: PerformanceModel,
638    pub traces: Vec<Trace>,
639    pub t2ft_scatter_data: ScatterData<f64>,
640    pub t2tl_scatter_data: ScatterData<f64>,
641}
642
643#[derive(Serialize, Deserialize, Debug)]
644pub struct LoadgenResults {
645    pub stats_by_iteration: HashMap<u64, Stats>,
646    pub traces: Vec<Trace>,
647}
648
649impl CompletionContext {
650    /// Create a new CompletionContext
651    pub fn new(prompt: String, system_prompt: Option<String>) -> Self {
652        Self {
653            prompt,
654            system_prompt,
655        }
656    }
657
658    /// Create a new CompletionContext with only a prompt
659    pub fn from_prompt(prompt: String) -> Self {
660        Self {
661            prompt,
662            system_prompt: None,
663        }
664    }
665
666    /// Create a new CompletionContext with a prompt and system prompt
667    pub fn with_system_prompt(prompt: String, system_prompt: String) -> Self {
668        Self {
669            prompt,
670            system_prompt: Some(system_prompt),
671        }
672    }
673}
674
675// todo(ryan) - create a builder for chat context
676impl From<CompletionContext> for PromptType {
677    fn from(context: CompletionContext) -> Self {
678        PromptType::Completion(context)
679    }
680}
681
682#[cfg(test)]
683mod tests {
684    use serde_json;
685
686    use super::*;
687
688    #[test]
689    fn test_completion_context_new() {
690        let prompt = "Hello, world!".to_string();
691        let system_prompt = Some("This is a system prompt.".to_string());
692        let context = CompletionContext::new(prompt.clone(), system_prompt.clone());
693
694        assert_eq!(context.prompt, prompt);
695        assert_eq!(context.system_prompt, system_prompt);
696    }
697
698    #[test]
699    fn test_completion_context_from_prompt() {
700        let prompt = "Hello, world!".to_string();
701        let context = CompletionContext::from_prompt(prompt.clone());
702
703        assert_eq!(context.prompt, prompt);
704        assert_eq!(context.system_prompt, None);
705    }
706
707    #[test]
708    fn test_completion_context_with_system_prompt() {
709        let prompt = "Hello, world!".to_string();
710        let system_prompt = "This is a system prompt.".to_string();
711        let context = CompletionContext::with_system_prompt(prompt.clone(), system_prompt.clone());
712
713        assert_eq!(context.prompt, prompt);
714        assert_eq!(context.system_prompt, Some(system_prompt));
715    }
716
717    #[test]
718    fn test_completion_context_into_prompt_type() {
719        let prompt = "Hello, world!".to_string();
720        let system_prompt = "This is a system prompt.".to_string();
721        let context = CompletionContext::with_system_prompt(prompt.clone(), system_prompt.clone());
722        let prompt_type: PromptType = context.into();
723
724        if let PromptType::Completion(completion_context) = prompt_type {
725            assert_eq!(completion_context.prompt, prompt);
726            assert_eq!(completion_context.system_prompt, Some(system_prompt));
727        } else {
728            panic!("Expected a Completion variant");
729        }
730    }
731
732    // #[test]
733    // fn test_serialize_with_stats() {
734    //     let response = StreamingCompletionResponse {
735    //         delta: Delta {
736    //             is_complete: true,
737    //             finish_reason: Some(FinishReason::Length),
738    //             token_ids: Some(vec![101, 102, 103]),
739    //             tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
740    //             text: Some("example text".to_string()),
741    //             sequence_length: Some(3),
742    //             index: Some(0),
743    //             cum_log_probs: Some(-0.5),
744    //             err_msg: None,
745    //             usage: None,
746    //         },
747    //         logprobs: None,
748    //     };
749
750    //     // Serialize the response
751    //     let serialized = serde_json::to_string(&response).expect("Failed to serialize");
752
753    //     // Expected JSON string (simplified)
754    //     let expected = r#"{
755    //         "delta": {
756    //             "is_complete": true,
757    //             "finish_reason": "length",
758    //             "token_ids": [101, 102, 103],
759    //             "tokens": ["token1", "token2"],
760    //             "text": "example text",
761    //             "sequence_length": 3,
762    //             "index": 0,
763    //             "cum_log_probs": -0.5,
764    //             "err_msg": null,
765    //             "usage": null
766    //         },
767    //         "stats": {
768    //             "time_since_last_forward_pass_us": 1000,
769    //             "request_active_count": 2,
770    //             "request_context_count": 1,
771    //             "request_generation_count": 3,
772    //             "request_scheduled_count": 1,
773    //             "request_max_count": 10,
774    //             "kv_free_cache_blocks": 500,
775    //             "kv_max_cache_blocks": 1000,
776    //             "kv_used_cache_blocks": 500,
777    //             "kv_tokens_per_cache_block": 10,
778    //             "runtime_cpu_memory_usage": 5000,
779    //             "runtime_gpu_memory_usage": 2000,
780    //             "runtime_pinned_memory_usage": 1000,
781    //             "iteration_counter": 5,
782    //             "microbatch_id": 12345,
783    //             "total_context_tokens": 256,
784    //             "timestamp": "2024-01-01T00:00:00Z"
785    //         }
786    //     }"#;
787
788    //     assert_eq!(
789    //         serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
790    //         serde_json::from_str::<serde_json::Value>(expected).unwrap()
791    //     );
792    // }
793
794    #[test]
795    fn test_serialize_without_stats() {
796        let response = StreamingCompletionResponse {
797            delta: Delta {
798                is_complete: false,
799                finish_reason: None,
800                token_ids: None,
801                tokens: None,
802                text: None,
803                sequence_length: None,
804                index: None,
805                cum_log_probs: None,
806                err_msg: None,
807                usage: None,
808            },
809            logprobs: None,
810        };
811
812        // Serialize the response
813        let serialized = serde_json::to_string(&response).expect("Failed to serialize");
814
815        // Expected JSON string
816        let expected = r#"{
817            "delta": {
818                "is_complete": false,
819                "finish_reason": null,
820                "token_ids": null,
821                "tokens": null,
822                "text": null,
823                "sequence_length": null,
824                "index": null,
825                "cum_log_probs": null,
826                "err_msg": null,
827                "usage": null
828            }
829        }"#;
830
831        assert_eq!(
832            serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
833            serde_json::from_str::<serde_json::Value>(expected).unwrap()
834        );
835    }
836
837    // #[test]
838    // fn test_deserialize_with_stats() {
839    //     let json_data = r#"{
840    //         "delta": {
841    //             "is_complete": true,
842    //             "finish_reason": "length",
843    //             "token_ids": [101, 102, 103],
844    //             "tokens": ["token1", "token2"],
845    //             "text": "example text",
846    //             "sequence_length": 3,
847    //             "index": 0,
848    //             "cum_log_probs": -0.5,
849    //             "err_msg": null,
850    //             "usage": null
851    //         },
852    //         "stats": {
853    //             "time_since_last_forward_pass_us": 1000,
854    //             "request_active_count": 2,
855    //             "request_context_count": 1,
856    //             "request_generation_count": 3,
857    //             "request_scheduled_count": 1,
858    //             "request_max_count": 10,
859    //             "kv_free_cache_blocks": 500,
860    //             "kv_max_cache_blocks": 1000,
861    //             "kv_used_cache_blocks": 500,
862    //             "kv_tokens_per_cache_block": 10,
863    //             "runtime_cpu_memory_usage": 5000,
864    //             "runtime_gpu_memory_usage": 2000,
865    //             "runtime_pinned_memory_usage": 1000,
866    //             "iteration_counter": 5,
867    //             "microbatch_id": 12345,
868    //             "total_context_tokens": 256,
869    //             "timestamp": "2024-01-01T00:00:00Z"
870    //         }
871    //     }"#;
872
873    //     // Deserialize the JSON string
874    //     let deserialized: StreamingCompletionResponse =
875    //         serde_json::from_str(json_data).expect("Failed to deserialize");
876
877    //     // Expected response object
878    //     let expected = StreamingCompletionResponse {
879    //         delta: Delta {
880    //             is_complete: true,
881    //             finish_reason: Some(FinishReason::Length),
882    //             token_ids: Some(vec![101, 102, 103]),
883    //             tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
884    //             text: Some("example text".to_string()),
885    //             sequence_length: Some(3),
886    //             index: Some(0),
887    //             cum_log_probs: Some(-0.5),
888    //             err_msg: None,
889    //             usage: None,
890    //         },
891    //         logprobs: None,
892    //     };
893
894    //     // This is wieldy but we can no longer do assert_eq!(deserialized, expected);
895    //     // because the struct no longer has the PartialEq trait
896    //     assert_eq!(deserialized.delta.is_complete, expected.delta.is_complete);
897    //     assert_eq!(
898    //         deserialized.delta.finish_reason,
899    //         expected.delta.finish_reason
900    //     );
901    //     assert_eq!(deserialized.delta.token_ids, expected.delta.token_ids);
902    //     assert_eq!(deserialized.delta.tokens, expected.delta.tokens);
903    //     assert_eq!(deserialized.delta.text, expected.delta.text);
904    //     assert_eq!(
905    //         deserialized.delta.sequence_length,
906    //         expected.delta.sequence_length
907    //     );
908    //     assert_eq!(deserialized.delta.index, expected.delta.index);
909    //     assert_eq!(
910    //         deserialized.delta.cum_log_probs,
911    //         expected.delta.cum_log_probs
912    //     );
913    //     assert_eq!(deserialized.delta.err_msg, expected.delta.err_msg);
914    //     assert_eq!(deserialized.delta.usage, expected.delta.usage);
915
916    //     assert_eq!(
917    //         deserialized_stats.time_since_last_forward_pass_us,
918    //         expected_stats.time_since_last_forward_pass_us
919    //     );
920    //     assert_eq!(
921    //         deserialized_stats.request_active_count,
922    //         expected_stats.request_active_count
923    //     );
924    //     assert_eq!(
925    //         deserialized_stats.request_context_count,
926    //         expected_stats.request_context_count
927    //     );
928    //     assert_eq!(
929    //         deserialized_stats.request_generation_count,
930    //         expected_stats.request_generation_count
931    //     );
932    //     assert_eq!(
933    //         deserialized_stats.request_scheduled_count,
934    //         expected_stats.request_scheduled_count
935    //     );
936    //     assert_eq!(
937    //         deserialized_stats.request_max_count,
938    //         expected_stats.request_max_count
939    //     );
940    //     assert_eq!(
941    //         deserialized_stats.kv_free_cache_blocks,
942    //         expected_stats.kv_free_cache_blocks
943    //     );
944    //     assert_eq!(
945    //         deserialized_stats.kv_max_cache_blocks,
946    //         expected_stats.kv_max_cache_blocks
947    //     );
948    //     assert_eq!(
949    //         deserialized_stats.kv_used_cache_blocks,
950    //         expected_stats.kv_used_cache_blocks
951    //     );
952    //     assert_eq!(
953    //         deserialized_stats.kv_tokens_per_cache_block,
954    //         expected_stats.kv_tokens_per_cache_block
955    //     );
956    //     assert_eq!(
957    //         deserialized_stats.runtime_cpu_memory_usage,
958    //         expected_stats.runtime_cpu_memory_usage
959    //     );
960    //     assert_eq!(
961    //         deserialized_stats.runtime_gpu_memory_usage,
962    //         expected_stats.runtime_gpu_memory_usage
963    //     );
964    //     assert_eq!(
965    //         deserialized_stats.runtime_pinned_memory_usage,
966    //         expected_stats.runtime_pinned_memory_usage
967    //     );
968    //     assert_eq!(
969    //         deserialized_stats.iteration_counter,
970    //         expected_stats.iteration_counter
971    //     );
972    //     assert_eq!(
973    //         deserialized_stats.microbatch_id,
974    //         expected_stats.microbatch_id
975    //     );
976    //     assert_eq!(
977    //         deserialized_stats.total_context_tokens,
978    //         expected_stats.total_context_tokens
979    //     );
980    //     assert_eq!(deserialized_stats.timestamp, expected_stats.timestamp);
981    // }
982
983    #[test]
984    fn test_deserialize_without_stats() {
985        let json_data = r#"{
986            "delta": {
987                "is_complete": false,
988                "finish_reason": null,
989                "token_ids": null,
990                "tokens": null,
991                "text": null,
992                "sequence_length": null,
993                "index": null,
994                "cum_log_probs": null,
995                "err_msg": null,
996                "usage": null
997            }
998        }"#;
999
1000        // Deserialize the JSON string
1001        let deserialized: StreamingCompletionResponse =
1002            serde_json::from_str(json_data).expect("Failed to deserialize");
1003
1004        // Expected response object
1005        let expected = StreamingCompletionResponse {
1006            delta: Delta {
1007                is_complete: false,
1008                finish_reason: None,
1009                token_ids: None,
1010                tokens: None,
1011                text: None,
1012                sequence_length: None,
1013                index: None,
1014                cum_log_probs: None,
1015                err_msg: None,
1016                usage: None,
1017            },
1018            logprobs: None,
1019        };
1020
1021        // This is wieldy but we can no longer do assert_eq!(deserialized, expected);
1022        // because the struct no longer has the PartialEq trait
1023        assert_eq!(deserialized.delta.is_complete, expected.delta.is_complete);
1024        assert_eq!(
1025            deserialized.delta.finish_reason,
1026            expected.delta.finish_reason
1027        );
1028        assert_eq!(deserialized.delta.token_ids, expected.delta.token_ids);
1029        assert_eq!(deserialized.delta.tokens, expected.delta.tokens);
1030        assert_eq!(deserialized.delta.text, expected.delta.text);
1031        assert_eq!(
1032            deserialized.delta.sequence_length,
1033            expected.delta.sequence_length
1034        );
1035        assert_eq!(deserialized.delta.index, expected.delta.index);
1036        assert_eq!(
1037            deserialized.delta.cum_log_probs,
1038            expected.delta.cum_log_probs
1039        );
1040        assert_eq!(deserialized.delta.err_msg, expected.delta.err_msg);
1041        assert_eq!(deserialized.delta.usage, expected.delta.usage);
1042    }
1043
1044    #[test]
1045    fn test_serialize_delta_and_none_stats() {
1046        let response = StreamingCompletionResponse {
1047            delta: Delta {
1048                is_complete: true,
1049                finish_reason: Some(FinishReason::Length),
1050                token_ids: Some(vec![101, 102, 103]),
1051                tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
1052                text: Some("example text".to_string()),
1053                sequence_length: Some(3),
1054                index: Some(0),
1055                cum_log_probs: Some(-0.5),
1056                err_msg: None,
1057                usage: None,
1058            },
1059            logprobs: None,
1060        };
1061
1062        // Serialize the response
1063        let serialized = serde_json::to_string(&response).expect("Failed to serialize");
1064
1065        // Expected JSON string where stats is null
1066        let expected_json = r#"{
1067            "delta": {
1068                "is_complete": true,
1069                "finish_reason": "length",
1070                "token_ids": [101, 102, 103],
1071                "tokens": ["token1", "token2"],
1072                "text": "example text",
1073                "sequence_length": 3,
1074                "index": 0,
1075                "cum_log_probs": -0.5,
1076                "err_msg": null,
1077                "usage": null
1078            }
1079        }"#;
1080
1081        // Parse both the serialized response and the expected JSON as serde_json::Value for easy comparison
1082        assert_eq!(
1083            serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
1084            serde_json::from_str::<serde_json::Value>(expected_json).unwrap()
1085        );
1086    }
1087}