use std::time::Duration;
use tokio::task::{JoinError, JoinSet};
use tokio_util::sync::CancellationToken;
use crate::backend::{Backend, ConsumerImpl};
use crate::consumer::ConsumerOptions;
use crate::error::{Result, ShoveError};
use crate::handler::MessageHandler;
use crate::topic::{SequencedTopic, Topic};
#[must_use]
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct SupervisorOutcome {
pub errors: usize,
pub panics: usize,
pub timed_out: bool,
}
impl SupervisorOutcome {
pub fn exit_code(&self) -> i32 {
if self.timed_out {
3
} else if self.panics > 0 {
2
} else if self.errors > 0 {
1
} else {
0
}
}
pub fn is_clean(&self) -> bool {
self.exit_code() == 0
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, Default)]
pub(crate) struct ShutdownTally {
pub errors: usize,
pub panics: usize,
}
#[allow(dead_code)]
impl ShutdownTally {
pub(crate) fn add(&mut self, other: ShutdownTally) {
self.errors += other.errors;
self.panics += other.panics;
}
}
pub(crate) struct AbortOnDrop(pub(crate) tokio::task::AbortHandle);
impl Drop for AbortOnDrop {
fn drop(&mut self) {
self.0.abort();
}
}
pub(crate) fn tally_join_result(
res: std::result::Result<Result<()>, JoinError>,
errors: &mut usize,
panics: &mut usize,
) {
match res {
Ok(Ok(())) => {}
Ok(Err(e)) => {
tracing::error!(error = %e, "consumer task failed");
*errors += 1;
}
Err(e) if e.is_cancelled() => {}
Err(e) => {
tracing::error!(error = %e, "consumer task panicked");
*panics += 1;
}
}
}
#[allow(dead_code)] pub(crate) async fn drive_fifo_until_timeout<S>(
handles: Vec<tokio::task::JoinHandle<Result<()>>>,
shutdown: CancellationToken,
signal: S,
drain_timeout: Duration,
) -> SupervisorOutcome
where
S: Future<Output = ()> + Send + 'static,
{
let mut joinset: JoinSet<Result<()>> = JoinSet::new();
for handle in handles {
let abort_guard = AbortOnDrop(handle.abort_handle());
joinset.spawn(async move {
let _abort_guard = abort_guard;
match handle.await {
Ok(r) => r,
Err(e) if e.is_cancelled() => Ok(()),
Err(e) => std::panic::resume_unwind(e.into_panic()),
}
});
}
let mut errors = 0usize;
let mut panics = 0usize;
let shards_done = {
let errors = &mut errors;
let panics = &mut panics;
let joinset = &mut joinset;
async move {
while let Some(res) = joinset.join_next().await {
tally_join_result(res, errors, panics);
}
}
};
let signal_won = tokio::select! {
biased;
_ = signal => true,
_ = shards_done => false,
};
if !signal_won {
return SupervisorOutcome {
errors,
panics,
timed_out: false,
};
}
shutdown.cancel();
let drain = async {
while let Some(res) = joinset.join_next().await {
tally_join_result(res, &mut errors, &mut panics);
}
};
match tokio::time::timeout(drain_timeout, drain).await {
Ok(()) => SupervisorOutcome {
errors,
panics,
timed_out: false,
},
Err(_) => {
tracing::warn!(
timeout_ms = drain_timeout.as_millis() as u64,
"run_fifo_until_timeout: drain timed out; aborting surviving shards"
);
joinset.abort_all();
while let Some(res) = joinset.join_next().await {
tally_join_result(res, &mut errors, &mut panics);
}
SupervisorOutcome {
errors,
panics,
timed_out: true,
}
}
}
}
pub struct ConsumerSupervisor<B: Backend, Ctx: Clone + Send + Sync + 'static = ()> {
consumer: B::ConsumerImpl,
ctx: Ctx,
shutdown: CancellationToken,
tasks: JoinSet<Result<()>>,
registered: std::collections::HashSet<&'static str>,
}
impl<B: Backend> ConsumerSupervisor<B, ()> {
pub(crate) fn new(client: &B::Client) -> Self {
Self {
consumer: B::make_consumer(client),
ctx: (),
shutdown: CancellationToken::new(),
tasks: JoinSet::new(),
registered: std::collections::HashSet::new(),
}
}
pub fn with_context<Ctx: Clone + Send + Sync + 'static>(
self,
ctx: Ctx,
) -> ConsumerSupervisor<B, Ctx> {
ConsumerSupervisor {
consumer: self.consumer,
ctx,
shutdown: self.shutdown,
tasks: self.tasks,
registered: self.registered,
}
}
}
impl<B: Backend, Ctx: Clone + Send + Sync + 'static> ConsumerSupervisor<B, Ctx> {
pub fn cancellation_token(&self) -> CancellationToken {
self.shutdown.clone()
}
pub fn register<T, H>(&mut self, handler: H, options: ConsumerOptions<B>) -> Result<()>
where
T: Topic,
H: MessageHandler<T, Context = Ctx>,
{
let queue = T::topology().queue();
if T::topology().sequencing().is_some() {
return Err(ShoveError::Topology(format!(
"topic '{queue}' has a sequencing config; `ConsumerSupervisor::register` \
would silently drop FIFO ordering. Use `register_fifo` instead."
)));
}
if !self.registered.insert(queue) {
return Err(ShoveError::Topology(format!(
"topic '{queue}' is already registered on this supervisor"
)));
}
let consumer = self.consumer.clone();
let ctx = self.ctx.clone();
let inner = options.with_shutdown(self.shutdown.clone()).into_inner();
self.tasks
.spawn(async move { consumer.run::<T, H>(handler, ctx, inner).await });
Ok(())
}
pub async fn register_fifo<T, H>(
&mut self,
handler: H,
options: ConsumerOptions<B>,
) -> Result<()>
where
T: SequencedTopic,
H: MessageHandler<T, Context = Ctx>,
{
let queue = T::topology().queue();
if T::topology().sequencing().is_none() {
return Err(ShoveError::Topology(format!(
"topic '{queue}' implements `SequencedTopic` but its topology has no \
sequencing config; `ConsumerSupervisor::register_fifo` would attach to \
FIFO shard queues that were never declared. Use `register` for \
unsequenced topics, or add `.sequenced(...)` to the topology."
)));
}
if !self.registered.insert(queue) {
return Err(ShoveError::Topology(format!(
"topic '{queue}' is already registered on this supervisor"
)));
}
let ctx = self.ctx.clone();
let inner = options.with_shutdown(self.shutdown.clone()).into_inner();
let handles = self
.consumer
.spawn_fifo_shards::<T, H>(handler, ctx, inner)
.await?;
for handle in handles {
let abort_guard = AbortOnDrop(handle.abort_handle());
self.tasks.spawn(async move {
let _abort_guard = abort_guard;
match handle.await {
Ok(r) => r,
Err(e) if e.is_cancelled() => Ok(()),
Err(e) => std::panic::resume_unwind(e.into_panic()),
}
});
}
Ok(())
}
pub async fn run_until_timeout<S>(
mut self,
signal: S,
drain_timeout: Duration,
) -> SupervisorOutcome
where
S: Future<Output = ()> + Send + 'static,
{
tokio::select! {
_ = signal => { self.shutdown.cancel(); }
_ = self.shutdown.cancelled() => {}
}
let mut errors = 0usize;
let mut panics = 0usize;
let drain = {
let tasks = &mut self.tasks;
let errors = &mut errors;
let panics = &mut panics;
async move {
while let Some(res) = tasks.join_next().await {
tally_join_result(res, errors, panics);
}
}
};
match tokio::time::timeout(drain_timeout, drain).await {
Ok(()) => SupervisorOutcome {
errors,
panics,
timed_out: false,
},
Err(_) => {
tracing::warn!(
timeout_ms = drain_timeout.as_millis() as u64,
"drain timeout elapsed; aborting surviving tasks"
);
self.tasks.abort_all();
while let Some(res) = self.tasks.join_next().await {
tally_join_result(res, &mut errors, &mut panics);
}
SupervisorOutcome {
errors,
panics,
timed_out: true,
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn clean_outcome_has_exit_code_zero() {
assert_eq!(SupervisorOutcome::default().exit_code(), 0);
assert!(SupervisorOutcome::default().is_clean());
}
#[test]
fn errors_produce_exit_code_one() {
let o = SupervisorOutcome {
errors: 3,
panics: 0,
timed_out: false,
};
assert_eq!(o.exit_code(), 1);
}
#[test]
fn panics_outrank_errors() {
let o = SupervisorOutcome {
errors: 3,
panics: 1,
timed_out: false,
};
assert_eq!(o.exit_code(), 2);
}
#[test]
fn timeout_outranks_everything() {
let o = SupervisorOutcome {
errors: 3,
panics: 1,
timed_out: true,
};
assert_eq!(o.exit_code(), 3);
}
use crate::error::ShoveError;
#[tokio::test]
async fn tally_increments_errors_on_task_failure() {
let mut errors = 0usize;
let mut panics = 0usize;
tally_join_result(
Ok(Err(ShoveError::Topology("test".into()))),
&mut errors,
&mut panics,
);
assert_eq!(errors, 1);
assert_eq!(panics, 0);
}
#[tokio::test]
async fn tally_increments_panics_on_join_panic() {
let handle = tokio::spawn(async { panic!("boom") });
let join_err = handle.await.unwrap_err();
assert!(join_err.is_panic());
let mut errors = 0usize;
let mut panics = 0usize;
tally_join_result(Err(join_err), &mut errors, &mut panics);
assert_eq!(panics, 1);
assert_eq!(errors, 0);
}
#[tokio::test]
async fn tally_ignores_cancellation() {
let handle: tokio::task::JoinHandle<Result<()>> = tokio::spawn(async {
tokio::time::sleep(Duration::from_secs(60)).await;
Ok(())
});
handle.abort();
let join_err = handle.await.unwrap_err();
assert!(join_err.is_cancelled());
let mut errors = 0usize;
let mut panics = 0usize;
tally_join_result(Err(join_err), &mut errors, &mut panics);
assert_eq!(errors, 0);
assert_eq!(panics, 0);
}
#[test]
fn tally_does_not_count_success() {
let mut errors = 0usize;
let mut panics = 0usize;
tally_join_result(Ok(Ok(())), &mut errors, &mut panics);
assert_eq!(errors, 0);
assert_eq!(panics, 0);
}
}
#[cfg(all(test, feature = "inmemory"))]
mod inmemory_tests {
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::consumer::ConsumerOptions;
use crate::define_sequenced_topic;
use crate::error::ShoveError;
use crate::inmemory::InMemoryConfig;
use crate::markers::InMemory;
use crate::topic::SequencedTopic;
use crate::topology::{SequenceFailure, TopologyBuilder};
use crate::{Broker, MessageHandler, MessageMetadata, Outcome};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LedgerEntry {
account_id: String,
}
define_sequenced_topic!(
Ledger,
LedgerEntry,
|msg| msg.account_id.clone(),
TopologyBuilder::new("supervisor-ledger-test")
.sequenced(SequenceFailure::FailAll)
.hold_queue(Duration::from_millis(50))
.dlq()
.build()
);
struct NoopHandler;
impl MessageHandler<Ledger> for NoopHandler {
type Context = ();
async fn handle(&self, _: LedgerEntry, _: MessageMetadata, _: &()) -> Outcome {
Outcome::Ack
}
}
#[tokio::test]
async fn register_fifo_runs_cleanly() {
let broker = Broker::<InMemory>::new(InMemoryConfig::default())
.await
.expect("broker");
broker
.topology()
.declare::<Ledger>()
.await
.expect("declare");
let mut sup = broker.consumer_supervisor();
sup.register_fifo::<Ledger, _>(NoopHandler, ConsumerOptions::<InMemory>::new())
.await
.expect("register_fifo");
let token = sup.cancellation_token();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
token.cancel();
});
let outcome = sup
.run_until_timeout(std::future::pending::<()>(), Duration::from_secs(1))
.await;
assert!(outcome.is_clean(), "unexpected outcome: {outcome:?}");
}
#[tokio::test]
async fn register_rejection_points_at_register_fifo() {
let broker = Broker::<InMemory>::new(InMemoryConfig::default())
.await
.expect("broker");
broker
.topology()
.declare::<Ledger>()
.await
.expect("declare");
let mut sup = broker.consumer_supervisor();
let result = sup.register::<Ledger, _>(NoopHandler, ConsumerOptions::<InMemory>::new());
match result {
Err(ShoveError::Topology(msg)) => {
assert!(msg.contains("register_fifo"), "unexpected msg: {msg}");
}
other => panic!("expected Topology error, got {other:?}"),
}
}
#[tokio::test]
async fn register_fifo_rejects_duplicate_topic() {
let broker = Broker::<InMemory>::new(InMemoryConfig::default())
.await
.expect("broker");
broker
.topology()
.declare::<Ledger>()
.await
.expect("declare");
let mut sup = broker.consumer_supervisor();
sup.register_fifo::<Ledger, _>(NoopHandler, ConsumerOptions::<InMemory>::new())
.await
.expect("first register_fifo should succeed");
let result = sup
.register_fifo::<Ledger, _>(NoopHandler, ConsumerOptions::<InMemory>::new())
.await;
match result {
Err(ShoveError::Topology(msg)) => {
assert!(msg.contains("already registered"), "unexpected msg: {msg}");
}
other => panic!("expected Topology error, got {other:?}"),
}
}
}