dynamo_llm/protocols/
openai.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16pub mod chat_completions;
17pub mod completions;
18pub mod embeddings;
19pub mod models;
20pub mod nvext;
21
22use anyhow::Result;
23use serde::{Deserialize, Serialize};
24use std::{
25    fmt::Display,
26    ops::{Add, Div, Mul, Sub},
27};
28
29use super::{
30    common::{self, SamplingOptionsProvider, StopConditionsProvider},
31    ContentProvider,
32};
33
34/// Minimum allowed value for OpenAI's `temperature` sampling option
35pub const MIN_TEMPERATURE: f32 = 0.0;
36
37/// Maximum allowed value for OpenAI's `temperature` sampling option
38pub const MAX_TEMPERATURE: f32 = 2.0;
39
40/// Allowed range of values for OpenAI's `temperature`` sampling option
41pub const TEMPERATURE_RANGE: (f32, f32) = (MIN_TEMPERATURE, MAX_TEMPERATURE);
42
43/// Minimum allowed value for OpenAI's `top_p` sampling option
44pub const MIN_TOP_P: f32 = 0.0;
45
46/// Maximum allowed value for OpenAI's `top_p` sampling option
47pub const MAX_TOP_P: f32 = 1.0;
48
49/// Allowed range of values for OpenAI's `top_p` sampling option
50pub const TOP_P_RANGE: (f32, f32) = (MIN_TOP_P, MAX_TOP_P);
51
52/// Minimum allowed value for OpenAI's `frequency_penalty` sampling option
53pub const MIN_FREQUENCY_PENALTY: f32 = -2.0;
54
55/// Maximum allowed value for OpenAI's `frequency_penalty` sampling option
56pub const MAX_FREQUENCY_PENALTY: f32 = 2.0;
57
58/// Allowed range of values for OpenAI's `frequency_penalty` sampling option
59pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (MIN_FREQUENCY_PENALTY, MAX_FREQUENCY_PENALTY);
60
61/// Minimum allowed value for OpenAI's `presence_penalty` sampling option
62pub const MIN_PRESENCE_PENALTY: f32 = -2.0;
63
64/// Maximum allowed value for OpenAI's `presence_penalty` sampling option
65pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
66
67/// Allowed range of values for OpenAI's `presence_penalty` sampling option
68pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY);
69
70/// Usage statistics for the completion request
71#[derive(Serialize, Deserialize, Debug, Clone, Default)]
72pub struct CompletionUsage {
73    /// Number of tokens in the generated completion.
74    pub completion_tokens: i32,
75
76    /// Number of tokens in the prompt.
77    pub prompt_tokens: i32,
78
79    /// Total number of tokens used in the request (prompt + completion).
80    pub total_tokens: i32,
81
82    /// Breakdown of tokens used in a completion, optional.
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub completion_tokens_details: Option<CompletionTokensDetails>,
85
86    /// Breakdown of tokens used in the prompt, optional.
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub prompt_tokens_details: Option<PromptTokensDetails>,
89}
90
91// Struct for details on completion tokens
92#[derive(Debug, Serialize, Deserialize, Clone)]
93pub struct CompletionTokensDetails {
94    /// Audio input tokens generated by the model.
95    pub audio_tokens: Option<i32>,
96
97    /// Tokens generated by the model for reasoning.
98    pub reasoning_tokens: Option<i32>,
99}
100
101// Struct for details on prompt tokens
102#[derive(Debug, Serialize, Deserialize, Clone)]
103pub struct PromptTokensDetails {
104    /// Audio input tokens present in the prompt.
105    pub audio_tokens: Option<i32>,
106
107    /// Cached tokens present in the prompt.
108    pub cached_tokens: Option<i32>,
109}
110
111/// Represents a streaming response from the OpenAI API
112/// The object is generalized on R, which is the type of the response.
113/// For SSE streaming responses, the expected `data: ` field is always a JSON
114/// object corresponding to `R`; however, the comments in the SSE stream `: `
115/// may correspond to other types of information, such as performance metrics,
116/// as represented by other arms of this enum.
117///
118/// This is part of the common API as both the client and service need to agree
119/// on the format of the streaming responses.
120#[derive(Serialize, Deserialize, Debug)]
121pub enum StreamingDelta<R> {
122    /// Represents a response delta from the API
123    Delta(R),
124    Comment(String),
125}
126
127#[derive(Serialize, Deserialize, Debug)]
128pub struct AnnotatedDelta<R> {
129    pub delta: R,
130    pub id: Option<String>,
131    pub event: Option<String>,
132    pub comment: Option<String>,
133}
134
135trait OpenAISamplingOptionsProvider {
136    fn get_temperature(&self) -> Option<f32>;
137
138    fn get_top_p(&self) -> Option<f32>;
139
140    fn get_frequency_penalty(&self) -> Option<f32>;
141
142    fn get_presence_penalty(&self) -> Option<f32>;
143
144    fn nvext(&self) -> Option<&nvext::NvExt>;
145}
146
147trait OpenAIStopConditionsProvider {
148    fn get_max_tokens(&self) -> Option<u32>;
149
150    fn get_min_tokens(&self) -> Option<u32>;
151
152    fn get_stop(&self) -> Option<Vec<String>>;
153
154    fn nvext(&self) -> Option<&nvext::NvExt>;
155}
156
157impl<T: OpenAISamplingOptionsProvider> SamplingOptionsProvider for T {
158    fn extract_sampling_options(&self) -> Result<common::SamplingOptions> {
159        // let result = self.validate();
160        // if let Err(e) = result {
161        //     return Err(format!("Error validating sampling options: {}", e));
162        // }
163
164        let mut temperature = validate_range(self.get_temperature(), &TEMPERATURE_RANGE)
165            .map_err(|e| anyhow::anyhow!("Error validating temperature: {}", e))?;
166        let mut top_p = validate_range(self.get_top_p(), &TOP_P_RANGE)
167            .map_err(|e| anyhow::anyhow!("Error validating top_p: {}", e))?;
168        let frequency_penalty =
169            validate_range(self.get_frequency_penalty(), &FREQUENCY_PENALTY_RANGE)
170                .map_err(|e| anyhow::anyhow!("Error validating frequency_penalty: {}", e))?;
171        let presence_penalty = validate_range(self.get_presence_penalty(), &PRESENCE_PENALTY_RANGE)
172            .map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?;
173
174        if let Some(nvext) = self.nvext() {
175            let greedy = nvext.greed_sampling.unwrap_or(false);
176            if greedy {
177                top_p = None;
178                temperature = None;
179            }
180        }
181
182        Ok(common::SamplingOptions {
183            n: None,
184            best_of: None,
185            frequency_penalty,
186            presence_penalty,
187            repetition_penalty: None,
188            temperature,
189            top_p,
190            top_k: None,
191            min_p: None,
192            seed: None,
193            use_beam_search: None,
194            length_penalty: None,
195        })
196    }
197}
198
199impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
200    fn extract_stop_conditions(&self) -> Result<common::StopConditions> {
201        let max_tokens = self.get_max_tokens();
202        let min_tokens = self.get_min_tokens();
203        let stop = self.get_stop();
204
205        if let Some(stop) = &stop {
206            if stop.len() > 4 {
207                anyhow::bail!("stop conditions must be less than 4")
208            }
209        }
210
211        let mut ignore_eos = None;
212
213        if let Some(nvext) = self.nvext() {
214            ignore_eos = nvext.ignore_eos;
215        }
216
217        Ok(common::StopConditions {
218            max_tokens,
219            min_tokens,
220            stop,
221            stop_token_ids_hidden: None,
222            ignore_eos,
223        })
224    }
225}
226
227/// Common structure for chat completion responses; the only delta is the type of choices which differs
228/// between streaming and non-streaming requests.
229#[derive(Serialize, Deserialize, Debug, Clone)]
230pub struct GenericCompletionResponse<C>
231// where
232//     C: Serialize + Clone,
233{
234    /// A unique identifier for the chat completion.
235    pub id: String,
236
237    /// A list of chat completion choices. Can be more than one if n is greater than 1.
238    pub choices: Vec<C>,
239
240    /// The Unix timestamp (in seconds) of when the chat completion was created.
241    pub created: u64,
242
243    /// The model used for the chat completion.
244    pub model: String,
245
246    /// The object type, which is `chat.completion` if the type of `Choice` is `ChatCompletionChoice`,
247    /// or is `chat.completion.chunk` if the type of `Choice` is `ChatCompletionChoiceDelta`.
248    pub object: String,
249
250    pub usage: Option<CompletionUsage>,
251
252    /// This fingerprint represents the backend configuration that the model runs with.
253    ///
254    /// Can be used in conjunction with the seed request parameter to understand when backend changes
255    /// have been made that might impact determinism.
256    ///
257    /// NIM Compatibility:
258    /// This field is not supported by the NIM; however it will be added in the future.
259    /// The optional nature of this field will be relaxed when it is supported.
260    pub system_fingerprint: Option<String>,
261    // TODO() - add NvResponseExtention
262}
263
264// todo - move to common location
265fn validate_range<T>(value: Option<T>, range: &(T, T)) -> Result<Option<T>>
266where
267    T: PartialOrd + Display,
268{
269    if value.is_none() {
270        return Ok(None);
271    }
272    let value = value.unwrap();
273    if value < range.0 || value > range.1 {
274        anyhow::bail!("Value {} is out of range [{}, {}]", value, range.0, range.1);
275    }
276    Ok(Some(value))
277}
278
279// todo - move to common location
280/// scale value in `src` range to `dst` range
281pub fn scale_value<T>(value: &T, src: &(T, T), dst: &(T, T)) -> Result<T>
282where
283    T: Copy
284        + PartialOrd
285        + Add<Output = T>
286        + Sub<Output = T>
287        + Mul<Output = T>
288        + Div<Output = T>
289        + From<f32>,
290{
291    let dst_range = dst.1 - dst.0;
292    let src_range = src.1 - src.0;
293    if dst_range == T::from(0.0) {
294        anyhow::bail!("dst range is 0");
295    }
296    if src_range == T::from(0.0) {
297        anyhow::bail!("src range is 0");
298    }
299    let value_scaled = (*value - src.0) / src_range;
300    Ok(dst.0 + (value_scaled * dst_range))
301}
302
303pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debug>:
304    Send + Sync + 'static
305{
306    fn choice_from_postprocessor(
307        &mut self,
308        response: common::llm_backend::BackendOutput,
309    ) -> Result<ResponseType>;
310
311    /// Gets the current prompt token count (Input Sequence Length).
312    fn get_isl(&self) -> Option<u32>;
313}
314
315#[cfg(test)]
316mod tests {
317
318    use super::*;
319
320    #[test]
321    fn test_validate_range() {
322        assert_eq!(validate_range(Some(0.5), &(0.0, 1.0)).unwrap(), Some(0.5));
323        assert_eq!(validate_range(Some(0.0), &(0.0, 1.0)).unwrap(), Some(0.0));
324        assert_eq!(validate_range(Some(1.0), &(1.0, 1.0)).unwrap(), Some(1.0));
325        assert_eq!(validate_range(Some(1_i32), &(1, 1)).unwrap(), Some(1));
326        assert_eq!(
327            validate_range(Some(1.1), &(0.0, 1.0))
328                .unwrap_err()
329                .to_string(),
330            "Value 1.1 is out of range [0, 1]"
331        );
332        assert_eq!(
333            validate_range(Some(-0.1), &(0.0, 1.0))
334                .unwrap_err()
335                .to_string(),
336            "Value -0.1 is out of range [0, 1]"
337        );
338    }
339
340    #[test]
341    fn test_scaled_value() {
342        assert_eq!(scale_value(&0.5, &(0.0, 1.0), &(0.0, 2.0)).unwrap(), 1.0);
343        assert_eq!(scale_value(&0.0, &(0.0, 1.0), &(0.0, 2.0)).unwrap(), 0.0);
344        assert_eq!(scale_value(&-1.0, &(-2.0, 2.0), &(1.0, 2.0)).unwrap(), 1.25);
345        assert!(scale_value(&1.0, &(1.0, 1.0), &(0.0, 2.0)).is_err());
346    }
347}