Skip to main content

oxigdal_etl/
stream.rs

1//! Stream processing primitives for ETL pipelines
2//!
3//! This module provides async stream processing capabilities with backpressure,
4//! state management, checkpointing, and parallel processing.
5
6use crate::error::{Result, StreamError};
7use async_trait::async_trait;
8use dashmap::DashMap;
9use futures::stream::{Stream, StreamExt};
10use std::pin::Pin;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::{RwLock, Semaphore, mpsc};
14use tokio::time::timeout;
15
16/// A data item flowing through the stream
17pub type StreamItem = Vec<u8>;
18
19/// Boxed async stream
20pub type BoxStream<T> = Pin<Box<dyn Stream<Item = Result<T>> + Send + 'static>>;
21
22/// Stream configuration
23#[derive(Debug, Clone)]
24pub struct StreamConfig {
25    /// Buffer size for channels
26    pub buffer_size: usize,
27    /// Backpressure timeout
28    pub backpressure_timeout: Duration,
29    /// Enable checkpointing
30    pub checkpointing: bool,
31    /// Checkpoint interval (number of items)
32    pub checkpoint_interval: usize,
33    /// Maximum parallelism
34    pub max_parallelism: usize,
35}
36
37impl Default for StreamConfig {
38    fn default() -> Self {
39        Self {
40            buffer_size: 1000,
41            backpressure_timeout: Duration::from_secs(30),
42            checkpointing: false,
43            checkpoint_interval: 1000,
44            max_parallelism: num_cpus(),
45        }
46    }
47}
48
49/// Stream processor trait
50#[async_trait]
51pub trait StreamProcessor: Send + Sync {
52    /// Process a single item
53    async fn process(&self, item: StreamItem) -> Result<StreamItem>;
54
55    /// Called when checkpoint is triggered
56    async fn checkpoint(&self) -> Result<Vec<u8>> {
57        Ok(Vec::new())
58    }
59
60    /// Restore from checkpoint
61    async fn restore(&self, _state: &[u8]) -> Result<()> {
62        Ok(())
63    }
64}
65
66/// Buffered stream with backpressure
67pub struct BufferedStream {
68    config: StreamConfig,
69    sender: mpsc::Sender<StreamItem>,
70    receiver: Arc<RwLock<mpsc::Receiver<StreamItem>>>,
71    items_processed: Arc<RwLock<usize>>,
72    semaphore: Arc<Semaphore>,
73}
74
75impl BufferedStream {
76    /// Create a new buffered stream
77    pub fn new(config: StreamConfig) -> Self {
78        let (sender, receiver) = mpsc::channel(config.buffer_size);
79        Self {
80            semaphore: Arc::new(Semaphore::new(config.buffer_size)),
81            config,
82            sender,
83            receiver: Arc::new(RwLock::new(receiver)),
84            items_processed: Arc::new(RwLock::new(0)),
85        }
86    }
87
88    /// Push an item to the stream with backpressure
89    pub async fn push(&self, item: StreamItem) -> Result<()> {
90        // Acquire permit with timeout
91        let permit = timeout(self.config.backpressure_timeout, self.semaphore.acquire())
92            .await
93            .map_err(|_| StreamError::BackpressureTimeout {
94                duration: self.config.backpressure_timeout,
95            })?
96            .map_err(|_| StreamError::ChannelClosed)?;
97
98        // Send item
99        self.sender
100            .send(item)
101            .await
102            .map_err(|_| StreamError::ChannelClosed)?;
103
104        // Release permit
105        permit.forget();
106
107        // Update counter
108        let mut count = self.items_processed.write().await;
109        *count += 1;
110
111        Ok(())
112    }
113
114    /// Pull an item from the stream
115    pub async fn pull(&self) -> Result<Option<StreamItem>> {
116        let mut receiver = self.receiver.write().await;
117        Ok(receiver.recv().await)
118    }
119
120    /// Get number of items processed
121    pub async fn items_processed(&self) -> usize {
122        *self.items_processed.read().await
123    }
124
125    /// Check if checkpoint is needed
126    pub async fn needs_checkpoint(&self) -> bool {
127        if !self.config.checkpointing {
128            return false;
129        }
130        let count = self.items_processed().await;
131        count > 0 && count % self.config.checkpoint_interval == 0
132    }
133}
134
135/// State manager for stream processing
136pub struct StateManager {
137    state: DashMap<String, Vec<u8>>,
138    checkpoint_dir: Option<std::path::PathBuf>,
139}
140
141impl StateManager {
142    /// Create a new state manager
143    pub fn new(checkpoint_dir: Option<std::path::PathBuf>) -> Self {
144        Self {
145            state: DashMap::new(),
146            checkpoint_dir,
147        }
148    }
149
150    /// Set state for a key
151    pub fn set(&self, key: String, value: Vec<u8>) {
152        self.state.insert(key, value);
153    }
154
155    /// Get state for a key
156    pub fn get(&self, key: &str) -> Option<Vec<u8>> {
157        self.state.get(key).map(|v| v.clone())
158    }
159
160    /// Save checkpoint to disk
161    pub async fn save_checkpoint(&self, pipeline_id: &str) -> Result<()> {
162        let checkpoint_dir =
163            self.checkpoint_dir
164                .as_ref()
165                .ok_or_else(|| StreamError::StateFailed {
166                    message: "No checkpoint directory configured".to_string(),
167                })?;
168
169        tokio::fs::create_dir_all(checkpoint_dir).await?;
170
171        let checkpoint_file = checkpoint_dir.join(format!("{}.checkpoint", pipeline_id));
172        let mut data = Vec::new();
173
174        for entry in self.state.iter() {
175            let key_bytes = entry.key().as_bytes();
176            data.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes());
177            data.extend_from_slice(key_bytes);
178            data.extend_from_slice(&(entry.value().len() as u32).to_le_bytes());
179            data.extend_from_slice(entry.value());
180        }
181
182        tokio::fs::write(checkpoint_file, data).await?;
183        Ok(())
184    }
185
186    /// Load checkpoint from disk
187    pub async fn load_checkpoint(&self, pipeline_id: &str) -> Result<()> {
188        let checkpoint_dir =
189            self.checkpoint_dir
190                .as_ref()
191                .ok_or_else(|| StreamError::StateFailed {
192                    message: "No checkpoint directory configured".to_string(),
193                })?;
194
195        let checkpoint_file = checkpoint_dir.join(format!("{}.checkpoint", pipeline_id));
196        if !checkpoint_file.exists() {
197            return Ok(());
198        }
199
200        let data = tokio::fs::read(checkpoint_file).await?;
201        let mut offset = 0;
202
203        while offset < data.len() {
204            if offset + 4 > data.len() {
205                break;
206            }
207
208            let key_len = u32::from_le_bytes([
209                data[offset],
210                data[offset + 1],
211                data[offset + 2],
212                data[offset + 3],
213            ]) as usize;
214            offset += 4;
215
216            if offset + key_len > data.len() {
217                break;
218            }
219
220            let key = String::from_utf8_lossy(&data[offset..offset + key_len]).to_string();
221            offset += key_len;
222
223            if offset + 4 > data.len() {
224                break;
225            }
226
227            let value_len = u32::from_le_bytes([
228                data[offset],
229                data[offset + 1],
230                data[offset + 2],
231                data[offset + 3],
232            ]) as usize;
233            offset += 4;
234
235            if offset + value_len > data.len() {
236                break;
237            }
238
239            let value = data[offset..offset + value_len].to_vec();
240            offset += value_len;
241
242            self.state.insert(key, value);
243        }
244
245        Ok(())
246    }
247
248    /// Clear all state
249    pub fn clear(&self) {
250        self.state.clear();
251    }
252}
253
254/// Parallel stream processor
255pub struct ParallelProcessor {
256    config: StreamConfig,
257    processor: Arc<dyn StreamProcessor>,
258    state_manager: Arc<StateManager>,
259    /// Pipeline identifier for checkpointing
260    pipeline_id: String,
261}
262
263impl ParallelProcessor {
264    /// Create a new parallel processor
265    pub fn new(
266        config: StreamConfig,
267        processor: Arc<dyn StreamProcessor>,
268        state_manager: Arc<StateManager>,
269    ) -> Self {
270        Self {
271            config,
272            processor,
273            state_manager,
274            pipeline_id: "default".to_string(),
275        }
276    }
277
278    /// Create a new parallel processor with a specific pipeline ID
279    pub fn with_pipeline_id(
280        config: StreamConfig,
281        processor: Arc<dyn StreamProcessor>,
282        state_manager: Arc<StateManager>,
283        pipeline_id: String,
284    ) -> Self {
285        Self {
286            config,
287            processor,
288            state_manager,
289            pipeline_id,
290        }
291    }
292
293    /// Get the state manager for external access
294    pub fn state_manager(&self) -> &Arc<StateManager> {
295        &self.state_manager
296    }
297
298    /// Save checkpoint using the state manager
299    pub async fn save_checkpoint(&self) -> Result<()> {
300        // Get checkpoint data from processor
301        let checkpoint_data = self.processor.checkpoint().await?;
302
303        // Store in state manager
304        self.state_manager
305            .set(format!("processor_{}", self.pipeline_id), checkpoint_data);
306
307        // Persist to disk if configured
308        self.state_manager.save_checkpoint(&self.pipeline_id).await
309    }
310
311    /// Restore from checkpoint
312    pub async fn restore_checkpoint(&self) -> Result<()> {
313        // Load checkpoint from disk
314        self.state_manager
315            .load_checkpoint(&self.pipeline_id)
316            .await?;
317
318        // Restore processor state if available
319        if let Some(state) = self
320            .state_manager
321            .get(&format!("processor_{}", self.pipeline_id))
322        {
323            self.processor.restore(&state).await?;
324        }
325
326        Ok(())
327    }
328
329    /// Process a stream in parallel
330    pub async fn process_stream<S>(&self, mut stream: S) -> Result<Vec<StreamItem>>
331    where
332        S: Stream<Item = Result<StreamItem>> + Unpin + Send,
333    {
334        let mut results = Vec::new();
335        let semaphore = Arc::new(Semaphore::new(self.config.max_parallelism));
336        let mut handles = Vec::new();
337
338        while let Some(item_result) = stream.next().await {
339            let item = item_result?;
340
341            let processor = Arc::clone(&self.processor);
342            let semaphore = Arc::clone(&semaphore);
343
344            let handle = tokio::spawn(async move {
345                let _permit =
346                    semaphore
347                        .acquire()
348                        .await
349                        .map_err(|_| StreamError::ParallelFailed {
350                            message: "Failed to acquire semaphore".to_string(),
351                        })?;
352                processor.process(item).await
353            });
354
355            handles.push(handle);
356        }
357
358        // Collect results
359        for handle in handles {
360            let result = handle.await.map_err(|e| StreamError::ParallelFailed {
361                message: format!("Task join error: {}", e),
362            })??;
363            results.push(result);
364        }
365
366        Ok(results)
367    }
368
369    /// Process a batch of items
370    pub async fn process_batch(&self, items: Vec<StreamItem>) -> Result<Vec<StreamItem>> {
371        let mut results = Vec::new();
372        let semaphore = Arc::new(Semaphore::new(self.config.max_parallelism));
373        let mut handles = Vec::new();
374
375        for item in items {
376            let processor = Arc::clone(&self.processor);
377            let semaphore = Arc::clone(&semaphore);
378
379            let handle = tokio::spawn(async move {
380                let _permit =
381                    semaphore
382                        .acquire()
383                        .await
384                        .map_err(|_| StreamError::ParallelFailed {
385                            message: "Failed to acquire semaphore".to_string(),
386                        })?;
387                processor.process(item).await
388            });
389
390            handles.push(handle);
391        }
392
393        // Collect results
394        for handle in handles {
395            let result = handle.await.map_err(|e| StreamError::ParallelFailed {
396                message: format!("Task join error: {}", e),
397            })??;
398            results.push(result);
399        }
400
401        Ok(results)
402    }
403}
404
405// Helper to get number of CPUs
406#[allow(clippy::unnecessary_wraps)]
407fn num_cpus() -> usize {
408    std::thread::available_parallelism()
409        .map(|n| n.get())
410        .unwrap_or(1)
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    struct TestProcessor;
418
419    #[async_trait]
420    impl StreamProcessor for TestProcessor {
421        async fn process(&self, item: StreamItem) -> Result<StreamItem> {
422            Ok(item)
423        }
424    }
425
426    #[tokio::test]
427    async fn test_buffered_stream() {
428        let config = StreamConfig::default();
429        let stream = BufferedStream::new(config);
430
431        let item = vec![1, 2, 3, 4];
432        stream.push(item.clone()).await.expect("Failed to push");
433
434        let pulled = stream.pull().await.expect("Failed to pull");
435        assert_eq!(pulled, Some(item));
436
437        assert_eq!(stream.items_processed().await, 1);
438    }
439
440    #[tokio::test]
441    async fn test_state_manager() {
442        let manager = StateManager::new(None);
443
444        manager.set("test_key".to_string(), vec![1, 2, 3]);
445        let value = manager.get("test_key");
446        assert_eq!(value, Some(vec![1, 2, 3]));
447
448        manager.clear();
449        let value = manager.get("test_key");
450        assert_eq!(value, None);
451    }
452
453    #[tokio::test]
454    async fn test_parallel_processor() {
455        let config = StreamConfig::default();
456        let processor = Arc::new(TestProcessor);
457        let state_manager = Arc::new(StateManager::new(None));
458
459        let parallel = ParallelProcessor::new(config, processor, state_manager);
460
461        let items = vec![vec![1, 2], vec![3, 4], vec![5, 6]];
462        let results = parallel
463            .process_batch(items.clone())
464            .await
465            .expect("Failed to process");
466
467        assert_eq!(results.len(), 3);
468    }
469
470    #[tokio::test]
471    async fn test_checkpoint_needed() {
472        let config = StreamConfig {
473            checkpointing: true,
474            checkpoint_interval: 2,
475            ..Default::default()
476        };
477
478        let stream = BufferedStream::new(config);
479
480        stream.push(vec![1]).await.expect("Failed to push");
481        assert!(!stream.needs_checkpoint().await);
482
483        stream.push(vec![2]).await.expect("Failed to push");
484        assert!(stream.needs_checkpoint().await);
485    }
486}