code_mesh_core/tool/
http.rs

1//! HTTP client abstraction for unified web requests across native and WASM
2
3use crate::error::Error;
4use async_trait::async_trait;
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::Duration;
9use url::Url;
10
11/// HTTP method enumeration
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum HttpMethod {
14    Get,
15    Post,
16    Put,
17    Delete,
18    Head,
19    Options,
20    Patch,
21}
22
23impl std::fmt::Display for HttpMethod {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            HttpMethod::Get => write!(f, "GET"),
27            HttpMethod::Post => write!(f, "POST"),
28            HttpMethod::Put => write!(f, "PUT"),
29            HttpMethod::Delete => write!(f, "DELETE"),
30            HttpMethod::Head => write!(f, "HEAD"),
31            HttpMethod::Options => write!(f, "OPTIONS"),
32            HttpMethod::Patch => write!(f, "PATCH"),
33        }
34    }
35}
36
37/// HTTP request builder
38#[derive(Debug, Clone)]
39pub struct HttpRequest {
40    pub method: HttpMethod,
41    pub url: Url,
42    pub headers: HashMap<String, String>,
43    pub body: Option<Bytes>,
44    pub timeout: Option<Duration>,
45    pub follow_redirects: bool,
46    pub max_redirects: u32,
47    pub user_agent: Option<String>,
48}
49
50impl HttpRequest {
51    pub fn new(method: HttpMethod, url: Url) -> Self {
52        Self {
53            method,
54            url,
55            headers: HashMap::new(),
56            body: None,
57            timeout: Some(Duration::from_secs(30)),
58            follow_redirects: true,
59            max_redirects: 10,
60            user_agent: Some(default_user_agent()),
61        }
62    }
63
64    pub fn get(url: Url) -> Self {
65        Self::new(HttpMethod::Get, url)
66    }
67
68    pub fn post(url: Url) -> Self {
69        Self::new(HttpMethod::Post, url)
70    }
71
72    pub fn header(mut self, key: String, value: String) -> Self {
73        self.headers.insert(key, value);
74        self
75    }
76
77    pub fn body(mut self, body: impl Into<Bytes>) -> Self {
78        self.body = Some(body.into());
79        self
80    }
81
82    pub fn json<T: Serialize>(mut self, data: &T) -> Result<Self, Error> {
83        let json = serde_json::to_vec(data)
84            .map_err(|e| Error::Other(anyhow::anyhow!("JSON serialization failed: {}", e)))?;
85        self.body = Some(json.into());
86        self.headers.insert("Content-Type".to_string(), "application/json".to_string());
87        Ok(self)
88    }
89
90    pub fn form(mut self, data: &HashMap<String, String>) -> Self {
91        let form_data = data
92            .iter()
93            .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
94            .collect::<Vec<_>>()
95            .join("&");
96        self.body = Some(form_data.into_bytes().into());
97        self.headers.insert("Content-Type".to_string(), "application/x-www-form-urlencoded".to_string());
98        self
99    }
100
101    pub fn timeout(mut self, duration: Duration) -> Self {
102        self.timeout = Some(duration);
103        self
104    }
105
106    pub fn user_agent(mut self, ua: String) -> Self {
107        self.user_agent = Some(ua);
108        self
109    }
110
111    pub fn no_redirects(mut self) -> Self {
112        self.follow_redirects = false;
113        self
114    }
115}
116
117/// HTTP response
118#[derive(Debug)]
119pub struct HttpResponse {
120    pub status: u16,
121    pub headers: HashMap<String, String>,
122    pub body: Bytes,
123    pub url: Url,
124}
125
126impl HttpResponse {
127    pub fn status(&self) -> u16 {
128        self.status
129    }
130
131    pub fn is_success(&self) -> bool {
132        self.status >= 200 && self.status < 300
133    }
134
135    pub fn is_redirect(&self) -> bool {
136        self.status >= 300 && self.status < 400
137    }
138
139    pub fn header(&self, name: &str) -> Option<&String> {
140        self.headers.get(name)
141    }
142
143    pub fn content_type(&self) -> Option<&String> {
144        self.header("content-type").or_else(|| self.header("Content-Type"))
145    }
146
147    pub fn content_length(&self) -> Option<usize> {
148        self.header("content-length")
149            .or_else(|| self.header("Content-Length"))
150            .and_then(|s| s.parse().ok())
151    }
152
153    pub fn body(&self) -> &Bytes {
154        &self.body
155    }
156
157    pub fn text(&self) -> Result<String, Error> {
158        String::from_utf8(self.body.to_vec())
159            .map_err(|e| Error::Other(anyhow::anyhow!("Invalid UTF-8: {}", e)))
160    }
161
162    pub fn json<T: for<'de> Deserialize<'de>>(&self) -> Result<T, Error> {
163        serde_json::from_slice(&self.body)
164            .map_err(|e| Error::Other(anyhow::anyhow!("JSON deserialization failed: {}", e)))
165    }
166}
167
168/// Request/Response interceptor trait
169#[async_trait]
170pub trait HttpInterceptor: Send + Sync {
171    /// Called before sending a request
172    async fn before_request(&self, request: &mut HttpRequest) -> Result<(), Error>;
173    
174    /// Called after receiving a response
175    async fn after_response(&self, response: &mut HttpResponse) -> Result<(), Error>;
176}
177
178/// Rate limiting interceptor
179pub struct RateLimiter {
180    requests_per_second: f64,
181    last_request: std::sync::Arc<parking_lot::Mutex<Option<std::time::Instant>>>,
182}
183
184impl RateLimiter {
185    pub fn new(requests_per_second: f64) -> Self {
186        Self {
187            requests_per_second,
188            last_request: std::sync::Arc::new(parking_lot::Mutex::new(None)),
189        }
190    }
191}
192
193#[async_trait]
194impl HttpInterceptor for RateLimiter {
195    async fn before_request(&self, _request: &mut HttpRequest) -> Result<(), Error> {
196        let sleep_duration = {
197            let mut last = self.last_request.lock();
198            if let Some(last_time) = *last {
199                let min_interval = Duration::from_secs_f64(1.0 / self.requests_per_second);
200                let elapsed = last_time.elapsed();
201                if elapsed < min_interval {
202                    Some(min_interval - elapsed)
203                } else {
204                    None
205                }
206            } else {
207                None
208            }
209        };
210        
211        if let Some(duration) = sleep_duration {
212            tokio::time::sleep(duration).await;
213        }
214        
215        {
216            let mut last = self.last_request.lock();
217            *last = Some(std::time::Instant::now());
218        }
219        
220        Ok(())
221    }
222
223    async fn after_response(&self, _response: &mut HttpResponse) -> Result<(), Error> {
224        Ok(())
225    }
226}
227
228/// User-Agent interceptor
229pub struct UserAgentInterceptor {
230    user_agent: String,
231}
232
233impl UserAgentInterceptor {
234    pub fn new(user_agent: String) -> Self {
235        Self { user_agent }
236    }
237}
238
239#[async_trait]
240impl HttpInterceptor for UserAgentInterceptor {
241    async fn before_request(&self, request: &mut HttpRequest) -> Result<(), Error> {
242        if request.user_agent.is_none() {
243            request.user_agent = Some(self.user_agent.clone());
244        }
245        Ok(())
246    }
247
248    async fn after_response(&self, _response: &mut HttpResponse) -> Result<(), Error> {
249        Ok(())
250    }
251}
252
253/// Cookie jar for session management
254#[derive(Debug, Clone)]
255pub struct CookieJar {
256    cookies: std::sync::Arc<parking_lot::RwLock<HashMap<String, cookie::Cookie<'static>>>>,
257}
258
259impl CookieJar {
260    pub fn new() -> Self {
261        Self {
262            cookies: std::sync::Arc::new(parking_lot::RwLock::new(HashMap::new())),
263        }
264    }
265
266    pub fn add_cookie(&self, cookie: cookie::Cookie<'static>) {
267        let mut cookies = self.cookies.write();
268        cookies.insert(cookie.name().to_string(), cookie);
269    }
270
271    pub fn get_cookies_for_url(&self, url: &Url) -> Vec<cookie::Cookie<'static>> {
272        let cookies = self.cookies.read();
273        cookies
274            .values()
275            .filter(|cookie| {
276                // Basic domain/path matching
277                if let Some(domain) = cookie.domain() {
278                    if let Some(host) = url.host_str() {
279                        if !host.ends_with(domain) && host != domain {
280                            return false;
281                        }
282                    }
283                }
284                if let Some(path) = cookie.path() {
285                    if !url.path().starts_with(path) {
286                        return false;
287                    }
288                }
289                true
290            })
291            .cloned()
292            .collect()
293    }
294
295    pub fn cookie_header_for_url(&self, url: &Url) -> Option<String> {
296        let cookies = self.get_cookies_for_url(url);
297        if cookies.is_empty() {
298            None
299        } else {
300            Some(
301                cookies
302                    .iter()
303                    .map(|c| format!("{}={}", c.name(), c.value()))
304                    .collect::<Vec<_>>()
305                    .join("; ")
306            )
307        }
308    }
309}
310
311/// Cookie interceptor
312pub struct CookieInterceptor {
313    jar: CookieJar,
314}
315
316impl CookieInterceptor {
317    pub fn new(jar: CookieJar) -> Self {
318        Self { jar }
319    }
320}
321
322#[async_trait]
323impl HttpInterceptor for CookieInterceptor {
324    async fn before_request(&self, request: &mut HttpRequest) -> Result<(), Error> {
325        if let Some(cookie_header) = self.jar.cookie_header_for_url(&request.url) {
326            request.headers.insert("Cookie".to_string(), cookie_header);
327        }
328        Ok(())
329    }
330
331    async fn after_response(&self, response: &mut HttpResponse) -> Result<(), Error> {
332        // Parse Set-Cookie headers
333        for (name, value) in &response.headers {
334            if name.to_lowercase() == "set-cookie" {
335                if let Ok(cookie) = cookie::Cookie::parse(value.clone()) {
336                    self.jar.add_cookie(cookie.into_owned());
337                }
338            }
339        }
340        Ok(())
341    }
342}
343
344/// HTTP client trait
345#[async_trait]
346pub trait HttpClient: Send + Sync {
347    async fn execute(&self, request: HttpRequest) -> Result<HttpResponse, Error>;
348}
349
350/// HTTP client builder
351pub struct HttpClientBuilder {
352    interceptors: Vec<Box<dyn HttpInterceptor>>,
353    cookie_jar: Option<CookieJar>,
354    rate_limit: Option<f64>,
355    default_user_agent: Option<String>,
356    default_timeout: Option<Duration>,
357    verify_ssl: bool,
358    proxy: Option<String>,
359}
360
361impl HttpClientBuilder {
362    pub fn new() -> Self {
363        Self {
364            interceptors: Vec::new(),
365            cookie_jar: None,
366            rate_limit: None,
367            default_user_agent: None,
368            default_timeout: Some(Duration::from_secs(30)),
369            verify_ssl: true,
370            proxy: None,
371        }
372    }
373
374    pub fn interceptor(mut self, interceptor: Box<dyn HttpInterceptor>) -> Self {
375        self.interceptors.push(interceptor);
376        self
377    }
378
379    pub fn cookie_jar(mut self, jar: CookieJar) -> Self {
380        self.cookie_jar = Some(jar);
381        self
382    }
383
384    pub fn rate_limit(mut self, requests_per_second: f64) -> Self {
385        self.rate_limit = Some(requests_per_second);
386        self
387    }
388
389    pub fn user_agent(mut self, ua: String) -> Self {
390        self.default_user_agent = Some(ua);
391        self
392    }
393
394    pub fn timeout(mut self, duration: Duration) -> Self {
395        self.default_timeout = Some(duration);
396        self
397    }
398
399    pub fn verify_ssl(mut self, verify: bool) -> Self {
400        self.verify_ssl = verify;
401        self
402    }
403
404    pub fn proxy(mut self, proxy_url: String) -> Self {
405        self.proxy = Some(proxy_url);
406        self
407    }
408
409    pub fn build(mut self) -> Result<Box<dyn HttpClient>, Error> {
410        // Add default interceptors
411        if let Some(rate) = self.rate_limit {
412            self.interceptors.push(Box::new(RateLimiter::new(rate)));
413        }
414
415        if let Some(ua) = self.default_user_agent {
416            self.interceptors.push(Box::new(UserAgentInterceptor::new(ua)));
417        }
418
419        if let Some(jar) = self.cookie_jar {
420            self.interceptors.push(Box::new(CookieInterceptor::new(jar)));
421        }
422
423        cfg_if::cfg_if! {
424            if #[cfg(target_arch = "wasm32")] {
425                Ok(Box::new(WasmHttpClient::new(self.interceptors, self.default_timeout)?))
426            } else {
427                Ok(Box::new(NativeHttpClient::new(
428                    self.interceptors,
429                    self.default_timeout,
430                    self.verify_ssl,
431                    self.proxy,
432                )?))
433            }
434        }
435    }
436}
437
438impl Default for HttpClientBuilder {
439    fn default() -> Self {
440        Self::new()
441    }
442}
443
444/// Native HTTP client implementation using reqwest
445#[cfg(not(target_arch = "wasm32"))]
446pub struct NativeHttpClient {
447    client: reqwest::Client,
448    interceptors: Vec<Box<dyn HttpInterceptor>>,
449}
450
451#[cfg(not(target_arch = "wasm32"))]
452impl NativeHttpClient {
453    pub fn new(
454        interceptors: Vec<Box<dyn HttpInterceptor>>,
455        default_timeout: Option<Duration>,
456        verify_ssl: bool,
457        proxy: Option<String>,
458    ) -> Result<Self, Error> {
459        let mut builder = reqwest::Client::builder()
460            .danger_accept_invalid_certs(!verify_ssl)
461            .redirect(reqwest::redirect::Policy::none());
462
463        if let Some(timeout) = default_timeout {
464            builder = builder.timeout(timeout);
465        }
466
467        if let Some(proxy_url) = proxy {
468            let proxy = reqwest::Proxy::all(&proxy_url)
469                .map_err(|e| Error::Other(anyhow::anyhow!("Invalid proxy URL: {}", e)))?;
470            builder = builder.proxy(proxy);
471        }
472
473        let client = builder
474            .build()
475            .map_err(|e| Error::Other(anyhow::anyhow!("Failed to create HTTP client: {}", e)))?;
476
477        Ok(Self { client, interceptors })
478    }
479}
480
481#[cfg(not(target_arch = "wasm32"))]
482#[async_trait]
483impl HttpClient for NativeHttpClient {
484    async fn execute(&self, mut request: HttpRequest) -> Result<HttpResponse, Error> {
485        // Apply request interceptors
486        for interceptor in &self.interceptors {
487            interceptor.before_request(&mut request).await?;
488        }
489
490        let method = match request.method {
491            HttpMethod::Get => reqwest::Method::GET,
492            HttpMethod::Post => reqwest::Method::POST,
493            HttpMethod::Put => reqwest::Method::PUT,
494            HttpMethod::Delete => reqwest::Method::DELETE,
495            HttpMethod::Head => reqwest::Method::HEAD,
496            HttpMethod::Options => reqwest::Method::OPTIONS,
497            HttpMethod::Patch => reqwest::Method::PATCH,
498        };
499
500        let mut req_builder = self.client.request(method, request.url.clone());
501
502        // Add headers
503        for (key, value) in &request.headers {
504            req_builder = req_builder.header(key, value);
505        }
506
507        // Add user agent
508        if let Some(ua) = &request.user_agent {
509            req_builder = req_builder.header("User-Agent", ua);
510        }
511
512        // Add body
513        if let Some(body) = request.body {
514            req_builder = req_builder.body(body);
515        }
516
517        // Set timeout
518        if let Some(timeout) = request.timeout {
519            req_builder = req_builder.timeout(timeout);
520        }
521
522        let response = req_builder
523            .send()
524            .await
525            .map_err(|e| Error::Other(anyhow::anyhow!("HTTP request failed: {}", e)))?;
526
527        let status = response.status().as_u16();
528        let headers = response
529            .headers()
530            .iter()
531            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
532            .collect();
533
534        let body = response
535            .bytes()
536            .await
537            .map_err(|e| Error::Other(anyhow::anyhow!("Failed to read response body: {}", e)))?;
538
539        let mut http_response = HttpResponse {
540            status,
541            headers,
542            body,
543            url: request.url,
544        };
545
546        // Apply response interceptors
547        for interceptor in &self.interceptors {
548            interceptor.after_response(&mut http_response).await?;
549        }
550
551        Ok(http_response)
552    }
553}
554
555/// WASM HTTP client implementation using web-sys fetch
556#[cfg(target_arch = "wasm32")]
557pub struct WasmHttpClient {
558    interceptors: Vec<Box<dyn HttpInterceptor>>,
559    default_timeout: Option<Duration>,
560}
561
562#[cfg(target_arch = "wasm32")]
563impl WasmHttpClient {
564    pub fn new(
565        interceptors: Vec<Box<dyn HttpInterceptor>>,
566        default_timeout: Option<Duration>,
567    ) -> Result<Self, Error> {
568        Ok(Self {
569            interceptors,
570            default_timeout,
571        })
572    }
573}
574
575#[cfg(target_arch = "wasm32")]
576#[async_trait]
577impl HttpClient for WasmHttpClient {
578    async fn execute(&self, mut request: HttpRequest) -> Result<HttpResponse, Error> {
579        use wasm_bindgen::prelude::*;
580        use wasm_bindgen_futures::JsFuture;
581        use web_sys::{Request, RequestInit, Response};
582
583        // Apply request interceptors
584        for interceptor in &self.interceptors {
585            interceptor.before_request(&mut request).await?;
586        }
587
588        let mut opts = RequestInit::new();
589        opts.method(&request.method.to_string());
590
591        // Add body
592        if let Some(body) = request.body {
593            let uint8_array = js_sys::Uint8Array::new_with_length(body.len() as u32);
594            uint8_array.copy_from(&body);
595            opts.body(Some(&uint8_array));
596        }
597
598        // Create headers
599        let headers = web_sys::Headers::new()
600            .map_err(|_| Error::Other("Failed to create headers".to_string()))?;
601
602        for (key, value) in &request.headers {
603            headers
604                .set(key, value)
605                .map_err(|_| Error::Other(format!("Failed to set header: {}", key)))?;
606        }
607
608        if let Some(ua) = &request.user_agent {
609            headers
610                .set("User-Agent", ua)
611                .map_err(|_| Error::Other("Failed to set User-Agent".to_string()))?;
612        }
613
614        opts.headers(&headers);
615
616        let req = Request::new_with_str_and_init(&request.url.to_string(), &opts)
617            .map_err(|_| Error::Other("Failed to create request".to_string()))?;
618
619        let window = web_sys::window().unwrap();
620        let resp_value = JsFuture::from(window.fetch_with_request(&req))
621            .await
622            .map_err(|_| Error::Other("Fetch failed".to_string()))?;
623
624        let resp: Response = resp_value
625            .dyn_into()
626            .map_err(|_| Error::Other("Invalid response".to_string()))?;
627
628        let status = resp.status() as u16;
629
630        // Extract headers
631        let mut response_headers = HashMap::new();
632        let headers_iter = js_sys::try_iter(&resp.headers())
633            .map_err(|_| Error::Other("Failed to iterate headers".to_string()))?
634            .ok_or_else(|| Error::Other("Headers not iterable".to_string()))?;
635
636        for item in headers_iter {
637            let item = item.map_err(|_| Error::Other("Header iteration error".to_string()))?;
638            let entry = js_sys::Array::from(&item);
639            let key = entry.get(0).as_string().unwrap_or_default();
640            let value = entry.get(1).as_string().unwrap_or_default();
641            response_headers.insert(key, value);
642        }
643
644        // Read body
645        let array_buffer = JsFuture::from(resp.array_buffer())
646            .await
647            .map_err(|_| Error::Other("Failed to read response body".to_string()))?;
648
649        let uint8_array = js_sys::Uint8Array::new(&array_buffer);
650        let body = uint8_array.to_vec().into();
651
652        let mut http_response = HttpResponse {
653            status,
654            headers: response_headers,
655            body,
656            url: request.url,
657        };
658
659        // Apply response interceptors
660        for interceptor in &self.interceptors {
661            interceptor.after_response(&mut http_response).await?;
662        }
663
664        Ok(http_response)
665    }
666}
667
668/// Default User-Agent string
669pub fn default_user_agent() -> String {
670    "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 CodeMesh/1.0".to_string()
671}
672
673/// Security helpers for SSRF protection
674pub fn is_safe_url(url: &Url) -> bool {
675    // Check scheme
676    if !matches!(url.scheme(), "http" | "https") {
677        return false;
678    }
679
680    // Check for private/internal IP ranges
681    if let Some(host) = url.host() {
682        match host {
683            url::Host::Ipv4(ip) => {
684                if ip.is_private() || ip.is_loopback() || ip.is_link_local() {
685                    return false;
686                }
687            }
688            url::Host::Ipv6(ip) => {
689                if ip.is_loopback() || ip.is_unspecified() {
690                    return false;
691                }
692            }
693            url::Host::Domain(domain) => {
694                // Block localhost and internal domains
695                if domain == "localhost" || domain.ends_with(".local") || domain.ends_with(".internal") {
696                    return false;
697                }
698            }
699        }
700    }
701
702    true
703}
704
705/// Sanitize URL to prevent SSRF attacks
706pub fn sanitize_url(url_str: &str) -> Result<Url, Error> {
707    let url = Url::parse(url_str)
708        .map_err(|e| Error::Other(anyhow::anyhow!("Invalid URL: {}", e)))?;
709
710    if !is_safe_url(&url) {
711        return Err(Error::Other(anyhow::anyhow!("URL not allowed for security reasons")));
712    }
713
714    Ok(url)
715}
716
717#[cfg(test)]
718mod tests {
719    use super::*;
720
721    #[test]
722    fn test_url_safety() {
723        // Safe URLs
724        assert!(is_safe_url(&Url::parse("https://example.com").unwrap()));
725        assert!(is_safe_url(&Url::parse("http://google.com").unwrap()));
726
727        // Unsafe URLs
728        assert!(!is_safe_url(&Url::parse("http://127.0.0.1").unwrap()));
729        assert!(!is_safe_url(&Url::parse("http://localhost").unwrap()));
730        assert!(!is_safe_url(&Url::parse("http://192.168.1.1").unwrap()));
731        assert!(!is_safe_url(&Url::parse("file:///etc/passwd").unwrap()));
732    }
733
734    #[test]
735    fn test_cookie_jar() {
736        let jar = CookieJar::new();
737        let cookie = cookie::Cookie::build(("session", "abc123"))
738            .domain("example.com")
739            .path("/")
740            .finish();
741        
742        jar.add_cookie(cookie.into_owned());
743
744        let url = Url::parse("https://example.com/test").unwrap();
745        let header = jar.cookie_header_for_url(&url);
746        assert_eq!(header, Some("session=abc123".to_string()));
747    }
748}