use std::{
collections::HashMap,
fmt::{Debug, Display},
};
use tokio::{
sync::{mpsc, oneshot},
task::JoinHandle,
};
use tracing::{debug, info};
use crate::{
BatchError,
batch::BatchItem,
batch_inner::Generation,
batch_queue::BatchQueue,
limits::Limits,
policies::{BatchingPolicy, OnAdd, OnFinish, OnGenerationEvent},
processor::Processor,
};
pub(crate) struct Worker<P: Processor> {
batcher_name: String,
item_rx: mpsc::Receiver<BatchItem<P>>,
processor: P,
msg_tx: mpsc::Sender<Message<P::Key, P::Error>>,
msg_rx: mpsc::Receiver<Message<P::Key, P::Error>>,
shutdown_notifier_rx: mpsc::Receiver<ShutdownMessage>,
shutdown_notifiers: Vec<oneshot::Sender<()>>,
shutting_down: bool,
limits: Limits,
batching_policy: BatchingPolicy,
batch_queues: HashMap<P::Key, BatchQueue<P>>,
}
#[derive(Debug)]
pub(crate) enum Message<K, E: Display + Debug> {
TimedOut(K, Generation),
ResourcesAcquired(K, Generation),
ResourceAcquisitionFailed(K, Generation, BatchError<E>),
Finished(K, BatchTerminalState),
}
#[derive(Debug)]
pub(crate) enum BatchTerminalState {
Processed,
FailedAcquiring,
}
pub(crate) enum ShutdownMessage {
Register(ShutdownNotifier),
ShutDown,
}
pub(crate) struct ShutdownNotifier(oneshot::Sender<()>);
#[derive(Debug, Clone)]
pub struct WorkerHandle {
shutdown_tx: mpsc::Sender<ShutdownMessage>,
}
#[derive(Debug)]
pub(crate) struct WorkerDropGuard {
handle: JoinHandle<()>,
}
impl<P: Processor> Worker<P> {
pub fn spawn(
batcher_name: String,
processor: P,
limits: Limits,
batching_policy: BatchingPolicy,
) -> (WorkerHandle, WorkerDropGuard, mpsc::Sender<BatchItem<P>>) {
let (item_tx, item_rx) = mpsc::channel(limits.max_items_in_system_per_key());
let (msg_tx, msg_rx) = mpsc::channel(limits.max_items_in_system_per_key());
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let mut worker = Worker {
batcher_name,
item_rx,
processor,
msg_tx,
msg_rx,
shutdown_notifier_rx: shutdown_rx,
shutdown_notifiers: Vec::new(),
shutting_down: false,
limits,
batching_policy,
batch_queues: HashMap::new(),
};
let handle = tokio::spawn(async move {
worker.run().await;
});
(
WorkerHandle { shutdown_tx },
WorkerDropGuard { handle },
item_tx,
)
}
fn add(&mut self, item: BatchItem<P>) {
let key = item.key.clone();
let batch_queue = self.batch_queues.entry(key.clone()).or_insert_with(|| {
BatchQueue::new(self.batcher_name.clone(), key.clone(), self.limits)
});
match self.batching_policy.on_add(batch_queue) {
OnAdd::AddAndProcess => {
batch_queue.push(item);
self.process_next_batch(&key);
}
OnAdd::AddAndAcquireResources => {
batch_queue.push(item);
batch_queue.pre_acquire_resources(self.processor.clone(), self.msg_tx.clone());
}
OnAdd::AddAndProcessAfter(duration) => {
batch_queue.push(item);
batch_queue.process_after(duration, self.msg_tx.clone());
}
OnAdd::Add => {
batch_queue.push(item);
}
OnAdd::Reject(reason) => {
if item
.tx
.send((Err(BatchError::Rejected(reason)), None))
.is_err()
{
debug!(
"Unable to send output over oneshot channel. Receiver deallocated. Batcher: {}",
self.batcher_name
);
}
}
}
}
fn process_generation(&mut self, key: P::Key, generation: Generation) {
let batch_queue = self.batch_queues.get_mut(&key).expect("batch should exist");
batch_queue.process_generation(generation, self.processor.clone(), self.msg_tx.clone());
}
fn process_next_ready_batch(&mut self, key: &P::Key) {
let batch_queue = self
.batch_queues
.get_mut(key)
.expect("batch queue should exist");
batch_queue.process_next_ready_batch(self.processor.clone(), self.msg_tx.clone());
}
fn process_next_batch(&mut self, key: &P::Key) {
let batch_queue = self
.batch_queues
.get_mut(key)
.expect("batch queue should exist");
batch_queue.process_next_batch(self.processor.clone(), self.msg_tx.clone());
}
fn on_timeout(&mut self, key: P::Key, generation: Generation) {
let batch_queue = self
.batch_queues
.get_mut(&key)
.expect("batch queue should exist");
match self.batching_policy.on_timeout(generation, batch_queue) {
OnGenerationEvent::Process => {
self.process_generation(key, generation);
}
OnGenerationEvent::DoNothing => {}
}
}
fn on_resource_acquired(&mut self, key: P::Key, generation: Generation) {
let batch_queue = self
.batch_queues
.get_mut(&key)
.expect("batch queue should exist");
batch_queue.mark_resource_acquisition_finished();
match self
.batching_policy
.on_resources_acquired(generation, batch_queue)
{
OnGenerationEvent::Process => {
self.process_generation(key, generation);
}
OnGenerationEvent::DoNothing => {}
}
}
fn on_resource_acquisition_failed(
&mut self,
key: P::Key,
generation: Generation,
err: BatchError<P::Error>,
) {
let batch_queue = self
.batch_queues
.get_mut(&key)
.expect("batch queue should exist");
batch_queue.fail_generation(generation, err.clone(), self.msg_tx.clone());
}
fn on_batch_finished(&mut self, key: &P::Key, terminal_state: BatchTerminalState) {
let batch_queue = self
.batch_queues
.get_mut(key)
.expect("batch queue should exist");
match terminal_state {
BatchTerminalState::Processed => {
batch_queue.mark_processed();
}
BatchTerminalState::FailedAcquiring => {
batch_queue.mark_resource_acquisition_finished();
}
}
match self.batching_policy.on_finish(batch_queue) {
OnFinish::ProcessNextReady => {
self.process_next_ready_batch(key);
}
OnFinish::ProcessNext => {
self.process_next_batch(key);
}
OnFinish::DoNothing => {}
}
}
fn ready_to_shut_down(&self) -> bool {
self.shutting_down
&& self.batch_queues.values().all(|q| q.is_empty())
&& !self.batch_queues.values().any(|q| q.is_processing())
}
async fn run(&mut self) {
loop {
tokio::select! {
Some(msg) = self.shutdown_notifier_rx.recv() => {
match msg {
ShutdownMessage::Register(notifier) => {
self.shutdown_notifiers.push(notifier.0);
}
ShutdownMessage::ShutDown => {
self.shutting_down = true;
}
}
}
Some(item) = self.item_rx.recv() => {
self.add(item);
}
Some(msg) = self.msg_rx.recv() => {
match msg {
Message::ResourcesAcquired(key, generation) => {
self.on_resource_acquired(key, generation);
}
Message::ResourceAcquisitionFailed(key, generation, err) => {
self.on_resource_acquisition_failed(key, generation, err);
}
Message::TimedOut(key, generation) => {
self.on_timeout(key, generation);
}
Message::Finished(key, terminal_state) => {
self.on_batch_finished(&key, terminal_state);
}
}
}
}
if self.ready_to_shut_down() {
info!("Batch worker '{}' is shutting down", &self.batcher_name);
return;
}
}
}
}
impl WorkerHandle {
pub async fn shut_down(&self) {
let _ = self.shutdown_tx.send(ShutdownMessage::ShutDown).await;
}
pub async fn wait_for_shutdown(&self) {
let (notifier_tx, notifier_rx) = oneshot::channel();
let _ = self
.shutdown_tx
.send(ShutdownMessage::Register(ShutdownNotifier(notifier_tx)))
.await;
let _ = notifier_rx.await;
}
}
impl Drop for WorkerDropGuard {
fn drop(&mut self) {
self.handle.abort();
}
}
#[cfg(test)]
mod test {
use tokio::sync::oneshot;
use tracing::Span;
use super::*;
#[derive(Debug, Clone)]
struct SimpleBatchProcessor;
impl Processor for SimpleBatchProcessor {
type Key = String;
type Input = String;
type Output = String;
type Error = String;
type Resources = ();
async fn acquire_resources(&self, _key: String) -> Result<(), String> {
Ok(())
}
async fn process(
&self,
_key: String,
inputs: impl Iterator<Item = String> + Send,
_resources: (),
) -> Result<Vec<String>, String> {
Ok(inputs.map(|s| s + " processed").collect())
}
}
#[tokio::test]
async fn simple_test_over_channel() {
let (_worker_handle, _worker_guard, item_tx) = Worker::<SimpleBatchProcessor>::spawn(
"test".to_string(),
SimpleBatchProcessor,
Limits::builder().max_batch_size(2).build(),
BatchingPolicy::Size,
);
let rx1 = {
let (tx, rx) = oneshot::channel();
item_tx
.send(BatchItem {
key: "K1".to_string(),
input: "I1".to_string(),
submitted_at: tokio::time::Instant::now(),
tx,
requesting_span: Span::none(),
})
.await
.unwrap();
rx
};
let rx2 = {
let (tx, rx) = oneshot::channel();
item_tx
.send(BatchItem {
key: "K1".to_string(),
input: "I2".to_string(),
submitted_at: tokio::time::Instant::now(),
tx,
requesting_span: Span::none(),
})
.await
.unwrap();
rx
};
let o1 = rx1.await.unwrap().0.unwrap();
let o2 = rx2.await.unwrap().0.unwrap();
assert_eq!(o1, "I1 processed".to_string());
assert_eq!(o2, "I2 processed".to_string());
}
}