dynamo_llm/protocols/openai/
completions.rs1use derive_builder::Builder;
5use dynamo_runtime::protocols::annotated::AnnotationsProvider;
6use serde::{Deserialize, Serialize};
7use validator::Validate;
8
9use crate::engines::ValidateRequest;
10
11use super::{
12 ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
13 OpenAIStopConditionsProvider,
14 common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
15 common_ext::{
16 CommonExt, CommonExtProvider, choose_with_deprecation, emit_nvext_deprecation_warning,
17 },
18 nvext::{NvExt, NvExtProvider},
19 validate,
20};
21
22mod aggregator;
23mod delta;
24
25pub use aggregator::DeltaAggregator;
26pub use delta::DeltaGenerator;
27
28#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
29pub struct NvCreateCompletionRequest {
30 #[serde(flatten)]
31 pub inner: dynamo_async_openai::types::CreateCompletionRequest,
32
33 #[serde(flatten)]
34 pub common: CommonExt,
35
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub nvext: Option<NvExt>,
38}
39
40#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
41pub struct NvCreateCompletionResponse {
42 #[serde(flatten)]
43 pub inner: dynamo_async_openai::types::CreateCompletionResponse,
44}
45
46impl ContentProvider for dynamo_async_openai::types::Choice {
47 fn content(&self) -> String {
48 self.text.clone()
49 }
50}
51
52pub fn prompt_to_string(prompt: &dynamo_async_openai::types::Prompt) -> String {
53 match prompt {
54 dynamo_async_openai::types::Prompt::String(s) => s.clone(),
55 dynamo_async_openai::types::Prompt::StringArray(arr) => arr.join(" "), dynamo_async_openai::types::Prompt::IntegerArray(arr) => arr
57 .iter()
58 .map(|&num| num.to_string())
59 .collect::<Vec<_>>()
60 .join(" "),
61 dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
62 .iter()
63 .map(|inner| {
64 inner
65 .iter()
66 .map(|&num| num.to_string())
67 .collect::<Vec<_>>()
68 .join(" ")
69 })
70 .collect::<Vec<_>>()
71 .join(" | "), }
73}
74
75impl NvExtProvider for NvCreateCompletionRequest {
76 fn nvext(&self) -> Option<&NvExt> {
77 self.nvext.as_ref()
78 }
79
80 fn raw_prompt(&self) -> Option<String> {
81 if let Some(nvext) = self.nvext.as_ref()
82 && let Some(use_raw_prompt) = nvext.use_raw_prompt
83 && use_raw_prompt
84 {
85 return Some(prompt_to_string(&self.inner.prompt));
86 }
87 None
88 }
89}
90
91impl AnnotationsProvider for NvCreateCompletionRequest {
92 fn annotations(&self) -> Option<Vec<String>> {
93 self.nvext
94 .as_ref()
95 .and_then(|nvext| nvext.annotations.clone())
96 }
97
98 fn has_annotation(&self, annotation: &str) -> bool {
99 self.nvext
100 .as_ref()
101 .and_then(|nvext| nvext.annotations.as_ref())
102 .map(|annotations| annotations.contains(&annotation.to_string()))
103 .unwrap_or(false)
104 }
105}
106
107impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
108 fn get_temperature(&self) -> Option<f32> {
109 self.inner.temperature
110 }
111
112 fn get_top_p(&self) -> Option<f32> {
113 self.inner.top_p
114 }
115
116 fn get_frequency_penalty(&self) -> Option<f32> {
117 self.inner.frequency_penalty
118 }
119
120 fn get_presence_penalty(&self) -> Option<f32> {
121 self.inner.presence_penalty
122 }
123
124 fn nvext(&self) -> Option<&NvExt> {
125 self.nvext.as_ref()
126 }
127
128 fn get_seed(&self) -> Option<i64> {
129 self.inner.seed
130 }
131
132 fn get_n(&self) -> Option<u8> {
133 self.inner.n
134 }
135
136 fn get_best_of(&self) -> Option<u8> {
137 self.inner.best_of
138 }
139}
140
141impl CommonExtProvider for NvCreateCompletionRequest {
142 fn common_ext(&self) -> Option<&CommonExt> {
143 Some(&self.common)
144 }
145
146 fn get_guided_json(&self) -> Option<&serde_json::Value> {
148 if let Some(nvext) = &self.nvext
150 && nvext.guided_json.is_some()
151 {
152 emit_nvext_deprecation_warning("guided_json", true, self.common.guided_json.is_some());
153 }
154 self.common
155 .guided_json
156 .as_ref()
157 .or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
158 }
159
160 fn get_guided_regex(&self) -> Option<String> {
161 choose_with_deprecation(
162 "guided_regex",
163 self.common.guided_regex.as_ref(),
164 self.nvext.as_ref().and_then(|nv| nv.guided_regex.as_ref()),
165 )
166 }
167
168 fn get_guided_grammar(&self) -> Option<String> {
169 choose_with_deprecation(
170 "guided_grammar",
171 self.common.guided_grammar.as_ref(),
172 self.nvext
173 .as_ref()
174 .and_then(|nv| nv.guided_grammar.as_ref()),
175 )
176 }
177
178 fn get_guided_choice(&self) -> Option<Vec<String>> {
179 choose_with_deprecation(
180 "guided_choice",
181 self.common.guided_choice.as_ref(),
182 self.nvext.as_ref().and_then(|nv| nv.guided_choice.as_ref()),
183 )
184 }
185
186 fn get_guided_decoding_backend(&self) -> Option<String> {
187 choose_with_deprecation(
188 "guided_decoding_backend",
189 self.common.guided_decoding_backend.as_ref(),
190 self.nvext
191 .as_ref()
192 .and_then(|nv| nv.guided_decoding_backend.as_ref()),
193 )
194 }
195
196 fn get_top_k(&self) -> Option<i32> {
197 choose_with_deprecation(
198 "top_k",
199 self.common.top_k.as_ref(),
200 self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()),
201 )
202 }
203
204 fn get_min_p(&self) -> Option<f32> {
205 choose_with_deprecation(
206 "min_p",
207 self.common.min_p.as_ref(),
208 self.nvext.as_ref().and_then(|nv| nv.min_p.as_ref()),
209 )
210 }
211
212 fn get_repetition_penalty(&self) -> Option<f32> {
213 choose_with_deprecation(
214 "repetition_penalty",
215 self.common.repetition_penalty.as_ref(),
216 self.nvext
217 .as_ref()
218 .and_then(|nv| nv.repetition_penalty.as_ref()),
219 )
220 }
221
222 fn get_include_stop_str_in_output(&self) -> Option<bool> {
223 self.common.include_stop_str_in_output
224 }
225}
226
227impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
228 fn get_max_tokens(&self) -> Option<u32> {
229 self.inner.max_tokens
230 }
231
232 fn get_min_tokens(&self) -> Option<u32> {
233 self.common.min_tokens
234 }
235
236 fn get_stop(&self) -> Option<Vec<String>> {
237 None
238 }
239
240 fn nvext(&self) -> Option<&NvExt> {
241 self.nvext.as_ref()
242 }
243
244 fn get_common_ignore_eos(&self) -> Option<bool> {
245 self.common.ignore_eos
246 }
247
248 fn get_ignore_eos(&self) -> Option<bool> {
251 choose_with_deprecation(
252 "ignore_eos",
253 self.get_common_ignore_eos().as_ref(),
254 NvExtProvider::nvext(self).and_then(|nv| nv.ignore_eos.as_ref()),
255 )
256 }
257}
258
259#[derive(Builder)]
260pub struct ResponseFactory {
261 #[builder(setter(into))]
262 pub model: String,
263
264 #[builder(default)]
265 pub system_fingerprint: Option<String>,
266
267 #[builder(default = "format!(\"cmpl-{}\", uuid::Uuid::new_v4())")]
268 pub id: String,
269
270 #[builder(default = "\"text_completion\".to_string()")]
271 pub object: String,
272
273 #[builder(default = "chrono::Utc::now().timestamp() as u32")]
274 pub created: u32,
275}
276
277impl ResponseFactory {
278 pub fn builder() -> ResponseFactoryBuilder {
279 ResponseFactoryBuilder::default()
280 }
281
282 pub fn make_response(
283 &self,
284 choice: dynamo_async_openai::types::Choice,
285 usage: Option<dynamo_async_openai::types::CompletionUsage>,
286 ) -> NvCreateCompletionResponse {
287 let inner = dynamo_async_openai::types::CreateCompletionResponse {
288 id: self.id.clone(),
289 object: self.object.clone(),
290 created: self.created,
291 model: self.model.clone(),
292 choices: vec![choice],
293 system_fingerprint: self.system_fingerprint.clone(),
294 usage,
295 };
296 NvCreateCompletionResponse { inner }
297 }
298}
299
300impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
302 type Error = anyhow::Error;
303
304 fn try_from(request: NvCreateCompletionRequest) -> Result<Self, Self::Error> {
305 if request.inner.suffix.is_some() {
327 return Err(anyhow::anyhow!("suffix is not supported"));
328 }
329
330 let stop_conditions = request
331 .extract_stop_conditions()
332 .map_err(|e| anyhow::anyhow!("Failed to extract stop conditions: {}", e))?;
333
334 let sampling_options = request
335 .extract_sampling_options()
336 .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;
337
338 let output_options = request
339 .extract_output_options()
340 .map_err(|e| anyhow::anyhow!("Failed to extract output options: {}", e))?;
341
342 let prompt = common::PromptType::Completion(common::CompletionContext {
343 prompt: prompt_to_string(&request.inner.prompt),
344 system_prompt: None,
345 });
346
347 Ok(common::CompletionRequest {
348 prompt,
349 stop_conditions,
350 sampling_options,
351 output_options,
352 mdc_sum: None,
353 annotations: None,
354 })
355 }
356}
357
358impl TryFrom<common::StreamingCompletionResponse> for dynamo_async_openai::types::Choice {
359 type Error = anyhow::Error;
360
361 fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
362 let text = response
363 .delta
364 .text
365 .ok_or(anyhow::anyhow!("No text in response"))?;
366
367 let index: u32 = response
370 .delta
371 .index
372 .unwrap_or(0)
373 .try_into()
374 .expect("index exceeds u32::MAX");
375
376 let logprobs = None;
378
379 let finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason> =
380 response.delta.finish_reason.map(Into::into);
381
382 let choice = dynamo_async_openai::types::Choice {
383 text,
384 index,
385 logprobs,
386 finish_reason,
387 };
388
389 Ok(choice)
390 }
391}
392
393impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
394 fn get_logprobs(&self) -> Option<u32> {
395 self.inner.logprobs.map(|logprobs| logprobs as u32)
396 }
397
398 fn get_prompt_logprobs(&self) -> Option<u32> {
399 self.inner
400 .echo
401 .and_then(|echo| if echo { Some(1) } else { None })
402 }
403
404 fn get_skip_special_tokens(&self) -> Option<bool> {
405 None
406 }
407
408 fn get_formatted_prompt(&self) -> Option<bool> {
409 None
410 }
411}
412
413impl ValidateRequest for NvCreateCompletionRequest {
416 fn validate(&self) -> Result<(), anyhow::Error> {
417 validate::validate_model(&self.inner.model)?;
418 validate::validate_prompt(&self.inner.prompt)?;
419 validate::validate_suffix(self.inner.suffix.as_deref())?;
420 validate::validate_max_tokens(self.inner.max_tokens)?;
421 validate::validate_temperature(self.inner.temperature)?;
422 validate::validate_top_p(self.inner.top_p)?;
423 validate::validate_n(self.inner.n)?;
424 validate::validate_logprobs(self.inner.logprobs)?;
427 validate::validate_stop(&self.inner.stop)?;
429 validate::validate_presence_penalty(self.inner.presence_penalty)?;
430 validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
431 validate::validate_best_of(self.inner.best_of, self.inner.n)?;
432 validate::validate_logit_bias(&self.inner.logit_bias)?;
433 validate::validate_user(self.inner.user.as_deref())?;
434 validate::validate_repetition_penalty(self.get_repetition_penalty())?;
438
439 Ok(())
440 }
441}