use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use acton_reactive::prelude::{Reply, *};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, Semaphore};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use super::messages::{CancelTask, GetAllTaskStatuses, GetTaskStatus, TaskStatusResponse};
fn default_task_shutdown_timeout_secs() -> u64 {
5
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackgroundWorkerConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub max_concurrent_tasks: usize,
#[serde(default = "default_task_shutdown_timeout_secs")]
pub task_shutdown_timeout_secs: u64,
#[serde(default)]
pub cleanup_interval_secs: u64,
}
impl Default for BackgroundWorkerConfig {
fn default() -> Self {
Self {
enabled: false,
max_concurrent_tasks: 0,
task_shutdown_timeout_secs: 5,
cleanup_interval_secs: 0,
}
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub enum TaskStatus {
#[default]
Pending,
Running,
Completed,
Failed(String),
Cancelled,
}
#[derive(Debug)]
pub(crate) struct TaskInfo {
task_id: String,
join_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
cancellation_token: CancellationToken,
status: Arc<Mutex<TaskStatus>>,
}
#[derive(Debug, Default)]
pub struct BackgroundWorkerState {
pub(crate) root_token: Option<CancellationToken>,
}
#[derive(Clone)]
pub struct BackgroundWorker {
agent_handle: ActorHandle,
tasks: Arc<DashMap<String, TaskInfo>>,
root_token: CancellationToken,
semaphore: Option<Arc<Semaphore>>,
shutdown_timeout: Duration,
}
impl BackgroundWorker {
pub async fn spawn(
runtime: &mut ActorRuntime,
config: &BackgroundWorkerConfig,
) -> anyhow::Result<Self> {
let tasks: Arc<DashMap<String, TaskInfo>> = Arc::new(DashMap::new());
let root_token = CancellationToken::new();
let shutdown_timeout = Duration::from_secs(config.task_shutdown_timeout_secs);
let semaphore = if config.max_concurrent_tasks > 0 {
Some(Arc::new(Semaphore::new(config.max_concurrent_tasks)))
} else {
None
};
let tasks_for_shutdown = tasks.clone();
let root_token_for_agent = root_token.clone();
let shutdown_timeout_for_cancel = shutdown_timeout;
let shutdown_timeout_for_stop = shutdown_timeout;
let mut agent = runtime.new_actor::<BackgroundWorkerState>();
agent.model.root_token = Some(root_token.clone());
let tasks_for_cancel = tasks.clone();
agent.mutate_on::<CancelTask>(move |_agent, envelope| {
let msg = envelope.message().clone();
let tasks = tasks_for_cancel.clone();
let timeout = shutdown_timeout_for_cancel;
Reply::pending(async move {
if let Some(task_info) = tasks.get(&msg.task_id) {
task_info.cancellation_token.cancel();
tracing::info!(task_id = %msg.task_id, "Task cancellation requested");
let mut handle_lock = task_info.join_handle.lock().await;
if let Some(handle) = handle_lock.take() {
let _ = tokio::time::timeout(timeout, handle).await;
}
} else {
tracing::warn!(task_id = %msg.task_id, "Task not found for cancellation");
}
})
});
let tasks_for_status = tasks.clone();
agent.act_on::<GetTaskStatus>(move |_agent, envelope| {
let msg = envelope.message().clone();
let tasks = tasks_for_status.clone();
let reply = envelope.reply_envelope();
Box::pin(async move {
let status = if let Some(task_info) = tasks.get(&msg.task_id) {
task_info.status.lock().await.clone()
} else {
TaskStatus::Pending };
reply
.send(TaskStatusResponse {
task_id: msg.task_id,
status,
})
.await;
})
});
let tasks_for_all_status = tasks.clone();
agent.act_on::<GetAllTaskStatuses>(move |_agent, envelope| {
let tasks = tasks_for_all_status.clone();
let reply = envelope.reply_envelope();
Box::pin(async move {
let mut statuses = Vec::new();
for entry in tasks.iter() {
let status = entry.status.lock().await.clone();
statuses.push(TaskStatusResponse {
task_id: entry.task_id.clone(),
status,
});
}
reply.send(statuses).await;
})
});
agent.before_stop(move |_agent| {
let tasks = tasks_for_shutdown.clone();
let root_token = root_token_for_agent.clone();
let timeout = shutdown_timeout_for_stop;
Box::pin(async move {
let task_count = tasks.len();
if task_count == 0 {
tracing::info!("BackgroundWorker stopping with no active tasks");
return;
}
tracing::info!(
task_count,
"BackgroundWorker stopping, cancelling all tasks..."
);
root_token.cancel();
for entry in tasks.iter() {
let mut handle_lock = entry.join_handle.lock().await;
if let Some(handle) = handle_lock.take() {
match tokio::time::timeout(timeout, handle).await {
Ok(Ok(())) => {
tracing::debug!(task_id = %entry.task_id, "Task shutdown complete");
}
Ok(Err(e)) => {
tracing::warn!(
task_id = %entry.task_id,
error = %e,
"Task panicked during shutdown"
);
}
Err(_) => {
tracing::warn!(
task_id = %entry.task_id,
"Task shutdown timed out"
);
}
}
}
}
tracing::info!("All background tasks stopped");
})
});
agent.after_start(|_agent| {
Box::pin(async move {
tracing::info!("BackgroundWorker agent started");
})
});
let handle = agent.start().await;
let worker = Self {
agent_handle: handle,
tasks,
root_token,
semaphore,
shutdown_timeout,
};
if config.cleanup_interval_secs > 0 {
let cleanup_worker = worker.clone();
let cleanup_token = worker.root_token.child_token();
let interval = Duration::from_secs(config.cleanup_interval_secs);
tokio::spawn(async move {
loop {
tokio::select! {
biased;
() = cleanup_token.cancelled() => break,
() = tokio::time::sleep(interval) => {
cleanup_worker.cleanup_finished_tasks().await;
tracing::debug!("Periodic background task cleanup completed");
}
}
}
});
}
Ok(worker)
}
pub async fn submit<F, Fut>(&self, task_id: impl Into<String>, work: F)
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
{
let task_id = task_id.into();
let permit = if let Some(ref sem) = self.semaphore {
match sem.clone().acquire_owned().await {
Ok(permit) => Some(permit),
Err(_) => {
tracing::warn!(task_id = %task_id, "Semaphore closed, task not submitted");
return;
}
}
} else {
None
};
let cancel_token = self.root_token.child_token();
let cancel_token_clone = cancel_token.clone();
let status = Arc::new(Mutex::new(TaskStatus::Running));
let status_for_task = status.clone();
let task_id_clone = task_id.clone();
let handle = tokio::spawn(async move {
let _permit = permit;
let task_id = task_id_clone;
tokio::select! {
biased;
() = cancel_token_clone.cancelled() => {
tracing::debug!(task_id = %task_id, "Task cancelled");
let mut s = status_for_task.lock().await;
*s = TaskStatus::Cancelled;
}
result = work() => {
match result {
Ok(()) => {
tracing::debug!(task_id = %task_id, "Task completed successfully");
let mut s = status_for_task.lock().await;
*s = TaskStatus::Completed;
}
Err(e) => {
tracing::warn!(task_id = %task_id, error = %e, "Task failed");
let mut s = status_for_task.lock().await;
*s = TaskStatus::Failed(e.to_string());
}
}
}
}
});
let task_info = TaskInfo {
task_id: task_id.clone(),
join_handle: Arc::new(Mutex::new(Some(handle))),
cancellation_token: cancel_token,
status,
};
self.tasks.insert(task_id.clone(), task_info);
tracing::info!(task_id = %task_id, "Background task submitted");
}
pub async fn cancel(&self, task_id: impl Into<String>) {
self.agent_handle
.send(CancelTask {
task_id: task_id.into(),
})
.await;
}
pub async fn get_task_status(&self, task_id: &str) -> TaskStatus {
if let Some(task_info) = self.tasks.get(task_id) {
task_info.status.lock().await.clone()
} else {
TaskStatus::Pending
}
}
#[must_use]
pub fn task_count(&self) -> usize {
self.tasks.len()
}
pub async fn running_task_count(&self) -> usize {
let mut count = 0;
for entry in self.tasks.iter() {
if *entry.status.lock().await == TaskStatus::Running {
count += 1;
}
}
count
}
#[must_use]
pub fn has_task(&self, task_id: &str) -> bool {
self.tasks.contains_key(task_id)
}
pub async fn cleanup_finished_tasks(&self) {
let mut to_remove = Vec::new();
for entry in self.tasks.iter() {
let status = entry.status.lock().await.clone();
match status {
TaskStatus::Completed | TaskStatus::Failed(_) | TaskStatus::Cancelled => {
to_remove.push(entry.task_id.clone());
}
_ => {}
}
}
for task_id in to_remove {
self.tasks.remove(&task_id);
}
}
#[must_use]
pub fn handle(&self) -> &ActorHandle {
&self.agent_handle
}
#[must_use]
pub fn shutdown_timeout(&self) -> Duration {
self.shutdown_timeout
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_defaults() {
let config = BackgroundWorkerConfig::default();
assert!(!config.enabled);
assert_eq!(config.max_concurrent_tasks, 0);
assert_eq!(config.task_shutdown_timeout_secs, 5);
assert_eq!(config.cleanup_interval_secs, 0);
}
#[test]
fn test_config_serde_empty_object() {
let config: BackgroundWorkerConfig = serde_json::from_str("{}").unwrap();
assert!(!config.enabled);
assert_eq!(config.max_concurrent_tasks, 0);
assert_eq!(config.task_shutdown_timeout_secs, 5);
assert_eq!(config.cleanup_interval_secs, 0);
}
#[test]
fn test_config_serde_partial() {
let config: BackgroundWorkerConfig =
serde_json::from_str(r#"{"enabled": true, "max_concurrent_tasks": 10}"#).unwrap();
assert!(config.enabled);
assert_eq!(config.max_concurrent_tasks, 10);
assert_eq!(config.task_shutdown_timeout_secs, 5);
assert_eq!(config.cleanup_interval_secs, 0);
}
#[tokio::test]
async fn test_semaphore_concurrency_limiting() {
let mut runtime = ActonApp::launch_async().await;
let config = BackgroundWorkerConfig {
enabled: true,
max_concurrent_tasks: 2,
task_shutdown_timeout_secs: 5,
cleanup_interval_secs: 0,
};
let worker = BackgroundWorker::spawn(&mut runtime, &config).await.unwrap();
let (tx, rx) = tokio::sync::watch::channel(false);
let running_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let max_observed = Arc::new(std::sync::atomic::AtomicUsize::new(0));
for i in 0..4 {
let rx = rx.clone();
let running = running_count.clone();
let max_obs = max_observed.clone();
let w = worker.clone();
tokio::spawn(async move {
w.submit(format!("task-{i}"), move || {
let rx = rx;
let running = running;
let max_obs = max_obs;
async move {
let current =
running.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
max_obs.fetch_max(current, std::sync::atomic::Ordering::SeqCst);
let mut rx = rx;
let _ = rx.wait_for(|v| *v).await;
running.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
})
.await;
});
}
tokio::time::sleep(Duration::from_millis(100)).await;
let max = max_observed.load(std::sync::atomic::Ordering::SeqCst);
assert!(max <= 2, "Max concurrent tasks was {max}, expected <= 2");
tx.send(true).unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
runtime.shutdown_all().await.unwrap();
}
#[tokio::test]
async fn test_cleanup_finished_tasks() {
let mut runtime = ActonApp::launch_async().await;
let config = BackgroundWorkerConfig::default();
let worker = BackgroundWorker::spawn(&mut runtime, &config).await.unwrap();
for i in 0..3 {
worker
.submit(format!("task-{i}"), || async { Ok(()) })
.await;
}
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(worker.task_count(), 3);
worker.cleanup_finished_tasks().await;
assert_eq!(worker.task_count(), 0);
runtime.shutdown_all().await.unwrap();
}
#[tokio::test]
async fn test_configurable_shutdown_timeout() {
let mut runtime = ActonApp::launch_async().await;
let config = BackgroundWorkerConfig {
enabled: true,
max_concurrent_tasks: 0,
task_shutdown_timeout_secs: 10,
cleanup_interval_secs: 0,
};
let worker = BackgroundWorker::spawn(&mut runtime, &config).await.unwrap();
assert_eq!(worker.shutdown_timeout(), Duration::from_secs(10));
runtime.shutdown_all().await.unwrap();
}
}