openai_interface/rest/
post.rs

1use std::{future::Future, str::FromStr};
2
3use eventsource_stream::Eventsource;
4use futures_util::{StreamExt, TryStreamExt, stream::BoxStream};
5use serde::{Serialize, de::DeserializeOwned};
6
7use crate::errors::OapiError;
8
9pub trait Post {
10    fn is_streaming(&self) -> bool;
11    /// Builds the URL for the request.
12    ///
13    /// `base_url` should be like <https://api.openai.com/v1>
14    fn build_url(&self, base_url: &str) -> Result<String, OapiError>;
15}
16
17pub trait PostNoStream: Post + Serialize + Sync + Send {
18    type Response: DeserializeOwned + FromStr<Err = OapiError> + Send + Sync;
19
20    /// Sends a POST request to the specified URL with the provided api-key.
21    fn get_response_string(
22        &self,
23        base_url: &str,
24        key: &str,
25    ) -> impl Future<Output = Result<String, OapiError>> + Send + Sync {
26        async move {
27            if self.is_streaming() {
28                return Err(OapiError::NonStreamingViolation);
29            }
30
31            let client = reqwest::Client::new();
32            let response = client
33                .post(self.build_url(base_url)?)
34                .headers({
35                    let mut headers = reqwest::header::HeaderMap::new();
36                    headers.insert("Content-Type", "application/json".parse().unwrap());
37                    headers.insert("Accept", "application/json".parse().unwrap());
38                    headers
39                })
40                .bearer_auth(key)
41                .json(self)
42                .send()
43                .await
44                .map_err(|e| OapiError::SendError(format!("Failed to send request: {:#?}", e)))?;
45
46            if response.status() != reqwest::StatusCode::OK {
47                return Err(
48                    crate::errors::OapiError::ResponseStatus(response.status().as_u16()).into(),
49                );
50            }
51
52            let text = response.text().await.map_err(|e| {
53                OapiError::ResponseError(format!("Failed to get response text: {:#?}", e))
54            })?;
55
56            Ok(text)
57        }
58    }
59
60    fn get_response(
61        &self,
62        url: &str,
63        key: &str,
64    ) -> impl Future<Output = Result<Self::Response, OapiError>> + Send + Sync {
65        async move {
66            let text = self.get_response_string(url, key).await?;
67            let result = Self::Response::from_str(&text)?;
68            Ok(result)
69        }
70    }
71}
72
73pub trait PostStream: Post + Serialize + Sync + Send {
74    type Response: DeserializeOwned + FromStr<Err = OapiError> + Send + Sync;
75
76    /// Sends a streaming POST request to the specified URL with the provided api-key.
77    ///
78    /// # Example
79    ///
80    /// ```rust
81    /// use std::sync::LazyLock;
82    /// use futures_util::StreamExt;
83    /// use openai_interface::chat::create::request::{Message, RequestBody};
84    /// use openai_interface::rest::post::PostStream;
85    ///
86    /// const DEEPSEEK_API_KEY: LazyLock<&str> =
87    ///     LazyLock::new(|| include_str!("../../keys/deepseek_domestic_key").trim());
88    /// const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/";
89    /// const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
90    ///
91    /// #[tokio::main]
92    /// async fn main() {
93    ///     let request = RequestBody {
94    ///         messages: vec![
95    ///             Message::System {
96    ///                 content: "This is a request of test purpose. Reply briefly".to_string(),
97    ///                 name: None,
98    ///             },
99    ///             Message::User {
100    ///                 content: "What's your name?".to_string(),
101    ///                 name: None,
102    ///             },
103    ///         ],
104    ///         model: DEEPSEEK_MODEL.to_string(),
105    ///         stream: true,
106    ///         ..Default::default()
107    ///     };
108    ///
109    ///     let mut response = request
110    ///         .get_stream_response_string(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
111    ///         .await
112    ///         .unwrap();
113    ///
114    ///     while let Some(chunk) = response.next().await {
115    ///         println!("{}", chunk.unwrap());
116    ///     }
117    /// }
118    /// ```
119    fn get_stream_response_string(
120        &self,
121        base_url: &str,
122        api_key: &str,
123    ) -> impl Future<Output = Result<BoxStream<'static, Result<String, OapiError>>, OapiError>>
124    + Send
125    + Sync {
126        async move {
127            if !self.is_streaming() {
128                return Err(OapiError::StreamingViolation);
129            }
130
131            let client = reqwest::Client::new();
132
133            let response = client
134                .post(&self.build_url(base_url)?)
135                .headers({
136                    let mut headers = reqwest::header::HeaderMap::new();
137                    headers.insert("Content-Type", "application/json".parse().unwrap());
138                    headers.insert("Accept", "text/event-stream".parse().unwrap());
139                    headers
140                })
141                .bearer_auth(api_key)
142                .json(self)
143                .send()
144                .await
145                .map_err(|e| OapiError::ResponseError(format!("Failed to send request: {}", e)))?;
146
147            if !response.status().is_success() {
148                return Err(OapiError::ResponseStatus(response.status().as_u16()).into());
149            }
150
151            // The following code is generated by Qwen3-480B-Coder
152            // 使用 eventsource-stream 解析 SSE
153            let stream = response
154                .bytes_stream()
155                .eventsource()
156                .map(|event| match event {
157                    Ok(event) => Ok(event.data),
158                    Err(e) => Err(OapiError::SseParseError(format!("SSE parse error: {}", e))),
159                })
160                .boxed();
161
162            Ok(stream as BoxStream<'static, Result<String, OapiError>>)
163        }
164    }
165
166    fn get_stream_response(
167        &self,
168        base_url: &str,
169        api_key: &str,
170    ) -> impl Future<
171        Output = Result<BoxStream<'static, Result<Self::Response, OapiError>>, OapiError>,
172    > + Send
173    + Sync {
174        async move {
175            let stream = self.get_stream_response_string(base_url, api_key).await?;
176
177            let parsed_stream = stream
178                .take_while(|result| {
179                    let should_continue = match result {
180                        Ok(data) => data != "[DONE]",
181                        Err(_) => true, // 继续传播错误
182                    };
183                    async move { should_continue }
184                })
185                .and_then(|data| async move { Self::Response::from_str(&data) });
186
187            Ok(Box::pin(parsed_stream) as BoxStream<'static, _>)
188        }
189    }
190}