use std::{
cmp,
num::NonZeroUsize,
sync::Arc,
time::{
Duration,
Instant,
},
};
use crate::progress::{
NoOpProgressReporter,
ProgressReporter,
};
use super::{
BatchProcessResult,
BatchProcessor,
ChunkedBatchProcessError,
};
pub struct ChunkedBatchProcessor<P> {
delegate: P,
chunk_size: NonZeroUsize,
report_interval: Duration,
reporter: Arc<dyn ProgressReporter>,
}
impl<P> ChunkedBatchProcessor<P> {
pub const DEFAULT_REPORT_INTERVAL: Duration = Duration::from_secs(5);
#[inline]
pub fn new(delegate: P, chunk_size: NonZeroUsize) -> Self {
Self {
delegate,
chunk_size,
report_interval: Self::DEFAULT_REPORT_INTERVAL,
reporter: Arc::new(NoOpProgressReporter),
}
}
#[inline]
pub fn with_reporter<R>(self, reporter: R) -> Self
where
R: ProgressReporter + 'static,
{
self.with_reporter_arc(Arc::new(reporter))
}
#[inline]
pub fn with_reporter_arc(self, reporter: Arc<dyn ProgressReporter>) -> Self {
Self { reporter, ..self }
}
#[inline]
pub fn with_report_interval(self, report_interval: Duration) -> Self {
Self {
report_interval,
..self
}
}
#[inline]
pub const fn chunk_size(&self) -> NonZeroUsize {
self.chunk_size
}
#[inline]
pub const fn report_interval(&self) -> Duration {
self.report_interval
}
#[inline]
pub fn reporter(&self) -> &Arc<dyn ProgressReporter> {
&self.reporter
}
#[inline]
pub const fn delegate(&self) -> &P {
&self.delegate
}
#[inline]
pub fn delegate_mut(&mut self) -> &mut P {
&mut self.delegate
}
#[inline]
pub fn into_delegate(self) -> P {
self.delegate
}
}
impl<Item, P> BatchProcessor<Item> for ChunkedBatchProcessor<P>
where
P: BatchProcessor<Item>,
{
type Error = ChunkedBatchProcessError<P::Error>;
fn process<I>(&mut self, items: I, count: usize) -> Result<BatchProcessResult, Self::Error>
where
I: IntoIterator<Item = Item>,
{
self.reporter.start(count);
let start = Instant::now();
let mut next_progress = start + self.report_interval;
let mut state = ChunkedProcessState::new(count);
let capacity = cmp::min(self.chunk_size.get(), count.max(1));
let mut chunk = Vec::with_capacity(capacity);
for item in items {
if state.actual_count == count {
let result = state.to_result(start.elapsed());
self.reporter.finish(count, result.elapsed());
return Err(ChunkedBatchProcessError::CountExceeded {
expected: count,
observed_at_least: count + 1,
result,
});
}
chunk.push(item);
state.actual_count += 1;
if chunk.len() == self.chunk_size.get() {
self.process_chunk(&mut chunk, &mut state, start, &mut next_progress)?;
}
}
if !chunk.is_empty() {
self.process_chunk(&mut chunk, &mut state, start, &mut next_progress)?;
}
let result = state.to_result(start.elapsed());
self.reporter.finish(count, result.elapsed());
if state.actual_count < count {
Err(ChunkedBatchProcessError::CountShortfall {
expected: count,
actual: state.actual_count,
result,
})
} else {
Ok(result)
}
}
}
impl<P> ChunkedBatchProcessor<P> {
fn process_chunk<Item>(
&mut self,
chunk: &mut Vec<Item>,
state: &mut ChunkedProcessState,
start: Instant,
next_progress: &mut Instant,
) -> Result<(), ChunkedBatchProcessError<P::Error>>
where
P: BatchProcessor<Item>,
{
let chunk_len = chunk.len();
let start_index = state.actual_count - chunk_len;
let chunk_index = state.chunk_count;
let current_chunk = std::mem::take(chunk);
match self.delegate.process(current_chunk, chunk_len) {
Ok(chunk_result) => {
state.completed_count += chunk_len;
state.processed_count += chunk_result.processed_count();
state.chunk_count += 1;
report_progress_if_due(
self.reporter.as_ref(),
state.item_count,
state.completed_count,
start,
self.report_interval,
next_progress,
);
Ok(())
}
Err(source) => {
let result = state.to_result(start.elapsed());
self.reporter.finish(state.item_count, result.elapsed());
Err(ChunkedBatchProcessError::ChunkFailed {
chunk_index,
start_index,
chunk_len,
source,
result,
})
}
}
}
}
struct ChunkedProcessState {
item_count: usize,
actual_count: usize,
completed_count: usize,
processed_count: usize,
chunk_count: usize,
}
impl ChunkedProcessState {
const fn new(item_count: usize) -> Self {
Self {
item_count,
actual_count: 0,
completed_count: 0,
processed_count: 0,
chunk_count: 0,
}
}
const fn to_result(&self, elapsed: Duration) -> BatchProcessResult {
BatchProcessResult::new(
self.item_count,
self.completed_count,
self.processed_count,
self.chunk_count,
elapsed,
)
}
}
fn report_progress_if_due(
reporter: &dyn ProgressReporter,
item_count: usize,
completed_count: usize,
start: Instant,
report_interval: Duration,
next_progress: &mut Instant,
) {
let now = Instant::now();
if now >= *next_progress {
reporter.process(item_count, 0, completed_count, start.elapsed());
*next_progress = now + report_interval;
}
}