use crate::model::{Object, Predicate, Subject, Triple};
use crate::OxirsError;
use crossbeam_deque::Injector;
use parking_lot::{Mutex, RwLock};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Barrier};
use std::thread;
use std::time::{Duration, Instant};
type TransformFn = Arc<dyn Fn(&Triple) -> Option<Triple> + Send + Sync>;
#[derive(Clone)]
pub enum BatchOperation {
Insert(Vec<Triple>),
Remove(Vec<Triple>),
Query {
subject: Option<Subject>,
predicate: Option<Predicate>,
object: Option<Object>,
},
Transform(TransformFn),
}
impl std::fmt::Debug for BatchOperation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BatchOperation::Insert(triples) => write!(f, "Insert({} triples)", triples.len()),
BatchOperation::Remove(triples) => write!(f, "Remove({} triples)", triples.len()),
BatchOperation::Query {
subject,
predicate,
object,
} => {
write!(f, "Query({subject:?}, {predicate:?}, {object:?})")
}
BatchOperation::Transform(_) => write!(f, "Transform(function)"),
}
}
}
pub type ProgressCallback = Box<dyn Fn(usize, usize) + Send + Sync>;
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub num_threads: Option<usize>,
pub batch_size: usize,
pub max_queue_size: usize,
pub timeout: Option<Duration>,
pub enable_progress: bool,
}
impl Default for BatchConfig {
fn default() -> Self {
let num_cpus = num_cpus::get();
BatchConfig {
num_threads: None,
batch_size: 1000,
max_queue_size: num_cpus * 10000,
timeout: None,
enable_progress: true,
}
}
}
impl BatchConfig {
pub fn auto() -> Self {
let num_cpus = num_cpus::get();
let total_memory = sys_info::mem_info()
.map(|info| info.total)
.unwrap_or(8 * 1024 * 1024);
let batch_size = if total_memory > 16 * 1024 * 1024 {
5000
} else if total_memory > 8 * 1024 * 1024 {
2000
} else {
1000
};
BatchConfig {
num_threads: Some(num_cpus),
batch_size,
max_queue_size: num_cpus * batch_size * 10,
timeout: None,
enable_progress: true,
}
}
}
#[derive(Debug, Default)]
pub struct BatchStats {
pub total_processed: AtomicUsize,
pub total_succeeded: AtomicUsize,
pub total_failed: AtomicUsize,
pub processing_time_ms: AtomicUsize,
}
impl BatchStats {
pub fn summary(&self) -> BatchStatsSummary {
BatchStatsSummary {
total_processed: self.total_processed.load(Ordering::Relaxed),
total_succeeded: self.total_succeeded.load(Ordering::Relaxed),
total_failed: self.total_failed.load(Ordering::Relaxed),
processing_time_ms: self.processing_time_ms.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct BatchStatsSummary {
pub total_processed: usize,
pub total_succeeded: usize,
pub total_failed: usize,
pub processing_time_ms: usize,
}
pub struct ParallelBatchProcessor {
config: BatchConfig,
injector: Arc<Injector<BatchOperation>>,
cancelled: Arc<AtomicBool>,
stats: Arc<BatchStats>,
progress_callback: Arc<Mutex<Option<ProgressCallback>>>,
errors: Arc<RwLock<Vec<OxirsError>>>,
}
impl ParallelBatchProcessor {
pub fn new(config: BatchConfig) -> Self {
let injector = Arc::new(Injector::new());
ParallelBatchProcessor {
config,
injector,
cancelled: Arc::new(AtomicBool::new(false)),
stats: Arc::new(BatchStats::default()),
progress_callback: Arc::new(Mutex::new(None)),
errors: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn set_progress_callback<F>(&self, callback: F)
where
F: Fn(usize, usize) + Send + Sync + 'static,
{
*self.progress_callback.lock() = Some(Box::new(callback));
}
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::SeqCst);
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::SeqCst)
}
pub fn stats(&self) -> BatchStatsSummary {
self.stats.summary()
}
pub fn errors(&self) -> Vec<OxirsError> {
self.errors.read().clone()
}
pub fn clear_errors(&self) {
self.errors.write().clear();
}
pub fn submit(&self, operation: BatchOperation) -> Result<(), OxirsError> {
if self.injector.len() > self.config.max_queue_size {
return Err(OxirsError::Store("Queue is full".to_string()));
}
self.injector.push(operation);
Ok(())
}
pub fn submit_batch(&self, operations: Vec<BatchOperation>) -> Result<(), OxirsError> {
if self.injector.len() + operations.len() > self.config.max_queue_size {
return Err(OxirsError::Store("Queue would overflow".to_string()));
}
for op in operations {
self.injector.push(op);
}
Ok(())
}
pub fn process<E, R>(&self, executor: E) -> Result<Vec<R>, OxirsError>
where
E: Fn(BatchOperation) -> Result<R, OxirsError> + Send + Sync + 'static,
R: Send + 'static,
{
let start_time = Instant::now();
let num_threads = self.config.num_threads.unwrap_or_else(num_cpus::get);
let barrier = Arc::new(Barrier::new(num_threads + 1));
let executor = Arc::new(executor);
let results = Arc::new(Mutex::new(Vec::new()));
self.cancelled.store(false, Ordering::SeqCst);
let handles: Vec<_> = (0..num_threads)
.map(|_worker_id| {
let injector = self.injector.clone();
let cancelled = self.cancelled.clone();
let stats = self.stats.clone();
let executor = executor.clone();
let results = results.clone();
let errors = self.errors.clone();
let barrier = barrier.clone();
let progress_callback = self.progress_callback.clone();
let enable_progress = self.config.enable_progress;
thread::spawn(move || {
barrier.wait();
loop {
if cancelled.load(Ordering::SeqCst) {
break;
}
let task = loop {
match injector.steal() {
crossbeam_deque::Steal::Success(task) => break Some(task),
crossbeam_deque::Steal::Empty => break None,
crossbeam_deque::Steal::Retry => continue,
}
};
match task {
Some(operation) => {
let processed =
stats.total_processed.fetch_add(1, Ordering::Relaxed) + 1;
if enable_progress && processed % 100 == 0 {
if let Some(callback) = &*progress_callback.lock() {
let total = injector.len() + processed;
callback(processed, total);
}
}
match executor(operation) {
Ok(result) => {
stats.total_succeeded.fetch_add(1, Ordering::Relaxed);
results.lock().push(result);
}
Err(e) => {
stats.total_failed.fetch_add(1, Ordering::Relaxed);
errors.write().push(e);
}
}
}
None => {
if injector.is_empty() {
break;
}
thread::sleep(Duration::from_micros(10));
}
}
}
})
})
.collect();
barrier.wait();
if let Some(timeout) = self.config.timeout {
let deadline = Instant::now() + timeout;
for handle in handles {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
self.cancel();
return Err(OxirsError::Store("Operation timed out".to_string()));
}
handle
.join()
.map_err(|_| OxirsError::Store("Worker thread panicked".to_string()))?;
}
} else {
for handle in handles {
handle
.join()
.map_err(|_| OxirsError::Store("Worker thread panicked".to_string()))?;
}
}
let elapsed = start_time.elapsed();
self.stats
.processing_time_ms
.store(elapsed.as_millis() as usize, Ordering::Relaxed);
let errors = self.errors.read();
if !errors.is_empty() {
return Err(OxirsError::Store(format!(
"Batch processing failed with {} errors",
errors.len()
)));
}
let final_results = Arc::try_unwrap(results)
.map_err(|_| OxirsError::Store("Failed to extract results from Arc".to_string()))?
.into_inner();
Ok(final_results)
}
#[cfg(feature = "parallel")]
pub fn process_rayon<E, R>(&self, executor: E) -> Result<Vec<R>, OxirsError>
where
E: Fn(BatchOperation) -> Result<R, OxirsError> + Send + Sync,
R: Send,
{
let start_time = Instant::now();
let mut operations = Vec::new();
loop {
match self.injector.steal() {
crossbeam_deque::Steal::Success(op) => {
if self.is_cancelled() {
return Err(OxirsError::Store("Operation cancelled".to_string()));
}
operations.push(op);
}
crossbeam_deque::Steal::Empty => break,
crossbeam_deque::Steal::Retry => continue,
}
}
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(self.config.num_threads.unwrap_or_else(num_cpus::get))
.build()
.map_err(|e| OxirsError::Store(format!("Failed to build thread pool: {e}")))?;
let cancelled = self.cancelled.clone();
let stats = self.stats.clone();
let errors = self.errors.clone();
let batch_size = self.config.batch_size;
let executor = Arc::new(executor);
let results = pool.install(move || {
operations
.into_par_iter()
.chunks(batch_size)
.map(move |chunk| {
let mut chunk_results = Vec::new();
for op in chunk {
if cancelled.load(Ordering::SeqCst) {
return Err(OxirsError::Store("Operation cancelled".to_string()));
}
stats.total_processed.fetch_add(1, Ordering::Relaxed);
match executor(op) {
Ok(result) => {
stats.total_succeeded.fetch_add(1, Ordering::Relaxed);
chunk_results.push(result);
}
Err(e) => {
stats.total_failed.fetch_add(1, Ordering::Relaxed);
errors.write().push(e.clone());
return Err(e);
}
}
}
Ok(chunk_results)
})
.collect::<Result<Vec<_>, _>>()
})?;
let results: Vec<R> = results.into_iter().flatten().collect();
let elapsed = start_time.elapsed();
self.stats
.processing_time_ms
.store(elapsed.as_millis() as usize, Ordering::Relaxed);
Ok(results)
}
}
impl BatchOperation {
pub fn insert(triples: Vec<Triple>) -> Self {
BatchOperation::Insert(triples)
}
pub fn remove(triples: Vec<Triple>) -> Self {
BatchOperation::Remove(triples)
}
pub fn query(
subject: Option<Subject>,
predicate: Option<Predicate>,
object: Option<Object>,
) -> Self {
BatchOperation::Query {
subject,
predicate,
object,
}
}
pub fn transform<F>(f: F) -> Self
where
F: Fn(&Triple) -> Option<Triple> + Send + Sync + 'static,
{
BatchOperation::Transform(Arc::new(f))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::NamedNode;
fn create_test_triple(id: usize) -> Triple {
Triple::new(
Subject::NamedNode(
NamedNode::new(format!("http://subject/{id}")).expect("valid IRI from format"),
),
Predicate::NamedNode(
NamedNode::new(format!("http://predicate/{id}")).expect("valid IRI from format"),
),
Object::NamedNode(
NamedNode::new(format!("http://object/{id}")).expect("valid IRI from format"),
),
)
}
#[test]
fn test_parallel_batch_processor() {
let config = BatchConfig::default();
let processor = ParallelBatchProcessor::new(config);
let operations: Vec<_> = (0..1000)
.map(|i| BatchOperation::insert(vec![create_test_triple(i)]))
.collect();
processor
.submit_batch(operations)
.expect("operation should succeed");
let results = processor
.process(|op| -> Result<usize, OxirsError> {
match op {
BatchOperation::Insert(triples) => Ok(triples.len()),
_ => Ok(0),
}
})
.expect("operation should succeed");
assert_eq!(results.len(), 1000);
assert_eq!(results.iter().sum::<usize>(), 1000);
let stats = processor.stats();
assert_eq!(stats.total_processed, 1000);
assert_eq!(stats.total_succeeded, 1000);
assert_eq!(stats.total_failed, 0);
}
#[test]
#[cfg(feature = "parallel")]
fn test_work_stealing() {
let config = BatchConfig {
num_threads: Some(4),
batch_size: 10,
..Default::default()
};
let processor = ParallelBatchProcessor::new(config);
for i in 0..100 {
processor
.submit(BatchOperation::insert(vec![create_test_triple(i)]))
.expect("operation should succeed");
}
let results = processor
.process_rayon(|op| -> Result<usize, OxirsError> {
thread::sleep(Duration::from_micros(100));
match op {
BatchOperation::Insert(triples) => Ok(triples.len()),
_ => Ok(0),
}
})
.expect("operation should succeed");
assert_eq!(results.len(), 100);
let stats = processor.stats();
assert_eq!(stats.total_processed, 100);
}
#[test]
fn test_error_handling() {
let config = BatchConfig::default();
let processor = ParallelBatchProcessor::new(config);
for i in 0..10 {
processor
.submit(BatchOperation::insert(vec![create_test_triple(i)]))
.expect("operation should succeed");
}
let result = processor.process(|_op| -> Result<(), OxirsError> {
Err(OxirsError::Store("Test error".to_string()))
});
assert!(result.is_err());
let stats = processor.stats();
assert_eq!(stats.total_failed, 10);
assert_eq!(processor.errors().len(), 10);
}
#[test]
fn test_cancellation() {
let config = BatchConfig::default();
let processor = Arc::new(ParallelBatchProcessor::new(config));
for i in 0..1000 {
processor
.submit(BatchOperation::insert(vec![create_test_triple(i)]))
.expect("operation should succeed");
}
let processor_thread = processor.clone();
let handle = thread::spawn(move || {
processor_thread.process(|op| -> Result<(), OxirsError> {
thread::sleep(Duration::from_millis(10));
match op {
BatchOperation::Insert(_) => Ok(()),
_ => Ok(()),
}
})
});
thread::sleep(Duration::from_millis(50));
processor.cancel();
let _result = handle.join().expect("thread should not panic");
let stats = processor.stats();
assert!(stats.total_processed < 1000);
assert!(processor.is_cancelled());
}
#[test]
fn test_progress_tracking() {
let config = BatchConfig::default();
let processor = ParallelBatchProcessor::new(config);
let progress_count = Arc::new(AtomicUsize::new(0));
let progress_count_clone = progress_count.clone();
processor.set_progress_callback(move |current, _total| {
progress_count_clone.fetch_add(1, Ordering::Relaxed);
println!("Progress: {current}/{_total}");
});
for i in 0..500 {
processor
.submit(BatchOperation::insert(vec![create_test_triple(i)]))
.expect("operation should succeed");
}
processor
.process(|op| -> Result<(), OxirsError> {
match op {
BatchOperation::Insert(_) => Ok(()),
_ => Ok(()),
}
})
.expect("operation should succeed");
assert!(progress_count.load(Ordering::Relaxed) > 0);
}
}