openai_interface/rest/post.rs
1use std::future::Future;
2
3use futures_util::{TryStreamExt, stream::BoxStream};
4use serde::Serialize;
5
6use crate::errors::RequestError;
7
8pub trait Post {
9 fn is_streaming(&self) -> bool;
10}
11
12pub trait NoStream: Post + Serialize + Sync + Send {
13 /// Sends a POST request to the specified URL with the provided api-key.
14 fn get_response(
15 &self,
16 url: &str,
17 key: &str,
18 ) -> impl Future<Output = Result<String, RequestError>> + Send + Sync {
19 async move {
20 if self.is_streaming() {
21 return Err(RequestError::StreamingViolation);
22 }
23
24 let client = reqwest::Client::new();
25 let response = client
26 .post(url)
27 .headers({
28 let mut headers = reqwest::header::HeaderMap::new();
29 headers.insert("Content-Type", "application/json".parse().unwrap());
30 headers.insert("Accept", "application/json".parse().unwrap());
31 headers
32 })
33 .bearer_auth(key)
34 .json(self)
35 .send()
36 .await
37 .map_err(|e| {
38 RequestError::SendError(format!("Failed to send request: {:#?}", e))
39 })?;
40
41 if response.status() != reqwest::StatusCode::OK {
42 return Err(crate::errors::RequestError::ResponseStatus(
43 response.status().as_u16(),
44 )
45 .into());
46 }
47
48 let text = response.text().await.map_err(|e| {
49 RequestError::ResponseError(format!("Failed to get response text: {:#?}", e))
50 })?;
51
52 Ok(text)
53 }
54 }
55}
56
57pub trait Stream: Post + Serialize + Sync + Send {
58 /// Sends a streaming POST request to the specified URL with the provided api-key.
59 ///
60 /// # Example
61 ///
62 /// ```rust
63 /// use std::sync::LazyLock;
64 /// use futures_util::StreamExt;
65 /// use openai_interface::chat::request::{Message, RequestBody};
66 /// use openai_interface::rest::post::Stream;
67 ///
68 /// const DEEPSEEK_API_KEY: LazyLock<&str> =
69 /// LazyLock::new(|| include_str!("../.././keys/deepseek_domestic_key").trim());
70 /// const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/chat/completions";
71 /// const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
72 ///
73 /// #[tokio::main]
74 /// async fn main() {
75 /// let request = RequestBody {
76 /// messages: vec![
77 /// Message::System {
78 /// content: "This is a request of test purpose. Reply briefly".to_string(),
79 /// name: None,
80 /// },
81 /// Message::User {
82 /// content: "What's your name?".to_string(),
83 /// name: None,
84 /// },
85 /// ],
86 /// model: DEEPSEEK_MODEL.to_string(),
87 /// stream: true,
88 /// ..Default::default()
89 /// };
90 ///
91 /// let mut response = request
92 /// .get_stream_response(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
93 /// .await
94 /// .unwrap();
95 ///
96 /// while let Some(chunk) = response.next().await {
97 /// println!("{}", chunk.unwrap());
98 /// }
99 /// }
100 /// ```
101 fn get_stream_response(
102 &self,
103 url: &str,
104 api_key: &str,
105 ) -> impl Future<
106 Output = Result<BoxStream<'static, Result<String, anyhow::Error>>, anyhow::Error>,
107 > + Send
108 + Sync {
109 async move {
110 if !self.is_streaming() {
111 return Err(anyhow::Error::from(RequestError::StreamingViolation));
112 }
113
114 let client = reqwest::Client::new();
115
116 let response = client
117 .post(url)
118 .headers({
119 let mut headers = reqwest::header::HeaderMap::new();
120 headers.insert("Content-Type", "application/json".parse().unwrap());
121 headers.insert("Accept", "application/json".parse().unwrap());
122 headers
123 })
124 .bearer_auth(api_key)
125 .json(self)
126 .send()
127 .await
128 .map_err(|e| anyhow::anyhow!("Failed to send request: {}", e))?;
129
130 if !response.status().is_success() {
131 return Err(RequestError::ResponseStatus(response.status().as_u16()).into());
132 }
133
134 // The following code is generated by Qwen3-Coder-480B
135 // 使用 unfold 来维护状态,正确处理 SSE 格式
136 let stream = futures_util::stream::unfold(
137 (response.bytes_stream(), String::new()),
138 |(mut stream, mut buffer)| async move {
139 loop {
140 // 检查缓冲区中是否有完整的事件(以 \n\n 结尾)
141 while let Some(end_pos) = buffer.find("\n\n") {
142 let event = buffer[..end_pos].to_string();
143 buffer.drain(..end_pos + 2); // 移除事件和分隔符
144
145 // 处理 SSE 事件
146 if event.starts_with("data: ") {
147 let data = event[6..].to_string(); // 移除 "data: " 前缀
148 // 检查是否是 [DONE] 事件
149 if data == "[DONE]" {
150 // 对于 [DONE] 事件,返回它以便调用者可以处理
151 return Some((Ok(data), (stream, buffer)));
152 } else {
153 // 对于其他数据事件,返回数据
154 return Some((Ok(data), (stream, buffer)));
155 }
156 } else if event == "[DONE]" {
157 // 直接的 [DONE] 事件(不带 data: 前缀)
158 return Some((Ok("[DONE]".to_string()), (stream, buffer)));
159 } else {
160 // 其他类型的事件(如注释),忽略
161 continue;
162 }
163 }
164
165 // 从流中获取更多数据
166 match stream.try_next().await {
167 Ok(Some(bytes)) => {
168 if let Ok(s) = std::str::from_utf8(&bytes) {
169 buffer.push_str(s);
170 // 继续循环检查是否有完整事件
171 continue;
172 } else {
173 return Some((
174 Err(RequestError::SseParseError(
175 "Invalid UTF-8 in stream".to_string(),
176 )
177 .into()),
178 (stream, buffer),
179 ));
180 }
181 }
182 Ok(None) => {
183 // 流结束,检查是否有剩余数据
184 if !buffer.is_empty() {
185 // 处理可能的不完整事件
186 if buffer.starts_with("data: ") {
187 let data = buffer[6..].to_string();
188 // 检查是否是 [DONE] 事件
189 if data == "[DONE]" {
190 return Some((Ok(data), (stream, String::new())));
191 } else {
192 return Some((Ok(data), (stream, String::new())));
193 }
194 } else if buffer == "[DONE]" {
195 // 直接的 [DONE] 事件
196 return Some((
197 Ok("[DONE]".to_string()),
198 (stream, String::new()),
199 ));
200 }
201 }
202 return None; // 流结束
203 }
204 Err(e) => {
205 return Some((
206 Err(anyhow::anyhow!("Stream error: {}", e).into()),
207 (stream, buffer),
208 ));
209 }
210 }
211 }
212 },
213 );
214
215 Ok(Box::pin(stream) as BoxStream<'static, _>)
216 }
217 }
218}