dynamo_llm/protocols/openai/
completions.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use derive_builder::Builder;
5use dynamo_runtime::protocols::annotated::AnnotationsProvider;
6use serde::{Deserialize, Serialize};
7use validator::Validate;
8
9use crate::engines::ValidateRequest;
10
11use super::{
12    ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
13    OpenAIStopConditionsProvider,
14    common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
15    common_ext::{
16        CommonExt, CommonExtProvider, choose_with_deprecation, emit_nvext_deprecation_warning,
17    },
18    nvext::{NvExt, NvExtProvider},
19    validate,
20};
21
22mod aggregator;
23mod delta;
24
25pub use aggregator::DeltaAggregator;
26pub use delta::DeltaGenerator;
27
28#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
29pub struct NvCreateCompletionRequest {
30    #[serde(flatten)]
31    pub inner: dynamo_async_openai::types::CreateCompletionRequest,
32
33    #[serde(flatten)]
34    pub common: CommonExt,
35
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub nvext: Option<NvExt>,
38}
39
40#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
41pub struct NvCreateCompletionResponse {
42    #[serde(flatten)]
43    pub inner: dynamo_async_openai::types::CreateCompletionResponse,
44}
45
46impl ContentProvider for dynamo_async_openai::types::Choice {
47    fn content(&self) -> String {
48        self.text.clone()
49    }
50}
51
52pub fn prompt_to_string(prompt: &dynamo_async_openai::types::Prompt) -> String {
53    match prompt {
54        dynamo_async_openai::types::Prompt::String(s) => s.clone(),
55        dynamo_async_openai::types::Prompt::StringArray(arr) => arr.join(" "), // Join strings with spaces
56        dynamo_async_openai::types::Prompt::IntegerArray(arr) => arr
57            .iter()
58            .map(|&num| num.to_string())
59            .collect::<Vec<_>>()
60            .join(" "),
61        dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
62            .iter()
63            .map(|inner| {
64                inner
65                    .iter()
66                    .map(|&num| num.to_string())
67                    .collect::<Vec<_>>()
68                    .join(" ")
69            })
70            .collect::<Vec<_>>()
71            .join(" | "), // Separate arrays with a delimiter
72    }
73}
74
75impl NvExtProvider for NvCreateCompletionRequest {
76    fn nvext(&self) -> Option<&NvExt> {
77        self.nvext.as_ref()
78    }
79
80    fn raw_prompt(&self) -> Option<String> {
81        if let Some(nvext) = self.nvext.as_ref()
82            && let Some(use_raw_prompt) = nvext.use_raw_prompt
83            && use_raw_prompt
84        {
85            return Some(prompt_to_string(&self.inner.prompt));
86        }
87        None
88    }
89}
90
91impl AnnotationsProvider for NvCreateCompletionRequest {
92    fn annotations(&self) -> Option<Vec<String>> {
93        self.nvext
94            .as_ref()
95            .and_then(|nvext| nvext.annotations.clone())
96    }
97
98    fn has_annotation(&self, annotation: &str) -> bool {
99        self.nvext
100            .as_ref()
101            .and_then(|nvext| nvext.annotations.as_ref())
102            .map(|annotations| annotations.contains(&annotation.to_string()))
103            .unwrap_or(false)
104    }
105}
106
107impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
108    fn get_temperature(&self) -> Option<f32> {
109        self.inner.temperature
110    }
111
112    fn get_top_p(&self) -> Option<f32> {
113        self.inner.top_p
114    }
115
116    fn get_frequency_penalty(&self) -> Option<f32> {
117        self.inner.frequency_penalty
118    }
119
120    fn get_presence_penalty(&self) -> Option<f32> {
121        self.inner.presence_penalty
122    }
123
124    fn nvext(&self) -> Option<&NvExt> {
125        self.nvext.as_ref()
126    }
127
128    fn get_seed(&self) -> Option<i64> {
129        self.inner.seed
130    }
131
132    fn get_n(&self) -> Option<u8> {
133        self.inner.n
134    }
135
136    fn get_best_of(&self) -> Option<u8> {
137        self.inner.best_of
138    }
139}
140
141impl CommonExtProvider for NvCreateCompletionRequest {
142    fn common_ext(&self) -> Option<&CommonExt> {
143        Some(&self.common)
144    }
145
146    /// Guided Decoding Options
147    fn get_guided_json(&self) -> Option<&serde_json::Value> {
148        // Note: This one needs special handling since it returns a reference
149        if let Some(nvext) = &self.nvext
150            && nvext.guided_json.is_some()
151        {
152            emit_nvext_deprecation_warning("guided_json", true, self.common.guided_json.is_some());
153        }
154        self.common
155            .guided_json
156            .as_ref()
157            .or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
158    }
159
160    fn get_guided_regex(&self) -> Option<String> {
161        choose_with_deprecation(
162            "guided_regex",
163            self.common.guided_regex.as_ref(),
164            self.nvext.as_ref().and_then(|nv| nv.guided_regex.as_ref()),
165        )
166    }
167
168    fn get_guided_grammar(&self) -> Option<String> {
169        choose_with_deprecation(
170            "guided_grammar",
171            self.common.guided_grammar.as_ref(),
172            self.nvext
173                .as_ref()
174                .and_then(|nv| nv.guided_grammar.as_ref()),
175        )
176    }
177
178    fn get_guided_choice(&self) -> Option<Vec<String>> {
179        choose_with_deprecation(
180            "guided_choice",
181            self.common.guided_choice.as_ref(),
182            self.nvext.as_ref().and_then(|nv| nv.guided_choice.as_ref()),
183        )
184    }
185
186    fn get_guided_decoding_backend(&self) -> Option<String> {
187        choose_with_deprecation(
188            "guided_decoding_backend",
189            self.common.guided_decoding_backend.as_ref(),
190            self.nvext
191                .as_ref()
192                .and_then(|nv| nv.guided_decoding_backend.as_ref()),
193        )
194    }
195
196    fn get_top_k(&self) -> Option<i32> {
197        choose_with_deprecation(
198            "top_k",
199            self.common.top_k.as_ref(),
200            self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()),
201        )
202    }
203
204    fn get_min_p(&self) -> Option<f32> {
205        choose_with_deprecation(
206            "min_p",
207            self.common.min_p.as_ref(),
208            self.nvext.as_ref().and_then(|nv| nv.min_p.as_ref()),
209        )
210    }
211
212    fn get_repetition_penalty(&self) -> Option<f32> {
213        choose_with_deprecation(
214            "repetition_penalty",
215            self.common.repetition_penalty.as_ref(),
216            self.nvext
217                .as_ref()
218                .and_then(|nv| nv.repetition_penalty.as_ref()),
219        )
220    }
221
222    fn get_include_stop_str_in_output(&self) -> Option<bool> {
223        self.common.include_stop_str_in_output
224    }
225}
226
227impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
228    fn get_max_tokens(&self) -> Option<u32> {
229        self.inner.max_tokens
230    }
231
232    fn get_min_tokens(&self) -> Option<u32> {
233        self.common.min_tokens
234    }
235
236    fn get_stop(&self) -> Option<Vec<String>> {
237        None
238    }
239
240    fn nvext(&self) -> Option<&NvExt> {
241        self.nvext.as_ref()
242    }
243
244    fn get_common_ignore_eos(&self) -> Option<bool> {
245        self.common.ignore_eos
246    }
247
248    /// Get the effective ignore_eos value, considering both CommonExt and NvExt.
249    /// CommonExt (root-level) takes precedence over NvExt.
250    fn get_ignore_eos(&self) -> Option<bool> {
251        choose_with_deprecation(
252            "ignore_eos",
253            self.get_common_ignore_eos().as_ref(),
254            NvExtProvider::nvext(self).and_then(|nv| nv.ignore_eos.as_ref()),
255        )
256    }
257}
258
259#[derive(Builder)]
260pub struct ResponseFactory {
261    #[builder(setter(into))]
262    pub model: String,
263
264    #[builder(default)]
265    pub system_fingerprint: Option<String>,
266
267    #[builder(default = "format!(\"cmpl-{}\", uuid::Uuid::new_v4())")]
268    pub id: String,
269
270    #[builder(default = "\"text_completion\".to_string()")]
271    pub object: String,
272
273    #[builder(default = "chrono::Utc::now().timestamp() as u32")]
274    pub created: u32,
275}
276
277impl ResponseFactory {
278    pub fn builder() -> ResponseFactoryBuilder {
279        ResponseFactoryBuilder::default()
280    }
281
282    pub fn make_response(
283        &self,
284        choice: dynamo_async_openai::types::Choice,
285        usage: Option<dynamo_async_openai::types::CompletionUsage>,
286    ) -> NvCreateCompletionResponse {
287        let inner = dynamo_async_openai::types::CreateCompletionResponse {
288            id: self.id.clone(),
289            object: self.object.clone(),
290            created: self.created,
291            model: self.model.clone(),
292            choices: vec![choice],
293            system_fingerprint: self.system_fingerprint.clone(),
294            usage,
295        };
296        NvCreateCompletionResponse { inner }
297    }
298}
299
300/// Implements TryFrom for converting an OpenAI's CompletionRequest to an Engine's CompletionRequest
301impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
302    type Error = anyhow::Error;
303
304    fn try_from(request: NvCreateCompletionRequest) -> Result<Self, Self::Error> {
305        // openai_api_rs::v1::completion::CompletionRequest {
306        // NA  pub model: String,
307        //     pub prompt: String,
308        // **  pub suffix: Option<String>,
309        //     pub max_tokens: Option<i32>,
310        //     pub temperature: Option<f32>,
311        //     pub top_p: Option<f32>,
312        //     pub n: Option<i32>,
313        //     pub stream: Option<bool>,
314        //     pub logprobs: Option<i32>,
315        //     pub echo: Option<bool>,
316        //     pub stop: Option<Vec<String, Global>>,
317        //     pub presence_penalty: Option<f32>,
318        //     pub frequency_penalty: Option<f32>,
319        //     pub best_of: Option<i32>,
320        //     pub logit_bias: Option<HashMap<String, i32, RandomState>>,
321        //     pub user: Option<String>,
322        // }
323        //
324        // ** no supported
325
326        if request.inner.suffix.is_some() {
327            return Err(anyhow::anyhow!("suffix is not supported"));
328        }
329
330        let stop_conditions = request
331            .extract_stop_conditions()
332            .map_err(|e| anyhow::anyhow!("Failed to extract stop conditions: {}", e))?;
333
334        let sampling_options = request
335            .extract_sampling_options()
336            .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;
337
338        let output_options = request
339            .extract_output_options()
340            .map_err(|e| anyhow::anyhow!("Failed to extract output options: {}", e))?;
341
342        let prompt = common::PromptType::Completion(common::CompletionContext {
343            prompt: prompt_to_string(&request.inner.prompt),
344            system_prompt: None,
345        });
346
347        Ok(common::CompletionRequest {
348            prompt,
349            stop_conditions,
350            sampling_options,
351            output_options,
352            mdc_sum: None,
353            annotations: None,
354        })
355    }
356}
357
358impl TryFrom<common::StreamingCompletionResponse> for dynamo_async_openai::types::Choice {
359    type Error = anyhow::Error;
360
361    fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
362        let text = response
363            .delta
364            .text
365            .ok_or(anyhow::anyhow!("No text in response"))?;
366
367        // SAFETY: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
368        // so we're fairly safe knowing we won't generate that many Choices
369        let index: u32 = response
370            .delta
371            .index
372            .unwrap_or(0)
373            .try_into()
374            .expect("index exceeds u32::MAX");
375
376        // TODO handle aggregating logprobs
377        let logprobs = None;
378
379        let finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason> =
380            response.delta.finish_reason.map(Into::into);
381
382        let choice = dynamo_async_openai::types::Choice {
383            text,
384            index,
385            logprobs,
386            finish_reason,
387        };
388
389        Ok(choice)
390    }
391}
392
393impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
394    fn get_logprobs(&self) -> Option<u32> {
395        self.inner.logprobs.map(|logprobs| logprobs as u32)
396    }
397
398    fn get_prompt_logprobs(&self) -> Option<u32> {
399        self.inner
400            .echo
401            .and_then(|echo| if echo { Some(1) } else { None })
402    }
403
404    fn get_skip_special_tokens(&self) -> Option<bool> {
405        None
406    }
407
408    fn get_formatted_prompt(&self) -> Option<bool> {
409        None
410    }
411}
412
413/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
414/// allowing us to validate the data.
415impl ValidateRequest for NvCreateCompletionRequest {
416    fn validate(&self) -> Result<(), anyhow::Error> {
417        validate::validate_model(&self.inner.model)?;
418        validate::validate_prompt(&self.inner.prompt)?;
419        validate::validate_suffix(self.inner.suffix.as_deref())?;
420        validate::validate_max_tokens(self.inner.max_tokens)?;
421        validate::validate_temperature(self.inner.temperature)?;
422        validate::validate_top_p(self.inner.top_p)?;
423        validate::validate_n(self.inner.n)?;
424        // none for stream
425        // none for stream_options
426        validate::validate_logprobs(self.inner.logprobs)?;
427        // none for echo
428        validate::validate_stop(&self.inner.stop)?;
429        validate::validate_presence_penalty(self.inner.presence_penalty)?;
430        validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
431        validate::validate_best_of(self.inner.best_of, self.inner.n)?;
432        validate::validate_logit_bias(&self.inner.logit_bias)?;
433        validate::validate_user(self.inner.user.as_deref())?;
434        // none for seed
435
436        // Common Ext
437        validate::validate_repetition_penalty(self.get_repetition_penalty())?;
438
439        Ok(())
440    }
441}