openai_interface/completions/
request.rs

1use std::collections::HashMap;
2
3use serde::Serialize;
4
5use crate::rest::post::{Post, PostNoStream, PostStream};
6
7#[derive(Debug, Serialize, Default, Clone)]
8pub struct CompletionRequest {
9    /// ID of the model to use. Note that not all models are supported for completion.
10    pub model: String,
11    /// The prompt(s) to generate completions for, encoded as a string, array of
12    /// strings, array of tokens, or array of token arrays.
13    /// Note that <|endoftext|> is the document separator that the model sees during
14    /// training, so if a prompt is not specified the model will generate as if from the
15    /// beginning of a new document.
16    pub prompt: Prompt,
17    /// Generates `best_of` completions server-side and returns the "best" (the one with
18    /// the highest log probability per token). Results cannot be streamed.
19    ///
20    /// When used with `n`, `best_of` controls the number of candidate completions and
21    /// `n` specifies how many to return – `best_of` must be greater than `n`.
22    ///
23    /// **Note:** Because this parameter generates many completions, it can quickly
24    /// consume your token quota. Use carefully and ensure that you have reasonable
25    /// settings for `max_tokens` and `stop`.
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub best_of: Option<usize>,
28    /// Echo back the prompt in addition to the completion
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub echo: Option<bool>,
31    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their
32    /// existing frequency in the text so far, decreasing the model's likelihood to
33    /// repeat the same line verbatim.
34    ///
35    /// [more info about frequency/presence penalties](https://platform.openai.com/docs/guides/text-generation)
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub frequency_penalty: Option<f32>,
38    /// Modify the likelihood of specified tokens appearing in the completion.
39    ///
40    /// Accepts a JSON object that maps tokens (specified by their token ID in the GPT
41    /// tokenizer) to an associated bias value from -100 to 100. You can use this
42    /// [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs.
43    /// Mathematically, the bias is added to the logits generated by the model prior to
44    /// sampling. The exact effect will vary per model, but values between -1 and 1
45    /// should decrease or increase likelihood of selection; values like -100 or 100
46    /// should result in a ban or exclusive selection of the relevant token.
47    ///
48    /// As an example, you can pass `{"50256": -100}` to prevent the <|end-of-stream|> token
49    /// from being generated.
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub logit_bias: Option<HashMap<String, isize>>,
52    /// Include the log probabilities on the `logprobs` most likely output tokens, as
53    /// well the chosen tokens. For example, if `logprobs` is 5, the API will return a
54    /// list of the 5 most likely tokens. The API will always return the `logprob` of
55    /// the sampled token, so there may be up to `logprobs+1` elements in the response.
56    ///
57    /// The maximum value for `logprobs` is 5.
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub logprobs: Option<usize>,
60    /// The maximum number of [tokens](/tokenizer) that can be generated in the
61    /// completion.
62    ///
63    /// The token count of your prompt plus `max_tokens` cannot exceed the model's
64    /// context length.
65    /// [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
66    /// for counting tokens.
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub max_tokens: Option<usize>,
69    /// How many completions to generate for each prompt.
70    ///
71    /// **Note:** Because this parameter generates many completions, it can quickly
72    /// consume your token quota. Use carefully and ensure that you have reasonable
73    /// settings for `max_tokens` and `stop`.
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub n: Option<usize>,
76    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on
77    /// whether they appear in the text so far, increasing the model's likelihood to
78    /// talk about new topics.
79    ///
80    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation)
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub presence_penalty: Option<f32>,
83    /// If specified, our system will make a best effort to sample deterministically,
84    /// such that repeated requests with the same `seed` and parameters should return
85    /// the same result.
86    ///
87    /// Determinism is not guaranteed, and you should refer to the `system_fingerprint`
88    /// response parameter to monitor changes in the backend.
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub seed: Option<usize>,
91    /// Up to 4 sequences where the API will stop generating further tokens. The
92    /// returned text will not contain the stop sequence.
93    ///
94    /// Note: Not supported with latest reasoning models `o3` and `o4-mini`.
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub stop: Option<StopKeywords>,
97    /// Whether to stream back partial progress. If set, tokens will be sent as
98    /// data-only
99    /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
100    /// as they become available, with the stream terminated by a `data: [DONE]`
101    /// message.
102    /// [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
103    pub stream: bool,
104    /// Options for streaming response. Only set this when you set `stream: true`.
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub stream_options: Option<StreamOptions>,
107    /// The suffix that comes after a completion of inserted text.
108    ///
109    /// This parameter is only supported for `gpt-3.5-turbo-instruct`.
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub suffix: Option<String>,
112    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
113    /// make the output more random, while lower values like 0.2 will make it more
114    /// focused and deterministic.
115    ///
116    /// It is generally recommended to alter this or `top_p` but not both.
117    #[serde(skip_serializing_if = "Option::is_none")]
118    pub temperature: Option<f32>,
119    /// An alternative to sampling with temperature, called nucleus sampling,
120    /// where the model considers the results of the tokens with `top_p`
121    /// probability mass. So 0.1 means only the tokens comprising the top 10%
122    /// probability mass are considered.
123    ///
124    /// It is generally recommended to alter this or `temperature` but not both.
125    #[serde(skip_serializing_if = "Option::is_none")]
126    pub top_p: Option<f32>,
127    /// A unique identifier representing your end-user, which can help OpenAI to monitor
128    /// and detect abuse.
129    /// [Learn more from OpenAI](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids).
130    #[serde(skip_serializing_if = "Option::is_none")]
131    pub user: Option<String>,
132    /// Add additional JSON properties to the request
133    pub extra_body: serde_json::Map<String, serde_json::Value>,
134}
135
136#[derive(Debug, Serialize, Clone)]
137#[serde(untagged)]
138pub enum Prompt {
139    /// String
140    PromptString(String),
141    /// Array of strings
142    PromptStringArray(Vec<String>),
143    /// Array of tokens
144    TokensArray(Vec<usize>),
145    /// Array of arrays of tokens
146    TokenArraysArray(Vec<Vec<usize>>),
147}
148impl Default for Prompt {
149    fn default() -> Self {
150        Self::PromptString("".to_string())
151    }
152}
153
154#[derive(Debug, Serialize, Clone)]
155pub struct StreamOptions {
156    /// When true, stream obfuscation will be enabled.
157    ///
158    /// Stream obfuscation adds random characters to an `obfuscation` field on streaming
159    /// delta events to normalize payload sizes as a mitigation to certain side-channel
160    /// attacks. These obfuscation fields are included by default, but add a small
161    /// amount of overhead to the data stream. You can set `include_obfuscation` to
162    /// false to optimize for bandwidth if you trust the network links between your
163    /// application and the OpenAI API.
164    pub include_obfuscation: bool,
165    /// If set, an additional chunk will be streamed before the `data: [DONE]` message.
166    ///
167    /// The `usage` field on this chunk shows the token usage statistics for the entire
168    /// request, and the `choices` field will always be an empty array.
169    ///
170    /// All other chunks will also include a `usage` field, but with a null value.
171    /// **NOTE:** If the stream is interrupted, you may not receive the final usage
172    /// chunk which contains the total token usage for the request.
173    pub include_usage: bool,
174}
175
176#[derive(Debug, Serialize, Clone)]
177#[serde(untagged)]
178pub enum StopKeywords {
179    Word(String),
180    Words(Vec<String>),
181}
182
183impl Post for CompletionRequest {
184    fn is_streaming(&self) -> bool {
185        self.stream
186    }
187}
188
189impl PostNoStream for CompletionRequest {
190    type Response = super::response::Completion;
191}
192
193impl PostStream for CompletionRequest {
194    type Response = super::response::Completion;
195}
196
197#[cfg(test)]
198mod tests {
199    use std::sync::LazyLock;
200
201    use futures_util::StreamExt;
202
203    use super::*;
204
205    const QWEN_MODEL: &str = "qwen-coder-turbo-latest";
206    const QWEN_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1/completions";
207    const QWEN_API_KEY: LazyLock<&'static str> =
208        LazyLock::new(|| include_str!("../../keys/modelstudio_domestic_key").trim());
209
210    #[tokio::test]
211    async fn test_qwen_completions_no_stream() -> Result<(), anyhow::Error> {
212        let request_body = CompletionRequest {
213            model: QWEN_MODEL.to_string(),
214            prompt: Prompt::PromptString(
215                r#"
216    package main
217
218    import (
219      "fmt"
220      "strings"
221      "net/http"
222      "io/ioutil"
223    )
224
225    func main() {
226
227      url := "https://api.deepseek.com/chat/completions"
228      method := "POST"
229
230      payload := strings.NewReader(`{
231      "messages": [
232        {
233          "content": "You are a helpful assistant",
234          "role": "system"
235        },
236        {
237          "content": "Hi",
238          "role": "user"
239        }
240      ],
241      "model": "deepseek-chat",
242      "frequency_penalty": 0,
243      "max_tokens": 4096,
244      "presence_penalty": 0,
245      "response_format": {
246        "type": "text"
247      },
248      "stop": null,
249      "stream": false,
250      "stream_options": null,
251      "temperature": 1,
252      "top_p": 1,
253      "tools": null,
254      "tool_choice": "none",
255      "logprobs": false,
256      "top_logprobs": null
257    }`)
258
259      client := &http.Client {
260      }
261      req, err := http.NewRequest(method, url, payload)
262
263      if err != nil {
264        fmt.Println(err)
265        return
266      }
267      req.Header.Add("Content-Type", "application/json")
268      req.Header.Add("Accept", "application/json")
269      req.Header.Add("Authorization", "Bearer <TOKEN>")
270
271      res, err := client.Do(req)
272      if err != nil {
273        fmt.Println(err)
274        return
275      }
276      defer res.Body.Close()
277"#
278                .to_string(),
279            ),
280            suffix: Some(
281                r#"
282    if err != nil {
283        fmt.Println(err)
284        return
285    }
286    fmt.Println(string(body))
287}
288"#
289                .to_string(),
290            ),
291            stream: false,
292            ..Default::default()
293        };
294
295        let result = request_body
296            .get_response_string(QWEN_URL, *QWEN_API_KEY)
297            .await?;
298        println!("{}", result);
299
300        Ok(())
301    }
302
303    #[tokio::test]
304    async fn test_qwen_completions_stream() -> Result<(), anyhow::Error> {
305        let request_body = CompletionRequest {
306            model: QWEN_MODEL.to_string(),
307            prompt: Prompt::PromptString(
308                r#"
309        package main
310
311        import (
312          "fmt"
313          "strings"
314          "net/http"
315          "io/ioutil"
316        )
317
318        func main() {
319
320          url := "https://api.deepseek.com/chat/completions"
321          method := "POST"
322
323          payload := strings.NewReader(`{
324          "messages": [
325            {
326              "content": "You are a helpful assistant",
327              "role": "system"
328            },
329            {
330              "content": "Hi",
331              "role": "user"
332            }
333          ],
334          "model": "deepseek-chat",
335          "frequency_penalty": 0,
336          "max_tokens": 4096,
337          "presence_penalty": 0,
338          "response_format": {
339            "type": "text"
340          },
341          "stop": null,
342          "stream": true,
343          "stream_options": null,
344          "temperature": 1,
345          "top_p": 1,
346          "tools": null,
347          "tool_choice": "none",
348          "logprobs": false,
349          "top_logprobs": null
350        }`)
351
352          client := &http.Client {
353          }
354          req, err := http.NewRequest(method, url, payload)
355
356          if err != nil {
357            fmt.Println(err)
358            return
359          }
360          req.Header.Add("Content-Type", "application/json")
361          req.Header.Add("Accept", "application/json")
362          req.Header.Add("Authorization", "Bearer <TOKEN>")
363
364          res, err := client.Do(req)
365          if err != nil {
366            fmt.Println(err)
367            return
368          }
369          defer res.Body.Close()
370    "#
371                .to_string(),
372            ),
373            suffix: Some(
374                r#"
375        if err != nil {
376            fmt.Println(err)
377            return
378        }
379        fmt.Println(string(body))
380    }
381    "#
382                .to_string(),
383            ),
384            stream: true,
385            ..Default::default()
386        };
387
388        let mut stream = request_body
389            .get_stream_response_string(QWEN_URL, *QWEN_API_KEY)
390            .await?;
391
392        while let Some(chunk) = stream.next().await {
393            match chunk {
394                Ok(data) => {
395                    println!("Received chunk: {:?}", data);
396                }
397                Err(e) => {
398                    eprintln!("Error receiving chunk: {:?}", e);
399                    break;
400                }
401            }
402        }
403
404        Ok(())
405    }
406}