dynamo_llm/protocols/common.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Engine Protocols
17//! ================
18//!
19//! This module contains the protocols in public API for the LLM Engine and AsyncEngine facades.
20//!
21//! The core components are the `CompletionRequest` and `StreamingCompletionResponse` objects.
22//!
23//! The `StreamingCompletionResponse` objects are the outputs of the LLM Engine; however, we
24//! need some additional information to propagate intermediate results for improved observability.
25//! The metadata is transferred via the other arms of the `StreamingResponse` enum.
26//!
27
28use anyhow::Result;
29use derive_builder::Builder;
30use serde::ser::SerializeStruct;
31use serde::{Deserialize, Deserializer, Serialize, Serializer};
32use std::collections::HashMap;
33use std::time::SystemTime;
34
35use super::TokenIdType;
36
37pub mod llm_backend;
38pub mod postprocessor;
39pub mod preprocessor;
40
41/// SamplingOptionsProvider is a trait that allows the caller to extract the sampling options from
42/// the object that implements it. This will mutate the object.
43pub trait SamplingOptionsProvider {
44 fn extract_sampling_options(&self) -> Result<SamplingOptions>;
45}
46
47pub trait StopConditionsProvider {
48 fn extract_stop_conditions(&self) -> Result<StopConditions>;
49}
50
51#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
52pub enum FinishReason {
53 #[serde(rename = "eos")]
54 EoS,
55
56 #[serde(rename = "length")]
57 Length,
58
59 #[serde(rename = "stop")]
60 Stop,
61
62 #[serde(rename = "error")]
63 Error(String),
64
65 #[serde(rename = "cancelled")]
66 Cancelled,
67}
68
69impl std::fmt::Display for FinishReason {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 FinishReason::EoS => write!(f, "eos"),
73 FinishReason::Length => write!(f, "length"),
74 FinishReason::Stop => write!(f, "stop"),
75 FinishReason::Error(msg) => write!(f, "error: {}", msg),
76 FinishReason::Cancelled => write!(f, "cancelled"),
77 }
78 }
79}
80
81impl std::str::FromStr for FinishReason {
82 type Err = anyhow::Error;
83
84 fn from_str(s: &str) -> Result<Self, Self::Err> {
85 match s {
86 "eos" => Ok(FinishReason::EoS),
87 "length" => Ok(FinishReason::Length),
88 "stop" => Ok(FinishReason::Stop),
89 "cancelled" => Ok(FinishReason::Cancelled),
90 s if s.starts_with("error: ") => Ok(FinishReason::Error(s[7..].to_string())),
91 _ => Err(anyhow::anyhow!("Invalid FinishReason variant: '{}'", s)),
92 }
93 }
94}
95
96/// LLM Inference Engines can accept a variety of input types. Not all Engines will support all
97/// input types. For example, the trtllm::AsyncEngine only supports `PromptType::Tokens` as an
98/// input type. The higher-level `Backend` class is a general wrapper around Engines that will
99/// enable many of the input options that require pre/postprocessing.
100#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
101pub enum PromptType {
102 /// If allowed, this input type allowed the caller to pass a list of token_ids directly to the
103 /// inference engine. This is an advanced feature that requires the caller to handle all of the
104 /// necessary prompt formatting and tokenization.
105 #[serde(rename = "token_ids")]
106 TokenIds(Vec<TokenIdType>),
107
108 /// If allowed, the raw text will be tokenized and converted to token_ids without any additional
109 /// preprocessing. This is an advanced features that requires the caller to correctly format the
110 /// prompt as defined by the model.
111 #[serde(rename = "raw")]
112 Raw(String),
113
114 /// If allowed, the `CompletionContext` will be preprocessed server-side. If the `Model` trait
115 /// `requires_prompt_template` returns true then the `CompletionContext` will be used to
116 /// to render the formatted prompt from the template. `Completion` is the preferred `PromptType`
117 /// for single turn completions.
118 #[serde(rename = "completion")]
119 Completion(CompletionContext),
120
121 /// If allowed, the `ChatContext` will be preprocessed server-side. Most chat models will have
122 /// a predefined prompt format/structure. If the `Model` trait `requires_prompt_template` returns
123 /// true then the `ChatContext` will be used to to render the formatted prompt from the template.
124 /// `ChatCompletion` is the preferred `PromptType` for multi-turn completions.
125 #[serde(rename = "chat_completion")]
126 ChatCompletion(ChatContext),
127
128 /// If allowed, then `Model::requires_prompt_template()` must also return true. The `serde_json::Value`
129 /// will be passed directly the prompt template. This allows for a complete generic data model and
130 /// prompt template to be passed to be defined and used by the server.
131 #[serde(rename = "custom_json")]
132 CustomJson(serde_json::Value),
133}
134
135/// TensorRT LLM does not perform preprocessing or postprocessing. The input_ids / token_ids
136/// are expected to be preprocessed by the client. The client is responsible for constructing
137/// the model specific prompt template and applying the tokenizer.
138///
139/// TensorRT LLM will perform some server side postprocessing to ensure that generation is
140/// efficiently stopped. See `StopConditions` below.
141#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
142pub struct CompletionRequest {
143 /// Type of prompt
144 pub prompt: PromptType,
145
146 /// StopConditions are conditions that the inference engine will use to stop generation.
147 pub stop_conditions: StopConditions,
148
149 /// SamplingOptions directs the inference engine to use sampling instead of greedy decoding.
150 /// More documentation on how and on the order in which sampling options are applied
151 /// are needed.
152 pub sampling_options: SamplingOptions,
153
154 /// The computed checksum of the Model Deployment Card (MDC).
155 #[builder(default)]
156 pub mdc_sum: Option<String>,
157
158 /// User requested annotations for the request
159 #[builder(default)]
160 pub annotations: Option<Vec<String>>,
161}
162
163impl CompletionRequest {
164 pub fn builder() -> CompletionRequestBuilder {
165 CompletionRequestBuilder::default()
166 }
167}
168
169#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
170/// Defines the prompt template and system prompt for a completion request.
171/// If the model does not support prompt templates, the system_prompt will be ignored.
172pub struct CompletionContext {
173 /// Prompt sent by the user
174 pub prompt: String,
175
176 /// Optional system_prompt for models that support prompt templates with system_prompts.
177 pub system_prompt: Option<String>,
178}
179
180/// ChatTurn is a struct that contains the user and assistant messages in a chat.
181#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
182pub struct ChatTurn {
183 /// The user message
184 pub user: String,
185
186 /// The assistant response
187 pub assistant: String,
188}
189
190/// ChatContext is a struct that contains the role and context of a chat message
191/// along with a flattened CompletionContext.
192#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
193pub struct ChatContext {
194 /// CompletionContext for this chat turn
195 #[serde(flatten)]
196 pub completion: CompletionContext,
197
198 /// The history/context of the user and assistant messages in the chat context
199 pub context: Vec<ChatTurn>,
200}
201
202/// TensorRT LLM server-side stop conditions. These options allow for the server to evaluate
203/// the generated sequence and stop generation if the sequence meets a stop condition.
204#[derive(Serialize, Deserialize, Debug, Clone, Default)]
205pub struct StopConditions {
206 /// The maximum number of tokens to generate
207 pub max_tokens: Option<u32>,
208
209 /// List of strings that stop the generation when they are generated.
210 /// The returned output will not contain the stop strings.
211 pub stop: Option<Vec<String>>,
212
213 /// List of tokens that stop the generation when they are
214 /// generated. The returned output will NOT contain the stop tokens.
215 pub stop_token_ids_hidden: Option<Vec<TokenIdType>>,
216
217 /// The minimum number of tokens to generate
218 /// To ignore_eos, set min_tokens to max_tokens
219 pub min_tokens: Option<u32>,
220
221 /// Whether to ignore the EOS token and continue generating
222 /// tokens after the EOS token is generated.
223 // TODO(ignore_eos) - improve this my masking the EOS token with logit bias
224 pub ignore_eos: Option<bool>,
225}
226
227impl StopConditions {
228 pub fn apply_ignore_eos(&mut self) {
229 if self.ignore_eos.unwrap_or(false) {
230 self.min_tokens = self.max_tokens;
231 self.stop = None;
232 self.stop_token_ids_hidden = None;
233 }
234 }
235}
236
237/// Temperature range for sampling.
238pub const TEMPERATURE_RANGE: (f32, f32) = (0.0, 1.0);
239
240/// Top P range for sampling.
241pub const TOP_P_RANGE: (f32, f32) = (0.0, 1.0);
242
243/// Frequency Penalty range for sampling.
244pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (-1.0, 1.0);
245
246/// Collection of options that control the sampling behavior of the inference engine.
247#[derive(Serialize, Deserialize, Debug, Clone, Default)]
248pub struct SamplingOptions {
249 /// Number of output sequences to return for the given prompt
250 pub n: Option<i32>,
251
252 /// Number of output sequences that are generated from the prompt.
253 /// From these `best_of` sequences, the top `n` sequences are returned.
254 /// `best_of` must be greater than or equal to `n`. This is treated as
255 /// the beam width when `use_beam_search` is True. By default, `best_of`
256 /// is set to `n`.
257 pub best_of: Option<i32>,
258
259 /// Float that penalizes new tokens based on whether they
260 /// appear in the generated text so far. Values > 0 encourage the model
261 /// to use new tokens, while values < 0 encourage the model to repeat
262 /// tokens.
263 pub presence_penalty: Option<f32>,
264
265 /// Float that penalizes new tokens based on their
266 /// frequency in the generated text so far. Values > 0 encourage the
267 /// model to use new tokens, while values < 0 encourage the model to
268 /// repeat tokens.
269 pub frequency_penalty: Option<f32>,
270
271 /// Float that penalizes new tokens based on whether
272 /// they appear in the prompt and the generated text so far. Values > 1
273 /// encourage the model to use new tokens, while values < 1 encourage
274 /// the model to repeat tokens.
275 pub repetition_penalty: Option<f32>,
276
277 /// Float that controls the randomness of the sampling. Lower
278 /// values make the model more deterministic, while higher values make
279 /// the model more random. Zero means greedy sampling.
280 pub temperature: Option<f32>,
281
282 /// Float that controls the cumulative probability of the top tokens
283 /// to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
284 pub top_p: Option<f32>,
285
286 /// Integer that controls the number of top tokens to consider. Set
287 /// to -1 to consider all tokens.
288 pub top_k: Option<i32>,
289
290 /// Float that represents the minimum probability for a token to be
291 /// considered, relative to the probability of the most likely token.
292 /// Must be in [0, 1]. Set to 0 to disable this.
293 pub min_p: Option<f32>,
294
295 /// Whether to use beam search instead of sampling.
296 pub use_beam_search: Option<bool>,
297
298 /// Float that penalizes sequences based on their length.
299 /// Used in beam search.
300 pub length_penalty: Option<f32>,
301
302 /// The seed to use when sampling
303 pub seed: Option<i64>,
304}
305
306impl SamplingOptions {
307 pub fn force_greedy(&mut self) {
308 self.presence_penalty = None;
309 self.frequency_penalty = None;
310 self.repetition_penalty = None;
311 self.temperature = None;
312 self.top_p = None;
313 self.top_k = None;
314 self.min_p = None;
315 }
316}
317
318/// Collection of options that control what information the inference engine returns in the response.
319#[derive(Serialize, Deserialize, Debug, Clone, Default)]
320pub struct OutputOptions {
321 /// Number of log probabilities to return per output token.
322 /// Note that the implementation follows the OpenAI API: The return
323 /// result includes the log probabilities on the `logprobs` most likely
324 /// tokens, as well the chosen tokens. The API will always return the
325 /// log probability of the sampled token, so there may be up to
326 /// `logprobs+1` elements in the response
327 pub logprobs: Option<u32>,
328
329 /// Number of log probabilities to return per prompt token.
330 pub prompt_logprobs: Option<u32>,
331
332 /// Whether to skip special tokens in the output.
333 /// spaces_between_special_tokens: Whether to add spaces between special
334 /// tokens in the output. Defaults to True.
335 pub skip_special_tokens: Option<bool>,
336
337 /// If true, the Context object will contain the prompt that was pass to
338 /// the tokenizer. This is useful for inspecting the behavior of prompt
339 /// templates that are applied during the backend preprocessing.
340 pub formatted_prompt: Option<bool>,
341}
342
343// Struct for log probability information
344#[derive(Debug, Serialize, Deserialize, Clone)]
345pub struct ChatCompletionLogprobs {
346 /// A list of message content tokens with log probability information.
347 #[serde(skip_serializing_if = "Option::is_none")]
348 pub content: Option<Vec<ChatCompletionTokenLogprob>>,
349
350 /// A list of message refusal tokens with log probability information.
351 #[serde(skip_serializing_if = "Option::is_none")]
352 pub refusal: Option<Vec<ChatCompletionTokenLogprob>>,
353}
354
355#[derive(Debug, Serialize, Deserialize, Clone)]
356pub struct ChatCompletionTokenLogprob {
357 /// The token.
358 pub token: String,
359
360 /// The log probability of this token, if it is within the top 20 most likely tokens.
361 /// Otherwise, the value `-9999.0` signifies that the token is very unlikely.
362 pub logprob: f64,
363
364 /// A list of integers representing the UTF-8 bytes representation of the token.
365 /// Useful in instances where characters are represented by multiple tokens and their
366 /// byte representations must be combined to generate the correct text representation.
367 /// Can be `None` if there is no bytes representation for the token.
368 pub bytes: Option<Vec<u8>>,
369
370 /// List of the most likely tokens and their log probability, at this token position.
371 /// In rare cases, there may be fewer than the requested number of `top_logprobs` returned.
372 pub top_logprobs: Vec<TopLogprob>,
373}
374
375#[derive(Debug, Serialize, Deserialize, Clone)]
376pub struct TopLogprob {
377 /// The token.
378 pub token: String,
379
380 /// The log probability of this token.
381 pub logprob: f64,
382
383 /// A list of integers representing the UTF-8 bytes representation of the token.
384 /// Can be `None` if there is no bytes representation for the token.
385 pub bytes: Option<Vec<u8>>,
386}
387
388// /// UserData is a struct that contains user-defined data that can be passed to the inference engine.
389// /// This information will be use to annotate the distributed traces for improved observability.
390// #[derive(Serialize, Deserialize, Debug, Clone, Default)]
391// pub struct UserData {
392// /// Apply server-side prompt template to the request
393// pub request_uuid: Option<uuid::Uuid>,
394// }
395
396/// StreamingResponse is the primary response object for the LLM Engine. The response stream
397/// can emit three different types of messages. The Initialize and Finalize messages are optional
398/// and primarily used over disaggreated transports to move states from the server to the client.
399#[derive(Serialize, Deserialize, Debug)]
400pub enum StreamingResponse {
401 /// Initialize transports a Prologue object which communication the LLM Engine Context
402 Initialize(Option<Prologue>),
403
404 /// Step is the primary data in the response stream. It contains the StreamingCompletionResponse
405 Step(Box<StreamingCompletionResponse>),
406
407 /// Finalize is an optional final message in the response stream. It contains the Epilogue object which
408 /// is used to communicate extra information about the completion and the engine statistics.
409 Finalize(Option<Epilogue>),
410}
411
412// TODO(ryan) - this should be part of the internal api as it is not deserializble
413// the public API should drop the Option<Arc<Stats>> in favor of Option<Stats>
414// the two variants both serialize to the same json; however, the internal version
415// can not be deserialized directly.
416// we use the internal one on the server side to avoid the cost of cloning the Stats
417// object; however, client side, we should always fully materialize the Stats object.
418//
419// TODO(ryan) - update this object to use an enum where we have the current definition be the
420// StepResponse arm; then we will add the following arms:
421// - Initialize(Prologue)
422// - Step()
423// - Finalize(Epilogue)
424
425/// This is the first message that will be emitted by an Engine Response Stream
426/// It indicates that the request has been preprocessed and queued for execution on the backend.
427#[derive(Serialize, Deserialize, Debug)]
428pub struct Prologue {
429 /// If the request was preprocessed with a prompt template, this will contain the formatted prompt
430 pub formatted_prompt: Option<String>,
431
432 /// If the request did not contain TokenIds, this will contain the token_ids that were generated
433 /// from tokenizing the prompt.
434 pub input_token_ids: Option<Vec<TokenIdType>>,
435}
436
437/// This is the final message that will be emitted by a Engine Response Stream when it
438/// finishes without error. In some cases, the engine may emit an error which will indicate
439/// the end of the steam. Another case in which an Finalize(Epilogue) will not be emitted is
440/// if the response handler has stalled and too many responses
441#[derive(Serialize, Deserialize, Debug)]
442pub struct Epilogue {}
443
444#[derive(Debug)]
445pub struct StreamingCompletionResponse {
446 pub delta: Delta,
447 pub logprobs: Option<ChatCompletionLogprobs>,
448}
449
450#[derive(Serialize, Deserialize, Debug, Clone)]
451pub enum StreamState {
452 Active,
453 Finished(FinishReason),
454}
455
456#[derive(Serialize, Deserialize, Debug, Clone)]
457#[serde(rename_all = "snake_case")]
458pub enum Logits {
459 All(Vec<f32>),
460 Sparse(Vec<(u32, f32)>),
461}
462
463#[derive(Serialize, Deserialize, Debug, Clone)]
464#[serde(rename_all = "snake_case")]
465pub enum LogProbs {
466 Normalized(Logits),
467 Raw(Logits),
468}
469
470/// At each SequencePosition we hold position specific data
471pub struct SequencePositionData {
472 pub token_id: TokenIdType,
473
474 /// The log probability of the token
475 pub logprobs: Option<LogProbs>,
476}
477
478// todo(ryan) - we need to create a DeltaBuilder which is a mutable object that can be passed
479// around from the low-level compute engine to the high-level api. The DeltaBuilder will allow
480// us to construct the Delta object at multiple layers in the streaming response path.
481#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
482pub struct Delta {
483 pub is_complete: bool,
484
485 pub finish_reason: Option<FinishReason>,
486
487 // new token_ids
488 pub token_ids: Option<Vec<u32>>,
489
490 // tokens
491 pub tokens: Option<Vec<String>>,
492
493 // decoded text
494 pub text: Option<String>,
495
496 // current sequence length
497 // when stream, we expect this to increase by 1 on each response
498 pub sequence_length: Option<usize>,
499
500 // if the number of slots for a given request is greater than 1
501 // this indicates the index of the slot for the response
502 pub index: Option<usize>,
503
504 /// cumulative log probabilities
505 pub cum_log_probs: Option<f64>,
506
507 /// error message from engine
508 /// if this is set, is_complete should also be true
509 pub err_msg: Option<String>,
510
511 /// usage info
512 pub usage: Option<Usage>,
513}
514
515#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
516pub struct Usage {
517 pub input_tokens_count: usize,
518 pub output_tokens_count: usize,
519}
520
521// todo(ryan) - we need to update this object to make it more generic
522// we need to define a set of generic stats traits that allow those stats to be None
523// then back them by a concrete implementation like a TrtllmStats object
524#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
525pub struct Stats {
526 /// Time since the last Epoch/Forward Pass in microseconds (us).
527 /// This is measured and recorded by the Response Router rather then the
528 /// Inference Engine. Note, when evaluating the responses, if the this
529 /// values is greater then the stream's measured value, then there was a gap
530 /// between forward passes. In normal operation, the value of this field should
531 /// be less than the recorded value on the response stream.
532 pub time_since_last_forward_pass_us: Option<u64>,
533
534 pub request_active_count: u32,
535
536 pub request_context_count: u32,
537
538 pub request_generation_count: u32,
539
540 pub request_scheduled_count: u32,
541
542 pub request_max_count: u32,
543
544 pub kv_free_cache_blocks: u64,
545
546 pub kv_max_cache_blocks: u64,
547
548 pub kv_used_cache_blocks: u64,
549
550 pub kv_tokens_per_cache_block: u64,
551
552 pub runtime_cpu_memory_usage: u64,
553
554 pub runtime_gpu_memory_usage: u64,
555
556 pub runtime_pinned_memory_usage: u64,
557
558 pub iteration_counter: u64,
559
560 pub microbatch_id: u64,
561
562 pub total_context_tokens: u32,
563
564 pub timestamp: String,
565}
566
567impl Serialize for StreamingCompletionResponse {
568 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
569 where
570 S: Serializer,
571 {
572 let mut state = serializer.serialize_struct("StreamingCompletionResponse", 2)?;
573
574 // Serialize `delta` field
575 state.serialize_field("delta", &self.delta)?;
576
577 state.end()
578 }
579}
580
581impl<'de> Deserialize<'de> for StreamingCompletionResponse {
582 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
583 where
584 D: Deserializer<'de>,
585 {
586 // Create a temporary struct for deserialization
587 #[derive(Deserialize)]
588 struct TempResponse {
589 delta: Delta,
590 logprobs: Option<ChatCompletionLogprobs>,
591 }
592
593 let TempResponse { delta, logprobs } = TempResponse::deserialize(deserializer)?;
594
595 Ok(StreamingCompletionResponse { delta, logprobs })
596 }
597}
598
599#[derive(Serialize, Deserialize, Debug)]
600pub struct ScatterData<T> {
601 pub x: Vec<T>,
602 pub y: Vec<T>,
603}
604
605#[derive(Serialize, Deserialize, Debug)]
606pub struct Trace {
607 pub time_to_first_token: u64,
608 pub token_to_token: Vec<u64>,
609 pub start: SystemTime,
610 pub complete: SystemTime,
611 pub initial_tokens: u32,
612 pub max_tokens: u32,
613 pub t2ft_iteration_count: u64,
614 pub t2t_iteration_count: Vec<u64>,
615}
616
617#[derive(Serialize, Deserialize, Debug)]
618pub struct PerformanceModel {
619 // linear regression parameters fitting t2ft vs. initial tokens
620 pub t2ft_intercept: f64,
621 pub t2ft_slope: f64,
622
623 // linear regression parameters fitting t2tl vs. initial tokens
624 pub t2tl_intercept: f64,
625 pub t2tl_slope: f64,
626
627 // r2 values from the regression
628 pub t2ft_fit_r2: f64,
629 pub t2tl_fit_r2: f64,
630}
631
632#[derive(Serialize, Deserialize, Debug)]
633pub struct CalibrationResults {
634 pub effective_flops: f64,
635 pub effective_memory_bandwidth: f64,
636 pub max_q: u32,
637 pub performance_model: PerformanceModel,
638 pub traces: Vec<Trace>,
639 pub t2ft_scatter_data: ScatterData<f64>,
640 pub t2tl_scatter_data: ScatterData<f64>,
641}
642
643#[derive(Serialize, Deserialize, Debug)]
644pub struct LoadgenResults {
645 pub stats_by_iteration: HashMap<u64, Stats>,
646 pub traces: Vec<Trace>,
647}
648
649impl CompletionContext {
650 /// Create a new CompletionContext
651 pub fn new(prompt: String, system_prompt: Option<String>) -> Self {
652 Self {
653 prompt,
654 system_prompt,
655 }
656 }
657
658 /// Create a new CompletionContext with only a prompt
659 pub fn from_prompt(prompt: String) -> Self {
660 Self {
661 prompt,
662 system_prompt: None,
663 }
664 }
665
666 /// Create a new CompletionContext with a prompt and system prompt
667 pub fn with_system_prompt(prompt: String, system_prompt: String) -> Self {
668 Self {
669 prompt,
670 system_prompt: Some(system_prompt),
671 }
672 }
673}
674
675// todo(ryan) - create a builder for chat context
676impl From<CompletionContext> for PromptType {
677 fn from(context: CompletionContext) -> Self {
678 PromptType::Completion(context)
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use serde_json;
685
686 use super::*;
687
688 #[test]
689 fn test_completion_context_new() {
690 let prompt = "Hello, world!".to_string();
691 let system_prompt = Some("This is a system prompt.".to_string());
692 let context = CompletionContext::new(prompt.clone(), system_prompt.clone());
693
694 assert_eq!(context.prompt, prompt);
695 assert_eq!(context.system_prompt, system_prompt);
696 }
697
698 #[test]
699 fn test_completion_context_from_prompt() {
700 let prompt = "Hello, world!".to_string();
701 let context = CompletionContext::from_prompt(prompt.clone());
702
703 assert_eq!(context.prompt, prompt);
704 assert_eq!(context.system_prompt, None);
705 }
706
707 #[test]
708 fn test_completion_context_with_system_prompt() {
709 let prompt = "Hello, world!".to_string();
710 let system_prompt = "This is a system prompt.".to_string();
711 let context = CompletionContext::with_system_prompt(prompt.clone(), system_prompt.clone());
712
713 assert_eq!(context.prompt, prompt);
714 assert_eq!(context.system_prompt, Some(system_prompt));
715 }
716
717 #[test]
718 fn test_completion_context_into_prompt_type() {
719 let prompt = "Hello, world!".to_string();
720 let system_prompt = "This is a system prompt.".to_string();
721 let context = CompletionContext::with_system_prompt(prompt.clone(), system_prompt.clone());
722 let prompt_type: PromptType = context.into();
723
724 if let PromptType::Completion(completion_context) = prompt_type {
725 assert_eq!(completion_context.prompt, prompt);
726 assert_eq!(completion_context.system_prompt, Some(system_prompt));
727 } else {
728 panic!("Expected a Completion variant");
729 }
730 }
731
732 // #[test]
733 // fn test_serialize_with_stats() {
734 // let response = StreamingCompletionResponse {
735 // delta: Delta {
736 // is_complete: true,
737 // finish_reason: Some(FinishReason::Length),
738 // token_ids: Some(vec![101, 102, 103]),
739 // tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
740 // text: Some("example text".to_string()),
741 // sequence_length: Some(3),
742 // index: Some(0),
743 // cum_log_probs: Some(-0.5),
744 // err_msg: None,
745 // usage: None,
746 // },
747 // logprobs: None,
748 // };
749
750 // // Serialize the response
751 // let serialized = serde_json::to_string(&response).expect("Failed to serialize");
752
753 // // Expected JSON string (simplified)
754 // let expected = r#"{
755 // "delta": {
756 // "is_complete": true,
757 // "finish_reason": "length",
758 // "token_ids": [101, 102, 103],
759 // "tokens": ["token1", "token2"],
760 // "text": "example text",
761 // "sequence_length": 3,
762 // "index": 0,
763 // "cum_log_probs": -0.5,
764 // "err_msg": null,
765 // "usage": null
766 // },
767 // "stats": {
768 // "time_since_last_forward_pass_us": 1000,
769 // "request_active_count": 2,
770 // "request_context_count": 1,
771 // "request_generation_count": 3,
772 // "request_scheduled_count": 1,
773 // "request_max_count": 10,
774 // "kv_free_cache_blocks": 500,
775 // "kv_max_cache_blocks": 1000,
776 // "kv_used_cache_blocks": 500,
777 // "kv_tokens_per_cache_block": 10,
778 // "runtime_cpu_memory_usage": 5000,
779 // "runtime_gpu_memory_usage": 2000,
780 // "runtime_pinned_memory_usage": 1000,
781 // "iteration_counter": 5,
782 // "microbatch_id": 12345,
783 // "total_context_tokens": 256,
784 // "timestamp": "2024-01-01T00:00:00Z"
785 // }
786 // }"#;
787
788 // assert_eq!(
789 // serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
790 // serde_json::from_str::<serde_json::Value>(expected).unwrap()
791 // );
792 // }
793
794 #[test]
795 fn test_serialize_without_stats() {
796 let response = StreamingCompletionResponse {
797 delta: Delta {
798 is_complete: false,
799 finish_reason: None,
800 token_ids: None,
801 tokens: None,
802 text: None,
803 sequence_length: None,
804 index: None,
805 cum_log_probs: None,
806 err_msg: None,
807 usage: None,
808 },
809 logprobs: None,
810 };
811
812 // Serialize the response
813 let serialized = serde_json::to_string(&response).expect("Failed to serialize");
814
815 // Expected JSON string
816 let expected = r#"{
817 "delta": {
818 "is_complete": false,
819 "finish_reason": null,
820 "token_ids": null,
821 "tokens": null,
822 "text": null,
823 "sequence_length": null,
824 "index": null,
825 "cum_log_probs": null,
826 "err_msg": null,
827 "usage": null
828 }
829 }"#;
830
831 assert_eq!(
832 serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
833 serde_json::from_str::<serde_json::Value>(expected).unwrap()
834 );
835 }
836
837 // #[test]
838 // fn test_deserialize_with_stats() {
839 // let json_data = r#"{
840 // "delta": {
841 // "is_complete": true,
842 // "finish_reason": "length",
843 // "token_ids": [101, 102, 103],
844 // "tokens": ["token1", "token2"],
845 // "text": "example text",
846 // "sequence_length": 3,
847 // "index": 0,
848 // "cum_log_probs": -0.5,
849 // "err_msg": null,
850 // "usage": null
851 // },
852 // "stats": {
853 // "time_since_last_forward_pass_us": 1000,
854 // "request_active_count": 2,
855 // "request_context_count": 1,
856 // "request_generation_count": 3,
857 // "request_scheduled_count": 1,
858 // "request_max_count": 10,
859 // "kv_free_cache_blocks": 500,
860 // "kv_max_cache_blocks": 1000,
861 // "kv_used_cache_blocks": 500,
862 // "kv_tokens_per_cache_block": 10,
863 // "runtime_cpu_memory_usage": 5000,
864 // "runtime_gpu_memory_usage": 2000,
865 // "runtime_pinned_memory_usage": 1000,
866 // "iteration_counter": 5,
867 // "microbatch_id": 12345,
868 // "total_context_tokens": 256,
869 // "timestamp": "2024-01-01T00:00:00Z"
870 // }
871 // }"#;
872
873 // // Deserialize the JSON string
874 // let deserialized: StreamingCompletionResponse =
875 // serde_json::from_str(json_data).expect("Failed to deserialize");
876
877 // // Expected response object
878 // let expected = StreamingCompletionResponse {
879 // delta: Delta {
880 // is_complete: true,
881 // finish_reason: Some(FinishReason::Length),
882 // token_ids: Some(vec![101, 102, 103]),
883 // tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
884 // text: Some("example text".to_string()),
885 // sequence_length: Some(3),
886 // index: Some(0),
887 // cum_log_probs: Some(-0.5),
888 // err_msg: None,
889 // usage: None,
890 // },
891 // logprobs: None,
892 // };
893
894 // // This is wieldy but we can no longer do assert_eq!(deserialized, expected);
895 // // because the struct no longer has the PartialEq trait
896 // assert_eq!(deserialized.delta.is_complete, expected.delta.is_complete);
897 // assert_eq!(
898 // deserialized.delta.finish_reason,
899 // expected.delta.finish_reason
900 // );
901 // assert_eq!(deserialized.delta.token_ids, expected.delta.token_ids);
902 // assert_eq!(deserialized.delta.tokens, expected.delta.tokens);
903 // assert_eq!(deserialized.delta.text, expected.delta.text);
904 // assert_eq!(
905 // deserialized.delta.sequence_length,
906 // expected.delta.sequence_length
907 // );
908 // assert_eq!(deserialized.delta.index, expected.delta.index);
909 // assert_eq!(
910 // deserialized.delta.cum_log_probs,
911 // expected.delta.cum_log_probs
912 // );
913 // assert_eq!(deserialized.delta.err_msg, expected.delta.err_msg);
914 // assert_eq!(deserialized.delta.usage, expected.delta.usage);
915
916 // assert_eq!(
917 // deserialized_stats.time_since_last_forward_pass_us,
918 // expected_stats.time_since_last_forward_pass_us
919 // );
920 // assert_eq!(
921 // deserialized_stats.request_active_count,
922 // expected_stats.request_active_count
923 // );
924 // assert_eq!(
925 // deserialized_stats.request_context_count,
926 // expected_stats.request_context_count
927 // );
928 // assert_eq!(
929 // deserialized_stats.request_generation_count,
930 // expected_stats.request_generation_count
931 // );
932 // assert_eq!(
933 // deserialized_stats.request_scheduled_count,
934 // expected_stats.request_scheduled_count
935 // );
936 // assert_eq!(
937 // deserialized_stats.request_max_count,
938 // expected_stats.request_max_count
939 // );
940 // assert_eq!(
941 // deserialized_stats.kv_free_cache_blocks,
942 // expected_stats.kv_free_cache_blocks
943 // );
944 // assert_eq!(
945 // deserialized_stats.kv_max_cache_blocks,
946 // expected_stats.kv_max_cache_blocks
947 // );
948 // assert_eq!(
949 // deserialized_stats.kv_used_cache_blocks,
950 // expected_stats.kv_used_cache_blocks
951 // );
952 // assert_eq!(
953 // deserialized_stats.kv_tokens_per_cache_block,
954 // expected_stats.kv_tokens_per_cache_block
955 // );
956 // assert_eq!(
957 // deserialized_stats.runtime_cpu_memory_usage,
958 // expected_stats.runtime_cpu_memory_usage
959 // );
960 // assert_eq!(
961 // deserialized_stats.runtime_gpu_memory_usage,
962 // expected_stats.runtime_gpu_memory_usage
963 // );
964 // assert_eq!(
965 // deserialized_stats.runtime_pinned_memory_usage,
966 // expected_stats.runtime_pinned_memory_usage
967 // );
968 // assert_eq!(
969 // deserialized_stats.iteration_counter,
970 // expected_stats.iteration_counter
971 // );
972 // assert_eq!(
973 // deserialized_stats.microbatch_id,
974 // expected_stats.microbatch_id
975 // );
976 // assert_eq!(
977 // deserialized_stats.total_context_tokens,
978 // expected_stats.total_context_tokens
979 // );
980 // assert_eq!(deserialized_stats.timestamp, expected_stats.timestamp);
981 // }
982
983 #[test]
984 fn test_deserialize_without_stats() {
985 let json_data = r#"{
986 "delta": {
987 "is_complete": false,
988 "finish_reason": null,
989 "token_ids": null,
990 "tokens": null,
991 "text": null,
992 "sequence_length": null,
993 "index": null,
994 "cum_log_probs": null,
995 "err_msg": null,
996 "usage": null
997 }
998 }"#;
999
1000 // Deserialize the JSON string
1001 let deserialized: StreamingCompletionResponse =
1002 serde_json::from_str(json_data).expect("Failed to deserialize");
1003
1004 // Expected response object
1005 let expected = StreamingCompletionResponse {
1006 delta: Delta {
1007 is_complete: false,
1008 finish_reason: None,
1009 token_ids: None,
1010 tokens: None,
1011 text: None,
1012 sequence_length: None,
1013 index: None,
1014 cum_log_probs: None,
1015 err_msg: None,
1016 usage: None,
1017 },
1018 logprobs: None,
1019 };
1020
1021 // This is wieldy but we can no longer do assert_eq!(deserialized, expected);
1022 // because the struct no longer has the PartialEq trait
1023 assert_eq!(deserialized.delta.is_complete, expected.delta.is_complete);
1024 assert_eq!(
1025 deserialized.delta.finish_reason,
1026 expected.delta.finish_reason
1027 );
1028 assert_eq!(deserialized.delta.token_ids, expected.delta.token_ids);
1029 assert_eq!(deserialized.delta.tokens, expected.delta.tokens);
1030 assert_eq!(deserialized.delta.text, expected.delta.text);
1031 assert_eq!(
1032 deserialized.delta.sequence_length,
1033 expected.delta.sequence_length
1034 );
1035 assert_eq!(deserialized.delta.index, expected.delta.index);
1036 assert_eq!(
1037 deserialized.delta.cum_log_probs,
1038 expected.delta.cum_log_probs
1039 );
1040 assert_eq!(deserialized.delta.err_msg, expected.delta.err_msg);
1041 assert_eq!(deserialized.delta.usage, expected.delta.usage);
1042 }
1043
1044 #[test]
1045 fn test_serialize_delta_and_none_stats() {
1046 let response = StreamingCompletionResponse {
1047 delta: Delta {
1048 is_complete: true,
1049 finish_reason: Some(FinishReason::Length),
1050 token_ids: Some(vec![101, 102, 103]),
1051 tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
1052 text: Some("example text".to_string()),
1053 sequence_length: Some(3),
1054 index: Some(0),
1055 cum_log_probs: Some(-0.5),
1056 err_msg: None,
1057 usage: None,
1058 },
1059 logprobs: None,
1060 };
1061
1062 // Serialize the response
1063 let serialized = serde_json::to_string(&response).expect("Failed to serialize");
1064
1065 // Expected JSON string where stats is null
1066 let expected_json = r#"{
1067 "delta": {
1068 "is_complete": true,
1069 "finish_reason": "length",
1070 "token_ids": [101, 102, 103],
1071 "tokens": ["token1", "token2"],
1072 "text": "example text",
1073 "sequence_length": 3,
1074 "index": 0,
1075 "cum_log_probs": -0.5,
1076 "err_msg": null,
1077 "usage": null
1078 }
1079 }"#;
1080
1081 // Parse both the serialized response and the expected JSON as serde_json::Value for easy comparison
1082 assert_eq!(
1083 serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
1084 serde_json::from_str::<serde_json::Value>(expected_json).unwrap()
1085 );
1086 }
1087}