openai_interface/completions/request.rs
1use std::collections::HashMap;
2
3use serde::Serialize;
4
5use crate::rest::post::{Post, PostNoStream, PostStream};
6
7#[derive(Debug, Serialize, Default, Clone)]
8pub struct CompletionRequest {
9 /// ID of the model to use. Note that not all models are supported for completion.
10 pub model: String,
11 /// The prompt(s) to generate completions for, encoded as a string, array of
12 /// strings, array of tokens, or array of token arrays.
13 /// Note that <|endoftext|> is the document separator that the model sees during
14 /// training, so if a prompt is not specified the model will generate as if from the
15 /// beginning of a new document.
16 pub prompt: Prompt,
17 /// Generates `best_of` completions server-side and returns the "best" (the one with
18 /// the highest log probability per token). Results cannot be streamed.
19 ///
20 /// When used with `n`, `best_of` controls the number of candidate completions and
21 /// `n` specifies how many to return – `best_of` must be greater than `n`.
22 ///
23 /// **Note:** Because this parameter generates many completions, it can quickly
24 /// consume your token quota. Use carefully and ensure that you have reasonable
25 /// settings for `max_tokens` and `stop`.
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub best_of: Option<usize>,
28 /// Echo back the prompt in addition to the completion
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub echo: Option<bool>,
31 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their
32 /// existing frequency in the text so far, decreasing the model's likelihood to
33 /// repeat the same line verbatim.
34 ///
35 /// [more info about frequency/presence penalties](https://platform.openai.com/docs/guides/text-generation)
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub frequency_penalty: Option<f32>,
38 /// Modify the likelihood of specified tokens appearing in the completion.
39 ///
40 /// Accepts a JSON object that maps tokens (specified by their token ID in the GPT
41 /// tokenizer) to an associated bias value from -100 to 100. You can use this
42 /// [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs.
43 /// Mathematically, the bias is added to the logits generated by the model prior to
44 /// sampling. The exact effect will vary per model, but values between -1 and 1
45 /// should decrease or increase likelihood of selection; values like -100 or 100
46 /// should result in a ban or exclusive selection of the relevant token.
47 ///
48 /// As an example, you can pass `{"50256": -100}` to prevent the <|end-of-stream|> token
49 /// from being generated.
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub logit_bias: Option<HashMap<String, isize>>,
52 /// Include the log probabilities on the `logprobs` most likely output tokens, as
53 /// well the chosen tokens. For example, if `logprobs` is 5, the API will return a
54 /// list of the 5 most likely tokens. The API will always return the `logprob` of
55 /// the sampled token, so there may be up to `logprobs+1` elements in the response.
56 ///
57 /// The maximum value for `logprobs` is 5.
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub logprobs: Option<usize>,
60 /// The maximum number of [tokens](/tokenizer) that can be generated in the
61 /// completion.
62 ///
63 /// The token count of your prompt plus `max_tokens` cannot exceed the model's
64 /// context length.
65 /// [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
66 /// for counting tokens.
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub max_tokens: Option<usize>,
69 /// How many completions to generate for each prompt.
70 ///
71 /// **Note:** Because this parameter generates many completions, it can quickly
72 /// consume your token quota. Use carefully and ensure that you have reasonable
73 /// settings for `max_tokens` and `stop`.
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub n: Option<usize>,
76 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on
77 /// whether they appear in the text so far, increasing the model's likelihood to
78 /// talk about new topics.
79 ///
80 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation)
81 #[serde(skip_serializing_if = "Option::is_none")]
82 pub presence_penalty: Option<f32>,
83 /// If specified, our system will make a best effort to sample deterministically,
84 /// such that repeated requests with the same `seed` and parameters should return
85 /// the same result.
86 ///
87 /// Determinism is not guaranteed, and you should refer to the `system_fingerprint`
88 /// response parameter to monitor changes in the backend.
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub seed: Option<usize>,
91 /// Up to 4 sequences where the API will stop generating further tokens. The
92 /// returned text will not contain the stop sequence.
93 ///
94 /// Note: Not supported with latest reasoning models `o3` and `o4-mini`.
95 #[serde(skip_serializing_if = "Option::is_none")]
96 pub stop: Option<StopKeywords>,
97 /// Whether to stream back partial progress. If set, tokens will be sent as
98 /// data-only
99 /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
100 /// as they become available, with the stream terminated by a `data: [DONE]`
101 /// message.
102 /// [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
103 pub stream: bool,
104 /// Options for streaming response. Only set this when you set `stream: true`.
105 #[serde(skip_serializing_if = "Option::is_none")]
106 pub stream_options: Option<StreamOptions>,
107 /// The suffix that comes after a completion of inserted text.
108 ///
109 /// This parameter is only supported for `gpt-3.5-turbo-instruct`.
110 #[serde(skip_serializing_if = "Option::is_none")]
111 pub suffix: Option<String>,
112 /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
113 /// make the output more random, while lower values like 0.2 will make it more
114 /// focused and deterministic.
115 ///
116 /// It is generally recommended to alter this or `top_p` but not both.
117 #[serde(skip_serializing_if = "Option::is_none")]
118 pub temperature: Option<f32>,
119 /// An alternative to sampling with temperature, called nucleus sampling,
120 /// where the model considers the results of the tokens with `top_p`
121 /// probability mass. So 0.1 means only the tokens comprising the top 10%
122 /// probability mass are considered.
123 ///
124 /// It is generally recommended to alter this or `temperature` but not both.
125 #[serde(skip_serializing_if = "Option::is_none")]
126 pub top_p: Option<f32>,
127 /// A unique identifier representing your end-user, which can help OpenAI to monitor
128 /// and detect abuse.
129 /// [Learn more from OpenAI](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids).
130 #[serde(skip_serializing_if = "Option::is_none")]
131 pub user: Option<String>,
132 /// Add additional JSON properties to the request
133 pub extra_body: serde_json::Map<String, serde_json::Value>,
134}
135
136#[derive(Debug, Serialize, Clone)]
137#[serde(untagged)]
138pub enum Prompt {
139 /// String
140 PromptString(String),
141 /// Array of strings
142 PromptStringArray(Vec<String>),
143 /// Array of tokens
144 TokensArray(Vec<usize>),
145 /// Array of arrays of tokens
146 TokenArraysArray(Vec<Vec<usize>>),
147}
148impl Default for Prompt {
149 fn default() -> Self {
150 Self::PromptString("".to_string())
151 }
152}
153
154#[derive(Debug, Serialize, Clone)]
155pub struct StreamOptions {
156 /// When true, stream obfuscation will be enabled.
157 ///
158 /// Stream obfuscation adds random characters to an `obfuscation` field on streaming
159 /// delta events to normalize payload sizes as a mitigation to certain side-channel
160 /// attacks. These obfuscation fields are included by default, but add a small
161 /// amount of overhead to the data stream. You can set `include_obfuscation` to
162 /// false to optimize for bandwidth if you trust the network links between your
163 /// application and the OpenAI API.
164 pub include_obfuscation: bool,
165 /// If set, an additional chunk will be streamed before the `data: [DONE]` message.
166 ///
167 /// The `usage` field on this chunk shows the token usage statistics for the entire
168 /// request, and the `choices` field will always be an empty array.
169 ///
170 /// All other chunks will also include a `usage` field, but with a null value.
171 /// **NOTE:** If the stream is interrupted, you may not receive the final usage
172 /// chunk which contains the total token usage for the request.
173 pub include_usage: bool,
174}
175
176#[derive(Debug, Serialize, Clone)]
177#[serde(untagged)]
178pub enum StopKeywords {
179 Word(String),
180 Words(Vec<String>),
181}
182
183impl Post for CompletionRequest {
184 fn is_streaming(&self) -> bool {
185 self.stream
186 }
187}
188
189impl PostNoStream for CompletionRequest {
190 type Response = super::response::Completion;
191}
192
193impl PostStream for CompletionRequest {
194 type Response = super::response::Completion;
195}
196
197#[cfg(test)]
198mod tests {
199 use std::sync::LazyLock;
200
201 use futures_util::StreamExt;
202
203 use super::*;
204
205 const QWEN_MODEL: &str = "qwen-coder-turbo-latest";
206 const QWEN_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1/completions";
207 const QWEN_API_KEY: LazyLock<&'static str> =
208 LazyLock::new(|| include_str!("../../keys/modelstudio_domestic_key").trim());
209
210 #[tokio::test]
211 async fn test_qwen_completions_no_stream() -> Result<(), anyhow::Error> {
212 let request_body = CompletionRequest {
213 model: QWEN_MODEL.to_string(),
214 prompt: Prompt::PromptString(
215 r#"
216 package main
217
218 import (
219 "fmt"
220 "strings"
221 "net/http"
222 "io/ioutil"
223 )
224
225 func main() {
226
227 url := "https://api.deepseek.com/chat/completions"
228 method := "POST"
229
230 payload := strings.NewReader(`{
231 "messages": [
232 {
233 "content": "You are a helpful assistant",
234 "role": "system"
235 },
236 {
237 "content": "Hi",
238 "role": "user"
239 }
240 ],
241 "model": "deepseek-chat",
242 "frequency_penalty": 0,
243 "max_tokens": 4096,
244 "presence_penalty": 0,
245 "response_format": {
246 "type": "text"
247 },
248 "stop": null,
249 "stream": false,
250 "stream_options": null,
251 "temperature": 1,
252 "top_p": 1,
253 "tools": null,
254 "tool_choice": "none",
255 "logprobs": false,
256 "top_logprobs": null
257 }`)
258
259 client := &http.Client {
260 }
261 req, err := http.NewRequest(method, url, payload)
262
263 if err != nil {
264 fmt.Println(err)
265 return
266 }
267 req.Header.Add("Content-Type", "application/json")
268 req.Header.Add("Accept", "application/json")
269 req.Header.Add("Authorization", "Bearer <TOKEN>")
270
271 res, err := client.Do(req)
272 if err != nil {
273 fmt.Println(err)
274 return
275 }
276 defer res.Body.Close()
277"#
278 .to_string(),
279 ),
280 suffix: Some(
281 r#"
282 if err != nil {
283 fmt.Println(err)
284 return
285 }
286 fmt.Println(string(body))
287}
288"#
289 .to_string(),
290 ),
291 stream: false,
292 ..Default::default()
293 };
294
295 let result = request_body
296 .get_response_string(QWEN_URL, *QWEN_API_KEY)
297 .await?;
298 println!("{}", result);
299
300 Ok(())
301 }
302
303 #[tokio::test]
304 async fn test_qwen_completions_stream() -> Result<(), anyhow::Error> {
305 let request_body = CompletionRequest {
306 model: QWEN_MODEL.to_string(),
307 prompt: Prompt::PromptString(
308 r#"
309 package main
310
311 import (
312 "fmt"
313 "strings"
314 "net/http"
315 "io/ioutil"
316 )
317
318 func main() {
319
320 url := "https://api.deepseek.com/chat/completions"
321 method := "POST"
322
323 payload := strings.NewReader(`{
324 "messages": [
325 {
326 "content": "You are a helpful assistant",
327 "role": "system"
328 },
329 {
330 "content": "Hi",
331 "role": "user"
332 }
333 ],
334 "model": "deepseek-chat",
335 "frequency_penalty": 0,
336 "max_tokens": 4096,
337 "presence_penalty": 0,
338 "response_format": {
339 "type": "text"
340 },
341 "stop": null,
342 "stream": true,
343 "stream_options": null,
344 "temperature": 1,
345 "top_p": 1,
346 "tools": null,
347 "tool_choice": "none",
348 "logprobs": false,
349 "top_logprobs": null
350 }`)
351
352 client := &http.Client {
353 }
354 req, err := http.NewRequest(method, url, payload)
355
356 if err != nil {
357 fmt.Println(err)
358 return
359 }
360 req.Header.Add("Content-Type", "application/json")
361 req.Header.Add("Accept", "application/json")
362 req.Header.Add("Authorization", "Bearer <TOKEN>")
363
364 res, err := client.Do(req)
365 if err != nil {
366 fmt.Println(err)
367 return
368 }
369 defer res.Body.Close()
370 "#
371 .to_string(),
372 ),
373 suffix: Some(
374 r#"
375 if err != nil {
376 fmt.Println(err)
377 return
378 }
379 fmt.Println(string(body))
380 }
381 "#
382 .to_string(),
383 ),
384 stream: true,
385 ..Default::default()
386 };
387
388 let mut stream = request_body
389 .get_stream_response_string(QWEN_URL, *QWEN_API_KEY)
390 .await?;
391
392 while let Some(chunk) = stream.next().await {
393 match chunk {
394 Ok(data) => {
395 println!("Received chunk: {:?}", data);
396 }
397 Err(e) => {
398 eprintln!("Error receiving chunk: {:?}", e);
399 break;
400 }
401 }
402 }
403
404 Ok(())
405 }
406}