1use crate::model::{Object, Predicate, Subject, Triple};
7use crate::OxirsError;
8use crossbeam_deque::Injector;
9use parking_lot::{Mutex, RwLock};
10#[cfg(feature = "parallel")]
11use rayon::prelude::*;
12use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
13use std::sync::{Arc, Barrier};
14use std::thread;
15use std::time::{Duration, Instant};
16
17type TransformFn = Arc<dyn Fn(&Triple) -> Option<Triple> + Send + Sync>;
19
20#[derive(Clone)]
22pub enum BatchOperation {
23 Insert(Vec<Triple>),
25 Remove(Vec<Triple>),
27 Query {
29 subject: Option<Subject>,
30 predicate: Option<Predicate>,
31 object: Option<Object>,
32 },
33 Transform(TransformFn),
35}
36
37impl std::fmt::Debug for BatchOperation {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 BatchOperation::Insert(triples) => write!(f, "Insert({} triples)", triples.len()),
41 BatchOperation::Remove(triples) => write!(f, "Remove({} triples)", triples.len()),
42 BatchOperation::Query {
43 subject,
44 predicate,
45 object,
46 } => {
47 write!(f, "Query({subject:?}, {predicate:?}, {object:?})")
48 }
49 BatchOperation::Transform(_) => write!(f, "Transform(function)"),
50 }
51 }
52}
53
54pub type ProgressCallback = Box<dyn Fn(usize, usize) + Send + Sync>;
56
57#[derive(Debug, Clone)]
59pub struct BatchConfig {
60 pub num_threads: Option<usize>,
62 pub batch_size: usize,
64 pub max_queue_size: usize,
66 pub timeout: Option<Duration>,
68 pub enable_progress: bool,
70}
71
72impl Default for BatchConfig {
73 fn default() -> Self {
74 let num_cpus = num_cpus::get();
75 BatchConfig {
76 num_threads: None,
77 batch_size: 1000,
78 max_queue_size: num_cpus * 10000,
79 timeout: None,
80 enable_progress: true,
81 }
82 }
83}
84
85impl BatchConfig {
86 pub fn auto() -> Self {
88 let num_cpus = num_cpus::get();
89 let total_memory = sys_info::mem_info()
90 .map(|info| info.total)
91 .unwrap_or(8 * 1024 * 1024); let batch_size = if total_memory > 16 * 1024 * 1024 {
95 5000
96 } else if total_memory > 8 * 1024 * 1024 {
97 2000
98 } else {
99 1000
100 };
101
102 BatchConfig {
103 num_threads: Some(num_cpus),
104 batch_size,
105 max_queue_size: num_cpus * batch_size * 10,
106 timeout: None,
107 enable_progress: true,
108 }
109 }
110}
111
112#[derive(Debug, Default)]
114pub struct BatchStats {
115 pub total_processed: AtomicUsize,
116 pub total_succeeded: AtomicUsize,
117 pub total_failed: AtomicUsize,
118 pub processing_time_ms: AtomicUsize,
119}
120
121impl BatchStats {
122 pub fn summary(&self) -> BatchStatsSummary {
124 BatchStatsSummary {
125 total_processed: self.total_processed.load(Ordering::Relaxed),
126 total_succeeded: self.total_succeeded.load(Ordering::Relaxed),
127 total_failed: self.total_failed.load(Ordering::Relaxed),
128 processing_time_ms: self.processing_time_ms.load(Ordering::Relaxed),
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
134pub struct BatchStatsSummary {
135 pub total_processed: usize,
136 pub total_succeeded: usize,
137 pub total_failed: usize,
138 pub processing_time_ms: usize,
139}
140
141pub struct ParallelBatchProcessor {
143 config: BatchConfig,
144 injector: Arc<Injector<BatchOperation>>,
146 cancelled: Arc<AtomicBool>,
148 stats: Arc<BatchStats>,
150 progress_callback: Arc<Mutex<Option<ProgressCallback>>>,
152 errors: Arc<RwLock<Vec<OxirsError>>>,
154}
155
156impl ParallelBatchProcessor {
157 pub fn new(config: BatchConfig) -> Self {
159 let injector = Arc::new(Injector::new());
160
161 ParallelBatchProcessor {
162 config,
163 injector,
164 cancelled: Arc::new(AtomicBool::new(false)),
165 stats: Arc::new(BatchStats::default()),
166 progress_callback: Arc::new(Mutex::new(None)),
167 errors: Arc::new(RwLock::new(Vec::new())),
168 }
169 }
170
171 pub fn set_progress_callback<F>(&self, callback: F)
173 where
174 F: Fn(usize, usize) + Send + Sync + 'static,
175 {
176 *self.progress_callback.lock() = Some(Box::new(callback));
177 }
178
179 pub fn cancel(&self) {
181 self.cancelled.store(true, Ordering::SeqCst);
182 }
183
184 pub fn is_cancelled(&self) -> bool {
186 self.cancelled.load(Ordering::SeqCst)
187 }
188
189 pub fn stats(&self) -> BatchStatsSummary {
191 self.stats.summary()
192 }
193
194 pub fn errors(&self) -> Vec<OxirsError> {
196 self.errors.read().clone()
197 }
198
199 pub fn clear_errors(&self) {
201 self.errors.write().clear();
202 }
203
204 pub fn submit(&self, operation: BatchOperation) -> Result<(), OxirsError> {
206 if self.injector.len() > self.config.max_queue_size {
208 return Err(OxirsError::Store("Queue is full".to_string()));
209 }
210
211 self.injector.push(operation);
212 Ok(())
213 }
214
215 pub fn submit_batch(&self, operations: Vec<BatchOperation>) -> Result<(), OxirsError> {
217 if self.injector.len() + operations.len() > self.config.max_queue_size {
219 return Err(OxirsError::Store("Queue would overflow".to_string()));
220 }
221
222 for op in operations {
223 self.injector.push(op);
224 }
225 Ok(())
226 }
227
228 pub fn process<E, R>(&self, executor: E) -> Result<Vec<R>, OxirsError>
230 where
231 E: Fn(BatchOperation) -> Result<R, OxirsError> + Send + Sync + 'static,
232 R: Send + 'static,
233 {
234 let start_time = Instant::now();
235 let num_threads = self.config.num_threads.unwrap_or_else(num_cpus::get);
236 let barrier = Arc::new(Barrier::new(num_threads + 1));
237 let executor = Arc::new(executor);
238 let results = Arc::new(Mutex::new(Vec::new()));
239
240 self.cancelled.store(false, Ordering::SeqCst);
242
243 let handles: Vec<_> = (0..num_threads)
245 .map(|_worker_id| {
246 let injector = self.injector.clone();
247 let cancelled = self.cancelled.clone();
248 let stats = self.stats.clone();
249 let executor = executor.clone();
250 let results = results.clone();
251 let errors = self.errors.clone();
252 let barrier = barrier.clone();
253 let progress_callback = self.progress_callback.clone();
254 let enable_progress = self.config.enable_progress;
255
256 thread::spawn(move || {
257 barrier.wait();
259
260 loop {
261 if cancelled.load(Ordering::SeqCst) {
263 break;
264 }
265
266 let task = loop {
268 match injector.steal() {
269 crossbeam_deque::Steal::Success(task) => break Some(task),
270 crossbeam_deque::Steal::Empty => break None,
271 crossbeam_deque::Steal::Retry => continue,
272 }
273 };
274
275 match task {
276 Some(operation) => {
277 let processed =
279 stats.total_processed.fetch_add(1, Ordering::Relaxed) + 1;
280
281 if enable_progress && processed % 100 == 0 {
283 if let Some(callback) = &*progress_callback.lock() {
284 let total = injector.len() + processed;
285 callback(processed, total);
286 }
287 }
288
289 match executor(operation) {
290 Ok(result) => {
291 stats.total_succeeded.fetch_add(1, Ordering::Relaxed);
292 results.lock().push(result);
293 }
294 Err(e) => {
295 stats.total_failed.fetch_add(1, Ordering::Relaxed);
296 errors.write().push(e);
297 }
298 }
299 }
300 None => {
301 if injector.is_empty() {
303 break;
304 }
305 thread::sleep(Duration::from_micros(10));
307 }
308 }
309 }
310 })
311 })
312 .collect();
313
314 barrier.wait();
316
317 if let Some(timeout) = self.config.timeout {
319 let deadline = Instant::now() + timeout;
320 for handle in handles {
321 let remaining = deadline.saturating_duration_since(Instant::now());
322 if remaining.is_zero() {
323 self.cancel();
324 return Err(OxirsError::Store("Operation timed out".to_string()));
325 }
326 handle
328 .join()
329 .map_err(|_| OxirsError::Store("Worker thread panicked".to_string()))?;
330 }
331 } else {
332 for handle in handles {
333 handle
334 .join()
335 .map_err(|_| OxirsError::Store("Worker thread panicked".to_string()))?;
336 }
337 }
338
339 let elapsed = start_time.elapsed();
341 self.stats
342 .processing_time_ms
343 .store(elapsed.as_millis() as usize, Ordering::Relaxed);
344
345 let errors = self.errors.read();
347 if !errors.is_empty() {
348 return Err(OxirsError::Store(format!(
349 "Batch processing failed with {} errors",
350 errors.len()
351 )));
352 }
353
354 let final_results = Arc::try_unwrap(results)
356 .map_err(|_| OxirsError::Store("Failed to extract results from Arc".to_string()))?
357 .into_inner();
358
359 Ok(final_results)
360 }
361
362 #[cfg(feature = "parallel")]
364 pub fn process_rayon<E, R>(&self, executor: E) -> Result<Vec<R>, OxirsError>
365 where
366 E: Fn(BatchOperation) -> Result<R, OxirsError> + Send + Sync,
367 R: Send,
368 {
369 let start_time = Instant::now();
370
371 let mut operations = Vec::new();
373 loop {
374 match self.injector.steal() {
375 crossbeam_deque::Steal::Success(op) => {
376 if self.is_cancelled() {
377 return Err(OxirsError::Store("Operation cancelled".to_string()));
378 }
379 operations.push(op);
380 }
381 crossbeam_deque::Steal::Empty => break,
382 crossbeam_deque::Steal::Retry => continue,
383 }
384 }
385
386 let pool = rayon::ThreadPoolBuilder::new()
388 .num_threads(self.config.num_threads.unwrap_or_else(num_cpus::get))
389 .build()
390 .map_err(|e| OxirsError::Store(format!("Failed to build thread pool: {e}")))?;
391
392 let cancelled = self.cancelled.clone();
394 let stats = self.stats.clone();
395 let errors = self.errors.clone();
396 let batch_size = self.config.batch_size;
397 let executor = Arc::new(executor);
398
399 let results = pool.install(move || {
401 operations
402 .into_par_iter()
403 .chunks(batch_size)
404 .map(move |chunk| {
405 let mut chunk_results = Vec::new();
406 for op in chunk {
407 if cancelled.load(Ordering::SeqCst) {
408 return Err(OxirsError::Store("Operation cancelled".to_string()));
409 }
410
411 stats.total_processed.fetch_add(1, Ordering::Relaxed);
412
413 match executor(op) {
414 Ok(result) => {
415 stats.total_succeeded.fetch_add(1, Ordering::Relaxed);
416 chunk_results.push(result);
417 }
418 Err(e) => {
419 stats.total_failed.fetch_add(1, Ordering::Relaxed);
420 errors.write().push(e.clone());
421 return Err(e);
422 }
423 }
424 }
425 Ok(chunk_results)
426 })
427 .collect::<Result<Vec<_>, _>>()
428 })?;
429
430 let results: Vec<R> = results.into_iter().flatten().collect();
432
433 let elapsed = start_time.elapsed();
435 self.stats
436 .processing_time_ms
437 .store(elapsed.as_millis() as usize, Ordering::Relaxed);
438
439 Ok(results)
440 }
441}
442
443impl BatchOperation {
445 pub fn insert(triples: Vec<Triple>) -> Self {
447 BatchOperation::Insert(triples)
448 }
449
450 pub fn remove(triples: Vec<Triple>) -> Self {
452 BatchOperation::Remove(triples)
453 }
454
455 pub fn query(
457 subject: Option<Subject>,
458 predicate: Option<Predicate>,
459 object: Option<Object>,
460 ) -> Self {
461 BatchOperation::Query {
462 subject,
463 predicate,
464 object,
465 }
466 }
467
468 pub fn transform<F>(f: F) -> Self
470 where
471 F: Fn(&Triple) -> Option<Triple> + Send + Sync + 'static,
472 {
473 BatchOperation::Transform(Arc::new(f))
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use crate::model::NamedNode;
481
482 fn create_test_triple(id: usize) -> Triple {
483 Triple::new(
484 Subject::NamedNode(NamedNode::new(format!("http://subject/{id}")).unwrap()),
485 Predicate::NamedNode(NamedNode::new(format!("http://predicate/{id}")).unwrap()),
486 Object::NamedNode(NamedNode::new(format!("http://object/{id}")).unwrap()),
487 )
488 }
489
490 #[test]
491 fn test_parallel_batch_processor() {
492 let config = BatchConfig::default();
493 let processor = ParallelBatchProcessor::new(config);
494
495 let operations: Vec<_> = (0..1000)
497 .map(|i| BatchOperation::insert(vec![create_test_triple(i)]))
498 .collect();
499
500 processor.submit_batch(operations).unwrap();
501
502 let results = processor
504 .process(|op| -> Result<usize, OxirsError> {
505 match op {
506 BatchOperation::Insert(triples) => Ok(triples.len()),
507 _ => Ok(0),
508 }
509 })
510 .unwrap();
511
512 assert_eq!(results.len(), 1000);
513 assert_eq!(results.iter().sum::<usize>(), 1000);
514
515 let stats = processor.stats();
516 assert_eq!(stats.total_processed, 1000);
517 assert_eq!(stats.total_succeeded, 1000);
518 assert_eq!(stats.total_failed, 0);
519 }
520
521 #[test]
522 #[cfg(feature = "parallel")]
523 fn test_work_stealing() {
524 let config = BatchConfig {
525 num_threads: Some(4),
526 batch_size: 10,
527 ..Default::default()
528 };
529
530 let processor = ParallelBatchProcessor::new(config);
531
532 for i in 0..100 {
534 processor
535 .submit(BatchOperation::insert(vec![create_test_triple(i)]))
536 .unwrap();
537 }
538
539 let results = processor
541 .process_rayon(|op| -> Result<usize, OxirsError> {
542 thread::sleep(Duration::from_micros(100));
544 match op {
545 BatchOperation::Insert(triples) => Ok(triples.len()),
546 _ => Ok(0),
547 }
548 })
549 .unwrap();
550
551 assert_eq!(results.len(), 100);
552 let stats = processor.stats();
553 assert_eq!(stats.total_processed, 100);
554 }
555
556 #[test]
557 fn test_error_handling() {
558 let config = BatchConfig::default();
559 let processor = ParallelBatchProcessor::new(config);
560
561 for i in 0..10 {
563 processor
564 .submit(BatchOperation::insert(vec![create_test_triple(i)]))
565 .unwrap();
566 }
567
568 let result = processor.process(|_op| -> Result<(), OxirsError> {
570 Err(OxirsError::Store("Test error".to_string()))
571 });
572
573 assert!(result.is_err());
574 let stats = processor.stats();
575 assert_eq!(stats.total_failed, 10);
576 assert_eq!(processor.errors().len(), 10);
577 }
578
579 #[test]
580 fn test_cancellation() {
581 let config = BatchConfig::default();
582 let processor = Arc::new(ParallelBatchProcessor::new(config));
583
584 for i in 0..1000 {
586 processor
587 .submit(BatchOperation::insert(vec![create_test_triple(i)]))
588 .unwrap();
589 }
590
591 let processor_thread = processor.clone();
593
594 let handle = thread::spawn(move || {
595 processor_thread.process(|op| -> Result<(), OxirsError> {
596 thread::sleep(Duration::from_millis(10));
598 match op {
599 BatchOperation::Insert(_) => Ok(()),
600 _ => Ok(()),
601 }
602 })
603 });
604
605 thread::sleep(Duration::from_millis(50));
607 processor.cancel();
608
609 let _result = handle.join().unwrap();
611
612 let stats = processor.stats();
614 assert!(stats.total_processed < 1000);
615 assert!(processor.is_cancelled());
616 }
617
618 #[test]
619 fn test_progress_tracking() {
620 let config = BatchConfig::default();
621 let processor = ParallelBatchProcessor::new(config);
622
623 let progress_count = Arc::new(AtomicUsize::new(0));
624 let progress_count_clone = progress_count.clone();
625
626 processor.set_progress_callback(move |current, _total| {
627 progress_count_clone.fetch_add(1, Ordering::Relaxed);
628 println!("Progress: {current}/{_total}");
629 });
630
631 for i in 0..500 {
633 processor
634 .submit(BatchOperation::insert(vec![create_test_triple(i)]))
635 .unwrap();
636 }
637
638 processor
640 .process(|op| -> Result<(), OxirsError> {
641 match op {
642 BatchOperation::Insert(_) => Ok(()),
643 _ => Ok(()),
644 }
645 })
646 .unwrap();
647
648 assert!(progress_count.load(Ordering::Relaxed) > 0);
650 }
651}