dynamo_async_openai/types/completion.rs
1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
5// Original Copyright (c) 2022 Himanshu Neema
6// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
7//
8// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
9// Licensed under Apache 2.0
10
11use std::{collections::HashMap, pin::Pin};
12
13use derive_builder::Builder;
14use futures::Stream;
15use serde::{Deserialize, Serialize};
16
17use crate::error::OpenAIError;
18
19use super::{ChatCompletionStreamOptions, Choice, CompletionUsage, Prompt, Stop};
20
21/// Custom deserializer for the echo parameter that only accepts booleans.
22/// Rejects integers and strings with clear error messages.
23fn deserialize_echo_bool<'de, D>(deserializer: D) -> Result<Option<bool>, D::Error>
24where
25 D: serde::Deserializer<'de>,
26{
27 // Outer visitor: handles Option semantics (Some/None/null)
28 struct StrictBoolVisitor;
29
30 impl<'de> serde::de::Visitor<'de> for StrictBoolVisitor {
31 type Value = Option<bool>;
32
33 // Required by Visitor trait
34 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
35 formatter.write_str("echo parameter to be a boolean (true or false) or null")
36 }
37
38 fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
39 where
40 D: serde::Deserializer<'de>,
41 {
42 deserializer.deserialize_any(BoolOnlyVisitor)
43 }
44
45 fn visit_none<E>(self) -> Result<Self::Value, E>
46 where
47 E: serde::de::Error,
48 {
49 Ok(None)
50 }
51
52 fn visit_unit<E>(self) -> Result<Self::Value, E>
53 where
54 E: serde::de::Error,
55 {
56 Ok(None)
57 }
58 }
59
60 // Inner visitor: validates type is boolean, rejects integers and strings
61 struct BoolOnlyVisitor;
62
63 impl<'de> serde::de::Visitor<'de> for BoolOnlyVisitor {
64 type Value = Option<bool>;
65
66 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
67 formatter.write_str("echo parameter to be a boolean (true or false) or null")
68 }
69
70 fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
71 where
72 E: serde::de::Error,
73 {
74 Ok(Some(value))
75 }
76
77 // Explicitly reject strings (including "null", "true", "false")
78 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
79 where
80 E: serde::de::Error,
81 {
82 Err(E::invalid_type(
83 serde::de::Unexpected::Str(value),
84 &"echo parameter to be a boolean (true or false) or null",
85 ))
86 }
87 }
88
89 deserializer.deserialize_option(StrictBoolVisitor)
90}
91
92#[derive(Clone, Serialize, Deserialize, Default, Debug, Builder, PartialEq)]
93#[builder(name = "CreateCompletionRequestArgs")]
94#[builder(pattern = "mutable")]
95#[builder(setter(into, strip_option), default)]
96#[builder(derive(Debug))]
97#[builder(build_fn(error = "OpenAIError"))]
98pub struct CreateCompletionRequest {
99 /// ID of the model to use. You can use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them.
100 pub model: String,
101
102 /// The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.
103 ///
104 /// Note that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.
105 pub prompt: Prompt,
106
107 /// The suffix that comes after a completion of inserted text.
108 ///
109 /// This parameter is only supported for `gpt-3.5-turbo-instruct`.
110 #[serde(skip_serializing_if = "Option::is_none")]
111 pub suffix: Option<String>, // default: null
112
113 /// The maximum number of [tokens](https://platform.openai.com/tokenizer) that can be generated in the completion.
114 ///
115 /// The token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.
116 #[serde(skip_serializing_if = "Option::is_none")]
117 pub max_tokens: Option<u32>,
118
119 /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
120 ///
121 /// We generally recommend altering this or `top_p` but not both.
122 #[serde(skip_serializing_if = "Option::is_none")]
123 pub temperature: Option<f32>, // min: 0, max: 2, default: 1,
124
125 /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
126 ///
127 /// We generally recommend altering this or `temperature` but not both.
128 #[serde(skip_serializing_if = "Option::is_none")]
129 pub top_p: Option<f32>, // min: 0, max: 1, default: 1
130
131 /// How many completions to generate for each prompt.
132
133 /// **Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.
134 ///
135 #[serde(skip_serializing_if = "Option::is_none")]
136 pub n: Option<u8>, // min:1 max: 128, default: 1
137
138 /// Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
139 /// as they become available, with the stream terminated by a `data: [DONE]` message.
140 #[serde(skip_serializing_if = "Option::is_none")]
141 pub stream: Option<bool>, // nullable: true
142
143 #[serde(skip_serializing_if = "Option::is_none")]
144 pub stream_options: Option<ChatCompletionStreamOptions>,
145
146 /// Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.
147 ///
148 /// The maximum value for `logprobs` is 5.
149 #[serde(skip_serializing_if = "Option::is_none")]
150 pub logprobs: Option<u8>, // min:0 , max: 5, default: null, nullable: true
151
152 /// Echo back the prompt in addition to the completion
153 #[serde(skip_serializing_if = "Option::is_none")]
154 #[serde(default, deserialize_with = "deserialize_echo_bool")]
155 pub echo: Option<bool>,
156
157 /// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
158 #[serde(skip_serializing_if = "Option::is_none")]
159 pub stop: Option<Stop>,
160
161 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
162 ///
163 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details)
164 #[serde(skip_serializing_if = "Option::is_none")]
165 pub presence_penalty: Option<f32>, // min: -2.0, max: 2.0, default 0
166
167 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
168 ///
169 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details)
170 #[serde(skip_serializing_if = "Option::is_none")]
171 pub frequency_penalty: Option<f32>, // min: -2.0, max: 2.0, default: 0
172
173 /// Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed.
174 ///
175 /// When used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`.
176 ///
177 /// **Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.
178 #[serde(skip_serializing_if = "Option::is_none")]
179 pub best_of: Option<u8>, //min: 0, max: 20, default: 1
180
181 /// Modify the likelihood of specified tokens appearing in the completion.
182 ///
183 /// Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
184 ///
185 /// As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.
186 #[serde(skip_serializing_if = "Option::is_none")]
187 pub logit_bias: Option<HashMap<String, serde_json::Value>>, // default: null
188
189 /// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids).
190 #[serde(skip_serializing_if = "Option::is_none")]
191 pub user: Option<String>,
192
193 /// If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.
194 ///
195 /// Determinism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.
196 #[serde(skip_serializing_if = "Option::is_none")]
197 pub seed: Option<i64>,
198}
199
200#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
201pub struct CreateCompletionResponse {
202 /// A unique identifier for the completion.
203 pub id: String,
204 pub choices: Vec<Choice>,
205 /// The Unix timestamp (in seconds) of when the completion was created.
206 pub created: u32,
207
208 /// The model used for completion.
209 pub model: String,
210 /// This fingerprint represents the backend configuration that the model runs with.
211 ///
212 /// Can be used in conjunction with the `seed` request parameter to understand when backend changes have been
213 /// made that might impact determinism.
214 pub system_fingerprint: Option<String>,
215
216 /// The object type, which is always "text_completion"
217 pub object: String,
218 pub usage: Option<CompletionUsage>,
219}
220
221/// Parsed server side events stream until an \[DONE\] is received from server.
222pub type CompletionResponseStream =
223 Pin<Box<dyn Stream<Item = Result<CreateCompletionResponse, OpenAIError>> + Send>>;
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn echo_rejects_integer() {
231 let json = r#"{"model": "test_model", "prompt": "test", "echo": 1}"#;
232 let result: Result<CreateCompletionRequest, _> = serde_json::from_str(json);
233 assert!(result.is_err());
234 let err_msg = result.unwrap_err().to_string();
235 assert!(err_msg.contains("invalid type"));
236 assert!(err_msg.contains("integer"));
237 assert!(err_msg.contains("echo parameter"));
238 }
239
240 #[test]
241 fn echo_rejects_string() {
242 let json = r#"{"model": "test_model", "prompt": "test", "echo": "null"}"#;
243 let result: Result<CreateCompletionRequest, _> = serde_json::from_str(json);
244 assert!(result.is_err());
245 let err_msg = result.unwrap_err().to_string();
246 assert!(err_msg.contains("invalid type"));
247 assert!(err_msg.contains("string"));
248 assert!(err_msg.contains("echo parameter"));
249 }
250}