1use bytes::Bytes;
16use dashmap::DashMap;
17use http::header::HeaderMap;
18use reqwest::{Method, Url};
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21use std::borrow::Cow;
22use std::collections::HashMap;
23use std::hash::Hasher;
24use std::str::FromStr;
25use std::sync::Arc;
26use twox_hash::XxHash64;
27
28use crate::error::SpiderError;
29
30#[derive(Debug, Clone)]
31pub enum Body {
32 Json(Value),
33 Form(DashMap<String, String>),
34 Bytes(Bytes),
35}
36
37impl Serialize for Body {
39 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
40 where
41 S: serde::Serializer,
42 {
43 use serde::ser::SerializeMap;
44 let mut map = serializer.serialize_map(Some(1))?;
45
46 match self {
47 Body::Json(value) => map.serialize_entry("Json", value)?,
48 Body::Form(dashmap) => {
49 let hmap: HashMap<String, String> = dashmap.clone().into_iter().collect();
50 map.serialize_entry("Form", &hmap)?
51 }
52 Body::Bytes(bytes) => map.serialize_entry("Bytes", bytes)?,
53 }
54
55 map.end()
56 }
57}
58
59impl<'de> Deserialize<'de> for Body {
60 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
61 where
62 D: serde::Deserializer<'de>,
63 {
64 use serde::de::{self, MapAccess, Visitor};
65 use std::fmt;
66
67 struct BodyVisitor;
68
69 impl<'de> Visitor<'de> for BodyVisitor {
70 type Value = Body;
71
72 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
73 formatter.write_str("a body object")
74 }
75
76 fn visit_map<V>(self, mut map: V) -> Result<Body, V::Error>
77 where
78 V: MapAccess<'de>,
79 {
80 let entry = map.next_entry::<String, Value>()?;
81 let (key, value) = match entry {
82 Some((k, v)) => (k, v),
83 None => return Err(de::Error::custom("Expected a body variant")),
84 };
85
86 match key.as_str() {
87 "Json" => Ok(Body::Json(value)),
88 "Form" => {
89 let form_data: HashMap<String, String> =
90 serde_json::from_value(value).map_err(de::Error::custom)?;
91 let dashmap = DashMap::new();
92 for (k, v) in form_data {
93 dashmap.insert(k, v);
94 }
95 Ok(Body::Form(dashmap))
96 }
97 "Bytes" => {
98 let bytes: Bytes =
99 serde_json::from_value(value).map_err(de::Error::custom)?;
100 Ok(Body::Bytes(bytes))
101 }
102 _ => Err(de::Error::custom(format!("Unknown body variant: {}", key))),
103 }
104 }
105 }
106
107 deserializer.deserialize_map(BodyVisitor)
108 }
109}
110
111#[derive(Debug, Clone)]
112pub struct Request {
113 pub url: Url,
114 pub method: Method,
115 pub headers: HeaderMap,
116 pub body: Option<Body>,
117 pub meta: Arc<DashMap<Cow<'static, str>, Value>>,
118}
119
120impl Serialize for Request {
122 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
123 where
124 S: serde::Serializer,
125 {
126 use serde::ser::SerializeStruct;
127 let headers_vec: Vec<(String, String)> = self
129 .headers
130 .iter()
131 .filter_map(|(name, value)| {
132 value
133 .to_str()
134 .ok()
135 .map(|val_str| (name.as_str().to_string(), val_str.to_string()))
136 })
137 .collect();
138
139 let mut s = serializer.serialize_struct("Request", 5)?;
140 s.serialize_field("url", &self.url.as_str())?;
141 s.serialize_field("method", &self.method.as_str())?;
142 s.serialize_field("headers", &headers_vec)?;
143 s.serialize_field("body", &self.body)?;
144 s.end()
145 }
146}
147
148impl<'de> Deserialize<'de> for Request {
149 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
150 where
151 D: serde::Deserializer<'de>,
152 {
153 use serde::de::{self, MapAccess, Visitor};
154 use std::fmt;
155
156 #[derive(Deserialize)]
157 #[serde(field_identifier, rename_all = "lowercase")]
158 enum Field {
159 Url,
160 Method,
161 Headers,
162 Body,
163 }
164
165 struct RequestVisitor;
166
167 impl<'de> Visitor<'de> for RequestVisitor {
168 type Value = Request;
169
170 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
171 formatter.write_str("struct Request")
172 }
173
174 fn visit_map<V>(self, mut map: V) -> Result<Request, V::Error>
175 where
176 V: MapAccess<'de>,
177 {
178 let mut url = None;
179 let mut method = None;
180 let mut headers = None;
181 let mut body = None;
182
183 while let Some(key) = map.next_key()? {
184 match key {
185 Field::Url => {
186 if url.is_some() {
187 return Err(de::Error::duplicate_field("url"));
188 }
189 let url_str: String = map.next_value()?;
190 let parsed_url = Url::parse(&url_str).map_err(de::Error::custom)?;
191 url = Some(parsed_url);
192 }
193 Field::Method => {
194 if method.is_some() {
195 return Err(de::Error::duplicate_field("method"));
196 }
197 let method_str: String = map.next_value()?;
198 let parsed_method =
199 Method::from_str(&method_str).map_err(de::Error::custom)?;
200 method = Some(parsed_method);
201 }
202 Field::Headers => {
203 if headers.is_some() {
204 return Err(de::Error::duplicate_field("headers"));
205 }
206 let headers_vec: Vec<(String, String)> = map.next_value()?;
208 let mut header_map = HeaderMap::new();
209 for (name, value) in headers_vec {
210 if let Ok(header_name) =
211 http::header::HeaderName::from_bytes(name.as_bytes())
212 && let Ok(header_value) =
213 http::header::HeaderValue::from_str(&value)
214 {
215 header_map.insert(header_name, header_value);
216 }
217 }
218 headers = Some(header_map);
219 }
220 Field::Body => {
221 if body.is_some() {
222 return Err(de::Error::duplicate_field("body"));
223 }
224 body = Some(map.next_value()?);
225 }
226 }
227 }
228
229 let url = url.ok_or_else(|| de::Error::missing_field("url"))?;
230 let method = method.ok_or_else(|| de::Error::missing_field("method"))?;
231 let headers = headers.ok_or_else(|| de::Error::missing_field("headers"))?;
232 let body = body; Ok(Request {
235 url,
236 method,
237 headers,
238 body,
239 meta: Arc::new(DashMap::new()), })
241 }
242 }
243
244 const FIELDS: &[&str] = &["url", "method", "headers", "body"];
245 deserializer.deserialize_struct("Request", FIELDS, RequestVisitor)
246 }
247}
248
249impl Default for Request {
250 fn default() -> Self {
251 Self {
252 url: Url::parse("http://default.invalid").unwrap(),
253 method: Method::GET,
254 headers: HeaderMap::new(),
255 body: None,
256 meta: Arc::new(DashMap::new()),
257 }
258 }
259}
260
261impl Request {
262 pub fn new(url: Url) -> Self {
264 Request {
265 url,
266 method: Method::GET,
267 headers: HeaderMap::new(),
268 body: None,
269 meta: Arc::new(DashMap::new()),
270 }
271 }
272
273 pub fn with_method(mut self, method: Method) -> Self {
275 self.method = method;
276 self
277 }
278
279 pub fn with_header(mut self, name: &str, value: &str) -> Result<Self, SpiderError> {
281 let header_name =
282 reqwest::header::HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
283 SpiderError::HeaderValueError(format!("Invalid header name '{}': {}", name, e))
284 })?;
285 let header_value = reqwest::header::HeaderValue::from_str(value).map_err(|e| {
286 SpiderError::HeaderValueError(format!("Invalid header value '{}': {}", value, e))
287 })?;
288
289 self.headers.insert(header_name, header_value);
290 Ok(self)
291 }
292
293 pub fn with_body(mut self, body: Body) -> Self {
295 self.body = Some(body);
296 self.with_method(Method::POST)
297 }
298
299 pub fn with_json(self, json: Value) -> Self {
301 self.with_body(Body::Json(json))
302 }
303
304 pub fn with_form(self, form: DashMap<String, String>) -> Self {
306 self.with_body(Body::Form(form))
307 }
308
309 pub fn with_bytes(self, bytes: Bytes) -> Self {
311 self.with_body(Body::Bytes(bytes))
312 }
313
314 pub fn with_meta(self, key: &str, value: Value) -> Self {
316 self.meta.insert(Cow::Owned(key.to_owned()), value);
317 self
318 }
319
320 const RETRY_ATTEMPTS_KEY: &str = "retry_attempts";
321
322 pub fn get_retry_attempts(&self) -> u32 {
324 self.meta
325 .get(Self::RETRY_ATTEMPTS_KEY)
326 .and_then(|v| v.value().as_u64())
327 .unwrap_or(0) as u32
328 }
329
330 pub fn increment_retry_attempts(&mut self) {
332 let current_attempts = self.get_retry_attempts();
333 self.meta.insert(
334 Cow::Borrowed(Self::RETRY_ATTEMPTS_KEY),
335 Value::from(current_attempts + 1),
336 );
337 }
338
339 pub fn fingerprint(&self) -> String {
341 let mut hasher = XxHash64::default();
342 hasher.write(self.url.as_str().as_bytes());
343 hasher.write(self.method.as_str().as_bytes());
344
345 if let Some(ref body) = self.body {
346 match body {
347 Body::Json(json_val) => {
348 if let Ok(serialized) = serde_json::to_string(json_val) {
349 hasher.write(serialized.as_bytes());
350 }
351 }
352 Body::Form(form_val) => {
353 let mut form_string = String::new();
354 for r in form_val.iter() {
355 form_string.push_str(r.key());
356 form_string.push_str(r.value());
357 }
358 hasher.write(form_string.as_bytes());
359 }
360 Body::Bytes(bytes_val) => {
361 hasher.write(bytes_val);
362 }
363 }
364 }
365 format!("{:x}", hasher.finish())
366 }
367}