use std::{
collections::HashMap, future::Future, panic::AssertUnwindSafe, pin::Pin, sync::Arc,
time::Duration,
};
use async_trait::async_trait;
use crossbeam_skiplist::SkipMap;
use fail_parallel::{fail_point, FailPointRegistry};
use futures::{
future::BoxFuture,
stream::{BoxStream, FuturesUnordered},
FutureExt, StreamExt,
};
use log::error;
use parking_lot::Mutex;
use tokio::{runtime::Handle, task::JoinHandle};
use tokio_util::{sync::CancellationToken, task::JoinMap};
use crate::{
db_status::ClosedResultWriter,
error::SlateDBError,
utils::{panic_string, split_join_result, split_unwind_result, WatchableOnceCell},
};
use slatedb_common::clock::{SystemClock, SystemClockTicker};
pub(crate) type MessageFactory<T> = dyn Fn() -> T + Send;
struct MessageDispatcher<T: Send + std::fmt::Debug> {
handler: Box<dyn MessageHandler<T>>,
rx: async_channel::Receiver<T>,
clock: Arc<dyn SystemClock>,
cancellation_token: CancellationToken,
#[allow(dead_code)]
fp_registry: Arc<FailPointRegistry>,
}
impl<T: Send + std::fmt::Debug> MessageDispatcher<T> {
#[allow(dead_code)]
fn new(
handler: Box<dyn MessageHandler<T>>,
rx: async_channel::Receiver<T>,
clock: Arc<dyn SystemClock>,
cancellation_token: CancellationToken,
) -> Self {
Self {
handler,
rx,
clock,
cancellation_token,
fp_registry: Arc::new(FailPointRegistry::new()),
}
}
#[allow(dead_code)]
fn with_fp_registry(mut self, fp_registry: Arc<FailPointRegistry>) -> Self {
self.fp_registry = fp_registry;
self
}
async fn run(&mut self) -> Result<(), SlateDBError> {
let mut tickers = self
.handler
.tickers()
.into_iter()
.map(|(dur, factory)| MessageDispatcherTicker::new(self.clock.ticker(dur), factory))
.collect::<Vec<_>>();
let mut ticker_futures: FuturesUnordered<_> =
tickers.iter_mut().map(|t| t.tick()).collect();
loop {
fail_point!(Arc::clone(&self.fp_registry), "dispatcher-run-loop", |_| {
Err(SlateDBError::Fenced)
});
tokio::select! {
biased;
_ = self.cancellation_token.cancelled() => {
break;
}
Ok(message) = self.rx.recv() => {
self.handler.handle(message).await?;
},
Some((message, ticker)) = ticker_futures.next() => {
self.handler.handle(message).await?;
ticker_futures.push(ticker.tick());
},
}
}
Ok(())
}
#[allow(clippy::panic)] async fn run_lifecycle(
mut self,
name: String,
closed_result: Arc<dyn ClosedResultWriter>,
#[allow(unused_variables)] fp_registry: Arc<FailPointRegistry>,
) -> Result<(), SlateDBError> {
let run_unwind_result = AssertUnwindSafe(self.run()).catch_unwind().await;
let (run_result, run_maybe_panic) = split_unwind_result(name.clone(), run_unwind_result);
if let Err(ref err) = run_result {
error!(
"background task panicked unexpectedly. [task_name={}, error={:?}, panic={:?}]",
name,
err,
run_maybe_panic.map(|p| panic_string(&p))
);
}
fail_point!(fp_registry.clone(), "executor-wrapper-before-write", |_| {
panic!("failpoint: executor-wrapper-before-write");
});
closed_result.write_result(run_result.clone());
let final_result = closed_result
.result_reader()
.read()
.expect("error state was unexpectedly empty");
let cleanup_unwind_result = AssertUnwindSafe(self.cleanup(final_result))
.catch_unwind()
.await;
let (cleanup_result, cleanup_maybe_panic) =
split_unwind_result(name.clone(), cleanup_unwind_result);
if let Err(err) = cleanup_result {
error!(
"background task failed to clean up on shutdown [name={}, error={:?}, panic={:?}]",
name,
err,
cleanup_maybe_panic.map(|p| panic_string(&p))
);
}
run_result
}
async fn cleanup(&mut self, result: Result<(), SlateDBError>) -> Result<(), SlateDBError> {
fail_point!(Arc::clone(&self.fp_registry), "dispatcher-cleanup", |_| {
Err(SlateDBError::Fenced)
});
self.rx.close();
let messages = futures::stream::unfold(&mut self.rx, |rx| async move {
rx.recv().await.ok().map(|message| (message, rx))
});
self.handler.cleanup(Box::pin(messages), result).await
}
}
struct MessageDispatcherTicker<'a, T: Send> {
inner: SystemClockTicker<'a>,
message_factory: Box<MessageFactory<T>>,
}
impl<'a, T: Send> MessageDispatcherTicker<'a, T> {
fn new(inner: SystemClockTicker<'a>, message_factory: Box<MessageFactory<T>>) -> Self {
Self {
inner,
message_factory,
}
}
fn tick(&mut self) -> Pin<Box<dyn Future<Output = (T, &mut Self)> + Send + '_>> {
let message = (self.message_factory)();
Box::pin(async move {
self.inner.tick().await;
(message, self)
})
}
}
#[async_trait]
pub(crate) trait MessageHandler<T: Send>: Send {
fn tickers(&mut self) -> Vec<(Duration, Box<MessageFactory<T>>)> {
vec![]
}
async fn handle(&mut self, message: T) -> Result<(), SlateDBError>;
async fn cleanup(
&mut self,
messages: BoxStream<'async_trait, T>,
result: Result<(), SlateDBError>,
) -> Result<(), SlateDBError>;
}
struct MessageHandlerFuture {
name: String,
group_index: usize,
future: BoxFuture<'static, Result<(), SlateDBError>>,
token: CancellationToken,
handle: Handle,
}
struct TaskGroup {
remaining: usize,
result: Result<(), SlateDBError>,
}
impl TaskGroup {
fn new(count: usize) -> Self {
Self {
remaining: count,
result: Ok(()),
}
}
fn member_completed(&mut self, result: Result<(), SlateDBError>) -> bool {
if self.result.is_ok() {
self.result = result;
}
self.remaining -= 1;
self.remaining == 0
}
}
struct TaskMonitor {
tasks: JoinMap<(String, usize), Result<(), SlateDBError>>,
groups: HashMap<String, TaskGroup>,
closed_result: Arc<dyn ClosedResultWriter>,
results: Arc<SkipMap<String, WatchableOnceCell<Result<(), SlateDBError>>>>,
tokens: Vec<CancellationToken>,
}
impl TaskMonitor {
async fn run(mut self) {
while !self.tasks.is_empty() {
if let Some(((group_name, _), join_result)) = self.tasks.join_next().await {
let (task_result, task_maybe_panic) =
split_join_result(group_name.clone(), join_result);
if let Err(ref err) = task_result {
error!(
"background task failed [name={}, error={:?}, panic={:?}]",
group_name,
err,
task_maybe_panic.map(|p| panic_string(&p))
);
self.closed_result.write_result(task_result.clone());
self.tokens.iter().for_each(|t| t.cancel());
}
let group = self
.groups
.get_mut(&group_name)
.expect("group tracking entry missing");
if group.member_completed(task_result) {
let group_result = self
.groups
.remove(&group_name)
.expect("group tracking entry missing on final member")
.result;
self.results
.get(&group_name)
.expect("result cell isn't set when expected")
.value()
.write(group_result);
}
}
}
}
}
impl std::fmt::Debug for MessageHandlerFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MessageHandlerFuture")
.field("name", &self.name)
.finish()
}
}
#[derive(Debug)]
pub(crate) struct MessageHandlerExecutor {
futures: Mutex<Option<Vec<MessageHandlerFuture>>>,
tokens: SkipMap<String, CancellationToken>,
results: Arc<SkipMap<String, WatchableOnceCell<Result<(), SlateDBError>>>>,
closed_result: Arc<dyn ClosedResultWriter>,
clock: Arc<dyn SystemClock>,
#[allow(dead_code)]
fp_registry: Arc<FailPointRegistry>,
}
impl MessageHandlerExecutor {
pub(crate) fn new(
closed_result: Arc<dyn ClosedResultWriter>,
clock: Arc<dyn SystemClock>,
) -> Self {
Self {
futures: Mutex::new(Some(vec![])),
closed_result,
clock,
tokens: SkipMap::new(),
results: Arc::new(SkipMap::new()),
fp_registry: Arc::new(FailPointRegistry::new()),
}
}
#[allow(dead_code)]
pub(crate) fn with_fp_registry(mut self, fp_registry: Arc<FailPointRegistry>) -> Self {
self.fp_registry = fp_registry;
self
}
pub(crate) fn add_handler<T: Send + std::fmt::Debug + 'static>(
&self,
name: String,
handler: Box<dyn MessageHandler<T>>,
rx: async_channel::Receiver<T>,
handle: &Handle,
) -> Result<(), SlateDBError> {
self.add_handlers(name, vec![handler], rx, handle)
}
pub(crate) fn add_handlers<T: Send + std::fmt::Debug + 'static>(
&self,
name: String,
handlers: Vec<Box<dyn MessageHandler<T>>>,
rx: async_channel::Receiver<T>,
handle: &Handle,
) -> Result<(), SlateDBError> {
assert!(!handlers.is_empty(), "handlers must not be empty");
let token = CancellationToken::new();
let mut futures = Vec::with_capacity(handlers.len());
for (group_index, handler) in handlers.into_iter().enumerate() {
let dispatcher =
MessageDispatcher::new(handler, rx.clone(), self.clock.clone(), token.clone())
.with_fp_registry(self.fp_registry.clone());
let future = dispatcher.run_lifecycle(
name.clone(),
self.closed_result.clone(),
self.fp_registry.clone(),
);
futures.push(MessageHandlerFuture {
name: name.clone(),
group_index,
future: Box::pin(future),
token: token.clone(),
handle: handle.clone(),
});
}
let mut guard = self.futures.lock();
if let Some(task_definitions) = guard.as_mut() {
if task_definitions.iter().any(|t| t.name == name) {
return Err(SlateDBError::BackgroundTaskExists(name));
}
task_definitions.extend(futures);
Ok(())
} else {
Err(SlateDBError::BackgroundTaskExecutorStarted)
}
}
pub(crate) fn monitor_on(&self, handle: &Handle) -> Result<JoinHandle<()>, SlateDBError> {
let mut task_definitions = {
let mut guard = self.futures.lock();
if let Some(task_definitions) = guard.take() {
task_definitions
} else {
return Err(SlateDBError::BackgroundTaskExecutorStarted);
}
};
let mut tasks = JoinMap::new();
let mut groups: HashMap<String, TaskGroup> = HashMap::new();
for task_definition in task_definitions.drain(..) {
if !self.tokens.contains_key(&task_definition.name) {
self.tokens
.insert(task_definition.name.clone(), task_definition.token.clone());
self.results
.insert(task_definition.name.clone(), WatchableOnceCell::new());
}
groups
.entry(task_definition.name.clone())
.or_insert_with(|| TaskGroup::new(0))
.remaining += 1;
tasks.spawn_on(
(task_definition.name.clone(), task_definition.group_index),
task_definition.future,
&task_definition.handle,
);
}
let monitor = TaskMonitor {
tasks,
groups,
closed_result: self.closed_result.clone(),
results: self.results.clone(),
tokens: self.tokens.iter().map(|e| e.value().clone()).collect(),
};
Ok(handle.spawn(monitor.run()))
}
pub(crate) fn cancel_task(&self, name: &str) {
if let Some(entry) = self.tokens.get(name) {
entry.value().cancel();
}
}
pub(crate) async fn join_task(&self, name: &str) -> Result<(), SlateDBError> {
if let Some(entry) = self.results.get(name) {
return entry.value().reader().await_value().await;
}
Ok(())
}
pub(crate) async fn shutdown_task(&self, name: &str) -> Result<(), SlateDBError> {
self.cancel_task(name);
self.join_task(name).await
}
}
#[cfg(all(test, feature = "test-util"))]
mod test {
use super::{MessageDispatcher, MessageHandler};
use crate::db_status::ClosedResultWriter;
use crate::dispatcher::{MessageFactory, MessageHandlerExecutor};
use crate::error::SlateDBError;
use crate::utils::WatchableOnceCell;
use fail_parallel::FailPointRegistry;
use futures::stream::BoxStream;
use futures::StreamExt;
use slatedb_common::clock::{DefaultSystemClock, MockSystemClock, SystemClock};
use std::collections::{HashSet, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::runtime::Handle;
use tokio::task::yield_now;
use tokio::time::timeout;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, PartialEq, Eq)]
enum TestMessage {
Channel(i32),
Tick(i32),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Phase {
Pre,
Cleanup,
}
#[derive(Clone)]
struct TestHandler {
log: Arc<Mutex<Vec<(Phase, TestMessage)>>>,
cleanup_called: WatchableOnceCell<Result<(), SlateDBError>>,
tickers: Vec<(Duration, u8)>,
clock: Arc<dyn SystemClock>,
clock_schedule: VecDeque<Duration>,
}
impl TestHandler {
fn new(
log: Arc<Mutex<Vec<(Phase, TestMessage)>>>,
cleanup_called: WatchableOnceCell<Result<(), SlateDBError>>,
clock: Arc<dyn SystemClock>,
) -> Self {
Self {
log,
cleanup_called,
tickers: vec![],
clock,
clock_schedule: VecDeque::new(),
}
}
fn add_ticker(mut self, d: Duration, id: u8) -> Self {
self.tickers.push((d, id));
self
}
fn add_clock_schedule(mut self, ts: u64) -> Self {
self.clock_schedule.push_back(Duration::from_millis(ts));
self
}
}
#[async_trait::async_trait]
impl MessageHandler<TestMessage> for TestHandler {
fn tickers(&mut self) -> Vec<(Duration, Box<MessageFactory<TestMessage>>)> {
let mut tickers: Vec<(Duration, Box<MessageFactory<_>>)> = vec![];
for (interval, id) in self.tickers.iter() {
let id = *id as i32;
tickers.push((*interval, Box::new(move || TestMessage::Tick(id))));
}
tickers
}
async fn handle(&mut self, message: TestMessage) -> Result<(), SlateDBError> {
self.log.lock().unwrap().push((Phase::Pre, message));
if let Some(advance_duration) = self.clock_schedule.pop_front() {
self.clock.advance(advance_duration).await;
}
Ok(())
}
async fn cleanup(
&mut self,
mut messages: futures::stream::BoxStream<'async_trait, TestMessage>,
result: Result<(), SlateDBError>,
) -> Result<(), SlateDBError> {
self.cleanup_called.write(result);
while let Some(m) = messages.next().await {
self.log.lock().unwrap().push((Phase::Cleanup, m));
}
Ok(())
}
}
async fn wait_for_message_count(log: Arc<Mutex<Vec<(Phase, TestMessage)>>>, count: usize) {
timeout(Duration::from_secs(30), async move {
while log.lock().unwrap().len() < count {
yield_now().await;
}
})
.await
.expect("timeout waiting for message count");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_dispatcher_run_happy_path() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let (tx, rx) = async_channel::unbounded();
let clock = Arc::new(MockSystemClock::new());
let handler = TestHandler::new(log.clone(), WatchableOnceCell::new(), clock.clone())
.add_ticker(Duration::from_millis(5), 1)
.add_clock_schedule(5); let cancellation_token = CancellationToken::new();
let fp = Arc::new(FailPointRegistry::default());
let mut dispatcher = MessageDispatcher::new(
Box::new(handler),
rx,
clock.clone(),
cancellation_token.clone(),
)
.with_fp_registry(fp.clone());
fail_parallel::cfg(fp.clone(), "dispatcher-run-loop", "pause").unwrap();
let join = tokio::spawn(async move { dispatcher.run().await });
tx.try_send(TestMessage::Channel(10)).unwrap();
fail_parallel::cfg(fp.clone(), "dispatcher-run-loop", "off").unwrap();
wait_for_message_count(log.clone(), 2).await;
tx.try_send(TestMessage::Channel(20)).unwrap();
wait_for_message_count(log.clone(), 3).await;
cancellation_token.cancel();
let result = timeout(Duration::from_secs(30), join)
.await
.expect("dispatcher did not stop in time")
.expect("join failed");
assert!(matches!(result, Ok(())));
let messages = log.lock().unwrap().clone();
assert_eq!(
messages,
vec![
(Phase::Pre, TestMessage::Channel(10)),
(Phase::Pre, TestMessage::Tick(1)),
(Phase::Pre, TestMessage::Channel(20))
]
);
}
#[cfg(feature = "test-util")]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_executor_propagates_handler_error_drains_messages() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let cleanup_called = WatchableOnceCell::new();
let mut cleanup_reader = cleanup_called.reader();
let (tx, rx) = async_channel::unbounded();
let clock = Arc::new(MockSystemClock::new());
let handler = TestHandler::new(log.clone(), cleanup_called.clone(), clock.clone());
let closed_result = Arc::new(WatchableOnceCell::new());
let fp_registry = Arc::new(FailPointRegistry::default());
let task_executor = MessageHandlerExecutor::new(closed_result.clone(), clock.clone())
.with_fp_registry(fp_registry.clone());
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "1*off->return").unwrap();
fail_parallel::cfg(fp_registry.clone(), "dispatcher-cleanup", "pause").unwrap();
task_executor
.add_handler(
"test".to_string(),
Box::new(handler),
rx,
&Handle::current(),
)
.expect("spawn failed");
task_executor
.monitor_on(&Handle::current())
.expect("failed to monitor executor");
tx.try_send(TestMessage::Channel(42)).unwrap();
wait_for_message_count(log.clone(), 1).await;
tx.try_send(TestMessage::Channel(77)).unwrap();
fail_parallel::cfg(fp_registry.clone(), "dispatcher-cleanup", "off").unwrap();
let _ = cleanup_reader.await_value().await;
let result = timeout(Duration::from_secs(30), task_executor.join_task("test"))
.await
.expect("dispatcher did not stop in time");
assert!(matches!(result, Err(SlateDBError::Fenced)));
assert!(matches!(
closed_result.result_reader().read(),
Some(Err(SlateDBError::Fenced))
));
assert!(matches!(
cleanup_reader.read(),
Some(Err(SlateDBError::Fenced))
));
let messages = log.lock().unwrap().clone();
assert_eq!(
messages,
vec![
(Phase::Pre, TestMessage::Channel(42)),
(Phase::Cleanup, TestMessage::Channel(77)),
]
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_dispatcher_prioritizes_messages_over_tickers() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let (tx, rx) = async_channel::unbounded();
let clock = Arc::new(MockSystemClock::new());
let handler = TestHandler::new(log.clone(), WatchableOnceCell::new(), clock.clone())
.add_ticker(Duration::from_millis(5), 1);
let cancellation_token = CancellationToken::new();
let fp_registry = Arc::new(FailPointRegistry::default());
let mut dispatcher = MessageDispatcher::new(
Box::new(handler),
rx,
clock.clone(),
cancellation_token.clone(),
)
.with_fp_registry(fp_registry.clone());
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "pause").unwrap();
let join = tokio::spawn(async move { dispatcher.run().await });
clock.advance(Duration::from_millis(5)).await;
tx.try_send(TestMessage::Channel(99)).unwrap();
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "off").unwrap();
wait_for_message_count(log.clone(), 2).await;
cancellation_token.cancel();
let result = timeout(Duration::from_secs(30), join)
.await
.expect("dispatcher did not stop in time")
.expect("join failed");
assert!(matches!(result, Ok(())));
let messages = log.lock().unwrap().clone();
assert_eq!(
messages,
vec![
(Phase::Pre, TestMessage::Channel(99)),
(Phase::Pre, TestMessage::Tick(1)),
]
);
}
#[cfg(feature = "test-util")]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_dispatcher_supports_multiple_tickers() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let (_tx, rx) = async_channel::unbounded::<TestMessage>();
let clock = Arc::new(MockSystemClock::new());
let handler = TestHandler::new(log.clone(), WatchableOnceCell::new(), clock.clone())
.add_ticker(Duration::from_millis(5), 1)
.add_ticker(Duration::from_millis(7), 2)
.add_clock_schedule(0) .add_clock_schedule(5) .add_clock_schedule(2) .add_clock_schedule(3) .add_clock_schedule(4) .add_clock_schedule(1) .add_clock_schedule(5) .add_clock_schedule(1); let cancellation_token = CancellationToken::new();
let fp_registry = Arc::new(FailPointRegistry::default());
let mut dispatcher = MessageDispatcher::new(
Box::new(handler),
rx,
clock.clone(),
cancellation_token.clone(),
)
.with_fp_registry(fp_registry.clone());
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "pause").unwrap();
let join = tokio::spawn(async move { dispatcher.run().await });
assert_eq!(log.lock().unwrap().len(), 0);
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "off").unwrap();
wait_for_message_count(log.clone(), 9).await;
assert_eq!(
log.lock().unwrap().clone(),
vec![
(Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), ]
);
cancellation_token.cancel();
let result = timeout(Duration::from_secs(30), join)
.await
.expect("dispatcher did not stop in time")
.expect("join failed");
assert!(matches!(result, Ok(())));
assert_eq!(log.lock().unwrap().len(), 9);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_dispatcher_supports_overlapping_tickers() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let (_tx, rx) = async_channel::unbounded::<TestMessage>();
let clock = Arc::new(MockSystemClock::new());
let handler = TestHandler::new(log.clone(), WatchableOnceCell::new(), clock.clone())
.add_ticker(Duration::from_millis(3), 3)
.add_ticker(Duration::from_millis(5), 5)
.add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(2) .add_clock_schedule(1) .add_clock_schedule(3) .add_clock_schedule(1) .add_clock_schedule(2) .add_clock_schedule(3); let cancellation_token = CancellationToken::new();
let fp_registry = Arc::new(FailPointRegistry::default());
let mut dispatcher = MessageDispatcher::new(
Box::new(handler),
rx,
clock.clone(),
cancellation_token.clone(),
)
.with_fp_registry(fp_registry.clone());
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "pause").unwrap();
let join = tokio::spawn(async move { dispatcher.run().await });
assert_eq!(log.lock().unwrap().len(), 0);
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "off").unwrap();
wait_for_message_count(log.clone(), 10).await;
assert_eq!(
log.lock().unwrap().clone()[..8],
vec![
(Phase::Pre, TestMessage::Tick(3)), (Phase::Pre, TestMessage::Tick(5)), (Phase::Pre, TestMessage::Tick(3)), (Phase::Pre, TestMessage::Tick(5)), (Phase::Pre, TestMessage::Tick(3)), (Phase::Pre, TestMessage::Tick(3)), (Phase::Pre, TestMessage::Tick(5)), (Phase::Pre, TestMessage::Tick(3)), ]
);
let mut last_two_ticks = log.lock().unwrap().clone()[8..].to_vec();
last_two_ticks.sort_by(|a, b| match (a.1.clone(), b.1.clone()) {
(TestMessage::Tick(a), TestMessage::Tick(b)) => a.cmp(&b),
_ => panic!("expected ticks"),
});
assert_eq!(
last_two_ticks,
vec![
(Phase::Pre, TestMessage::Tick(3)),
(Phase::Pre, TestMessage::Tick(5))
]
);
cancellation_token.cancel();
let result = timeout(Duration::from_secs(30), join)
.await
.expect("dispatcher did not stop in time")
.expect("join failed");
assert!(matches!(result, Ok(())));
assert_eq!(log.lock().unwrap().len(), 10);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_dispatcher_supports_identical_tickers() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let (_tx, rx) = async_channel::unbounded::<TestMessage>();
let clock = Arc::new(MockSystemClock::new());
let handler = TestHandler::new(log.clone(), WatchableOnceCell::new(), clock.clone())
.add_ticker(Duration::from_millis(3), 1)
.add_ticker(Duration::from_millis(3), 2)
.add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0) .add_clock_schedule(3) .add_clock_schedule(0); let cancellation_token = CancellationToken::new();
let fp_registry = Arc::new(FailPointRegistry::default());
let mut dispatcher = MessageDispatcher::new(
Box::new(handler),
rx,
clock.clone(),
cancellation_token.clone(),
)
.with_fp_registry(fp_registry.clone());
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "pause").unwrap();
let join = tokio::spawn(async move { dispatcher.run().await });
assert_eq!(log.lock().unwrap().len(), 0);
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "off").unwrap();
wait_for_message_count(log.clone(), 16).await;
assert_eq!(
log.lock().unwrap().clone(),
vec![
(Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), (Phase::Pre, TestMessage::Tick(1)), (Phase::Pre, TestMessage::Tick(2)), ]
);
cancellation_token.cancel();
let result = timeout(Duration::from_secs(30), join)
.await
.expect("dispatcher did not stop in time")
.expect("join failed");
assert!(matches!(result, Ok(())));
assert_eq!(log.lock().unwrap().len(), 16);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_executor_catches_panic_from_run_loop() {
#[derive(Clone)]
struct PanicHandler {
cleanup_called: WatchableOnceCell<Result<(), SlateDBError>>,
}
#[async_trait::async_trait]
impl MessageHandler<u8> for PanicHandler {
fn tickers(&mut self) -> Vec<(Duration, Box<MessageFactory<u8>>)> {
vec![]
}
async fn handle(&mut self, _message: u8) -> Result<(), SlateDBError> {
panic!("intentional panic in handler");
}
async fn cleanup(
&mut self,
mut messages: BoxStream<'async_trait, u8>,
result: Result<(), SlateDBError>,
) -> Result<(), SlateDBError> {
self.cleanup_called.write(result);
while messages.next().await.is_some() {}
Ok(())
}
}
let cleanup_called = WatchableOnceCell::new();
let mut cleanup_reader = cleanup_called.reader();
let (tx, rx) = async_channel::unbounded::<u8>();
let clock = Arc::new(DefaultSystemClock::new());
let handler = PanicHandler { cleanup_called };
let closed_result = Arc::new(WatchableOnceCell::new());
let task_executor = MessageHandlerExecutor::new(closed_result.clone(), clock.clone());
task_executor
.add_handler(
"test".to_string(),
Box::new(handler),
rx,
&Handle::current(),
)
.expect("failed to spawn task");
task_executor
.monitor_on(&Handle::current())
.expect("failed to monitor executor");
tx.try_send(1u8).unwrap();
let _ = timeout(Duration::from_secs(30), cleanup_reader.await_value())
.await
.expect("timeout waiting for cleanup result");
let result = timeout(Duration::from_secs(30), task_executor.join_task("test"))
.await
.expect("dispatcher did not stop in time");
assert!(matches!(result, Err(SlateDBError::BackgroundTaskPanic(_))));
assert!(matches!(
closed_result.result_reader().read(),
Some(Err(SlateDBError::BackgroundTaskPanic(_)))
));
assert!(matches!(
cleanup_reader.read(),
Some(Err(SlateDBError::BackgroundTaskPanic(_)))
));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_executor_panic_in_wrapper() {
let log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let cleanup_called = WatchableOnceCell::new();
let cleanup_reader = cleanup_called.reader();
let (tx, rx) = async_channel::unbounded::<TestMessage>();
drop(tx); let clock = Arc::new(DefaultSystemClock::new());
let handler = TestHandler::new(log.clone(), cleanup_called.clone(), clock.clone());
let closed_result = Arc::new(WatchableOnceCell::new());
let fp_registry = Arc::new(FailPointRegistry::default());
let task_executor = MessageHandlerExecutor::new(closed_result.clone(), clock.clone())
.with_fp_registry(fp_registry.clone());
fail_parallel::cfg(
fp_registry.clone(),
"executor-wrapper-before-write",
"panic",
)
.unwrap();
task_executor
.add_handler(
"test".to_string(),
Box::new(handler),
rx,
&Handle::current(),
)
.expect("failed to spawn task");
task_executor
.monitor_on(&Handle::current())
.expect("failed to monitor executor");
task_executor.cancel_task("test");
let result = timeout(Duration::from_secs(30), task_executor.join_task("test"))
.await
.expect("dispatcher did not stop in time");
assert!(matches!(result, Err(SlateDBError::BackgroundTaskPanic(_))));
assert!(matches!(
closed_result.result_reader().read(),
Some(Err(SlateDBError::BackgroundTaskPanic(_)))
));
assert!(cleanup_reader.read().is_none());
}
struct ParallelTestHandler {
handler_id: u8,
log: Arc<Mutex<Vec<(u8, TestMessage)>>>,
cleanup_called: WatchableOnceCell<Result<(), SlateDBError>>,
block_on: Option<Arc<tokio::sync::Notify>>,
error_after: Option<usize>,
handled_count: usize,
tickers: Vec<(Duration, u8)>,
}
impl ParallelTestHandler {
fn new(
handler_id: u8,
log: Arc<Mutex<Vec<(u8, TestMessage)>>>,
cleanup_called: WatchableOnceCell<Result<(), SlateDBError>>,
) -> Self {
Self {
handler_id,
log,
cleanup_called,
block_on: None,
error_after: None,
handled_count: 0,
tickers: vec![],
}
}
fn with_block_on(mut self, notify: Arc<tokio::sync::Notify>) -> Self {
self.block_on = Some(notify);
self
}
fn with_error_after(mut self, n: usize) -> Self {
self.error_after = Some(n);
self
}
fn add_ticker(mut self, d: Duration, id: u8) -> Self {
self.tickers.push((d, id));
self
}
}
#[async_trait::async_trait]
impl MessageHandler<TestMessage> for ParallelTestHandler {
fn tickers(&mut self) -> Vec<(Duration, Box<MessageFactory<TestMessage>>)> {
self.tickers
.iter()
.map(|(interval, id)| {
let id = *id as i32;
(
*interval,
Box::new(move || TestMessage::Tick(id)) as Box<MessageFactory<TestMessage>>,
)
})
.collect()
}
async fn handle(&mut self, message: TestMessage) -> Result<(), SlateDBError> {
self.log.lock().unwrap().push((self.handler_id, message));
if let Some(notify) = &self.block_on {
notify.notified().await;
}
self.handled_count += 1;
if let Some(n) = self.error_after {
if self.handled_count > n {
return Err(SlateDBError::Fenced);
}
}
Ok(())
}
async fn cleanup(
&mut self,
mut messages: BoxStream<'async_trait, TestMessage>,
result: Result<(), SlateDBError>,
) -> Result<(), SlateDBError> {
self.cleanup_called.write(result);
while let Some(m) = messages.next().await {
self.log.lock().unwrap().push((self.handler_id, m));
}
Ok(())
}
}
async fn wait_for_parallel_log_count(log: Arc<Mutex<Vec<(u8, TestMessage)>>>, count: usize) {
timeout(Duration::from_secs(30), async move {
while log.lock().unwrap().len() < count {
yield_now().await;
}
})
.await
.expect("timeout waiting for parallel log count");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_parallel_handlers_both_members_process_messages() {
let log = Arc::new(Mutex::new(Vec::<(u8, TestMessage)>::new()));
let (tx, rx) = async_channel::unbounded();
let clock = Arc::new(DefaultSystemClock::new());
let block1 = Arc::new(tokio::sync::Notify::new());
let block2 = Arc::new(tokio::sync::Notify::new());
let handler1 = ParallelTestHandler::new(1, log.clone(), WatchableOnceCell::new())
.with_block_on(block1.clone());
let handler2 = ParallelTestHandler::new(2, log.clone(), WatchableOnceCell::new())
.with_block_on(block2.clone());
let closed_result = Arc::new(WatchableOnceCell::new());
let task_executor = MessageHandlerExecutor::new(closed_result, clock);
task_executor
.add_handlers(
"parallel".to_string(),
vec![Box::new(handler1), Box::new(handler2)],
rx,
&Handle::current(),
)
.expect("failed to add handlers");
task_executor.monitor_on(&Handle::current()).unwrap();
tx.try_send(TestMessage::Channel(1)).unwrap();
tx.try_send(TestMessage::Channel(2)).unwrap();
wait_for_parallel_log_count(log.clone(), 2).await;
let entries = log.lock().unwrap().clone();
let handler_ids: HashSet<u8> = entries.iter().map(|(id, _)| *id).collect();
assert!(
handler_ids.contains(&1) && handler_ids.contains(&2),
"expected both handlers to process messages, got: {:?}",
entries,
);
block1.notify_waiters();
block2.notify_waiters();
task_executor.shutdown_task("parallel").await.ok();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_parallel_handlers_independent_tickers() {
let log = Arc::new(Mutex::new(Vec::<(u8, TestMessage)>::new()));
let (_tx, rx) = async_channel::unbounded::<TestMessage>();
let clock = Arc::new(DefaultSystemClock::new());
let handler1 = ParallelTestHandler::new(1, log.clone(), WatchableOnceCell::new())
.add_ticker(Duration::from_millis(1), 10);
let handler2 = ParallelTestHandler::new(2, log.clone(), WatchableOnceCell::new())
.add_ticker(Duration::from_millis(1), 20);
let closed_result = Arc::new(WatchableOnceCell::new());
let task_executor = MessageHandlerExecutor::new(closed_result, clock);
task_executor
.add_handlers(
"parallel".to_string(),
vec![Box::new(handler1), Box::new(handler2)],
rx,
&Handle::current(),
)
.expect("failed to add handlers");
task_executor.monitor_on(&Handle::current()).unwrap();
wait_for_parallel_log_count(log.clone(), 4).await;
let entries = log.lock().unwrap().clone();
let h1_ticks: Vec<_> = entries
.iter()
.filter(|(id, _)| *id == 1)
.map(|(_, m)| m.clone())
.collect();
let h2_ticks: Vec<_> = entries
.iter()
.filter(|(id, _)| *id == 2)
.map(|(_, m)| m.clone())
.collect();
assert!(
h1_ticks.iter().all(|m| *m == TestMessage::Tick(10)),
"handler 1 should only see tick(10), got: {:?}",
h1_ticks,
);
assert!(
h2_ticks.iter().all(|m| *m == TestMessage::Tick(20)),
"handler 2 should only see tick(20), got: {:?}",
h2_ticks,
);
assert!(!h1_ticks.is_empty(), "handler 1 should have ticked");
assert!(!h2_ticks.is_empty(), "handler 2 should have ticked");
task_executor.shutdown_task("parallel").await.ok();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_parallel_handlers_error_cancels_sibling_and_runs_cleanup() {
let log = Arc::new(Mutex::new(Vec::<(u8, TestMessage)>::new()));
let (tx, rx) = async_channel::unbounded::<TestMessage>();
let clock = Arc::new(DefaultSystemClock::new());
let cleanup1 = WatchableOnceCell::new();
let cleanup2 = WatchableOnceCell::new();
let mut cleanup1_reader = cleanup1.reader();
let mut cleanup2_reader = cleanup2.reader();
let handler1 = ParallelTestHandler::new(1, log.clone(), cleanup1).with_error_after(1);
let handler2 = ParallelTestHandler::new(2, log.clone(), cleanup2).with_error_after(1);
let closed_result = Arc::new(WatchableOnceCell::new());
let task_executor = MessageHandlerExecutor::new(closed_result.clone(), clock);
task_executor
.add_handlers(
"parallel".to_string(),
vec![Box::new(handler1), Box::new(handler2)],
rx,
&Handle::current(),
)
.expect("failed to add handlers");
task_executor.monitor_on(&Handle::current()).unwrap();
tx.try_send(TestMessage::Channel(1)).unwrap();
tx.try_send(TestMessage::Channel(2)).unwrap();
let _ = tx.try_send(TestMessage::Channel(3));
let _ = timeout(Duration::from_secs(5), cleanup1_reader.await_value())
.await
.expect("timeout waiting for handler 1 cleanup");
let _ = timeout(Duration::from_secs(5), cleanup2_reader.await_value())
.await
.expect("timeout waiting for handler 2 cleanup");
let result = timeout(Duration::from_secs(5), task_executor.join_task("parallel"))
.await
.expect("timeout waiting for join");
assert!(matches!(result, Err(SlateDBError::Fenced)));
assert!(matches!(
closed_result.result_reader().read(),
Some(Err(SlateDBError::Fenced))
));
assert!(cleanup1_reader.read().is_some());
assert!(cleanup2_reader.read().is_some());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_parallel_handlers_join_waits_for_all_members() {
let log = Arc::new(Mutex::new(Vec::<(u8, TestMessage)>::new()));
let (tx, rx) = async_channel::unbounded();
let clock = Arc::new(DefaultSystemClock::new());
let block1 = Arc::new(tokio::sync::Notify::new());
let block2 = Arc::new(tokio::sync::Notify::new());
let handler1 = ParallelTestHandler::new(1, log.clone(), WatchableOnceCell::new())
.with_block_on(block1.clone());
let handler2 = ParallelTestHandler::new(2, log.clone(), WatchableOnceCell::new())
.with_block_on(block2.clone());
let closed_result = Arc::new(WatchableOnceCell::new());
let task_executor = Arc::new(MessageHandlerExecutor::new(closed_result, clock));
task_executor
.add_handlers(
"parallel".to_string(),
vec![Box::new(handler1), Box::new(handler2)],
rx,
&Handle::current(),
)
.expect("failed to add handlers");
task_executor.monitor_on(&Handle::current()).unwrap();
tx.try_send(TestMessage::Channel(1)).unwrap();
tx.try_send(TestMessage::Channel(2)).unwrap();
wait_for_parallel_log_count(log.clone(), 2).await;
task_executor.cancel_task("parallel");
let join_result = timeout(
Duration::from_millis(100),
task_executor.join_task("parallel"),
)
.await;
assert!(
join_result.is_err(),
"join_task should not resolve while members are still blocked"
);
block1.notify_waiters();
block2.notify_waiters();
let result = timeout(Duration::from_secs(5), task_executor.join_task("parallel"))
.await
.expect("timeout waiting for join after unblock");
assert!(result.is_ok());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_parallel_handlers_cleanup_drains_messages() {
let log = Arc::new(Mutex::new(Vec::<(u8, TestMessage)>::new()));
let (tx, rx) = async_channel::unbounded();
let clock = Arc::new(DefaultSystemClock::new());
let fp_registry = Arc::new(FailPointRegistry::default());
let handler1 = ParallelTestHandler::new(1, log.clone(), WatchableOnceCell::new());
let handler2 = ParallelTestHandler::new(2, log.clone(), WatchableOnceCell::new());
let closed_result = Arc::new(WatchableOnceCell::new());
let task_executor =
MessageHandlerExecutor::new(closed_result, clock).with_fp_registry(fp_registry.clone());
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "pause").unwrap();
task_executor
.add_handlers(
"parallel".to_string(),
vec![Box::new(handler1), Box::new(handler2)],
rx,
&Handle::current(),
)
.expect("failed to add handlers");
task_executor.monitor_on(&Handle::current()).unwrap();
for i in 1..=5 {
tx.try_send(TestMessage::Channel(i)).unwrap();
}
fail_parallel::cfg(fp_registry.clone(), "dispatcher-run-loop", "off").unwrap();
task_executor.cancel_task("parallel");
let result = timeout(Duration::from_secs(5), task_executor.join_task("parallel"))
.await
.expect("timeout waiting for join");
assert!(result.is_ok());
let entries = log.lock().unwrap().clone();
let mut values: Vec<i32> = entries
.iter()
.map(|(_, m)| match m {
TestMessage::Channel(v) => *v,
_ => panic!("unexpected tick"),
})
.collect();
values.sort();
assert_eq!(values, vec![1, 2, 3, 4, 5]);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_parallel_group_coexists_with_sequential_task() {
let parallel_log = Arc::new(Mutex::new(Vec::<(u8, TestMessage)>::new()));
let sequential_log = Arc::new(Mutex::new(Vec::<(Phase, TestMessage)>::new()));
let (par_tx, par_rx) = async_channel::unbounded();
let (seq_tx, seq_rx) = async_channel::unbounded();
let clock = Arc::new(DefaultSystemClock::new());
let handler1 = ParallelTestHandler::new(1, parallel_log.clone(), WatchableOnceCell::new());
let handler2 = ParallelTestHandler::new(2, parallel_log.clone(), WatchableOnceCell::new());
let seq_handler = TestHandler::new(
sequential_log.clone(),
WatchableOnceCell::new(),
clock.clone(),
);
let closed_result = Arc::new(WatchableOnceCell::new());
let task_executor = MessageHandlerExecutor::new(closed_result, clock);
task_executor
.add_handlers(
"parallel".to_string(),
vec![Box::new(handler1), Box::new(handler2)],
par_rx,
&Handle::current(),
)
.expect("failed to add parallel handlers");
task_executor
.add_handler(
"sequential".to_string(),
Box::new(seq_handler),
seq_rx,
&Handle::current(),
)
.expect("failed to add sequential handler");
task_executor.monitor_on(&Handle::current()).unwrap();
par_tx.try_send(TestMessage::Channel(1)).unwrap();
par_tx.try_send(TestMessage::Channel(2)).unwrap();
seq_tx.try_send(TestMessage::Channel(100)).unwrap();
wait_for_parallel_log_count(parallel_log.clone(), 2).await;
wait_for_message_count(sequential_log.clone(), 1).await;
let par_entries = parallel_log.lock().unwrap().clone();
let mut par_values: Vec<i32> = par_entries
.iter()
.map(|(_, m)| match m {
TestMessage::Channel(v) => *v,
_ => panic!("unexpected tick"),
})
.collect();
par_values.sort();
assert_eq!(par_values, vec![1, 2]);
let seq_entries = sequential_log.lock().unwrap().clone();
assert_eq!(seq_entries.len(), 1);
assert_eq!(seq_entries[0], (Phase::Pre, TestMessage::Channel(100)));
task_executor.shutdown_task("parallel").await.ok();
task_executor.shutdown_task("sequential").await.ok();
}
}