openai_rust/
completions.rs

1//! See <https://platform.openai.com/docs/api-reference/completions>.
2//! Use with [Client::create_completion](crate::Client::create_completion).
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7/// Request arguments for completions.
8///
9/// See <https://platform.openai.com/docs/api-reference/completions/create>.
10///
11/// ```
12/// let args = openai_rust::completions::CompletionArguments::new(
13///     "text-davinci-003",
14///     "The quick brown fox".to_owned()
15/// );
16/// ```
17#[derive(Serialize, Debug, Clone)]
18pub struct CompletionArguments {
19    /// ID of the model to use.
20    /// You can use the [List models](crate::Client::list_models) API to see all of your available models,
21    /// or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them.
22    pub model: String,
23
24    /// The prompt(s) to generate completions for,
25    /// encoded as a string, array of strings, array of tokens,
26    /// or array of token arrays.
27    ///
28    /// Defaults to <|endoftext|>.
29    ///
30    /// Note that <|endoftext|> is the document separator that the model
31    /// sees during training, so if a prompt is not specified the model
32    /// will generate as if from the beginning of a new document.
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub prompt: Option<String>,
35
36    /// The suffix that comes after a completion of inserted text.
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub suffix: Option<String>,
39
40    /// The maximum number of [tokens](https://platform.openai.com/tokenizer) to generate in the chat completion.
41    ///
42    /// The token count of your prompt plus `max_tokens` cannot exceed the model's context length.
43    /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub max_tokens: Option<u32>,
46
47    /// What sampling temperature to use, between 0 and 2.
48    /// Higher values like 0.8 will make the output more random,
49    /// while lower values like 0.2 will make it more focused and deterministic.
50    ///
51    /// We generally recommend altering this or `top_p` but not both.
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub temperature: Option<f32>,
54
55    /// An alternative to sampling with temperature, called nucleus sampling,
56    /// where the model considers the results of the tokens with top_p probability mass.
57    /// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
58    ///
59    /// We generally recommend altering this or `temperature` but not both.
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub top_p: Option<f32>,
62
63    /// How many completions to generate for each prompt.
64    ///
65    /// *Note:* Because this parameter generates many completions,it can quickly consume your token quota.
66    /// Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub n: Option<u32>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub(crate) stream: Option<bool>,
71
72    /// Include the log probabilities on the `logprobs` most likely tokens,
73    /// as well the chosen tokens. For example, if `logprobs` is 5,
74    /// the API will return a list of the 5 most likely tokens.
75    /// The API will always return the `logprob` of the sampled token,
76    /// so there may be up to `logprobs+1` elements in the response.
77    ///
78    /// The maximum value for `logprobs` is 5.
79    /// If you need more than this, please contact us through our [Help center](https://help.openai.com/) and describe your use case.
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub logprobs: Option<u8>,
82
83    /// Echo back the prompt in addition to the completion
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub echo: Option<bool>,
86
87    /// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub stop: Option<String>,
90
91    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
92    /// increasing the model's likelihood to talk about new topics.
93    ///
94    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub presence_penalty: Option<f32>,
97
98    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
99    /// decreasing the model's likelihood to repeat the same line verbatim.
100    ///
101    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub frequency_penalty: Option<f32>,
104
105    /// Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token).
106    /// Results cannot be streamed.
107    ///
108    /// When used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`.
109    ///
110    /// *Note:* Because this parameter generates many completions,it can quickly consume your token quota.
111    /// Use carefully and ensure that you have reasonable settings for max_tokens` and `stop`.
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub best_of: Option<u32>,
114
115    //logit_bias
116    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
117    /// [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub user: Option<String>,
120}
121
122impl CompletionArguments {
123    pub fn new(model: impl AsRef<str>, prompt: String) -> CompletionArguments {
124        CompletionArguments {
125            model: model.as_ref().to_owned(),
126            prompt: Some(prompt),
127            suffix: None,
128            max_tokens: None,
129            temperature: None,
130            top_p: None,
131            n: None,
132            stream: None,
133            logprobs: None,
134            echo: None,
135            stop: None,
136            presence_penalty: None,
137            frequency_penalty: None,
138            best_of: None,
139            user: None,
140        }
141    }
142}
143
144/// The repsonse of a completion request.
145///
146/// It implements [Display](std::fmt::Display) as a shortcut to easily extract the content.
147/// ```
148/// # use serde_json;
149/// # let json = "{
150/// #  \"id\": \"cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7\",
151/// #  \"object\": \"text_completion\",
152/// #  \"created\": 1589478378,
153/// #  \"model\": \"text-davinci-003\",
154/// #  \"choices\": [
155/// #    {
156/// #      \"text\": \"\\n\\nThis is indeed a test\",
157/// #      \"index\": 0,
158/// #      \"logprobs\": null,
159/// #      \"finish_reason\": \"length\"
160/// #    }
161/// #  ],
162/// #  \"usage\": {
163/// #    \"prompt_tokens\": 5,
164/// #    \"completion_tokens\": 7,
165/// #    \"total_tokens\": 12
166/// #  }
167/// # }";
168/// # let res = serde_json::from_str::<openai_rust::completions::CompletionResponse>(json).unwrap();
169/// let text = &res.choices[0].text;
170/// // or
171/// let text = res.to_string();
172/// ```
173#[derive(Deserialize, Debug, Clone)]
174pub struct CompletionResponse {
175    pub id: String,
176    pub created: u32,
177    pub model: String,
178    pub choices: Vec<Choice>,
179    pub usage: Usage,
180}
181
182impl std::fmt::Display for CompletionResponse {
183    /// Automatically grab the first choice
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        write!(f, "{}", self.choices[0].text)?;
186        Ok(())
187    }
188}
189
190/// The completion choices of a completion response.
191#[derive(Deserialize, Debug, Clone)]
192pub struct Choice {
193    pub text: String,
194    pub index: u32,
195    pub logprobs: Option<LogProbs>,
196    pub finish_reason: String,
197}
198
199/// The log probabilities of a completion response.
200#[derive(Deserialize, Debug, Clone)]
201pub struct LogProbs {
202    pub tokens: Vec<String>,
203    pub token_logprobs: Vec<f32>,
204    pub top_logprobs: Vec<HashMap<String, f32>>,
205    pub text_offset: Vec<u32>,
206}
207
208/// Infomration about the tokens used by [CompletionResponse].
209#[derive(Deserialize, Debug, Clone)]
210pub struct Usage {
211    pub prompt_tokens: u32,
212    pub completion_tokens: u32,
213    pub total_tokens: u32,
214}
215
216/*
217{
218  "logprobs": {
219    "tokens": [
220      "\"",
221      "\n",
222      "\n",
223      "The",
224      " quick",
225      " brown",
226      " fox",
227      " jumped",
228      " over",
229      " the",
230      " lazy",
231      " dog",
232      "."
233    ],
234    "token_logprobs": [
235      -3.4888523,
236      -0.081398554,
237      -0.27080205,
238      -0.010607235,
239      -0.03842781,
240      -0.00033003604,
241      -0.00006468596,
242      -0.8200931,
243      -0.0002035838,
244      -0.00010665305,
245      -0.0003372524,
246      -0.002368947,
247      -0.0031320814
248    ],
249    "top_logprobs": [
250      {
251        "\n": -1.016303
252      },
253      {
254        "\n": -0.081398554
255      },
256      {
257        "\n": -0.27080205
258      },
259      {
260        "The": -0.010607235
261      },
262      {
263        " quick": -0.03842781
264      },
265      {
266        " brown": -0.00033003604
267      },
268      {
269        " fox": -0.00006468596
270      },
271      {
272        " jumps": -0.58238596
273      },
274      {
275        " over": -0.0002035838
276      },
277      {
278        " the": -0.00010665305
279      },
280      {
281        " lazy": -0.0003372524
282      },
283      {
284        " dog": -0.002368947
285      },
286      {
287        ".": -0.0031320814
288      }
289    ],
290    "text_offset": [
291      13,
292      14,
293      15,
294      16,
295      19,
296      25,
297      31,
298      35,
299      42,
300      47,
301      51,
302      56,
303      60
304    ]
305  }
306}
307*/