Skip to main content

argus_frontier/
stream.rs

1#[cfg(feature = "redis")]
2use anyhow::{Context, Result};
3#[cfg(feature = "redis")]
4use argus_common::types::CrawlJob;
5#[cfg(feature = "redis")]
6use async_trait::async_trait;
7#[cfg(feature = "redis")]
8use redis::{aio::ConnectionManager, AsyncCommands, RedisError};
9#[cfg(feature = "redis")]
10use std::collections::HashMap;
11
12#[cfg(feature = "redis")]
13type XAutoClaimResult = Vec<HashMap<String, Vec<(String, HashMap<String, String>)>>>;
14
15#[cfg(feature = "redis")]
16use crate::frontier::Frontier;
17
18/// Redis Streams-based frontier for high-throughput job distribution
19/// Provides better backpressure handling and consumer groups
20#[cfg(feature = "redis")]
21pub struct StreamFrontier {
22    conn: ConnectionManager,
23    stream_key: String,
24    consumer_group: String,
25    consumer_name: String,
26    batch_size: usize,
27}
28
29#[cfg(feature = "redis")]
30impl StreamFrontier {
31    /// Create a new Redis Streams frontier
32    ///
33    /// # Arguments
34    /// * `redis_url` - Redis connection URL
35    /// * `stream_key` - Stream name (default: "argus:jobs")
36    /// * `consumer_group` - Consumer group name (default: "workers")
37    /// * `consumer_name` - Unique consumer identifier
38    pub async fn new(
39        redis_url: &str,
40        stream_key: Option<String>,
41        consumer_group: Option<String>,
42        consumer_name: String,
43    ) -> Result<Self> {
44        let client = redis::Client::open(redis_url)?;
45        let conn = ConnectionManager::new(client).await?;
46
47        let stream_key = stream_key.unwrap_or_else(|| "argus:jobs".to_string());
48        let consumer_group = consumer_group.unwrap_or_else(|| "workers".to_string());
49
50        let mut frontier = Self {
51            conn,
52            stream_key,
53            consumer_group,
54            consumer_name,
55            batch_size: 10,
56        };
57
58        // Create consumer group if it doesn't exist
59        frontier.ensure_consumer_group().await?;
60
61        Ok(frontier)
62    }
63
64    /// Set batch size for reading from stream
65    pub fn with_batch_size(mut self, size: usize) -> Self {
66        self.batch_size = size;
67        self
68    }
69
70    async fn ensure_consumer_group(&mut self) -> Result<()> {
71        let result: Result<String, RedisError> = redis::cmd("XGROUP")
72            .arg("CREATE")
73            .arg(&self.stream_key)
74            .arg(&self.consumer_group)
75            .arg("0")
76            .arg("MKSTREAM")
77            .query_async(&mut self.conn)
78            .await;
79
80        match result {
81            Ok(_) => Ok(()),
82            Err(e) => {
83                // Ignore "BUSYGROUP" error (group already exists)
84                if e.to_string().contains("BUSYGROUP") {
85                    Ok(())
86                } else {
87                    Err(e.into())
88                }
89            }
90        }
91    }
92
93    /// Acknowledge a job as processed
94    pub async fn ack(&mut self, message_id: &str) -> Result<()> {
95        let _: i64 = redis::cmd("XACK")
96            .arg(&self.stream_key)
97            .arg(&self.consumer_group)
98            .arg(message_id)
99            .query_async(&mut self.conn)
100            .await
101            .context("Failed to acknowledge message")?;
102
103        Ok(())
104    }
105
106    /// Get pending messages count for this consumer
107    pub async fn pending_count(&mut self) -> Result<usize> {
108        let result: Vec<redis::Value> = redis::cmd("XPENDING")
109            .arg(&self.stream_key)
110            .arg(&self.consumer_group)
111            .query_async(&mut self.conn)
112            .await?;
113
114        if let Some(redis::Value::Int(count)) = result.first() {
115            Ok(*count as usize)
116        } else {
117            Ok(0)
118        }
119    }
120
121    /// Get stream length
122    pub async fn stream_len(&mut self) -> Result<usize> {
123        let len: usize = self.conn.xlen(&self.stream_key).await?;
124        Ok(len)
125    }
126
127    /// Claim abandoned messages (from dead consumers)
128    pub async fn claim_abandoned(
129        &mut self,
130        idle_time_ms: usize,
131    ) -> Result<Vec<(String, CrawlJob)>> {
132        let result: XAutoClaimResult = redis::cmd("XAUTOCLAIM")
133            .arg(&self.stream_key)
134            .arg(&self.consumer_group)
135            .arg(&self.consumer_name)
136            .arg(idle_time_ms)
137            .arg("0-0")
138            .arg("COUNT")
139            .arg(self.batch_size)
140            .query_async(&mut self.conn)
141            .await?;
142
143        let mut jobs = Vec::new();
144        for entry in result {
145            for (_, messages) in entry {
146                for (msg_id, fields) in messages {
147                    if let Some(job_json) = fields.get("job") {
148                        if let Ok(job) = serde_json::from_str::<CrawlJob>(job_json) {
149                            jobs.push((msg_id, job));
150                        }
151                    }
152                }
153            }
154        }
155
156        Ok(jobs)
157    }
158}
159
160#[cfg(feature = "redis")]
161#[async_trait]
162impl Frontier for StreamFrontier {
163    async fn push(&self, job: CrawlJob) {
164        let job_json = match serde_json::to_string(&job) {
165            Ok(json) => json,
166            Err(e) => {
167                eprintln!("Failed to serialize job: {}", e);
168                return;
169            }
170        };
171
172        let mut conn = self.conn.clone();
173        let _: Result<String, _> = conn
174            .xadd(&self.stream_key, "*", &[("job", job_json.as_str())])
175            .await;
176    }
177
178    async fn pop(&self) -> Option<CrawlJob> {
179        let mut conn = self.conn.clone();
180
181        // Read from consumer group
182        let result: Result<
183            Vec<HashMap<String, Vec<(String, HashMap<String, String>)>>>,
184            RedisError,
185        > = redis::cmd("XREADGROUP")
186            .arg("GROUP")
187            .arg(&self.consumer_group)
188            .arg(&self.consumer_name)
189            .arg("COUNT")
190            .arg(1)
191            .arg("BLOCK")
192            .arg(1000) // 1 second timeout
193            .arg("STREAMS")
194            .arg(&self.stream_key)
195            .arg(">")
196            .query_async(&mut conn)
197            .await;
198
199        match result {
200            Ok(streams) => {
201                for stream in streams {
202                    for (_, messages) in stream {
203                        for (_msg_id, fields) in messages {
204                            if let Some(job_json) = fields.get("job") {
205                                if let Ok(job) = serde_json::from_str::<CrawlJob>(job_json) {
206                                    // Store message ID for later acknowledgment
207                                    // In production, you'd want to track this properly
208                                    return Some(job);
209                                }
210                            }
211                        }
212                    }
213                }
214                None
215            }
216            Err(_) => None,
217        }
218    }
219}
220
221#[cfg(not(feature = "redis"))]
222pub struct StreamFrontier;
223
224#[cfg(not(feature = "redis"))]
225impl StreamFrontier {
226    pub async fn new(
227        _redis_url: &str,
228        _stream_key: Option<String>,
229        _consumer_group: Option<String>,
230        _consumer_name: String,
231    ) -> anyhow::Result<Self> {
232        anyhow::bail!("Redis Streams not enabled. Compile with 'redis' feature.")
233    }
234}
235
236#[cfg(all(test, feature = "redis"))]
237mod tests {
238    use super::*;
239
240    #[tokio::test]
241    #[ignore] // Requires Redis
242    async fn stream_frontier_basic() {
243        let frontier = StreamFrontier::new(
244            "redis://localhost:6379",
245            Some("test:stream".to_string()),
246            Some("test:group".to_string()),
247            "consumer1".to_string(),
248        )
249        .await
250        .unwrap();
251
252        let job = CrawlJob {
253            url: "https://example.com".to_string(),
254            normalized_url: "https://example.com".to_string(),
255            host: "example.com".to_string(),
256            depth: 0,
257        };
258
259        frontier.push(job.clone()).await;
260
261        let popped = frontier.pop().await;
262        assert!(popped.is_some());
263        assert_eq!(popped.unwrap().url, job.url);
264    }
265
266    #[tokio::test]
267    #[ignore] // Requires Redis
268    async fn stream_frontier_consumer_groups() {
269        let consumer1 = StreamFrontier::new(
270            "redis://localhost:6379",
271            Some("test:stream2".to_string()),
272            Some("test:group2".to_string()),
273            "consumer1".to_string(),
274        )
275        .await
276        .unwrap();
277
278        let consumer2 = StreamFrontier::new(
279            "redis://localhost:6379",
280            Some("test:stream2".to_string()),
281            Some("test:group2".to_string()),
282            "consumer2".to_string(),
283        )
284        .await
285        .unwrap();
286
287        // Push multiple jobs
288        for i in 0..10 {
289            let job = CrawlJob {
290                url: format!("https://example.com/{}", i),
291                normalized_url: format!("https://example.com/{}", i),
292                host: "example.com".to_string(),
293                depth: 0,
294            };
295            consumer1.push(job).await;
296        }
297
298        // Both consumers should get different jobs
299        let job1 = consumer1.pop().await;
300        let job2 = consumer2.pop().await;
301
302        assert!(job1.is_some());
303        assert!(job2.is_some());
304        assert_ne!(job1.unwrap().url, job2.unwrap().url);
305    }
306
307    #[tokio::test]
308    #[ignore] // Requires Redis
309    async fn stream_frontier_stats() {
310        let mut frontier = StreamFrontier::new(
311            "redis://localhost:6379",
312            Some("test:stream3".to_string()),
313            Some("test:group3".to_string()),
314            "consumer1".to_string(),
315        )
316        .await
317        .unwrap();
318
319        // Push jobs
320        for i in 0..5 {
321            let job = CrawlJob {
322                url: format!("https://example.com/{}", i),
323                normalized_url: format!("https://example.com/{}", i),
324                host: "example.com".to_string(),
325                depth: 0,
326            };
327            frontier.push(job).await;
328        }
329
330        let len = frontier.stream_len().await.unwrap();
331        assert_eq!(len, 5);
332    }
333}