use std::{future::Future, panic::AssertUnwindSafe, pin::Pin, sync::Arc, time::Duration};
use async_trait::async_trait;
use crossbeam_skiplist::SkipMap;
use fail_parallel::{fail_point, FailPointRegistry};
use futures::{
future::BoxFuture,
stream::{BoxStream, FuturesUnordered},
FutureExt, StreamExt,
};
use log::error;
use parking_lot::Mutex;
use tokio::{runtime::Handle, sync::mpsc, task::JoinHandle};
use tokio_util::{sync::CancellationToken, task::JoinMap};
use crate::{
clock::{SystemClock, SystemClockTicker},
error::SlateDBError,
utils::{panic_string, split_join_result, split_unwind_result, WatchableOnceCell},
};
pub(crate) type MessageFactory<T> = dyn Fn() -> T + Send;
struct MessageDispatcher<T: Send + std::fmt::Debug> {
handler: Box<dyn MessageHandler<T>>,
rx: mpsc::UnboundedReceiver<T>,
clock: Arc<dyn SystemClock>,
cancellation_token: CancellationToken,
#[allow(dead_code)]
fp_registry: Arc<FailPointRegistry>,
}
impl<T: Send + std::fmt::Debug> MessageDispatcher<T> {
#[allow(dead_code)]
fn new(
handler: Box<dyn MessageHandler<T>>,
rx: mpsc::UnboundedReceiver<T>,
clock: Arc<dyn SystemClock>,
cancellation_token: CancellationToken,
) -> Self {
Self {
handler,
rx,
clock,
cancellation_token,
fp_registry: Arc::new(FailPointRegistry::new()),
}
}
#[allow(dead_code)]
fn with_fp_registry(mut self, fp_registry: Arc<FailPointRegistry>) -> Self {
self.fp_registry = fp_registry;
self
}
async fn run(&mut self) -> Result<(), SlateDBError> {
let mut tickers = self
.handler
.tickers()
.into_iter()
.map(|(dur, factory)| MessageDispatcherTicker::new(self.clock.ticker(dur), factory))
.collect::<Vec<_>>();
let mut ticker_futures: FuturesUnordered<_> =
tickers.iter_mut().map(|t| t.tick()).collect();
loop {
fail_point!(Arc::clone(&self.fp_registry), "dispatcher-run-loop", |_| {
Err(SlateDBError::Fenced)
});
tokio::select! {
biased;
_ = self.cancellation_token.cancelled() => {
break;
}
Some(message) = self.rx.recv() => {
self.handler.handle(message).await?;
},
Some((message, ticker)) = ticker_futures.next() => {
self.handler.handle(message).await?;
ticker_futures.push(ticker.tick());
},
}
}
Ok(())
}
async fn cleanup(&mut self, result: Result<(), SlateDBError>) -> Result<(), SlateDBError> {
fail_point!(Arc::clone(&self.fp_registry), "dispatcher-cleanup", |_| {
Err(SlateDBError::Fenced)
});
self.rx.close();
let messages = futures::stream::unfold(&mut self.rx, |rx| async move {
rx.recv().await.map(|message| (message, rx))
});
self.handler.cleanup(Box::pin(messages), result).await
}
}
struct MessageDispatcherTicker<'a, T: Send> {
inner: SystemClockTicker<'a>,
message_factory: Box<MessageFactory<T>>,
}
impl<'a, T: Send> MessageDispatcherTicker<'a, T> {
fn new(inner: SystemClockTicker<'a>, message_factory: Box<MessageFactory<T>>) -> Self {
Self {
inner,
message_factory,
}
}
fn tick(&mut self) -> Pin<Box<dyn Future<Output = (T, &mut Self)> + Send + '_>> {
let message = (self.message_factory)();
Box::pin(async move {
self.inner.tick().await;
(message, self)
})
}
}
#[async_trait]
pub(crate) trait MessageHandler<T: Send>: Send {
fn tickers(&mut self) -> Vec<(Duration, Box<MessageFactory<T>>)> {
vec![]
}
async fn handle(&mut self, message: T) -> Result<(), SlateDBError>;
async fn cleanup(
&mut self,
messages: BoxStream<'async_trait, T>,
result: Result<(), SlateDBError>,
) -> Result<(), SlateDBError>;
}
struct MessageHandlerFuture {
name: String,
future: BoxFuture<'static, Result<(), SlateDBError>>,
token: CancellationToken,
handle: Handle,
}
impl std::fmt::Debug for MessageHandlerFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MessageHandlerFuture")
.field("name", &self.name)
.finish()
}
}
#[derive(Debug)]
pub(crate) struct MessageHandlerExecutor {
futures: Mutex<Option<Vec<MessageHandlerFuture>>>,
tokens: SkipMap<String, CancellationToken>,
results: Arc<SkipMap<String, WatchableOnceCell<Result<(), SlateDBError>>>>,
closed_result: WatchableOnceCell<Result<(), SlateDBError>>,
clock: Arc<dyn SystemClock>,
#[allow(dead_code)]
fp_registry: Arc<FailPointRegistry>,
}
impl MessageHandlerExecutor {
pub(crate) fn new(
closed_result: WatchableOnceCell<Result<(), SlateDBError>>,
clock: Arc<dyn SystemClock>,
) -> Self {
Self {
futures: Mutex::new(Some(vec![])),
closed_result,
clock,
tokens: SkipMap::new(),
results: Arc::new(SkipMap::new()),
fp_registry: Arc::new(FailPointRegistry::new()),
}
}
#[allow(dead_code)]
pub(crate) fn with_fp_registry(mut self, fp_registry: Arc<FailPointRegistry>) -> Self {
self.fp_registry = fp_registry;
self
}
#[allow(clippy::panic)] pub(crate) fn add_handler<T: Send + std::fmt::Debug + 'static>(
&self,
name: String,
handler: Box<dyn MessageHandler<T>>,
rx: mpsc::UnboundedReceiver<T>,
handle: &Handle,
) -> Result<(), SlateDBError> {
let token = CancellationToken::new();
let mut dispatcher = MessageDispatcher::new(handler, rx, self.clock.clone(), token.clone())
.with_fp_registry(self.fp_registry.clone());
let this_closed_result = self.closed_result.clone();
let this_name = name.clone();
#[allow(unused_variables)]
let this_fp_registry = self.fp_registry.clone();
let task_future = async move {
let run_unwind_result = AssertUnwindSafe(dispatcher.run()).catch_unwind().await;
let (run_result, run_maybe_panic) =
split_unwind_result(this_name.clone(), run_unwind_result);
if let Err(ref err) = run_result {
error!(
"background task panicked unexpectedly. [task_name={}, error={:?}, panic={:?}]",
this_name,
err,
run_maybe_panic.map(|p| panic_string(&p))
);
}
fail_point!(
this_fp_registry.clone(),
"executor-wrapper-before-write",
|_| {
panic!("failpoint: executor-wrapper-before-write");
}
);
this_closed_result.write(run_result.clone());
let final_result = this_closed_result
.reader()
.read()
.expect("error state was unexpectedly empty");
let cleanup_unwind_result = AssertUnwindSafe(dispatcher.cleanup(final_result))
.catch_unwind()
.await;
let (cleanup_result, cleanup_maybe_panic) =
split_unwind_result(this_name.clone(), cleanup_unwind_result);
if let Err(err) = cleanup_result {
error!(
"background task failed to clean up on shutdown [name={}, error={:?}, panic={:?}]",
this_name.clone(),
err,
cleanup_maybe_panic.map(|p| panic_string(&p))
);
}
run_result
};
let mut guard = self.futures.lock();
if let Some(task_definitions) = guard.as_mut() {
if task_definitions.iter().any(|t| t.name == name) {
return Err(SlateDBError::BackgroundTaskExists(name));
}
task_definitions.push(MessageHandlerFuture {
name,
future: Box::pin(task_future),
token,
handle: handle.clone(),
});
Ok(())
} else {
Err(SlateDBError::BackgroundTaskExecutorStarted)
}
}
pub(crate) fn monitor_on(&self, handle: &Handle) -> Result<JoinHandle<()>, SlateDBError> {
let mut task_definitions = {
let mut guard = self.futures.lock();
if let Some(task_definitions) = guard.take() {
task_definitions
} else {
return Err(SlateDBError::BackgroundTaskExecutorStarted);
}
};
let mut tasks = JoinMap::new();
for task_definition in task_definitions.drain(..) {
self.tokens
.insert(task_definition.name.clone(), task_definition.token.clone());
self.results
.insert(task_definition.name.clone(), WatchableOnceCell::new());
tasks.spawn_on(
task_definition.name.clone(),
task_definition.future,
&task_definition.handle,
);
}
let this_closed_result = self.closed_result.clone();
let this_results = self.results.clone();
let this_tokens = self
.tokens
.iter()
.map(|e| e.value().clone())
.collect::<Vec<_>>();
let monitor_future = async move {
while !tasks.is_empty() {
if let Some((name, join_result)) = tasks.join_next().await {
let (task_result, task_maybe_panic) =
split_join_result(name.clone(), join_result);
if let Err(ref err) = task_result {
error!(
"background task failed [name={}, error={:?}, panic={:?}]",
name,
err,
task_maybe_panic.map(|p| panic_string(&p))
);
}
this_closed_result.write(task_result.clone());
let entry = this_results
.get(&name)
.expect("result cell isn't set when expected");
let result_cell = entry.value();
result_cell.write(task_result.clone());
if task_result.is_err() {
this_tokens.iter().for_each(|t| t.cancel());
}
}
}
};
Ok(handle.spawn(monitor_future))
}
pub(crate) fn cancel_task(&self, name: &str) {
if let Some(entry) = self.tokens.get(name) {
entry.value().cancel();
}
}
pub(crate) async fn join_task(&self, name: &str) -> Result<(), SlateDBError> {
if let Some(entry) = self.results.get(name) {
return entry.value().reader().await_value().await;
}
Ok(())
}
pub(crate) async fn shutdown_task(&self, name: &str) -> Result<(), SlateDBError> {
self.cancel_task(name);
self.join_task(name).await
}
}
#[cfg(all(test, feature = "test-util"))]
mod test {
use super::{MessageDispatcher, MessageHandler};
use crate::clock::{DefaultSystemClock, MockSystemClock, SystemClock};
use crate::dispatcher::{MessageFactory, MessageHandlerExecutor};
use crate::error::SlateDBError;
use crate::utils::WatchableOnceCell;
use fail_parallel::FailPointRegistry;
use futures::stream::BoxStream;
use futures::StreamExt;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::runtime::Handle;
use tokio::sync::mpsc;
use tokio::task::yield_now;
use tokio::time::timeout;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, PartialEq, Eq)]
enum TestMessage {
Channel(i32),
Tick(i32),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Phase {
Pre,
Cleanup,
}
#[derive(Clone)]
struct TestHandler {
log: Arc<Mutex<Vec<(Phase, TestMessage)>>>,
cleanup_called: WatchableOnceCell<Result<(), SlateDBError>>,
tickers: Vec<(Duration, u8)>,
clock: Arc<dyn SystemClock>,
clock_schedule: VecDeque<Duration>,
}
impl TestHandler {
fn new(
log: Arc<Mutex<Vec<(Phase, TestMessage)>>>,
cleanup_called: WatchableOnceCell<Result<(), SlateDBError>>,
clock: Arc<dyn SystemClock>,
) -> Self {
Self {
log,
cleanup_called,
tickers: vec![],
clock,
clock_schedule: VecDeque::new(),
}
}
fn add_ticker(mut self, d: Duration, id: u8) -> Self {
self.tickers.push((d, id));
self
}
fn add_clock_schedule(mut self, ts: u64) -> Self {
self.clock_schedule.push_back(Duration::from_millis(ts));
self
}
}
#[async_trait::async_trait]
impl MessageHandler<TestMessage> for TestHandler {
fn tickers(&mut self) -> Vec<(Duration, Box<MessageFactory<TestMessage>>)> {
let mut tickers: Vec<(Duration, Box<MessageFactory<_>>)> = vec![];
for (interval, id) in self.tickers.iter() {
let id = *id as i32;
tickers.push((*interval, Box::new(move || TestMessage::Tick(id))));
}
tickers
}
async fn handle(&mut self, message: TestMessage) -> Result<(), SlateDBError> {
self.log.lock().unwrap().push((Phase::Pre, message));
if let Some(advance_duration) = self.clock_schedule.pop_front() {
self.clock.advance(advance_duration).await;
}
Ok(())
}
async fn cleanup(
&mut self,
mut messages: futures::stream::BoxStream<'async_trait, TestMessage>,
result: Result<(), SlateDBError>,
) -> Result<(), SlateDBError> {
self.cleanup_called.write(result);
while let Some(m) = messages.next().await {
self.log.lock().unwrap().push((Phase::Cleanup, m));
}
Ok(())
}
}
async fn wait_for_message_count(log: Arc<Mutex<Vec<(Phase, TestMessage)>>>, count: usize) {
timeout(Duration::from_secs(30), async move {
while log.lock().unwrap().len() < count {
yield_now().await;
}
})
.await
.expect("timeout waiting for message count");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_dispatcher_run_happy_path() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let (tx, rx) = mpsc::unbounded_channel();
let clock = Arc::new(MockSystemClock::new());
let handler = TestHandler::new(log.clone(), WatchableOnceCell::new(), clock.clone())
.add_ticker(Duration::from_millis(5), 1)
.add_clock_schedule(5); let cancellation_token = CancellationToken::new();
let mut dispatcher = MessageDispatcher::new(
Box::new(handler),
rx,
clock.clone(),
cancellation_token.clone(),
);
let join = tokio::spawn(async move { dispatcher.run().await });
tx.send(TestMessage::Channel(10)).unwrap();
wait_for_message_count(log.clone(), 2).await;
tx.send(TestMessage::Channel(20)).unwrap();
wait_for_message_count(log.clone(), 3).await;
cancellation_token.cancel();
let result = timeout(Duration::from_secs(30), join)
.await
.expect("dispatcher did not stop in time")
.expect("join failed");
assert!(matches!(result, Ok(())));
let messages = log.lock().unwrap().clone();
assert_eq!(
messages,
vec![
(Phase::Pre, TestMessage::Channel(10)),
(Phase::Pre, TestMessage::Tick(1)),
(Phase::Pre, TestMessage::Channel(20))
]
);
}
#[cfg(feature = "test-util")]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_executor_propagates_handler_error_drains_messages() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let cleanup_called = WatchableOnceCell::new();
let mut cleanup_reader = cleanup_called.reader();
let (tx, rx) = mpsc::unbounded_channel();
let clock = Arc::new(MockSystemClock::new());
let handler = TestHandler::new(log.clone(), cleanup_called.clone(), clock.clone());
let closed_result = WatchableOnceCell::new();
let fp_registry = Arc::new(FailPointRegistry::default());
let task_executor = MessageHandlerExecutor::new(closed_result.clone(), clock.clone())
.with_fp_registry(fp_registry.clone());
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "1*off->return").unwrap();
fail_parallel::cfg(fp_registry.clone(), "dispatcher-cleanup", "pause").unwrap();
task_executor
.add_handler(
"test".to_string(),
Box::new(handler),
rx,
&Handle::current(),
)
.expect("spawn failed");
task_executor
.monitor_on(&Handle::current())
.expect("failed to monitor executor");
tx.send(TestMessage::Channel(42)).unwrap();
wait_for_message_count(log.clone(), 1).await;
tx.send(TestMessage::Channel(77)).unwrap();
fail_parallel::cfg(fp_registry.clone(), "dispatcher-cleanup", "off").unwrap();
let _ = cleanup_reader.await_value().await;
let result = timeout(Duration::from_secs(30), task_executor.join_task("test"))
.await
.expect("dispatcher did not stop in time");
assert!(matches!(result, Err(SlateDBError::Fenced)));
assert!(matches!(
closed_result.reader().read(),
Some(Err(SlateDBError::Fenced))
));
assert!(matches!(
cleanup_reader.read(),
Some(Err(SlateDBError::Fenced))
));
let messages = log.lock().unwrap().clone();
assert_eq!(
messages,
vec![
(Phase::Pre, TestMessage::Channel(42)),
(Phase::Cleanup, TestMessage::Channel(77)),
]
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_dispatcher_prioritizes_messages_over_tickers() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let (tx, rx) = mpsc::unbounded_channel();
let clock = Arc::new(MockSystemClock::new());
let handler = TestHandler::new(log.clone(), WatchableOnceCell::new(), clock.clone())
.add_ticker(Duration::from_millis(5), 1);
let cancellation_token = CancellationToken::new();
let fp_registry = Arc::new(FailPointRegistry::default());
let mut dispatcher = MessageDispatcher::new(
Box::new(handler),
rx,
clock.clone(),
cancellation_token.clone(),
)
.with_fp_registry(fp_registry.clone());
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "pause").unwrap();
let join = tokio::spawn(async move { dispatcher.run().await });
clock.advance(Duration::from_millis(5)).await;
tx.send(TestMessage::Channel(99)).unwrap();
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "off").unwrap();
wait_for_message_count(log.clone(), 2).await;
cancellation_token.cancel();
let result = timeout(Duration::from_secs(30), join)
.await
.expect("dispatcher did not stop in time")
.expect("join failed");
assert!(matches!(result, Ok(())));
let messages = log.lock().unwrap().clone();
assert_eq!(
messages,
vec![
(Phase::Pre, TestMessage::Channel(99)),
(Phase::Pre, TestMessage::Tick(1)),
]
);
}
#[cfg(feature = "test-util")]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_dispatcher_supports_multiple_tickers() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let (_tx, rx) = mpsc::unbounded_channel::<TestMessage>();
let clock = Arc::new(MockSystemClock::new());
let handler = TestHandler::new(log.clone(), WatchableOnceCell::new(), clock.clone())
.add_ticker(Duration::from_millis(5), 1)
.add_ticker(Duration::from_millis(7), 2)
.add_clock_schedule(0) .add_clock_schedule(5) .add_clock_schedule(2) .add_clock_schedule(3) .add_clock_schedule(4) .add_clock_schedule(1) .add_clock_schedule(5) .add_clock_schedule(1); let cancellation_token = CancellationToken::new();
let fp_registry = Arc::new(FailPointRegistry::default());
let mut dispatcher = MessageDispatcher::new(
Box::new(handler),
rx,
clock.clone(),
cancellation_token.clone(),
)
.with_fp_registry(fp_registry.clone());
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "pause").unwrap();
let join = tokio::spawn(async move { dispatcher.run().await });
assert_eq!(log.lock().unwrap().len(), 0);
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "off").unwrap();
wait_for_message_count(log.clone(), 9).await;
assert_eq!(
log.lock().unwrap().clone(),
vec![
(Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), ]
);
cancellation_token.cancel();
let result = timeout(Duration::from_secs(30), join)
.await
.expect("dispatcher did not stop in time")
.expect("join failed");
assert!(matches!(result, Ok(())));
assert_eq!(log.lock().unwrap().len(), 9);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_dispatcher_supports_overlapping_tickers() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let (_tx, rx) = mpsc::unbounded_channel::<TestMessage>();
let clock = Arc::new(MockSystemClock::new());
let handler = TestHandler::new(log.clone(), WatchableOnceCell::new(), clock.clone())
.add_ticker(Duration::from_millis(3), 3)
.add_ticker(Duration::from_millis(5), 5)
.add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(2) .add_clock_schedule(1) .add_clock_schedule(3) .add_clock_schedule(1) .add_clock_schedule(2) .add_clock_schedule(3); let cancellation_token = CancellationToken::new();
let fp_registry = Arc::new(FailPointRegistry::default());
let mut dispatcher = MessageDispatcher::new(
Box::new(handler),
rx,
clock.clone(),
cancellation_token.clone(),
)
.with_fp_registry(fp_registry.clone());
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "pause").unwrap();
let join = tokio::spawn(async move { dispatcher.run().await });
assert_eq!(log.lock().unwrap().len(), 0);
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "off").unwrap();
wait_for_message_count(log.clone(), 10).await;
assert_eq!(
log.lock().unwrap().clone()[..8],
vec![
(Phase::Pre, TestMessage::Tick(3)), (Phase::Pre, TestMessage::Tick(5)), (Phase::Pre, TestMessage::Tick(3)), (Phase::Pre, TestMessage::Tick(5)), (Phase::Pre, TestMessage::Tick(3)), (Phase::Pre, TestMessage::Tick(3)), (Phase::Pre, TestMessage::Tick(5)), (Phase::Pre, TestMessage::Tick(3)), ]
);
let mut last_two_ticks = log.lock().unwrap().clone()[8..].to_vec();
last_two_ticks.sort_by(|a, b| match (a.1.clone(), b.1.clone()) {
(TestMessage::Tick(a), TestMessage::Tick(b)) => a.cmp(&b),
_ => panic!("expected ticks"),
});
assert_eq!(
last_two_ticks,
vec![
(Phase::Pre, TestMessage::Tick(3)),
(Phase::Pre, TestMessage::Tick(5))
]
);
cancellation_token.cancel();
let result = timeout(Duration::from_secs(30), join)
.await
.expect("dispatcher did not stop in time")
.expect("join failed");
assert!(matches!(result, Ok(())));
assert_eq!(log.lock().unwrap().len(), 10);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_dispatcher_supports_identical_tickers() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let (_tx, rx) = mpsc::unbounded_channel::<TestMessage>();
let clock = Arc::new(MockSystemClock::new());
let handler = TestHandler::new(log.clone(), WatchableOnceCell::new(), clock.clone())
.add_ticker(Duration::from_millis(3), 1)
.add_ticker(Duration::from_millis(3), 2)
.add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0); let cancellation_token = CancellationToken::new();
let fp_registry = Arc::new(FailPointRegistry::default());
let mut dispatcher = MessageDispatcher::new(
Box::new(handler),
rx,
clock.clone(),
cancellation_token.clone(),
)
.with_fp_registry(fp_registry.clone());
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "pause").unwrap();
let join = tokio::spawn(async move { dispatcher.run().await });
assert_eq!(log.lock().unwrap().len(), 0);
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "off").unwrap();
wait_for_message_count(log.clone(), 16).await;
assert_eq!(
log.lock().unwrap().clone(),
vec![
(Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), ]
);
cancellation_token.cancel();
let result = timeout(Duration::from_secs(30), join)
.await
.expect("dispatcher did not stop in time")
.expect("join failed");
assert!(matches!(result, Ok(())));
assert_eq!(log.lock().unwrap().len(), 16);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_executor_catches_panic_from_run_loop() {
#[derive(Clone)]
struct PanicHandler {
cleanup_called: WatchableOnceCell<Result<(), SlateDBError>>,
}
#[async_trait::async_trait]
impl MessageHandler<u8> for PanicHandler {
fn tickers(&mut self) -> Vec<(Duration, Box<MessageFactory<u8>>)> {
vec![]
}
async fn handle(&mut self, _message: u8) -> Result<(), SlateDBError> {
panic!("intentional panic in handler");
}
async fn cleanup(
&mut self,
mut messages: BoxStream<'async_trait, u8>,
result: Result<(), SlateDBError>,
) -> Result<(), SlateDBError> {
self.cleanup_called.write(result);
while messages.next().await.is_some() {}
Ok(())
}
}
let cleanup_called = WatchableOnceCell::new();
let mut cleanup_reader = cleanup_called.reader();
let (tx, rx) = mpsc::unbounded_channel::<u8>();
let clock = Arc::new(DefaultSystemClock::new());
let handler = PanicHandler { cleanup_called };
let closed_result = WatchableOnceCell::new();
let task_executor = MessageHandlerExecutor::new(closed_result.clone(), clock.clone());
task_executor
.add_handler(
"test".to_string(),
Box::new(handler),
rx,
&Handle::current(),
)
.expect("failed to spawn task");
task_executor
.monitor_on(&Handle::current())
.expect("failed to monitor executor");
tx.send(1u8).unwrap();
let _ = timeout(Duration::from_secs(30), cleanup_reader.await_value())
.await
.expect("timeout waiting for cleanup result");
let result = timeout(Duration::from_secs(30), task_executor.join_task("test"))
.await
.expect("dispatcher did not stop in time");
assert!(matches!(result, Err(SlateDBError::BackgroundTaskPanic(_))));
assert!(matches!(
closed_result.reader().read(),
Some(Err(SlateDBError::BackgroundTaskPanic(_)))
));
assert!(matches!(
cleanup_reader.read(),
Some(Err(SlateDBError::BackgroundTaskPanic(_)))
));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_executor_panic_in_wrapper() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let cleanup_called = WatchableOnceCell::new();
let cleanup_reader = cleanup_called.reader();
let (tx, rx) = mpsc::unbounded_channel::<TestMessage>();
drop(tx); let clock = Arc::new(DefaultSystemClock::new());
let handler = TestHandler::new(log.clone(), cleanup_called.clone(), clock.clone());
let closed_result = WatchableOnceCell::new();
let fp_registry = Arc::new(FailPointRegistry::default());
let task_executor = MessageHandlerExecutor::new(closed_result.clone(), clock.clone())
.with_fp_registry(fp_registry.clone());
fail_parallel::cfg(
fp_registry.clone(),
"executor-wrapper-before-write",
"panic",
)
.unwrap();
task_executor
.add_handler(
"test".to_string(),
Box::new(handler),
rx,
&Handle::current(),
)
.expect("failed to spawn task");
task_executor
.monitor_on(&Handle::current())
.expect("failed to monitor executor");
task_executor.cancel_task("test");
let result = timeout(Duration::from_secs(30), task_executor.join_task("test"))
.await
.expect("dispatcher did not stop in time");
assert!(matches!(result, Err(SlateDBError::BackgroundTaskPanic(_))));
assert!(matches!(
closed_result.reader().read(),
Some(Err(SlateDBError::BackgroundTaskPanic(_)))
));
assert!(cleanup_reader.read().is_none());
}
}