dynamo_llm/protocols/openai/
completions.rs1use std::collections::HashMap;
17
18use derive_builder::Builder;
19use serde::{Deserialize, Serialize};
20use validator::Validate;
21
22mod aggregator;
23mod delta;
24
25pub use aggregator::DeltaAggregator;
26pub use delta::DeltaGenerator;
27
28use super::{
29 common::{self, SamplingOptionsProvider, StopConditionsProvider},
30 nvext::{NvExt, NvExtProvider},
31 CompletionUsage, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
32};
33
34use dynamo_runtime::protocols::annotated::AnnotationsProvider;
35
36#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
37pub struct NvCreateCompletionRequest {
38 #[serde(flatten)]
39 pub inner: async_openai::types::CreateCompletionRequest,
40
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub nvext: Option<NvExt>,
43}
44
45#[derive(Clone, Debug, Deserialize, Serialize)]
50pub struct CompletionResponse {
51 pub id: String,
53
54 pub choices: Vec<CompletionChoice>,
56
57 pub created: u64,
59
60 pub model: String,
62
63 pub object: String,
65
66 pub usage: Option<CompletionUsage>,
68
69 #[serde(skip_serializing_if = "Option::is_none")]
77 pub system_fingerprint: Option<String>,
78 }
81
82#[derive(Clone, Debug, Deserialize, Serialize, Builder)]
84pub struct CompletionChoice {
85 #[builder(setter(into))]
86 pub text: String,
87
88 #[builder(default = "0")]
89 pub index: u64,
90
91 #[builder(default, setter(into, strip_option))]
92 pub finish_reason: Option<String>,
93
94 #[serde(skip_serializing_if = "Option::is_none")]
95 #[builder(default, setter(strip_option))]
96 pub logprobs: Option<LogprobResult>,
97}
98
99impl ContentProvider for CompletionChoice {
100 fn content(&self) -> String {
101 self.text.clone()
102 }
103}
104
105impl CompletionChoice {
106 pub fn builder() -> CompletionChoiceBuilder {
107 CompletionChoiceBuilder::default()
108 }
109}
110
111#[derive(Clone, Debug, Deserialize, Serialize)]
114pub struct LogprobResult {
115 pub tokens: Vec<String>,
116 pub token_logprobs: Vec<f32>,
117 pub top_logprobs: Vec<HashMap<String, f32>>,
118 pub text_offset: Vec<i32>,
119}
120
121pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String {
122 match prompt {
123 async_openai::types::Prompt::String(s) => s.clone(),
124 async_openai::types::Prompt::StringArray(arr) => arr.join(" "), async_openai::types::Prompt::IntegerArray(arr) => arr
126 .iter()
127 .map(|&num| num.to_string())
128 .collect::<Vec<_>>()
129 .join(" "),
130 async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
131 .iter()
132 .map(|inner| {
133 inner
134 .iter()
135 .map(|&num| num.to_string())
136 .collect::<Vec<_>>()
137 .join(" ")
138 })
139 .collect::<Vec<_>>()
140 .join(" | "), }
142}
143
144impl NvExtProvider for NvCreateCompletionRequest {
145 fn nvext(&self) -> Option<&NvExt> {
146 self.nvext.as_ref()
147 }
148
149 fn raw_prompt(&self) -> Option<String> {
150 if let Some(nvext) = self.nvext.as_ref() {
151 if let Some(use_raw_prompt) = nvext.use_raw_prompt {
152 if use_raw_prompt {
153 return Some(prompt_to_string(&self.inner.prompt));
154 }
155 }
156 }
157 None
158 }
159}
160
161impl AnnotationsProvider for NvCreateCompletionRequest {
162 fn annotations(&self) -> Option<Vec<String>> {
163 self.nvext
164 .as_ref()
165 .and_then(|nvext| nvext.annotations.clone())
166 }
167
168 fn has_annotation(&self, annotation: &str) -> bool {
169 self.nvext
170 .as_ref()
171 .and_then(|nvext| nvext.annotations.as_ref())
172 .map(|annotations| annotations.contains(&annotation.to_string()))
173 .unwrap_or(false)
174 }
175}
176
177impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
178 fn get_temperature(&self) -> Option<f32> {
179 self.inner.temperature
180 }
181
182 fn get_top_p(&self) -> Option<f32> {
183 self.inner.top_p
184 }
185
186 fn get_frequency_penalty(&self) -> Option<f32> {
187 self.inner.frequency_penalty
188 }
189
190 fn get_presence_penalty(&self) -> Option<f32> {
191 self.inner.presence_penalty
192 }
193
194 fn nvext(&self) -> Option<&NvExt> {
195 self.nvext.as_ref()
196 }
197}
198
199impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
200 fn get_max_tokens(&self) -> Option<u32> {
201 self.inner.max_tokens
202 }
203
204 fn get_min_tokens(&self) -> Option<u32> {
205 None
206 }
207
208 fn get_stop(&self) -> Option<Vec<String>> {
209 None
210 }
211
212 fn nvext(&self) -> Option<&NvExt> {
213 self.nvext.as_ref()
214 }
215}
216
217#[derive(Builder)]
218pub struct ResponseFactory {
219 #[builder(setter(into))]
220 pub model: String,
221
222 #[builder(default)]
223 pub system_fingerprint: Option<String>,
224
225 #[builder(default = "format!(\"cmpl-{}\", uuid::Uuid::new_v4())")]
226 pub id: String,
227
228 #[builder(default = "\"text_completion\".to_string()")]
229 pub object: String,
230
231 #[builder(default = "chrono::Utc::now().timestamp() as u64")]
232 pub created: u64,
233}
234
235impl ResponseFactory {
236 pub fn builder() -> ResponseFactoryBuilder {
237 ResponseFactoryBuilder::default()
238 }
239
240 pub fn make_response(
241 &self,
242 choice: CompletionChoice,
243 usage: Option<CompletionUsage>,
244 ) -> CompletionResponse {
245 CompletionResponse {
246 id: self.id.clone(),
247 object: self.object.clone(),
248 created: self.created,
249 model: self.model.clone(),
250 choices: vec![choice],
251 system_fingerprint: self.system_fingerprint.clone(),
252 usage,
253 }
254 }
255}
256
257impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
259 type Error = anyhow::Error;
260
261 fn try_from(request: NvCreateCompletionRequest) -> Result<Self, Self::Error> {
262 if request.inner.suffix.is_some() {
284 return Err(anyhow::anyhow!("suffix is not supported"));
285 }
286
287 let stop_conditions = request
288 .extract_stop_conditions()
289 .map_err(|e| anyhow::anyhow!("Failed to extract stop conditions: {}", e))?;
290
291 let sampling_options = request
292 .extract_sampling_options()
293 .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;
294
295 let prompt = common::PromptType::Completion(common::CompletionContext {
296 prompt: prompt_to_string(&request.inner.prompt),
297 system_prompt: None,
298 });
299
300 Ok(common::CompletionRequest {
301 prompt,
302 stop_conditions,
303 sampling_options,
304 mdc_sum: None,
305 annotations: None,
306 })
307 }
308}
309
310impl TryFrom<common::StreamingCompletionResponse> for CompletionChoice {
311 type Error = anyhow::Error;
312
313 fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
314 let choice = CompletionChoice {
315 text: response
316 .delta
317 .text
318 .ok_or(anyhow::anyhow!("No text in response"))?,
319 index: response.delta.index.unwrap_or(0) as u64,
320 logprobs: None,
321 finish_reason: match &response.delta.finish_reason {
322 Some(common::FinishReason::EoS) => Some("stop".to_string()),
323 Some(common::FinishReason::Stop) => Some("stop".to_string()),
324 Some(common::FinishReason::Length) => Some("length".to_string()),
325 Some(common::FinishReason::Error(err_msg)) => {
326 return Err(anyhow::anyhow!("finish_reason::error = {}", err_msg));
327 }
328 Some(common::FinishReason::Cancelled) => Some("cancelled".to_string()),
329 None => None,
330 },
331 };
332
333 Ok(choice)
334 }
335}