use std::{
any::Any,
collections::HashMap,
future::Future,
panic::AssertUnwindSafe,
sync::{
Arc,
atomic::{AtomicBool, AtomicU64, Ordering},
},
};
use futures_util::FutureExt;
use tokio::sync::broadcast;
use crate::{
SendOutsideWasm,
executor::{AbortHandle, spawn},
locks::RwLock,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TaskId(u64);
impl TaskId {
fn new() -> Self {
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
Self(NEXT_ID.fetch_add(1, Ordering::SeqCst))
}
}
impl std::fmt::Display for TaskId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TaskId({})", self.0)
}
}
#[derive(Debug, Clone)]
pub struct BackgroundTaskInfo {
pub id: TaskId,
pub name: String,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
pub enum BackgroundTaskFailureReason {
Panic {
message: Option<String>,
panic_backtrace: Option<String>,
},
Error {
error: String,
},
EarlyTermination,
}
#[derive(Debug, Clone)]
pub struct BackgroundTaskFailure {
pub task: BackgroundTaskInfo,
pub reason: BackgroundTaskFailureReason,
}
#[derive(Debug)]
struct ActiveTask {
_abort_handle: AbortHandle,
}
const FAILURE_CHANNEL_CAPACITY: usize = 8;
#[derive(Debug)]
pub struct TaskMonitor {
failure_sender: broadcast::Sender<BackgroundTaskFailure>,
active_task_handles: Arc<RwLock<HashMap<TaskId, ActiveTask>>>,
}
impl Default for TaskMonitor {
fn default() -> Self {
Self::new()
}
}
impl TaskMonitor {
pub fn new() -> Self {
let (failure_sender, _) = broadcast::channel(FAILURE_CHANNEL_CAPACITY);
Self { failure_sender, active_task_handles: Default::default() }
}
pub fn subscribe(&self) -> broadcast::Receiver<BackgroundTaskFailure> {
self.failure_sender.subscribe()
}
pub fn spawn_infinite_task<F>(&self, name: impl Into<String>, future: F) -> BackgroundTaskHandle
where
F: Future<Output = ()> + SendOutsideWasm + 'static,
{
self.spawn_task_internal(name, future, true)
}
pub fn spawn_finite_task<F>(&self, name: impl Into<String>, future: F) -> BackgroundTaskHandle
where
F: Future<Output = ()> + SendOutsideWasm + 'static,
{
self.spawn_task_internal(name, future, false)
}
fn spawn_task_internal<F>(
&self,
name: impl Into<String>,
future: F,
runs_forever: bool,
) -> BackgroundTaskHandle
where
F: Future<Output = ()> + SendOutsideWasm + 'static,
{
let name = name.into();
let task_id = TaskId::new();
let task_info = BackgroundTaskInfo { id: task_id, name };
let intentionally_aborted = Arc::new(AtomicBool::new(false));
let active_tasks = self.active_task_handles.clone();
let failure_sender = self.failure_sender.clone();
let aborted_flag = intentionally_aborted.clone();
let wrapped = async move {
let result = AssertUnwindSafe(future).catch_unwind().await;
active_tasks.write().remove(&task_id);
if aborted_flag.load(Ordering::Acquire) {
return;
}
let failure_reason = match result {
Ok(()) => {
if runs_forever {
BackgroundTaskFailureReason::EarlyTermination
} else {
return;
}
}
Err(panic_payload) => BackgroundTaskFailureReason::Panic {
message: extract_panic_message(&panic_payload),
panic_backtrace: capture_backtrace(),
},
};
let failure = BackgroundTaskFailure { task: task_info, reason: failure_reason };
let _ = failure_sender.send(failure);
};
let join_handle = spawn(wrapped);
let abort_handle = join_handle.abort_handle();
self.active_task_handles
.write()
.insert(task_id, ActiveTask { _abort_handle: abort_handle.clone() });
BackgroundTaskHandle { abort_on_drop: false, abort_handle, intentionally_aborted }
}
pub fn spawn_fallible_task<F, E>(
&self,
name: impl Into<String>,
future: F,
) -> BackgroundTaskHandle
where
F: Future<Output = Result<(), E>> + SendOutsideWasm + 'static,
E: std::error::Error + SendOutsideWasm + 'static,
{
let name = name.into();
let task_id = TaskId::new();
let task_info = BackgroundTaskInfo { id: task_id, name };
let intentionally_aborted = Arc::new(AtomicBool::new(false));
let active_tasks = self.active_task_handles.clone();
let failure_sender = self.failure_sender.clone();
let aborted_flag = intentionally_aborted.clone();
let wrapped = async move {
let result = AssertUnwindSafe(future).catch_unwind().await;
active_tasks.write().remove(&task_id);
if aborted_flag.load(Ordering::Acquire) {
return;
}
let failure_reason = match result {
Ok(Ok(())) => {
return;
}
Ok(Err(e)) => BackgroundTaskFailureReason::Error { error: e.to_string() },
Err(panic_payload) => BackgroundTaskFailureReason::Panic {
message: extract_panic_message(&panic_payload),
panic_backtrace: capture_backtrace(),
},
};
let _ = failure_sender
.send(BackgroundTaskFailure { task: task_info, reason: failure_reason });
};
let join_handle = spawn(wrapped);
let abort_handle = join_handle.abort_handle();
self.active_task_handles
.write()
.insert(task_id, ActiveTask { _abort_handle: abort_handle.clone() });
BackgroundTaskHandle { abort_on_drop: false, abort_handle, intentionally_aborted }
}
}
#[derive(Debug)]
pub struct BackgroundTaskHandle {
abort_handle: AbortHandle,
abort_on_drop: bool,
intentionally_aborted: Arc<AtomicBool>,
}
impl Drop for BackgroundTaskHandle {
fn drop(&mut self) {
if self.abort_on_drop {
self.abort();
}
}
}
impl BackgroundTaskHandle {
pub fn abort_on_drop(mut self) -> Self {
self.abort_on_drop = true;
self
}
pub fn abort(&self) {
self.intentionally_aborted.store(true, Ordering::Release);
self.abort_handle.abort();
}
pub fn is_finished(&self) -> bool {
#[cfg(not(target_family = "wasm"))]
{
self.abort_handle.is_finished()
}
#[cfg(target_family = "wasm")]
{
self.abort_handle.is_aborted()
}
}
}
#[cfg(not(target_family = "wasm"))]
fn capture_backtrace() -> Option<String> {
use std::backtrace::{Backtrace, BacktraceStatus};
let bt = Backtrace::capture();
if bt.status() == BacktraceStatus::Captured { Some(bt.to_string()) } else { None }
}
#[cfg(target_family = "wasm")]
fn capture_backtrace() -> Option<String> {
None
}
fn extract_panic_message(payload: &Box<dyn Any + Send>) -> Option<String> {
if let Some(s) = payload.downcast_ref::<&str>() {
Some((*s).to_owned())
} else {
payload.downcast_ref::<String>().cloned()
}
}
#[cfg(test)]
mod tests {
use std::{
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
use assert_matches::assert_matches;
use matrix_sdk_test_macros::async_test;
use super::{BackgroundTaskFailureReason, TaskMonitor};
use crate::{sleep::sleep, timeout::timeout};
#[async_test]
async fn test_early_termination_is_reported() {
let monitor = TaskMonitor::new();
let mut failures = monitor.subscribe();
let _handle = monitor.spawn_infinite_task("test_task", async {
});
let failure = timeout(failures.recv(), Duration::from_secs(1))
.await
.expect("timeout waiting for failure")
.expect("channel closed");
assert_eq!(failure.task.name, "test_task");
assert_matches!(failure.reason, BackgroundTaskFailureReason::EarlyTermination);
}
#[async_test]
#[cfg(not(target_family = "wasm"))] async fn test_panic_is_captured() {
let monitor = TaskMonitor::new();
let mut failures = monitor.subscribe();
let _handle = monitor.spawn_infinite_task("panicking_task", async {
panic!("test panic message");
});
let failure = timeout(failures.recv(), Duration::from_secs(1))
.await
.expect("timeout waiting for failure")
.expect("channel closed");
assert_eq!(failure.task.name, "panicking_task");
assert_matches!(
failure.reason,
BackgroundTaskFailureReason::Panic { message, .. } => {
assert_eq!(message.as_deref(), Some("test panic message"));
}
);
}
#[async_test]
async fn test_error_is_captured() {
let monitor = TaskMonitor::new();
let mut failures = monitor.subscribe();
let _handle = monitor.spawn_fallible_task("fallible_task", async {
Err::<(), _>(std::io::Error::other("test error message"))
});
let failure = timeout(failures.recv(), Duration::from_secs(1))
.await
.expect("timeout waiting for failure")
.expect("channel closed");
assert_eq!(failure.task.name, "fallible_task");
assert_matches!(
failure.reason,
BackgroundTaskFailureReason::Error { error } => {
assert!(error.contains("test error message"));
}
);
}
#[async_test]
async fn test_successful_fallible_task_no_failure() {
let monitor = TaskMonitor::new();
let mut failures = monitor.subscribe();
let _handle =
monitor.spawn_fallible_task("success_task", async { Ok::<(), std::io::Error>(()) });
let result = timeout(failures.recv(), Duration::from_millis(100)).await;
assert!(result.is_err(), "should timeout, no failure expected");
}
#[async_test]
async fn test_abort_does_not_report_failure() {
let monitor = TaskMonitor::new();
let mut failures = monitor.subscribe();
let handle = monitor.spawn_infinite_task("aborted_task", async {
loop {
sleep(Duration::from_secs(10)).await;
}
});
sleep(Duration::from_millis(10)).await;
handle.abort();
let result = timeout(failures.recv(), Duration::from_millis(100)).await;
assert!(result.is_err(), "should timeout, no failure expected for abort");
assert!(handle.is_finished(), "task should be finished after abort");
}
#[async_test]
async fn test_abort_on_drop_does_not_report_failure() {
let monitor = TaskMonitor::new();
let mut failures = monitor.subscribe();
let handle = monitor
.spawn_infinite_task("aborted_task", async {
loop {
sleep(Duration::from_secs(10)).await;
}
})
.abort_on_drop();
sleep(Duration::from_millis(10)).await;
drop(handle);
let result = timeout(failures.recv(), Duration::from_millis(100)).await;
assert!(result.is_err(), "should timeout, no failure expected for abort");
}
#[async_test]
async fn test_spawn_finite_task() {
let monitor = TaskMonitor::new();
let mut failures = monitor.subscribe();
let successful_completion = Arc::new(AtomicBool::new(false));
let successful_completion_clone = successful_completion.clone();
let _handle = monitor.spawn_finite_task("one-shot job", async move {
sleep(Duration::from_millis(10)).await;
successful_completion_clone.store(true, Ordering::SeqCst);
});
sleep(Duration::from_millis(20)).await;
let result = timeout(failures.recv(), Duration::from_millis(100)).await;
assert!(result.is_err(), "should timeout, no failure expected for abort");
assert!(
successful_completion.load(Ordering::SeqCst),
"background job should have completed successfully"
);
}
}