oxirs_core/concurrent/
parallel_batch.rs

1//! Parallel batch processing for high-throughput RDF operations
2//!
3//! This module provides a parallel batch processor with work-stealing queues,
4//! configurable thread pools, and progress tracking for efficient RDF data processing.
5
6use crate::model::{Object, Predicate, Subject, Triple};
7use crate::OxirsError;
8use crossbeam_deque::Injector;
9use parking_lot::{Mutex, RwLock};
10#[cfg(feature = "parallel")]
11use rayon::prelude::*;
12use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
13use std::sync::{Arc, Barrier};
14use std::thread;
15use std::time::{Duration, Instant};
16
17/// Type alias for transform functions
18type TransformFn = Arc<dyn Fn(&Triple) -> Option<Triple> + Send + Sync>;
19
20/// Batch operation types
21#[derive(Clone)]
22pub enum BatchOperation {
23    /// Insert a collection of triples
24    Insert(Vec<Triple>),
25    /// Remove a collection of triples
26    Remove(Vec<Triple>),
27    /// Execute a query with pattern matching
28    Query {
29        subject: Option<Subject>,
30        predicate: Option<Predicate>,
31        object: Option<Object>,
32    },
33    /// Transform triples using a function
34    Transform(TransformFn),
35}
36
37impl std::fmt::Debug for BatchOperation {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            BatchOperation::Insert(triples) => write!(f, "Insert({} triples)", triples.len()),
41            BatchOperation::Remove(triples) => write!(f, "Remove({} triples)", triples.len()),
42            BatchOperation::Query {
43                subject,
44                predicate,
45                object,
46            } => {
47                write!(f, "Query({subject:?}, {predicate:?}, {object:?})")
48            }
49            BatchOperation::Transform(_) => write!(f, "Transform(function)"),
50        }
51    }
52}
53
54/// Progress callback for tracking batch operations
55pub type ProgressCallback = Box<dyn Fn(usize, usize) + Send + Sync>;
56
57/// Configuration for parallel batch processing
58#[derive(Debug, Clone)]
59pub struct BatchConfig {
60    /// Number of worker threads (defaults to number of CPU cores)
61    pub num_threads: Option<usize>,
62    /// Size of each batch for processing
63    pub batch_size: usize,
64    /// Maximum queue size before applying backpressure
65    pub max_queue_size: usize,
66    /// Timeout for batch operations
67    pub timeout: Option<Duration>,
68    /// Enable progress tracking
69    pub enable_progress: bool,
70}
71
72impl Default for BatchConfig {
73    fn default() -> Self {
74        let num_cpus = num_cpus::get();
75        BatchConfig {
76            num_threads: None,
77            batch_size: 1000,
78            max_queue_size: num_cpus * 10000,
79            timeout: None,
80            enable_progress: true,
81        }
82    }
83}
84
85impl BatchConfig {
86    /// Create a config optimized for the current system
87    pub fn auto() -> Self {
88        let num_cpus = num_cpus::get();
89        let total_memory = sys_info::mem_info()
90            .map(|info| info.total)
91            .unwrap_or(8 * 1024 * 1024); // 8GB default
92
93        // Adjust batch size based on available memory
94        let batch_size = if total_memory > 16 * 1024 * 1024 {
95            5000
96        } else if total_memory > 8 * 1024 * 1024 {
97            2000
98        } else {
99            1000
100        };
101
102        BatchConfig {
103            num_threads: Some(num_cpus),
104            batch_size,
105            max_queue_size: num_cpus * batch_size * 10,
106            timeout: None,
107            enable_progress: true,
108        }
109    }
110}
111
112/// Statistics for batch processing
113#[derive(Debug, Default)]
114pub struct BatchStats {
115    pub total_processed: AtomicUsize,
116    pub total_succeeded: AtomicUsize,
117    pub total_failed: AtomicUsize,
118    pub processing_time_ms: AtomicUsize,
119}
120
121impl BatchStats {
122    /// Get a summary of the statistics
123    pub fn summary(&self) -> BatchStatsSummary {
124        BatchStatsSummary {
125            total_processed: self.total_processed.load(Ordering::Relaxed),
126            total_succeeded: self.total_succeeded.load(Ordering::Relaxed),
127            total_failed: self.total_failed.load(Ordering::Relaxed),
128            processing_time_ms: self.processing_time_ms.load(Ordering::Relaxed),
129        }
130    }
131}
132
133#[derive(Debug, Clone)]
134pub struct BatchStatsSummary {
135    pub total_processed: usize,
136    pub total_succeeded: usize,
137    pub total_failed: usize,
138    pub processing_time_ms: usize,
139}
140
141/// Parallel batch processor with work-stealing queues
142pub struct ParallelBatchProcessor {
143    config: BatchConfig,
144    /// Global work queue (injector)
145    injector: Arc<Injector<BatchOperation>>,
146    /// Cancellation flag
147    cancelled: Arc<AtomicBool>,
148    /// Processing statistics
149    stats: Arc<BatchStats>,
150    /// Progress callback
151    progress_callback: Arc<Mutex<Option<ProgressCallback>>>,
152    /// Error accumulator
153    errors: Arc<RwLock<Vec<OxirsError>>>,
154}
155
156impl ParallelBatchProcessor {
157    /// Create a new parallel batch processor
158    pub fn new(config: BatchConfig) -> Self {
159        let injector = Arc::new(Injector::new());
160
161        ParallelBatchProcessor {
162            config,
163            injector,
164            cancelled: Arc::new(AtomicBool::new(false)),
165            stats: Arc::new(BatchStats::default()),
166            progress_callback: Arc::new(Mutex::new(None)),
167            errors: Arc::new(RwLock::new(Vec::new())),
168        }
169    }
170
171    /// Set a progress callback
172    pub fn set_progress_callback<F>(&self, callback: F)
173    where
174        F: Fn(usize, usize) + Send + Sync + 'static,
175    {
176        *self.progress_callback.lock() = Some(Box::new(callback));
177    }
178
179    /// Cancel ongoing operations
180    pub fn cancel(&self) {
181        self.cancelled.store(true, Ordering::SeqCst);
182    }
183
184    /// Check if operations are cancelled
185    pub fn is_cancelled(&self) -> bool {
186        self.cancelled.load(Ordering::SeqCst)
187    }
188
189    /// Get current statistics
190    pub fn stats(&self) -> BatchStatsSummary {
191        self.stats.summary()
192    }
193
194    /// Get accumulated errors
195    pub fn errors(&self) -> Vec<OxirsError> {
196        self.errors.read().clone()
197    }
198
199    /// Clear accumulated errors
200    pub fn clear_errors(&self) {
201        self.errors.write().clear();
202    }
203
204    /// Submit a batch operation
205    pub fn submit(&self, operation: BatchOperation) -> Result<(), OxirsError> {
206        // Check queue size for backpressure
207        if self.injector.len() > self.config.max_queue_size {
208            return Err(OxirsError::Store("Queue is full".to_string()));
209        }
210
211        self.injector.push(operation);
212        Ok(())
213    }
214
215    /// Submit multiple operations
216    pub fn submit_batch(&self, operations: Vec<BatchOperation>) -> Result<(), OxirsError> {
217        // Check if adding these operations would exceed queue size
218        if self.injector.len() + operations.len() > self.config.max_queue_size {
219            return Err(OxirsError::Store("Queue would overflow".to_string()));
220        }
221
222        for op in operations {
223            self.injector.push(op);
224        }
225        Ok(())
226    }
227
228    /// Process operations with the given executor
229    pub fn process<E, R>(&self, executor: E) -> Result<Vec<R>, OxirsError>
230    where
231        E: Fn(BatchOperation) -> Result<R, OxirsError> + Send + Sync + 'static,
232        R: Send + 'static,
233    {
234        let start_time = Instant::now();
235        let num_threads = self.config.num_threads.unwrap_or_else(num_cpus::get);
236        let barrier = Arc::new(Barrier::new(num_threads + 1));
237        let executor = Arc::new(executor);
238        let results = Arc::new(Mutex::new(Vec::new()));
239
240        // Reset cancellation flag
241        self.cancelled.store(false, Ordering::SeqCst);
242
243        // Spawn worker threads
244        let handles: Vec<_> = (0..num_threads)
245            .map(|_worker_id| {
246                let injector = self.injector.clone();
247                let cancelled = self.cancelled.clone();
248                let stats = self.stats.clone();
249                let executor = executor.clone();
250                let results = results.clone();
251                let errors = self.errors.clone();
252                let barrier = barrier.clone();
253                let progress_callback = self.progress_callback.clone();
254                let enable_progress = self.config.enable_progress;
255
256                thread::spawn(move || {
257                    // Wait for all threads to be ready
258                    barrier.wait();
259
260                    loop {
261                        // Check for cancellation
262                        if cancelled.load(Ordering::SeqCst) {
263                            break;
264                        }
265
266                        // Try to get work from global queue
267                        let task = loop {
268                            match injector.steal() {
269                                crossbeam_deque::Steal::Success(task) => break Some(task),
270                                crossbeam_deque::Steal::Empty => break None,
271                                crossbeam_deque::Steal::Retry => continue,
272                            }
273                        };
274
275                        match task {
276                            Some(operation) => {
277                                // Process the operation
278                                let processed =
279                                    stats.total_processed.fetch_add(1, Ordering::Relaxed) + 1;
280
281                                // Report progress
282                                if enable_progress && processed % 100 == 0 {
283                                    if let Some(callback) = &*progress_callback.lock() {
284                                        let total = injector.len() + processed;
285                                        callback(processed, total);
286                                    }
287                                }
288
289                                match executor(operation) {
290                                    Ok(result) => {
291                                        stats.total_succeeded.fetch_add(1, Ordering::Relaxed);
292                                        results.lock().push(result);
293                                    }
294                                    Err(e) => {
295                                        stats.total_failed.fetch_add(1, Ordering::Relaxed);
296                                        errors.write().push(e);
297                                    }
298                                }
299                            }
300                            None => {
301                                // No work available, check if we're done
302                                if injector.is_empty() {
303                                    break;
304                                }
305                                // Brief sleep to avoid busy-waiting
306                                thread::sleep(Duration::from_micros(10));
307                            }
308                        }
309                    }
310                })
311            })
312            .collect();
313
314        // Signal all threads to start
315        barrier.wait();
316
317        // Wait for completion or timeout
318        if let Some(timeout) = self.config.timeout {
319            let deadline = Instant::now() + timeout;
320            for handle in handles {
321                let remaining = deadline.saturating_duration_since(Instant::now());
322                if remaining.is_zero() {
323                    self.cancel();
324                    return Err(OxirsError::Store("Operation timed out".to_string()));
325                }
326                // Note: We can't actually join with timeout in std, would need a different approach
327                handle
328                    .join()
329                    .map_err(|_| OxirsError::Store("Worker thread panicked".to_string()))?;
330            }
331        } else {
332            for handle in handles {
333                handle
334                    .join()
335                    .map_err(|_| OxirsError::Store("Worker thread panicked".to_string()))?;
336            }
337        }
338
339        // Record processing time
340        let elapsed = start_time.elapsed();
341        self.stats
342            .processing_time_ms
343            .store(elapsed.as_millis() as usize, Ordering::Relaxed);
344
345        // Check for errors
346        let errors = self.errors.read();
347        if !errors.is_empty() {
348            return Err(OxirsError::Store(format!(
349                "Batch processing failed with {} errors",
350                errors.len()
351            )));
352        }
353
354        // Extract results
355        let final_results = Arc::try_unwrap(results)
356            .map_err(|_| OxirsError::Store("Failed to extract results from Arc".to_string()))?
357            .into_inner();
358
359        Ok(final_results)
360    }
361
362    /// Process operations in parallel using rayon
363    #[cfg(feature = "parallel")]
364    pub fn process_rayon<E, R>(&self, executor: E) -> Result<Vec<R>, OxirsError>
365    where
366        E: Fn(BatchOperation) -> Result<R, OxirsError> + Send + Sync,
367        R: Send,
368    {
369        let start_time = Instant::now();
370
371        // Collect all operations from the queue
372        let mut operations = Vec::new();
373        loop {
374            match self.injector.steal() {
375                crossbeam_deque::Steal::Success(op) => {
376                    if self.is_cancelled() {
377                        return Err(OxirsError::Store("Operation cancelled".to_string()));
378                    }
379                    operations.push(op);
380                }
381                crossbeam_deque::Steal::Empty => break,
382                crossbeam_deque::Steal::Retry => continue,
383            }
384        }
385
386        // Configure rayon thread pool
387        let pool = rayon::ThreadPoolBuilder::new()
388            .num_threads(self.config.num_threads.unwrap_or_else(num_cpus::get))
389            .build()
390            .map_err(|e| OxirsError::Store(format!("Failed to build thread pool: {e}")))?;
391
392        // Clone needed references
393        let cancelled = self.cancelled.clone();
394        let stats = self.stats.clone();
395        let errors = self.errors.clone();
396        let batch_size = self.config.batch_size;
397        let executor = Arc::new(executor);
398
399        // Process in parallel
400        let results = pool.install(move || {
401            operations
402                .into_par_iter()
403                .chunks(batch_size)
404                .map(move |chunk| {
405                    let mut chunk_results = Vec::new();
406                    for op in chunk {
407                        if cancelled.load(Ordering::SeqCst) {
408                            return Err(OxirsError::Store("Operation cancelled".to_string()));
409                        }
410
411                        stats.total_processed.fetch_add(1, Ordering::Relaxed);
412
413                        match executor(op) {
414                            Ok(result) => {
415                                stats.total_succeeded.fetch_add(1, Ordering::Relaxed);
416                                chunk_results.push(result);
417                            }
418                            Err(e) => {
419                                stats.total_failed.fetch_add(1, Ordering::Relaxed);
420                                errors.write().push(e.clone());
421                                return Err(e);
422                            }
423                        }
424                    }
425                    Ok(chunk_results)
426                })
427                .collect::<Result<Vec<_>, _>>()
428        })?;
429
430        // Flatten results
431        let results: Vec<R> = results.into_iter().flatten().collect();
432
433        // Record processing time
434        let elapsed = start_time.elapsed();
435        self.stats
436            .processing_time_ms
437            .store(elapsed.as_millis() as usize, Ordering::Relaxed);
438
439        Ok(results)
440    }
441}
442
443/// Helper functions for creating batch operations
444impl BatchOperation {
445    /// Create an insert operation
446    pub fn insert(triples: Vec<Triple>) -> Self {
447        BatchOperation::Insert(triples)
448    }
449
450    /// Create a remove operation
451    pub fn remove(triples: Vec<Triple>) -> Self {
452        BatchOperation::Remove(triples)
453    }
454
455    /// Create a query operation
456    pub fn query(
457        subject: Option<Subject>,
458        predicate: Option<Predicate>,
459        object: Option<Object>,
460    ) -> Self {
461        BatchOperation::Query {
462            subject,
463            predicate,
464            object,
465        }
466    }
467
468    /// Create a transform operation
469    pub fn transform<F>(f: F) -> Self
470    where
471        F: Fn(&Triple) -> Option<Triple> + Send + Sync + 'static,
472    {
473        BatchOperation::Transform(Arc::new(f))
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use crate::model::NamedNode;
481
482    fn create_test_triple(id: usize) -> Triple {
483        Triple::new(
484            Subject::NamedNode(NamedNode::new(format!("http://subject/{id}")).unwrap()),
485            Predicate::NamedNode(NamedNode::new(format!("http://predicate/{id}")).unwrap()),
486            Object::NamedNode(NamedNode::new(format!("http://object/{id}")).unwrap()),
487        )
488    }
489
490    #[test]
491    fn test_parallel_batch_processor() {
492        let config = BatchConfig::default();
493        let processor = ParallelBatchProcessor::new(config);
494
495        // Submit some operations
496        let operations: Vec<_> = (0..1000)
497            .map(|i| BatchOperation::insert(vec![create_test_triple(i)]))
498            .collect();
499
500        processor.submit_batch(operations).unwrap();
501
502        // Process with a simple executor
503        let results = processor
504            .process(|op| -> Result<usize, OxirsError> {
505                match op {
506                    BatchOperation::Insert(triples) => Ok(triples.len()),
507                    _ => Ok(0),
508                }
509            })
510            .unwrap();
511
512        assert_eq!(results.len(), 1000);
513        assert_eq!(results.iter().sum::<usize>(), 1000);
514
515        let stats = processor.stats();
516        assert_eq!(stats.total_processed, 1000);
517        assert_eq!(stats.total_succeeded, 1000);
518        assert_eq!(stats.total_failed, 0);
519    }
520
521    #[test]
522    #[cfg(feature = "parallel")]
523    fn test_work_stealing() {
524        let config = BatchConfig {
525            num_threads: Some(4),
526            batch_size: 10,
527            ..Default::default()
528        };
529
530        let processor = ParallelBatchProcessor::new(config);
531
532        // Submit operations
533        for i in 0..100 {
534            processor
535                .submit(BatchOperation::insert(vec![create_test_triple(i)]))
536                .unwrap();
537        }
538
539        // Process and verify work is distributed
540        let results = processor
541            .process_rayon(|op| -> Result<usize, OxirsError> {
542                // Simulate some work
543                thread::sleep(Duration::from_micros(100));
544                match op {
545                    BatchOperation::Insert(triples) => Ok(triples.len()),
546                    _ => Ok(0),
547                }
548            })
549            .unwrap();
550
551        assert_eq!(results.len(), 100);
552        let stats = processor.stats();
553        assert_eq!(stats.total_processed, 100);
554    }
555
556    #[test]
557    fn test_error_handling() {
558        let config = BatchConfig::default();
559        let processor = ParallelBatchProcessor::new(config);
560
561        // Submit operations that will fail
562        for i in 0..10 {
563            processor
564                .submit(BatchOperation::insert(vec![create_test_triple(i)]))
565                .unwrap();
566        }
567
568        // Process with failing executor
569        let result = processor.process(|_op| -> Result<(), OxirsError> {
570            Err(OxirsError::Store("Test error".to_string()))
571        });
572
573        assert!(result.is_err());
574        let stats = processor.stats();
575        assert_eq!(stats.total_failed, 10);
576        assert_eq!(processor.errors().len(), 10);
577    }
578
579    #[test]
580    fn test_cancellation() {
581        let config = BatchConfig::default();
582        let processor = Arc::new(ParallelBatchProcessor::new(config));
583
584        // Submit many operations
585        for i in 0..1000 {
586            processor
587                .submit(BatchOperation::insert(vec![create_test_triple(i)]))
588                .unwrap();
589        }
590
591        // Start processing in a thread
592        let processor_thread = processor.clone();
593
594        let handle = thread::spawn(move || {
595            processor_thread.process(|op| -> Result<(), OxirsError> {
596                // Simulate slow processing
597                thread::sleep(Duration::from_millis(10));
598                match op {
599                    BatchOperation::Insert(_) => Ok(()),
600                    _ => Ok(()),
601                }
602            })
603        });
604
605        // Cancel after a short delay
606        thread::sleep(Duration::from_millis(50));
607        processor.cancel();
608
609        // Wait for completion
610        let _result = handle.join().unwrap();
611
612        // Should have processed some but not all
613        let stats = processor.stats();
614        assert!(stats.total_processed < 1000);
615        assert!(processor.is_cancelled());
616    }
617
618    #[test]
619    fn test_progress_tracking() {
620        let config = BatchConfig::default();
621        let processor = ParallelBatchProcessor::new(config);
622
623        let progress_count = Arc::new(AtomicUsize::new(0));
624        let progress_count_clone = progress_count.clone();
625
626        processor.set_progress_callback(move |current, _total| {
627            progress_count_clone.fetch_add(1, Ordering::Relaxed);
628            println!("Progress: {current}/{_total}");
629        });
630
631        // Submit operations
632        for i in 0..500 {
633            processor
634                .submit(BatchOperation::insert(vec![create_test_triple(i)]))
635                .unwrap();
636        }
637
638        // Process
639        processor
640            .process(|op| -> Result<(), OxirsError> {
641                match op {
642                    BatchOperation::Insert(_) => Ok(()),
643                    _ => Ok(()),
644                }
645            })
646            .unwrap();
647
648        // Should have received progress updates
649        assert!(progress_count.load(Ordering::Relaxed) > 0);
650    }
651}