Skip to main content

spider_lib/middlewares/
referer.rs

1use async_trait::async_trait;
2use dashmap::DashMap;
3use reqwest::header::{HeaderValue, REFERER};
4use std::sync::Arc;
5use url::Url;
6
7use crate::error::SpiderError;
8use crate::middleware::{Middleware, MiddlewareAction};
9use crate::request::Request;
10use crate::response::Response;
11use tracing::{debug, info};
12
13/// Referer middleware that automatically sets Referer headers
14/// based on the navigation chain
15#[derive(Debug, Clone)]
16pub struct RefererMiddleware {
17    /// Whether to use same-origin only referer
18    pub same_origin_only: bool,
19    /// Maximum referer chain length to keep in memory
20    pub max_chain_length: usize,
21    /// Whether to include fragment in referer URL
22    pub include_fragment: bool,
23    /// Map of request ID to referer URL
24    referer_map: Arc<DashMap<String, Url>>,
25}
26
27impl Default for RefererMiddleware {
28    fn default() -> Self {
29        let middleware = RefererMiddleware {
30            same_origin_only: true,
31            max_chain_length: 1000,
32            include_fragment: false,
33            referer_map: Arc::new(DashMap::new()),
34        };
35        info!(
36            "Initializing RefererMiddleware with config: {:?}",
37            middleware
38        );
39        middleware
40    }
41}
42
43impl RefererMiddleware {
44    /// Create a new RefererMiddleware with default config
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Set whether to use same-origin only referer.
50    pub fn same_origin_only(mut self, same_origin_only: bool) -> Self {
51        self.same_origin_only = same_origin_only;
52        self
53    }
54
55    /// Set the maximum referer chain length to keep in memory.
56    pub fn max_chain_length(mut self, max_chain_length: usize) -> Self {
57        self.max_chain_length = max_chain_length;
58        self
59    }
60
61    /// Set whether to include the fragment in the referer URL.
62    pub fn include_fragment(mut self, include_fragment: bool) -> Self {
63        self.include_fragment = include_fragment;
64        self
65    }
66
67    /// Extract referer from request metadata and clean it
68    fn get_referer_for_request(&self, request: &Request) -> Option<Url> {
69        if let Some(referer_value) = request.meta.get("referer")
70            && let Some(referer_str) = referer_value.value().as_str()
71            && let Ok(url) = Url::parse(referer_str)
72        {
73            if self.same_origin_only {
74                let request_origin = format!(
75                    "{}://{}",
76                    request.url.scheme(),
77                    request.url.host_str().unwrap_or("")
78                );
79                let referer_origin = format!("{}://{}", url.scheme(), url.host_str().unwrap_or(""));
80
81                if request_origin == referer_origin {
82                    return Some(self.clean_url(&url));
83                }
84            } else {
85                return Some(self.clean_url(&url));
86            }
87        }
88
89        None
90    }
91
92    /// Clean URL by removing fragments if configured
93    fn clean_url(&self, url: &Url) -> Url {
94        if !self.include_fragment && url.fragment().is_some() {
95            let mut cleaned = url.clone();
96            cleaned.set_fragment(None);
97            cleaned
98        } else {
99            url.clone()
100        }
101    }
102
103    /// Generate a unique key for request tracking
104    fn request_key(&self, request: &Request) -> String {
105        // Simple key based on URL and method
106        format!("{}:{}", request.method, request.url)
107    }
108}
109
110#[async_trait]
111impl<C: Send + Sync> Middleware<C> for RefererMiddleware {
112    fn name(&self) -> &str {
113        "RefererMiddleware"
114    }
115
116    async fn process_request(
117        &mut self,
118        _client: &C,
119        mut request: Request,
120    ) -> Result<MiddlewareAction<Request>, SpiderError> {
121        let referer = self.get_referer_for_request(&request);
122        let referer = if let Some(ref_from_meta) = referer {
123            Some(ref_from_meta)
124        } else {
125            let request_key = self.request_key(&request);
126            self.referer_map
127                .get(&request_key)
128                .map(|entry| entry.value().clone())
129        };
130
131        let referer = if let Some(ref_url) = referer {
132            Some(ref_url)
133        } else if let Some(parent_id) = request.meta.get("parent_request_id") {
134            if let Some(parent_id_str) = parent_id.value().as_str() {
135                self.referer_map
136                    .get(parent_id_str)
137                    .map(|entry| entry.value().clone())
138            } else {
139                None
140            }
141        } else {
142            None
143        };
144
145        if let Some(referer) = referer {
146            match HeaderValue::from_str(referer.as_str()) {
147                Ok(header_value) => {
148                    request.headers.insert(REFERER, header_value);
149                    debug!(
150                        "Set Referer header to: {} for request: {}",
151                        referer, request.url
152                    );
153                }
154                Err(e) => {
155                    debug!("Failed to set Referer header: {}", e);
156                }
157            }
158        }
159
160        Ok(MiddlewareAction::Continue(request))
161    }
162    
163    async fn process_response(
164        &mut self,
165        response: Response,
166    ) -> Result<MiddlewareAction<Response>, SpiderError> {
167        let response_url = response.url.clone();
168        let request = response.request_from_response();
169        let request_id = format!("req_{:x}", seahash::hash(request.url.as_str().as_bytes()));
170
171        // Store mapping:
172        // 1. Request ID -> Response URL (for parent-child relationships)
173        // 2. Request key -> Response URL (for direct lookups)
174
175        let request_key = self.request_key(&request);
176        let cleaned_url = self.clean_url(&response_url);
177
178        if self.referer_map.len() < self.max_chain_length {
179            self.referer_map.insert(request_key, cleaned_url.clone());
180            self.referer_map.insert(request_id.clone(), cleaned_url);
181
182            debug!(
183                "Stored referer mapping for request {}: {}",
184                request.url, response_url
185            );
186        }
187
188        Ok(MiddlewareAction::Continue(response))
189    }
190}