use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Semaphore, mpsc};
#[cfg(feature = "trace")]
use tracing::Instrument;
use super::dispatch::{AsyncListenerMeta, execute_async_listener};
use super::tracker::AsyncTaskTracker;
use super::types::{ControlNotification, ErasedAsyncHandlerFn, EventType};
struct ExternalWorkGuard {
tracker: Arc<AsyncTaskTracker>,
}
impl ExternalWorkGuard {
fn new(tracker: Arc<AsyncTaskTracker>) -> Self {
tracker.track_external();
Self { tracker }
}
}
impl Drop for ExternalWorkGuard {
fn drop(&mut self) {
self.tracker.finish_external();
}
}
pub(crate) struct WorkItem {
pub handler: ErasedAsyncHandlerFn,
pub event: EventType,
pub event_name: &'static str,
pub meta: AsyncListenerMeta,
pub handler_timeout: Option<Duration>,
pub notify_tx: mpsc::UnboundedSender<ControlNotification>,
pub concurrency_semaphore: Option<Arc<Semaphore>>,
#[cfg(feature = "trace")]
pub parent_span: tracing::Span,
_guard: ExternalWorkGuard,
}
pub(crate) struct WorkItemData {
pub handler: ErasedAsyncHandlerFn,
pub event: EventType,
pub event_name: &'static str,
pub meta: AsyncListenerMeta,
pub handler_timeout: Option<Duration>,
pub notify_tx: mpsc::UnboundedSender<ControlNotification>,
pub concurrency_semaphore: Option<Arc<Semaphore>>,
#[cfg(feature = "trace")]
pub parent_span: tracing::Span,
}
impl WorkItem {
pub(crate) fn from_data(data: WorkItemData, tracker: Arc<AsyncTaskTracker>) -> Self {
Self {
handler: data.handler,
event: data.event,
event_name: data.event_name,
meta: data.meta,
handler_timeout: data.handler_timeout,
notify_tx: data.notify_tx,
concurrency_semaphore: data.concurrency_semaphore,
#[cfg(feature = "trace")]
parent_span: data.parent_span,
_guard: ExternalWorkGuard::new(tracker),
}
}
}
#[derive(Clone)]
pub(crate) struct AsyncSlotWorker {
tx: mpsc::UnboundedSender<WorkItem>,
}
impl AsyncSlotWorker {
pub(crate) fn spawn(_tracker: Arc<AsyncTaskTracker>) -> Self {
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(worker_loop(rx));
Self { tx }
}
pub(crate) fn send(&self, item: WorkItem) -> bool {
self.tx.send(item).is_ok()
}
}
async fn worker_loop(mut rx: mpsc::UnboundedReceiver<WorkItem>) {
while let Some(item) = rx.recv().await {
let task = async {
let inner = async {
if let Some(failure) = execute_async_listener(item.handler, item.event, item.event_name, item.meta, item.handler_timeout).await {
let _ = item.notify_tx.send(ControlNotification::Failure(failure));
}
};
if let Some(semaphore) = item.concurrency_semaphore {
if let Ok(_permit) = semaphore.acquire().await {
inner.await;
}
} else {
inner.await;
}
};
#[cfg(feature = "trace")]
let task = task.instrument(item.parent_span);
task.await;
}
}