use std::{
cmp,
num::NonZeroUsize,
sync::Arc,
time::Duration,
};
use qubit_progress::{
Progress,
reporter::{
NoOpProgressReporter,
ProgressReporter,
},
};
use crate::process::{
BatchProcessResult,
BatchProcessState,
BatchProcessor,
ChunkedBatchProcessError,
};
pub struct ChunkedBatchProcessor<P> {
delegate: P,
chunk_size: NonZeroUsize,
report_interval: Duration,
reporter: Arc<dyn ProgressReporter>,
}
impl<P> ChunkedBatchProcessor<P> {
pub const DEFAULT_REPORT_INTERVAL: Duration = Duration::from_secs(5);
#[inline]
pub fn new(delegate: P, chunk_size: NonZeroUsize) -> Self {
Self {
delegate,
chunk_size,
report_interval: Self::DEFAULT_REPORT_INTERVAL,
reporter: Arc::new(NoOpProgressReporter),
}
}
#[inline]
pub fn with_reporter<R>(self, reporter: R) -> Self
where
R: ProgressReporter + 'static,
{
self.with_reporter_arc(Arc::new(reporter))
}
#[inline]
pub fn with_reporter_arc(self, reporter: Arc<dyn ProgressReporter>) -> Self {
Self { reporter, ..self }
}
#[inline]
pub fn with_report_interval(self, report_interval: Duration) -> Self {
Self {
report_interval,
..self
}
}
#[inline]
pub const fn chunk_size(&self) -> NonZeroUsize {
self.chunk_size
}
#[inline]
pub const fn report_interval(&self) -> Duration {
self.report_interval
}
#[inline]
pub fn reporter(&self) -> &Arc<dyn ProgressReporter> {
&self.reporter
}
#[inline]
pub const fn delegate(&self) -> &P {
&self.delegate
}
#[inline]
pub fn delegate_mut(&mut self) -> &mut P {
&mut self.delegate
}
#[inline]
pub fn into_delegate(self) -> P {
self.delegate
}
}
impl<Item, P> BatchProcessor<Item> for ChunkedBatchProcessor<P>
where
P: BatchProcessor<Item>,
{
type Error = ChunkedBatchProcessError<P::Error>;
fn process<I>(&mut self, items: I, count: usize) -> Result<BatchProcessResult, Self::Error>
where
I: IntoIterator<Item = Item>,
{
let reporter = Arc::clone(&self.reporter);
let mut progress = Progress::new(reporter.as_ref(), self.report_interval);
let state = BatchProcessState::new(count);
progress.report_started(state.progress_counters());
let capacity = cmp::min(self.chunk_size.get(), count.max(1));
let mut chunk = Vec::with_capacity(capacity);
for item in items {
let observed_count = state.record_item_observed();
if observed_count > count {
if !chunk.is_empty() {
self.process_chunk(&mut chunk, &state, &mut progress)?;
}
let failed = progress.report_failed(state.progress_counters());
let result = state.to_chunked_result(failed.elapsed());
return Err(ChunkedBatchProcessError::CountExceeded {
expected: count,
observed_at_least: observed_count,
result,
});
}
chunk.push(item);
if chunk.len() == self.chunk_size.get() {
self.process_chunk(&mut chunk, &state, &mut progress)?;
}
}
if !chunk.is_empty() {
self.process_chunk(&mut chunk, &state, &mut progress)?;
}
if state.observed_count() < count {
let failed = progress.report_failed(state.progress_counters());
let result = state.to_chunked_result(failed.elapsed());
Err(ChunkedBatchProcessError::CountShortfall {
expected: count,
actual: state.observed_count(),
result,
})
} else {
let finished = progress.report_finished(state.progress_counters());
let result = state.to_chunked_result(finished.elapsed());
Ok(result)
}
}
}
impl<P> ChunkedBatchProcessor<P> {
fn process_chunk<Item>(
&mut self,
chunk: &mut Vec<Item>,
state: &BatchProcessState,
progress: &mut Progress<'_>,
) -> Result<(), ChunkedBatchProcessError<P::Error>>
where
P: BatchProcessor<Item>,
{
let chunk_len = chunk.len();
let start_index = state.completed_count();
let chunk_index = state.chunk_count();
let current_chunk = std::mem::take(chunk);
match self.delegate.process(current_chunk, chunk_len) {
Ok(chunk_result) => {
if chunk_result.item_count() != chunk_len
|| chunk_result.completed_count() != chunk_len
{
let failed = progress.report_failed(state.progress_counters());
let result = state.to_chunked_result(failed.elapsed());
return Err(ChunkedBatchProcessError::InvalidChunkResult {
chunk_index,
start_index,
chunk_len,
item_count: chunk_result.item_count(),
completed_count: chunk_result.completed_count(),
result,
});
}
state.record_chunk_processed(chunk_len, chunk_result.processed_count());
let _ = progress.report_running_if_due(state.running_chunk_progress_counters());
Ok(())
}
Err(source) => {
let failed = progress.report_failed(state.progress_counters());
let result = state.to_chunked_result(failed.elapsed());
Err(ChunkedBatchProcessError::ChunkFailed {
chunk_index,
start_index,
chunk_len,
source,
result,
})
}
}
}
}