use std::sync::Arc;
use futures::{Stream, StreamExt, stream};
use recoco_utils::concur_control::{ConcurrencyController, Options as ConcurOptions};
use crate::{Engine, FixResult, LintResult};
#[derive(Debug)]
pub enum BatchError {
TaskFailed(tokio::task::JoinError),
}
impl BatchError {
pub fn is_panic(&self) -> bool {
match self {
Self::TaskFailed(e) => e.is_panic(),
}
}
pub fn is_cancelled(&self) -> bool {
match self {
Self::TaskFailed(e) => e.is_cancelled(),
}
}
}
impl std::fmt::Display for BatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TaskFailed(e) => {
let kind = if e.is_panic() {
"panicked"
} else if e.is_cancelled() {
"was cancelled"
} else {
"failed"
};
write!(f, "batch task {kind}: {e}")
}
}
}
}
impl std::error::Error for BatchError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::TaskFailed(e) => Some(e),
}
}
}
impl From<tokio::task::JoinError> for BatchError {
fn from(e: tokio::task::JoinError) -> Self {
Self::TaskFailed(e)
}
}
pub struct BatchOptions {
pub max_concurrent_docs: Option<usize>,
pub max_inflight_bytes: Option<usize>,
}
impl Default for BatchOptions {
fn default() -> Self {
Self {
max_concurrent_docs: Some(32),
max_inflight_bytes: None,
}
}
}
pub struct BatchEngine {
engine: Arc<Engine>,
controller: Arc<ConcurrencyController>,
concurrent: usize,
}
impl BatchEngine {
pub fn new(engine: Engine, options: BatchOptions) -> Self {
let concurrent = options.max_concurrent_docs.unwrap_or(32);
let controller = ConcurrencyController::new(&ConcurOptions {
max_inflight_rows: options.max_concurrent_docs,
max_inflight_bytes: options.max_inflight_bytes,
});
Self {
engine: Arc::new(engine),
controller: Arc::new(controller),
concurrent,
}
}
pub fn lint_many(
&self,
docs: impl IntoIterator<Item = (String, Vec<u8>)>,
) -> impl Stream<Item = (String, Result<LintResult, BatchError>)> {
let engine = Arc::clone(&self.engine);
let controller = Arc::clone(&self.controller);
let concurrent = self.concurrent;
stream::iter(docs)
.map(move |(id, data)| {
let engine = Arc::clone(&engine);
let controller = Arc::clone(&controller);
async move {
let byte_len = data.len();
let _permit = controller
.acquire(Some(|| byte_len))
.await
.expect("ConcurrencyController semaphore unexpectedly closed");
let result = tokio::task::spawn_blocking(move || engine.lint(&data))
.await
.map_err(BatchError::from);
(id, result)
}
})
.buffer_unordered(concurrent)
}
pub fn fix_many(
&self,
docs: impl IntoIterator<Item = (String, Vec<u8>)>,
) -> impl Stream<Item = (String, Result<FixResult, BatchError>)> {
let engine = Arc::clone(&self.engine);
let controller = Arc::clone(&self.controller);
let concurrent = self.concurrent;
stream::iter(docs)
.map(move |(id, data)| {
let engine = Arc::clone(&engine);
let controller = Arc::clone(&controller);
async move {
let byte_len = data.len();
let _permit = controller
.acquire(Some(|| byte_len))
.await
.expect("ConcurrencyController semaphore unexpectedly closed");
let result = tokio::task::spawn_blocking(move || {
engine.fix(&data, crate::FixMode::Apply)
})
.await
.map_err(BatchError::from);
(id, result)
}
})
.buffer_unordered(concurrent)
}
}