1use crate::checkpoint::SchedulerCheckpoint;
15use crate::error::SpiderError;
16use crate::request::Request;
17use dashmap::DashSet;
18use kanal::{AsyncReceiver, AsyncSender, bounded_async, unbounded_async};
19use std::collections::VecDeque;
20use std::sync::Arc;
21use std::sync::atomic::{AtomicUsize, Ordering};
22use tokio::sync::{Mutex, oneshot};
23use tracing::{debug, error, info};
24
25enum SchedulerMessage {
26 Enqueue(Box<Request>),
27 MarkAsVisited(String),
28 Shutdown,
29 TakeSnapshot(oneshot::Sender<SchedulerCheckpoint>),
30}
31
32pub struct Scheduler {
33 request_queue: Arc<Mutex<VecDeque<Request>>>,
34 visited_urls: DashSet<String>,
35 tx_internal: AsyncSender<SchedulerMessage>,
36 pending_requests: AtomicUsize,
37 salvaged_requests: Arc<Mutex<VecDeque<Request>>>,
38}
39
40impl Scheduler {
41 pub fn new(initial_state: Option<SchedulerCheckpoint>) -> (Arc<Self>, AsyncReceiver<Request>) {
43 let (tx_internal, rx_internal) = unbounded_async();
44 let (tx_req_out, rx_req_out) = bounded_async(1);
45
46 let (request_queue, visited_urls, pending_requests, salvaged_requests) = if let Some(
47 state,
48 ) =
49 initial_state
50 {
51 info!(
52 "Initializing scheduler from checkpoint with {} requests, {} visited URLs, and {} salvaged requests.",
53 state.request_queue.len(),
54 state.visited_urls.len(),
55 state.salvaged_requests.len(),
56 );
57 let pending = state.request_queue.len() + state.salvaged_requests.len();
58 (
59 Arc::new(Mutex::new(state.request_queue)),
60 state.visited_urls,
61 AtomicUsize::new(pending),
62 Arc::new(Mutex::new(state.salvaged_requests)),
63 )
64 } else {
65 (
66 Arc::new(Mutex::new(VecDeque::new())),
67 DashSet::new(),
68 AtomicUsize::new(0),
69 Arc::new(Mutex::new(VecDeque::new())),
70 )
71 };
72
73 let scheduler = Arc::new(Scheduler {
74 request_queue,
75 visited_urls,
76 tx_internal,
77 pending_requests,
78 salvaged_requests,
79 });
80
81 let scheduler_clone = Arc::clone(&scheduler);
82 tokio::spawn(async move {
83 scheduler_clone.run_loop(rx_internal, tx_req_out).await;
84 });
85
86 (scheduler, rx_req_out)
87 }
88
89 async fn run_loop(
90 &self,
91 rx_internal: AsyncReceiver<SchedulerMessage>,
92 tx_req_out: AsyncSender<Request>,
93 ) {
94 info!("Scheduler run_loop started.");
95 loop {
96 let maybe_request = if !tx_req_out.is_closed() && !self.is_idle() {
98 self.request_queue.lock().await.pop_front()
99 } else {
100 None
101 };
102
103 if let Some(request) = maybe_request {
104 tokio::select! {
105 biased;
106 send_res = tx_req_out.send(request) => {
107 if send_res.is_err() {
108 error!("Crawler receiver dropped. Scheduler can no longer send requests.");
109 }
110 self.pending_requests.fetch_sub(1, Ordering::SeqCst);
111 },
112 recv_res = rx_internal.recv() => {
113 self.pending_requests.fetch_sub(1, Ordering::SeqCst);
114 if !self.handle_message(recv_res).await {
115 break;
116 }
117 }
118 }
119 } else {
120 if !self.handle_message(rx_internal.recv().await).await {
122 break;
123 }
124 }
125 }
126 info!("Scheduler run_loop finished.");
127 }
128
129 async fn handle_message(&self, msg: Result<SchedulerMessage, kanal::ReceiveError>) -> bool {
130 match msg {
131 Ok(SchedulerMessage::Enqueue(boxed_request)) => {
132 let request = *boxed_request;
133 self.request_queue.lock().await.push_back(request);
134 self.pending_requests.fetch_add(1, Ordering::SeqCst);
135 true
136 }
137 Ok(SchedulerMessage::MarkAsVisited(fingerprint)) => {
138 self.visited_urls.insert(fingerprint.clone());
139 debug!("Marked URL as visited: {}", fingerprint);
140 true
141 }
142 Ok(SchedulerMessage::TakeSnapshot(responder)) => {
143 let visited_urls = self.visited_urls.iter().map(|item| item.clone()).collect();
144 let request_queue = self.request_queue.lock().await.clone();
145 let salvaged_requests = self.salvaged_requests.lock().await.clone();
146
147 let _ = responder.send(SchedulerCheckpoint {
148 request_queue,
149 visited_urls,
150 salvaged_requests,
151 });
152 true
153 }
154 Ok(SchedulerMessage::Shutdown) | Err(_) => {
155 info!("Scheduler received shutdown signal or channel closed. Exiting run_loop.");
156 false
157 }
158 }
159 }
160
161 pub async fn snapshot(&self) -> Result<SchedulerCheckpoint, SpiderError> {
163 let (tx, rx) = oneshot::channel();
164 self.tx_internal
165 .send(SchedulerMessage::TakeSnapshot(tx))
166 .await
167 .map_err(|e| {
168 SpiderError::GeneralError(format!(
169 "Scheduler: Failed to send snapshot request: {}",
170 e
171 ))
172 })?;
173 rx.await.map_err(|e| {
174 SpiderError::GeneralError(format!("Scheduler: Failed to receive snapshot: {}", e))
175 })
176 }
177
178 pub async fn enqueue_request(&self, request: Request) -> Result<(), SpiderError> {
180 if self
181 .tx_internal
182 .send(SchedulerMessage::Enqueue(Box::new(request.clone()))) .await
184 .is_err()
185 {
186 error!("Scheduler internal message channel is closed. Salvaging request.");
187 self.salvaged_requests.lock().await.push_back(request); return Err(SpiderError::GeneralError(
189 "Scheduler internal channel closed, request salvaged.".into(),
190 ));
191 }
192 Ok(())
193 }
194
195 pub async fn shutdown(&self) -> Result<(), SpiderError> {
197 self.tx_internal
198 .send(SchedulerMessage::Shutdown)
199 .await
200 .map_err(|e| {
201 SpiderError::GeneralError(format!(
202 "Scheduler: Failed to send shutdown signal: {}",
203 e
204 ))
205 })
206 }
207
208 pub async fn send_mark_as_visited(&self, fingerprint: String) -> Result<(), SpiderError> {
210 self.tx_internal
211 .send(SchedulerMessage::MarkAsVisited(fingerprint))
212 .await
213 .map_err(|e| {
214 SpiderError::GeneralError(format!(
215 "Scheduler: Failed to send MarkAsVisited message: {}",
216 e
217 ))
218 })
219 }
220
221 #[inline]
223 pub fn len(&self) -> usize {
224 self.pending_requests.load(Ordering::SeqCst)
225 }
226
227 #[inline]
229 pub fn is_empty(&self) -> bool {
230 self.len() == 0
231 }
232
233 #[inline]
235 pub fn is_idle(&self) -> bool {
236 self.is_empty()
237 }
238}