dynamo_llm/protocols/
openai.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use anyhow::Result;
5use serde::{Deserialize, Serialize};
6
7use super::{
8    ContentProvider,
9    common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
10};
11use crate::protocols::openai::common_ext::{CommonExtProvider, choose_with_deprecation};
12
13pub mod chat_completions;
14pub mod common_ext;
15pub mod completions;
16pub mod embeddings;
17pub mod models;
18pub mod nvext;
19pub mod responses;
20pub mod validate;
21
22use validate::{
23    BEST_OF_RANGE, FREQUENCY_PENALTY_RANGE, MIN_P_RANGE, N_RANGE, PRESENCE_PENALTY_RANGE,
24    TEMPERATURE_RANGE, TOP_P_RANGE, validate_range,
25};
26
27#[derive(Serialize, Deserialize, Debug)]
28pub struct AnnotatedDelta<R> {
29    pub delta: R,
30    pub id: Option<String>,
31    pub event: Option<String>,
32    pub comment: Option<String>,
33}
34
35trait OpenAISamplingOptionsProvider {
36    fn get_temperature(&self) -> Option<f32>;
37
38    fn get_top_p(&self) -> Option<f32>;
39
40    fn get_frequency_penalty(&self) -> Option<f32>;
41
42    fn get_presence_penalty(&self) -> Option<f32>;
43
44    fn get_seed(&self) -> Option<i64>;
45
46    fn get_n(&self) -> Option<u8>;
47
48    fn get_best_of(&self) -> Option<u8>;
49
50    fn nvext(&self) -> Option<&nvext::NvExt>;
51}
52
53trait OpenAIStopConditionsProvider {
54    fn get_max_tokens(&self) -> Option<u32>;
55
56    fn get_min_tokens(&self) -> Option<u32>;
57
58    fn get_stop(&self) -> Option<Vec<String>>;
59
60    fn nvext(&self) -> Option<&nvext::NvExt>;
61
62    /// Get ignore_eos from CommonExt if the type supports it.
63    /// Default returns None for types without CommonExt support.
64    fn get_common_ignore_eos(&self) -> Option<bool> {
65        None
66    }
67
68    /// Get the effective ignore_eos value, considering both CommonExt and NvExt.
69    /// CommonExt (root-level) takes precedence over NvExt.
70    fn get_ignore_eos(&self) -> Option<bool> {
71        choose_with_deprecation(
72            "ignore_eos",
73            self.get_common_ignore_eos().as_ref(),
74            self.nvext().and_then(|nv| nv.ignore_eos.as_ref()),
75        )
76    }
77
78    /// Get max_thinking_tokens from nvext
79    /// NOTE: This is currently a passthrough for future thinking budget implementation
80    fn get_max_thinking_tokens(&self) -> Option<u32> {
81        self.nvext().and_then(|nv| nv.max_thinking_tokens)
82    }
83}
84
85trait OpenAIOutputOptionsProvider {
86    fn get_logprobs(&self) -> Option<u32>;
87
88    fn get_prompt_logprobs(&self) -> Option<u32>;
89
90    fn get_skip_special_tokens(&self) -> Option<bool>;
91
92    fn get_formatted_prompt(&self) -> Option<bool>;
93}
94
95impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvider for T {
96    fn extract_sampling_options(&self) -> Result<common::SamplingOptions> {
97        // let result = self.validate();
98        // if let Err(e) = result {
99        //     return Err(format!("Error validating sampling options: {}", e));
100        // }
101
102        let mut temperature = validate_range(self.get_temperature(), &TEMPERATURE_RANGE)
103            .map_err(|e| anyhow::anyhow!("Error validating temperature: {}", e))?;
104        let mut top_p = validate_range(self.get_top_p(), &TOP_P_RANGE)
105            .map_err(|e| anyhow::anyhow!("Error validating top_p: {}", e))?;
106        let frequency_penalty =
107            validate_range(self.get_frequency_penalty(), &FREQUENCY_PENALTY_RANGE)
108                .map_err(|e| anyhow::anyhow!("Error validating frequency_penalty: {}", e))?;
109        let presence_penalty = validate_range(self.get_presence_penalty(), &PRESENCE_PENALTY_RANGE)
110            .map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?;
111        let top_k = CommonExtProvider::get_top_k(self);
112        let repetition_penalty = CommonExtProvider::get_repetition_penalty(self);
113        let include_stop_str_in_output = CommonExtProvider::get_include_stop_str_in_output(self);
114        let seed = self.get_seed();
115        let n = validate_range(self.get_n(), &N_RANGE)
116            .map_err(|e| anyhow::anyhow!("Error validating n: {}", e))?;
117        let best_of = validate_range(self.get_best_of(), &BEST_OF_RANGE)
118            .map_err(|e| anyhow::anyhow!("Error validating best_of: {}", e))?;
119
120        let min_p = validate_range(CommonExtProvider::get_min_p(self), &MIN_P_RANGE)
121            .map_err(|e| anyhow::anyhow!("Error validating min_p: {}", e))?;
122
123        if let Some(nvext) = self.nvext() {
124            let greedy = nvext.greed_sampling.unwrap_or(false);
125            if greedy {
126                top_p = None;
127                temperature = None;
128            }
129        }
130
131        let guided_decoding_backend = self.get_guided_decoding_backend();
132        let guided_json = self.get_guided_json();
133        let guided_regex = self.get_guided_regex();
134        let guided_grammar = self.get_guided_grammar();
135        let guided_choice = self.get_guided_choice();
136
137        let guided_decoding = match common::GuidedDecodingOptions::from_optional(
138            guided_json.cloned(),
139            guided_regex,
140            guided_choice,
141            guided_grammar,
142            guided_decoding_backend,
143        ) {
144            Ok(options) => options,
145            Err(e) => {
146                // Handle the validation error (log, return error, etc.)
147                tracing::error!("Invalid guided decoding options: {:?}", e);
148                return Err(e);
149            }
150        };
151
152        Ok(common::SamplingOptions {
153            n,
154            best_of,
155            frequency_penalty,
156            presence_penalty,
157            repetition_penalty,
158            temperature,
159            top_p,
160            top_k,
161            min_p,
162            seed,
163            use_beam_search: None,
164            length_penalty: None,
165            guided_decoding,
166            include_stop_str_in_output,
167        })
168    }
169}
170
171impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
172    fn extract_stop_conditions(&self) -> Result<common::StopConditions> {
173        let max_tokens = self.get_max_tokens();
174        let min_tokens = self.get_min_tokens();
175        let stop = self.get_stop();
176        let max_thinking_tokens = self.get_max_thinking_tokens();
177
178        if let Some(stop) = &stop
179            && stop.len() > 4
180        {
181            anyhow::bail!("stop conditions must be less than 4")
182        }
183
184        // Use the trait method to get ignore_eos, which handles precedence
185        let ignore_eos = self.get_ignore_eos();
186
187        Ok(common::StopConditions {
188            max_tokens,
189            min_tokens,
190            stop,
191            stop_token_ids_hidden: None,
192            ignore_eos,
193            max_thinking_tokens,
194        })
195    }
196}
197
198impl<T: OpenAIOutputOptionsProvider> OutputOptionsProvider for T {
199    fn extract_output_options(&self) -> Result<common::OutputOptions> {
200        let logprobs = self.get_logprobs();
201        let prompt_logprobs = self.get_prompt_logprobs();
202        let skip_special_tokens = self.get_skip_special_tokens();
203        let formatted_prompt = self.get_formatted_prompt();
204
205        Ok(common::OutputOptions {
206            logprobs,
207            prompt_logprobs,
208            skip_special_tokens,
209            formatted_prompt,
210        })
211    }
212}
213
214pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
215    Send + 'static
216{
217    fn choice_from_postprocessor(
218        &mut self,
219        response: common::llm_backend::BackendOutput,
220    ) -> Result<ResponseType>;
221
222    /// Gets the current prompt token count (Input Sequence Length).
223    fn get_isl(&self) -> Option<u32>;
224
225    /// Creates a final usage-only chunk for OpenAI compliance.
226    fn create_usage_chunk(&self) -> ResponseType;
227
228    /// Check if usage tracking is enabled.
229    fn is_usage_enabled(&self) -> bool;
230}
231
232#[derive(Clone, Debug, Serialize, Deserialize, Default)]
233pub struct ParsingOptions {
234    pub tool_call_parser: Option<String>,
235
236    pub reasoning_parser: Option<String>,
237}
238
239impl ParsingOptions {
240    pub fn new(tool_call_parser: Option<String>, reasoning_parser: Option<String>) -> Self {
241        Self {
242            tool_call_parser,
243            reasoning_parser,
244        }
245    }
246}