openai_interface/chat/
request.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use futures_util::{TryStreamExt, stream::BoxStream};
6
7use crate::errors::RequestError;
8
9#[derive(Serialize, Deserialize, Debug)]
10pub struct RequestBody {
11    /// A list of messages comprising the conversation so far.
12    pub messages: Vec<Message>,
13    /// Name of the model to use to generate the response.
14    pub model: String,
15    /// Although it is optional, you should explicitly designate it
16    /// for an expected response.
17    pub stream: bool,
18    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their
19    /// existing frequency in the text so far, decreasing the model's likelihood to
20    /// repeat the same line verbatim.
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub frequency_penalty: Option<f32>,
23    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on
24    /// whether they appear in the text so far, increasing the model's likelihood to
25    /// talk about new topics.
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub presence_penalty: Option<f32>,
28    /// The maximum number of tokens that can be generated in the chat completion.
29    /// Deprecated according to OpenAI's Python SDK in favour of
30    /// `max_completion_tokens`.
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub max_tokens: Option<u32>,
33    /// An upper bound for the number of tokens that can be generated for a completion,
34    /// including visible output tokens and reasoning tokens.
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub max_completion_tokens: Option<u32>,
37    /// specifying the format that the model must output.
38    ///
39    /// Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured
40    /// Outputs which ensures the model will match your supplied JSON schema. Learn more
41    /// in the
42    /// [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
43    /// Setting to `{ "type": "json_object" }` enables the older JSON mode, which
44    /// ensures the message the model generates is valid JSON. Using `json_schema` is
45    /// preferred for models that support it.
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub response_format: Option<ResponseFormat>, // The type of this attribute needs improvements.
48    /// If specified, the system will make a best effort to sample deterministically. Determinism
49    /// is not guaranteed, and you should refer to the `system_fingerprint` response parameter to
50    /// monitor changes in the backend.
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub seed: Option<i64>,
53    /// How many chat completion choices to generate for each input message. Note that
54    /// you will be charged based on the number of generated tokens across all of the
55    /// choices. Keep `n` as `1` to minimize costs.
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub n: Option<u32>,
58    /// Up to 4 sequences where the API will stop generating further tokens. The
59    /// returned text will not contain the stop sequence.
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub stop: Option<StopKeywords>,
62    /// Options for streaming response. Only set this when you set `stream: true`
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub stream_options: Option<StreamOptions>,
65    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
66    /// make the output more random, while lower values like 0.2 will make it more
67    /// focused and deterministic. It is generally recommended to alter this or `top_p` but
68    /// not both.
69    pub temperature: Option<f32>,
70    /// An alternative to sampling with temperature, called nucleus sampling, where the
71    /// model considers the results of the tokens with top_p probability mass. So 0.1
72    /// means only the tokens comprising the top 10% probability mass are considered.
73    ///
74    /// It is generally recommended to alter this or `temperature` but not both.
75    pub top_p: Option<f32>,
76    /// A list of tools the model may call.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub tools: Option<Vec<Tools>>,
79    /// Controls which (if any) tool is called by the model. `none` means the model will
80    /// not call any tool and instead generates a message. `auto` means the model can
81    /// pick between generating a message or calling one or more tools. `required` means
82    /// the model must call one or more tools. Specifying a particular tool via
83    /// `{"type": "function", "function": {"name": "my_function"}}` forces the model to
84    /// call that tool.
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub tool_choice: Option<ToolChoice>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub logprobs: Option<bool>,
89    /// An integer between 0 and 20 specifying the number of most likely tokens to
90    /// return at each token position, each with an associated log probability.
91    /// `logprobs` must be set to `true` if this parameter is used.
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub top_logprobs: Option<u32>,
94
95    /// Other request bodies that are not in standard OpenAI API.
96    #[serde(flatten, skip_serializing_if = "Option::is_none")]
97    pub extra_body: Option<ExtraBody>,
98
99    /// Other request bodies that are not in standard OpenAI API and
100    /// not included in the ExtraBody struct.
101    #[serde(flatten, skip_serializing_if = "Option::is_none")]
102    pub extra_body_map: Option<HashMap<String, String>>,
103}
104
105#[derive(Serialize, Deserialize, Debug)]
106#[serde(tag = "role", rename_all = "lowercase")]
107pub enum Message {
108    /// In this case, the role of the message author is `system`.
109    /// The field `{ role = "system" }` is added automatically.
110    System {
111        /// The contents of the system message.
112        content: String,
113        /// An optional name for the participant.
114        ///
115        /// Provides the model information to differentiate between
116        /// participants of the same role.
117        #[serde(skip_serializing_if = "Option::is_none")]
118        name: Option<String>,
119    },
120    /// In this case, the role of the message author is `user`.
121    /// The field `{ role = "user" }` is added automatically.
122    User {
123        /// The contents of the user message.
124        content: String,
125        /// An optional name for the participant.
126        ///
127        /// Provides the model information to differentiate between
128        /// participants of the same role.
129        #[serde(skip_serializing_if = "Option::is_none")]
130        name: Option<String>,
131    },
132    /// In this case, the role of the message author is `assistant`.
133    /// The field `{ role = "assistant" }` is added automatically.
134    ///
135    /// Unimplemented params:
136    /// - _audio_: Data about a previous audio response from the model.
137    Assistant {
138        /// The contents of the assistant message. Required unless `tool_calls`
139        /// or `function_call` is specified. (Note that `function_call` is deprecated
140        /// in favour of `tool_calls`.)
141        content: Option<String>,
142        /// The refusal message by the assistant.
143        #[serde(skip_serializing_if = "Option::is_none")]
144        refusal: Option<String>,
145        #[serde(skip_serializing_if = "Option::is_none")]
146        name: Option<String>,
147        /// Set this to true for completion
148        #[serde(skip_serializing_if = "is_false")]
149        prefix: bool,
150        /// Used for the deepseek-reasoner model in the Chat Prefix
151        /// Completion feature as the input for the CoT in the last
152        /// assistant message. When using this feature, the prefix
153        /// parameter must be set to true.
154        #[serde(skip_serializing_if = "Option::is_none")]
155        reasoning_content: Option<String>,
156
157        /// The tool calls generated by the model, such as function calls.
158        #[serde(skip_serializing_if = "Option::is_none")]
159        tool_calls: Option<Vec<AssistantToolCall>>,
160    },
161    /// In this case, the role of the message author is `assistant`.
162    /// The field `{ role = "tool" }` is added automatically.
163    Tool {
164        /// The contents of the tool message.
165        content: String,
166        /// Tool call that this message is responding to.
167        tool_call_id: String,
168    },
169}
170
171#[derive(Debug, Serialize, Deserialize)]
172#[serde(tag = "role", rename_all = "lowercase")]
173pub enum AssistantToolCall {
174    Function {
175        /// The ID of the tool call.
176        id: String,
177        /// The function that the model called.
178        function: ToolCallFunction,
179    },
180    Custom {
181        /// The ID of the tool call.
182        id: String,
183        /// The custom tool that the model called.
184        custom: ToolCallCustom,
185    },
186}
187
188#[derive(Debug, Serialize, Deserialize)]
189pub struct ToolCallFunction {
190    /// The arguments to call the function with, as generated by the model in JSON
191    /// format. Note that the model does not always generate valid JSON, and may
192    /// hallucinate parameters not defined by your function schema. Validate the
193    /// arguments in your code before calling your function.
194    arguments: String,
195    /// The name of the function to call.
196    name: String,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200pub struct ToolCallCustom {
201    /// The input for the custom tool call generated by the model.
202    input: String,
203    /// The name of the custom tool to call.
204    name: String,
205}
206
207#[derive(Serialize, Deserialize, Debug)]
208pub enum ResponseFormat {
209    JsonObject,
210    Text,
211}
212
213fn is_false(value: &bool) -> bool {
214    !value
215}
216
217#[derive(Serialize, Deserialize, Debug)]
218#[serde(untagged)]
219pub enum StopKeywords {
220    Word(String),
221    Words(Vec<String>),
222}
223
224#[derive(Serialize, Deserialize, Debug)]
225pub struct StreamOptions {
226    pub include_usage: bool,
227}
228
229#[derive(Serialize, Deserialize, Debug)]
230pub struct Tools {
231    #[serde(rename = "type")]
232    pub type_: String,
233    pub function: Option<Vec<ToolFunction>>,
234}
235
236#[derive(Serialize, Deserialize, Debug)]
237pub struct ToolFunction {
238    name: String,
239    description: String,
240    #[serde(skip_serializing_if = "Option::is_none")]
241    strict: Option<bool>,
242}
243
244#[derive(Serialize, Deserialize, Debug)]
245pub struct ToolFunctionParameter {
246    name: String,
247    description: String,
248    required: bool,
249    parameters: String,
250}
251
252#[derive(Serialize, Deserialize, Debug)]
253pub enum ToolChoice {
254    #[serde(rename = "none")]
255    None,
256    #[serde(rename = "auto")]
257    Auto,
258    #[serde(rename = "required")]
259    Required,
260    #[serde(untagged)]
261    Specific {
262        /// This parameter should always be "function" literal.
263        #[serde(rename = "type")]
264        type_: ToolChoiceSpecificType,
265        function: ToolChoiceFunction,
266    },
267}
268
269#[derive(Serialize, Deserialize, Debug)]
270pub struct ToolChoiceFunction {
271    pub name: String,
272}
273
274#[derive(Serialize, Deserialize, Debug)]
275#[serde(rename_all = "lowercase")]
276pub enum ToolChoiceSpecificType {
277    Function,
278}
279
280#[derive(Serialize, Deserialize, Debug)]
281pub struct ExtraBody {
282    /// Make sense only for Qwen API.
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub enable_thinking: Option<bool>,
285    /// Make sense only for Qwen API.
286    #[serde(skip_serializing_if = "Option::is_none")]
287    pub thinking_budget: Option<u32>,
288    ///The size of the candidate set for sampling during generation.
289    ///
290    /// Make sense only for Qwen API.
291    #[serde(skip_serializing_if = "Option::is_none")]
292    pub top_k: Option<u32>,
293}
294
295impl Default for RequestBody {
296    fn default() -> Self {
297        RequestBody {
298            messages: vec![],
299            model: "deepseek-chat".to_string(),
300            frequency_penalty: None,
301            presence_penalty: None,
302            max_completion_tokens: None,
303            max_tokens: None,
304            response_format: None,
305            seed: None,
306            n: None,
307            stop: None,
308            stream: false,
309            stream_options: None,
310            temperature: None,
311            top_p: None,
312            tools: None,
313            tool_choice: None,
314            logprobs: None,
315            top_logprobs: None,
316            extra_body: None,
317            extra_body_map: None,
318        }
319    }
320}
321
322impl RequestBody {
323    pub async fn get_response(&self, url: &str, key: &str) -> anyhow::Result<String> {
324        assert!(!self.stream);
325
326        let client = reqwest::Client::new();
327        let response = client
328            .post(url)
329            .headers({
330                let mut headers = reqwest::header::HeaderMap::new();
331                headers.insert("Content-Type", "application/json".parse().unwrap());
332                headers.insert("Accept", "application/json".parse().unwrap());
333                headers
334            })
335            .bearer_auth(key)
336            .json(self)
337            .send()
338            .await
339            .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
340
341        if response.status() != reqwest::StatusCode::OK {
342            return Err(
343                crate::errors::RequestError::ResponseStatus(response.status().as_u16()).into(),
344            );
345        }
346
347        let text = response.text().await?;
348
349        Ok(text)
350    }
351
352    /// Getting stream response. You must ensure self.stream is true, or otherwise it will panic.
353    ///
354    /// # Example
355    ///
356    /// ```rust
357    /// use std::sync::LazyLock;
358    /// use futures_util::StreamExt;
359    /// use openai_interface::chat::request::{Message, RequestBody};
360    ///
361    /// const DEEPSEEK_API_KEY: LazyLock<&str> =
362    ///     LazyLock::new(|| include_str!("../.././keys/deepseek_domestic_key").trim());
363    /// const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/chat/completions";
364    /// const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
365    ///
366    /// #[tokio::main]
367    /// async fn main() {
368    ///     let request = RequestBody {
369    ///         messages: vec![
370    ///             Message::System {
371    ///                 content: "This is a request of test purpose. Reply briefly".to_string(),
372    ///                 name: None,
373    ///             },
374    ///             Message::User {
375    ///                 content: "What's your name?".to_string(),
376    ///                 name: None,
377    ///             },
378    ///         ],
379    ///         model: DEEPSEEK_MODEL.to_string(),
380    ///         stream: true,
381    ///         ..Default::default()
382    ///     };
383    ///
384    ///     let mut response = request
385    ///         .stream_response(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
386    ///         .await
387    ///         .unwrap();
388    ///
389    ///     while let Some(chunk) = response.next().await {
390    ///         println!("{}", chunk.unwrap());
391    ///     }
392    /// }
393    /// ```
394    pub async fn stream_response(
395        &self,
396        url: &str,
397        api_key: &str,
398    ) -> Result<BoxStream<'static, Result<String, anyhow::Error>>, anyhow::Error> {
399        // 断言开启了流模式
400        assert!(
401            self.stream,
402            "RequestBody::stream_response requires `stream: true`"
403        );
404
405        let client = reqwest::Client::new();
406
407        let response = client
408            .post(url)
409            .headers({
410                let mut headers = reqwest::header::HeaderMap::new();
411                headers.insert("Content-Type", "application/json".parse().unwrap());
412                headers.insert("Accept", "application/json".parse().unwrap());
413                headers
414            })
415            .bearer_auth(api_key)
416            .json(self)
417            .send()
418            .await
419            .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
420
421        if !response.status().is_success() {
422            return Err(RequestError::ResponseStatus(response.status().as_u16()).into());
423        }
424
425        let stream = response
426            .bytes_stream()
427            .map_err(|e| RequestError::StreamError(e.to_string()).into())
428            .try_filter_map(|bytes| async move {
429                let s = std::str::from_utf8(&bytes)
430                    .map_err(|e| RequestError::SseParseError(e.to_string()))?;
431                if s.starts_with("[DONE]") {
432                    Ok(None)
433                } else {
434                    Ok(Some(s.to_string()))
435                }
436            });
437
438        Ok(Box::pin(stream) as BoxStream<'static, _>)
439
440        // return Err(anyhow!("Not implemented"));
441    }
442}
443
444#[cfg(test)]
445mod request_test {
446    use std::sync::LazyLock;
447
448    use futures_util::StreamExt;
449
450    use crate::chat::request::{Message, RequestBody};
451
452    const DEEPSEEK_API_KEY: LazyLock<&str> =
453        LazyLock::new(|| include_str!("../.././keys/deepseek_domestic_key").trim());
454    const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/chat/completions";
455    const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
456
457    #[tokio::test]
458    async fn test_00_basics() {
459        let request = RequestBody {
460            messages: vec![
461                Message::System {
462                    content: "This is a request of test purpose. Reply briefly".to_string(),
463                    name: None,
464                },
465                Message::User {
466                    content: "What's your name?".to_string(),
467                    name: None,
468                },
469            ],
470            model: DEEPSEEK_MODEL.to_string(),
471            stream: false,
472            ..Default::default()
473        };
474
475        let response = request
476            .get_response(DEEPSEEK_CHAT_URL, &*DEEPSEEK_API_KEY)
477            .await
478            .unwrap();
479
480        println!("{}", response);
481
482        assert!(response.to_ascii_lowercase().contains("deepseek"));
483    }
484
485    #[tokio::test]
486    async fn test_01_streaming() {
487        let request = RequestBody {
488            messages: vec![
489                Message::System {
490                    content: "This is a request of test purpose. Reply briefly".to_string(),
491                    name: None,
492                },
493                Message::User {
494                    content: "What's your name?".to_string(),
495                    name: None,
496                },
497            ],
498            model: DEEPSEEK_MODEL.to_string(),
499            stream: true,
500            ..Default::default()
501        };
502
503        let mut response = request
504            .stream_response(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
505            .await
506            .unwrap();
507
508        while let Some(chunk) = response.next().await {
509            println!("{}", chunk.unwrap());
510        }
511    }
512}