use std::sync::Arc;
use ahash::AHashMap;
use chrono::Utc;
use serde::Serialize;
use tokio::sync::broadcast;
use crate::a2a::core::bus::{Event, MessageBus};
use crate::a2a::core::task_types::{
ContextId, Task, TaskFilter, TaskId, TaskMessage, TaskState, TaskStatus,
};
use crate::a2a::core::types::AgentId;
#[derive(Debug, thiserror::Error)]
#[allow(clippy::enum_variant_names)]
pub enum TaskError {
#[error("no task with id '{id}'")]
TaskNotFound {
id: String,
},
#[error("invalid task transition for '{task_id}': {from} -> {to}")]
TaskInvalidTransition {
task_id: String,
from: String,
to: String,
},
#[error("task '{task_id}' is in terminal state {state} and cannot be modified")]
TaskAlreadyTerminal {
task_id: String,
state: String,
},
}
#[derive(Clone, Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[allow(clippy::enum_variant_names)]
pub enum TaskEvent {
TaskCreated(Arc<Task>),
TaskStatusChanged {
task_id: TaskId,
old_state: TaskState,
new_state: TaskState,
task: Arc<Task>,
},
}
pub struct TaskManager {
tasks: AHashMap<TaskId, Task>,
context_index: AHashMap<ContextId, Vec<TaskId>>,
bus: Arc<MessageBus>,
event_tx: broadcast::Sender<TaskEvent>,
}
impl TaskManager {
pub fn new(bus: Arc<MessageBus>) -> Self {
let (event_tx, _) = broadcast::channel(64);
Self {
tasks: AHashMap::new(),
context_index: AHashMap::new(),
bus,
event_tx,
}
}
pub fn create_task_with_deadline(
&mut self,
message: TaskMessage,
context_id: Option<ContextId>,
assignee: Option<AgentId>,
creator: Option<AgentId>,
metadata: Option<serde_json::Value>,
deadline: Option<chrono::DateTime<chrono::Utc>>,
) -> Result<Task, TaskError> {
let id = TaskId::new();
let context_id = context_id.unwrap_or_default();
let now = Utc::now();
let task = Task {
id,
context_id,
status: TaskStatus {
state: TaskState::Submitted,
message: Some(message.clone()),
timestamp: now,
},
history: vec![message],
artifacts: Vec::new(),
assignee,
creator,
metadata,
deadline,
};
self.context_index.entry(context_id).or_default().push(id);
let snapshot = Arc::new(task);
self.tasks.insert(id, Task::clone(&snapshot));
self.publish(TaskEvent::TaskCreated(Arc::clone(&snapshot)));
Ok(Task::clone(&snapshot))
}
pub fn get(&self, id: &TaskId) -> Option<&Task> {
self.tasks.get(id)
}
pub fn list_filtered(&self, filter: &TaskFilter) -> Vec<&Task> {
let candidates: Box<dyn Iterator<Item = &Task>> = match &filter.context_id {
Some(ctx) => {
let ids = self.context_index.get(ctx);
Box::new(
ids.into_iter()
.flatten()
.filter_map(|id| self.tasks.get(id)),
)
}
None => Box::new(self.tasks.values()),
};
candidates
.filter(|t| {
filter.state.as_ref().is_none_or(|s| &t.status.state == s)
&& filter
.assignee
.as_ref()
.is_none_or(|a| t.assignee.as_ref() == Some(a))
})
.collect()
}
pub fn update_status(
&mut self,
task_id: &TaskId,
new_state: TaskState,
message: Option<TaskMessage>,
) -> Result<Task, TaskError> {
let task = self
.tasks
.get_mut(task_id)
.ok_or_else(|| TaskError::TaskNotFound {
id: task_id.to_string(),
})?;
let old_state = task.status.state;
if old_state.is_terminal() {
return Err(TaskError::TaskAlreadyTerminal {
task_id: task_id.to_string(),
state: format!("{old_state:?}"),
});
}
if !old_state.can_transition_to(new_state) {
return Err(TaskError::TaskInvalidTransition {
task_id: task_id.to_string(),
from: format!("{old_state:?}"),
to: format!("{new_state:?}"),
});
}
let now = Utc::now();
task.status = TaskStatus {
state: new_state,
message: message.clone(),
timestamp: now,
};
if let Some(msg) = message {
task.history.push(msg);
}
let snapshot = Arc::new(task.clone());
self.publish(TaskEvent::TaskStatusChanged {
task_id: *task_id,
old_state,
new_state,
task: Arc::clone(&snapshot),
});
Ok(Task::clone(&snapshot))
}
pub fn cancel(
&mut self,
task_id: &TaskId,
message: Option<TaskMessage>,
) -> Result<Task, TaskError> {
self.update_status(task_id, TaskState::Canceled, message)
}
fn publish(&self, event: TaskEvent) {
let bus_event = match &event {
TaskEvent::TaskCreated(task) => Event::TaskCreated(Arc::clone(task)),
TaskEvent::TaskStatusChanged {
task_id,
old_state,
new_state,
task,
} => Event::TaskStatusChanged {
task_id: *task_id,
old_state: *old_state,
new_state: *new_state,
task: Arc::clone(task),
},
};
self.bus.publish(bus_event);
if let Err(tokio::sync::broadcast::error::SendError(dropped)) = self.event_tx.send(event) {
let event_type = match &dropped {
TaskEvent::TaskCreated(_) => "task_created",
TaskEvent::TaskStatusChanged { .. } => "task_status_changed",
};
tracing::trace!(event_type, "no task-event subscribers; event dropped");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::a2a::core::task_types::{MessageRole, Part};
use crate::a2a::core::types::MessageId;
fn make_manager() -> TaskManager {
let bus = Arc::new(MessageBus::new(16));
TaskManager::new(bus)
}
fn make_message() -> TaskMessage {
TaskMessage {
id: MessageId::new(),
role: MessageRole::User,
parts: vec![Part::Text {
text: "hello".to_owned(),
}],
metadata: None,
}
}
#[tokio::test]
async fn create_task_succeeds() {
let mut mgr = make_manager();
let task = mgr
.create_task_with_deadline(make_message(), None, None, None, None, None)
.expect("create_task_with_deadline must succeed");
assert_eq!(
task.status.state,
TaskState::Submitted,
"new task must start in Submitted state"
);
assert_eq!(
task.history.len(),
1,
"history must contain the initial message"
);
}
#[tokio::test]
async fn create_task_generates_context_id_when_none() {
let mut mgr = make_manager();
let task1 = mgr
.create_task_with_deadline(make_message(), None, None, None, None, None)
.expect("first create_task must succeed");
let task2 = mgr
.create_task_with_deadline(make_message(), None, None, None, None, None)
.expect("second create_task must succeed");
assert_ne!(
task1.context_id, task2.context_id,
"each task with no explicit context_id must get a unique one"
);
}
#[tokio::test]
async fn get_returns_created_task() {
let mut mgr = make_manager();
let task = mgr
.create_task_with_deadline(make_message(), None, None, None, None, None)
.expect("create_task must succeed");
let found = mgr.get(&task.id).expect("get must return the created task");
assert_eq!(found.id, task.id, "retrieved task id must match");
}
#[tokio::test]
async fn list_filtered_by_state() {
let mut mgr = make_manager();
let task1 = mgr
.create_task_with_deadline(make_message(), None, None, None, None, None)
.expect("first create must succeed");
let task2 = mgr
.create_task_with_deadline(make_message(), None, None, None, None, None)
.expect("second create must succeed");
mgr.update_status(&task2.id, TaskState::Working, None)
.expect("transition to Working must succeed");
let filter = TaskFilter {
state: Some(TaskState::Submitted),
context_id: None,
assignee: None,
};
let results = mgr.list_filtered(&filter);
assert_eq!(results.len(), 1, "only one task should be Submitted");
assert_eq!(results[0].id, task1.id, "the Submitted task must be task1");
}
#[tokio::test]
async fn update_status_valid_transition() {
let mut mgr = make_manager();
let task = mgr
.create_task_with_deadline(make_message(), None, None, None, None, None)
.expect("create must succeed");
let updated = mgr
.update_status(&task.id, TaskState::Working, None)
.expect("Submitted → Working is a valid transition");
assert_eq!(
updated.status.state,
TaskState::Working,
"task must be in Working state after transition"
);
}
#[tokio::test]
async fn update_status_invalid_transition() {
let mut mgr = make_manager();
let task = mgr
.create_task_with_deadline(make_message(), None, None, None, None, None)
.expect("create must succeed");
let err = mgr
.update_status(&task.id, TaskState::Completed, None)
.expect_err("Submitted → Completed must be rejected");
assert!(
matches!(err, TaskError::TaskInvalidTransition { .. }),
"expected TaskInvalidTransition, got: {err:?}"
);
}
#[tokio::test]
async fn update_status_terminal_rejects() {
let mut mgr = make_manager();
let task = mgr
.create_task_with_deadline(make_message(), None, None, None, None, None)
.expect("create must succeed");
mgr.update_status(&task.id, TaskState::Working, None)
.expect("Submitted → Working");
mgr.update_status(&task.id, TaskState::Completed, None)
.expect("Working → Completed");
let err = mgr
.update_status(&task.id, TaskState::Working, None)
.expect_err("Completed → Working must be rejected");
assert!(
matches!(err, TaskError::TaskAlreadyTerminal { .. }),
"expected TaskAlreadyTerminal, got: {err:?}"
);
}
#[tokio::test]
async fn cancel_from_working_succeeds() {
let mut mgr = make_manager();
let task = mgr
.create_task_with_deadline(make_message(), None, None, None, None, None)
.expect("create must succeed");
mgr.update_status(&task.id, TaskState::Working, None)
.expect("Submitted → Working");
let canceled = mgr
.cancel(&task.id, None)
.expect("cancel must succeed from Working");
assert_eq!(
canceled.status.state,
TaskState::Canceled,
"task must be in Canceled state after cancel()"
);
}
}