openai_interface/chat/
request.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use futures_util::{TryStreamExt, stream::BoxStream};
6
7use crate::errors::RequestError;
8
9#[derive(Serialize, Deserialize, Debug)]
10pub struct RequestBody {
11    pub messages: Vec<Message>,
12    pub model: String,
13    /// `frequency_penalty` must be between -2.0 and 2.0
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub frequency_penalty: Option<f32>,
16    /// `presence_penalty` must be between -2.0 and 2.0
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub presence_penalty: Option<f32>,
19    /// `max_tokens` must be greater than 1
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub max_tokens: Option<u32>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub response_format: Option<ResponseFormat>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub seed: Option<i64>,
26    /// The number of responses to generate.
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub n: Option<u32>,
29    /// stop keywords
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub stop: Option<StopKeywords>,
32    /// Although it is optional, you should explicitly designate it
33    /// for an expected response.
34    pub stream: bool,
35    pub stream_options: Option<StreamOptions>,
36    pub temperature: Option<f32>,
37    pub top_p: Option<f32>,
38    pub tools: Option<Vec<Tools>>,
39    pub tool_choice: Option<ToolChoice>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub logprobs: Option<bool>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub top_logprobs: Option<u32>,
44
45    /// Other request bodies that are not in standard OpenAI API.
46    #[serde(flatten)]
47    pub extra_body: Option<ExtraBody>,
48
49    /// Other request bodies that are not in standard OpenAI API and
50    /// not included in the ExtraBody struct.
51    #[serde(flatten)]
52    pub extra_body_map: Option<HashMap<String, String>>,
53}
54
55#[derive(Serialize, Deserialize, Debug)]
56#[serde(tag = "role", rename_all = "lowercase")]
57pub enum Message {
58    System {
59        content: String,
60        #[serde(skip_serializing_if = "Option::is_none")]
61        name: Option<String>,
62    },
63    User {
64        content: String,
65        #[serde(skip_serializing_if = "Option::is_none")]
66        name: Option<String>,
67    },
68    Assistant {
69        content: String,
70        #[serde(skip_serializing_if = "Option::is_none")]
71        name: Option<String>,
72        /// Set this to true for completion
73        #[serde(skip_serializing_if = "is_false")]
74        prefix: bool,
75        /// Used for the deepseek-reasoner model in the Chat Prefix
76        /// Completion feature as the input for the CoT in the last
77        /// assistant message. When using this feature, the prefix
78        /// parameter must be set to true.
79        #[serde(skip_serializing_if = "Option::is_none")]
80        reasoning_content: Option<String>,
81    },
82}
83
84#[derive(Serialize, Deserialize, Debug)]
85pub enum ResponseFormat {
86    JsonObject,
87    Text,
88}
89
90fn is_false(value: &bool) -> bool {
91    !value
92}
93
94#[derive(Serialize, Deserialize, Debug)]
95#[serde(untagged)]
96pub enum StopKeywords {
97    Word(String),
98    Words(Vec<String>),
99}
100
101#[derive(Serialize, Deserialize, Debug)]
102pub struct StreamOptions {
103    pub include_usage: bool,
104}
105
106#[derive(Serialize, Deserialize, Debug)]
107pub struct Tools {
108    #[serde(rename = "type")]
109    pub type_: String,
110    pub function: Option<Vec<ToolFunction>>,
111}
112
113#[derive(Serialize, Deserialize, Debug)]
114pub struct ToolFunction {
115    name: String,
116    description: String,
117    #[serde(skip_serializing_if = "Option::is_none")]
118    strict: Option<bool>,
119}
120
121#[derive(Serialize, Deserialize, Debug)]
122pub struct ToolFunctionParameter {
123    name: String,
124    description: String,
125    required: bool,
126    parameters: String,
127}
128
129#[derive(Serialize, Deserialize, Debug)]
130pub enum ToolChoice {
131    #[serde(rename = "none")]
132    None,
133    #[serde(rename = "auto")]
134    Auto,
135    #[serde(rename = "required")]
136    Required,
137    #[serde(untagged)]
138    Specific {
139        /// This parameter should always be "function" literal.
140        #[serde(rename = "type")]
141        type_: ToolChoiceSpecificType,
142        function: ToolChoiceFunction,
143    },
144}
145
146#[derive(Serialize, Deserialize, Debug)]
147pub struct ToolChoiceFunction {
148    pub name: String,
149}
150
151#[derive(Serialize, Deserialize, Debug)]
152#[serde(rename_all = "lowercase")]
153pub enum ToolChoiceSpecificType {
154    Function,
155}
156
157#[derive(Serialize, Deserialize, Debug)]
158pub struct ExtraBody {
159    /// Make sense only for Qwen API.
160    #[serde(skip_serializing_if = "Option::is_none")]
161    pub enable_thinking: Option<bool>,
162    /// Make sense only for Qwen API.
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub thinking_budget: Option<u32>,
165    ///The size of the candidate set for sampling during generation.
166    ///
167    /// Make sense only for Qwen API.
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub top_k: Option<u32>,
170}
171
172impl Default for RequestBody {
173    fn default() -> Self {
174        RequestBody {
175            messages: vec![],
176            model: "deepseek-chat".to_string(),
177            frequency_penalty: None,
178            presence_penalty: None,
179            max_tokens: None,
180            response_format: None,
181            seed: None,
182            n: None,
183            stop: None,
184            stream: false,
185            stream_options: None,
186            temperature: None,
187            top_p: None,
188            tools: None,
189            tool_choice: None,
190            logprobs: None,
191            top_logprobs: None,
192            extra_body: None,
193            extra_body_map: None,
194        }
195    }
196}
197
198impl RequestBody {
199    pub async fn get_response(&self, url: &str, key: &str) -> anyhow::Result<String> {
200        assert!(!self.stream);
201
202        let client = reqwest::Client::new();
203        let response = client
204            .post(url)
205            .headers({
206                let mut headers = reqwest::header::HeaderMap::new();
207                headers.insert("Content-Type", "application/json".parse().unwrap());
208                headers.insert("Accept", "application/json".parse().unwrap());
209                headers
210            })
211            .bearer_auth(key)
212            .json(self)
213            .send()
214            .await
215            .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
216
217        if response.status() != reqwest::StatusCode::OK {
218            return Err(
219                crate::errors::RequestError::ResponseStatus(response.status().as_u16()).into(),
220            );
221        }
222
223        let text = response.text().await?;
224
225        Ok(text)
226    }
227
228    /// Getting stream response. You must ensure self.stream is true, or otherwise it will panic.
229    ///
230    /// # Example
231    ///
232    /// ```rust
233    /// use std::sync::LazyLock;
234    /// use futures_util::StreamExt;
235    /// use openai_interface::chat::request::{Message, RequestBody};
236    ///
237    /// const DEEPSEEK_API_KEY: LazyLock<&str> =
238    ///     LazyLock::new(|| include_str!("../.././keys/deepseek_domestic_key").trim());
239    /// const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/chat/completions";
240    /// const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
241    ///
242    /// #[tokio::main]
243    /// async fn main() {
244    ///     let request = RequestBody {
245    ///         messages: vec![
246    ///             Message::System {
247    ///                 content: "This is a request of test purpose. Reply briefly".to_string(),
248    ///                 name: None,
249    ///             },
250    ///             Message::User {
251    ///                 content: "What's your name?".to_string(),
252    ///                 name: None,
253    ///             },
254    ///         ],
255    ///         model: DEEPSEEK_MODEL.to_string(),
256    ///         stream: true,
257    ///         ..Default::default()
258    ///     };
259    ///
260    ///     let mut response = request
261    ///         .stream_response(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
262    ///         .await
263    ///         .unwrap();
264    ///
265    ///     while let Some(chunk) = response.next().await {
266    ///         println!("{}", chunk.unwrap());
267    ///     }
268    /// }
269    /// ```
270    pub async fn stream_response(
271        &self,
272        url: &str,
273        api_key: &str,
274    ) -> Result<BoxStream<'static, Result<String, anyhow::Error>>, anyhow::Error> {
275        // 断言开启了流模式
276        assert!(
277            self.stream,
278            "RequestBody::stream_response requires `stream: true`"
279        );
280
281        let client = reqwest::Client::new();
282
283        let response = client
284            .post(url)
285            .headers({
286                let mut headers = reqwest::header::HeaderMap::new();
287                headers.insert("Content-Type", "application/json".parse().unwrap());
288                headers.insert("Accept", "application/json".parse().unwrap());
289                headers
290            })
291            .bearer_auth(api_key)
292            .json(self)
293            .send()
294            .await
295            .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
296
297        if !response.status().is_success() {
298            return Err(RequestError::ResponseStatus(response.status().as_u16()).into());
299        }
300
301        let stream = response
302            .bytes_stream()
303            .map_err(|e| RequestError::StreamError(e.to_string()).into())
304            .try_filter_map(|bytes| async move {
305                let s = std::str::from_utf8(&bytes)
306                    .map_err(|e| RequestError::SseParseError(e.to_string()))?;
307                if s.starts_with("[DONE]") {
308                    Ok(None)
309                } else {
310                    Ok(Some(s.to_string()))
311                }
312            });
313
314        Ok(Box::pin(stream) as BoxStream<'static, _>)
315
316        // return Err(anyhow!("Not implemented"));
317    }
318}
319
320#[cfg(test)]
321mod request_test {
322    use std::sync::LazyLock;
323
324    use futures_util::StreamExt;
325
326    use crate::chat::request::{Message, RequestBody};
327
328    const DEEPSEEK_API_KEY: LazyLock<&str> =
329        LazyLock::new(|| include_str!("../.././keys/deepseek_domestic_key").trim());
330    const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/chat/completions";
331    const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
332
333    #[tokio::test]
334    async fn test_00_basics() {
335        let request = RequestBody {
336            messages: vec![
337                Message::System {
338                    content: "This is a request of test purpose. Reply briefly".to_string(),
339                    name: None,
340                },
341                Message::User {
342                    content: "What's your name?".to_string(),
343                    name: None,
344                },
345            ],
346            model: DEEPSEEK_MODEL.to_string(),
347            stream: false,
348            ..Default::default()
349        };
350
351        let response = request
352            .get_response(DEEPSEEK_CHAT_URL, &*DEEPSEEK_API_KEY)
353            .await
354            .unwrap();
355
356        println!("{}", response);
357
358        assert!(response.to_ascii_lowercase().contains("deepseek"));
359    }
360
361    #[tokio::test]
362    async fn test_01_streaming() {
363        let request = RequestBody {
364            messages: vec![
365                Message::System {
366                    content: "This is a request of test purpose. Reply briefly".to_string(),
367                    name: None,
368                },
369                Message::User {
370                    content: "What's your name?".to_string(),
371                    name: None,
372                },
373            ],
374            model: DEEPSEEK_MODEL.to_string(),
375            stream: true,
376            ..Default::default()
377        };
378
379        let mut response = request
380            .stream_response(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
381            .await
382            .unwrap();
383
384        while let Some(chunk) = response.next().await {
385            println!("{}", chunk.unwrap());
386        }
387    }
388}