openai_rust/completions.rs
1//! See <https://platform.openai.com/docs/api-reference/completions>.
2//! Use with [Client::create_completion](crate::Client::create_completion).
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7/// Request arguments for completions.
8///
9/// See <https://platform.openai.com/docs/api-reference/completions/create>.
10///
11/// ```
12/// let args = openai_rust::completions::CompletionArguments::new(
13/// "text-davinci-003",
14/// "The quick brown fox".to_owned()
15/// );
16/// ```
17#[derive(Serialize, Debug, Clone)]
18pub struct CompletionArguments {
19 /// ID of the model to use.
20 /// You can use the [List models](crate::Client::list_models) API to see all of your available models,
21 /// or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them.
22 pub model: String,
23
24 /// The prompt(s) to generate completions for,
25 /// encoded as a string, array of strings, array of tokens,
26 /// or array of token arrays.
27 ///
28 /// Defaults to <|endoftext|>.
29 ///
30 /// Note that <|endoftext|> is the document separator that the model
31 /// sees during training, so if a prompt is not specified the model
32 /// will generate as if from the beginning of a new document.
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub prompt: Option<String>,
35
36 /// The suffix that comes after a completion of inserted text.
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub suffix: Option<String>,
39
40 /// The maximum number of [tokens](https://platform.openai.com/tokenizer) to generate in the chat completion.
41 ///
42 /// The token count of your prompt plus `max_tokens` cannot exceed the model's context length.
43 /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub max_tokens: Option<u32>,
46
47 /// What sampling temperature to use, between 0 and 2.
48 /// Higher values like 0.8 will make the output more random,
49 /// while lower values like 0.2 will make it more focused and deterministic.
50 ///
51 /// We generally recommend altering this or `top_p` but not both.
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub temperature: Option<f32>,
54
55 /// An alternative to sampling with temperature, called nucleus sampling,
56 /// where the model considers the results of the tokens with top_p probability mass.
57 /// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
58 ///
59 /// We generally recommend altering this or `temperature` but not both.
60 #[serde(skip_serializing_if = "Option::is_none")]
61 pub top_p: Option<f32>,
62
63 /// How many completions to generate for each prompt.
64 ///
65 /// *Note:* Because this parameter generates many completions,it can quickly consume your token quota.
66 /// Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub n: Option<u32>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub(crate) stream: Option<bool>,
71
72 /// Include the log probabilities on the `logprobs` most likely tokens,
73 /// as well the chosen tokens. For example, if `logprobs` is 5,
74 /// the API will return a list of the 5 most likely tokens.
75 /// The API will always return the `logprob` of the sampled token,
76 /// so there may be up to `logprobs+1` elements in the response.
77 ///
78 /// The maximum value for `logprobs` is 5.
79 /// If you need more than this, please contact us through our [Help center](https://help.openai.com/) and describe your use case.
80 #[serde(skip_serializing_if = "Option::is_none")]
81 pub logprobs: Option<u8>,
82
83 /// Echo back the prompt in addition to the completion
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub echo: Option<bool>,
86
87 /// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub stop: Option<String>,
90
91 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
92 /// increasing the model's likelihood to talk about new topics.
93 ///
94 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
95 #[serde(skip_serializing_if = "Option::is_none")]
96 pub presence_penalty: Option<f32>,
97
98 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
99 /// decreasing the model's likelihood to repeat the same line verbatim.
100 ///
101 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
102 #[serde(skip_serializing_if = "Option::is_none")]
103 pub frequency_penalty: Option<f32>,
104
105 /// Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token).
106 /// Results cannot be streamed.
107 ///
108 /// When used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`.
109 ///
110 /// *Note:* Because this parameter generates many completions,it can quickly consume your token quota.
111 /// Use carefully and ensure that you have reasonable settings for max_tokens` and `stop`.
112 #[serde(skip_serializing_if = "Option::is_none")]
113 pub best_of: Option<u32>,
114
115 //logit_bias
116 /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
117 /// [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
118 #[serde(skip_serializing_if = "Option::is_none")]
119 pub user: Option<String>,
120}
121
122impl CompletionArguments {
123 pub fn new(model: impl AsRef<str>, prompt: String) -> CompletionArguments {
124 CompletionArguments {
125 model: model.as_ref().to_owned(),
126 prompt: Some(prompt),
127 suffix: None,
128 max_tokens: None,
129 temperature: None,
130 top_p: None,
131 n: None,
132 stream: None,
133 logprobs: None,
134 echo: None,
135 stop: None,
136 presence_penalty: None,
137 frequency_penalty: None,
138 best_of: None,
139 user: None,
140 }
141 }
142}
143
144/// The repsonse of a completion request.
145///
146/// It implements [Display](std::fmt::Display) as a shortcut to easily extract the content.
147/// ```
148/// # use serde_json;
149/// # let json = "{
150/// # \"id\": \"cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7\",
151/// # \"object\": \"text_completion\",
152/// # \"created\": 1589478378,
153/// # \"model\": \"text-davinci-003\",
154/// # \"choices\": [
155/// # {
156/// # \"text\": \"\\n\\nThis is indeed a test\",
157/// # \"index\": 0,
158/// # \"logprobs\": null,
159/// # \"finish_reason\": \"length\"
160/// # }
161/// # ],
162/// # \"usage\": {
163/// # \"prompt_tokens\": 5,
164/// # \"completion_tokens\": 7,
165/// # \"total_tokens\": 12
166/// # }
167/// # }";
168/// # let res = serde_json::from_str::<openai_rust::completions::CompletionResponse>(json).unwrap();
169/// let text = &res.choices[0].text;
170/// // or
171/// let text = res.to_string();
172/// ```
173#[derive(Deserialize, Debug, Clone)]
174pub struct CompletionResponse {
175 pub id: String,
176 pub created: u32,
177 pub model: String,
178 pub choices: Vec<Choice>,
179 pub usage: Usage,
180}
181
182impl std::fmt::Display for CompletionResponse {
183 /// Automatically grab the first choice
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 write!(f, "{}", self.choices[0].text)?;
186 Ok(())
187 }
188}
189
190/// The completion choices of a completion response.
191#[derive(Deserialize, Debug, Clone)]
192pub struct Choice {
193 pub text: String,
194 pub index: u32,
195 pub logprobs: Option<LogProbs>,
196 pub finish_reason: String,
197}
198
199/// The log probabilities of a completion response.
200#[derive(Deserialize, Debug, Clone)]
201pub struct LogProbs {
202 pub tokens: Vec<String>,
203 pub token_logprobs: Vec<f32>,
204 pub top_logprobs: Vec<HashMap<String, f32>>,
205 pub text_offset: Vec<u32>,
206}
207
208/// Infomration about the tokens used by [CompletionResponse].
209#[derive(Deserialize, Debug, Clone)]
210pub struct Usage {
211 pub prompt_tokens: u32,
212 pub completion_tokens: u32,
213 pub total_tokens: u32,
214}
215
216/*
217{
218 "logprobs": {
219 "tokens": [
220 "\"",
221 "\n",
222 "\n",
223 "The",
224 " quick",
225 " brown",
226 " fox",
227 " jumped",
228 " over",
229 " the",
230 " lazy",
231 " dog",
232 "."
233 ],
234 "token_logprobs": [
235 -3.4888523,
236 -0.081398554,
237 -0.27080205,
238 -0.010607235,
239 -0.03842781,
240 -0.00033003604,
241 -0.00006468596,
242 -0.8200931,
243 -0.0002035838,
244 -0.00010665305,
245 -0.0003372524,
246 -0.002368947,
247 -0.0031320814
248 ],
249 "top_logprobs": [
250 {
251 "\n": -1.016303
252 },
253 {
254 "\n": -0.081398554
255 },
256 {
257 "\n": -0.27080205
258 },
259 {
260 "The": -0.010607235
261 },
262 {
263 " quick": -0.03842781
264 },
265 {
266 " brown": -0.00033003604
267 },
268 {
269 " fox": -0.00006468596
270 },
271 {
272 " jumps": -0.58238596
273 },
274 {
275 " over": -0.0002035838
276 },
277 {
278 " the": -0.00010665305
279 },
280 {
281 " lazy": -0.0003372524
282 },
283 {
284 " dog": -0.002368947
285 },
286 {
287 ".": -0.0031320814
288 }
289 ],
290 "text_offset": [
291 13,
292 14,
293 15,
294 16,
295 19,
296 25,
297 31,
298 35,
299 42,
300 47,
301 51,
302 56,
303 60
304 ]
305 }
306}
307*/