Skip to main content

alpaca_http/
client.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use alpaca_core::BaseUrl;
5use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
6use serde::de::DeserializeOwned;
7
8use crate::Error;
9use crate::auth::Authenticator;
10use crate::meta::{ErrorMeta, HttpResponse, ResponseMeta};
11use crate::observer::{
12    ErrorEvent, NoopObserver, RequestStart, ResponseEvent, RetryEvent, TransportObserver,
13};
14use crate::rate_limit::ConcurrencyLimit;
15use crate::request::{NoContent, RequestBody, RequestParts};
16use crate::retry::{RetryConfig, RetryDecision};
17
18#[derive(Clone)]
19pub struct HttpClient {
20    client: reqwest::Client,
21    default_headers: HeaderMap,
22    request_id_header_name: HeaderName,
23    retry_config: RetryConfig,
24    observer: Arc<dyn TransportObserver>,
25    concurrency_limit: ConcurrencyLimit,
26}
27
28#[derive(Clone)]
29pub struct HttpClientBuilder {
30    reqwest_client: Option<reqwest::Client>,
31    timeout: Duration,
32    default_headers: HeaderMap,
33    request_id_header_name: HeaderName,
34    retry_config: RetryConfig,
35    observer: Arc<dyn TransportObserver>,
36    concurrency_limit: ConcurrencyLimit,
37}
38
39struct ResponseParts {
40    meta: ResponseMeta,
41    body: String,
42}
43
44impl HttpClient {
45    #[must_use]
46    pub fn builder() -> HttpClientBuilder {
47        HttpClientBuilder::default()
48    }
49
50    pub async fn send_json<T>(
51        &self,
52        base_url: &BaseUrl,
53        request: RequestParts,
54        authenticator: Option<&dyn Authenticator>,
55    ) -> Result<HttpResponse<T>, Error>
56    where
57        T: DeserializeOwned,
58    {
59        let response = self.send(base_url, &request, authenticator).await?;
60        let parsed = serde_json::from_str(&response.body).map_err(|error| {
61            let meta = ErrorMeta::from_response_meta(response.meta.clone(), response.body.clone());
62            let error = Error::Deserialize {
63                message: error.to_string(),
64                meta: Some(meta.clone()),
65            };
66            self.observer.on_error(&ErrorEvent { meta: Some(meta) });
67            error
68        })?;
69
70        self.observer.on_response(&ResponseEvent {
71            meta: response.meta.clone(),
72        });
73        Ok(HttpResponse::new(parsed, response.meta))
74    }
75
76    pub async fn send_text(
77        &self,
78        base_url: &BaseUrl,
79        request: RequestParts,
80        authenticator: Option<&dyn Authenticator>,
81    ) -> Result<HttpResponse<String>, Error> {
82        let response = self.send(base_url, &request, authenticator).await?;
83        self.observer.on_response(&ResponseEvent {
84            meta: response.meta.clone(),
85        });
86        Ok(HttpResponse::new(response.body, response.meta))
87    }
88
89    pub async fn send_no_content(
90        &self,
91        base_url: &BaseUrl,
92        request: RequestParts,
93        authenticator: Option<&dyn Authenticator>,
94    ) -> Result<HttpResponse<NoContent>, Error> {
95        let response = self.send(base_url, &request, authenticator).await?;
96        if response.meta.status() != 204 {
97            let meta = ErrorMeta::from_response_meta(response.meta, response.body);
98            let error = Error::HttpStatus(meta.clone());
99            self.observer.on_error(&ErrorEvent { meta: Some(meta) });
100            return Err(error);
101        }
102
103        self.observer.on_response(&ResponseEvent {
104            meta: response.meta.clone(),
105        });
106        Ok(HttpResponse::new(NoContent, response.meta))
107    }
108
109    async fn send(
110        &self,
111        base_url: &BaseUrl,
112        request: &RequestParts,
113        authenticator: Option<&dyn Authenticator>,
114    ) -> Result<ResponseParts, Error> {
115        let _permit = self.concurrency_limit.acquire().await?;
116        let url = base_url.join_path(request.path());
117        let mut attempt = 0;
118        let started_at = Instant::now();
119
120        loop {
121            self.observer.on_request_start(&RequestStart {
122                operation: request.operation().map(ToOwned::to_owned),
123                method: request.method(),
124                url: url.clone(),
125            });
126
127            let request_builder = self.build_request(&url, request, authenticator)?;
128            let response = request_builder.send().await.map_err(|error| {
129                let error = Error::from_reqwest(error, None);
130                self.observer.on_error(&ErrorEvent { meta: None });
131                error
132            })?;
133
134            let status = response.status();
135            let headers = response.headers().clone();
136            let meta = ResponseMeta::from_response_parts(
137                request.operation().map(ToOwned::to_owned),
138                url.clone(),
139                status,
140                &headers,
141                &self.request_id_header_name,
142                attempt + 1,
143                started_at.elapsed(),
144            );
145            let body = response.text().await.map_err(|error| {
146                let error_meta = ErrorMeta::from_response_meta(meta.clone(), String::new());
147                let error = Error::from_reqwest(error, Some(error_meta.clone()));
148                self.observer.on_error(&ErrorEvent {
149                    meta: Some(error_meta),
150                });
151                error
152            })?;
153
154            match self.retry_config.classify_response(
155                &request.method(),
156                status,
157                attempt,
158                meta.retry_after(),
159                started_at.elapsed(),
160            ) {
161                RetryDecision::RetryAfter(wait) => {
162                    self.observer.on_retry(&RetryEvent {
163                        operation: request.operation().map(ToOwned::to_owned),
164                        method: request.method(),
165                        url: url.clone(),
166                        attempt: attempt + 1,
167                        status: Some(status),
168                        wait,
169                    });
170                    tokio::time::sleep(wait).await;
171                    attempt += 1;
172                    continue;
173                }
174                RetryDecision::DoNotRetry => {}
175            }
176
177            if status.is_success() {
178                return Ok(ResponseParts { meta, body });
179            }
180
181            let error_meta = ErrorMeta::from_response_meta(meta, body);
182            let error = if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
183                Error::RateLimited(error_meta.clone())
184            } else {
185                Error::HttpStatus(error_meta.clone())
186            };
187            self.observer.on_error(&ErrorEvent {
188                meta: Some(error_meta),
189            });
190            return Err(error);
191        }
192    }
193
194    fn build_request(
195        &self,
196        url: &str,
197        request: &RequestParts,
198        authenticator: Option<&dyn Authenticator>,
199    ) -> Result<reqwest::RequestBuilder, Error> {
200        let mut headers = self.default_headers.clone();
201        headers.extend(request.headers().clone());
202        if let Some(authenticator) = authenticator {
203            authenticator.apply(&mut headers)?;
204        }
205
206        let mut builder = self
207            .client
208            .request(request.method(), url)
209            .headers(headers)
210            .query(request.query());
211
212        builder = match request.body() {
213            RequestBody::Empty => builder,
214            RequestBody::Json(value) => builder.json(value),
215            RequestBody::Text(value) => builder.body(value.clone()),
216            RequestBody::Bytes(value) => builder.body(value.clone()),
217        };
218
219        if matches!(request.body(), RequestBody::Text(_))
220            && !request.headers().contains_key(CONTENT_TYPE)
221            && !self.default_headers.contains_key(CONTENT_TYPE)
222        {
223            builder = builder.header(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
224        }
225
226        Ok(builder)
227    }
228}
229
230impl Default for HttpClientBuilder {
231    fn default() -> Self {
232        Self {
233            reqwest_client: None,
234            timeout: Duration::from_secs(30),
235            default_headers: HeaderMap::new(),
236            request_id_header_name: HeaderName::from_static("x-request-id"),
237            retry_config: RetryConfig::default(),
238            observer: Arc::new(NoopObserver),
239            concurrency_limit: ConcurrencyLimit::default(),
240        }
241    }
242}
243
244impl HttpClientBuilder {
245    #[must_use]
246    pub fn timeout(mut self, timeout: Duration) -> Self {
247        self.timeout = timeout;
248        self
249    }
250
251    #[must_use]
252    pub fn reqwest_client(mut self, client: reqwest::Client) -> Self {
253        self.reqwest_client = Some(client);
254        self
255    }
256
257    pub fn default_header(mut self, name: &str, value: &str) -> Result<Self, Error> {
258        let name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
259            Error::InvalidRequest(format!("invalid default header name: {error}"))
260        })?;
261        let value = HeaderValue::from_str(value).map_err(|error| {
262            Error::InvalidRequest(format!("invalid default header value: {error}"))
263        })?;
264        self.default_headers.insert(name, value);
265        Ok(self)
266    }
267
268    pub fn request_id_header_name(mut self, name: &str) -> Result<Self, Error> {
269        self.request_id_header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
270            Error::InvalidRequest(format!("invalid request id header name: {error}"))
271        })?;
272        Ok(self)
273    }
274
275    #[must_use]
276    pub fn retry_config(mut self, retry_config: RetryConfig) -> Self {
277        self.retry_config = retry_config;
278        self
279    }
280
281    #[must_use]
282    pub fn observer(mut self, observer: Arc<dyn TransportObserver>) -> Self {
283        self.observer = observer;
284        self
285    }
286
287    #[must_use]
288    pub fn concurrency_limit(mut self, concurrency_limit: ConcurrencyLimit) -> Self {
289        self.concurrency_limit = concurrency_limit;
290        self
291    }
292
293    pub fn build(self) -> Result<HttpClient, Error> {
294        let client = match self.reqwest_client {
295            Some(client) => client,
296            None => reqwest::Client::builder()
297                .timeout(self.timeout)
298                .build()
299                .map_err(|error| Error::from_reqwest(error, None))?,
300        };
301
302        Ok(HttpClient {
303            client,
304            default_headers: self.default_headers,
305            request_id_header_name: self.request_id_header_name,
306            retry_config: self.retry_config,
307            observer: self.observer,
308            concurrency_limit: self.concurrency_limit,
309        })
310    }
311}