1use crate::error::{Result, StreamError};
7use async_trait::async_trait;
8use dashmap::DashMap;
9use futures::stream::{Stream, StreamExt};
10use std::pin::Pin;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::{RwLock, Semaphore, mpsc};
14use tokio::time::timeout;
15
16pub type StreamItem = Vec<u8>;
18
19pub type BoxStream<T> = Pin<Box<dyn Stream<Item = Result<T>> + Send + 'static>>;
21
22#[derive(Debug, Clone)]
24pub struct StreamConfig {
25 pub buffer_size: usize,
27 pub backpressure_timeout: Duration,
29 pub checkpointing: bool,
31 pub checkpoint_interval: usize,
33 pub max_parallelism: usize,
35}
36
37impl Default for StreamConfig {
38 fn default() -> Self {
39 Self {
40 buffer_size: 1000,
41 backpressure_timeout: Duration::from_secs(30),
42 checkpointing: false,
43 checkpoint_interval: 1000,
44 max_parallelism: num_cpus(),
45 }
46 }
47}
48
49#[async_trait]
51pub trait StreamProcessor: Send + Sync {
52 async fn process(&self, item: StreamItem) -> Result<StreamItem>;
54
55 async fn checkpoint(&self) -> Result<Vec<u8>> {
57 Ok(Vec::new())
58 }
59
60 async fn restore(&self, _state: &[u8]) -> Result<()> {
62 Ok(())
63 }
64}
65
66pub struct BufferedStream {
68 config: StreamConfig,
69 sender: mpsc::Sender<StreamItem>,
70 receiver: Arc<RwLock<mpsc::Receiver<StreamItem>>>,
71 items_processed: Arc<RwLock<usize>>,
72 semaphore: Arc<Semaphore>,
73}
74
75impl BufferedStream {
76 pub fn new(config: StreamConfig) -> Self {
78 let (sender, receiver) = mpsc::channel(config.buffer_size);
79 Self {
80 semaphore: Arc::new(Semaphore::new(config.buffer_size)),
81 config,
82 sender,
83 receiver: Arc::new(RwLock::new(receiver)),
84 items_processed: Arc::new(RwLock::new(0)),
85 }
86 }
87
88 pub async fn push(&self, item: StreamItem) -> Result<()> {
90 let permit = timeout(self.config.backpressure_timeout, self.semaphore.acquire())
92 .await
93 .map_err(|_| StreamError::BackpressureTimeout {
94 duration: self.config.backpressure_timeout,
95 })?
96 .map_err(|_| StreamError::ChannelClosed)?;
97
98 self.sender
100 .send(item)
101 .await
102 .map_err(|_| StreamError::ChannelClosed)?;
103
104 permit.forget();
106
107 let mut count = self.items_processed.write().await;
109 *count += 1;
110
111 Ok(())
112 }
113
114 pub async fn pull(&self) -> Result<Option<StreamItem>> {
116 let mut receiver = self.receiver.write().await;
117 Ok(receiver.recv().await)
118 }
119
120 pub async fn items_processed(&self) -> usize {
122 *self.items_processed.read().await
123 }
124
125 pub async fn needs_checkpoint(&self) -> bool {
127 if !self.config.checkpointing {
128 return false;
129 }
130 let count = self.items_processed().await;
131 count > 0 && count % self.config.checkpoint_interval == 0
132 }
133}
134
135pub struct StateManager {
137 state: DashMap<String, Vec<u8>>,
138 checkpoint_dir: Option<std::path::PathBuf>,
139}
140
141impl StateManager {
142 pub fn new(checkpoint_dir: Option<std::path::PathBuf>) -> Self {
144 Self {
145 state: DashMap::new(),
146 checkpoint_dir,
147 }
148 }
149
150 pub fn set(&self, key: String, value: Vec<u8>) {
152 self.state.insert(key, value);
153 }
154
155 pub fn get(&self, key: &str) -> Option<Vec<u8>> {
157 self.state.get(key).map(|v| v.clone())
158 }
159
160 pub async fn save_checkpoint(&self, pipeline_id: &str) -> Result<()> {
162 let checkpoint_dir =
163 self.checkpoint_dir
164 .as_ref()
165 .ok_or_else(|| StreamError::StateFailed {
166 message: "No checkpoint directory configured".to_string(),
167 })?;
168
169 tokio::fs::create_dir_all(checkpoint_dir).await?;
170
171 let checkpoint_file = checkpoint_dir.join(format!("{}.checkpoint", pipeline_id));
172 let mut data = Vec::new();
173
174 for entry in self.state.iter() {
175 let key_bytes = entry.key().as_bytes();
176 data.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes());
177 data.extend_from_slice(key_bytes);
178 data.extend_from_slice(&(entry.value().len() as u32).to_le_bytes());
179 data.extend_from_slice(entry.value());
180 }
181
182 tokio::fs::write(checkpoint_file, data).await?;
183 Ok(())
184 }
185
186 pub async fn load_checkpoint(&self, pipeline_id: &str) -> Result<()> {
188 let checkpoint_dir =
189 self.checkpoint_dir
190 .as_ref()
191 .ok_or_else(|| StreamError::StateFailed {
192 message: "No checkpoint directory configured".to_string(),
193 })?;
194
195 let checkpoint_file = checkpoint_dir.join(format!("{}.checkpoint", pipeline_id));
196 if !checkpoint_file.exists() {
197 return Ok(());
198 }
199
200 let data = tokio::fs::read(checkpoint_file).await?;
201 let mut offset = 0;
202
203 while offset < data.len() {
204 if offset + 4 > data.len() {
205 break;
206 }
207
208 let key_len = u32::from_le_bytes([
209 data[offset],
210 data[offset + 1],
211 data[offset + 2],
212 data[offset + 3],
213 ]) as usize;
214 offset += 4;
215
216 if offset + key_len > data.len() {
217 break;
218 }
219
220 let key = String::from_utf8_lossy(&data[offset..offset + key_len]).to_string();
221 offset += key_len;
222
223 if offset + 4 > data.len() {
224 break;
225 }
226
227 let value_len = u32::from_le_bytes([
228 data[offset],
229 data[offset + 1],
230 data[offset + 2],
231 data[offset + 3],
232 ]) as usize;
233 offset += 4;
234
235 if offset + value_len > data.len() {
236 break;
237 }
238
239 let value = data[offset..offset + value_len].to_vec();
240 offset += value_len;
241
242 self.state.insert(key, value);
243 }
244
245 Ok(())
246 }
247
248 pub fn clear(&self) {
250 self.state.clear();
251 }
252}
253
254pub struct ParallelProcessor {
256 config: StreamConfig,
257 processor: Arc<dyn StreamProcessor>,
258 state_manager: Arc<StateManager>,
259 pipeline_id: String,
261}
262
263impl ParallelProcessor {
264 pub fn new(
266 config: StreamConfig,
267 processor: Arc<dyn StreamProcessor>,
268 state_manager: Arc<StateManager>,
269 ) -> Self {
270 Self {
271 config,
272 processor,
273 state_manager,
274 pipeline_id: "default".to_string(),
275 }
276 }
277
278 pub fn with_pipeline_id(
280 config: StreamConfig,
281 processor: Arc<dyn StreamProcessor>,
282 state_manager: Arc<StateManager>,
283 pipeline_id: String,
284 ) -> Self {
285 Self {
286 config,
287 processor,
288 state_manager,
289 pipeline_id,
290 }
291 }
292
293 pub fn state_manager(&self) -> &Arc<StateManager> {
295 &self.state_manager
296 }
297
298 pub async fn save_checkpoint(&self) -> Result<()> {
300 let checkpoint_data = self.processor.checkpoint().await?;
302
303 self.state_manager
305 .set(format!("processor_{}", self.pipeline_id), checkpoint_data);
306
307 self.state_manager.save_checkpoint(&self.pipeline_id).await
309 }
310
311 pub async fn restore_checkpoint(&self) -> Result<()> {
313 self.state_manager
315 .load_checkpoint(&self.pipeline_id)
316 .await?;
317
318 if let Some(state) = self
320 .state_manager
321 .get(&format!("processor_{}", self.pipeline_id))
322 {
323 self.processor.restore(&state).await?;
324 }
325
326 Ok(())
327 }
328
329 pub async fn process_stream<S>(&self, mut stream: S) -> Result<Vec<StreamItem>>
331 where
332 S: Stream<Item = Result<StreamItem>> + Unpin + Send,
333 {
334 let mut results = Vec::new();
335 let semaphore = Arc::new(Semaphore::new(self.config.max_parallelism));
336 let mut handles = Vec::new();
337
338 while let Some(item_result) = stream.next().await {
339 let item = item_result?;
340
341 let processor = Arc::clone(&self.processor);
342 let semaphore = Arc::clone(&semaphore);
343
344 let handle = tokio::spawn(async move {
345 let _permit =
346 semaphore
347 .acquire()
348 .await
349 .map_err(|_| StreamError::ParallelFailed {
350 message: "Failed to acquire semaphore".to_string(),
351 })?;
352 processor.process(item).await
353 });
354
355 handles.push(handle);
356 }
357
358 for handle in handles {
360 let result = handle.await.map_err(|e| StreamError::ParallelFailed {
361 message: format!("Task join error: {}", e),
362 })??;
363 results.push(result);
364 }
365
366 Ok(results)
367 }
368
369 pub async fn process_batch(&self, items: Vec<StreamItem>) -> Result<Vec<StreamItem>> {
371 let mut results = Vec::new();
372 let semaphore = Arc::new(Semaphore::new(self.config.max_parallelism));
373 let mut handles = Vec::new();
374
375 for item in items {
376 let processor = Arc::clone(&self.processor);
377 let semaphore = Arc::clone(&semaphore);
378
379 let handle = tokio::spawn(async move {
380 let _permit =
381 semaphore
382 .acquire()
383 .await
384 .map_err(|_| StreamError::ParallelFailed {
385 message: "Failed to acquire semaphore".to_string(),
386 })?;
387 processor.process(item).await
388 });
389
390 handles.push(handle);
391 }
392
393 for handle in handles {
395 let result = handle.await.map_err(|e| StreamError::ParallelFailed {
396 message: format!("Task join error: {}", e),
397 })??;
398 results.push(result);
399 }
400
401 Ok(results)
402 }
403}
404
405#[allow(clippy::unnecessary_wraps)]
407fn num_cpus() -> usize {
408 std::thread::available_parallelism()
409 .map(|n| n.get())
410 .unwrap_or(1)
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 struct TestProcessor;
418
419 #[async_trait]
420 impl StreamProcessor for TestProcessor {
421 async fn process(&self, item: StreamItem) -> Result<StreamItem> {
422 Ok(item)
423 }
424 }
425
426 #[tokio::test]
427 async fn test_buffered_stream() {
428 let config = StreamConfig::default();
429 let stream = BufferedStream::new(config);
430
431 let item = vec![1, 2, 3, 4];
432 stream.push(item.clone()).await.expect("Failed to push");
433
434 let pulled = stream.pull().await.expect("Failed to pull");
435 assert_eq!(pulled, Some(item));
436
437 assert_eq!(stream.items_processed().await, 1);
438 }
439
440 #[tokio::test]
441 async fn test_state_manager() {
442 let manager = StateManager::new(None);
443
444 manager.set("test_key".to_string(), vec![1, 2, 3]);
445 let value = manager.get("test_key");
446 assert_eq!(value, Some(vec![1, 2, 3]));
447
448 manager.clear();
449 let value = manager.get("test_key");
450 assert_eq!(value, None);
451 }
452
453 #[tokio::test]
454 async fn test_parallel_processor() {
455 let config = StreamConfig::default();
456 let processor = Arc::new(TestProcessor);
457 let state_manager = Arc::new(StateManager::new(None));
458
459 let parallel = ParallelProcessor::new(config, processor, state_manager);
460
461 let items = vec![vec![1, 2], vec![3, 4], vec![5, 6]];
462 let results = parallel
463 .process_batch(items.clone())
464 .await
465 .expect("Failed to process");
466
467 assert_eq!(results.len(), 3);
468 }
469
470 #[tokio::test]
471 async fn test_checkpoint_needed() {
472 let config = StreamConfig {
473 checkpointing: true,
474 checkpoint_interval: 2,
475 ..Default::default()
476 };
477
478 let stream = BufferedStream::new(config);
479
480 stream.push(vec![1]).await.expect("Failed to push");
481 assert!(!stream.needs_checkpoint().await);
482
483 stream.push(vec![2]).await.expect("Failed to push");
484 assert!(stream.needs_checkpoint().await);
485 }
486}