openai_interface/completions/request.rs
1use std::collections::HashMap;
2
3use serde::Serialize;
4use url::Url;
5
6use crate::{
7 errors::OapiError,
8 rest::post::{Post, PostNoStream, PostStream},
9};
10
11#[derive(Debug, Serialize, Default, Clone)]
12pub struct CompletionRequest {
13 /// ID of the model to use. Note that not all models are supported for completion.
14 pub model: String,
15 /// The prompt(s) to generate completions for, encoded as a string, array of
16 /// strings, array of tokens, or array of token arrays.
17 /// Note that <|endoftext|> is the document separator that the model sees during
18 /// training, so if a prompt is not specified the model will generate as if from the
19 /// beginning of a new document.
20 pub prompt: Prompt,
21 /// Generates `best_of` completions server-side and returns the "best" (the one with
22 /// the highest log probability per token). Results cannot be streamed.
23 ///
24 /// When used with `n`, `best_of` controls the number of candidate completions and
25 /// `n` specifies how many to return – `best_of` must be greater than `n`.
26 ///
27 /// **Note:** Because this parameter generates many completions, it can quickly
28 /// consume your token quota. Use carefully and ensure that you have reasonable
29 /// settings for `max_tokens` and `stop`.
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub best_of: Option<usize>,
32 /// Echo back the prompt in addition to the completion
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub echo: Option<bool>,
35 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their
36 /// existing frequency in the text so far, decreasing the model's likelihood to
37 /// repeat the same line verbatim.
38 ///
39 /// [more info about frequency/presence penalties](https://platform.openai.com/docs/guides/text-generation)
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub frequency_penalty: Option<f32>,
42 /// Modify the likelihood of specified tokens appearing in the completion.
43 ///
44 /// Accepts a JSON object that maps tokens (specified by their token ID in the GPT
45 /// tokenizer) to an associated bias value from -100 to 100. You can use this
46 /// [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs.
47 /// Mathematically, the bias is added to the logits generated by the model prior to
48 /// sampling. The exact effect will vary per model, but values between -1 and 1
49 /// should decrease or increase likelihood of selection; values like -100 or 100
50 /// should result in a ban or exclusive selection of the relevant token.
51 ///
52 /// As an example, you can pass `{"50256": -100}` to prevent the <|end-of-stream|> token
53 /// from being generated.
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub logit_bias: Option<HashMap<String, isize>>,
56 /// Include the log probabilities on the `logprobs` most likely output tokens, as
57 /// well the chosen tokens. For example, if `logprobs` is 5, the API will return a
58 /// list of the 5 most likely tokens. The API will always return the `logprob` of
59 /// the sampled token, so there may be up to `logprobs+1` elements in the response.
60 ///
61 /// The maximum value for `logprobs` is 5.
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub logprobs: Option<usize>,
64 /// The maximum number of [tokens](/tokenizer) that can be generated in the
65 /// completion.
66 ///
67 /// The token count of your prompt plus `max_tokens` cannot exceed the model's
68 /// context length.
69 /// [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
70 /// for counting tokens.
71 #[serde(skip_serializing_if = "Option::is_none")]
72 pub max_tokens: Option<usize>,
73 /// How many completions to generate for each prompt.
74 ///
75 /// **Note:** Because this parameter generates many completions, it can quickly
76 /// consume your token quota. Use carefully and ensure that you have reasonable
77 /// settings for `max_tokens` and `stop`.
78 #[serde(skip_serializing_if = "Option::is_none")]
79 pub n: Option<usize>,
80 /// Number between -2.0 and 2.0. Positive values penalize new tokens based on
81 /// whether they appear in the text so far, increasing the model's likelihood to
82 /// talk about new topics.
83 ///
84 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation)
85 #[serde(skip_serializing_if = "Option::is_none")]
86 pub presence_penalty: Option<f32>,
87 /// If specified, our system will make a best effort to sample deterministically,
88 /// such that repeated requests with the same `seed` and parameters should return
89 /// the same result.
90 ///
91 /// Determinism is not guaranteed, and you should refer to the `system_fingerprint`
92 /// response parameter to monitor changes in the backend.
93 #[serde(skip_serializing_if = "Option::is_none")]
94 pub seed: Option<usize>,
95 /// Up to 4 sequences where the API will stop generating further tokens. The
96 /// returned text will not contain the stop sequence.
97 ///
98 /// Note: Not supported with latest reasoning models `o3` and `o4-mini`.
99 #[serde(skip_serializing_if = "Option::is_none")]
100 pub stop: Option<StopKeywords>,
101 /// Whether to stream back partial progress. If set, tokens will be sent as
102 /// data-only
103 /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
104 /// as they become available, with the stream terminated by a `data: [DONE]`
105 /// message.
106 /// [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
107 pub stream: bool,
108 /// Options for streaming response. Only set this when you set `stream: true`.
109 #[serde(skip_serializing_if = "Option::is_none")]
110 pub stream_options: Option<StreamOptions>,
111 /// The suffix that comes after a completion of inserted text.
112 ///
113 /// This parameter is only supported for `gpt-3.5-turbo-instruct`.
114 #[serde(skip_serializing_if = "Option::is_none")]
115 pub suffix: Option<String>,
116 /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
117 /// make the output more random, while lower values like 0.2 will make it more
118 /// focused and deterministic.
119 ///
120 /// It is generally recommended to alter this or `top_p` but not both.
121 #[serde(skip_serializing_if = "Option::is_none")]
122 pub temperature: Option<f32>,
123 /// An alternative to sampling with temperature, called nucleus sampling,
124 /// where the model considers the results of the tokens with `top_p`
125 /// probability mass. So 0.1 means only the tokens comprising the top 10%
126 /// probability mass are considered.
127 ///
128 /// It is generally recommended to alter this or `temperature` but not both.
129 #[serde(skip_serializing_if = "Option::is_none")]
130 pub top_p: Option<f32>,
131 /// A unique identifier representing your end-user, which can help OpenAI to monitor
132 /// and detect abuse.
133 /// [Learn more from OpenAI](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids).
134 #[serde(skip_serializing_if = "Option::is_none")]
135 pub user: Option<String>,
136 /// Add additional JSON properties to the request
137 pub extra_body: serde_json::Map<String, serde_json::Value>,
138}
139
140#[derive(Debug, Serialize, Clone)]
141#[serde(untagged)]
142pub enum Prompt {
143 /// String
144 PromptString(String),
145 /// Array of strings
146 PromptStringArray(Vec<String>),
147 /// Array of tokens
148 TokensArray(Vec<usize>),
149 /// Array of arrays of tokens
150 TokenArraysArray(Vec<Vec<usize>>),
151}
152impl Default for Prompt {
153 fn default() -> Self {
154 Self::PromptString("".to_string())
155 }
156}
157
158#[derive(Debug, Serialize, Clone)]
159pub struct StreamOptions {
160 /// When true, stream obfuscation will be enabled.
161 ///
162 /// Stream obfuscation adds random characters to an `obfuscation` field on streaming
163 /// delta events to normalize payload sizes as a mitigation to certain side-channel
164 /// attacks. These obfuscation fields are included by default, but add a small
165 /// amount of overhead to the data stream. You can set `include_obfuscation` to
166 /// false to optimize for bandwidth if you trust the network links between your
167 /// application and the OpenAI API.
168 pub include_obfuscation: bool,
169 /// If set, an additional chunk will be streamed before the `data: [DONE]` message.
170 ///
171 /// The `usage` field on this chunk shows the token usage statistics for the entire
172 /// request, and the `choices` field will always be an empty array.
173 ///
174 /// All other chunks will also include a `usage` field, but with a null value.
175 /// **NOTE:** If the stream is interrupted, you may not receive the final usage
176 /// chunk which contains the total token usage for the request.
177 pub include_usage: bool,
178}
179
180#[derive(Debug, Serialize, Clone)]
181#[serde(untagged)]
182pub enum StopKeywords {
183 Word(String),
184 Words(Vec<String>),
185}
186
187impl Post for CompletionRequest {
188 fn is_streaming(&self) -> bool {
189 self.stream
190 }
191
192 /// Builds the URL for the request.
193 ///
194 /// `base_url` should be like <https://api.openai.com/v1>
195 fn build_url(&self, base_url: &str) -> Result<String, OapiError> {
196 let mut url =
197 Url::parse(base_url.trim_end_matches('/')).map_err(|e| OapiError::UrlError(e))?;
198 url.path_segments_mut()
199 .map_err(|_| OapiError::UrlCannotBeBase(base_url.to_string()))?
200 .push("completions");
201
202 Ok(url.to_string())
203 }
204}
205
206impl PostNoStream for CompletionRequest {
207 type Response = super::response::Completion;
208}
209
210impl PostStream for CompletionRequest {
211 type Response = super::response::Completion;
212}
213
214#[cfg(test)]
215mod tests {
216 use std::sync::LazyLock;
217
218 use futures_util::StreamExt;
219
220 use super::*;
221
222 const QWEN_MODEL: &str = "qwen-coder-turbo-latest";
223 const QWEN_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1";
224 const QWEN_API_KEY: LazyLock<&'static str> =
225 LazyLock::new(|| include_str!("../../keys/modelstudio_domestic_key").trim());
226
227 #[tokio::test]
228 async fn test_qwen_completions_no_stream() -> Result<(), anyhow::Error> {
229 let request_body = CompletionRequest {
230 model: QWEN_MODEL.to_string(),
231 prompt: Prompt::PromptString(
232 r#"
233 package main
234
235 import (
236 "fmt"
237 "strings"
238 "net/http"
239 "io/ioutil"
240 )
241
242 func main() {
243
244 url := "https://api.deepseek.com/chat/completions"
245 method := "POST"
246
247 payload := strings.NewReader(`{
248 "messages": [
249 {
250 "content": "You are a helpful assistant",
251 "role": "system"
252 },
253 {
254 "content": "Hi",
255 "role": "user"
256 }
257 ],
258 "model": "deepseek-chat",
259 "frequency_penalty": 0,
260 "max_tokens": 4096,
261 "presence_penalty": 0,
262 "response_format": {
263 "type": "text"
264 },
265 "stop": null,
266 "stream": false,
267 "stream_options": null,
268 "temperature": 1,
269 "top_p": 1,
270 "tools": null,
271 "tool_choice": "none",
272 "logprobs": false,
273 "top_logprobs": null
274 }`)
275
276 client := &http.Client {
277 }
278 req, err := http.NewRequest(method, url, payload)
279
280 if err != nil {
281 fmt.Println(err)
282 return
283 }
284 req.Header.Add("Content-Type", "application/json")
285 req.Header.Add("Accept", "application/json")
286 req.Header.Add("Authorization", "Bearer <TOKEN>")
287
288 res, err := client.Do(req)
289 if err != nil {
290 fmt.Println(err)
291 return
292 }
293 defer res.Body.Close()
294"#
295 .to_string(),
296 ),
297 suffix: Some(
298 r#"
299 if err != nil {
300 fmt.Println(err)
301 return
302 }
303 fmt.Println(string(body))
304}
305"#
306 .to_string(),
307 ),
308 stream: false,
309 ..Default::default()
310 };
311
312 let result = request_body
313 .get_response_string(QWEN_URL, *QWEN_API_KEY)
314 .await?;
315 println!("{}", result);
316
317 Ok(())
318 }
319
320 #[tokio::test]
321 async fn test_qwen_completions_stream() -> Result<(), anyhow::Error> {
322 let request_body = CompletionRequest {
323 model: QWEN_MODEL.to_string(),
324 prompt: Prompt::PromptString(
325 r#"
326 package main
327
328 import (
329 "fmt"
330 "strings"
331 "net/http"
332 "io/ioutil"
333 )
334
335 func main() {
336
337 url := "https://api.deepseek.com/chat/completions"
338 method := "POST"
339
340 payload := strings.NewReader(`{
341 "messages": [
342 {
343 "content": "You are a helpful assistant",
344 "role": "system"
345 },
346 {
347 "content": "Hi",
348 "role": "user"
349 }
350 ],
351 "model": "deepseek-chat",
352 "frequency_penalty": 0,
353 "max_tokens": 4096,
354 "presence_penalty": 0,
355 "response_format": {
356 "type": "text"
357 },
358 "stop": null,
359 "stream": true,
360 "stream_options": null,
361 "temperature": 1,
362 "top_p": 1,
363 "tools": null,
364 "tool_choice": "none",
365 "logprobs": false,
366 "top_logprobs": null
367 }`)
368
369 client := &http.Client {
370 }
371 req, err := http.NewRequest(method, url, payload)
372
373 if err != nil {
374 fmt.Println(err)
375 return
376 }
377 req.Header.Add("Content-Type", "application/json")
378 req.Header.Add("Accept", "application/json")
379 req.Header.Add("Authorization", "Bearer <TOKEN>")
380
381 res, err := client.Do(req)
382 if err != nil {
383 fmt.Println(err)
384 return
385 }
386 defer res.Body.Close()
387 "#
388 .to_string(),
389 ),
390 suffix: Some(
391 r#"
392 if err != nil {
393 fmt.Println(err)
394 return
395 }
396 fmt.Println(string(body))
397 }
398 "#
399 .to_string(),
400 ),
401 stream: true,
402 ..Default::default()
403 };
404
405 let mut stream = request_body
406 .get_stream_response_string(QWEN_URL, *QWEN_API_KEY)
407 .await?;
408
409 while let Some(chunk) = stream.next().await {
410 match chunk {
411 Ok(data) => {
412 println!("Received chunk: {:?}", data);
413 }
414 Err(e) => {
415 eprintln!("Error receiving chunk: {:?}", e);
416 break;
417 }
418 }
419 }
420
421 Ok(())
422 }
423}