use std::future::Future;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};
use crate::runtime::metrics;
#[derive(Debug, Error, PartialEq, Eq)]
pub enum SidejobError {
#[error("sidejob mailbox is full")]
Overloaded,
#[error("sidejob actor has stopped")]
Stopped,
}
pub struct Sidejob<Req, Reply> {
name: &'static str,
tx: mpsc::Sender<(Req, oneshot::Sender<Reply>)>,
full_failures: Arc<AtomicU64>,
}
impl<Req, Reply> Clone for Sidejob<Req, Reply> {
fn clone(&self) -> Self {
Self {
name: self.name,
tx: self.tx.clone(),
full_failures: Arc::clone(&self.full_failures),
}
}
}
impl<Req, Reply> Sidejob<Req, Reply>
where
Req: Send + 'static,
Reply: Send + 'static,
{
pub fn spawn<F, Fut>(name: &'static str, capacity: usize, mut handler: F) -> Self
where
F: FnMut(Req) -> Fut + Send + 'static,
Fut: Future<Output = Reply> + Send + 'static,
{
assert!(capacity > 0, "sidejob capacity must be > 0");
let _ = metrics::sidejob_overload().with_label_values(&[name]);
let (tx, mut rx) = mpsc::channel::<(Req, oneshot::Sender<Reply>)>(capacity);
tokio::spawn(async move {
while let Some((req, reply_tx)) = rx.recv().await {
let fut = handler(req);
match tokio::spawn(fut).await {
Ok(reply) => {
let _ = reply_tx.send(reply);
}
Err(join_err) => {
if join_err.is_panic() {
tracing::warn!(
sidejob = name,
"handler panicked; reply channel dropped"
);
}
drop(reply_tx);
}
}
}
tracing::debug!(sidejob = name, "actor loop exited (channel closed)");
});
Self {
name,
tx,
full_failures: Arc::new(AtomicU64::new(0)),
}
}
pub async fn submit(&self, req: Req) -> Result<Reply, SidejobError> {
let rx = self.try_submit(req)?;
rx.await.map_err(|_| SidejobError::Stopped)
}
pub fn try_submit(&self, req: Req) -> Result<oneshot::Receiver<Reply>, SidejobError> {
let (reply_tx, reply_rx) = oneshot::channel();
match self.tx.try_send((req, reply_tx)) {
Ok(()) => Ok(reply_rx),
Err(mpsc::error::TrySendError::Full(_)) => {
self.full_failures.fetch_add(1, Ordering::Relaxed);
metrics::sidejob_overload()
.with_label_values(&[self.name])
.inc();
Err(SidejobError::Overloaded)
}
Err(mpsc::error::TrySendError::Closed(_)) => Err(SidejobError::Stopped),
}
}
pub fn full_failures(&self) -> u64 {
self.full_failures.load(Ordering::Relaxed)
}
pub fn name(&self) -> &'static str {
self.name
}
}
impl<Req, Reply> std::fmt::Debug for Sidejob<Req, Reply> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sidejob")
.field("name", &self.name)
.field("capacity", &self.tx.max_capacity())
.field("full_failures", &self.full_failures.load(Ordering::Relaxed))
.finish()
}
}