1#[cfg(feature = "redis")]
2use anyhow::{Context, Result};
3#[cfg(feature = "redis")]
4use argus_common::types::CrawlJob;
5#[cfg(feature = "redis")]
6use async_trait::async_trait;
7#[cfg(feature = "redis")]
8use redis::{aio::ConnectionManager, AsyncCommands, RedisError};
9#[cfg(feature = "redis")]
10use std::collections::HashMap;
11
12#[cfg(feature = "redis")]
13type XAutoClaimResult = Vec<HashMap<String, Vec<(String, HashMap<String, String>)>>>;
14
15#[cfg(feature = "redis")]
16use crate::frontier::Frontier;
17
18#[cfg(feature = "redis")]
21pub struct StreamFrontier {
22 conn: ConnectionManager,
23 stream_key: String,
24 consumer_group: String,
25 consumer_name: String,
26 batch_size: usize,
27}
28
29#[cfg(feature = "redis")]
30impl StreamFrontier {
31 pub async fn new(
39 redis_url: &str,
40 stream_key: Option<String>,
41 consumer_group: Option<String>,
42 consumer_name: String,
43 ) -> Result<Self> {
44 let client = redis::Client::open(redis_url)?;
45 let conn = ConnectionManager::new(client).await?;
46
47 let stream_key = stream_key.unwrap_or_else(|| "argus:jobs".to_string());
48 let consumer_group = consumer_group.unwrap_or_else(|| "workers".to_string());
49
50 let mut frontier = Self {
51 conn,
52 stream_key,
53 consumer_group,
54 consumer_name,
55 batch_size: 10,
56 };
57
58 frontier.ensure_consumer_group().await?;
60
61 Ok(frontier)
62 }
63
64 pub fn with_batch_size(mut self, size: usize) -> Self {
66 self.batch_size = size;
67 self
68 }
69
70 async fn ensure_consumer_group(&mut self) -> Result<()> {
71 let result: Result<String, RedisError> = redis::cmd("XGROUP")
72 .arg("CREATE")
73 .arg(&self.stream_key)
74 .arg(&self.consumer_group)
75 .arg("0")
76 .arg("MKSTREAM")
77 .query_async(&mut self.conn)
78 .await;
79
80 match result {
81 Ok(_) => Ok(()),
82 Err(e) => {
83 if e.to_string().contains("BUSYGROUP") {
85 Ok(())
86 } else {
87 Err(e.into())
88 }
89 }
90 }
91 }
92
93 pub async fn ack(&mut self, message_id: &str) -> Result<()> {
95 let _: i64 = redis::cmd("XACK")
96 .arg(&self.stream_key)
97 .arg(&self.consumer_group)
98 .arg(message_id)
99 .query_async(&mut self.conn)
100 .await
101 .context("Failed to acknowledge message")?;
102
103 Ok(())
104 }
105
106 pub async fn pending_count(&mut self) -> Result<usize> {
108 let result: Vec<redis::Value> = redis::cmd("XPENDING")
109 .arg(&self.stream_key)
110 .arg(&self.consumer_group)
111 .query_async(&mut self.conn)
112 .await?;
113
114 if let Some(redis::Value::Int(count)) = result.first() {
115 Ok(*count as usize)
116 } else {
117 Ok(0)
118 }
119 }
120
121 pub async fn stream_len(&mut self) -> Result<usize> {
123 let len: usize = self.conn.xlen(&self.stream_key).await?;
124 Ok(len)
125 }
126
127 pub async fn claim_abandoned(
129 &mut self,
130 idle_time_ms: usize,
131 ) -> Result<Vec<(String, CrawlJob)>> {
132 let result: XAutoClaimResult = redis::cmd("XAUTOCLAIM")
133 .arg(&self.stream_key)
134 .arg(&self.consumer_group)
135 .arg(&self.consumer_name)
136 .arg(idle_time_ms)
137 .arg("0-0")
138 .arg("COUNT")
139 .arg(self.batch_size)
140 .query_async(&mut self.conn)
141 .await?;
142
143 let mut jobs = Vec::new();
144 for entry in result {
145 for (_, messages) in entry {
146 for (msg_id, fields) in messages {
147 if let Some(job_json) = fields.get("job") {
148 if let Ok(job) = serde_json::from_str::<CrawlJob>(job_json) {
149 jobs.push((msg_id, job));
150 }
151 }
152 }
153 }
154 }
155
156 Ok(jobs)
157 }
158}
159
160#[cfg(feature = "redis")]
161#[async_trait]
162impl Frontier for StreamFrontier {
163 async fn push(&self, job: CrawlJob) {
164 let job_json = match serde_json::to_string(&job) {
165 Ok(json) => json,
166 Err(e) => {
167 eprintln!("Failed to serialize job: {}", e);
168 return;
169 }
170 };
171
172 let mut conn = self.conn.clone();
173 let _: Result<String, _> = conn
174 .xadd(&self.stream_key, "*", &[("job", job_json.as_str())])
175 .await;
176 }
177
178 async fn pop(&self) -> Option<CrawlJob> {
179 let mut conn = self.conn.clone();
180
181 let result: Result<
183 Vec<HashMap<String, Vec<(String, HashMap<String, String>)>>>,
184 RedisError,
185 > = redis::cmd("XREADGROUP")
186 .arg("GROUP")
187 .arg(&self.consumer_group)
188 .arg(&self.consumer_name)
189 .arg("COUNT")
190 .arg(1)
191 .arg("BLOCK")
192 .arg(1000) .arg("STREAMS")
194 .arg(&self.stream_key)
195 .arg(">")
196 .query_async(&mut conn)
197 .await;
198
199 match result {
200 Ok(streams) => {
201 for stream in streams {
202 for (_, messages) in stream {
203 for (_msg_id, fields) in messages {
204 if let Some(job_json) = fields.get("job") {
205 if let Ok(job) = serde_json::from_str::<CrawlJob>(job_json) {
206 return Some(job);
209 }
210 }
211 }
212 }
213 }
214 None
215 }
216 Err(_) => None,
217 }
218 }
219}
220
221#[cfg(not(feature = "redis"))]
222pub struct StreamFrontier;
223
224#[cfg(not(feature = "redis"))]
225impl StreamFrontier {
226 pub async fn new(
227 _redis_url: &str,
228 _stream_key: Option<String>,
229 _consumer_group: Option<String>,
230 _consumer_name: String,
231 ) -> anyhow::Result<Self> {
232 anyhow::bail!("Redis Streams not enabled. Compile with 'redis' feature.")
233 }
234}
235
236#[cfg(all(test, feature = "redis"))]
237mod tests {
238 use super::*;
239
240 #[tokio::test]
241 #[ignore] async fn stream_frontier_basic() {
243 let frontier = StreamFrontier::new(
244 "redis://localhost:6379",
245 Some("test:stream".to_string()),
246 Some("test:group".to_string()),
247 "consumer1".to_string(),
248 )
249 .await
250 .unwrap();
251
252 let job = CrawlJob {
253 url: "https://example.com".to_string(),
254 normalized_url: "https://example.com".to_string(),
255 host: "example.com".to_string(),
256 depth: 0,
257 };
258
259 frontier.push(job.clone()).await;
260
261 let popped = frontier.pop().await;
262 assert!(popped.is_some());
263 assert_eq!(popped.unwrap().url, job.url);
264 }
265
266 #[tokio::test]
267 #[ignore] async fn stream_frontier_consumer_groups() {
269 let consumer1 = StreamFrontier::new(
270 "redis://localhost:6379",
271 Some("test:stream2".to_string()),
272 Some("test:group2".to_string()),
273 "consumer1".to_string(),
274 )
275 .await
276 .unwrap();
277
278 let consumer2 = StreamFrontier::new(
279 "redis://localhost:6379",
280 Some("test:stream2".to_string()),
281 Some("test:group2".to_string()),
282 "consumer2".to_string(),
283 )
284 .await
285 .unwrap();
286
287 for i in 0..10 {
289 let job = CrawlJob {
290 url: format!("https://example.com/{}", i),
291 normalized_url: format!("https://example.com/{}", i),
292 host: "example.com".to_string(),
293 depth: 0,
294 };
295 consumer1.push(job).await;
296 }
297
298 let job1 = consumer1.pop().await;
300 let job2 = consumer2.pop().await;
301
302 assert!(job1.is_some());
303 assert!(job2.is_some());
304 assert_ne!(job1.unwrap().url, job2.unwrap().url);
305 }
306
307 #[tokio::test]
308 #[ignore] async fn stream_frontier_stats() {
310 let mut frontier = StreamFrontier::new(
311 "redis://localhost:6379",
312 Some("test:stream3".to_string()),
313 Some("test:group3".to_string()),
314 "consumer1".to_string(),
315 )
316 .await
317 .unwrap();
318
319 for i in 0..5 {
321 let job = CrawlJob {
322 url: format!("https://example.com/{}", i),
323 normalized_url: format!("https://example.com/{}", i),
324 host: "example.com".to_string(),
325 depth: 0,
326 };
327 frontier.push(job).await;
328 }
329
330 let len = frontier.stream_len().await.unwrap();
331 assert_eq!(len, 5);
332 }
333}