1use crate::error::Error;
4use async_trait::async_trait;
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::Duration;
9use url::Url;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum HttpMethod {
14 Get,
15 Post,
16 Put,
17 Delete,
18 Head,
19 Options,
20 Patch,
21}
22
23impl std::fmt::Display for HttpMethod {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 HttpMethod::Get => write!(f, "GET"),
27 HttpMethod::Post => write!(f, "POST"),
28 HttpMethod::Put => write!(f, "PUT"),
29 HttpMethod::Delete => write!(f, "DELETE"),
30 HttpMethod::Head => write!(f, "HEAD"),
31 HttpMethod::Options => write!(f, "OPTIONS"),
32 HttpMethod::Patch => write!(f, "PATCH"),
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct HttpRequest {
40 pub method: HttpMethod,
41 pub url: Url,
42 pub headers: HashMap<String, String>,
43 pub body: Option<Bytes>,
44 pub timeout: Option<Duration>,
45 pub follow_redirects: bool,
46 pub max_redirects: u32,
47 pub user_agent: Option<String>,
48}
49
50impl HttpRequest {
51 pub fn new(method: HttpMethod, url: Url) -> Self {
52 Self {
53 method,
54 url,
55 headers: HashMap::new(),
56 body: None,
57 timeout: Some(Duration::from_secs(30)),
58 follow_redirects: true,
59 max_redirects: 10,
60 user_agent: Some(default_user_agent()),
61 }
62 }
63
64 pub fn get(url: Url) -> Self {
65 Self::new(HttpMethod::Get, url)
66 }
67
68 pub fn post(url: Url) -> Self {
69 Self::new(HttpMethod::Post, url)
70 }
71
72 pub fn header(mut self, key: String, value: String) -> Self {
73 self.headers.insert(key, value);
74 self
75 }
76
77 pub fn body(mut self, body: impl Into<Bytes>) -> Self {
78 self.body = Some(body.into());
79 self
80 }
81
82 pub fn json<T: Serialize>(mut self, data: &T) -> Result<Self, Error> {
83 let json = serde_json::to_vec(data)
84 .map_err(|e| Error::Other(anyhow::anyhow!("JSON serialization failed: {}", e)))?;
85 self.body = Some(json.into());
86 self.headers.insert("Content-Type".to_string(), "application/json".to_string());
87 Ok(self)
88 }
89
90 pub fn form(mut self, data: &HashMap<String, String>) -> Self {
91 let form_data = data
92 .iter()
93 .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
94 .collect::<Vec<_>>()
95 .join("&");
96 self.body = Some(form_data.into_bytes().into());
97 self.headers.insert("Content-Type".to_string(), "application/x-www-form-urlencoded".to_string());
98 self
99 }
100
101 pub fn timeout(mut self, duration: Duration) -> Self {
102 self.timeout = Some(duration);
103 self
104 }
105
106 pub fn user_agent(mut self, ua: String) -> Self {
107 self.user_agent = Some(ua);
108 self
109 }
110
111 pub fn no_redirects(mut self) -> Self {
112 self.follow_redirects = false;
113 self
114 }
115}
116
117#[derive(Debug)]
119pub struct HttpResponse {
120 pub status: u16,
121 pub headers: HashMap<String, String>,
122 pub body: Bytes,
123 pub url: Url,
124}
125
126impl HttpResponse {
127 pub fn status(&self) -> u16 {
128 self.status
129 }
130
131 pub fn is_success(&self) -> bool {
132 self.status >= 200 && self.status < 300
133 }
134
135 pub fn is_redirect(&self) -> bool {
136 self.status >= 300 && self.status < 400
137 }
138
139 pub fn header(&self, name: &str) -> Option<&String> {
140 self.headers.get(name)
141 }
142
143 pub fn content_type(&self) -> Option<&String> {
144 self.header("content-type").or_else(|| self.header("Content-Type"))
145 }
146
147 pub fn content_length(&self) -> Option<usize> {
148 self.header("content-length")
149 .or_else(|| self.header("Content-Length"))
150 .and_then(|s| s.parse().ok())
151 }
152
153 pub fn body(&self) -> &Bytes {
154 &self.body
155 }
156
157 pub fn text(&self) -> Result<String, Error> {
158 String::from_utf8(self.body.to_vec())
159 .map_err(|e| Error::Other(anyhow::anyhow!("Invalid UTF-8: {}", e)))
160 }
161
162 pub fn json<T: for<'de> Deserialize<'de>>(&self) -> Result<T, Error> {
163 serde_json::from_slice(&self.body)
164 .map_err(|e| Error::Other(anyhow::anyhow!("JSON deserialization failed: {}", e)))
165 }
166}
167
168#[async_trait]
170pub trait HttpInterceptor: Send + Sync {
171 async fn before_request(&self, request: &mut HttpRequest) -> Result<(), Error>;
173
174 async fn after_response(&self, response: &mut HttpResponse) -> Result<(), Error>;
176}
177
178pub struct RateLimiter {
180 requests_per_second: f64,
181 last_request: std::sync::Arc<parking_lot::Mutex<Option<std::time::Instant>>>,
182}
183
184impl RateLimiter {
185 pub fn new(requests_per_second: f64) -> Self {
186 Self {
187 requests_per_second,
188 last_request: std::sync::Arc::new(parking_lot::Mutex::new(None)),
189 }
190 }
191}
192
193#[async_trait]
194impl HttpInterceptor for RateLimiter {
195 async fn before_request(&self, _request: &mut HttpRequest) -> Result<(), Error> {
196 let sleep_duration = {
197 let mut last = self.last_request.lock();
198 if let Some(last_time) = *last {
199 let min_interval = Duration::from_secs_f64(1.0 / self.requests_per_second);
200 let elapsed = last_time.elapsed();
201 if elapsed < min_interval {
202 Some(min_interval - elapsed)
203 } else {
204 None
205 }
206 } else {
207 None
208 }
209 };
210
211 if let Some(duration) = sleep_duration {
212 tokio::time::sleep(duration).await;
213 }
214
215 {
216 let mut last = self.last_request.lock();
217 *last = Some(std::time::Instant::now());
218 }
219
220 Ok(())
221 }
222
223 async fn after_response(&self, _response: &mut HttpResponse) -> Result<(), Error> {
224 Ok(())
225 }
226}
227
228pub struct UserAgentInterceptor {
230 user_agent: String,
231}
232
233impl UserAgentInterceptor {
234 pub fn new(user_agent: String) -> Self {
235 Self { user_agent }
236 }
237}
238
239#[async_trait]
240impl HttpInterceptor for UserAgentInterceptor {
241 async fn before_request(&self, request: &mut HttpRequest) -> Result<(), Error> {
242 if request.user_agent.is_none() {
243 request.user_agent = Some(self.user_agent.clone());
244 }
245 Ok(())
246 }
247
248 async fn after_response(&self, _response: &mut HttpResponse) -> Result<(), Error> {
249 Ok(())
250 }
251}
252
253#[derive(Debug, Clone)]
255pub struct CookieJar {
256 cookies: std::sync::Arc<parking_lot::RwLock<HashMap<String, cookie::Cookie<'static>>>>,
257}
258
259impl CookieJar {
260 pub fn new() -> Self {
261 Self {
262 cookies: std::sync::Arc::new(parking_lot::RwLock::new(HashMap::new())),
263 }
264 }
265
266 pub fn add_cookie(&self, cookie: cookie::Cookie<'static>) {
267 let mut cookies = self.cookies.write();
268 cookies.insert(cookie.name().to_string(), cookie);
269 }
270
271 pub fn get_cookies_for_url(&self, url: &Url) -> Vec<cookie::Cookie<'static>> {
272 let cookies = self.cookies.read();
273 cookies
274 .values()
275 .filter(|cookie| {
276 if let Some(domain) = cookie.domain() {
278 if let Some(host) = url.host_str() {
279 if !host.ends_with(domain) && host != domain {
280 return false;
281 }
282 }
283 }
284 if let Some(path) = cookie.path() {
285 if !url.path().starts_with(path) {
286 return false;
287 }
288 }
289 true
290 })
291 .cloned()
292 .collect()
293 }
294
295 pub fn cookie_header_for_url(&self, url: &Url) -> Option<String> {
296 let cookies = self.get_cookies_for_url(url);
297 if cookies.is_empty() {
298 None
299 } else {
300 Some(
301 cookies
302 .iter()
303 .map(|c| format!("{}={}", c.name(), c.value()))
304 .collect::<Vec<_>>()
305 .join("; ")
306 )
307 }
308 }
309}
310
311pub struct CookieInterceptor {
313 jar: CookieJar,
314}
315
316impl CookieInterceptor {
317 pub fn new(jar: CookieJar) -> Self {
318 Self { jar }
319 }
320}
321
322#[async_trait]
323impl HttpInterceptor for CookieInterceptor {
324 async fn before_request(&self, request: &mut HttpRequest) -> Result<(), Error> {
325 if let Some(cookie_header) = self.jar.cookie_header_for_url(&request.url) {
326 request.headers.insert("Cookie".to_string(), cookie_header);
327 }
328 Ok(())
329 }
330
331 async fn after_response(&self, response: &mut HttpResponse) -> Result<(), Error> {
332 for (name, value) in &response.headers {
334 if name.to_lowercase() == "set-cookie" {
335 if let Ok(cookie) = cookie::Cookie::parse(value.clone()) {
336 self.jar.add_cookie(cookie.into_owned());
337 }
338 }
339 }
340 Ok(())
341 }
342}
343
344#[async_trait]
346pub trait HttpClient: Send + Sync {
347 async fn execute(&self, request: HttpRequest) -> Result<HttpResponse, Error>;
348}
349
350pub struct HttpClientBuilder {
352 interceptors: Vec<Box<dyn HttpInterceptor>>,
353 cookie_jar: Option<CookieJar>,
354 rate_limit: Option<f64>,
355 default_user_agent: Option<String>,
356 default_timeout: Option<Duration>,
357 verify_ssl: bool,
358 proxy: Option<String>,
359}
360
361impl HttpClientBuilder {
362 pub fn new() -> Self {
363 Self {
364 interceptors: Vec::new(),
365 cookie_jar: None,
366 rate_limit: None,
367 default_user_agent: None,
368 default_timeout: Some(Duration::from_secs(30)),
369 verify_ssl: true,
370 proxy: None,
371 }
372 }
373
374 pub fn interceptor(mut self, interceptor: Box<dyn HttpInterceptor>) -> Self {
375 self.interceptors.push(interceptor);
376 self
377 }
378
379 pub fn cookie_jar(mut self, jar: CookieJar) -> Self {
380 self.cookie_jar = Some(jar);
381 self
382 }
383
384 pub fn rate_limit(mut self, requests_per_second: f64) -> Self {
385 self.rate_limit = Some(requests_per_second);
386 self
387 }
388
389 pub fn user_agent(mut self, ua: String) -> Self {
390 self.default_user_agent = Some(ua);
391 self
392 }
393
394 pub fn timeout(mut self, duration: Duration) -> Self {
395 self.default_timeout = Some(duration);
396 self
397 }
398
399 pub fn verify_ssl(mut self, verify: bool) -> Self {
400 self.verify_ssl = verify;
401 self
402 }
403
404 pub fn proxy(mut self, proxy_url: String) -> Self {
405 self.proxy = Some(proxy_url);
406 self
407 }
408
409 pub fn build(mut self) -> Result<Box<dyn HttpClient>, Error> {
410 if let Some(rate) = self.rate_limit {
412 self.interceptors.push(Box::new(RateLimiter::new(rate)));
413 }
414
415 if let Some(ua) = self.default_user_agent {
416 self.interceptors.push(Box::new(UserAgentInterceptor::new(ua)));
417 }
418
419 if let Some(jar) = self.cookie_jar {
420 self.interceptors.push(Box::new(CookieInterceptor::new(jar)));
421 }
422
423 cfg_if::cfg_if! {
424 if #[cfg(target_arch = "wasm32")] {
425 Ok(Box::new(WasmHttpClient::new(self.interceptors, self.default_timeout)?))
426 } else {
427 Ok(Box::new(NativeHttpClient::new(
428 self.interceptors,
429 self.default_timeout,
430 self.verify_ssl,
431 self.proxy,
432 )?))
433 }
434 }
435 }
436}
437
438impl Default for HttpClientBuilder {
439 fn default() -> Self {
440 Self::new()
441 }
442}
443
444#[cfg(not(target_arch = "wasm32"))]
446pub struct NativeHttpClient {
447 client: reqwest::Client,
448 interceptors: Vec<Box<dyn HttpInterceptor>>,
449}
450
451#[cfg(not(target_arch = "wasm32"))]
452impl NativeHttpClient {
453 pub fn new(
454 interceptors: Vec<Box<dyn HttpInterceptor>>,
455 default_timeout: Option<Duration>,
456 verify_ssl: bool,
457 proxy: Option<String>,
458 ) -> Result<Self, Error> {
459 let mut builder = reqwest::Client::builder()
460 .danger_accept_invalid_certs(!verify_ssl)
461 .redirect(reqwest::redirect::Policy::none());
462
463 if let Some(timeout) = default_timeout {
464 builder = builder.timeout(timeout);
465 }
466
467 if let Some(proxy_url) = proxy {
468 let proxy = reqwest::Proxy::all(&proxy_url)
469 .map_err(|e| Error::Other(anyhow::anyhow!("Invalid proxy URL: {}", e)))?;
470 builder = builder.proxy(proxy);
471 }
472
473 let client = builder
474 .build()
475 .map_err(|e| Error::Other(anyhow::anyhow!("Failed to create HTTP client: {}", e)))?;
476
477 Ok(Self { client, interceptors })
478 }
479}
480
481#[cfg(not(target_arch = "wasm32"))]
482#[async_trait]
483impl HttpClient for NativeHttpClient {
484 async fn execute(&self, mut request: HttpRequest) -> Result<HttpResponse, Error> {
485 for interceptor in &self.interceptors {
487 interceptor.before_request(&mut request).await?;
488 }
489
490 let method = match request.method {
491 HttpMethod::Get => reqwest::Method::GET,
492 HttpMethod::Post => reqwest::Method::POST,
493 HttpMethod::Put => reqwest::Method::PUT,
494 HttpMethod::Delete => reqwest::Method::DELETE,
495 HttpMethod::Head => reqwest::Method::HEAD,
496 HttpMethod::Options => reqwest::Method::OPTIONS,
497 HttpMethod::Patch => reqwest::Method::PATCH,
498 };
499
500 let mut req_builder = self.client.request(method, request.url.clone());
501
502 for (key, value) in &request.headers {
504 req_builder = req_builder.header(key, value);
505 }
506
507 if let Some(ua) = &request.user_agent {
509 req_builder = req_builder.header("User-Agent", ua);
510 }
511
512 if let Some(body) = request.body {
514 req_builder = req_builder.body(body);
515 }
516
517 if let Some(timeout) = request.timeout {
519 req_builder = req_builder.timeout(timeout);
520 }
521
522 let response = req_builder
523 .send()
524 .await
525 .map_err(|e| Error::Other(anyhow::anyhow!("HTTP request failed: {}", e)))?;
526
527 let status = response.status().as_u16();
528 let headers = response
529 .headers()
530 .iter()
531 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
532 .collect();
533
534 let body = response
535 .bytes()
536 .await
537 .map_err(|e| Error::Other(anyhow::anyhow!("Failed to read response body: {}", e)))?;
538
539 let mut http_response = HttpResponse {
540 status,
541 headers,
542 body,
543 url: request.url,
544 };
545
546 for interceptor in &self.interceptors {
548 interceptor.after_response(&mut http_response).await?;
549 }
550
551 Ok(http_response)
552 }
553}
554
555#[cfg(target_arch = "wasm32")]
557pub struct WasmHttpClient {
558 interceptors: Vec<Box<dyn HttpInterceptor>>,
559 default_timeout: Option<Duration>,
560}
561
562#[cfg(target_arch = "wasm32")]
563impl WasmHttpClient {
564 pub fn new(
565 interceptors: Vec<Box<dyn HttpInterceptor>>,
566 default_timeout: Option<Duration>,
567 ) -> Result<Self, Error> {
568 Ok(Self {
569 interceptors,
570 default_timeout,
571 })
572 }
573}
574
575#[cfg(target_arch = "wasm32")]
576#[async_trait]
577impl HttpClient for WasmHttpClient {
578 async fn execute(&self, mut request: HttpRequest) -> Result<HttpResponse, Error> {
579 use wasm_bindgen::prelude::*;
580 use wasm_bindgen_futures::JsFuture;
581 use web_sys::{Request, RequestInit, Response};
582
583 for interceptor in &self.interceptors {
585 interceptor.before_request(&mut request).await?;
586 }
587
588 let mut opts = RequestInit::new();
589 opts.method(&request.method.to_string());
590
591 if let Some(body) = request.body {
593 let uint8_array = js_sys::Uint8Array::new_with_length(body.len() as u32);
594 uint8_array.copy_from(&body);
595 opts.body(Some(&uint8_array));
596 }
597
598 let headers = web_sys::Headers::new()
600 .map_err(|_| Error::Other("Failed to create headers".to_string()))?;
601
602 for (key, value) in &request.headers {
603 headers
604 .set(key, value)
605 .map_err(|_| Error::Other(format!("Failed to set header: {}", key)))?;
606 }
607
608 if let Some(ua) = &request.user_agent {
609 headers
610 .set("User-Agent", ua)
611 .map_err(|_| Error::Other("Failed to set User-Agent".to_string()))?;
612 }
613
614 opts.headers(&headers);
615
616 let req = Request::new_with_str_and_init(&request.url.to_string(), &opts)
617 .map_err(|_| Error::Other("Failed to create request".to_string()))?;
618
619 let window = web_sys::window().unwrap();
620 let resp_value = JsFuture::from(window.fetch_with_request(&req))
621 .await
622 .map_err(|_| Error::Other("Fetch failed".to_string()))?;
623
624 let resp: Response = resp_value
625 .dyn_into()
626 .map_err(|_| Error::Other("Invalid response".to_string()))?;
627
628 let status = resp.status() as u16;
629
630 let mut response_headers = HashMap::new();
632 let headers_iter = js_sys::try_iter(&resp.headers())
633 .map_err(|_| Error::Other("Failed to iterate headers".to_string()))?
634 .ok_or_else(|| Error::Other("Headers not iterable".to_string()))?;
635
636 for item in headers_iter {
637 let item = item.map_err(|_| Error::Other("Header iteration error".to_string()))?;
638 let entry = js_sys::Array::from(&item);
639 let key = entry.get(0).as_string().unwrap_or_default();
640 let value = entry.get(1).as_string().unwrap_or_default();
641 response_headers.insert(key, value);
642 }
643
644 let array_buffer = JsFuture::from(resp.array_buffer())
646 .await
647 .map_err(|_| Error::Other("Failed to read response body".to_string()))?;
648
649 let uint8_array = js_sys::Uint8Array::new(&array_buffer);
650 let body = uint8_array.to_vec().into();
651
652 let mut http_response = HttpResponse {
653 status,
654 headers: response_headers,
655 body,
656 url: request.url,
657 };
658
659 for interceptor in &self.interceptors {
661 interceptor.after_response(&mut http_response).await?;
662 }
663
664 Ok(http_response)
665 }
666}
667
668pub fn default_user_agent() -> String {
670 "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 CodeMesh/1.0".to_string()
671}
672
673pub fn is_safe_url(url: &Url) -> bool {
675 if !matches!(url.scheme(), "http" | "https") {
677 return false;
678 }
679
680 if let Some(host) = url.host() {
682 match host {
683 url::Host::Ipv4(ip) => {
684 if ip.is_private() || ip.is_loopback() || ip.is_link_local() {
685 return false;
686 }
687 }
688 url::Host::Ipv6(ip) => {
689 if ip.is_loopback() || ip.is_unspecified() {
690 return false;
691 }
692 }
693 url::Host::Domain(domain) => {
694 if domain == "localhost" || domain.ends_with(".local") || domain.ends_with(".internal") {
696 return false;
697 }
698 }
699 }
700 }
701
702 true
703}
704
705pub fn sanitize_url(url_str: &str) -> Result<Url, Error> {
707 let url = Url::parse(url_str)
708 .map_err(|e| Error::Other(anyhow::anyhow!("Invalid URL: {}", e)))?;
709
710 if !is_safe_url(&url) {
711 return Err(Error::Other(anyhow::anyhow!("URL not allowed for security reasons")));
712 }
713
714 Ok(url)
715}
716
717#[cfg(test)]
718mod tests {
719 use super::*;
720
721 #[test]
722 fn test_url_safety() {
723 assert!(is_safe_url(&Url::parse("https://example.com").unwrap()));
725 assert!(is_safe_url(&Url::parse("http://google.com").unwrap()));
726
727 assert!(!is_safe_url(&Url::parse("http://127.0.0.1").unwrap()));
729 assert!(!is_safe_url(&Url::parse("http://localhost").unwrap()));
730 assert!(!is_safe_url(&Url::parse("http://192.168.1.1").unwrap()));
731 assert!(!is_safe_url(&Url::parse("file:///etc/passwd").unwrap()));
732 }
733
734 #[test]
735 fn test_cookie_jar() {
736 let jar = CookieJar::new();
737 let cookie = cookie::Cookie::build(("session", "abc123"))
738 .domain("example.com")
739 .path("/")
740 .finish();
741
742 jar.add_cookie(cookie.into_owned());
743
744 let url = Url::parse("https://example.com/test").unwrap();
745 let header = jar.cookie_header_for_url(&url);
746 assert_eq!(header, Some("session=abc123".to_string()));
747 }
748}