use std::sync::Arc;
use std::time::{Duration, Instant};
use futures::{Stream, StreamExt, stream};
use recoco_utils::concur_control::{ConcurrencyController, Options as ConcurOptions};
use crate::{Engine, EngineError, FixOptions, FixResult, LintOptions, LintResult};
#[derive(Debug)]
#[non_exhaustive]
pub enum BatchError {
TaskFailed(tokio::task::JoinError),
ShutdownInProgress,
DocumentDeadlineExceeded {
partial_lint: LintResult,
},
}
impl BatchError {
pub fn is_panic(&self) -> bool {
match self {
Self::TaskFailed(e) => e.is_panic(),
Self::ShutdownInProgress => false,
Self::DocumentDeadlineExceeded { .. } => false,
}
}
pub fn is_cancelled(&self) -> bool {
match self {
Self::TaskFailed(e) => e.is_cancelled(),
Self::ShutdownInProgress => false,
Self::DocumentDeadlineExceeded { .. } => false,
}
}
pub fn is_shutdown(&self) -> bool {
matches!(self, Self::ShutdownInProgress)
}
pub fn is_deadline_exceeded(&self) -> bool {
matches!(self, Self::DocumentDeadlineExceeded { .. })
}
}
impl std::fmt::Display for BatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TaskFailed(e) => {
let kind = if e.is_panic() {
"panicked"
} else if e.is_cancelled() {
"was cancelled"
} else {
"failed"
};
write!(f, "batch task {kind}: {e}")
}
Self::ShutdownInProgress => {
f.write_str("ConcurrencyController semaphore closed (shutdown in progress)")
}
Self::DocumentDeadlineExceeded { partial_lint } => write!(
f,
"document deadline exceeded after {}/{} candidates",
partial_lint.candidates_processed, partial_lint.candidates_total
),
}
}
}
impl std::error::Error for BatchError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::TaskFailed(e) => Some(e),
Self::ShutdownInProgress => None,
Self::DocumentDeadlineExceeded { .. } => None,
}
}
}
impl From<tokio::task::JoinError> for BatchError {
fn from(e: tokio::task::JoinError) -> Self {
Self::TaskFailed(e)
}
}
impl From<tokio::sync::AcquireError> for BatchError {
fn from(_: tokio::sync::AcquireError) -> Self {
Self::ShutdownInProgress
}
}
#[non_exhaustive]
pub struct BatchOptions {
pub max_concurrent_docs: Option<usize>,
pub max_inflight_bytes: Option<usize>,
pub per_doc_deadline: Option<Duration>,
}
impl Default for BatchOptions {
fn default() -> Self {
Self {
max_concurrent_docs: Some(32),
max_inflight_bytes: None,
per_doc_deadline: None,
}
}
}
pub struct BatchEngine {
engine: Arc<Engine>,
controller: Arc<ConcurrencyController>,
concurrent: usize,
per_doc_deadline: Option<Duration>,
}
impl BatchEngine {
pub fn new(engine: Engine, options: BatchOptions) -> Self {
let concurrent = options.max_concurrent_docs.unwrap_or(32);
let controller = ConcurrencyController::new(&ConcurOptions {
max_inflight_rows: options.max_concurrent_docs,
max_inflight_bytes: options.max_inflight_bytes,
});
Self {
engine: Arc::new(engine),
controller: Arc::new(controller),
concurrent,
per_doc_deadline: options.per_doc_deadline,
}
}
pub fn lint_many(
&self,
docs: impl IntoIterator<Item = (String, Vec<u8>)>,
) -> impl Stream<Item = (String, Result<LintResult, BatchError>)> {
self.lint_many_inner(docs, self.per_doc_deadline)
}
pub fn lint_many_with_options(
&self,
docs: impl IntoIterator<Item = (String, Vec<u8>)>,
opts: &BatchOptions,
) -> impl Stream<Item = (String, Result<LintResult, BatchError>)> {
self.lint_many_inner(docs, opts.per_doc_deadline)
}
fn lint_many_inner(
&self,
docs: impl IntoIterator<Item = (String, Vec<u8>)>,
per_doc_deadline: Option<Duration>,
) -> impl Stream<Item = (String, Result<LintResult, BatchError>)> {
let engine = Arc::clone(&self.engine);
let controller = Arc::clone(&self.controller);
let concurrent = self.concurrent;
stream::iter(docs)
.map(move |(id, data)| {
let engine = Arc::clone(&engine);
let controller = Arc::clone(&controller);
async move {
let byte_len = data.len();
let _permit = match controller.acquire(Some(|| byte_len)).await {
Ok(p) => p,
Err(e) => return (id, Err(BatchError::from(e))),
};
let result = tokio::task::spawn_blocking(move || {
let deadline = per_doc_deadline.map(|d| {
let now = Instant::now();
now.checked_add(d).unwrap_or(now)
});
let opts = LintOptions {
deadline,
..LintOptions::default()
};
engine.lint_with_options(&data, &opts)
})
.await
.map_err(BatchError::from);
(id, result)
}
})
.buffer_unordered(concurrent)
}
pub fn fix_many(
&self,
docs: impl IntoIterator<Item = (String, Vec<u8>)>,
) -> impl Stream<Item = (String, Result<FixResult, BatchError>)> {
self.fix_many_inner(docs, self.per_doc_deadline)
}
pub fn fix_many_with_options(
&self,
docs: impl IntoIterator<Item = (String, Vec<u8>)>,
opts: &BatchOptions,
) -> impl Stream<Item = (String, Result<FixResult, BatchError>)> {
self.fix_many_inner(docs, opts.per_doc_deadline)
}
fn fix_many_inner(
&self,
docs: impl IntoIterator<Item = (String, Vec<u8>)>,
per_doc_deadline: Option<Duration>,
) -> impl Stream<Item = (String, Result<FixResult, BatchError>)> {
let engine = Arc::clone(&self.engine);
let controller = Arc::clone(&self.controller);
let concurrent = self.concurrent;
stream::iter(docs)
.map(move |(id, data)| {
let engine = Arc::clone(&engine);
let controller = Arc::clone(&controller);
async move {
let byte_len = data.len();
let _permit = match controller.acquire(Some(|| byte_len)).await {
Ok(p) => p,
Err(e) => return (id, Err(BatchError::from(e))),
};
let result = tokio::task::spawn_blocking(move || {
let deadline = per_doc_deadline.map(|d| {
let now = Instant::now();
now.checked_add(d).unwrap_or(now)
});
let opts = FixOptions {
deadline,
..FixOptions::default()
};
engine.fix_with_options(&data, crate::FixMode::Apply, &opts)
})
.await;
let mapped = match result {
Ok(Ok(fix_result)) => Ok(fix_result),
Ok(Err(EngineError::DeadlineExceeded { partial_lint })) => {
Err(BatchError::DocumentDeadlineExceeded { partial_lint })
}
Ok(Err(EngineError::InvalidThreshold(_))) => unreachable!(
"BatchEngine does not set FixOptions::threshold_override; \
InvalidThreshold cannot fire"
),
Err(join_error) => Err(BatchError::from(join_error)),
};
(id, mapped)
}
})
.buffer_unordered(concurrent)
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
#[test]
fn shutdown_error_is_not_panic_or_cancellation() {
let e = BatchError::ShutdownInProgress;
assert!(!e.is_panic());
assert!(!e.is_cancelled());
assert!(e.is_shutdown());
}
#[test]
fn shutdown_error_display_names_the_state() {
let e = BatchError::ShutdownInProgress;
let s = e.to_string();
assert!(
s.contains("shutdown"),
"ShutdownInProgress Display should name the state explicitly: got {s:?}"
);
assert!(
s.contains("closed"),
"Display should name the underlying signal (semaphore closed): got {s:?}"
);
}
#[test]
fn shutdown_error_has_no_source() {
let e = BatchError::ShutdownInProgress;
assert!(
std::error::Error::source(&e).is_none(),
"ShutdownInProgress must not chain to a source"
);
}
#[test]
fn from_acquire_error_yields_shutdown_variant() {
let sem = tokio::sync::Semaphore::new(1);
sem.close();
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.expect("current_thread runtime builds");
let acquire_err = rt.block_on(async { sem.acquire().await }).unwrap_err();
let batch_err: BatchError = acquire_err.into();
assert!(batch_err.is_shutdown());
assert!(!batch_err.is_panic());
}
}