use crate::{Record, Result};
pub trait ParallelProcessor: Send + Clone {
fn process_record(&mut self, record: Record) -> Result<()>;
#[allow(unused_variables)]
fn on_batch_complete(&mut self) -> Result<()> {
Ok(())
}
#[allow(unused_variables)]
fn set_tid(&mut self, tid: usize) {
}
fn get_tid(&self) -> Option<usize> {
None
}
}
pub trait ParallelReader {
fn process_parallel<P: ParallelProcessor + Clone + 'static>(
&self,
processor: P,
num_threads: usize,
) -> Result<()>;
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
#[derive(Clone, Default)]
struct TestProcessor {
local_count: u64,
local_sum: u64,
global_count: Arc<AtomicU64>,
global_sum: Arc<AtomicU64>,
tid: Option<usize>,
}
impl ParallelProcessor for TestProcessor {
fn process_record(&mut self, record: Record) -> Result<()> {
self.local_count += 1;
self.local_sum += record.barcode + record.umi + record.index;
Ok(())
}
fn on_batch_complete(&mut self) -> Result<()> {
self.global_count
.fetch_add(self.local_count, Ordering::Relaxed);
self.global_sum.fetch_add(self.local_sum, Ordering::Relaxed);
self.local_count = 0;
self.local_sum = 0;
Ok(())
}
fn set_tid(&mut self, tid: usize) {
self.tid = Some(tid);
}
fn get_tid(&self) -> Option<usize> {
self.tid
}
}
#[derive(Clone)]
struct ErrorProcessor {
fail_on_record: u64,
current_record: u64,
}
impl ParallelProcessor for ErrorProcessor {
fn process_record(&mut self, record: Record) -> Result<()> {
self.current_record += 1;
if record.index == self.fail_on_record {
return Err(crate::IbuError::Process("Test error".into()));
}
Ok(())
}
}
#[test]
fn test_processor_basic_functionality() {
let processor = TestProcessor::default();
let mut processor_clone = processor.clone();
assert_eq!(processor_clone.get_tid(), None);
processor_clone.set_tid(42);
assert_eq!(processor_clone.get_tid(), Some(42));
let record1 = Record::new(1, 2, 3);
let record2 = Record::new(4, 5, 6);
processor_clone.process_record(record1).unwrap();
processor_clone.process_record(record2).unwrap();
assert_eq!(processor_clone.local_count, 2);
assert_eq!(processor_clone.local_sum, 1 + 2 + 3 + 4 + 5 + 6);
processor_clone.on_batch_complete().unwrap();
assert_eq!(processor_clone.local_count, 0);
assert_eq!(processor_clone.local_sum, 0);
assert_eq!(processor.global_count.load(Ordering::Relaxed), 2);
assert_eq!(processor.global_sum.load(Ordering::Relaxed), 21);
}
#[test]
fn test_processor_thread_safety() {
let processor = TestProcessor::default();
fn is_send<T: Send>() {}
fn is_clone<T: Clone>() {}
is_send::<TestProcessor>();
is_clone::<TestProcessor>();
let clone1 = processor.clone();
let clone2 = processor.clone();
let mut clone1 = clone1;
let mut clone2 = clone2;
clone1.set_tid(1);
clone2.set_tid(2);
assert_eq!(clone1.get_tid(), Some(1));
assert_eq!(clone2.get_tid(), Some(2));
assert!(Arc::ptr_eq(&clone1.global_count, &clone2.global_count));
assert!(Arc::ptr_eq(&clone1.global_sum, &clone2.global_sum));
}
#[test]
fn test_error_handling() {
let mut processor = ErrorProcessor {
fail_on_record: 5,
current_record: 0,
};
let record1 = Record::new(1, 2, 3);
assert!(processor.process_record(record1).is_ok());
let record2 = Record::new(1, 2, 4);
assert!(processor.process_record(record2).is_ok());
let record3 = Record::new(1, 2, 5);
let result = processor.process_record(record3);
assert!(result.is_err());
match result {
Err(crate::IbuError::Process(_)) => {} other => panic!("Expected Process error, got: {:?}", other),
}
}
#[test]
fn test_default_implementations() {
#[derive(Clone)]
struct MinimalProcessor;
impl ParallelProcessor for MinimalProcessor {
fn process_record(&mut self, _record: Record) -> Result<()> {
Ok(())
}
}
let mut processor = MinimalProcessor;
assert!(processor.on_batch_complete().is_ok());
assert_eq!(processor.get_tid(), None);
processor.set_tid(123);
assert_eq!(processor.get_tid(), None); }
#[test]
fn test_multiple_batch_completions() {
let processor = TestProcessor::default();
let mut processor_clone = processor.clone();
processor_clone
.process_record(Record::new(1, 0, 0))
.unwrap();
processor_clone.on_batch_complete().unwrap();
processor_clone
.process_record(Record::new(2, 0, 0))
.unwrap();
processor_clone
.process_record(Record::new(3, 0, 0))
.unwrap();
processor_clone.on_batch_complete().unwrap();
assert_eq!(processor.global_count.load(Ordering::Relaxed), 3);
assert_eq!(processor.global_sum.load(Ordering::Relaxed), 1 + 2 + 3);
}
}