use std::{
sync::Arc,
time::Duration,
};
use qubit_function::{
BoxConsumer,
Consumer,
};
use qubit_progress::{
Progress,
reporter::{
NoOpProgressReporter,
ProgressReporter,
},
};
use crate::process::{
BatchProcessError,
BatchProcessResult,
BatchProcessState,
BatchProcessor,
};
pub struct SequentialBatchProcessor<Item> {
consumer: BoxConsumer<Item>,
report_interval: Duration,
reporter: Arc<dyn ProgressReporter>,
}
impl<Item> SequentialBatchProcessor<Item> {
pub const DEFAULT_REPORT_INTERVAL: Duration = Duration::from_secs(5);
#[inline]
pub fn new<C>(consumer: C) -> Self
where
C: Consumer<Item> + 'static,
{
Self {
consumer: consumer.into_box(),
report_interval: Self::DEFAULT_REPORT_INTERVAL,
reporter: Arc::new(NoOpProgressReporter),
}
}
#[inline]
pub fn with_reporter<R>(self, reporter: R) -> Self
where
R: ProgressReporter + 'static,
{
self.with_reporter_arc(Arc::new(reporter))
}
#[inline]
pub fn with_reporter_arc(self, reporter: Arc<dyn ProgressReporter>) -> Self {
Self { reporter, ..self }
}
#[inline]
pub fn with_report_interval(self, report_interval: Duration) -> Self {
Self {
report_interval,
..self
}
}
#[inline]
pub const fn report_interval(&self) -> Duration {
self.report_interval
}
#[inline]
pub fn reporter(&self) -> &Arc<dyn ProgressReporter> {
&self.reporter
}
#[inline]
pub const fn consumer(&self) -> &BoxConsumer<Item> {
&self.consumer
}
#[inline]
pub fn into_consumer(self) -> BoxConsumer<Item> {
self.consumer
}
}
impl<Item> BatchProcessor<Item> for SequentialBatchProcessor<Item> {
type Error = BatchProcessError;
fn process<I>(&mut self, items: I, count: usize) -> Result<BatchProcessResult, Self::Error>
where
I: IntoIterator<Item = Item>,
{
let state = BatchProcessState::new(count);
let mut progress = Progress::new(self.reporter.as_ref(), self.report_interval);
progress.report_started(state.progress_counters());
for item in items {
let observed_count = state.record_item_observed();
if observed_count > count {
let failed = progress.report_failed(state.progress_counters());
let result = state.to_direct_result(failed.elapsed());
return Err(BatchProcessError::CountExceeded {
expected: count,
observed_at_least: observed_count,
result,
});
}
state.record_item_started();
self.consumer.accept(&item);
state.record_item_processed();
let _ = progress.report_running_if_due(state.progress_counters());
}
if state.observed_count() < count {
let failed = progress.report_failed(state.progress_counters());
let result = state.to_direct_result(failed.elapsed());
Err(BatchProcessError::CountShortfall {
expected: count,
actual: state.observed_count(),
result,
})
} else {
let finished = progress.report_finished(state.progress_counters());
let result = state.to_direct_result(finished.elapsed());
Ok(result)
}
}
}