use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use dashmap::DashMap;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio::sync::{watch, Semaphore};
use tokio::task::JoinHandle;
use tracing::{debug, error, warn};
use uuid::Uuid;
use crate::context::Context;
use crate::errors::{ErrorCode, ModuleError};
use crate::executor::Executor;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TaskStatus {
#[default]
Pending,
Running,
Completed,
Failed,
Cancelled,
}
impl TaskStatus {
fn is_terminal(self) -> bool {
matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
}
fn is_active(self) -> bool {
matches!(self, Self::Pending | Self::Running)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct TaskInfo {
pub task_id: String,
pub module_id: String,
pub status: TaskStatus,
pub submitted_at: f64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub started_at: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub completed_at: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(default)]
pub retry_count: u32,
#[serde(default)]
pub max_retries: u32,
}
fn now_secs() -> f64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64()
}
#[async_trait]
pub trait TaskStore: Send + Sync {
async fn save(&self, task: &TaskInfo) -> Result<(), ModuleError>;
async fn get(&self, id: &str) -> Result<Option<TaskInfo>, ModuleError>;
async fn list(&self, status: Option<TaskStatus>) -> Result<Vec<TaskInfo>, ModuleError>;
async fn delete(&self, id: &str) -> Result<(), ModuleError>;
async fn list_expired(&self, before_timestamp: f64) -> Result<Vec<TaskInfo>, ModuleError>;
fn store_type_name(&self) -> &'static str;
}
#[derive(Default)]
pub struct InMemoryTaskStore {
tasks: DashMap<String, TaskInfo>,
}
impl InMemoryTaskStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl TaskStore for InMemoryTaskStore {
async fn save(&self, task: &TaskInfo) -> Result<(), ModuleError> {
self.tasks.insert(task.task_id.clone(), task.clone());
Ok(())
}
async fn get(&self, id: &str) -> Result<Option<TaskInfo>, ModuleError> {
Ok(self.tasks.get(id).map(|entry| entry.clone()))
}
async fn list(&self, status: Option<TaskStatus>) -> Result<Vec<TaskInfo>, ModuleError> {
let mut out: Vec<TaskInfo> = self
.tasks
.iter()
.filter(|entry| match status {
Some(s) => entry.value().status == s,
None => true,
})
.map(|entry| entry.value().clone())
.collect();
out.sort_by(|a, b| a.task_id.cmp(&b.task_id));
Ok(out)
}
async fn delete(&self, id: &str) -> Result<(), ModuleError> {
self.tasks.remove(id);
Ok(())
}
async fn list_expired(&self, before_timestamp: f64) -> Result<Vec<TaskInfo>, ModuleError> {
let mut out: Vec<TaskInfo> = self
.tasks
.iter()
.filter(|entry| {
let info = entry.value();
if !info.status.is_terminal() {
return false;
}
match info.completed_at {
Some(ts) => ts < before_timestamp,
None => false,
}
})
.map(|entry| entry.value().clone())
.collect();
out.sort_by(|a, b| a.task_id.cmp(&b.task_id));
Ok(out)
}
fn store_type_name(&self) -> &'static str {
"InMemoryTaskStore"
}
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub struct RetryConfig {
pub max_retries: u32,
pub retry_delay_ms: u64,
pub backoff_multiplier: f64,
pub max_retry_delay_ms: u64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 0,
retry_delay_ms: 1000,
backoff_multiplier: 2.0,
max_retry_delay_ms: 60_000,
}
}
}
impl RetryConfig {
#[must_use]
pub fn compute_delay_ms(&self, attempt: u32) -> u64 {
#[allow(clippy::cast_precision_loss)]
let base = self.retry_delay_ms as f64;
let raw = base * self.backoff_multiplier.powf(f64::from(attempt));
#[allow(clippy::cast_precision_loss)]
let cap = self.max_retry_delay_ms as f64;
let capped = raw.min(cap);
if !capped.is_finite() || capped <= 0.0 {
return 0;
}
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
let out = capped as u64;
out
}
#[must_use]
#[deprecated(since = "0.21.0", note = "use compute_delay_ms")]
pub fn delay_for_attempt(&self, attempt: u32) -> u64 {
self.compute_delay_ms(attempt)
}
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub struct ReaperConfig {
pub ttl_seconds: f64,
pub sweep_interval_ms: u64,
}
impl Default for ReaperConfig {
fn default() -> Self {
Self {
ttl_seconds: 3600.0,
sweep_interval_ms: 300_000,
}
}
}
#[derive(Debug)]
pub struct ReaperHandle {
handle: Option<JoinHandle<()>>,
stop_tx: watch::Sender<bool>,
running_flag: Arc<std::sync::atomic::AtomicBool>,
}
impl ReaperHandle {
pub async fn stop(mut self) {
let _ = self.stop_tx.send(true);
if let Some(handle) = self.handle.take() {
if let Err(err) = handle.await {
if !err.is_cancelled() {
warn!("reaper task join failed: {err}");
}
}
}
self.running_flag
.store(false, std::sync::atomic::Ordering::SeqCst);
}
}
impl Drop for ReaperHandle {
fn drop(&mut self) {
self.running_flag
.store(false, std::sync::atomic::Ordering::SeqCst);
}
}
pub struct AsyncTaskManager {
executor: Arc<Executor>,
max_tasks: usize,
store: Arc<dyn TaskStore>,
handles: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
semaphore: Arc<Semaphore>,
admission_lock: Arc<tokio::sync::Mutex<()>>,
reaper_running: Arc<std::sync::atomic::AtomicBool>,
}
impl AsyncTaskManager {
pub fn new(executor: Arc<Executor>, max_concurrent: usize, max_tasks: usize) -> Self {
Self::with_store(
executor,
max_concurrent,
max_tasks,
Arc::new(InMemoryTaskStore::new()),
)
}
pub fn with_store(
executor: Arc<Executor>,
max_concurrent: usize,
max_tasks: usize,
store: Arc<dyn TaskStore>,
) -> Self {
Self {
executor,
max_tasks,
store,
handles: Arc::new(Mutex::new(HashMap::new())),
semaphore: Arc::new(Semaphore::new(max_concurrent)),
admission_lock: Arc::new(tokio::sync::Mutex::new(())),
reaper_running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
pub fn store_type_name(&self) -> &'static str {
self.store.store_type_name()
}
pub fn store(&self) -> Arc<dyn TaskStore> {
Arc::clone(&self.store)
}
pub async fn submit(
&self,
module_id: &str,
inputs: serde_json::Value,
context: Option<Context<serde_json::Value>>,
) -> Result<String, ModuleError> {
self.submit_with_retry(module_id, inputs, context, None)
.await
}
pub async fn submit_with_retry(
&self,
module_id: &str,
inputs: serde_json::Value,
context: Option<Context<serde_json::Value>>,
retry: Option<RetryConfig>,
) -> Result<String, ModuleError> {
let task_id = Uuid::new_v4().to_string();
let max_retries = retry.as_ref().map_or(0, |r| r.max_retries);
let info = TaskInfo {
task_id: task_id.clone(),
module_id: module_id.to_string(),
status: TaskStatus::Pending,
submitted_at: now_secs(),
started_at: None,
completed_at: None,
result: None,
error: None,
retry_count: 0,
max_retries,
};
{
let _admit = self.admission_lock.lock().await;
check_capacity_and_save(&*self.store, self.max_tasks, &info).await?;
}
let handles = Arc::clone(&self.handles);
let semaphore = Arc::clone(&self.semaphore);
let executor = Arc::clone(&self.executor);
let store_for_run = Arc::clone(&self.store);
let mid = module_id.to_string();
let tid = task_id.clone();
let handle = tokio::spawn(async move {
run_task(
tid.clone(),
mid,
inputs,
context,
retry,
executor,
semaphore,
store_for_run,
)
.await;
handles.lock().remove(&tid);
});
self.handles.lock().insert(task_id.clone(), handle);
Ok(task_id)
}
pub fn get_status(&self, task_id: &str) -> Option<TaskInfo> {
block_on_local(self.store.get(task_id)).ok().flatten()
}
pub async fn get_status_async(&self, task_id: &str) -> Option<TaskInfo> {
self.store.get(task_id).await.ok().flatten()
}
pub fn get_result(&self, task_id: &str) -> Result<serde_json::Value, ModuleError> {
block_on_local(self.get_result_async(task_id))
}
pub async fn get_result_async(&self, task_id: &str) -> Result<serde_json::Value, ModuleError> {
let info = self.store.get(task_id).await?.ok_or_else(|| {
ModuleError::new(
ErrorCode::GeneralInternalError,
format!("Task not found: {task_id}"),
)
})?;
if info.status != TaskStatus::Completed {
return Err(ModuleError::new(
ErrorCode::GeneralInternalError,
format!("Task {task_id} is not completed (status={:?})", info.status),
));
}
Ok(info.result.unwrap_or(serde_json::Value::Null))
}
pub async fn cancel(&self, task_id: &str) -> bool {
let Some(info) = self.store.get(task_id).await.ok().flatten() else {
return false;
};
if !info.status.is_active() {
return false;
}
if let Some(handle) = self.handles.lock().remove(task_id) {
handle.abort();
}
let mut updated = info;
if updated.status.is_active() {
updated.status = TaskStatus::Cancelled;
updated.completed_at = Some(now_secs());
let _ = self.store.save(&updated).await;
}
true
}
pub async fn shutdown(&self) {
let task_ids: Vec<String> = self
.store
.list(None)
.await
.unwrap_or_default()
.into_iter()
.filter_map(|info| info.status.is_active().then_some(info.task_id))
.collect();
for task_id in task_ids {
self.cancel(&task_id).await;
}
}
pub fn list_tasks(&self, status: Option<TaskStatus>) -> Vec<TaskInfo> {
block_on_local(self.store.list(status)).unwrap_or_default()
}
pub fn cleanup(&self, max_age_seconds: f64) -> usize {
let now = now_secs();
let to_remove: Vec<String> = block_on_local(self.store.list(None))
.unwrap_or_default()
.into_iter()
.filter(|info| info.status.is_terminal())
.filter(|info| {
let ref_time = info.completed_at.unwrap_or(info.submitted_at);
(now - ref_time) >= max_age_seconds
})
.map(|info| info.task_id)
.collect();
let count = to_remove.len();
for id in &to_remove {
let _ = block_on_local(self.store.delete(id));
self.handles.lock().remove(id);
}
count
}
pub fn task_count(&self) -> usize {
block_on_local(self.store.list(None)).map_or(0, |v| v.len())
}
pub fn start_reaper(&self, config: ReaperConfig) -> Result<ReaperHandle, ModuleError> {
if self
.reaper_running
.compare_exchange(
false,
true,
std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::SeqCst,
)
.is_err()
{
return Err(ModuleError::new(
ErrorCode::ReaperAlreadyRunning,
"AsyncTaskManager reaper is already running; call stop() first",
));
}
let store = Arc::clone(&self.store);
let (stop_tx, mut stop_rx) = watch::channel(false);
let interval = Duration::from_millis(config.sweep_interval_ms);
let ttl = config.ttl_seconds;
let handle = tokio::spawn(async move {
loop {
tokio::select! {
biased;
changed = stop_rx.changed() => {
if changed.is_ok() && *stop_rx.borrow() {
debug!("reaper received stop signal");
return;
}
}
() = tokio::time::sleep(interval) => {
let before = now_secs() - ttl;
match store.list_expired(before).await {
Ok(expired) => {
let count = expired.len();
for info in &expired {
if let Err(err) = store.delete(&info.task_id).await {
warn!(task_id = %info.task_id, "reaper delete failed: {err}");
}
}
if count > 0 {
debug!("reaper deleted {count} expired tasks");
}
}
Err(err) => {
warn!("reaper list_expired failed: {err}");
}
}
}
}
}
});
Ok(ReaperHandle {
handle: Some(handle),
stop_tx,
running_flag: Arc::clone(&self.reaper_running),
})
}
}
async fn check_capacity_and_save(
store: &dyn TaskStore,
max_tasks: usize,
info: &TaskInfo,
) -> Result<(), ModuleError> {
let current = store.list(None).await?;
let active = current.iter().filter(|t| t.status.is_active()).count();
if active >= max_tasks {
return Err(ModuleError::new(
ErrorCode::TaskLimitExceeded,
format!("Task limit reached ({max_tasks})"),
));
}
store.save(info).await?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn run_task(
task_id: String,
module_id: String,
inputs: serde_json::Value,
context: Option<Context<serde_json::Value>>,
retry: Option<RetryConfig>,
executor: Arc<Executor>,
semaphore: Arc<Semaphore>,
store: Arc<dyn TaskStore>,
) {
let max_retries = retry.as_ref().map_or(0, |r| r.max_retries);
loop {
let Ok(permit) = semaphore.acquire().await else {
mark_cancelled(&store, &task_id).await;
return;
};
let Ok(Some(mut info)) = store.get(&task_id).await else {
return;
};
if info.status == TaskStatus::Cancelled {
return;
}
info.status = TaskStatus::Running;
if info.started_at.is_none() {
info.started_at = Some(now_secs());
}
if let Err(err) = store.save(&info).await {
error!(task_id = %task_id, "store.save(running) failed: {err}");
return;
}
let result = executor
.call(&module_id, inputs.clone(), context.as_ref(), None)
.await;
drop(permit);
let Ok(Some(mut info)) = store.get(&task_id).await else {
return;
};
if info.status == TaskStatus::Cancelled {
return;
}
match result {
Ok(output) => {
info.status = TaskStatus::Completed;
info.completed_at = Some(now_secs());
info.result = Some(output);
save_terminal_if_not_cancelled(&store, &task_id, &info).await;
return;
}
Err(err) => {
if let Some(cfg) = retry.as_ref() {
if info.retry_count < max_retries {
let delay_ms = cfg.compute_delay_ms(info.retry_count);
info.retry_count += 1;
info.status = TaskStatus::Pending;
let _ = store.save(&info).await;
debug!(
task_id = %task_id,
attempt = info.retry_count,
delay_ms,
"scheduling retry"
);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
continue;
}
}
info.status = TaskStatus::Failed;
info.completed_at = Some(now_secs());
info.error = Some(err.to_string());
save_terminal_if_not_cancelled(&store, &task_id, &info).await;
error!(task_id = %task_id, "task failed: {err}");
return;
}
}
}
}
pub(crate) async fn save_terminal_if_not_cancelled(
store: &Arc<dyn TaskStore>,
task_id: &str,
info: &TaskInfo,
) {
if let Ok(Some(current)) = store.get(task_id).await {
if current.status == TaskStatus::Cancelled {
return;
}
}
let _ = store.save(info).await;
}
async fn mark_cancelled(store: &Arc<dyn TaskStore>, task_id: &str) {
if let Ok(Some(mut info)) = store.get(task_id).await {
if info.status.is_active() {
info.status = TaskStatus::Cancelled;
info.completed_at = Some(now_secs());
let _ = store.save(&info).await;
}
}
}
fn block_on_local<F, T>(fut: F) -> T
where
F: std::future::Future<Output = T>,
{
use std::pin::pin;
use std::ptr;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
const VTABLE: RawWakerVTable = RawWakerVTable::new(
|_| RawWaker::new(ptr::null(), &VTABLE),
|_| {},
|_| {},
|_| {},
);
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &VTABLE)) };
let mut cx = Context::from_waker(&waker);
let mut fut = pin!(fut);
match fut.as_mut().poll(&mut cx) {
Poll::Ready(v) => v,
Poll::Pending => panic!(
"block_on_local: TaskStore future yielded — use the _async variants for non-blocking stores"
),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::executor::Executor;
use crate::registry::registry::Registry;
fn make_executor() -> Arc<Executor> {
let registry = Arc::new(Registry::default());
let config = Arc::new(crate::config::Config::default());
Arc::new(Executor::new(registry, config))
}
#[test]
fn retry_delay_grows_exponentially_and_caps() {
let cfg = RetryConfig {
max_retries: 5,
retry_delay_ms: 1000,
backoff_multiplier: 2.0,
max_retry_delay_ms: 30_000,
};
assert_eq!(cfg.compute_delay_ms(0), 1000);
assert_eq!(cfg.compute_delay_ms(1), 2000);
assert_eq!(cfg.compute_delay_ms(2), 4000);
assert_eq!(cfg.compute_delay_ms(3), 8000);
assert_eq!(cfg.compute_delay_ms(4), 16_000);
assert_eq!(cfg.compute_delay_ms(5), 30_000);
}
#[tokio::test]
async fn default_store_is_in_memory() {
let mgr = AsyncTaskManager::new(make_executor(), 4, 100);
assert_eq!(mgr.store_type_name(), "InMemoryTaskStore");
}
#[tokio::test]
async fn in_memory_store_save_and_get_round_trip() {
let store = InMemoryTaskStore::new();
let info = TaskInfo {
task_id: "abc".into(),
module_id: "data.process".into(),
status: TaskStatus::Completed,
submitted_at: 1.0,
started_at: Some(2.0),
completed_at: Some(3.0),
result: Some(serde_json::json!({"ok": true})),
error: None,
retry_count: 0,
max_retries: 0,
};
store.save(&info).await.unwrap();
let got = store.get("abc").await.unwrap().unwrap();
assert_eq!(got.task_id, "abc");
assert_eq!(got.status, TaskStatus::Completed);
}
#[tokio::test]
async fn in_memory_store_list_filters_by_status() {
let store = InMemoryTaskStore::new();
for (id, status) in [
("c1", TaskStatus::Completed),
("c2", TaskStatus::Running),
("c3", TaskStatus::Failed),
] {
store
.save(&TaskInfo {
task_id: id.into(),
module_id: "m".into(),
status,
submitted_at: 0.0,
started_at: None,
completed_at: None,
result: None,
error: None,
retry_count: 0,
max_retries: 0,
})
.await
.unwrap();
}
let completed = store.list(Some(TaskStatus::Completed)).await.unwrap();
assert_eq!(completed.len(), 1);
assert_eq!(completed[0].task_id, "c1");
}
#[tokio::test]
async fn in_memory_store_list_expired_skips_active_tasks() {
let store = InMemoryTaskStore::new();
store
.save(&TaskInfo {
task_id: "old-completed".into(),
module_id: "m".into(),
status: TaskStatus::Completed,
submitted_at: 0.0,
started_at: Some(0.0),
completed_at: Some(100.0),
result: None,
error: None,
retry_count: 0,
max_retries: 0,
})
.await
.unwrap();
store
.save(&TaskInfo {
task_id: "old-running".into(),
module_id: "m".into(),
status: TaskStatus::Running,
submitted_at: 0.0,
started_at: Some(0.0),
completed_at: None,
result: None,
error: None,
retry_count: 0,
max_retries: 0,
})
.await
.unwrap();
let expired = store.list_expired(1000.0).await.unwrap();
assert_eq!(expired.len(), 1);
assert_eq!(expired[0].task_id, "old-completed");
}
#[tokio::test]
async fn save_terminal_does_not_overwrite_cancelled() {
let store: Arc<dyn TaskStore> = Arc::new(InMemoryTaskStore::new());
let task_id = "race-task";
store
.save(&TaskInfo {
task_id: task_id.into(),
module_id: "m".into(),
status: TaskStatus::Cancelled,
submitted_at: 0.0,
started_at: Some(0.0),
completed_at: Some(1.0),
result: None,
error: None,
retry_count: 0,
max_retries: 0,
})
.await
.unwrap();
let terminal = TaskInfo {
task_id: task_id.into(),
module_id: "m".into(),
status: TaskStatus::Completed,
submitted_at: 0.0,
started_at: Some(0.0),
completed_at: Some(2.0),
result: Some(serde_json::json!({"value": 42})),
error: None,
retry_count: 0,
max_retries: 0,
};
save_terminal_if_not_cancelled(&store, task_id, &terminal).await;
let after = store.get(task_id).await.unwrap().expect("task present");
assert_eq!(
after.status,
TaskStatus::Cancelled,
"terminal save MUST NOT overwrite a concurrent cancellation"
);
assert!(
after.result.is_none(),
"the cancel-time TaskInfo had no result; the overwriting Completed payload must not leak through"
);
}
}