openai_interface/rest/
post.rs

1use std::future::Future;
2
3use futures_util::{TryStreamExt, stream::BoxStream};
4use serde::Serialize;
5
6use crate::errors::RequestError;
7
8pub trait Post {
9    fn is_streaming(&self) -> bool;
10}
11
12pub trait NoStream: Post + Serialize + Sync + Send {
13    /// Sends a POST request to the specified URL with the provided api-key.
14    fn get_response(
15        &self,
16        url: &str,
17        key: &str,
18    ) -> impl Future<Output = Result<String, RequestError>> + Send + Sync {
19        async move {
20            if self.is_streaming() {
21                return Err(RequestError::StreamingViolation);
22            }
23
24            let client = reqwest::Client::new();
25            let response = client
26                .post(url)
27                .headers({
28                    let mut headers = reqwest::header::HeaderMap::new();
29                    headers.insert("Content-Type", "application/json".parse().unwrap());
30                    headers.insert("Accept", "application/json".parse().unwrap());
31                    headers
32                })
33                .bearer_auth(key)
34                .json(self)
35                .send()
36                .await
37                .map_err(|e| {
38                    RequestError::SendError(format!("Failed to send request: {:#?}", e))
39                })?;
40
41            if response.status() != reqwest::StatusCode::OK {
42                return Err(crate::errors::RequestError::ResponseStatus(
43                    response.status().as_u16(),
44                )
45                .into());
46            }
47
48            let text = response.text().await.map_err(|e| {
49                RequestError::ResponseError(format!("Failed to get response text: {:#?}", e))
50            })?;
51
52            Ok(text)
53        }
54    }
55}
56
57pub trait Stream: Post + Serialize + Sync + Send {
58    /// Sends a streaming POST request to the specified URL with the provided api-key.
59    ///
60    /// # Example
61    ///
62    /// ```rust
63    /// use std::sync::LazyLock;
64    /// use futures_util::StreamExt;
65    /// use openai_interface::chat::request::{Message, RequestBody};
66    /// use openai_interface::rest::post::Stream;
67    ///
68    /// const DEEPSEEK_API_KEY: LazyLock<&str> =
69    ///     LazyLock::new(|| include_str!("../.././keys/deepseek_domestic_key").trim());
70    /// const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/chat/completions";
71    /// const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
72    ///
73    /// #[tokio::main]
74    /// async fn main() {
75    ///     let request = RequestBody {
76    ///         messages: vec![
77    ///             Message::System {
78    ///                 content: "This is a request of test purpose. Reply briefly".to_string(),
79    ///                 name: None,
80    ///             },
81    ///             Message::User {
82    ///                 content: "What's your name?".to_string(),
83    ///                 name: None,
84    ///             },
85    ///         ],
86    ///         model: DEEPSEEK_MODEL.to_string(),
87    ///         stream: true,
88    ///         ..Default::default()
89    ///     };
90    ///
91    ///     let mut response = request
92    ///         .get_stream_response(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
93    ///         .await
94    ///         .unwrap();
95    ///
96    ///     while let Some(chunk) = response.next().await {
97    ///         println!("{}", chunk.unwrap());
98    ///     }
99    /// }
100    /// ```
101    fn get_stream_response(
102        &self,
103        url: &str,
104        api_key: &str,
105    ) -> impl Future<
106        Output = Result<BoxStream<'static, Result<String, anyhow::Error>>, anyhow::Error>,
107    > + Send
108    + Sync {
109        async move {
110            if !self.is_streaming() {
111                return Err(anyhow::Error::from(RequestError::StreamingViolation));
112            }
113
114            let client = reqwest::Client::new();
115
116            let response = client
117                .post(url)
118                .headers({
119                    let mut headers = reqwest::header::HeaderMap::new();
120                    headers.insert("Content-Type", "application/json".parse().unwrap());
121                    headers.insert("Accept", "application/json".parse().unwrap());
122                    headers
123                })
124                .bearer_auth(api_key)
125                .json(self)
126                .send()
127                .await
128                .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
129
130            if !response.status().is_success() {
131                return Err(RequestError::ResponseStatus(response.status().as_u16()).into());
132            }
133
134            // The following code is generated by Qwen3-Coder-480B
135            // 使用 unfold 来维护状态,正确处理 SSE 格式
136            let stream = futures_util::stream::unfold(
137                (response.bytes_stream(), String::new()),
138                |(mut stream, mut buffer)| async move {
139                    loop {
140                        // 检查缓冲区中是否有完整的事件(以 \n\n 结尾)
141                        while let Some(end_pos) = buffer.find("\n\n") {
142                            let event = buffer[..end_pos].to_string();
143                            buffer.drain(..end_pos + 2); // 移除事件和分隔符
144
145                            // 处理 SSE 事件
146                            if event.starts_with("data: ") {
147                                let data = event[6..].to_string(); // 移除 "data: " 前缀
148                                // 检查是否是 [DONE] 事件
149                                if data == "[DONE]" {
150                                    // 对于 [DONE] 事件,返回它以便调用者可以处理
151                                    return Some((Ok(data), (stream, buffer)));
152                                } else {
153                                    // 对于其他数据事件,返回数据
154                                    return Some((Ok(data), (stream, buffer)));
155                                }
156                            } else if event == "[DONE]" {
157                                // 直接的 [DONE] 事件(不带 data: 前缀)
158                                return Some((Ok("[DONE]".to_string()), (stream, buffer)));
159                            } else {
160                                // 其他类型的事件(如注释),忽略
161                                continue;
162                            }
163                        }
164
165                        // 从流中获取更多数据
166                        match stream.try_next().await {
167                            Ok(Some(bytes)) => {
168                                if let Ok(s) = std::str::from_utf8(&bytes) {
169                                    buffer.push_str(s);
170                                    // 继续循环检查是否有完整事件
171                                    continue;
172                                } else {
173                                    return Some((
174                                        Err(RequestError::SseParseError(
175                                            "Invalid UTF-8 in stream".to_string(),
176                                        )
177                                        .into()),
178                                        (stream, buffer),
179                                    ));
180                                }
181                            }
182                            Ok(None) => {
183                                // 流结束,检查是否有剩余数据
184                                if !buffer.is_empty() {
185                                    // 处理可能的不完整事件
186                                    if buffer.starts_with("data: ") {
187                                        let data = buffer[6..].to_string();
188                                        // 检查是否是 [DONE] 事件
189                                        if data == "[DONE]" {
190                                            return Some((Ok(data), (stream, String::new())));
191                                        } else {
192                                            return Some((Ok(data), (stream, String::new())));
193                                        }
194                                    } else if buffer == "[DONE]" {
195                                        // 直接的 [DONE] 事件
196                                        return Some((
197                                            Ok("[DONE]".to_string()),
198                                            (stream, String::new()),
199                                        ));
200                                    }
201                                }
202                                return None; // 流结束
203                            }
204                            Err(e) => {
205                                return Some((
206                                    Err(anyhow::anyhow!("Stream error: {}", e).into()),
207                                    (stream, buffer),
208                                ));
209                            }
210                        }
211                    }
212                },
213            );
214
215            Ok(Box::pin(stream) as BoxStream<'static, _>)
216        }
217    }
218}