1use super::common::{
17 Headers, Method, ProtocolError, ProtocolResult, RetryConfig, StatusCode, Timeout, Uri,
18};
19use std::collections::HashMap;
20use std::time::Duration;
21
22#[derive(Debug, Clone)]
24pub struct ClientConfig {
25 pub base_url: Option<String>,
27 pub default_headers: Headers,
29 pub timeout: Timeout,
31 pub retry: Option<RetryConfig>,
33 pub max_redirects: u32,
35 pub follow_redirects: bool,
37 pub user_agent: String,
39 pub accept_compressed: bool,
41 pub pool_idle_timeout: Option<Duration>,
43 pub pool_max_idle_per_host: usize,
45}
46
47impl Default for ClientConfig {
48 fn default() -> Self {
49 ClientConfig {
50 base_url: None,
51 default_headers: Headers::new(),
52 timeout: Timeout::new()
53 .connect_timeout(Duration::from_secs(30))
54 .read_timeout(Duration::from_secs(30)),
55 retry: None,
56 max_redirects: 10,
57 follow_redirects: true,
58 user_agent: format!("sigil-http/{}", env!("CARGO_PKG_VERSION")),
59 accept_compressed: true,
60 pool_idle_timeout: Some(Duration::from_secs(90)),
61 pool_max_idle_per_host: 10,
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct Client {
69 config: ClientConfig,
70 #[cfg(feature = "reqwest")]
71 inner: Option<reqwest::Client>,
72}
73
74impl Client {
75 pub fn new() -> Self {
77 Client::with_config(ClientConfig::default())
78 }
79
80 pub fn with_config(config: ClientConfig) -> Self {
82 #[cfg(feature = "reqwest")]
83 let inner = {
84 let mut builder = reqwest::Client::builder()
85 .user_agent(&config.user_agent)
86 .redirect(if config.follow_redirects {
87 reqwest::redirect::Policy::limited(config.max_redirects as usize)
88 } else {
89 reqwest::redirect::Policy::none()
90 });
91
92 if let Some(timeout) = config.timeout.connect {
93 builder = builder.connect_timeout(timeout);
94 }
95 if let Some(timeout) = config.timeout.read {
96 builder = builder.read_timeout(timeout);
97 }
98 if let Some(timeout) = config.timeout.total {
99 builder = builder.timeout(timeout);
100 }
101 if let Some(idle) = config.pool_idle_timeout {
102 builder = builder.pool_idle_timeout(idle);
103 }
104
105 if config.accept_compressed {
106 builder = builder.gzip(true).brotli(true);
107 }
108
109 builder.build().ok()
110 };
111
112 #[cfg(not(feature = "reqwest"))]
113 let inner = None;
114
115 Client {
116 config,
117 #[cfg(feature = "reqwest")]
118 inner,
119 }
120 }
121
122 pub fn base_url(mut self, url: impl Into<String>) -> Self {
124 self.config.base_url = Some(url.into());
125 self
126 }
127
128 pub fn default_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
130 self.config.default_headers.insert(key, value);
131 self
132 }
133
134 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
136 self.config.timeout = self.config.timeout.connect_timeout(timeout);
137 self
138 }
139
140 pub fn read_timeout(mut self, timeout: Duration) -> Self {
142 self.config.timeout = self.config.timeout.read_timeout(timeout);
143 self
144 }
145
146 pub fn timeout(mut self, timeout: Duration) -> Self {
148 self.config.timeout = self.config.timeout.total_timeout(timeout);
149 self
150 }
151
152 pub fn retry(mut self, config: RetryConfig) -> Self {
154 self.config.retry = Some(config);
155 self
156 }
157
158 pub fn user_agent(mut self, ua: impl Into<String>) -> Self {
160 self.config.user_agent = ua.into();
161 self
162 }
163
164 pub fn bearer_auth(self, token: impl Into<String>) -> Self {
166 self.default_header("Authorization", format!("Bearer {}", token.into()))
167 }
168
169 pub fn basic_auth(self, username: impl Into<String>, password: impl Into<String>) -> Self {
171 use base64::{engine::general_purpose::STANDARD, Engine};
172 let credentials = format!("{}:{}", username.into(), password.into());
173 let encoded = STANDARD.encode(credentials.as_bytes());
174 self.default_header("Authorization", format!("Basic {}", encoded))
175 }
176
177 pub fn get(&self, url: impl Into<String>) -> RequestBuilder {
179 self.request(Method::GET, url)
180 }
181
182 pub fn post(&self, url: impl Into<String>) -> RequestBuilder {
184 self.request(Method::POST, url)
185 }
186
187 pub fn put(&self, url: impl Into<String>) -> RequestBuilder {
189 self.request(Method::PUT, url)
190 }
191
192 pub fn delete(&self, url: impl Into<String>) -> RequestBuilder {
194 self.request(Method::DELETE, url)
195 }
196
197 pub fn patch(&self, url: impl Into<String>) -> RequestBuilder {
199 self.request(Method::PATCH, url)
200 }
201
202 pub fn head(&self, url: impl Into<String>) -> RequestBuilder {
204 self.request(Method::HEAD, url)
205 }
206
207 pub fn request(&self, method: Method, url: impl Into<String>) -> RequestBuilder {
209 let url_str = url.into();
210 let full_url = if let Some(ref base) = self.config.base_url {
211 if url_str.starts_with("http://") || url_str.starts_with("https://") {
212 url_str
213 } else {
214 format!(
215 "{}{}",
216 base.trim_end_matches('/'),
217 if url_str.starts_with('/') {
218 url_str
219 } else {
220 format!("/{}", url_str)
221 }
222 )
223 }
224 } else {
225 url_str
226 };
227
228 RequestBuilder {
229 client: self.clone(),
230 method,
231 url: full_url,
232 headers: self.config.default_headers.clone(),
233 query: Vec::new(),
234 body: None,
235 timeout: self.config.timeout.total,
236 }
237 }
238
239 fn resolve_url(&self, url: &str) -> String {
241 if let Some(ref base) = self.config.base_url {
242 if url.starts_with("http://") || url.starts_with("https://") {
243 url.to_string()
244 } else {
245 format!(
246 "{}{}",
247 base.trim_end_matches('/'),
248 if url.starts_with('/') {
249 url.to_string()
250 } else {
251 format!("/{}", url)
252 }
253 )
254 }
255 } else {
256 url.to_string()
257 }
258 }
259}
260
261impl Default for Client {
262 fn default() -> Self {
263 Client::new()
264 }
265}
266
267#[derive(Debug, Clone)]
269pub struct RequestBuilder {
270 client: Client,
271 method: Method,
272 url: String,
273 headers: Headers,
274 query: Vec<(String, String)>,
275 body: Option<Body>,
276 timeout: Option<Duration>,
277}
278
279impl RequestBuilder {
280 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
282 self.headers.insert(key, value);
283 self
284 }
285
286 pub fn query(mut self, params: &[(&str, &str)]) -> Self {
288 for (k, v) in params {
289 self.query.push((k.to_string(), v.to_string()));
290 }
291 self
292 }
293
294 pub fn query_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
296 self.query.push((key.into(), value.into()));
297 self
298 }
299
300 pub fn json<T: serde::Serialize>(mut self, value: &T) -> ProtocolResult<Self> {
302 let json = serde_json::to_string(value)
303 .map_err(|e| ProtocolError::Serialization(e.to_string()))?;
304 self.headers.set("Content-Type", "application/json");
305 self.body = Some(Body::Text(json));
306 Ok(self)
307 }
308
309 pub fn body(mut self, body: impl Into<String>) -> Self {
311 self.body = Some(Body::Text(body.into()));
312 self
313 }
314
315 pub fn body_bytes(mut self, bytes: Vec<u8>) -> Self {
317 self.body = Some(Body::Bytes(bytes));
318 self
319 }
320
321 pub fn form(mut self, data: &[(&str, &str)]) -> Self {
323 let encoded: String = data
324 .iter()
325 .map(|(k, v)| format!("{}={}", urlencoded(k), urlencoded(v)))
326 .collect::<Vec<_>>()
327 .join("&");
328 self.headers
329 .set("Content-Type", "application/x-www-form-urlencoded");
330 self.body = Some(Body::Text(encoded));
331 self
332 }
333
334 pub fn timeout(mut self, timeout: Duration) -> Self {
336 self.timeout = Some(timeout);
337 self
338 }
339
340 pub fn build(self) -> Request {
342 let mut url = self.url;
343 if !self.query.is_empty() {
344 let query_string: String = self
345 .query
346 .iter()
347 .map(|(k, v)| format!("{}={}", urlencoded(k), urlencoded(v)))
348 .collect::<Vec<_>>()
349 .join("&");
350 if url.contains('?') {
351 url = format!("{}&{}", url, query_string);
352 } else {
353 url = format!("{}?{}", url, query_string);
354 }
355 }
356
357 Request {
358 method: self.method,
359 url,
360 headers: self.headers,
361 body: self.body,
362 timeout: self.timeout,
363 }
364 }
365
366 #[cfg(feature = "reqwest")]
368 pub async fn send(self) -> ProtocolResult<Response> {
369 let client = self.client.clone();
370 let request = self.build();
371
372 if let Some(ref inner) = client.inner {
373 let mut req_builder = match request.method {
374 Method::GET => inner.get(&request.url),
375 Method::POST => inner.post(&request.url),
376 Method::PUT => inner.put(&request.url),
377 Method::DELETE => inner.delete(&request.url),
378 Method::PATCH => inner.patch(&request.url),
379 Method::HEAD => inner.head(&request.url),
380 _ => {
381 return Err(ProtocolError::Protocol(format!(
382 "Unsupported method: {:?}",
383 request.method
384 )))
385 }
386 };
387
388 for (key, values) in request.headers.iter() {
390 for value in values {
391 req_builder = req_builder.header(key.as_str(), value.as_str());
392 }
393 }
394
395 if let Some(body) = request.body {
397 req_builder = match body {
398 Body::Text(text) => req_builder.body(text),
399 Body::Bytes(bytes) => req_builder.body(bytes),
400 };
401 }
402
403 if let Some(timeout) = request.timeout {
405 req_builder = req_builder.timeout(timeout);
406 }
407
408 let resp = req_builder.send().await.map_err(|e| {
410 if e.is_timeout() {
411 ProtocolError::RequestTimeout
412 } else if e.is_connect() {
413 ProtocolError::ConnectionFailed(e.to_string())
414 } else {
415 ProtocolError::Protocol(e.to_string())
416 }
417 })?;
418
419 let status = StatusCode::from_u16(resp.status().as_u16());
421 let mut headers = Headers::new();
422 for (key, value) in resp.headers() {
423 if let Ok(v) = value.to_str() {
424 headers.insert(key.as_str(), v);
425 }
426 }
427
428 let bytes = resp
429 .bytes()
430 .await
431 .map_err(|e| ProtocolError::Io(e.to_string()))?;
432
433 Ok(Response {
434 status,
435 headers,
436 body: bytes.to_vec(),
437 })
438 } else {
439 Err(ProtocolError::Protocol(
440 "HTTP client not initialized".to_string(),
441 ))
442 }
443 }
444
445 #[cfg(not(feature = "reqwest"))]
447 pub fn send_sync(self) -> ProtocolResult<Response> {
448 Err(ProtocolError::Protocol(
449 "HTTP client requires 'http-client' feature".to_string(),
450 ))
451 }
452}
453
454#[derive(Debug, Clone)]
456pub enum Body {
457 Text(String),
459 Bytes(Vec<u8>),
461}
462
463#[derive(Debug, Clone)]
465pub struct Request {
466 pub method: Method,
468 pub url: String,
470 pub headers: Headers,
472 pub body: Option<Body>,
474 pub timeout: Option<Duration>,
476}
477
478#[derive(Debug, Clone)]
480pub struct Response {
481 pub status: StatusCode,
483 pub headers: Headers,
485 pub body: Vec<u8>,
487}
488
489impl Response {
490 pub fn status(&self) -> StatusCode {
492 self.status
493 }
494
495 pub fn headers(&self) -> &Headers {
497 &self.headers
498 }
499
500 pub fn text(&self) -> ProtocolResult<String> {
502 String::from_utf8(self.body.clone())
503 .map_err(|e| ProtocolError::Deserialization(e.to_string()))
504 }
505
506 pub fn json<T: serde::de::DeserializeOwned>(&self) -> ProtocolResult<T> {
508 serde_json::from_slice(&self.body)
509 .map_err(|e| ProtocolError::Deserialization(e.to_string()))
510 }
511
512 pub fn bytes(&self) -> &[u8] {
514 &self.body
515 }
516
517 pub fn into_bytes(self) -> Vec<u8> {
519 self.body
520 }
521
522 pub fn is_success(&self) -> bool {
524 self.status.is_success()
525 }
526
527 pub fn is_client_error(&self) -> bool {
529 self.status.is_client_error()
530 }
531
532 pub fn is_server_error(&self) -> bool {
534 self.status.is_server_error()
535 }
536
537 pub fn error_for_status(self) -> ProtocolResult<Self> {
539 if self.status.is_client_error() {
540 let body = self.text().unwrap_or_default();
541 Err(ProtocolError::ClientError(self.status, body))
542 } else if self.status.is_server_error() {
543 let body = self.text().unwrap_or_default();
544 Err(ProtocolError::ServerError(self.status, body))
545 } else {
546 Ok(self)
547 }
548 }
549
550 pub fn content_type(&self) -> Option<&str> {
552 self.headers.get("content-type")
553 }
554
555 pub fn content_length(&self) -> Option<u64> {
557 self.headers.content_length()
558 }
559}
560
561fn urlencoded(s: &str) -> String {
563 let mut result = String::new();
564 for c in s.chars() {
565 match c {
566 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' | '.' | '~' => {
567 result.push(c);
568 }
569 ' ' => result.push_str("%20"),
570 _ => {
571 for b in c.to_string().as_bytes() {
572 result.push_str(&format!("%{:02X}", b));
573 }
574 }
575 }
576 }
577 result
578}
579
580pub async fn get(url: impl Into<String>) -> ProtocolResult<Response> {
582 #[cfg(feature = "reqwest")]
583 {
584 Client::new().get(url).send().await
585 }
586 #[cfg(not(feature = "reqwest"))]
587 {
588 let _ = url;
589 Err(ProtocolError::Protocol(
590 "HTTP client requires 'http-client' feature".to_string(),
591 ))
592 }
593}
594
595pub async fn post_json<T: serde::Serialize>(
597 url: impl Into<String>,
598 body: &T,
599) -> ProtocolResult<Response> {
600 #[cfg(feature = "reqwest")]
601 {
602 Client::new().post(url).json(body)?.send().await
603 }
604 #[cfg(not(feature = "reqwest"))]
605 {
606 let _ = (url, body);
607 Err(ProtocolError::Protocol(
608 "HTTP client requires 'http-client' feature".to_string(),
609 ))
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_client_builder() {
619 let client = Client::new()
620 .base_url("https://api.example.com")
621 .bearer_auth("token123")
622 .timeout(Duration::from_secs(30));
623
624 assert_eq!(
625 client.config.base_url,
626 Some("https://api.example.com".to_string())
627 );
628 }
629
630 #[test]
631 fn test_request_builder() {
632 let client = Client::new().base_url("https://api.example.com");
633 let request = client
634 .get("/users")
635 .query(&[("page", "1"), ("limit", "10")])
636 .header("X-Custom", "value")
637 .build();
638
639 assert_eq!(request.method, Method::GET);
640 assert!(request.url.contains("page=1"));
641 assert!(request.url.contains("limit=10"));
642 }
643
644 #[test]
645 fn test_url_encoding() {
646 assert_eq!(urlencoded("hello world"), "hello%20world");
647 assert_eq!(urlencoded("foo=bar"), "foo%3Dbar");
648 }
649}