dynamo_llm/protocols/
openai.rs1use anyhow::Result;
5use serde::{Deserialize, Serialize};
6
7use super::{
8 ContentProvider,
9 common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
10};
11use crate::protocols::openai::common_ext::{CommonExtProvider, choose_with_deprecation};
12
13pub mod chat_completions;
14pub mod common_ext;
15pub mod completions;
16pub mod embeddings;
17pub mod models;
18pub mod nvext;
19pub mod responses;
20pub mod validate;
21
22use validate::{
23 BEST_OF_RANGE, FREQUENCY_PENALTY_RANGE, MIN_P_RANGE, N_RANGE, PRESENCE_PENALTY_RANGE,
24 TEMPERATURE_RANGE, TOP_P_RANGE, validate_range,
25};
26
27#[derive(Serialize, Deserialize, Debug)]
28pub struct AnnotatedDelta<R> {
29 pub delta: R,
30 pub id: Option<String>,
31 pub event: Option<String>,
32 pub comment: Option<String>,
33}
34
35trait OpenAISamplingOptionsProvider {
36 fn get_temperature(&self) -> Option<f32>;
37
38 fn get_top_p(&self) -> Option<f32>;
39
40 fn get_frequency_penalty(&self) -> Option<f32>;
41
42 fn get_presence_penalty(&self) -> Option<f32>;
43
44 fn get_seed(&self) -> Option<i64>;
45
46 fn get_n(&self) -> Option<u8>;
47
48 fn get_best_of(&self) -> Option<u8>;
49
50 fn nvext(&self) -> Option<&nvext::NvExt>;
51}
52
53trait OpenAIStopConditionsProvider {
54 fn get_max_tokens(&self) -> Option<u32>;
55
56 fn get_min_tokens(&self) -> Option<u32>;
57
58 fn get_stop(&self) -> Option<Vec<String>>;
59
60 fn nvext(&self) -> Option<&nvext::NvExt>;
61
62 fn get_common_ignore_eos(&self) -> Option<bool> {
65 None
66 }
67
68 fn get_ignore_eos(&self) -> Option<bool> {
71 choose_with_deprecation(
72 "ignore_eos",
73 self.get_common_ignore_eos().as_ref(),
74 self.nvext().and_then(|nv| nv.ignore_eos.as_ref()),
75 )
76 }
77
78 fn get_max_thinking_tokens(&self) -> Option<u32> {
81 self.nvext().and_then(|nv| nv.max_thinking_tokens)
82 }
83}
84
85trait OpenAIOutputOptionsProvider {
86 fn get_logprobs(&self) -> Option<u32>;
87
88 fn get_prompt_logprobs(&self) -> Option<u32>;
89
90 fn get_skip_special_tokens(&self) -> Option<bool>;
91
92 fn get_formatted_prompt(&self) -> Option<bool>;
93}
94
95impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvider for T {
96 fn extract_sampling_options(&self) -> Result<common::SamplingOptions> {
97 let mut temperature = validate_range(self.get_temperature(), &TEMPERATURE_RANGE)
103 .map_err(|e| anyhow::anyhow!("Error validating temperature: {}", e))?;
104 let mut top_p = validate_range(self.get_top_p(), &TOP_P_RANGE)
105 .map_err(|e| anyhow::anyhow!("Error validating top_p: {}", e))?;
106 let frequency_penalty =
107 validate_range(self.get_frequency_penalty(), &FREQUENCY_PENALTY_RANGE)
108 .map_err(|e| anyhow::anyhow!("Error validating frequency_penalty: {}", e))?;
109 let presence_penalty = validate_range(self.get_presence_penalty(), &PRESENCE_PENALTY_RANGE)
110 .map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?;
111 let top_k = CommonExtProvider::get_top_k(self);
112 let repetition_penalty = CommonExtProvider::get_repetition_penalty(self);
113 let include_stop_str_in_output = CommonExtProvider::get_include_stop_str_in_output(self);
114 let seed = self.get_seed();
115 let n = validate_range(self.get_n(), &N_RANGE)
116 .map_err(|e| anyhow::anyhow!("Error validating n: {}", e))?;
117 let best_of = validate_range(self.get_best_of(), &BEST_OF_RANGE)
118 .map_err(|e| anyhow::anyhow!("Error validating best_of: {}", e))?;
119
120 let min_p = validate_range(CommonExtProvider::get_min_p(self), &MIN_P_RANGE)
121 .map_err(|e| anyhow::anyhow!("Error validating min_p: {}", e))?;
122
123 if let Some(nvext) = self.nvext() {
124 let greedy = nvext.greed_sampling.unwrap_or(false);
125 if greedy {
126 top_p = None;
127 temperature = None;
128 }
129 }
130
131 let guided_decoding_backend = self.get_guided_decoding_backend();
132 let guided_json = self.get_guided_json();
133 let guided_regex = self.get_guided_regex();
134 let guided_grammar = self.get_guided_grammar();
135 let guided_choice = self.get_guided_choice();
136
137 let guided_decoding = match common::GuidedDecodingOptions::from_optional(
138 guided_json.cloned(),
139 guided_regex,
140 guided_choice,
141 guided_grammar,
142 guided_decoding_backend,
143 ) {
144 Ok(options) => options,
145 Err(e) => {
146 tracing::error!("Invalid guided decoding options: {:?}", e);
148 return Err(e);
149 }
150 };
151
152 Ok(common::SamplingOptions {
153 n,
154 best_of,
155 frequency_penalty,
156 presence_penalty,
157 repetition_penalty,
158 temperature,
159 top_p,
160 top_k,
161 min_p,
162 seed,
163 use_beam_search: None,
164 length_penalty: None,
165 guided_decoding,
166 include_stop_str_in_output,
167 })
168 }
169}
170
171impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
172 fn extract_stop_conditions(&self) -> Result<common::StopConditions> {
173 let max_tokens = self.get_max_tokens();
174 let min_tokens = self.get_min_tokens();
175 let stop = self.get_stop();
176 let max_thinking_tokens = self.get_max_thinking_tokens();
177
178 if let Some(stop) = &stop
179 && stop.len() > 4
180 {
181 anyhow::bail!("stop conditions must be less than 4")
182 }
183
184 let ignore_eos = self.get_ignore_eos();
186
187 Ok(common::StopConditions {
188 max_tokens,
189 min_tokens,
190 stop,
191 stop_token_ids_hidden: None,
192 ignore_eos,
193 max_thinking_tokens,
194 })
195 }
196}
197
198impl<T: OpenAIOutputOptionsProvider> OutputOptionsProvider for T {
199 fn extract_output_options(&self) -> Result<common::OutputOptions> {
200 let logprobs = self.get_logprobs();
201 let prompt_logprobs = self.get_prompt_logprobs();
202 let skip_special_tokens = self.get_skip_special_tokens();
203 let formatted_prompt = self.get_formatted_prompt();
204
205 Ok(common::OutputOptions {
206 logprobs,
207 prompt_logprobs,
208 skip_special_tokens,
209 formatted_prompt,
210 })
211 }
212}
213
214pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
215 Send + 'static
216{
217 fn choice_from_postprocessor(
218 &mut self,
219 response: common::llm_backend::BackendOutput,
220 ) -> Result<ResponseType>;
221
222 fn get_isl(&self) -> Option<u32>;
224
225 fn create_usage_chunk(&self) -> ResponseType;
227
228 fn is_usage_enabled(&self) -> bool;
230}
231
232#[derive(Clone, Debug, Serialize, Deserialize, Default)]
233pub struct ParsingOptions {
234 pub tool_call_parser: Option<String>,
235
236 pub reasoning_parser: Option<String>,
237}
238
239impl ParsingOptions {
240 pub fn new(tool_call_parser: Option<String>, reasoning_parser: Option<String>) -> Self {
241 Self {
242 tool_call_parser,
243 reasoning_parser,
244 }
245 }
246}