Skip to main content

spider_middleware/
referer.rs

1//! Middleware that fills `Referer` headers for follow-up requests.
2
3use async_trait::async_trait;
4use dashmap::DashMap;
5use reqwest::header::{HeaderValue, REFERER};
6use std::sync::Arc;
7use url::Url;
8
9use crate::middleware::{Middleware, MiddlewareAction};
10use log::{debug, info};
11use spider_util::error::SpiderError;
12use spider_util::request::Request;
13use spider_util::response::Response;
14
15/// Middleware that derives `Referer` values from request metadata and history.
16#[derive(Debug, Clone)]
17pub struct RefererMiddleware {
18    /// Whether to use same-origin only referer
19    pub same_origin_only: bool,
20    /// Maximum referer chain length to keep in memory
21    pub max_chain_length: usize,
22    /// Whether to include fragment in referer URL
23    pub include_fragment: bool,
24    /// Map of request ID to referer URL
25    referer_map: Arc<DashMap<String, Url>>,
26}
27
28impl Default for RefererMiddleware {
29    fn default() -> Self {
30        let middleware = RefererMiddleware {
31            same_origin_only: true,
32            max_chain_length: 1000,
33            include_fragment: false,
34            referer_map: Arc::new(DashMap::new()),
35        };
36        info!(
37            "Initializing RefererMiddleware with config: {:?}",
38            middleware
39        );
40        middleware
41    }
42}
43
44impl RefererMiddleware {
45    /// Creates a middleware with default settings.
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Set whether to use same-origin only referer.
51    pub fn same_origin_only(mut self, same_origin_only: bool) -> Self {
52        self.same_origin_only = same_origin_only;
53        self
54    }
55
56    /// Set the maximum referer chain length to keep in memory.
57    pub fn max_chain_length(mut self, max_chain_length: usize) -> Self {
58        self.max_chain_length = max_chain_length;
59        self
60    }
61
62    /// Set whether to include the fragment in the referer URL.
63    pub fn include_fragment(mut self, include_fragment: bool) -> Self {
64        self.include_fragment = include_fragment;
65        self
66    }
67
68    /// Extract referer from request metadata and clean it
69    fn get_referer_for_request(&self, request: &Request) -> Option<Url> {
70        if let Some(referer_value) = request.get_meta_ref("referer")
71            && let Some(referer_str) = referer_value.value().as_str()
72            && let Ok(url) = Url::parse(referer_str)
73        {
74            if self.same_origin_only {
75                let request_origin = format!(
76                    "{}://{}",
77                    request.url.scheme(),
78                    request.url.host_str().unwrap_or("")
79                );
80                let referer_origin = format!("{}://{}", url.scheme(), url.host_str().unwrap_or(""));
81
82                if request_origin == referer_origin {
83                    return Some(self.clean_url(&url));
84                }
85            } else {
86                return Some(self.clean_url(&url));
87            }
88        }
89
90        None
91    }
92
93    /// Clean URL by removing fragments if configured
94    fn clean_url(&self, url: &Url) -> Url {
95        if !self.include_fragment && url.fragment().is_some() {
96            let mut cleaned = url.clone();
97            cleaned.set_fragment(None);
98            cleaned
99        } else {
100            url.clone()
101        }
102    }
103
104    /// Generate a unique key for request tracking
105    fn request_key(&self, request: &Request) -> String {
106        // Simple key based on URL and method
107        format!("{}:{}", request.method, request.url)
108    }
109}
110
111#[async_trait]
112impl<C: Send + Sync> Middleware<C> for RefererMiddleware {
113    fn name(&self) -> &str {
114        "RefererMiddleware"
115    }
116
117    async fn process_request(
118        &self,
119        _client: &C,
120        mut request: Request,
121    ) -> Result<MiddlewareAction<Request>, SpiderError> {
122        let referer = self.get_referer_for_request(&request);
123        let referer = if let Some(ref_from_meta) = referer {
124            Some(ref_from_meta)
125        } else {
126            let request_key = self.request_key(&request);
127            self.referer_map
128                .get(&request_key)
129                .map(|entry| entry.value().clone())
130        };
131
132        let referer = if let Some(ref_url) = referer {
133            Some(ref_url)
134        } else if let Some(parent_id) = request.get_meta_ref("parent_request_id") {
135            if let Some(parent_id_str) = parent_id.value().as_str() {
136                self.referer_map
137                    .get(parent_id_str)
138                    .map(|entry| entry.value().clone())
139            } else {
140                None
141            }
142        } else {
143            None
144        };
145
146        if let Some(referer) = referer {
147            match HeaderValue::from_str(referer.as_str()) {
148                Ok(header_value) => {
149                    request.headers.insert(REFERER, header_value);
150                    debug!(
151                        "Set Referer header to: {} for request: {}",
152                        referer, request.url
153                    );
154                }
155                Err(e) => {
156                    debug!("Failed to set Referer header: {}", e);
157                }
158            }
159        }
160
161        Ok(MiddlewareAction::Continue(request))
162    }
163
164    async fn process_response(
165        &self,
166        response: Response,
167    ) -> Result<MiddlewareAction<Response>, SpiderError> {
168        let response_url = response.url.clone();
169        let request = response.request_from_response();
170        let request_id = format!("req_{:x}", seahash::hash(request.url.as_str().as_bytes()));
171
172        // Store mapping:
173        // 1. Request ID -> Response URL (for parent-child relationships)
174        // 2. Request key -> Response URL (for direct lookups)
175
176        let request_key = self.request_key(&request);
177        let cleaned_url = self.clean_url(&response_url);
178
179        if self.referer_map.len() < self.max_chain_length {
180            self.referer_map.insert(request_key, cleaned_url.clone());
181            self.referer_map.insert(request_id.clone(), cleaned_url);
182
183            debug!(
184                "Stored referer mapping for request {}: {}",
185                request.url, response_url
186            );
187        }
188
189        Ok(MiddlewareAction::Continue(response))
190    }
191}