openai_interface/completions/
request.rs

1use std::collections::HashMap;
2
3use serde::Serialize;
4use url::Url;
5
6use crate::{
7    errors::OapiError,
8    rest::post::{Post, PostNoStream, PostStream},
9};
10
11#[derive(Debug, Serialize, Default, Clone)]
12pub struct CompletionRequest {
13    /// ID of the model to use. Note that not all models are supported for completion.
14    pub model: String,
15    /// The prompt(s) to generate completions for, encoded as a string, array of
16    /// strings, array of tokens, or array of token arrays.
17    /// Note that <|endoftext|> is the document separator that the model sees during
18    /// training, so if a prompt is not specified the model will generate as if from the
19    /// beginning of a new document.
20    pub prompt: Prompt,
21    /// Generates `best_of` completions server-side and returns the "best" (the one with
22    /// the highest log probability per token). Results cannot be streamed.
23    ///
24    /// When used with `n`, `best_of` controls the number of candidate completions and
25    /// `n` specifies how many to return – `best_of` must be greater than `n`.
26    ///
27    /// **Note:** Because this parameter generates many completions, it can quickly
28    /// consume your token quota. Use carefully and ensure that you have reasonable
29    /// settings for `max_tokens` and `stop`.
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub best_of: Option<usize>,
32    /// Echo back the prompt in addition to the completion
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub echo: Option<bool>,
35    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their
36    /// existing frequency in the text so far, decreasing the model's likelihood to
37    /// repeat the same line verbatim.
38    ///
39    /// [more info about frequency/presence penalties](https://platform.openai.com/docs/guides/text-generation)
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub frequency_penalty: Option<f32>,
42    /// Modify the likelihood of specified tokens appearing in the completion.
43    ///
44    /// Accepts a JSON object that maps tokens (specified by their token ID in the GPT
45    /// tokenizer) to an associated bias value from -100 to 100. You can use this
46    /// [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs.
47    /// Mathematically, the bias is added to the logits generated by the model prior to
48    /// sampling. The exact effect will vary per model, but values between -1 and 1
49    /// should decrease or increase likelihood of selection; values like -100 or 100
50    /// should result in a ban or exclusive selection of the relevant token.
51    ///
52    /// As an example, you can pass `{"50256": -100}` to prevent the <|end-of-stream|> token
53    /// from being generated.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub logit_bias: Option<HashMap<String, isize>>,
56    /// Include the log probabilities on the `logprobs` most likely output tokens, as
57    /// well the chosen tokens. For example, if `logprobs` is 5, the API will return a
58    /// list of the 5 most likely tokens. The API will always return the `logprob` of
59    /// the sampled token, so there may be up to `logprobs+1` elements in the response.
60    ///
61    /// The maximum value for `logprobs` is 5.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub logprobs: Option<usize>,
64    /// The maximum number of [tokens](/tokenizer) that can be generated in the
65    /// completion.
66    ///
67    /// The token count of your prompt plus `max_tokens` cannot exceed the model's
68    /// context length.
69    /// [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
70    /// for counting tokens.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub max_tokens: Option<usize>,
73    /// How many completions to generate for each prompt.
74    ///
75    /// **Note:** Because this parameter generates many completions, it can quickly
76    /// consume your token quota. Use carefully and ensure that you have reasonable
77    /// settings for `max_tokens` and `stop`.
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub n: Option<usize>,
80    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on
81    /// whether they appear in the text so far, increasing the model's likelihood to
82    /// talk about new topics.
83    ///
84    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation)
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub presence_penalty: Option<f32>,
87    /// If specified, our system will make a best effort to sample deterministically,
88    /// such that repeated requests with the same `seed` and parameters should return
89    /// the same result.
90    ///
91    /// Determinism is not guaranteed, and you should refer to the `system_fingerprint`
92    /// response parameter to monitor changes in the backend.
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub seed: Option<usize>,
95    /// Up to 4 sequences where the API will stop generating further tokens. The
96    /// returned text will not contain the stop sequence.
97    ///
98    /// Note: Not supported with latest reasoning models `o3` and `o4-mini`.
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub stop: Option<StopKeywords>,
101    /// Whether to stream back partial progress. If set, tokens will be sent as
102    /// data-only
103    /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
104    /// as they become available, with the stream terminated by a `data: [DONE]`
105    /// message.
106    /// [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
107    pub stream: bool,
108    /// Options for streaming response. Only set this when you set `stream: true`.
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub stream_options: Option<StreamOptions>,
111    /// The suffix that comes after a completion of inserted text.
112    ///
113    /// This parameter is only supported for `gpt-3.5-turbo-instruct`.
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub suffix: Option<String>,
116    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
117    /// make the output more random, while lower values like 0.2 will make it more
118    /// focused and deterministic.
119    ///
120    /// It is generally recommended to alter this or `top_p` but not both.
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub temperature: Option<f32>,
123    /// An alternative to sampling with temperature, called nucleus sampling,
124    /// where the model considers the results of the tokens with `top_p`
125    /// probability mass. So 0.1 means only the tokens comprising the top 10%
126    /// probability mass are considered.
127    ///
128    /// It is generally recommended to alter this or `temperature` but not both.
129    #[serde(skip_serializing_if = "Option::is_none")]
130    pub top_p: Option<f32>,
131    /// A unique identifier representing your end-user, which can help OpenAI to monitor
132    /// and detect abuse.
133    /// [Learn more from OpenAI](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids).
134    #[serde(skip_serializing_if = "Option::is_none")]
135    pub user: Option<String>,
136    /// Add additional JSON properties to the request
137    pub extra_body: serde_json::Map<String, serde_json::Value>,
138}
139
140#[derive(Debug, Serialize, Clone)]
141#[serde(untagged)]
142pub enum Prompt {
143    /// String
144    PromptString(String),
145    /// Array of strings
146    PromptStringArray(Vec<String>),
147    /// Array of tokens
148    TokensArray(Vec<usize>),
149    /// Array of arrays of tokens
150    TokenArraysArray(Vec<Vec<usize>>),
151}
152impl Default for Prompt {
153    fn default() -> Self {
154        Self::PromptString("".to_string())
155    }
156}
157
158#[derive(Debug, Serialize, Clone)]
159pub struct StreamOptions {
160    /// When true, stream obfuscation will be enabled.
161    ///
162    /// Stream obfuscation adds random characters to an `obfuscation` field on streaming
163    /// delta events to normalize payload sizes as a mitigation to certain side-channel
164    /// attacks. These obfuscation fields are included by default, but add a small
165    /// amount of overhead to the data stream. You can set `include_obfuscation` to
166    /// false to optimize for bandwidth if you trust the network links between your
167    /// application and the OpenAI API.
168    pub include_obfuscation: bool,
169    /// If set, an additional chunk will be streamed before the `data: [DONE]` message.
170    ///
171    /// The `usage` field on this chunk shows the token usage statistics for the entire
172    /// request, and the `choices` field will always be an empty array.
173    ///
174    /// All other chunks will also include a `usage` field, but with a null value.
175    /// **NOTE:** If the stream is interrupted, you may not receive the final usage
176    /// chunk which contains the total token usage for the request.
177    pub include_usage: bool,
178}
179
180#[derive(Debug, Serialize, Clone)]
181#[serde(untagged)]
182pub enum StopKeywords {
183    Word(String),
184    Words(Vec<String>),
185}
186
187impl Post for CompletionRequest {
188    fn is_streaming(&self) -> bool {
189        self.stream
190    }
191
192    /// Builds the URL for the request.
193    ///
194    /// `base_url` should be like <https://api.openai.com/v1>
195    fn build_url(&self, base_url: &str) -> Result<String, OapiError> {
196        let mut url =
197            Url::parse(base_url.trim_end_matches('/')).map_err(|e| OapiError::UrlError(e))?;
198        url.path_segments_mut()
199            .map_err(|_| OapiError::UrlCannotBeBase(base_url.to_string()))?
200            .push("completions");
201
202        Ok(url.to_string())
203    }
204}
205
206impl PostNoStream for CompletionRequest {
207    type Response = super::response::Completion;
208}
209
210impl PostStream for CompletionRequest {
211    type Response = super::response::Completion;
212}
213
214#[cfg(test)]
215mod tests {
216    use std::sync::LazyLock;
217
218    use futures_util::StreamExt;
219
220    use super::*;
221
222    const QWEN_MODEL: &str = "qwen-coder-turbo-latest";
223    const QWEN_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1";
224    const QWEN_API_KEY: LazyLock<&'static str> =
225        LazyLock::new(|| include_str!("../../keys/modelstudio_domestic_key").trim());
226
227    #[tokio::test]
228    async fn test_qwen_completions_no_stream() -> Result<(), anyhow::Error> {
229        let request_body = CompletionRequest {
230            model: QWEN_MODEL.to_string(),
231            prompt: Prompt::PromptString(
232                r#"
233    package main
234
235    import (
236      "fmt"
237      "strings"
238      "net/http"
239      "io/ioutil"
240    )
241
242    func main() {
243
244      url := "https://api.deepseek.com/chat/completions"
245      method := "POST"
246
247      payload := strings.NewReader(`{
248      "messages": [
249        {
250          "content": "You are a helpful assistant",
251          "role": "system"
252        },
253        {
254          "content": "Hi",
255          "role": "user"
256        }
257      ],
258      "model": "deepseek-chat",
259      "frequency_penalty": 0,
260      "max_tokens": 4096,
261      "presence_penalty": 0,
262      "response_format": {
263        "type": "text"
264      },
265      "stop": null,
266      "stream": false,
267      "stream_options": null,
268      "temperature": 1,
269      "top_p": 1,
270      "tools": null,
271      "tool_choice": "none",
272      "logprobs": false,
273      "top_logprobs": null
274    }`)
275
276      client := &http.Client {
277      }
278      req, err := http.NewRequest(method, url, payload)
279
280      if err != nil {
281        fmt.Println(err)
282        return
283      }
284      req.Header.Add("Content-Type", "application/json")
285      req.Header.Add("Accept", "application/json")
286      req.Header.Add("Authorization", "Bearer <TOKEN>")
287
288      res, err := client.Do(req)
289      if err != nil {
290        fmt.Println(err)
291        return
292      }
293      defer res.Body.Close()
294"#
295                .to_string(),
296            ),
297            suffix: Some(
298                r#"
299    if err != nil {
300        fmt.Println(err)
301        return
302    }
303    fmt.Println(string(body))
304}
305"#
306                .to_string(),
307            ),
308            stream: false,
309            ..Default::default()
310        };
311
312        let result = request_body
313            .get_response_string(QWEN_URL, *QWEN_API_KEY)
314            .await?;
315        println!("{}", result);
316
317        Ok(())
318    }
319
320    #[tokio::test]
321    async fn test_qwen_completions_stream() -> Result<(), anyhow::Error> {
322        let request_body = CompletionRequest {
323            model: QWEN_MODEL.to_string(),
324            prompt: Prompt::PromptString(
325                r#"
326        package main
327
328        import (
329          "fmt"
330          "strings"
331          "net/http"
332          "io/ioutil"
333        )
334
335        func main() {
336
337          url := "https://api.deepseek.com/chat/completions"
338          method := "POST"
339
340          payload := strings.NewReader(`{
341          "messages": [
342            {
343              "content": "You are a helpful assistant",
344              "role": "system"
345            },
346            {
347              "content": "Hi",
348              "role": "user"
349            }
350          ],
351          "model": "deepseek-chat",
352          "frequency_penalty": 0,
353          "max_tokens": 4096,
354          "presence_penalty": 0,
355          "response_format": {
356            "type": "text"
357          },
358          "stop": null,
359          "stream": true,
360          "stream_options": null,
361          "temperature": 1,
362          "top_p": 1,
363          "tools": null,
364          "tool_choice": "none",
365          "logprobs": false,
366          "top_logprobs": null
367        }`)
368
369          client := &http.Client {
370          }
371          req, err := http.NewRequest(method, url, payload)
372
373          if err != nil {
374            fmt.Println(err)
375            return
376          }
377          req.Header.Add("Content-Type", "application/json")
378          req.Header.Add("Accept", "application/json")
379          req.Header.Add("Authorization", "Bearer <TOKEN>")
380
381          res, err := client.Do(req)
382          if err != nil {
383            fmt.Println(err)
384            return
385          }
386          defer res.Body.Close()
387    "#
388                .to_string(),
389            ),
390            suffix: Some(
391                r#"
392        if err != nil {
393            fmt.Println(err)
394            return
395        }
396        fmt.Println(string(body))
397    }
398    "#
399                .to_string(),
400            ),
401            stream: true,
402            ..Default::default()
403        };
404
405        let mut stream = request_body
406            .get_stream_response_string(QWEN_URL, *QWEN_API_KEY)
407            .await?;
408
409        while let Some(chunk) = stream.next().await {
410            match chunk {
411                Ok(data) => {
412                    println!("Received chunk: {:?}", data);
413                }
414                Err(e) => {
415                    eprintln!("Error receiving chunk: {:?}", e);
416                    break;
417                }
418            }
419        }
420
421        Ok(())
422    }
423}