1use crate::{join_url, ApiError, ClientConfig, RequestOptions};
2use futures::{Stream, StreamExt};
3use reqwest::{
4 header::{HeaderName, HeaderValue},
5 Client, Method, Request, Response,
6};
7use serde::de::DeserializeOwned;
8use std::{
9 pin::Pin,
10 str::FromStr,
11 task::{Context, Poll},
12};
13
14pub struct ByteStream {
16 content_length: Option<u64>,
17 inner: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
18}
19
20impl ByteStream {
21 pub(crate) fn new(response: Response) -> Self {
23 let content_length = response.content_length();
24 let stream = response.bytes_stream();
25
26 Self {
27 content_length,
28 inner: Box::pin(stream),
29 }
30 }
31
32 pub async fn collect(mut self) -> Result<Vec<u8>, ApiError> {
43 let mut result = Vec::new();
44 while let Some(chunk) = self.inner.next().await {
45 result.extend_from_slice(&chunk.map_err(ApiError::Network)?);
46 }
47 Ok(result)
48 }
49
50 pub async fn try_next(&mut self) -> Result<Option<bytes::Bytes>, ApiError> {
63 match self.inner.next().await {
64 Some(Ok(bytes)) => Ok(Some(bytes)),
65 Some(Err(e)) => Err(ApiError::Network(e)),
66 None => Ok(None),
67 }
68 }
69
70 pub fn content_length(&self) -> Option<u64> {
72 self.content_length
73 }
74}
75
76impl Stream for ByteStream {
77 type Item = Result<bytes::Bytes, ApiError>;
78
79 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
80 match self.inner.as_mut().poll_next(cx) {
81 Poll::Ready(Some(Ok(bytes))) => Poll::Ready(Some(Ok(bytes))),
82 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(ApiError::Network(e)))),
83 Poll::Ready(None) => Poll::Ready(None),
84 Poll::Pending => Poll::Pending,
85 }
86 }
87}
88
89#[derive(Clone)]
91pub struct HttpClient {
92 client: Client,
93 config: ClientConfig,
94}
95
96impl HttpClient {
97 pub fn new(config: ClientConfig) -> Result<Self, ApiError> {
98 let client = Client::builder()
99 .timeout(config.timeout)
100 .user_agent(&config.user_agent)
101 .build()
102 .map_err(ApiError::Network)?;
103
104 Ok(Self { client, config })
105 }
106
107 pub async fn execute_request<T>(
109 &self,
110 method: Method,
111 path: &str,
112 body: Option<serde_json::Value>,
113 query_params: Option<Vec<(String, String)>>,
114 options: Option<RequestOptions>,
115 ) -> Result<T, ApiError>
116 where
117 T: DeserializeOwned, {
119 let url = join_url(&self.config.base_url, path);
120 let mut request = self.client.request(method, &url);
121
122 if let Some(params) = query_params {
124 request = request.query(¶ms);
125 }
126
127 if let Some(opts) = &options {
129 if !opts.additional_query_params.is_empty() {
130 request = request.query(&opts.additional_query_params);
131 }
132 }
133
134 if let Some(body) = body {
136 request = request.json(&body);
137 }
138
139 let mut req = request.build().map_err(|e| ApiError::Network(e))?;
141
142 self.apply_auth_headers(&mut req, &options)?;
144 self.apply_custom_headers(&mut req, &options)?;
145
146 let response = self.execute_with_retries(req, &options).await?;
148 self.parse_response(response).await
149 }
150
151 fn apply_auth_headers(
152 &self,
153 request: &mut Request,
154 options: &Option<RequestOptions>,
155 ) -> Result<(), ApiError> {
156 let headers = request.headers_mut();
157
158 let api_key = options
160 .as_ref()
161 .and_then(|opts| opts.api_key.as_ref())
162 .or(self.config.api_key.as_ref());
163
164 if let Some(key) = api_key {
165 headers.insert("api_key", key.parse().map_err(|_| ApiError::InvalidHeader)?);
166 }
167
168 let token = options
170 .as_ref()
171 .and_then(|opts| opts.token.as_ref())
172 .or(self.config.token.as_ref());
173
174 if let Some(token) = token {
175 let auth_value = format!("Bearer {}", token);
176 headers.insert(
177 "Authorization",
178 auth_value.parse().map_err(|_| ApiError::InvalidHeader)?,
179 );
180 }
181
182 Ok(())
183 }
184
185 fn apply_custom_headers(
186 &self,
187 request: &mut Request,
188 options: &Option<RequestOptions>,
189 ) -> Result<(), ApiError> {
190 let headers = request.headers_mut();
191
192 for (key, value) in &self.config.custom_headers {
194 headers.insert(
195 HeaderName::from_str(key).map_err(|_| ApiError::InvalidHeader)?,
196 HeaderValue::from_str(value).map_err(|_| ApiError::InvalidHeader)?,
197 );
198 }
199
200 if let Some(options) = options {
202 for (key, value) in &options.additional_headers {
203 headers.insert(
204 HeaderName::from_str(key).map_err(|_| ApiError::InvalidHeader)?,
205 HeaderValue::from_str(value).map_err(|_| ApiError::InvalidHeader)?,
206 );
207 }
208 }
209
210 Ok(())
211 }
212
213 async fn execute_with_retries(
214 &self,
215 request: Request,
216 options: &Option<RequestOptions>,
217 ) -> Result<Response, ApiError> {
218 let max_retries = options
219 .as_ref()
220 .and_then(|opts| opts.max_retries)
221 .unwrap_or(self.config.max_retries);
222
223 let mut last_error = None;
224
225 for attempt in 0..=max_retries {
226 let cloned_request = request.try_clone().ok_or(ApiError::RequestClone)?;
227
228 match self.client.execute(cloned_request).await {
229 Ok(response) if response.status().is_success() => return Ok(response),
230 Ok(response) => {
231 let status_code = response.status().as_u16();
232 let body = response.text().await.ok();
233 return Err(ApiError::from_response(status_code, body.as_deref()));
234 }
235 Err(e) if attempt < max_retries => {
236 last_error = Some(e);
237 let delay = std::time::Duration::from_millis(100 * 2_u64.pow(attempt));
239 tokio::time::sleep(delay).await;
240 }
241 Err(e) => return Err(ApiError::Network(e)),
242 }
243 }
244
245 Err(ApiError::Network(last_error.unwrap()))
246 }
247
248 async fn parse_response<T>(&self, response: Response) -> Result<T, ApiError>
249 where
250 T: DeserializeOwned,
251 {
252 let text = response.text().await.map_err(ApiError::Network)?;
253 serde_json::from_str(&text).map_err(ApiError::Serialization)
254 }
255
256 pub async fn execute_stream_request(
312 &self,
313 method: Method,
314 path: &str,
315 body: Option<serde_json::Value>,
316 query_params: Option<Vec<(String, String)>>,
317 options: Option<RequestOptions>,
318 ) -> Result<ByteStream, ApiError> {
319 let url = join_url(&self.config.base_url, path);
320 let mut request = self.client.request(method, &url);
321
322 if let Some(params) = query_params {
324 request = request.query(¶ms);
325 }
326
327 if let Some(opts) = &options {
329 if !opts.additional_query_params.is_empty() {
330 request = request.query(&opts.additional_query_params);
331 }
332 }
333
334 if let Some(body) = body {
336 request = request.json(&body);
337 }
338
339 let mut req = request.build().map_err(|e| ApiError::Network(e))?;
341
342 self.apply_auth_headers(&mut req, &options)?;
344 self.apply_custom_headers(&mut req, &options)?;
345
346 let response = self.execute_with_retries(req, &options).await?;
348
349 Ok(ByteStream::new(response))
351 }
352
353 pub async fn execute_sse_request<T>(
387 &self,
388 method: Method,
389 path: &str,
390 body: Option<serde_json::Value>,
391 query_params: Option<Vec<(String, String)>>,
392 options: Option<RequestOptions>,
393 terminator: Option<String>,
394 ) -> Result<crate::SseStream<T>, ApiError>
395 where
396 T: DeserializeOwned + Send + 'static,
397 {
398 let url = join_url(&self.config.base_url, path);
399 let mut request = self.client.request(method, &url);
400
401 if let Some(params) = query_params {
403 request = request.query(¶ms);
404 }
405
406 if let Some(opts) = &options {
408 if !opts.additional_query_params.is_empty() {
409 request = request.query(&opts.additional_query_params);
410 }
411 }
412
413 if let Some(body) = body {
415 request = request.json(&body);
416 }
417
418 let mut req = request.build().map_err(|e| ApiError::Network(e))?;
420
421 self.apply_auth_headers(&mut req, &options)?;
423 self.apply_custom_headers(&mut req, &options)?;
424
425 req.headers_mut().insert(
427 "Accept",
428 "text/event-stream"
429 .parse()
430 .map_err(|_| ApiError::InvalidHeader)?,
431 );
432 req.headers_mut().insert(
433 "Cache-Control",
434 "no-store".parse().map_err(|_| ApiError::InvalidHeader)?,
435 );
436
437 let response = self.execute_with_retries(req, &options).await?;
439
440 crate::SseStream::new(response, terminator).await
442 }
443}