dynamo_async_openai/types/
completion.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
5// Original Copyright (c) 2022 Himanshu Neema
6// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
7//
8// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
9// Licensed under Apache 2.0
10
11use std::{collections::HashMap, pin::Pin};
12
13use derive_builder::Builder;
14use futures::Stream;
15use serde::{Deserialize, Serialize};
16
17use crate::error::OpenAIError;
18
19use super::{ChatCompletionStreamOptions, Choice, CompletionUsage, Prompt, Stop};
20
21/// Custom deserializer for the echo parameter that only accepts booleans.
22/// Rejects integers and strings with clear error messages.
23fn deserialize_echo_bool<'de, D>(deserializer: D) -> Result<Option<bool>, D::Error>
24where
25    D: serde::Deserializer<'de>,
26{
27    // Outer visitor: handles Option semantics (Some/None/null)
28    struct StrictBoolVisitor;
29
30    impl<'de> serde::de::Visitor<'de> for StrictBoolVisitor {
31        type Value = Option<bool>;
32
33        // Required by Visitor trait
34        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
35            formatter.write_str("echo parameter to be a boolean (true or false) or null")
36        }
37
38        fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
39        where
40            D: serde::Deserializer<'de>,
41        {
42            deserializer.deserialize_any(BoolOnlyVisitor)
43        }
44
45        fn visit_none<E>(self) -> Result<Self::Value, E>
46        where
47            E: serde::de::Error,
48        {
49            Ok(None)
50        }
51
52        fn visit_unit<E>(self) -> Result<Self::Value, E>
53        where
54            E: serde::de::Error,
55        {
56            Ok(None)
57        }
58    }
59
60    // Inner visitor: validates type is boolean, rejects integers and strings
61    struct BoolOnlyVisitor;
62
63    impl<'de> serde::de::Visitor<'de> for BoolOnlyVisitor {
64        type Value = Option<bool>;
65
66        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
67            formatter.write_str("echo parameter to be a boolean (true or false) or null")
68        }
69
70        fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
71        where
72            E: serde::de::Error,
73        {
74            Ok(Some(value))
75        }
76
77        // Explicitly reject strings (including "null", "true", "false")
78        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
79        where
80            E: serde::de::Error,
81        {
82            Err(E::invalid_type(
83                serde::de::Unexpected::Str(value),
84                &"echo parameter to be a boolean (true or false) or null",
85            ))
86        }
87    }
88
89    deserializer.deserialize_option(StrictBoolVisitor)
90}
91
92#[derive(Clone, Serialize, Deserialize, Default, Debug, Builder, PartialEq)]
93#[builder(name = "CreateCompletionRequestArgs")]
94#[builder(pattern = "mutable")]
95#[builder(setter(into, strip_option), default)]
96#[builder(derive(Debug))]
97#[builder(build_fn(error = "OpenAIError"))]
98pub struct CreateCompletionRequest {
99    /// ID of the model to use. You can use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them.
100    pub model: String,
101
102    /// The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.
103    ///
104    /// Note that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.
105    pub prompt: Prompt,
106
107    /// The suffix that comes after a completion of inserted text.
108    ///
109    /// This parameter is only supported for `gpt-3.5-turbo-instruct`.
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub suffix: Option<String>, // default: null
112
113    /// The maximum number of [tokens](https://platform.openai.com/tokenizer) that can be generated in the completion.
114    ///
115    /// The token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub max_tokens: Option<u32>,
118
119    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
120    ///
121    /// We generally recommend altering this or `top_p` but not both.
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub temperature: Option<f32>, // min: 0, max: 2, default: 1,
124
125    /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
126    ///
127    ///  We generally recommend altering this or `temperature` but not both.
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub top_p: Option<f32>, // min: 0, max: 1, default: 1
130
131    /// How many completions to generate for each prompt.
132
133    /// **Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.
134    ///
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub n: Option<u8>, // min:1 max: 128, default: 1
137
138    /// Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
139    /// as they become available, with the stream terminated by a `data: [DONE]` message.
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub stream: Option<bool>, // nullable: true
142
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub stream_options: Option<ChatCompletionStreamOptions>,
145
146    /// Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.
147    ///
148    /// The maximum value for `logprobs` is 5.
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub logprobs: Option<u8>, // min:0 , max: 5, default: null, nullable: true
151
152    /// Echo back the prompt in addition to the completion
153    #[serde(skip_serializing_if = "Option::is_none")]
154    #[serde(default, deserialize_with = "deserialize_echo_bool")]
155    pub echo: Option<bool>,
156
157    ///  Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub stop: Option<Stop>,
160
161    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
162    ///
163    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details)
164    #[serde(skip_serializing_if = "Option::is_none")]
165    pub presence_penalty: Option<f32>, // min: -2.0, max: 2.0, default 0
166
167    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
168    ///
169    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details)
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub frequency_penalty: Option<f32>, // min: -2.0, max: 2.0, default: 0
172
173    /// Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed.
174    ///
175    /// When used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`.
176    ///
177    /// **Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub best_of: Option<u8>, //min: 0, max: 20, default: 1
180
181    /// Modify the likelihood of specified tokens appearing in the completion.
182    ///
183    /// Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
184    ///
185    /// As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub logit_bias: Option<HashMap<String, serde_json::Value>>, // default: null
188
189    /// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids).
190    #[serde(skip_serializing_if = "Option::is_none")]
191    pub user: Option<String>,
192
193    /// If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.
194    ///
195    /// Determinism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.
196    #[serde(skip_serializing_if = "Option::is_none")]
197    pub seed: Option<i64>,
198}
199
200#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
201pub struct CreateCompletionResponse {
202    /// A unique identifier for the completion.
203    pub id: String,
204    pub choices: Vec<Choice>,
205    /// The Unix timestamp (in seconds) of when the completion was created.
206    pub created: u32,
207
208    /// The model used for completion.
209    pub model: String,
210    /// This fingerprint represents the backend configuration that the model runs with.
211    ///
212    /// Can be used in conjunction with the `seed` request parameter to understand when backend changes have been
213    /// made that might impact determinism.
214    pub system_fingerprint: Option<String>,
215
216    /// The object type, which is always "text_completion"
217    pub object: String,
218    pub usage: Option<CompletionUsage>,
219}
220
221/// Parsed server side events stream until an \[DONE\] is received from server.
222pub type CompletionResponseStream =
223    Pin<Box<dyn Stream<Item = Result<CreateCompletionResponse, OpenAIError>> + Send>>;
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn echo_rejects_integer() {
231        let json = r#"{"model": "test_model", "prompt": "test", "echo": 1}"#;
232        let result: Result<CreateCompletionRequest, _> = serde_json::from_str(json);
233        assert!(result.is_err());
234        let err_msg = result.unwrap_err().to_string();
235        assert!(err_msg.contains("invalid type"));
236        assert!(err_msg.contains("integer"));
237        assert!(err_msg.contains("echo parameter"));
238    }
239
240    #[test]
241    fn echo_rejects_string() {
242        let json = r#"{"model": "test_model", "prompt": "test", "echo": "null"}"#;
243        let result: Result<CreateCompletionRequest, _> = serde_json::from_str(json);
244        assert!(result.is_err());
245        let err_msg = result.unwrap_err().to_string();
246        assert!(err_msg.contains("invalid type"));
247        assert!(err_msg.contains("string"));
248        assert!(err_msg.contains("echo parameter"));
249    }
250}