1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use alpaca_core::BaseUrl;
5use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
6use serde::de::DeserializeOwned;
7
8use crate::Error;
9use crate::auth::Authenticator;
10use crate::meta::{ErrorMeta, HttpResponse, ResponseMeta};
11use crate::observer::{
12 ErrorEvent, NoopObserver, RequestStart, ResponseEvent, RetryEvent, TransportObserver,
13};
14use crate::rate_limit::ConcurrencyLimit;
15use crate::request::{NoContent, RequestBody, RequestParts};
16use crate::retry::{RetryConfig, RetryDecision};
17
18#[derive(Clone)]
19pub struct HttpClient {
20 client: reqwest::Client,
21 default_headers: HeaderMap,
22 request_id_header_name: HeaderName,
23 retry_config: RetryConfig,
24 observer: Arc<dyn TransportObserver>,
25 concurrency_limit: ConcurrencyLimit,
26}
27
28#[derive(Clone)]
29pub struct HttpClientBuilder {
30 reqwest_client: Option<reqwest::Client>,
31 timeout: Duration,
32 default_headers: HeaderMap,
33 request_id_header_name: HeaderName,
34 retry_config: RetryConfig,
35 observer: Arc<dyn TransportObserver>,
36 concurrency_limit: ConcurrencyLimit,
37}
38
39struct ResponseParts {
40 meta: ResponseMeta,
41 body: String,
42}
43
44impl HttpClient {
45 #[must_use]
46 pub fn builder() -> HttpClientBuilder {
47 HttpClientBuilder::default()
48 }
49
50 pub async fn send_json<T>(
51 &self,
52 base_url: &BaseUrl,
53 request: RequestParts,
54 authenticator: Option<&dyn Authenticator>,
55 ) -> Result<HttpResponse<T>, Error>
56 where
57 T: DeserializeOwned,
58 {
59 let response = self.send(base_url, &request, authenticator).await?;
60 let parsed = serde_json::from_str(&response.body).map_err(|error| {
61 let meta = ErrorMeta::from_response_meta(response.meta.clone(), response.body.clone());
62 let error = Error::Deserialize {
63 message: error.to_string(),
64 meta: Some(meta.clone()),
65 };
66 self.observer.on_error(&ErrorEvent { meta: Some(meta) });
67 error
68 })?;
69
70 self.observer.on_response(&ResponseEvent {
71 meta: response.meta.clone(),
72 });
73 Ok(HttpResponse::new(parsed, response.meta))
74 }
75
76 pub async fn send_text(
77 &self,
78 base_url: &BaseUrl,
79 request: RequestParts,
80 authenticator: Option<&dyn Authenticator>,
81 ) -> Result<HttpResponse<String>, Error> {
82 let response = self.send(base_url, &request, authenticator).await?;
83 self.observer.on_response(&ResponseEvent {
84 meta: response.meta.clone(),
85 });
86 Ok(HttpResponse::new(response.body, response.meta))
87 }
88
89 pub async fn send_no_content(
90 &self,
91 base_url: &BaseUrl,
92 request: RequestParts,
93 authenticator: Option<&dyn Authenticator>,
94 ) -> Result<HttpResponse<NoContent>, Error> {
95 let response = self.send(base_url, &request, authenticator).await?;
96 if response.meta.status() != 204 {
97 let meta = ErrorMeta::from_response_meta(response.meta, response.body);
98 let error = Error::HttpStatus(meta.clone());
99 self.observer.on_error(&ErrorEvent { meta: Some(meta) });
100 return Err(error);
101 }
102
103 self.observer.on_response(&ResponseEvent {
104 meta: response.meta.clone(),
105 });
106 Ok(HttpResponse::new(NoContent, response.meta))
107 }
108
109 async fn send(
110 &self,
111 base_url: &BaseUrl,
112 request: &RequestParts,
113 authenticator: Option<&dyn Authenticator>,
114 ) -> Result<ResponseParts, Error> {
115 let _permit = self.concurrency_limit.acquire().await?;
116 let url = base_url.join_path(request.path());
117 let mut attempt = 0;
118 let started_at = Instant::now();
119
120 loop {
121 self.observer.on_request_start(&RequestStart {
122 operation: request.operation().map(ToOwned::to_owned),
123 method: request.method(),
124 url: url.clone(),
125 });
126
127 let request_builder = self.build_request(&url, request, authenticator)?;
128 let response = request_builder.send().await.map_err(|error| {
129 let error = Error::from_reqwest(error, None);
130 self.observer.on_error(&ErrorEvent { meta: None });
131 error
132 })?;
133
134 let status = response.status();
135 let headers = response.headers().clone();
136 let meta = ResponseMeta::from_response_parts(
137 request.operation().map(ToOwned::to_owned),
138 url.clone(),
139 status,
140 &headers,
141 &self.request_id_header_name,
142 attempt + 1,
143 started_at.elapsed(),
144 );
145 let body = response.text().await.map_err(|error| {
146 let error_meta = ErrorMeta::from_response_meta(meta.clone(), String::new());
147 let error = Error::from_reqwest(error, Some(error_meta.clone()));
148 self.observer.on_error(&ErrorEvent {
149 meta: Some(error_meta),
150 });
151 error
152 })?;
153
154 match self.retry_config.classify_response(
155 &request.method(),
156 status,
157 attempt,
158 meta.retry_after(),
159 started_at.elapsed(),
160 ) {
161 RetryDecision::RetryAfter(wait) => {
162 self.observer.on_retry(&RetryEvent {
163 operation: request.operation().map(ToOwned::to_owned),
164 method: request.method(),
165 url: url.clone(),
166 attempt: attempt + 1,
167 status: Some(status),
168 wait,
169 });
170 tokio::time::sleep(wait).await;
171 attempt += 1;
172 continue;
173 }
174 RetryDecision::DoNotRetry => {}
175 }
176
177 if status.is_success() {
178 return Ok(ResponseParts { meta, body });
179 }
180
181 let error_meta = ErrorMeta::from_response_meta(meta, body);
182 let error = if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
183 Error::RateLimited(error_meta.clone())
184 } else {
185 Error::HttpStatus(error_meta.clone())
186 };
187 self.observer.on_error(&ErrorEvent {
188 meta: Some(error_meta),
189 });
190 return Err(error);
191 }
192 }
193
194 fn build_request(
195 &self,
196 url: &str,
197 request: &RequestParts,
198 authenticator: Option<&dyn Authenticator>,
199 ) -> Result<reqwest::RequestBuilder, Error> {
200 let mut headers = self.default_headers.clone();
201 headers.extend(request.headers().clone());
202 if let Some(authenticator) = authenticator {
203 authenticator.apply(&mut headers)?;
204 }
205
206 let mut builder = self
207 .client
208 .request(request.method(), url)
209 .headers(headers)
210 .query(request.query());
211
212 builder = match request.body() {
213 RequestBody::Empty => builder,
214 RequestBody::Json(value) => builder.json(value),
215 RequestBody::Text(value) => builder.body(value.clone()),
216 RequestBody::Bytes(value) => builder.body(value.clone()),
217 };
218
219 if matches!(request.body(), RequestBody::Text(_))
220 && !request.headers().contains_key(CONTENT_TYPE)
221 && !self.default_headers.contains_key(CONTENT_TYPE)
222 {
223 builder = builder.header(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
224 }
225
226 Ok(builder)
227 }
228}
229
230impl Default for HttpClientBuilder {
231 fn default() -> Self {
232 Self {
233 reqwest_client: None,
234 timeout: Duration::from_secs(30),
235 default_headers: HeaderMap::new(),
236 request_id_header_name: HeaderName::from_static("x-request-id"),
237 retry_config: RetryConfig::default(),
238 observer: Arc::new(NoopObserver),
239 concurrency_limit: ConcurrencyLimit::default(),
240 }
241 }
242}
243
244impl HttpClientBuilder {
245 #[must_use]
246 pub fn timeout(mut self, timeout: Duration) -> Self {
247 self.timeout = timeout;
248 self
249 }
250
251 #[must_use]
252 pub fn reqwest_client(mut self, client: reqwest::Client) -> Self {
253 self.reqwest_client = Some(client);
254 self
255 }
256
257 pub fn default_header(mut self, name: &str, value: &str) -> Result<Self, Error> {
258 let name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
259 Error::InvalidRequest(format!("invalid default header name: {error}"))
260 })?;
261 let value = HeaderValue::from_str(value).map_err(|error| {
262 Error::InvalidRequest(format!("invalid default header value: {error}"))
263 })?;
264 self.default_headers.insert(name, value);
265 Ok(self)
266 }
267
268 pub fn request_id_header_name(mut self, name: &str) -> Result<Self, Error> {
269 self.request_id_header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
270 Error::InvalidRequest(format!("invalid request id header name: {error}"))
271 })?;
272 Ok(self)
273 }
274
275 #[must_use]
276 pub fn retry_config(mut self, retry_config: RetryConfig) -> Self {
277 self.retry_config = retry_config;
278 self
279 }
280
281 #[must_use]
282 pub fn observer(mut self, observer: Arc<dyn TransportObserver>) -> Self {
283 self.observer = observer;
284 self
285 }
286
287 #[must_use]
288 pub fn concurrency_limit(mut self, concurrency_limit: ConcurrencyLimit) -> Self {
289 self.concurrency_limit = concurrency_limit;
290 self
291 }
292
293 pub fn build(self) -> Result<HttpClient, Error> {
294 let client = match self.reqwest_client {
295 Some(client) => client,
296 None => reqwest::Client::builder()
297 .timeout(self.timeout)
298 .build()
299 .map_err(|error| Error::from_reqwest(error, None))?,
300 };
301
302 Ok(HttpClient {
303 client,
304 default_headers: self.default_headers,
305 request_id_header_name: self.request_id_header_name,
306 retry_config: self.retry_config,
307 observer: self.observer,
308 concurrency_limit: self.concurrency_limit,
309 })
310 }
311}