openai_interface/rest/
post.rs1use std::{future::Future, str::FromStr};
2
3use eventsource_stream::Eventsource;
4use futures_util::{StreamExt, TryStreamExt, stream::BoxStream};
5use serde::{Serialize, de::DeserializeOwned};
6
7use crate::errors::OapiError;
8
9pub trait Post {
10 fn is_streaming(&self) -> bool;
11}
12
13pub trait PostNoStream: Post + Serialize + Sync + Send {
14 type Response: DeserializeOwned + FromStr<Err = OapiError> + Send + Sync;
15
16 fn get_response_string(
18 &self,
19 url: &str,
20 key: &str,
21 ) -> impl Future<Output = Result<String, OapiError>> + Send + Sync {
22 async move {
23 if self.is_streaming() {
24 return Err(OapiError::NonStreamingViolation);
25 }
26
27 let client = reqwest::Client::new();
28 let response = client
29 .post(url)
30 .headers({
31 let mut headers = reqwest::header::HeaderMap::new();
32 headers.insert("Content-Type", "application/json".parse().unwrap());
33 headers.insert("Accept", "application/json".parse().unwrap());
34 headers
35 })
36 .bearer_auth(key)
37 .json(self)
38 .send()
39 .await
40 .map_err(|e| OapiError::SendError(format!("Failed to send request: {:#?}", e)))?;
41
42 if response.status() != reqwest::StatusCode::OK {
43 return Err(
44 crate::errors::OapiError::ResponseStatus(response.status().as_u16()).into(),
45 );
46 }
47
48 let text = response.text().await.map_err(|e| {
49 OapiError::ResponseError(format!("Failed to get response text: {:#?}", e))
50 })?;
51
52 Ok(text)
53 }
54 }
55
56 fn get_response(
57 &self,
58 url: &str,
59 key: &str,
60 ) -> impl Future<Output = Result<Self::Response, OapiError>> + Send + Sync {
61 async move {
62 let text = self.get_response_string(url, key).await?;
63 let result = Self::Response::from_str(&text)?;
64 Ok(result)
65 }
66 }
67}
68
69pub trait PostStream: Post + Serialize + Sync + Send {
70 type Response: DeserializeOwned + FromStr<Err = OapiError> + Send + Sync;
71
72 fn get_stream_response_string(
116 &self,
117 url: &str,
118 api_key: &str,
119 ) -> impl Future<Output = Result<BoxStream<'static, Result<String, OapiError>>, OapiError>>
120 + Send
121 + Sync {
122 async move {
123 if !self.is_streaming() {
124 return Err(OapiError::StreamingViolation);
125 }
126
127 let client = reqwest::Client::new();
128
129 let response = client
130 .post(url)
131 .headers({
132 let mut headers = reqwest::header::HeaderMap::new();
133 headers.insert("Content-Type", "application/json".parse().unwrap());
134 headers.insert("Accept", "text/event-stream".parse().unwrap());
135 headers
136 })
137 .bearer_auth(api_key)
138 .json(self)
139 .send()
140 .await
141 .map_err(|e| OapiError::ResponseError(format!("Failed to send request: {}", e)))?;
142
143 if !response.status().is_success() {
144 return Err(OapiError::ResponseStatus(response.status().as_u16()).into());
145 }
146
147 let stream = response
150 .bytes_stream()
151 .eventsource()
152 .map(|event| match event {
153 Ok(event) => Ok(event.data),
154 Err(e) => Err(OapiError::SseParseError(format!("SSE parse error: {}", e))),
155 })
156 .boxed();
157
158 Ok(stream as BoxStream<'static, Result<String, OapiError>>)
159 }
160 }
161
162 fn get_stream_response(
163 &self,
164 url: &str,
165 api_key: &str,
166 ) -> impl Future<
167 Output = Result<BoxStream<'static, Result<Self::Response, OapiError>>, OapiError>,
168 > + Send
169 + Sync {
170 async move {
171 let stream = self.get_stream_response_string(url, api_key).await?;
172
173 let parsed_stream = stream
174 .take_while(|result| {
175 let should_continue = match result {
176 Ok(data) => data != "[DONE]",
177 Err(_) => true, };
179 async move { should_continue }
180 })
181 .and_then(|data| async move { Self::Response::from_str(&data) });
182
183 Ok(Box::pin(parsed_stream) as BoxStream<'static, _>)
184 }
185 }
186}