1use crate::SchedulerCheckpoint;
15
16use crate::error::SpiderError;
17use crate::request::Request;
18use crate::utils::BloomFilter;
19use crossbeam::queue::SegQueue;
20use kanal::{AsyncReceiver, AsyncSender, bounded_async, unbounded_async};
21use moka::sync::Cache;
22use std::sync::Arc;
23use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
24use tracing::{debug, error, info};
25
26enum SchedulerMessage {
27 Enqueue(Box<Request>),
28 MarkAsVisited(String),
29 Shutdown,
30}
31
32pub struct Scheduler {
33 request_queue: SegQueue<Request>,
34 visited_urls: Cache<String, bool>,
35 bloom_filter: BloomFilter,
36 tx_internal: AsyncSender<SchedulerMessage>,
37 pending_requests: AtomicUsize,
38 salvaged_requests: SegQueue<Request>,
39 pub(crate) is_shutting_down: AtomicBool,
40}
41
42impl Scheduler {
43 pub fn new(
45 #[cfg(feature = "checkpoint")] initial_state: Option<SchedulerCheckpoint>,
46 #[cfg(not(feature = "checkpoint"))] _initial_state: Option<SchedulerCheckpoint>,
47 ) -> (Arc<Self>, AsyncReceiver<Request>) {
48 let (tx_internal, rx_internal) = unbounded_async();
49
50 let (tx_req_out, rx_req_out) = bounded_async(100);
51
52 let request_queue: SegQueue<Request>;
53 let visited_urls: Cache<String, bool>;
54 let pending_requests: AtomicUsize;
55 let salvaged_requests: SegQueue<Request>;
56
57 #[cfg(feature = "checkpoint")]
58 if let Some(state) = initial_state {
59 info!(
60 "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
61 state.request_queue.len(),
62 state.visited_urls.len(),
63 state.salvaged_requests.len(),
64 );
65 let pending = state.request_queue.len() + state.salvaged_requests.len();
66 request_queue = SegQueue::new();
67 for request in state.request_queue {
68 request_queue.push(request);
69 }
70
71 visited_urls = Cache::builder().max_capacity(100000).build();
72 for url in state.visited_urls {
73 visited_urls.insert(url, true);
74 }
75
76 pending_requests = AtomicUsize::new(pending);
77 salvaged_requests = SegQueue::new();
78 for request in state.salvaged_requests {
79 salvaged_requests.push(request);
80 }
81 } else {
82 request_queue = SegQueue::new();
83 visited_urls = Cache::builder().max_capacity(100000).build();
84 pending_requests = AtomicUsize::new(0);
85 salvaged_requests = SegQueue::new();
86 }
87
88 #[cfg(not(feature = "checkpoint"))]
89 {
90 request_queue = SegQueue::new();
91 visited_urls = Cache::builder().max_capacity(100000).build();
92 pending_requests = AtomicUsize::new(0);
93 salvaged_requests = SegQueue::new();
94 }
95
96 let scheduler = Arc::new(Scheduler {
97 request_queue,
98 visited_urls,
99 bloom_filter: BloomFilter::new(1000000, 3),
100 tx_internal,
101 pending_requests,
102 salvaged_requests,
103 is_shutting_down: AtomicBool::new(false),
104 });
105
106 let scheduler_clone = Arc::clone(&scheduler);
107 tokio::spawn(async move {
108 scheduler_clone.run_loop(rx_internal, tx_req_out).await;
109 });
110
111 (scheduler, rx_req_out)
112 }
113
114 async fn run_loop(
115 &self,
116 rx_internal: AsyncReceiver<SchedulerMessage>,
117 tx_req_out: AsyncSender<Request>,
118 ) {
119 info!("Scheduler run_loop started.");
120 loop {
121 let maybe_request = if !tx_req_out.is_closed() && !self.is_idle() {
122 self.request_queue.pop()
123 } else {
124 None
125 };
126
127 if let Some(request) = maybe_request {
128 tokio::select! {
129 biased;
130 send_res = tx_req_out.send(request) => {
131 if send_res.is_err() {
132 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
133 }
134 self.pending_requests.fetch_sub(1, Ordering::SeqCst);
135 },
136 recv_res = rx_internal.recv() => {
137 self.pending_requests.fetch_sub(1, Ordering::SeqCst);
138 if !self.handle_message(recv_res).await {
139 break;
140 }
141 }
142 }
143 } else if !self.handle_message(rx_internal.recv().await).await {
144 break;
145 }
146 }
147 info!("Scheduler run_loop finished.");
148 }
149
150 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
151 match msg {
152 Ok(SchedulerMessage::Enqueue(boxed_request)) => {
153 let request = *boxed_request;
154 self.request_queue.push(request);
155 self.pending_requests.fetch_add(1, Ordering::SeqCst);
156 true
157 }
158 Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
159 self.visited_urls.insert(fingerprint.clone(), true);
160 self.bloom_filter.add(&fingerprint);
162 debug!("Marked URL as visited: {}", fingerprint);
163 true
164 }
165 Ok(SchedulerMessage::Shutdown) | Err(_) => {
166 info!("Scheduler received shutdown signal or channel closed. Exiting run_loop.");
167 self.is_shutting_down.store(true, Ordering::SeqCst);
168 false
169 }
170 }
171 }
172
173 #[cfg(feature = "checkpoint")]
176 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
177 let visited_urls = dashmap::DashSet::new();
178 for entry in self.visited_urls.iter() {
179 let (key, _) = entry;
180 visited_urls.insert(key.as_ref().clone());
181 }
182
183 let mut request_queue = std::collections::VecDeque::new();
184 let mut temp_requests = Vec::new();
185 while let Some(request) = self.request_queue.pop() {
186 temp_requests.push(request);
187 }
188
189 for request in temp_requests.into_iter().rev() {
190 request_queue.push_back(request.clone());
191 self.request_queue.push(request);
192 }
193
194 let mut salvaged_requests = std::collections::VecDeque::new();
195 let mut temp_salvaged = Vec::new();
196 while let Some(request) = self.salvaged_requests.pop() {
197 temp_salvaged.push(request);
198 }
199
200 for request in temp_salvaged.into_iter().rev() {
201 salvaged_requests.push_back(request.clone());
202 self.salvaged_requests.push(request);
203 }
204
205 Ok(SchedulerCheckpoint {
206 request_queue,
207 visited_urls,
208 salvaged_requests,
209 })
210 }
211
212 pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
214 if !self.should_enqueue_request(&request) {
215 debug!("Request already visited, skipping: {}", request.url);
216 return Ok(());
217 }
218
219 if self
220 .tx_internal
221 .send(SchedulerMessage::Enqueue(Box::new(request.clone())))
222 .await
223 .is_err()
224 {
225 if !self.is_shutting_down.load(Ordering::SeqCst) {
226 error!("Scheduler internal message channel is closed. Salvaging request.");
227 }
228 self.salvaged_requests.push(request);
229 return Err(SpiderError::GeneralError(
230 "Scheduler internal channel closed, request salvaged.".into(),
231 ));
232 }
233 Ok(())
234 }
235
236 pub async fn shutdown(&self) -> Result<(), SpiderError> {
238 self.is_shutting_down.store(true, Ordering::SeqCst);
239
240 self.tx_internal
241 .send(SchedulerMessage::Shutdown)
242 .await
243 .map_err(|e| {
244 SpiderError::GeneralError(format!(
245 "Scheduler: Failed to send shutdown signal: {}",
246 e
247 ))
248 })
249 }
250
251 pub async fn send_mark_as_visited(&self, fingerprint: String) -> Result<(), SpiderError> {
253 self.tx_internal
254 .send(SchedulerMessage::MarkAsVisited(fingerprint))
255 .await
256 .map_err(|e| {
257
258 if !self.is_shutting_down.load(Ordering::SeqCst) {
259 error!("Scheduler internal message channel is closed. Failed to mark URL as visited: {}", e);
260 }
261 SpiderError::GeneralError(format!(
262 "Scheduler: Failed to send MarkAsVisited message: {}",
263 e
264 ))
265 })
266 }
267
268 pub fn has_been_visited(&self, fingerprint: &str) -> bool {
270 if !self.bloom_filter.might_contain(&fingerprint) {
271 return false;
272 }
273
274 self.visited_urls.contains_key(fingerprint)
275 }
276
277 pub fn should_enqueue_request(&self, request: &Request) -> bool {
279 let fingerprint = request.fingerprint();
280 !self.has_been_visited(&fingerprint)
281 }
282
283 #[inline]
285 pub fn len(&self) -> usize {
286 self.pending_requests.load(Ordering::SeqCst)
287 }
288
289 #[inline]
291 pub fn is_empty(&self) -> bool {
292 self.len() == 0
293 }
294
295 #[inline]
297 pub fn is_idle(&self) -> bool {
298 self.is_empty()
299 }
300}