1#![warn(missing_docs)]
7
8use serde::{Serialize, de::DeserializeOwned};
9use std::{collections::HashMap, fmt, sync::Arc, time::Duration};
10use tokio::{
11 io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
12 net::TcpStream,
13 time::timeout,
14};
15use tokio_rustls::{TlsConnector, rustls::pki_types::ServerName};
16use tracing::{debug, error, info};
17use url::Url;
18use wae_types::{NetworkErrorKind, WaeError, WaeResult};
19
20#[derive(Debug)]
22pub enum HttpError {
23 InvalidUrl(String),
25
26 DnsFailed(String),
28
29 ConnectionFailed(String),
31
32 TlsError(String),
34
35 Timeout,
37
38 ProtocolError(String),
40
41 ParseError(String),
43
44 SerializationError(String),
46
47 StatusError {
49 status: u16,
51 body: String,
53 },
54}
55
56impl fmt::Display for HttpError {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 match self {
59 HttpError::InvalidUrl(msg) => write!(f, "Invalid URL: {}", msg),
60 HttpError::DnsFailed(msg) => write!(f, "DNS resolution failed: {}", msg),
61 HttpError::ConnectionFailed(msg) => write!(f, "Connection failed: {}", msg),
62 HttpError::TlsError(msg) => write!(f, "TLS error: {}", msg),
63 HttpError::Timeout => write!(f, "Request timeout"),
64 HttpError::ProtocolError(msg) => write!(f, "HTTP protocol error: {}", msg),
65 HttpError::ParseError(msg) => write!(f, "Response parse error: {}", msg),
66 HttpError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
67 HttpError::StatusError { status, body } => write!(f, "HTTP {}: {}", status, body),
68 }
69 }
70}
71
72impl std::error::Error for HttpError {}
73
74#[derive(Debug, Clone)]
76pub struct HttpClientConfig {
77 pub timeout: Duration,
79 pub connect_timeout: Duration,
81 pub user_agent: String,
83 pub max_retries: u32,
85 pub retry_delay: Duration,
87 pub default_headers: HashMap<String, String>,
89}
90
91impl Default for HttpClientConfig {
92 fn default() -> Self {
93 Self {
94 timeout: Duration::from_secs(30),
95 connect_timeout: Duration::from_secs(10),
96 user_agent: "wae-request/0.1.0".to_string(),
97 max_retries: 3,
98 retry_delay: Duration::from_millis(1000),
99 default_headers: HashMap::new(),
100 }
101 }
102}
103
104#[derive(Debug)]
106pub struct HttpResponse {
107 pub version: String,
109 pub status: u16,
111 pub status_text: String,
113 pub headers: HashMap<String, String>,
115 pub body: Vec<u8>,
117}
118
119impl HttpResponse {
120 pub fn json<T: DeserializeOwned>(&self) -> Result<T, HttpError> {
122 serde_json::from_slice(&self.body).map_err(|e| HttpError::ParseError(format!("JSON parse error: {}", e)))
123 }
124
125 pub fn text(&self) -> Result<String, HttpError> {
127 String::from_utf8(self.body.clone()).map_err(|e| HttpError::ParseError(format!("UTF-8 decode error: {}", e)))
128 }
129
130 pub fn is_success(&self) -> bool {
132 self.status >= 200 && self.status < 300
133 }
134}
135
136static TLS_CONNECTOR: std::sync::OnceLock<Arc<TlsConnector>> = std::sync::OnceLock::new();
138
139fn get_tls_connector() -> Arc<TlsConnector> {
140 TLS_CONNECTOR
141 .get_or_init(|| {
142 let mut roots = tokio_rustls::rustls::RootCertStore::empty();
143 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
144 let config = tokio_rustls::rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth();
145 Arc::new(TlsConnector::from(Arc::new(config)))
146 })
147 .clone()
148}
149
150#[derive(Debug, Clone)]
152pub struct HttpClient {
153 config: HttpClientConfig,
154}
155
156impl Default for HttpClient {
157 fn default() -> Self {
158 Self::new(HttpClientConfig::default())
159 }
160}
161
162impl HttpClient {
163 pub fn new(config: HttpClientConfig) -> Self {
165 Self { config }
166 }
167
168 pub fn with_defaults() -> Self {
170 Self::default()
171 }
172
173 pub fn config(&self) -> &HttpClientConfig {
175 &self.config
176 }
177
178 pub async fn get(&self, url: &str) -> Result<HttpResponse, HttpError> {
180 self.request("GET", url, None, None).await
181 }
182
183 pub async fn get_with_headers(&self, url: &str, headers: HashMap<String, String>) -> Result<HttpResponse, HttpError> {
185 self.request("GET", url, None, Some(headers)).await
186 }
187
188 pub async fn post_json<T: Serialize>(&self, url: &str, body: &T) -> Result<HttpResponse, HttpError> {
190 let json_body = serde_json::to_vec(body).map_err(|e| HttpError::SerializationError(e.to_string()))?;
191
192 let mut headers = HashMap::new();
193 headers.insert("Content-Type".to_string(), "application/json".to_string());
194
195 self.request("POST", url, Some(json_body), Some(headers)).await
196 }
197
198 pub async fn post_with_headers(
200 &self,
201 url: &str,
202 body: Vec<u8>,
203 headers: HashMap<String, String>,
204 ) -> Result<HttpResponse, HttpError> {
205 self.request("POST", url, Some(body), Some(headers)).await
206 }
207
208 pub async fn request(
210 &self,
211 method: &str,
212 url: &str,
213 body: Option<Vec<u8>>,
214 headers: Option<HashMap<String, String>>,
215 ) -> Result<HttpResponse, HttpError> {
216 let mut last_error = None;
217
218 for attempt in 0..=self.config.max_retries {
219 if attempt > 0 {
220 let delay = self.config.retry_delay * attempt;
221 debug!("Retry attempt {} after {:?}", attempt, delay);
222 tokio::time::sleep(delay).await;
223 }
224
225 match self.request_once(method, url, body.clone(), headers.clone()).await {
226 Ok(response) => {
227 if response.is_success() {
228 info!("Request succeeded on attempt {}", attempt);
229 return Ok(response);
230 }
231
232 if Self::is_retryable_status(response.status) && attempt < self.config.max_retries {
233 last_error = Some(HttpError::StatusError {
234 status: response.status,
235 body: String::from_utf8_lossy(&response.body).to_string(),
236 });
237 continue;
238 }
239
240 return Err(HttpError::StatusError {
241 status: response.status,
242 body: String::from_utf8_lossy(&response.body).to_string(),
243 });
244 }
245 Err(e) => {
246 error!("Request error on attempt {}: {}", attempt, e);
247 if Self::is_retryable_error(&e) && attempt < self.config.max_retries {
248 last_error = Some(e);
249 continue;
250 }
251 return Err(e);
252 }
253 }
254 }
255
256 Err(last_error.unwrap_or(HttpError::Timeout))
257 }
258
259 async fn request_once(
261 &self,
262 method: &str,
263 url_str: &str,
264 body: Option<Vec<u8>>,
265 extra_headers: Option<HashMap<String, String>>,
266 ) -> Result<HttpResponse, HttpError> {
267 let url = Url::parse(url_str).map_err(|e| HttpError::InvalidUrl(e.to_string()))?;
268
269 let host = url.host_str().ok_or_else(|| HttpError::InvalidUrl("Missing host".into()))?;
270 let port = url.port().unwrap_or(if url.scheme() == "https" { 443 } else { 80 });
271 let path = url.path();
272 let query = url.query().map(|q| format!("?{}", q)).unwrap_or_default();
273 let uri = format!("{}{}", path, query);
274
275 let is_https = url.scheme() == "https";
276
277 let connect_result =
278 timeout(self.config.connect_timeout, TcpStream::connect((host, port))).await.map_err(|_| HttpError::Timeout)?;
279
280 let tcp_stream = connect_result.map_err(|e| HttpError::ConnectionFailed(format!("TCP connect failed: {}", e)))?;
281
282 tcp_stream.set_nodelay(true).ok();
283
284 let response = if is_https {
285 let connector = get_tls_connector();
286 let server_name = ServerName::try_from(host.to_string())
287 .map_err(|e| HttpError::TlsError(format!("Invalid server name: {}", e)))?;
288
289 let tls_stream = connector
290 .connect(server_name, tcp_stream)
291 .await
292 .map_err(|e| HttpError::TlsError(format!("TLS handshake failed: {}", e)))?;
293
294 let (reader, writer) = tokio::io::split(tls_stream);
295 self.send_http_request(reader, writer, method, host, &uri, body, extra_headers).await?
296 }
297 else {
298 let (reader, writer) = tcp_stream.into_split();
299 self.send_http_request(reader, writer, method, host, &uri, body, extra_headers).await?
300 };
301
302 Ok(response)
303 }
304
305 #[allow(clippy::too_many_arguments)]
307 async fn send_http_request<R, W>(
308 &self,
309 reader: R,
310 mut writer: W,
311 method: &str,
312 host: &str,
313 uri: &str,
314 body: Option<Vec<u8>>,
315 extra_headers: Option<HashMap<String, String>>,
316 ) -> Result<HttpResponse, HttpError>
317 where
318 R: AsyncReadExt + Unpin,
319 W: AsyncWriteExt + Unpin,
320 {
321 let body_len = body.as_ref().map(|b| b.len()).unwrap_or(0);
322
323 let mut request =
324 format!("{} {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: {}\r\n", method, uri, host, self.config.user_agent);
325
326 if body_len > 0 {
327 request.push_str(&format!("Content-Length: {}\r\n", body_len));
328 }
329
330 for (key, value) in &self.config.default_headers {
331 request.push_str(&format!("{}: {}\r\n", key, value));
332 }
333
334 if let Some(headers) = extra_headers {
335 for (key, value) in headers {
336 request.push_str(&format!("{}: {}\r\n", key, value));
337 }
338 }
339
340 request.push_str("Connection: close\r\n\r\n");
341
342 let mut request_bytes = request.into_bytes();
343 if let Some(b) = body {
344 request_bytes.extend(b);
345 }
346
347 timeout(self.config.timeout, async {
348 writer
349 .write_all(&request_bytes)
350 .await
351 .map_err(|e| HttpError::ConnectionFailed(format!("Write request failed: {}", e)))?;
352 writer.flush().await.map_err(|e| HttpError::ConnectionFailed(format!("Flush failed: {}", e)))?;
353 Ok::<_, HttpError>(())
354 })
355 .await
356 .map_err(|_| HttpError::Timeout)??;
357
358 let response = timeout(self.config.timeout, self.read_response(reader)).await.map_err(|_| HttpError::Timeout)??;
359
360 Ok(response)
361 }
362
363 async fn read_response<R: AsyncReadExt + Unpin>(&self, reader: R) -> Result<HttpResponse, HttpError> {
365 let mut buf_reader = BufReader::new(reader);
366 let mut status_line = String::new();
367
368 buf_reader
369 .read_line(&mut status_line)
370 .await
371 .map_err(|e| HttpError::ProtocolError(format!("Read status line failed: {}", e)))?;
372
373 let status_parts: Vec<&str> = status_line.trim().splitn(3, ' ').collect();
374 if status_parts.len() < 2 {
375 return Err(HttpError::ProtocolError("Invalid status line".into()));
376 }
377
378 let version = status_parts[0].to_string();
379 let status: u16 = status_parts[1].parse().map_err(|_| HttpError::ProtocolError("Invalid status code".into()))?;
380 let status_text = status_parts.get(2).unwrap_or(&"").to_string();
381
382 let mut headers = HashMap::new();
383 loop {
384 let mut line = String::new();
385 buf_reader
386 .read_line(&mut line)
387 .await
388 .map_err(|e| HttpError::ProtocolError(format!("Read header failed: {}", e)))?;
389
390 if line == "\r\n" || line.is_empty() {
391 break;
392 }
393
394 if let Some((key, value)) = line.split_once(':') {
395 headers.insert(key.trim().to_string(), value.trim().to_string());
396 }
397 }
398
399 let content_length: Option<usize> = headers.get("content-length").and_then(|v| v.parse().ok());
400
401 let mut body = Vec::new();
402
403 if let Some(len) = content_length {
404 body.resize(len, 0);
405 buf_reader.read_exact(&mut body).await.map_err(|e| HttpError::ProtocolError(format!("Read body failed: {}", e)))?;
406 }
407 else {
408 buf_reader
409 .read_to_end(&mut body)
410 .await
411 .map_err(|e| HttpError::ProtocolError(format!("Read body failed: {}", e)))?;
412 }
413
414 Ok(HttpResponse { version, status, status_text, headers, body })
415 }
416
417 fn is_retryable_status(status: u16) -> bool {
419 matches!(status, 408 | 429 | 500 | 502 | 503 | 504)
420 }
421
422 fn is_retryable_error(error: &HttpError) -> bool {
424 matches!(error, HttpError::Timeout | HttpError::ConnectionFailed(_) | HttpError::DnsFailed(_))
425 }
426}
427
428pub struct RequestBuilder {
430 client: HttpClient,
431 method: String,
432 url: String,
433 headers: HashMap<String, String>,
434 body: Option<Vec<u8>>,
435}
436
437impl RequestBuilder {
438 pub fn new(client: HttpClient, method: &str, url: &str) -> Self {
440 Self { client, method: method.to_string(), url: url.to_string(), headers: HashMap::new(), body: None }
441 }
442
443 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
445 self.headers.insert(key.into(), value.into());
446 self
447 }
448
449 pub fn bearer_auth(self, token: impl Into<String>) -> Self {
451 self.header("Authorization", format!("Bearer {}", token.into()))
452 }
453
454 pub fn json<T: Serialize>(mut self, body: &T) -> Self {
456 if let Ok(json) = serde_json::to_vec(body) {
457 self.headers.insert("Content-Type".into(), "application/json".into());
458 self.body = Some(json);
459 }
460 self
461 }
462
463 pub fn body(mut self, body: Vec<u8>) -> Self {
465 self.body = Some(body);
466 self
467 }
468
469 pub async fn send(self) -> Result<HttpResponse, HttpError> {
471 self.client.request(&self.method, &self.url, self.body, Some(self.headers)).await
472 }
473
474 pub async fn send_json<T: DeserializeOwned>(self) -> Result<T, HttpError> {
476 let response = self.send().await?;
477 response.json()
478 }
479}
480
481pub fn get(url: &str) -> RequestBuilder {
483 RequestBuilder::new(HttpClient::default(), "GET", url)
484}
485
486pub fn post(url: &str) -> RequestBuilder {
488 RequestBuilder::new(HttpClient::default(), "POST", url)
489}
490
491impl From<HttpError> for WaeError {
493 fn from(error: HttpError) -> Self {
494 match error {
495 HttpError::InvalidUrl(msg) => WaeError::network(NetworkErrorKind::ConnectionFailed).with_param("message", msg),
496 HttpError::Timeout => WaeError::network(NetworkErrorKind::Timeout),
497 HttpError::ConnectionFailed(msg) => {
498 WaeError::network(NetworkErrorKind::ConnectionFailed).with_param("message", msg)
499 }
500 HttpError::DnsFailed(msg) => WaeError::network(NetworkErrorKind::DnsFailed).with_param("detail", msg),
501 HttpError::TlsError(msg) => WaeError::network(NetworkErrorKind::TlsError).with_param("message", msg),
502 HttpError::StatusError { status, body } => WaeError::network(NetworkErrorKind::ProtocolError)
503 .with_param("status", status.to_string())
504 .with_param("body", body),
505 HttpError::ProtocolError(msg) => WaeError::network(NetworkErrorKind::ProtocolError).with_param("message", msg),
506 HttpError::ParseError(msg) => WaeError::internal(format!("Response parse error: {}", msg)),
507 HttpError::SerializationError(msg) => WaeError::internal(format!("Serialization error: {}", msg)),
508 }
509 }
510}
511
512#[derive(Debug, Clone)]
514pub struct RequestClient {
515 client: HttpClient,
516 config: RequestConfig,
517}
518
519#[derive(Debug, Clone)]
521pub struct RequestConfig {
522 pub timeout_secs: u64,
524 pub connect_timeout_secs: u64,
526 pub max_retries: u32,
528 pub retry_delay_ms: u64,
530 pub user_agent: String,
532}
533
534impl Default for RequestConfig {
535 fn default() -> Self {
536 Self {
537 timeout_secs: 30,
538 connect_timeout_secs: 10,
539 max_retries: 3,
540 retry_delay_ms: 1000,
541 user_agent: "wae-request/0.1.0".to_string(),
542 }
543 }
544}
545
546impl RequestClient {
547 pub fn new(config: RequestConfig) -> WaeResult<Self> {
549 let http_config = HttpClientConfig {
550 timeout: Duration::from_secs(config.timeout_secs),
551 connect_timeout: Duration::from_secs(config.connect_timeout_secs),
552 user_agent: config.user_agent.clone(),
553 max_retries: config.max_retries,
554 retry_delay: Duration::from_millis(config.retry_delay_ms),
555 default_headers: HashMap::new(),
556 };
557 Ok(Self { client: HttpClient::new(http_config), config })
558 }
559
560 pub fn with_defaults() -> WaeResult<Self> {
562 Self::new(RequestConfig::default())
563 }
564
565 pub fn config(&self) -> &RequestConfig {
567 &self.config
568 }
569
570 pub async fn get<T: DeserializeOwned>(&self, url: &str) -> WaeResult<T> {
572 let response = self.client.get(url).await.map_err(WaeError::from)?;
573 response.json().map_err(WaeError::from)
574 }
575
576 pub async fn get_raw(&self, url: &str) -> WaeResult<HttpResponse> {
578 self.client.get(url).await.map_err(WaeError::from)
579 }
580
581 pub async fn post<T: DeserializeOwned, B: Serialize>(&self, url: &str, body: &B) -> WaeResult<T> {
583 let response = self.client.post_json(url, body).await.map_err(WaeError::from)?;
584 response.json().map_err(WaeError::from)
585 }
586
587 pub async fn post_raw<B: Serialize>(&self, url: &str, body: &B) -> WaeResult<HttpResponse> {
589 self.client.post_json(url, body).await.map_err(WaeError::from)
590 }
591
592 pub fn builder(&self) -> RequestBuilder {
594 RequestBuilder::new(self.client.clone(), "GET", "")
595 }
596}