Skip to main content

spider_util/
request.rs

1//! Data structures for representing HTTP requests in `spider-lib`.
2//!
3//! This module defines the `Request` struct, which is a central component
4//! for constructing and managing outgoing HTTP requests within the
5//! `spider-lib` framework. It encapsulates all necessary details of an
6//! HTTP request, including:
7//! - The target URL and HTTP method.
8//! - Request headers and an optional request body (supporting JSON, form data, or raw bytes).
9//! - Metadata for tracking retry attempts or other custom information.
10//!
11//! Additionally, the module provides methods for building requests,
12//! incrementing retry counters, and generating unique fingerprints
13//! for request deduplication and caching.
14
15use bytes::Bytes;
16use dashmap::DashMap;
17use http::header::HeaderMap;
18use reqwest::{Method, Url};
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21use std::borrow::Cow;
22use std::collections::HashMap;
23use std::hash::Hasher;
24use std::str::FromStr;
25use twox_hash::XxHash64;
26
27use crate::error::SpiderError;
28
29#[derive(Debug, Clone)]
30pub enum Body {
31    Json(Value),
32    Form(DashMap<String, String>),
33    Bytes(Bytes),
34}
35
36// Custom serialization for Body enum
37impl Serialize for Body {
38    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
39    where
40        S: serde::Serializer,
41    {
42        use serde::ser::SerializeMap;
43        let mut map = serializer.serialize_map(Some(1))?;
44
45        match self {
46            Body::Json(value) => map.serialize_entry("Json", value)?,
47            Body::Form(dashmap) => {
48                let hmap: HashMap<String, String> = dashmap.clone().into_iter().collect();
49                map.serialize_entry("Form", &hmap)?
50            }
51            Body::Bytes(bytes) => map.serialize_entry("Bytes", bytes)?,
52        }
53
54        map.end()
55    }
56}
57
58impl<'de> Deserialize<'de> for Body {
59    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
60    where
61        D: serde::Deserializer<'de>,
62    {
63        use serde::de::{self, MapAccess, Visitor};
64        use std::fmt;
65
66        struct BodyVisitor;
67
68        impl<'de> Visitor<'de> for BodyVisitor {
69            type Value = Body;
70
71            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
72                formatter.write_str("a body object")
73            }
74
75            fn visit_map<V>(self, mut map: V) -> Result<Body, V::Error>
76            where
77                V: MapAccess<'de>,
78            {
79                let entry = map.next_entry::<String, Value>()?;
80                let (key, value) = match entry {
81                    Some((k, v)) => (k, v),
82                    None => return Err(de::Error::custom("Expected a body variant")),
83                };
84
85                match key.as_str() {
86                    "Json" => Ok(Body::Json(value)),
87                    "Form" => {
88                        let form_data: HashMap<String, String> =
89                            serde_json::from_value(value).map_err(de::Error::custom)?;
90                        let dashmap = DashMap::new();
91                        for (k, v) in form_data {
92                            dashmap.insert(k, v);
93                        }
94                        Ok(Body::Form(dashmap))
95                    }
96                    "Bytes" => {
97                        let bytes: Bytes =
98                            serde_json::from_value(value).map_err(de::Error::custom)?;
99                        Ok(Body::Bytes(bytes))
100                    }
101                    _ => Err(de::Error::custom(format!("Unknown body variant: {}", key))),
102                }
103            }
104        }
105
106        deserializer.deserialize_map(BodyVisitor)
107    }
108}
109
110#[derive(Debug, Clone)]
111pub struct Request {
112    pub url: Url,
113    pub method: Method,
114    pub headers: HeaderMap,
115    pub body: Option<Body>,
116    pub meta: DashMap<Cow<'static, str>, Value>,
117}
118
119// Custom serialization for Request struct
120impl Serialize for Request {
121    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
122    where
123        S: serde::Serializer,
124    {
125        use serde::ser::SerializeStruct;
126        // Convert HeaderMap to a serializable format
127        let headers_vec: Vec<(String, String)> = self
128            .headers
129            .iter()
130            .filter_map(|(name, value)| {
131                value
132                    .to_str()
133                    .ok()
134                    .map(|val_str| (name.as_str().to_string(), val_str.to_string()))
135            })
136            .collect();
137
138        let mut s = serializer.serialize_struct("Request", 5)?;
139        s.serialize_field("url", &self.url.as_str())?;
140        s.serialize_field("method", &self.method.as_str())?;
141        s.serialize_field("headers", &headers_vec)?;
142        s.serialize_field("body", &self.body)?;
143        s.end()
144    }
145}
146
147impl<'de> Deserialize<'de> for Request {
148    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
149    where
150        D: serde::Deserializer<'de>,
151    {
152        use serde::de::{self, MapAccess, Visitor};
153        use std::fmt;
154
155        #[derive(Deserialize)]
156        #[serde(field_identifier, rename_all = "lowercase")]
157        enum Field {
158            Url,
159            Method,
160            Headers,
161            Body,
162        }
163
164        struct RequestVisitor;
165
166        impl<'de> Visitor<'de> for RequestVisitor {
167            type Value = Request;
168
169            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
170                formatter.write_str("struct Request")
171            }
172
173            fn visit_map<V>(self, mut map: V) -> Result<Request, V::Error>
174            where
175                V: MapAccess<'de>,
176            {
177                let mut url = None;
178                let mut method = None;
179                let mut headers = None;
180                let mut body = None;
181
182                while let Some(key) = map.next_key()? {
183                    match key {
184                        Field::Url => {
185                            if url.is_some() {
186                                return Err(de::Error::duplicate_field("url"));
187                            }
188                            let url_str: String = map.next_value()?;
189                            let parsed_url = Url::parse(&url_str).map_err(de::Error::custom)?;
190                            url = Some(parsed_url);
191                        }
192                        Field::Method => {
193                            if method.is_some() {
194                                return Err(de::Error::duplicate_field("method"));
195                            }
196                            let method_str: String = map.next_value()?;
197                            let parsed_method =
198                                Method::from_str(&method_str).map_err(de::Error::custom)?;
199                            method = Some(parsed_method);
200                        }
201                        Field::Headers => {
202                            if headers.is_some() {
203                                return Err(de::Error::duplicate_field("headers"));
204                            }
205                            // Deserialize headers vector and convert back to HeaderMap
206                            let headers_vec: Vec<(String, String)> = map.next_value()?;
207                            let mut header_map = HeaderMap::new();
208                            for (name, value) in headers_vec {
209                                if let Ok(header_name) =
210                                    http::header::HeaderName::from_bytes(name.as_bytes())
211                                    && let Ok(header_value) =
212                                        http::header::HeaderValue::from_str(&value)
213                                {
214                                    header_map.insert(header_name, header_value);
215                                }
216                            }
217                            headers = Some(header_map);
218                        }
219                        Field::Body => {
220                            if body.is_some() {
221                                return Err(de::Error::duplicate_field("body"));
222                            }
223                            body = Some(map.next_value()?);
224                        }
225                    }
226                }
227
228                let url = url.ok_or_else(|| de::Error::missing_field("url"))?;
229                let method = method.ok_or_else(|| de::Error::missing_field("method"))?;
230                let headers = headers.ok_or_else(|| de::Error::missing_field("headers"))?;
231                let body = body; // Optional field
232
233                Ok(Request {
234                    url,
235                    method,
236                    headers,
237                    body,
238                    meta: DashMap::new(), // Initialize empty meta map
239                })
240            }
241        }
242
243        const FIELDS: &[&str] = &["url", "method", "headers", "body"];
244        deserializer.deserialize_struct("Request", FIELDS, RequestVisitor)
245    }
246}
247
248impl Default for Request {
249    fn default() -> Self {
250        Self {
251            url: Url::parse("http://default.invalid").unwrap(),
252            method: Method::GET,
253            headers: HeaderMap::new(),
254            body: None,
255            meta: DashMap::new(),
256        }
257    }
258}
259
260impl Request {
261    /// Creates a new `Request` with the given URL.
262    pub fn new(url: Url) -> Self {
263        Request {
264            url,
265            method: Method::GET,
266            headers: HeaderMap::new(),
267            body: None,
268            meta: DashMap::new(),
269        }
270    }
271
272    /// Sets the HTTP method for the request.
273    pub fn with_method(mut self, method: Method) -> Self {
274        self.method = method;
275        self
276    }
277
278    /// Adds a header to the request.
279    pub fn with_header(mut self, name: &str, value: &str) -> Result<Self, SpiderError> {
280        let header_name =
281            reqwest::header::HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
282                SpiderError::HeaderValueError(format!("Invalid header name '{}': {}", name, e))
283            })?;
284        let header_value = reqwest::header::HeaderValue::from_str(value).map_err(|e| {
285            SpiderError::HeaderValueError(format!("Invalid header value '{}': {}", value, e))
286        })?;
287
288        self.headers.insert(header_name, header_value);
289        Ok(self)
290    }
291
292    /// Sets the body of the request and defaults the method to POST.
293    pub fn with_body(mut self, body: Body) -> Self {
294        self.body = Some(body);
295        self.with_method(Method::POST)
296    }
297
298    /// Sets the body of the request to a JSON value.
299    pub fn with_json(self, json: Value) -> Self {
300        self.with_body(Body::Json(json))
301    }
302
303    /// Sets the body of the request to a form.
304    pub fn with_form(self, form: DashMap<String, String>) -> Self {
305        self.with_body(Body::Form(form))
306    }
307
308    /// Sets the body of the request to a byte slice.
309    pub fn with_bytes(self, bytes: Bytes) -> Self {
310        self.with_body(Body::Bytes(bytes))
311    }
312
313    /// Adds a value to the request's metadata.
314    pub fn with_meta(self, key: &str, value: Value) -> Self {
315        self.meta.insert(Cow::Owned(key.to_owned()), value);
316        self
317    }
318
319    const RETRY_ATTEMPTS_KEY: &str = "retry_attempts";
320
321    /// Gets the number of times the request has been retried.
322    pub fn get_retry_attempts(&self) -> u32 {
323        self.meta
324            .get(Self::RETRY_ATTEMPTS_KEY)
325            .and_then(|v| v.value().as_u64())
326            .unwrap_or(0) as u32
327    }
328
329    /// Increments the retry count for the request.
330    pub fn increment_retry_attempts(&mut self) {
331        let current_attempts = self.get_retry_attempts();
332        self.meta.insert(
333            Cow::Borrowed(Self::RETRY_ATTEMPTS_KEY),
334            Value::from(current_attempts + 1),
335        );
336    }
337
338    /// Generates a unique fingerprint for the request based on its URL, method, and body.
339    pub fn fingerprint(&self) -> String {
340        let mut hasher = XxHash64::default();
341        hasher.write(self.url.as_str().as_bytes());
342        hasher.write(self.method.as_str().as_bytes());
343
344        if let Some(ref body) = self.body {
345            match body {
346                Body::Json(json_val) => {
347                    if let Ok(serialized) = serde_json::to_string(json_val) {
348                        hasher.write(serialized.as_bytes());
349                    }
350                }
351                Body::Form(form_val) => {
352                    let mut form_string = String::new();
353                    for r in form_val.iter() {
354                        form_string.push_str(r.key());
355                        form_string.push_str(r.value());
356                    }
357                    hasher.write(form_string.as_bytes());
358                }
359                Body::Bytes(bytes_val) => {
360                    hasher.write(bytes_val);
361                }
362            }
363        }
364        format!("{:x}", hasher.finish())
365    }
366}