dynamo_llm/protocols/openai/
chat_completions.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::nvext::NvExt;
17use super::nvext::NvExtProvider;
18use super::OpenAISamplingOptionsProvider;
19use super::OpenAIStopConditionsProvider;
20use dynamo_runtime::protocols::annotated::AnnotationsProvider;
21use serde::{Deserialize, Serialize};
22use validator::Validate;
23
24mod aggregator;
25mod delta;
26
27pub use aggregator::DeltaAggregator;
28pub use delta::DeltaGenerator;
29
30/// A request structure for creating a chat completion, extending OpenAI's
31/// `CreateChatCompletionRequest` with [`NvExt`] extensions.
32///
33/// # Fields
34/// - `inner`: The base OpenAI chat completion request, embedded using `serde(flatten)`.
35/// - `nvext`: The optional NVIDIA extension field. See [`NvExt`] for
36///   more details.
37#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
38pub struct NvCreateChatCompletionRequest {
39    #[serde(flatten)]
40    pub inner: async_openai::types::CreateChatCompletionRequest,
41
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub nvext: Option<NvExt>,
44}
45
46/// A response structure for unary chat completion responses, embedding OpenAI's
47/// `CreateChatCompletionResponse`.
48///
49/// # Fields
50/// - `inner`: The base OpenAI unary chat completion response, embedded
51///   using `serde(flatten)`.
52#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
53pub struct NvCreateChatCompletionResponse {
54    #[serde(flatten)]
55    pub inner: async_openai::types::CreateChatCompletionResponse,
56}
57
58/// A response structure for streamed chat completions, embedding OpenAI's
59/// `CreateChatCompletionStreamResponse`.
60///
61/// # Fields
62/// - `inner`: The base OpenAI streaming chat completion response, embedded
63///   using `serde(flatten)`.
64#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
65pub struct NvCreateChatCompletionStreamResponse {
66    #[serde(flatten)]
67    pub inner: async_openai::types::CreateChatCompletionStreamResponse,
68}
69
70/// Implements `NvExtProvider` for `NvCreateChatCompletionRequest`,
71/// providing access to NVIDIA-specific extensions.
72impl NvExtProvider for NvCreateChatCompletionRequest {
73    /// Returns a reference to the optional `NvExt` extension, if available.
74    fn nvext(&self) -> Option<&NvExt> {
75        self.nvext.as_ref()
76    }
77
78    /// Returns `None`, as raw prompt extraction is not implemented.
79    fn raw_prompt(&self) -> Option<String> {
80        None
81    }
82}
83
84/// Implements `AnnotationsProvider` for `NvCreateChatCompletionRequest`,
85/// enabling retrieval and management of request annotations.
86impl AnnotationsProvider for NvCreateChatCompletionRequest {
87    /// Retrieves the list of annotations from `NvExt`, if present.
88    fn annotations(&self) -> Option<Vec<String>> {
89        self.nvext
90            .as_ref()
91            .and_then(|nvext| nvext.annotations.clone())
92    }
93
94    /// Checks whether a specific annotation exists in the request.
95    ///
96    /// # Arguments
97    /// * `annotation` - A string slice representing the annotation to check.
98    ///
99    /// # Returns
100    /// `true` if the annotation exists, `false` otherwise.
101    fn has_annotation(&self, annotation: &str) -> bool {
102        self.nvext
103            .as_ref()
104            .and_then(|nvext| nvext.annotations.as_ref())
105            .map(|annotations| annotations.contains(&annotation.to_string()))
106            .unwrap_or(false)
107    }
108}
109
110/// Implements `OpenAISamplingOptionsProvider` for `NvCreateChatCompletionRequest`,
111/// exposing OpenAI's sampling parameters for chat completion.
112impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
113    /// Retrieves the temperature parameter for sampling, if set.
114    fn get_temperature(&self) -> Option<f32> {
115        self.inner.temperature
116    }
117
118    /// Retrieves the top-p (nucleus sampling) parameter, if set.
119    fn get_top_p(&self) -> Option<f32> {
120        self.inner.top_p
121    }
122
123    /// Retrieves the frequency penalty parameter, if set.
124    fn get_frequency_penalty(&self) -> Option<f32> {
125        self.inner.frequency_penalty
126    }
127
128    /// Retrieves the presence penalty parameter, if set.
129    fn get_presence_penalty(&self) -> Option<f32> {
130        self.inner.presence_penalty
131    }
132
133    /// Returns a reference to the optional `NvExt` extension, if available.
134    fn nvext(&self) -> Option<&NvExt> {
135        self.nvext.as_ref()
136    }
137}
138
139/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
140/// providing access to stop conditions that control chat completion behavior.
141impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
142    /// Retrieves the maximum number of tokens allowed in the response.
143    #[allow(deprecated)]
144    fn get_max_tokens(&self) -> Option<u32> {
145        self.inner.max_completion_tokens.or(self.inner.max_tokens)
146    }
147
148    /// Retrieves the minimum number of tokens required in the response.
149    ///
150    /// # Note
151    /// This method is currently a placeholder and always returns `None`
152    /// since `min_tokens` is not an OpenAI-supported parameter.
153    fn get_min_tokens(&self) -> Option<u32> {
154        None
155    }
156
157    /// Retrieves the stop conditions that terminate the chat completion response.
158    ///
159    /// Converts OpenAI's `Stop` enum to a `Vec<String>`, normalizing the representation.
160    ///
161    /// # Returns
162    /// * `Some(Vec<String>)` if stop conditions are set.
163    /// * `None` if no stop conditions are defined.
164    fn get_stop(&self) -> Option<Vec<String>> {
165        self.inner.stop.as_ref().map(|stop| match stop {
166            async_openai::types::Stop::String(s) => vec![s.clone()],
167            async_openai::types::Stop::StringArray(arr) => arr.clone(),
168        })
169    }
170
171    /// Returns a reference to the optional `NvExt` extension, if available.
172    fn nvext(&self) -> Option<&NvExt> {
173        self.nvext.as_ref()
174    }
175}