use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use crate::cancellation::{CancellationHandle, CancellationToken};
use crate::inbox::InboxSender;
use crate::state::{MutationBatch, StateStore};
use super::state::{
BackgroundTaskStateAction, BackgroundTaskStateKey, BackgroundTaskStateSnapshot,
PersistedTaskMeta,
};
use super::types::{
TaskContext, TaskEvent, TaskId, TaskParentContext, TaskResult, TaskStatus, TaskSummary,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SendError {
TaskNotFound,
NotOwner,
TaskTerminated(TaskStatus),
NoInbox,
InboxClosed,
}
impl std::fmt::Display for SendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TaskNotFound => write!(f, "task not found"),
Self::NotOwner => write!(f, "caller does not own this task"),
Self::TaskTerminated(s) => write!(f, "task already {}", s.as_str()),
Self::NoInbox => write!(f, "task has no inbox (not a sub-agent)"),
Self::InboxClosed => write!(f, "sub-agent inbox closed"),
}
}
}
impl std::error::Error for SendError {}
const RESERVED_NAMES: &[&str] = &["parent", "self", "all", "broadcast"];
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SpawnError {
ReservedName(String),
DuplicateName(String),
}
impl std::fmt::Display for SpawnError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ReservedName(n) => write!(f, "'{n}' is a reserved name"),
Self::DuplicateName(n) => write!(f, "a running task named '{n}' already exists"),
}
}
}
impl std::error::Error for SpawnError {}
struct TaskHandle {
task_id: TaskId,
owner_thread_id: String,
cancel_handle: CancellationHandle,
_join_handle: JoinHandle<()>,
agent_inbox: Option<InboxSender>,
}
pub struct BackgroundTaskManager {
handles: Mutex<HashMap<TaskId, TaskHandle>>,
counter: AtomicU64,
owner_inbox: std::sync::RwLock<Option<InboxSender>>,
store: std::sync::OnceLock<StateStore>,
}
impl BackgroundTaskManager {
pub fn new() -> Self {
Self {
handles: Mutex::new(HashMap::new()),
counter: AtomicU64::new(0),
owner_inbox: std::sync::RwLock::new(None),
store: std::sync::OnceLock::new(),
}
}
pub fn set_owner_inbox(&self, inbox: InboxSender) {
*self.owner_inbox.write().expect("owner_inbox poisoned") = Some(inbox);
}
pub fn set_store(&self, store: StateStore) {
let _ = self.store.set(store);
}
fn validate_name(&self, name: &str, owner_thread_id: &str) -> Result<(), SpawnError> {
if RESERVED_NAMES.contains(&name) {
return Err(SpawnError::ReservedName(name.to_string()));
}
if let Some(store) = self.store()
&& let Some(snap) = store.read::<BackgroundTaskStateKey>()
{
for meta in snap.tasks.values() {
if meta.owner_thread_id == owner_thread_id
&& !meta.status.is_terminal()
&& meta.name.as_deref() == Some(name)
{
return Err(SpawnError::DuplicateName(name.to_string()));
}
}
}
Ok(())
}
fn store(&self) -> Option<&StateStore> {
self.store.get()
}
fn owner_inbox(&self) -> Option<InboxSender> {
self.owner_inbox
.read()
.expect("owner_inbox poisoned")
.clone()
}
fn next_task_id(&self) -> TaskId {
let n = self.counter.fetch_add(1, Ordering::Relaxed);
format!("bg_{n}")
}
fn commit_meta(&self, action: BackgroundTaskStateAction) {
if let Some(store) = self.store() {
let mut batch = MutationBatch::new();
batch.update::<BackgroundTaskStateKey>(action);
let _ = store.commit(batch);
}
}
pub async fn spawn<F, Fut>(
self: &Arc<Self>,
owner_thread_id: &str,
task_type: &str,
name: Option<&str>,
description: &str,
parent_context: TaskParentContext,
task_fn: F,
) -> Result<TaskId, SpawnError>
where
F: FnOnce(TaskContext) -> Fut + Send + 'static,
Fut: std::future::Future<Output = TaskResult> + Send + 'static,
{
if let Some(n) = name {
self.validate_name(n, owner_thread_id)?;
}
let task_id = self.next_task_id();
let (cancel_handle, cancel_token) = CancellationToken::new_pair();
let now = now_ms();
let ctx = TaskContext {
task_id: task_id.clone(),
cancel_token,
inbox: self.owner_inbox(),
};
let task_name = name.map(|n| n.to_string());
self.commit_meta(BackgroundTaskStateAction::Upsert(Box::new(
PersistedTaskMeta {
task_id: task_id.clone(),
owner_thread_id: owner_thread_id.to_string(),
task_type: task_type.to_string(),
name: task_name.clone(),
description: description.to_string(),
status: TaskStatus::Running,
error: None,
result: None,
created_at_ms: now,
completed_at_ms: None,
parent_context: parent_context.clone(),
},
)));
let manager = Arc::clone(self);
let tid = task_id.clone();
let owner_inbox = self.owner_inbox();
let owner = owner_thread_id.to_string();
let ttype = task_type.to_string();
let tname = task_name.clone();
let desc = description.to_string();
let join_handle = tokio::spawn(async move {
let result = task_fn(ctx).await;
let completed_at = now_ms();
let (status, error, result_val) = match &result {
TaskResult::Success(val) => (TaskStatus::Completed, None, Some(val.clone())),
TaskResult::Failed(err) => (TaskStatus::Failed, Some(err.clone()), None),
TaskResult::Cancelled => (TaskStatus::Cancelled, None, None),
};
manager.commit_meta(BackgroundTaskStateAction::Upsert(Box::new(
PersistedTaskMeta {
task_id: tid.clone(),
owner_thread_id: owner,
task_type: ttype,
name: tname,
description: desc,
status,
error,
result: result_val,
created_at_ms: now,
completed_at_ms: Some(completed_at),
parent_context,
},
)));
if let Some(ref inbox) = owner_inbox {
let event = match &result {
TaskResult::Success(val) => TaskEvent::Completed {
task_id: tid.clone(),
result: Some(val.clone()),
},
TaskResult::Failed(err) => TaskEvent::Failed {
task_id: tid.clone(),
error: err.clone(),
},
TaskResult::Cancelled => TaskEvent::Cancelled {
task_id: tid.clone(),
},
};
inbox.send(serde_json::to_value(&event).unwrap_or_default());
}
});
let handle = TaskHandle {
task_id: task_id.clone(),
owner_thread_id: owner_thread_id.to_string(),
cancel_handle,
_join_handle: join_handle,
agent_inbox: None,
};
self.handles.lock().await.insert(task_id.clone(), handle);
Ok(task_id)
}
pub async fn cancel(&self, task_id: &str) -> bool {
let handles = self.handles.lock().await;
if let Some(handle) = handles.get(task_id) {
if let Some(store) = self.store()
&& let Some(snap) = store.read::<BackgroundTaskStateKey>()
&& let Some(meta) = snap.tasks.get(task_id)
&& meta.status.is_terminal()
{
return false;
}
handle.cancel_handle.cancel();
return true;
}
false
}
pub async fn cancel_all(&self, owner_thread_id: &str) -> usize {
let handles = self.handles.lock().await;
let store_snap = self
.store()
.and_then(|s| s.read::<BackgroundTaskStateKey>());
let mut count = 0;
for handle in handles.values() {
if handle.owner_thread_id != owner_thread_id {
continue;
}
let is_terminal = store_snap
.as_ref()
.and_then(|snap| snap.tasks.get(&handle.task_id))
.map(|m| m.status.is_terminal())
.unwrap_or(false);
if !is_terminal {
handle.cancel_handle.cancel();
count += 1;
}
}
count
}
pub async fn list(&self, owner_thread_id: &str) -> Vec<TaskSummary> {
if let Some(store) = self.store()
&& let Some(snap) = store.read::<BackgroundTaskStateKey>()
{
return snap
.tasks
.values()
.filter(|m| m.owner_thread_id == owner_thread_id)
.map(Self::meta_to_summary)
.collect();
}
Vec::new()
}
pub async fn get(&self, task_id: &str) -> Option<TaskSummary> {
self.store()
.and_then(|s| s.read::<BackgroundTaskStateKey>())
.and_then(|snap| snap.tasks.get(task_id).map(Self::meta_to_summary))
}
fn meta_to_summary(m: &PersistedTaskMeta) -> TaskSummary {
TaskSummary {
task_id: m.task_id.clone(),
task_type: m.task_type.clone(),
description: m.description.clone(),
status: m.status,
error: m.error.clone(),
result: m.result.clone(),
created_at_ms: m.created_at_ms,
completed_at_ms: m.completed_at_ms,
parent_context: m.parent_context.clone(),
}
}
pub(crate) async fn restore_for_thread(
&self,
owner_thread_id: &str,
snapshot: &BackgroundTaskStateSnapshot,
) {
if let Some(store) = self.store() {
let existing = store.read::<BackgroundTaskStateKey>().unwrap_or_default();
for (task_id, meta) in &snapshot.tasks {
if existing.tasks.contains_key(task_id) {
continue;
}
if let Some(n) = task_id
.strip_prefix("bg_")
.and_then(|s| s.parse::<u64>().ok())
{
self.counter
.fetch_max(n.saturating_add(1), Ordering::Relaxed);
}
let handles = self.handles.lock().await;
let has_live_handle = handles.contains_key(task_id);
drop(handles);
let mut to_store = meta.clone();
to_store.owner_thread_id = owner_thread_id.to_string();
if meta.status == TaskStatus::Running && !has_live_handle {
to_store.status = TaskStatus::Failed;
to_store.error =
Some("task orphaned: runtime restarted while running".to_string());
}
self.commit_meta(BackgroundTaskStateAction::Upsert(Box::new(to_store)));
}
}
}
pub async fn has_running(&self, owner_thread_id: &str) -> bool {
if let Some(store) = self.store()
&& let Some(snap) = store.read::<BackgroundTaskStateKey>()
{
return snap
.tasks
.values()
.any(|m| m.owner_thread_id == owner_thread_id && !m.status.is_terminal());
}
self.handles
.lock()
.await
.values()
.any(|h| h.owner_thread_id == owner_thread_id)
}
pub async fn spawn_agent<F, Fut>(
self: &Arc<Self>,
owner_thread_id: &str,
name: Option<&str>,
description: &str,
parent_context: TaskParentContext,
task_fn: F,
) -> Result<TaskId, SpawnError>
where
F: FnOnce(CancellationToken, InboxSender, crate::inbox::InboxReceiver) -> Fut
+ Send
+ 'static,
Fut: std::future::Future<Output = TaskResult> + Send + 'static,
{
if let Some(n) = name {
self.validate_name(n, owner_thread_id)?;
}
let task_id = self.next_task_id();
let (cancel_handle, cancel_token) = CancellationToken::new_pair();
let now = now_ms();
let (child_inbox_tx, child_inbox_rx) = crate::inbox::inbox_channel();
let stored_sender = child_inbox_tx.clone();
let task_name = name.map(|n| n.to_string());
self.commit_meta(BackgroundTaskStateAction::Upsert(Box::new(
PersistedTaskMeta {
task_id: task_id.clone(),
owner_thread_id: owner_thread_id.to_string(),
task_type: "sub_agent".to_string(),
name: task_name.clone(),
description: description.to_string(),
status: TaskStatus::Running,
error: None,
result: None,
created_at_ms: now,
completed_at_ms: None,
parent_context: parent_context.clone(),
},
)));
let manager = Arc::clone(self);
let tid = task_id.clone();
let owner_inbox = self.owner_inbox();
let owner = owner_thread_id.to_string();
let tname = task_name.clone();
let desc = description.to_string();
let join_handle = tokio::spawn(async move {
let result = task_fn(cancel_token, child_inbox_tx, child_inbox_rx).await;
let completed_at = now_ms();
let (status, error, result_val) = match &result {
TaskResult::Success(val) => (TaskStatus::Completed, None, Some(val.clone())),
TaskResult::Failed(err) => (TaskStatus::Failed, Some(err.clone()), None),
TaskResult::Cancelled => (TaskStatus::Cancelled, None, None),
};
manager.commit_meta(BackgroundTaskStateAction::Upsert(Box::new(
PersistedTaskMeta {
task_id: tid.clone(),
owner_thread_id: owner,
task_type: "sub_agent".to_string(),
name: tname,
description: desc,
status,
error,
result: result_val,
created_at_ms: now,
completed_at_ms: Some(completed_at),
parent_context,
},
)));
let event = match &result {
TaskResult::Success(val) => TaskEvent::Completed {
task_id: tid.clone(),
result: Some(val.clone()),
},
TaskResult::Failed(err) => TaskEvent::Failed {
task_id: tid.clone(),
error: err.clone(),
},
TaskResult::Cancelled => TaskEvent::Cancelled {
task_id: tid.clone(),
},
};
if let Some(ref inbox) = owner_inbox {
inbox.send(serde_json::to_value(&event).unwrap_or_default());
}
});
let handle = TaskHandle {
task_id: task_id.clone(),
owner_thread_id: owner_thread_id.to_string(),
cancel_handle,
_join_handle: join_handle,
agent_inbox: Some(stored_sender),
};
self.handles.lock().await.insert(task_id.clone(), handle);
Ok(task_id)
}
pub async fn send_task_inbox_message(
&self,
task_id: &str,
owner_thread_id: &str,
sender_agent_id: &str,
content: &str,
) -> Result<(), SendError> {
let handles = self.handles.lock().await;
let handle = handles.get(task_id).ok_or(SendError::TaskNotFound)?;
if handle.owner_thread_id != owner_thread_id {
return Err(SendError::NotOwner);
}
if let Some(store) = self.store()
&& let Some(snap) = store.read::<BackgroundTaskStateKey>()
&& let Some(meta) = snap.tasks.get(task_id)
&& meta.status.is_terminal()
{
return Err(SendError::TaskTerminated(meta.status));
}
let inbox = handle.agent_inbox.as_ref().ok_or(SendError::NoInbox)?;
let event = TaskEvent::Custom {
task_id: task_id.to_string(),
event_type: "agent_message".to_string(),
payload: serde_json::json!({
"from": sender_agent_id,
"content": content,
}),
};
if inbox.send(serde_json::to_value(&event).unwrap_or_default()) {
Ok(())
} else {
Err(SendError::InboxClosed)
}
}
#[cfg(test)]
pub(crate) async fn persisted_snapshot(&self) -> HashMap<TaskId, PersistedTaskMeta> {
if let Some(store) = self.store()
&& let Some(snap) = store.read::<BackgroundTaskStateKey>()
{
return snap.tasks;
}
HashMap::new()
}
}
impl Default for BackgroundTaskManager {
fn default() -> Self {
Self::new()
}
}
use awaken_contract::now_ms;