openai_interface/rest/
post.rs1use std::{future::Future, str::FromStr};
2
3use eventsource_stream::Eventsource;
4use futures_util::{StreamExt, TryStreamExt, stream::BoxStream};
5use serde::{Serialize, de::DeserializeOwned};
6
7use crate::errors::OapiError;
8
9pub trait Post {
10 fn is_streaming(&self) -> bool;
11 fn build_url(&self, base_url: &str) -> Result<String, OapiError>;
15}
16
17pub trait PostNoStream: Post + Serialize + Sync + Send {
18 type Response: DeserializeOwned + FromStr<Err = OapiError> + Send + Sync;
19
20 fn get_response_string(
22 &self,
23 base_url: &str,
24 key: &str,
25 ) -> impl Future<Output = Result<String, OapiError>> + Send + Sync {
26 async move {
27 if self.is_streaming() {
28 return Err(OapiError::NonStreamingViolation);
29 }
30
31 let client = reqwest::Client::new();
32 let response = client
33 .post(self.build_url(base_url)?)
34 .headers({
35 let mut headers = reqwest::header::HeaderMap::new();
36 headers.insert("Content-Type", "application/json".parse().unwrap());
37 headers.insert("Accept", "application/json".parse().unwrap());
38 headers
39 })
40 .bearer_auth(key)
41 .json(self)
42 .send()
43 .await
44 .map_err(|e| OapiError::SendError(format!("Failed to send request: {:#?}", e)))?;
45
46 if response.status() != reqwest::StatusCode::OK {
47 return Err(
48 crate::errors::OapiError::ResponseStatus(response.status().as_u16()).into(),
49 );
50 }
51
52 let text = response.text().await.map_err(|e| {
53 OapiError::ResponseError(format!("Failed to get response text: {:#?}", e))
54 })?;
55
56 Ok(text)
57 }
58 }
59
60 fn get_response(
61 &self,
62 url: &str,
63 key: &str,
64 ) -> impl Future<Output = Result<Self::Response, OapiError>> + Send + Sync {
65 async move {
66 let text = self.get_response_string(url, key).await?;
67 let result = Self::Response::from_str(&text)?;
68 Ok(result)
69 }
70 }
71}
72
73pub trait PostStream: Post + Serialize + Sync + Send {
74 type Response: DeserializeOwned + FromStr<Err = OapiError> + Send + Sync;
75
76 fn get_stream_response_string(
120 &self,
121 base_url: &str,
122 api_key: &str,
123 ) -> impl Future<Output = Result<BoxStream<'static, Result<String, OapiError>>, OapiError>>
124 + Send
125 + Sync {
126 async move {
127 if !self.is_streaming() {
128 return Err(OapiError::StreamingViolation);
129 }
130
131 let client = reqwest::Client::new();
132
133 let response = client
134 .post(&self.build_url(base_url)?)
135 .headers({
136 let mut headers = reqwest::header::HeaderMap::new();
137 headers.insert("Content-Type", "application/json".parse().unwrap());
138 headers.insert("Accept", "text/event-stream".parse().unwrap());
139 headers
140 })
141 .bearer_auth(api_key)
142 .json(self)
143 .send()
144 .await
145 .map_err(|e| OapiError::ResponseError(format!("Failed to send request: {}", e)))?;
146
147 if !response.status().is_success() {
148 return Err(OapiError::ResponseStatus(response.status().as_u16()).into());
149 }
150
151 let stream = response
154 .bytes_stream()
155 .eventsource()
156 .map(|event| match event {
157 Ok(event) => Ok(event.data),
158 Err(e) => Err(OapiError::SseParseError(format!("SSE parse error: {}", e))),
159 })
160 .boxed();
161
162 Ok(stream as BoxStream<'static, Result<String, OapiError>>)
163 }
164 }
165
166 fn get_stream_response(
167 &self,
168 base_url: &str,
169 api_key: &str,
170 ) -> impl Future<
171 Output = Result<BoxStream<'static, Result<Self::Response, OapiError>>, OapiError>,
172 > + Send
173 + Sync {
174 async move {
175 let stream = self.get_stream_response_string(base_url, api_key).await?;
176
177 let parsed_stream = stream
178 .take_while(|result| {
179 let should_continue = match result {
180 Ok(data) => data != "[DONE]",
181 Err(_) => true, };
183 async move { should_continue }
184 })
185 .and_then(|data| async move { Self::Response::from_str(&data) });
186
187 Ok(Box::pin(parsed_stream) as BoxStream<'static, _>)
188 }
189 }
190}