dynamo_llm/protocols/openai/
completions.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::collections::HashMap;
17
18use derive_builder::Builder;
19use serde::{Deserialize, Serialize};
20use validator::Validate;
21
22mod aggregator;
23mod delta;
24
25pub use aggregator::DeltaAggregator;
26pub use delta::DeltaGenerator;
27
28use super::{
29    common::{self, SamplingOptionsProvider, StopConditionsProvider},
30    nvext::{NvExt, NvExtProvider},
31    CompletionUsage, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
32};
33
34use dynamo_runtime::protocols::annotated::AnnotationsProvider;
35
36#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
37pub struct NvCreateCompletionRequest {
38    #[serde(flatten)]
39    pub inner: async_openai::types::CreateCompletionRequest,
40
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub nvext: Option<NvExt>,
43}
44
45/// Legacy OpenAI CompletionResponse
46/// Represents a completion response from the API.
47/// Note: both the streamed and non-streamed response objects share the same
48/// shape (unlike the chat endpoint).
49#[derive(Clone, Debug, Deserialize, Serialize)]
50pub struct CompletionResponse {
51    /// A unique identifier for the completion.
52    pub id: String,
53
54    /// The list of completion choices the model generated for the input prompt.
55    pub choices: Vec<CompletionChoice>,
56
57    /// The Unix timestamp (in seconds) of when the completion was created.
58    pub created: u64,
59
60    /// The model used for completion.
61    pub model: String,
62
63    /// The object type, which is always "text_completion"
64    pub object: String,
65
66    /// Usage statistics for the completion request.
67    pub usage: Option<CompletionUsage>,
68
69    /// This fingerprint represents the backend configuration that the model runs with.
70    /// Can be used in conjunction with the seed request parameter to understand when backend
71    /// changes have been made that might impact determinism.
72    ///
73    /// NIM Compatibility:
74    /// This field is not supported by the NIM; however it will be added in the future.
75    /// The optional nature of this field will be relaxed when it is supported.
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub system_fingerprint: Option<String>,
78    // TODO(ryan)
79    // pub nvext: Option<NimResponseExt>,
80}
81
82/// Legacy OpenAI CompletionResponse Choice component
83#[derive(Clone, Debug, Deserialize, Serialize, Builder)]
84pub struct CompletionChoice {
85    #[builder(setter(into))]
86    pub text: String,
87
88    #[builder(default = "0")]
89    pub index: u64,
90
91    #[builder(default, setter(into, strip_option))]
92    pub finish_reason: Option<String>,
93
94    #[serde(skip_serializing_if = "Option::is_none")]
95    #[builder(default, setter(strip_option))]
96    pub logprobs: Option<LogprobResult>,
97}
98
99impl ContentProvider for CompletionChoice {
100    fn content(&self) -> String {
101        self.text.clone()
102    }
103}
104
105impl CompletionChoice {
106    pub fn builder() -> CompletionChoiceBuilder {
107        CompletionChoiceBuilder::default()
108    }
109}
110
111// TODO: validate this is the correct format
112/// Legacy OpenAI LogprobResult component
113#[derive(Clone, Debug, Deserialize, Serialize)]
114pub struct LogprobResult {
115    pub tokens: Vec<String>,
116    pub token_logprobs: Vec<f32>,
117    pub top_logprobs: Vec<HashMap<String, f32>>,
118    pub text_offset: Vec<i32>,
119}
120
121pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String {
122    match prompt {
123        async_openai::types::Prompt::String(s) => s.clone(),
124        async_openai::types::Prompt::StringArray(arr) => arr.join(" "), // Join strings with spaces
125        async_openai::types::Prompt::IntegerArray(arr) => arr
126            .iter()
127            .map(|&num| num.to_string())
128            .collect::<Vec<_>>()
129            .join(" "),
130        async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
131            .iter()
132            .map(|inner| {
133                inner
134                    .iter()
135                    .map(|&num| num.to_string())
136                    .collect::<Vec<_>>()
137                    .join(" ")
138            })
139            .collect::<Vec<_>>()
140            .join(" | "), // Separate arrays with a delimiter
141    }
142}
143
144impl NvExtProvider for NvCreateCompletionRequest {
145    fn nvext(&self) -> Option<&NvExt> {
146        self.nvext.as_ref()
147    }
148
149    fn raw_prompt(&self) -> Option<String> {
150        if let Some(nvext) = self.nvext.as_ref() {
151            if let Some(use_raw_prompt) = nvext.use_raw_prompt {
152                if use_raw_prompt {
153                    return Some(prompt_to_string(&self.inner.prompt));
154                }
155            }
156        }
157        None
158    }
159}
160
161impl AnnotationsProvider for NvCreateCompletionRequest {
162    fn annotations(&self) -> Option<Vec<String>> {
163        self.nvext
164            .as_ref()
165            .and_then(|nvext| nvext.annotations.clone())
166    }
167
168    fn has_annotation(&self, annotation: &str) -> bool {
169        self.nvext
170            .as_ref()
171            .and_then(|nvext| nvext.annotations.as_ref())
172            .map(|annotations| annotations.contains(&annotation.to_string()))
173            .unwrap_or(false)
174    }
175}
176
177impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
178    fn get_temperature(&self) -> Option<f32> {
179        self.inner.temperature
180    }
181
182    fn get_top_p(&self) -> Option<f32> {
183        self.inner.top_p
184    }
185
186    fn get_frequency_penalty(&self) -> Option<f32> {
187        self.inner.frequency_penalty
188    }
189
190    fn get_presence_penalty(&self) -> Option<f32> {
191        self.inner.presence_penalty
192    }
193
194    fn nvext(&self) -> Option<&NvExt> {
195        self.nvext.as_ref()
196    }
197}
198
199impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
200    fn get_max_tokens(&self) -> Option<u32> {
201        self.inner.max_tokens
202    }
203
204    fn get_min_tokens(&self) -> Option<u32> {
205        None
206    }
207
208    fn get_stop(&self) -> Option<Vec<String>> {
209        None
210    }
211
212    fn nvext(&self) -> Option<&NvExt> {
213        self.nvext.as_ref()
214    }
215}
216
217#[derive(Builder)]
218pub struct ResponseFactory {
219    #[builder(setter(into))]
220    pub model: String,
221
222    #[builder(default)]
223    pub system_fingerprint: Option<String>,
224
225    #[builder(default = "format!(\"cmpl-{}\", uuid::Uuid::new_v4())")]
226    pub id: String,
227
228    #[builder(default = "\"text_completion\".to_string()")]
229    pub object: String,
230
231    #[builder(default = "chrono::Utc::now().timestamp() as u64")]
232    pub created: u64,
233}
234
235impl ResponseFactory {
236    pub fn builder() -> ResponseFactoryBuilder {
237        ResponseFactoryBuilder::default()
238    }
239
240    pub fn make_response(
241        &self,
242        choice: CompletionChoice,
243        usage: Option<CompletionUsage>,
244    ) -> CompletionResponse {
245        CompletionResponse {
246            id: self.id.clone(),
247            object: self.object.clone(),
248            created: self.created,
249            model: self.model.clone(),
250            choices: vec![choice],
251            system_fingerprint: self.system_fingerprint.clone(),
252            usage,
253        }
254    }
255}
256
257/// Implements TryFrom for converting an OpenAI's CompletionRequest to an Engine's CompletionRequest
258impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
259    type Error = anyhow::Error;
260
261    fn try_from(request: NvCreateCompletionRequest) -> Result<Self, Self::Error> {
262        // openai_api_rs::v1::completion::CompletionRequest {
263        // NA  pub model: String,
264        //     pub prompt: String,
265        // **  pub suffix: Option<String>,
266        //     pub max_tokens: Option<i32>,
267        //     pub temperature: Option<f32>,
268        //     pub top_p: Option<f32>,
269        //     pub n: Option<i32>,
270        //     pub stream: Option<bool>,
271        //     pub logprobs: Option<i32>,
272        //     pub echo: Option<bool>,
273        //     pub stop: Option<Vec<String, Global>>,
274        //     pub presence_penalty: Option<f32>,
275        //     pub frequency_penalty: Option<f32>,
276        //     pub best_of: Option<i32>,
277        //     pub logit_bias: Option<HashMap<String, i32, RandomState>>,
278        //     pub user: Option<String>,
279        // }
280        //
281        // ** no supported
282
283        if request.inner.suffix.is_some() {
284            return Err(anyhow::anyhow!("suffix is not supported"));
285        }
286
287        let stop_conditions = request
288            .extract_stop_conditions()
289            .map_err(|e| anyhow::anyhow!("Failed to extract stop conditions: {}", e))?;
290
291        let sampling_options = request
292            .extract_sampling_options()
293            .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;
294
295        let prompt = common::PromptType::Completion(common::CompletionContext {
296            prompt: prompt_to_string(&request.inner.prompt),
297            system_prompt: None,
298        });
299
300        Ok(common::CompletionRequest {
301            prompt,
302            stop_conditions,
303            sampling_options,
304            mdc_sum: None,
305            annotations: None,
306        })
307    }
308}
309
310impl TryFrom<common::StreamingCompletionResponse> for CompletionChoice {
311    type Error = anyhow::Error;
312
313    fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
314        let choice = CompletionChoice {
315            text: response
316                .delta
317                .text
318                .ok_or(anyhow::anyhow!("No text in response"))?,
319            index: response.delta.index.unwrap_or(0) as u64,
320            logprobs: None,
321            finish_reason: match &response.delta.finish_reason {
322                Some(common::FinishReason::EoS) => Some("stop".to_string()),
323                Some(common::FinishReason::Stop) => Some("stop".to_string()),
324                Some(common::FinishReason::Length) => Some("length".to_string()),
325                Some(common::FinishReason::Error(err_msg)) => {
326                    return Err(anyhow::anyhow!("finish_reason::error = {}", err_msg));
327                }
328                Some(common::FinishReason::Cancelled) => Some("cancelled".to_string()),
329                None => None,
330            },
331        };
332
333        Ok(choice)
334    }
335}