use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use awaken_contract::StateError;
use parking_lot::RwLock;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use crate::cancellation::{CancellationHandle, CancellationToken};
use crate::inbox::InboxSender;
use crate::state::{MutationBatch, StateStore};
use super::state::{
BackgroundTaskStateAction, BackgroundTaskStateKey, BackgroundTaskStateSnapshot,
PersistedTaskMeta,
};
use super::types::{
AgentTaskContext, TaskContext, TaskEvent, TaskId, TaskParentContext, TaskResult, TaskStatus,
TaskSummary,
};
use super::{
BackgroundTaskExecutionContext, current_background_task_context, current_tool_lineage_context,
scope_background_task_context,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SendError {
TaskNotFound,
NotOwner,
TaskTerminated(TaskStatus),
NoInbox,
InboxClosed,
}
impl std::fmt::Display for SendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TaskNotFound => write!(f, "task not found"),
Self::NotOwner => write!(f, "caller does not own this task"),
Self::TaskTerminated(s) => write!(f, "task already {}", s.as_str()),
Self::NoInbox => write!(f, "task has no inbox (not a sub-agent)"),
Self::InboxClosed => write!(f, "sub-agent inbox closed"),
}
}
}
impl std::error::Error for SendError {}
const RESERVED_NAMES: &[&str] = &["parent", "self", "all", "broadcast"];
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
#[non_exhaustive]
pub enum SpawnError {
#[error("'{0}' is a reserved name")]
ReservedName(String),
#[error("a running task named '{0}' already exists")]
DuplicateName(String),
#[error("background task state store is not configured")]
StoreNotConfigured,
#[error(transparent)]
State(#[from] StateError),
}
#[derive(Debug, thiserror::Error)]
enum MetaCommitError {
#[error("background task state store is not configured")]
StoreUnavailable,
#[error(transparent)]
State(#[from] StateError),
}
impl From<MetaCommitError> for SpawnError {
fn from(err: MetaCommitError) -> Self {
match err {
MetaCommitError::StoreUnavailable => Self::StoreNotConfigured,
MetaCommitError::State(e) => Self::State(e),
}
}
}
struct TaskHandle {
task_id: TaskId,
owner_thread_id: String,
cancel_handle: CancellationHandle,
_join_handle: JoinHandle<()>,
agent_inbox: Option<InboxSender>,
}
pub struct BackgroundTaskManager {
handles: Mutex<HashMap<TaskId, TaskHandle>>,
counter: AtomicU64,
owner_inbox: RwLock<Option<InboxSender>>,
store: std::sync::OnceLock<StateStore>,
}
impl BackgroundTaskManager {
pub fn new() -> Self {
Self {
handles: Mutex::new(HashMap::new()),
counter: AtomicU64::new(0),
owner_inbox: RwLock::new(None),
store: std::sync::OnceLock::new(),
}
}
pub fn set_owner_inbox(&self, inbox: InboxSender) {
*self.owner_inbox.write() = Some(inbox);
}
pub fn set_store(&self, store: StateStore) {
let _ = self.store.set(store);
}
fn validate_name(&self, name: &str, owner_thread_id: &str) -> Result<(), SpawnError> {
if RESERVED_NAMES.contains(&name) {
return Err(SpawnError::ReservedName(name.to_string()));
}
if let Some(store) = self.store()
&& let Some(snap) = store.read::<BackgroundTaskStateKey>()
{
for meta in snap.tasks.values() {
if meta.owner_thread_id == owner_thread_id
&& !meta.status.is_terminal()
&& meta.name.as_deref() == Some(name)
{
return Err(SpawnError::DuplicateName(name.to_string()));
}
}
}
Ok(())
}
fn store(&self) -> Option<&StateStore> {
self.store.get()
}
fn owner_inbox(&self) -> Option<InboxSender> {
self.owner_inbox.read().clone()
}
#[cfg(test)]
pub(crate) fn panic_while_holding_owner_inbox_lock_for_test(&self) {
let _guard = self.owner_inbox.write();
panic!("owner_inbox lock test panic");
}
#[cfg(test)]
pub(crate) fn has_owner_inbox_for_test(&self) -> bool {
self.owner_inbox().is_some()
}
fn next_task_id(&self) -> TaskId {
let n = self.counter.fetch_add(1, Ordering::Relaxed);
format!("bg_{n}")
}
fn merge_ambient_parent_context(
&self,
mut parent_context: TaskParentContext,
) -> TaskParentContext {
if parent_context.task_id.is_none()
&& let Some(context) = current_background_task_context()
{
parent_context.task_id = Some(context.task_id);
}
if let Some(context) = current_tool_lineage_context() {
if parent_context.run_id.is_none() {
parent_context.run_id = Some(context.run_id);
}
if parent_context.call_id.is_none() && !context.call_id.is_empty() {
parent_context.call_id = Some(context.call_id);
}
if parent_context.agent_id.is_none() && !context.agent_id.is_empty() {
parent_context.agent_id = Some(context.agent_id);
}
}
parent_context
}
fn commit_meta(&self, action: BackgroundTaskStateAction) -> Result<u64, MetaCommitError> {
let Some(store) = self.store() else {
return Err(MetaCommitError::StoreUnavailable);
};
let mut batch = MutationBatch::new();
batch.update::<BackgroundTaskStateKey>(action);
store.commit(batch).map_err(Into::into)
}
fn commit_meta_or_warn(
&self,
action: BackgroundTaskStateAction,
operation: &'static str,
task_id: &str,
) {
if let Err(error) = self.commit_meta(action) {
metrics::counter!(
"awaken_background_task_state_commit_failures_total",
"operation" => operation
)
.increment(1);
tracing::warn!(
operation,
task_id,
error = %error,
"background task metadata commit failed"
);
}
}
fn terminal_event(task_id: &str, result: &TaskResult) -> TaskEvent {
match result {
TaskResult::Success(val) => TaskEvent::Completed {
task_id: task_id.to_string(),
result: Some(val.clone()),
},
TaskResult::Failed(err) => TaskEvent::Failed {
task_id: task_id.to_string(),
error: err.clone(),
},
TaskResult::Cancelled => TaskEvent::Cancelled {
task_id: task_id.to_string(),
},
}
}
pub async fn spawn<F, Fut>(
self: &Arc<Self>,
owner_thread_id: &str,
task_type: &str,
name: Option<&str>,
description: &str,
parent_context: TaskParentContext,
task_fn: F,
) -> Result<TaskId, SpawnError>
where
F: FnOnce(TaskContext) -> Fut + Send + 'static,
Fut: std::future::Future<Output = TaskResult> + Send + 'static,
{
let parent_context = self.merge_ambient_parent_context(parent_context);
if let Some(n) = name {
self.validate_name(n, owner_thread_id)?;
}
let task_id = self.next_task_id();
let (cancel_handle, cancel_token) = CancellationToken::new_pair();
let now = now_ms();
let ctx = TaskContext {
task_id: task_id.clone(),
cancel_token,
inbox: self.owner_inbox(),
};
let task_name = name.map(|n| n.to_string());
self.commit_meta(BackgroundTaskStateAction::Upsert(Box::new(
PersistedTaskMeta {
task_id: task_id.clone(),
owner_thread_id: owner_thread_id.to_string(),
task_type: task_type.to_string(),
name: task_name.clone(),
description: description.to_string(),
status: TaskStatus::Running,
error: None,
result: None,
created_at_ms: now,
completed_at_ms: None,
parent_context: parent_context.clone(),
},
)))
.map_err(SpawnError::from)?;
let manager = Arc::clone(self);
let tid = task_id.clone();
let owner_inbox = self.owner_inbox();
let owner = owner_thread_id.to_string();
let ttype = task_type.to_string();
let tname = task_name.clone();
let desc = description.to_string();
let join_handle = tokio::spawn(async move {
let result = scope_background_task_context(
BackgroundTaskExecutionContext {
manager: manager.clone(),
task_id: tid.clone(),
},
task_fn(ctx),
)
.await;
let completed_at = now_ms();
let (status, error, result_val) = match &result {
TaskResult::Success(val) => (TaskStatus::Completed, None, Some(val.clone())),
TaskResult::Failed(err) => (TaskStatus::Failed, Some(err.clone()), None),
TaskResult::Cancelled => (TaskStatus::Cancelled, None, None),
};
manager.commit_meta_or_warn(
BackgroundTaskStateAction::Upsert(Box::new(PersistedTaskMeta {
task_id: tid.clone(),
owner_thread_id: owner,
task_type: ttype,
name: tname,
description: desc,
status,
error,
result: result_val,
created_at_ms: now,
completed_at_ms: Some(completed_at),
parent_context,
})),
"task_completion",
&tid,
);
if let Some(ref inbox) = owner_inbox {
let event = Self::terminal_event(&tid, &result);
inbox.send(
serde_json::to_value(&event).expect("TaskEvent serialization is infallible"),
);
}
});
let handle = TaskHandle {
task_id: task_id.clone(),
owner_thread_id: owner_thread_id.to_string(),
cancel_handle,
_join_handle: join_handle,
agent_inbox: None,
};
self.handles.lock().await.insert(task_id.clone(), handle);
Ok(task_id)
}
pub async fn cancel(&self, task_id: &str) -> bool {
let handles = self.handles.lock().await;
if let Some(handle) = handles.get(task_id) {
if let Some(store) = self.store()
&& let Some(snap) = store.read::<BackgroundTaskStateKey>()
&& let Some(meta) = snap.tasks.get(task_id)
&& meta.status.is_terminal()
{
return false;
}
handle.cancel_handle.cancel();
return true;
}
false
}
pub async fn cancel_tree(&self, task_id: &str) -> usize {
let Some(task_ids) = self.task_tree_ids(task_id) else {
return 0;
};
let handles = self.handles.lock().await;
let store_snap = self
.store()
.and_then(|s| s.read::<BackgroundTaskStateKey>());
let mut count = 0;
for task_id in task_ids {
let Some(handle) = handles.get(&task_id) else {
continue;
};
let is_terminal = store_snap
.as_ref()
.and_then(|snap| snap.tasks.get(&task_id))
.map(|meta| meta.status.is_terminal())
.unwrap_or(false);
if !is_terminal {
handle.cancel_handle.cancel();
count += 1;
}
}
count
}
pub async fn cancel_all(&self, owner_thread_id: &str) -> usize {
let handles = self.handles.lock().await;
let store_snap = self
.store()
.and_then(|s| s.read::<BackgroundTaskStateKey>());
let mut count = 0;
for handle in handles.values() {
if handle.owner_thread_id != owner_thread_id {
continue;
}
let is_terminal = store_snap
.as_ref()
.and_then(|snap| snap.tasks.get(&handle.task_id))
.map(|m| m.status.is_terminal())
.unwrap_or(false);
if !is_terminal {
handle.cancel_handle.cancel();
count += 1;
}
}
count
}
pub async fn list(&self, owner_thread_id: &str) -> Vec<TaskSummary> {
if let Some(store) = self.store()
&& let Some(snap) = store.read::<BackgroundTaskStateKey>()
{
return snap
.tasks
.values()
.filter(|m| m.owner_thread_id == owner_thread_id)
.map(Self::meta_to_summary)
.collect();
}
Vec::new()
}
pub async fn get(&self, task_id: &str) -> Option<TaskSummary> {
self.store()
.and_then(|s| s.read::<BackgroundTaskStateKey>())
.and_then(|snap| snap.tasks.get(task_id).map(Self::meta_to_summary))
}
fn meta_to_summary(m: &PersistedTaskMeta) -> TaskSummary {
TaskSummary {
task_id: m.task_id.clone(),
task_type: m.task_type.clone(),
description: m.description.clone(),
status: m.status,
error: m.error.clone(),
result: m.result.clone(),
created_at_ms: m.created_at_ms,
completed_at_ms: m.completed_at_ms,
parent_context: m.parent_context.clone(),
}
}
pub(crate) async fn restore_for_thread(
&self,
owner_thread_id: &str,
snapshot: &BackgroundTaskStateSnapshot,
) {
if let Some(store) = self.store() {
let existing = store.read::<BackgroundTaskStateKey>().unwrap_or_default();
for (task_id, meta) in &snapshot.tasks {
if existing.tasks.contains_key(task_id) {
continue;
}
if let Some(n) = task_id
.strip_prefix("bg_")
.and_then(|s| s.parse::<u64>().ok())
{
self.counter
.fetch_max(n.saturating_add(1), Ordering::Relaxed);
}
let handles = self.handles.lock().await;
let has_live_handle = handles.contains_key(task_id);
drop(handles);
let mut to_store = meta.clone();
to_store.owner_thread_id = owner_thread_id.to_string();
if meta.status == TaskStatus::Running && !has_live_handle {
to_store.status = TaskStatus::Failed;
to_store.error =
Some("task orphaned: runtime restarted while running".to_string());
}
self.commit_meta_or_warn(
BackgroundTaskStateAction::Upsert(Box::new(to_store)),
"restore_task_metadata",
task_id,
);
}
}
}
pub async fn has_running(&self, owner_thread_id: &str) -> bool {
if let Some(store) = self.store()
&& let Some(snap) = store.read::<BackgroundTaskStateKey>()
{
return snap
.tasks
.values()
.any(|m| m.owner_thread_id == owner_thread_id && !m.status.is_terminal());
}
self.handles
.lock()
.await
.values()
.any(|h| h.owner_thread_id == owner_thread_id)
}
pub async fn spawn_agent<F, Fut>(
self: &Arc<Self>,
owner_thread_id: &str,
name: Option<&str>,
description: &str,
parent_context: TaskParentContext,
task_fn: F,
) -> Result<TaskId, SpawnError>
where
F: FnOnce(CancellationToken, InboxSender, crate::inbox::InboxReceiver) -> Fut
+ Send
+ 'static,
Fut: std::future::Future<Output = TaskResult> + Send + 'static,
{
self.spawn_agent_with_context(owner_thread_id, name, description, parent_context, |ctx| {
task_fn(ctx.cancel_token, ctx.inbox_sender, ctx.inbox_receiver)
})
.await
}
pub async fn spawn_agent_with_context<F, Fut>(
self: &Arc<Self>,
owner_thread_id: &str,
name: Option<&str>,
description: &str,
parent_context: TaskParentContext,
task_fn: F,
) -> Result<TaskId, SpawnError>
where
F: FnOnce(AgentTaskContext) -> Fut + Send + 'static,
Fut: std::future::Future<Output = TaskResult> + Send + 'static,
{
let parent_context = self.merge_ambient_parent_context(parent_context);
if let Some(n) = name {
self.validate_name(n, owner_thread_id)?;
}
let task_id = self.next_task_id();
let (cancel_handle, cancel_token) = CancellationToken::new_pair();
let now = now_ms();
let (child_inbox_tx, child_inbox_rx) = crate::inbox::inbox_channel();
let stored_sender = child_inbox_tx.clone();
let task_name = name.map(|n| n.to_string());
self.commit_meta(BackgroundTaskStateAction::Upsert(Box::new(
PersistedTaskMeta {
task_id: task_id.clone(),
owner_thread_id: owner_thread_id.to_string(),
task_type: "sub_agent".to_string(),
name: task_name.clone(),
description: description.to_string(),
status: TaskStatus::Running,
error: None,
result: None,
created_at_ms: now,
completed_at_ms: None,
parent_context: parent_context.clone(),
},
)))
.map_err(SpawnError::from)?;
let manager = Arc::clone(self);
let tid = task_id.clone();
let owner_inbox = self.owner_inbox();
let owner = owner_thread_id.to_string();
let tname = task_name.clone();
let desc = description.to_string();
let join_handle = tokio::spawn(async move {
let result = scope_background_task_context(
BackgroundTaskExecutionContext {
manager: manager.clone(),
task_id: tid.clone(),
},
task_fn(AgentTaskContext {
task_id: tid.clone(),
cancel_token,
inbox_sender: child_inbox_tx,
inbox_receiver: child_inbox_rx,
}),
)
.await;
let completed_at = now_ms();
let (status, error, result_val) = match &result {
TaskResult::Success(val) => (TaskStatus::Completed, None, Some(val.clone())),
TaskResult::Failed(err) => (TaskStatus::Failed, Some(err.clone()), None),
TaskResult::Cancelled => (TaskStatus::Cancelled, None, None),
};
manager.commit_meta_or_warn(
BackgroundTaskStateAction::Upsert(Box::new(PersistedTaskMeta {
task_id: tid.clone(),
owner_thread_id: owner,
task_type: "sub_agent".to_string(),
name: tname,
description: desc,
status,
error,
result: result_val,
created_at_ms: now,
completed_at_ms: Some(completed_at),
parent_context,
})),
"sub_agent_completion",
&tid,
);
let event = Self::terminal_event(&tid, &result);
if let Some(ref inbox) = owner_inbox {
inbox.send(
serde_json::to_value(&event).expect("TaskEvent serialization is infallible"),
);
}
});
let handle = TaskHandle {
task_id: task_id.clone(),
owner_thread_id: owner_thread_id.to_string(),
cancel_handle,
_join_handle: join_handle,
agent_inbox: Some(stored_sender),
};
self.handles.lock().await.insert(task_id.clone(), handle);
Ok(task_id)
}
pub async fn send_task_inbox_message(
&self,
task_id: &str,
owner_thread_id: &str,
sender_agent_id: &str,
content: &str,
) -> Result<(), SendError> {
let handles = self.handles.lock().await;
let handle = handles.get(task_id).ok_or(SendError::TaskNotFound)?;
if handle.owner_thread_id != owner_thread_id {
return Err(SendError::NotOwner);
}
if let Some(store) = self.store()
&& let Some(snap) = store.read::<BackgroundTaskStateKey>()
&& let Some(meta) = snap.tasks.get(task_id)
&& meta.status.is_terminal()
{
return Err(SendError::TaskTerminated(meta.status));
}
let inbox = handle.agent_inbox.as_ref().ok_or(SendError::NoInbox)?;
let event = TaskEvent::Custom {
task_id: task_id.to_string(),
event_type: "agent_message".to_string(),
payload: serde_json::json!({
"from": sender_agent_id,
"content": content,
}),
};
if inbox.send(serde_json::to_value(&event).expect("TaskEvent serialization is infallible"))
{
Ok(())
} else {
Err(SendError::InboxClosed)
}
}
pub(crate) fn task_tree_ids(&self, task_id: &str) -> Option<Vec<TaskId>> {
let snapshot = self
.store()
.and_then(|store| store.read::<BackgroundTaskStateKey>())?;
if !snapshot.tasks.contains_key(task_id) {
return None;
}
let mut ordered = Vec::new();
let mut stack = vec![task_id.to_string()];
while let Some(current) = stack.pop() {
if ordered.iter().any(|seen| seen == ¤t) {
continue;
}
ordered.push(current.clone());
for meta in snapshot.tasks.values() {
if meta.parent_context.task_id.as_deref() == Some(current.as_str()) {
stack.push(meta.task_id.clone());
}
}
}
Some(ordered)
}
pub(crate) fn resolve_live_child_task(
&self,
parent_task_id: &str,
name_or_task_id: &str,
) -> Option<TaskId> {
let snapshot = self.store()?.read::<BackgroundTaskStateKey>()?;
for meta in snapshot.tasks.values() {
if meta.status.is_terminal() {
continue;
}
if meta.parent_context.task_id.as_deref() != Some(parent_task_id) {
continue;
}
if meta.task_id == name_or_task_id || meta.name.as_deref() == Some(name_or_task_id) {
return Some(meta.task_id.clone());
}
}
None
}
pub(crate) fn resolve_live_child_run(
&self,
parent_run_id: &str,
name_or_task_id: &str,
) -> Option<TaskId> {
let snapshot = self.store()?.read::<BackgroundTaskStateKey>()?;
for meta in snapshot.tasks.values() {
if meta.status.is_terminal() {
continue;
}
if meta.parent_context.run_id.as_deref() != Some(parent_run_id)
|| meta.parent_context.task_id.is_some()
{
continue;
}
if meta.task_id == name_or_task_id || meta.name.as_deref() == Some(name_or_task_id) {
return Some(meta.task_id.clone());
}
}
None
}
pub async fn cancel_descendants_for_run(&self, parent_run_id: &str) -> usize {
let root_task_ids = self
.store()
.and_then(|store| store.read::<BackgroundTaskStateKey>())
.map(|snapshot| {
snapshot
.tasks
.values()
.filter(|meta| {
!meta.status.is_terminal()
&& meta.parent_context.run_id.as_deref() == Some(parent_run_id)
&& meta.parent_context.task_id.is_none()
})
.map(|meta| meta.task_id.clone())
.collect::<Vec<_>>()
})
.unwrap_or_default();
let mut cancelled = 0usize;
for task_id in root_task_ids {
cancelled += self.cancel_tree(&task_id).await;
}
cancelled
}
#[cfg(test)]
pub(crate) async fn persisted_snapshot(&self) -> HashMap<TaskId, PersistedTaskMeta> {
if let Some(store) = self.store()
&& let Some(snap) = store.read::<BackgroundTaskStateKey>()
{
return snap.tasks;
}
HashMap::new()
}
}
impl Default for BackgroundTaskManager {
fn default() -> Self {
Self::new()
}
}
use awaken_contract::now_ms;