Skip to main content

oxirs_stream/
utils.rs

1//! # Stream Utilities
2//!
3//! Utility functions and helpers for common stream operations.
4
5use crate::event::StreamEvent;
6use crate::{Stream, StreamConfig};
7use anyhow::Result;
8use std::time::Duration;
9
10/// Batch processor for processing multiple events efficiently
11pub struct BatchProcessor {
12    batch_size: usize,
13    timeout: Duration,
14}
15
16impl BatchProcessor {
17    /// Create a new batch processor
18    pub fn new(batch_size: usize, timeout: Duration) -> Self {
19        Self {
20            batch_size,
21            timeout,
22        }
23    }
24
25    /// Process events in batches with a callback
26    pub async fn process<F, Fut>(&self, stream: &mut Stream, mut callback: F) -> Result<usize>
27    where
28        F: FnMut(Vec<StreamEvent>) -> Fut,
29        Fut: std::future::Future<Output = Result<()>>,
30    {
31        let mut batch = Vec::with_capacity(self.batch_size);
32        let mut total_processed = 0;
33        let start = tokio::time::Instant::now();
34
35        loop {
36            match tokio::time::timeout(self.timeout, stream.consume()).await {
37                Ok(Ok(Some(event))) => {
38                    batch.push(event);
39
40                    if batch.len() >= self.batch_size {
41                        callback(std::mem::take(&mut batch)).await?;
42                        total_processed += self.batch_size;
43                    }
44                }
45                Ok(Ok(None)) => {
46                    // No more events, process remaining batch
47                    if !batch.is_empty() {
48                        let count = batch.len();
49                        callback(std::mem::take(&mut batch)).await?;
50                        total_processed += count;
51                    }
52                    break;
53                }
54                Ok(Err(e)) => {
55                    return Err(e);
56                }
57                Err(_) => {
58                    // Timeout - process what we have
59                    if !batch.is_empty() {
60                        let count = batch.len();
61                        callback(std::mem::take(&mut batch)).await?;
62                        total_processed += count;
63                    }
64
65                    // Check if we should continue or stop
66                    if start.elapsed() > self.timeout * 2 {
67                        break;
68                    }
69                }
70            }
71        }
72
73        Ok(total_processed)
74    }
75}
76
77/// Type alias for event predicate functions
78type EventPredicate = Box<dyn Fn(&StreamEvent) -> bool + Send + Sync>;
79
80/// Event filter builder for creating complex event filters
81pub struct EventFilter {
82    predicates: Vec<EventPredicate>,
83}
84
85impl EventFilter {
86    /// Create a new event filter
87    pub fn new() -> Self {
88        Self {
89            predicates: Vec::new(),
90        }
91    }
92
93    /// Add a predicate to the filter
94    pub fn add_predicate<F>(mut self, predicate: F) -> Self
95    where
96        F: Fn(&StreamEvent) -> bool + Send + Sync + 'static,
97    {
98        self.predicates.push(Box::new(predicate));
99        self
100    }
101
102    /// Filter events by subject pattern
103    pub fn by_subject(self, pattern: String) -> Self {
104        self.add_predicate(move |event| match event {
105            StreamEvent::TripleAdded { subject, .. } => subject.contains(&pattern),
106            StreamEvent::TripleRemoved { subject, .. } => subject.contains(&pattern),
107            _ => false,
108        })
109    }
110
111    /// Filter events by predicate pattern
112    pub fn by_predicate(self, pattern: String) -> Self {
113        self.add_predicate(move |event| match event {
114            StreamEvent::TripleAdded { predicate, .. } => predicate.contains(&pattern),
115            StreamEvent::TripleRemoved { predicate, .. } => predicate.contains(&pattern),
116            _ => false,
117        })
118    }
119
120    /// Filter events by graph
121    pub fn by_graph(self, graph_name: String) -> Self {
122        self.add_predicate(move |event| match event {
123            StreamEvent::TripleAdded { graph, .. } => {
124                graph.as_ref().is_some_and(|g| g == &graph_name)
125            }
126            StreamEvent::TripleRemoved { graph, .. } => {
127                graph.as_ref().is_some_and(|g| g == &graph_name)
128            }
129            _ => false,
130        })
131    }
132
133    /// Test if an event matches all predicates
134    pub fn matches(&self, event: &StreamEvent) -> bool {
135        self.predicates.iter().all(|predicate| predicate(event))
136    }
137
138    /// Filter a batch of events
139    pub fn filter_batch(&self, events: Vec<StreamEvent>) -> Vec<StreamEvent> {
140        events.into_iter().filter(|e| self.matches(e)).collect()
141    }
142}
143
144impl Default for EventFilter {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150/// Stream statistics aggregator
151#[derive(Debug, Clone, Default)]
152pub struct StreamStats {
153    pub total_events: u64,
154    pub events_per_second: f64,
155    pub avg_event_size: u64,
156    pub total_bytes: u64,
157    pub error_count: u64,
158    pub start_time: Option<std::time::Instant>,
159}
160
161impl StreamStats {
162    /// Create a new stream statistics aggregator
163    pub fn new() -> Self {
164        Self {
165            start_time: Some(std::time::Instant::now()),
166            ..Default::default()
167        }
168    }
169
170    /// Record an event
171    pub fn record_event(&mut self, event_size: u64) {
172        self.total_events += 1;
173        self.total_bytes += event_size;
174
175        if let Some(start) = self.start_time {
176            let elapsed = start.elapsed().as_secs_f64();
177            if elapsed > 0.0 {
178                self.events_per_second = self.total_events as f64 / elapsed;
179            }
180        }
181
182        if self.total_events > 0 {
183            self.avg_event_size = self.total_bytes / self.total_events;
184        }
185    }
186
187    /// Record an error
188    pub fn record_error(&mut self) {
189        self.error_count += 1;
190    }
191
192    /// Get the error rate
193    pub fn error_rate(&self) -> f64 {
194        if self.total_events == 0 {
195            return 0.0;
196        }
197        self.error_count as f64 / self.total_events as f64
198    }
199
200    /// Reset statistics
201    pub fn reset(&mut self) {
202        *self = Self::new();
203    }
204}
205
206/// Stream multiplexer for consuming from multiple streams
207pub struct StreamMultiplexer {
208    streams: Vec<Stream>,
209}
210
211impl StreamMultiplexer {
212    /// Create a new stream multiplexer
213    pub fn new(streams: Vec<Stream>) -> Self {
214        Self { streams }
215    }
216
217    /// Consume from all streams round-robin
218    pub async fn consume_round_robin(&mut self) -> Result<Option<StreamEvent>> {
219        for stream in &mut self.streams {
220            if let Some(event) = stream.consume().await? {
221                return Ok(Some(event));
222            }
223        }
224        Ok(None)
225    }
226
227    /// Consume from all streams in parallel and return the first available event
228    pub async fn consume_first_available(&mut self) -> Result<Option<StreamEvent>> {
229        use futures::future::select_all;
230
231        let futures: Vec<_> = self
232            .streams
233            .iter_mut()
234            .map(|stream| Box::pin(stream.consume()))
235            .collect();
236
237        if futures.is_empty() {
238            return Ok(None);
239        }
240
241        let (result, _index, _remaining) = select_all(futures).await;
242        result
243    }
244
245    /// Get the number of streams
246    pub fn len(&self) -> usize {
247        self.streams.len()
248    }
249
250    /// Check if the multiplexer is empty
251    pub fn is_empty(&self) -> bool {
252        self.streams.is_empty()
253    }
254}
255
256/// Helper to create a stream with sensible defaults for development
257pub async fn create_dev_stream(topic: &str) -> Result<Stream> {
258    let config = StreamConfig::development(topic);
259    Stream::new(config).await
260}
261
262/// Helper to create a stream with production settings
263pub async fn create_prod_stream(topic: &str) -> Result<Stream> {
264    let config = StreamConfig::production(topic);
265    Stream::new(config).await
266}
267
268/// Simple rate limiter for controlling event publishing rate
269pub struct SimpleRateLimiter {
270    permits_per_second: u64,
271    last_refill: tokio::time::Instant,
272    available_permits: u64,
273}
274
275impl SimpleRateLimiter {
276    /// Create a new rate limiter
277    pub fn new(permits_per_second: u64) -> Self {
278        Self {
279            permits_per_second,
280            last_refill: tokio::time::Instant::now(),
281            available_permits: permits_per_second,
282        }
283    }
284
285    /// Acquire a permit, blocking if necessary
286    pub async fn acquire(&mut self) -> Result<()> {
287        loop {
288            // Refill permits based on elapsed time
289            let now = tokio::time::Instant::now();
290            let elapsed = now.duration_since(self.last_refill);
291            let new_permits = (elapsed.as_secs_f64() * self.permits_per_second as f64) as u64;
292
293            if new_permits > 0 {
294                self.available_permits =
295                    (self.available_permits + new_permits).min(self.permits_per_second);
296                self.last_refill = now;
297            }
298
299            if self.available_permits > 0 {
300                self.available_permits -= 1;
301                return Ok(());
302            }
303
304            // Wait before checking again
305            tokio::time::sleep(Duration::from_millis(10)).await;
306        }
307    }
308}
309
310/// Event sampler for sampling events at a specified rate
311pub struct EventSampler {
312    sample_rate: f64,
313    count: u64,
314}
315
316impl EventSampler {
317    /// Create a new event sampler
318    ///
319    /// # Arguments
320    /// * `sample_rate` - Fraction of events to keep (0.0 to 1.0)
321    pub fn new(sample_rate: f64) -> Self {
322        assert!(
323            (0.0..=1.0).contains(&sample_rate),
324            "Sample rate must be between 0 and 1"
325        );
326        Self {
327            sample_rate,
328            count: 0,
329        }
330    }
331
332    /// Check if the current event should be sampled
333    pub fn should_sample(&mut self) -> bool {
334        self.count += 1;
335
336        if self.sample_rate >= 1.0 {
337            return true;
338        }
339
340        if self.sample_rate <= 0.0 {
341            return false;
342        }
343
344        // Deterministic sampling based on count
345        (self.count as f64 * self.sample_rate).fract() < self.sample_rate
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_stream_stats() {
355        let mut stats = StreamStats::new();
356
357        stats.record_event(100);
358        stats.record_event(200);
359        stats.record_event(300);
360
361        assert_eq!(stats.total_events, 3);
362        assert_eq!(stats.total_bytes, 600);
363        assert_eq!(stats.avg_event_size, 200);
364    }
365
366    #[test]
367    fn test_event_filter() {
368        use crate::EventMetadata;
369        use std::collections::HashMap;
370
371        let filter = EventFilter::new().by_subject("example.org".to_string());
372
373        let event = StreamEvent::TripleAdded {
374            subject: "http://example.org/test".to_string(),
375            predicate: "http://example.org/prop".to_string(),
376            object: "value".to_string(),
377            graph: None,
378            metadata: EventMetadata {
379                event_id: "test".to_string(),
380                timestamp: chrono::Utc::now(),
381                source: "test".to_string(),
382                user: None,
383                context: None,
384                caused_by: None,
385                version: "1.0".to_string(),
386                properties: HashMap::new(),
387                checksum: None,
388            },
389        };
390
391        assert!(filter.matches(&event));
392    }
393
394    #[test]
395    fn test_event_sampler() {
396        let mut sampler = EventSampler::new(0.5);
397
398        let mut sampled = 0;
399        for _ in 0..1000 {
400            if sampler.should_sample() {
401                sampled += 1;
402            }
403        }
404
405        // Should be approximately 500 (50% sampling)
406        assert!((450..=550).contains(&sampled), "Sampled {sampled} events");
407    }
408
409    #[tokio::test]
410    async fn test_simple_rate_limiter() {
411        let mut limiter = SimpleRateLimiter::new(10); // 10 permits per second
412
413        let start = tokio::time::Instant::now();
414
415        for _ in 0..5 {
416            limiter.acquire().await.unwrap();
417        }
418
419        let elapsed = start.elapsed();
420
421        // Should complete almost instantly for 5 permits
422        assert!(elapsed < Duration::from_millis(100));
423    }
424}