dynamo_llm/protocols/openai/
chat_completions.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use dynamo_runtime::protocols::annotated::AnnotationsProvider;
5use serde::{Deserialize, Serialize};
6use validator::Validate;
7
8use crate::engines::ValidateRequest;
9
10use super::{
11    OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
12    common_ext::{
13        CommonExt, CommonExtProvider, choose_with_deprecation, emit_nvext_deprecation_warning,
14    },
15    nvext::NvExt,
16    nvext::NvExtProvider,
17    validate,
18};
19
20pub mod aggregator;
21mod delta;
22pub mod jail;
23
24pub use aggregator::DeltaAggregator;
25pub use delta::DeltaGenerator;
26
27/// A request structure for creating a chat completion, extending OpenAI's
28/// `CreateChatCompletionRequest` with [`NvExt`] extensions and common fields.
29///
30/// # Fields
31/// - `inner`: The base OpenAI chat completion request, embedded using `serde(flatten)`.
32/// - `common`: Common extension fields (ignore_eos, min_tokens) at root level, embedded using `serde(flatten)`.
33/// - `nvext`: The optional NVIDIA extension field. See [`NvExt`] for more details.
34///   Note: If ignore_eos is specified in both common and nvext, the common (root-level) value takes precedence.
35#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
36pub struct NvCreateChatCompletionRequest {
37    #[serde(flatten)]
38    pub inner: dynamo_async_openai::types::CreateChatCompletionRequest,
39
40    #[serde(flatten, default)]
41    pub common: CommonExt,
42
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub nvext: Option<NvExt>,
45
46    /// Extra args to pass to the chat template rendering context
47    #[serde(default, skip_serializing_if = "Option::is_none")]
48    pub chat_template_args: Option<std::collections::HashMap<String, serde_json::Value>>,
49}
50
51/// A response structure for unary chat completion responses, embedding OpenAI's
52/// `CreateChatCompletionResponse`.
53///
54/// # Fields
55/// - `inner`: The base OpenAI unary chat completion response, embedded
56///   using `serde(flatten)`.
57pub type NvCreateChatCompletionResponse = dynamo_async_openai::types::CreateChatCompletionResponse;
58
59/// A response structure for streamed chat completions, embedding OpenAI's
60/// `CreateChatCompletionStreamResponse`.
61///
62/// # Fields
63/// - `inner`: The base OpenAI streaming chat completion response, embedded
64///   using `serde(flatten)`.
65pub type NvCreateChatCompletionStreamResponse =
66    dynamo_async_openai::types::CreateChatCompletionStreamResponse;
67
68/// Implements `NvExtProvider` for `NvCreateChatCompletionRequest`,
69/// providing access to NVIDIA-specific extensions.
70impl NvExtProvider for NvCreateChatCompletionRequest {
71    /// Returns a reference to the optional `NvExt` extension, if available.
72    fn nvext(&self) -> Option<&NvExt> {
73        self.nvext.as_ref()
74    }
75
76    /// Returns `None`, as raw prompt extraction is not implemented.
77    fn raw_prompt(&self) -> Option<String> {
78        None
79    }
80}
81
82/// Implements `AnnotationsProvider` for `NvCreateChatCompletionRequest`,
83/// enabling retrieval and management of request annotations.
84impl AnnotationsProvider for NvCreateChatCompletionRequest {
85    /// Retrieves the list of annotations from `NvExt`, if present.
86    fn annotations(&self) -> Option<Vec<String>> {
87        self.nvext
88            .as_ref()
89            .and_then(|nvext| nvext.annotations.clone())
90    }
91
92    /// Checks whether a specific annotation exists in the request.
93    ///
94    /// # Arguments
95    /// * `annotation` - A string slice representing the annotation to check.
96    ///
97    /// # Returns
98    /// `true` if the annotation exists, `false` otherwise.
99    fn has_annotation(&self, annotation: &str) -> bool {
100        self.nvext
101            .as_ref()
102            .and_then(|nvext| nvext.annotations.as_ref())
103            .map(|annotations| annotations.contains(&annotation.to_string()))
104            .unwrap_or(false)
105    }
106}
107
108/// Implements `OpenAISamplingOptionsProvider` for `NvCreateChatCompletionRequest`,
109/// exposing OpenAI's sampling parameters for chat completion.
110impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
111    /// Retrieves the temperature parameter for sampling, if set.
112    fn get_temperature(&self) -> Option<f32> {
113        self.inner.temperature
114    }
115
116    /// Retrieves the top-p (nucleus sampling) parameter, if set.
117    fn get_top_p(&self) -> Option<f32> {
118        self.inner.top_p
119    }
120
121    /// Retrieves the frequency penalty parameter, if set.
122    fn get_frequency_penalty(&self) -> Option<f32> {
123        self.inner.frequency_penalty
124    }
125
126    /// Retrieves the presence penalty parameter, if set.
127    fn get_presence_penalty(&self) -> Option<f32> {
128        self.inner.presence_penalty
129    }
130
131    /// Returns a reference to the optional `NvExt` extension, if available.
132    fn nvext(&self) -> Option<&NvExt> {
133        self.nvext.as_ref()
134    }
135    /// Retrieves the seed value for random number generation, if set.
136    fn get_seed(&self) -> Option<i64> {
137        self.inner.seed
138    }
139
140    /// Retrieves the number of completions to generate for each prompt, if set.
141    fn get_n(&self) -> Option<u8> {
142        self.inner.n
143    }
144
145    /// Retrieves the best_of parameter, if set.
146    fn get_best_of(&self) -> Option<u8> {
147        None // Not supported in chat completions
148    }
149}
150
151/// Implements `CommonExtProvider` for `NvCreateChatCompletionRequest`,
152/// providing access to common extension fields.
153impl CommonExtProvider for NvCreateChatCompletionRequest {
154    /// Returns a reference to the CommonExt struct.
155    fn common_ext(&self) -> Option<&CommonExt> {
156        Some(&self.common)
157    }
158
159    /// Guided Decoding Options
160    fn get_guided_json(&self) -> Option<&serde_json::Value> {
161        // Note: This one needs special handling since it returns a reference
162        if let Some(nvext) = &self.nvext
163            && nvext.guided_json.is_some()
164        {
165            emit_nvext_deprecation_warning("guided_json", true, self.common.guided_json.is_some());
166        }
167        self.common
168            .guided_json
169            .as_ref()
170            .or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
171    }
172
173    fn get_guided_regex(&self) -> Option<String> {
174        choose_with_deprecation(
175            "guided_regex",
176            self.common.guided_regex.as_ref(),
177            self.nvext.as_ref().and_then(|nv| nv.guided_regex.as_ref()),
178        )
179    }
180
181    fn get_guided_grammar(&self) -> Option<String> {
182        choose_with_deprecation(
183            "guided_grammar",
184            self.common.guided_grammar.as_ref(),
185            self.nvext
186                .as_ref()
187                .and_then(|nv| nv.guided_grammar.as_ref()),
188        )
189    }
190
191    fn get_guided_choice(&self) -> Option<Vec<String>> {
192        choose_with_deprecation(
193            "guided_choice",
194            self.common.guided_choice.as_ref(),
195            self.nvext.as_ref().and_then(|nv| nv.guided_choice.as_ref()),
196        )
197    }
198
199    fn get_guided_decoding_backend(&self) -> Option<String> {
200        choose_with_deprecation(
201            "guided_decoding_backend",
202            self.common.guided_decoding_backend.as_ref(),
203            self.nvext
204                .as_ref()
205                .and_then(|nv| nv.guided_decoding_backend.as_ref()),
206        )
207    }
208
209    fn get_top_k(&self) -> Option<i32> {
210        choose_with_deprecation(
211            "top_k",
212            self.common.top_k.as_ref(),
213            self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()),
214        )
215    }
216
217    fn get_min_p(&self) -> Option<f32> {
218        choose_with_deprecation(
219            "min_p",
220            self.common.min_p.as_ref(),
221            self.nvext.as_ref().and_then(|nv| nv.min_p.as_ref()),
222        )
223    }
224
225    fn get_repetition_penalty(&self) -> Option<f32> {
226        choose_with_deprecation(
227            "repetition_penalty",
228            self.common.repetition_penalty.as_ref(),
229            self.nvext
230                .as_ref()
231                .and_then(|nv| nv.repetition_penalty.as_ref()),
232        )
233    }
234
235    fn get_include_stop_str_in_output(&self) -> Option<bool> {
236        self.common.include_stop_str_in_output
237    }
238}
239
240/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
241/// providing access to stop conditions that control chat completion behavior.
242impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
243    /// Retrieves the maximum number of tokens allowed in the response.
244    #[allow(deprecated)]
245    fn get_max_tokens(&self) -> Option<u32> {
246        self.inner.max_completion_tokens.or(self.inner.max_tokens)
247    }
248
249    /// Retrieves the minimum number of tokens required in the response.
250    /// Returns `min_tokens` Value
251    /// `min_tokens` is not an OpenAI-supported parameter.
252    fn get_min_tokens(&self) -> Option<u32> {
253        self.common.min_tokens
254    }
255
256    /// Retrieves the stop conditions that terminate the chat completion response.
257    ///
258    /// Converts OpenAI's `Stop` enum to a `Vec<String>`, normalizing the representation.
259    ///
260    /// # Returns
261    /// * `Some(Vec<String>)` if stop conditions are set.
262    /// * `None` if no stop conditions are defined.
263    fn get_stop(&self) -> Option<Vec<String>> {
264        self.inner.stop.as_ref().map(|stop| match stop {
265            dynamo_async_openai::types::Stop::String(s) => vec![s.clone()],
266            dynamo_async_openai::types::Stop::StringArray(arr) => arr.clone(),
267        })
268    }
269
270    /// Returns a reference to the optional `NvExt` extension, if available.
271    fn nvext(&self) -> Option<&NvExt> {
272        self.nvext.as_ref()
273    }
274
275    /// Get ignore_eos from CommonExt.
276    fn get_common_ignore_eos(&self) -> Option<bool> {
277        self.common.ignore_eos
278    }
279
280    /// Get the effective ignore_eos value, considering both CommonExt and NvExt.
281    /// CommonExt (root-level) takes precedence over NvExt.
282    fn get_ignore_eos(&self) -> Option<bool> {
283        choose_with_deprecation(
284            "ignore_eos",
285            self.get_common_ignore_eos().as_ref(),
286            NvExtProvider::nvext(self).and_then(|nv| nv.ignore_eos.as_ref()),
287        )
288    }
289}
290
291impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest {
292    fn get_logprobs(&self) -> Option<u32> {
293        match self.inner.logprobs {
294            Some(true) => match self.inner.top_logprobs {
295                Some(top_logprobs) => Some(top_logprobs as u32),
296                None => Some(1_u32),
297            },
298            Some(false) => None,
299            None => None,
300        }
301    }
302
303    fn get_prompt_logprobs(&self) -> Option<u32> {
304        None
305    }
306
307    fn get_skip_special_tokens(&self) -> Option<bool> {
308        None
309    }
310
311    fn get_formatted_prompt(&self) -> Option<bool> {
312        None
313    }
314}
315
316/// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`,
317/// allowing us to validate the data.
318impl ValidateRequest for NvCreateChatCompletionRequest {
319    fn validate(&self) -> Result<(), anyhow::Error> {
320        validate::validate_messages(&self.inner.messages)?;
321        validate::validate_model(&self.inner.model)?;
322        // none for store
323        validate::validate_reasoning_effort(&self.inner.reasoning_effort)?;
324        validate::validate_metadata(&self.inner.metadata)?;
325        validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
326        validate::validate_logit_bias(&self.inner.logit_bias)?;
327        // none for logprobs
328        validate::validate_top_logprobs(self.inner.top_logprobs)?;
329        // validate::validate_max_tokens(self.inner.max_tokens)?; // warning depricated field
330        validate::validate_max_completion_tokens(self.inner.max_completion_tokens)?;
331        validate::validate_n(self.inner.n)?;
332        // none for modalities
333        // none for prediction
334        // none for audio
335        validate::validate_presence_penalty(self.inner.presence_penalty)?;
336        // none for response_format
337        // none for seed
338        validate::validate_service_tier(&self.inner.service_tier)?;
339        validate::validate_stop(&self.inner.stop)?;
340        // none for stream
341        // none for stream_options
342        validate::validate_temperature(self.inner.temperature)?;
343        validate::validate_top_p(self.inner.top_p)?;
344        validate::validate_tools(&self.inner.tools.as_deref())?;
345        // none for tool_choice
346        // none for parallel_tool_calls
347        validate::validate_user(self.inner.user.as_deref())?;
348        // none for function call
349        // none for functions
350        // Common Ext
351        validate::validate_repetition_penalty(self.get_repetition_penalty())?;
352
353        Ok(())
354    }
355}