openai_interface/rest/
post.rs

1use std::{future::Future, str::FromStr};
2
3use eventsource_stream::Eventsource;
4use futures_util::{StreamExt, TryStreamExt, stream::BoxStream};
5use serde::{Serialize, de::DeserializeOwned};
6
7use crate::errors::OapiError;
8
9pub trait Post {
10    fn is_streaming(&self) -> bool;
11}
12
13pub trait PostNoStream: Post + Serialize + Sync + Send {
14    type Response: DeserializeOwned + FromStr<Err = OapiError> + Send + Sync;
15
16    /// Sends a POST request to the specified URL with the provided api-key.
17    fn get_response_string(
18        &self,
19        url: &str,
20        key: &str,
21    ) -> impl Future<Output = Result<String, OapiError>> + Send + Sync {
22        async move {
23            if self.is_streaming() {
24                return Err(OapiError::NonStreamingViolation);
25            }
26
27            let client = reqwest::Client::new();
28            let response = client
29                .post(url)
30                .headers({
31                    let mut headers = reqwest::header::HeaderMap::new();
32                    headers.insert("Content-Type", "application/json".parse().unwrap());
33                    headers.insert("Accept", "application/json".parse().unwrap());
34                    headers
35                })
36                .bearer_auth(key)
37                .json(self)
38                .send()
39                .await
40                .map_err(|e| OapiError::SendError(format!("Failed to send request: {:#?}", e)))?;
41
42            if response.status() != reqwest::StatusCode::OK {
43                return Err(
44                    crate::errors::OapiError::ResponseStatus(response.status().as_u16()).into(),
45                );
46            }
47
48            let text = response.text().await.map_err(|e| {
49                OapiError::ResponseError(format!("Failed to get response text: {:#?}", e))
50            })?;
51
52            Ok(text)
53        }
54    }
55
56    fn get_response(
57        &self,
58        url: &str,
59        key: &str,
60    ) -> impl Future<Output = Result<Self::Response, OapiError>> + Send + Sync {
61        async move {
62            let text = self.get_response_string(url, key).await?;
63            let result = Self::Response::from_str(&text)?;
64            Ok(result)
65        }
66    }
67}
68
69pub trait PostStream: Post + Serialize + Sync + Send {
70    type Response: DeserializeOwned + FromStr<Err = OapiError> + Send + Sync;
71
72    /// Sends a streaming POST request to the specified URL with the provided api-key.
73    ///
74    /// # Example
75    ///
76    /// ```rust
77    /// use std::sync::LazyLock;
78    /// use futures_util::StreamExt;
79    /// use openai_interface::chat::request::{Message, RequestBody};
80    /// use openai_interface::rest::post::PostStream;
81    ///
82    /// const DEEPSEEK_API_KEY: LazyLock<&str> =
83    ///     LazyLock::new(|| include_str!("../.././keys/deepseek_domestic_key").trim());
84    /// const DEEPSEEK_CHAT_URL: &'static str = "https://api.deepseek.com/chat/completions";
85    /// const DEEPSEEK_MODEL: &'static str = "deepseek-chat";
86    ///
87    /// #[tokio::main]
88    /// async fn main() {
89    ///     let request = RequestBody {
90    ///         messages: vec![
91    ///             Message::System {
92    ///                 content: "This is a request of test purpose. Reply briefly".to_string(),
93    ///                 name: None,
94    ///             },
95    ///             Message::User {
96    ///                 content: "What's your name?".to_string(),
97    ///                 name: None,
98    ///             },
99    ///         ],
100    ///         model: DEEPSEEK_MODEL.to_string(),
101    ///         stream: true,
102    ///         ..Default::default()
103    ///     };
104    ///
105    ///     let mut response = request
106    ///         .get_stream_response_string(DEEPSEEK_CHAT_URL, *DEEPSEEK_API_KEY)
107    ///         .await
108    ///         .unwrap();
109    ///
110    ///     while let Some(chunk) = response.next().await {
111    ///         println!("{}", chunk.unwrap());
112    ///     }
113    /// }
114    /// ```
115    fn get_stream_response_string(
116        &self,
117        url: &str,
118        api_key: &str,
119    ) -> impl Future<Output = Result<BoxStream<'static, Result<String, OapiError>>, OapiError>>
120    + Send
121    + Sync {
122        async move {
123            if !self.is_streaming() {
124                return Err(OapiError::StreamingViolation);
125            }
126
127            let client = reqwest::Client::new();
128
129            let response = client
130                .post(url)
131                .headers({
132                    let mut headers = reqwest::header::HeaderMap::new();
133                    headers.insert("Content-Type", "application/json".parse().unwrap());
134                    headers.insert("Accept", "text/event-stream".parse().unwrap());
135                    headers
136                })
137                .bearer_auth(api_key)
138                .json(self)
139                .send()
140                .await
141                .map_err(|e| OapiError::ResponseError(format!("Failed to send request: {}", e)))?;
142
143            if !response.status().is_success() {
144                return Err(OapiError::ResponseStatus(response.status().as_u16()).into());
145            }
146
147            // The following code is generated by Qwen3-480B-Coder
148            // 使用 eventsource-stream 解析 SSE
149            let stream = response
150                .bytes_stream()
151                .eventsource()
152                .map(|event| match event {
153                    Ok(event) => Ok(event.data),
154                    Err(e) => Err(OapiError::SseParseError(format!("SSE parse error: {}", e))),
155                })
156                .boxed();
157
158            Ok(stream as BoxStream<'static, Result<String, OapiError>>)
159        }
160    }
161
162    fn get_stream_response(
163        &self,
164        url: &str,
165        api_key: &str,
166    ) -> impl Future<
167        Output = Result<BoxStream<'static, Result<Self::Response, OapiError>>, OapiError>,
168    > + Send
169    + Sync {
170        async move {
171            let stream = self.get_stream_response_string(url, api_key).await?;
172
173            let parsed_stream = stream
174                .take_while(|result| {
175                    let should_continue = match result {
176                        Ok(data) => data != "[DONE]",
177                        Err(_) => true, // 继续传播错误
178                    };
179                    async move { should_continue }
180                })
181                .and_then(|data| async move { Self::Response::from_str(&data) });
182
183            Ok(Box::pin(parsed_stream) as BoxStream<'static, _>)
184        }
185    }
186}