use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::Mutex;
use crate::adapter::business::push_notification::{
PushNotificationRegistry, PushNotificationSender,
};
#[cfg(feature = "http-client")]
use crate::adapter::business::push_notification::HttpPushNotificationSender;
#[cfg(not(feature = "http-client"))]
use crate::adapter::business::push_notification::NoopPushNotificationSender;
use crate::domain::{
A2AError, Artifact, Message, Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig,
TaskState, TaskStatus, TaskStatusUpdateEvent,
};
use crate::port::{
AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager,
streaming_handler::Subscriber,
};
type StatusSubscribers = Vec<Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>>;
type ArtifactSubscribers = Vec<Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>>;
pub(crate) struct TaskSubscribers {
status: StatusSubscribers,
artifacts: ArtifactSubscribers,
}
impl TaskSubscribers {
fn new() -> Self {
Self {
status: Vec::new(),
artifacts: Vec::new(),
}
}
}
pub struct InMemoryTaskStorage {
pub(crate) tasks: Arc<Mutex<HashMap<String, Task>>>,
pub(crate) subscribers: Arc<Mutex<HashMap<String, TaskSubscribers>>>,
pub(crate) push_notification_registry: Arc<PushNotificationRegistry>,
}
impl InMemoryTaskStorage {
pub fn new() -> Self {
#[cfg(feature = "http-client")]
let push_sender = HttpPushNotificationSender::new();
#[cfg(not(feature = "http-client"))]
let push_sender = NoopPushNotificationSender;
let push_registry = PushNotificationRegistry::new(push_sender);
Self {
tasks: Arc::new(Mutex::new(HashMap::new())),
subscribers: Arc::new(Mutex::new(HashMap::new())),
push_notification_registry: Arc::new(push_registry),
}
}
pub fn with_push_sender(push_sender: impl PushNotificationSender + 'static) -> Self {
let push_registry = PushNotificationRegistry::new(push_sender);
Self {
tasks: Arc::new(Mutex::new(HashMap::new())),
subscribers: Arc::new(Mutex::new(HashMap::new())),
push_notification_registry: Arc::new(push_registry),
}
}
pub async fn add_status_subscriber_legacy(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
) -> Result<(), A2AError> {
self.add_status_subscriber(task_id, subscriber)
.await
.map(|_| ())
}
pub async fn add_artifact_subscriber_legacy(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
) -> Result<(), A2AError> {
self.add_artifact_subscriber(task_id, subscriber)
.await
.map(|_| ())
}
}
impl Default for InMemoryTaskStorage {
fn default() -> Self {
Self::new()
}
}
impl InMemoryTaskStorage {
async fn get_task_context_id(&self, task_id: &str) -> String {
let tasks_guard = self.tasks.lock().await;
tasks_guard
.get(task_id)
.map(|t| t.context_id.clone())
.unwrap_or_else(|| "default".to_string())
}
pub(crate) async fn broadcast_status_update(
&self,
task_id: &str,
status: TaskStatus,
) -> Result<(), A2AError> {
let context_id = self.get_task_context_id(task_id).await;
let event = TaskStatusUpdateEvent {
task_id: task_id.to_string(),
context_id,
kind: "status-update".to_string(),
status: status.clone(),
metadata: None,
};
#[cfg(feature = "tracing")]
tracing::debug!(
task_id = %task_id,
state = ?status.state,
"📡 Broadcasting status update to subscribers"
);
let subscriber_count = {
let subscribers_guard = self.subscribers.lock().await;
if let Some(task_subscribers) = subscribers_guard.get(task_id) {
let count = task_subscribers.status.len();
#[cfg(feature = "tracing")]
tracing::info!(
task_id = %task_id,
subscriber_count = count,
state = ?status.state,
"📡 Notifying WebSocket subscribers of status update"
);
for (i, subscriber) in task_subscribers.status.iter().enumerate() {
if let Err(e) = subscriber.on_update(event.clone()).await {
#[cfg(feature = "tracing")]
tracing::error!(
task_id = %task_id,
subscriber_index = i,
error = %e,
"❌ Failed to notify subscriber"
);
eprintln!("Failed to notify subscriber: {}", e);
} else {
#[cfg(feature = "tracing")]
tracing::debug!(
task_id = %task_id,
subscriber_index = i,
"✅ Successfully notified subscriber"
);
}
}
count
} else {
#[cfg(feature = "tracing")]
tracing::debug!(
task_id = %task_id,
"no WebSocket subscribers for task; status broadcast skipped"
);
0
}
};
#[cfg(feature = "tracing")]
tracing::debug!(
task_id = %task_id,
notified_count = subscriber_count,
"📡 Finished broadcasting to WebSocket subscribers"
);
if let Err(e) = self
.push_notification_registry
.send_status_update(task_id, &event)
.await
{
eprintln!("Failed to send push notification: {}", e);
}
Ok(())
}
pub(crate) async fn broadcast_artifact_update(
&self,
task_id: &str,
artifact: Artifact,
_index: Option<u32>,
_final: bool,
) -> Result<(), A2AError> {
let context_id = self.get_task_context_id(task_id).await;
let event = TaskArtifactUpdateEvent {
task_id: task_id.to_string(),
context_id,
kind: "artifact-update".to_string(),
artifact,
append: None,
last_chunk: None,
metadata: None,
};
{
let subscribers_guard = self.subscribers.lock().await;
if let Some(task_subscribers) = subscribers_guard.get(task_id) {
for subscriber in task_subscribers.artifacts.iter() {
if let Err(e) = subscriber.on_update(event.clone()).await {
eprintln!("Failed to notify subscriber: {}", e);
}
}
}
};
if let Err(e) = self
.push_notification_registry
.send_artifact_update(task_id, &event)
.await
{
eprintln!("Failed to send push notification: {}", e);
}
Ok(())
}
}
#[async_trait]
impl AsyncTaskManager for InMemoryTaskStorage {
async fn create_task(&self, task_id: &str, context_id: &str) -> Result<Task, A2AError> {
let mut tasks_guard = self.tasks.lock().await;
if tasks_guard.contains_key(task_id) {
return Err(A2AError::TaskNotFound(format!(
"Task {} already exists",
task_id
)));
}
let task = Task::new(task_id.to_string(), context_id.to_string());
tasks_guard.insert(task_id.to_string(), task.clone());
Ok(task)
}
async fn update_task_status(
&self,
task_id: &str,
state: TaskState,
message: Option<Message>,
) -> Result<Task, A2AError> {
let mut tasks_guard = self.tasks.lock().await;
let task = tasks_guard
.get_mut(task_id)
.ok_or_else(|| A2AError::TaskNotFound(task_id.to_string()))?;
task.update_status(state, message);
let status_for_broadcast = task.status.clone().into_option().unwrap_or_default();
let updated_task = task.clone();
drop(tasks_guard);
self.broadcast_status_update(task_id, status_for_broadcast)
.await?;
Ok(updated_task)
}
async fn task_exists(&self, task_id: &str) -> Result<bool, A2AError> {
let tasks_guard = self.tasks.lock().await;
Ok(tasks_guard.contains_key(task_id))
}
async fn get_task(&self, task_id: &str, history_length: Option<u32>) -> Result<Task, A2AError> {
let task = {
let tasks_guard = self.tasks.lock().await;
let Some(task) = tasks_guard.get(task_id) else {
return Err(A2AError::TaskNotFound(task_id.to_string()));
};
task.with_limited_history(history_length)
};
Ok(task)
}
async fn cancel_task(&self, task_id: &str) -> Result<Task, A2AError> {
let (task, status_for_broadcast) = {
let mut tasks_guard = self.tasks.lock().await;
let Some(task) = tasks_guard.get(task_id) else {
return Err(A2AError::TaskNotFound(task_id.to_string()));
};
let mut updated_task = task.clone();
if updated_task.status.state != TaskState::Working {
return Err(A2AError::TaskNotCancelable(format!(
"Task {} is in state {:?} and cannot be canceled",
task_id, updated_task.status.state
)));
}
let cancel_message = Message {
role: ::buffa::EnumValue::from(crate::domain::Role::Agent),
parts: vec![crate::domain::Part::text(format!(
"Task {} canceled.",
task_id
))],
message_id: uuid::Uuid::new_v4().to_string(),
task_id: task_id.to_string(),
context_id: updated_task.context_id.clone(),
..Default::default()
};
updated_task.update_status(TaskState::Canceled, Some(cancel_message));
let status_for_broadcast = updated_task
.status
.clone()
.into_option()
.unwrap_or_default();
tasks_guard.insert(task_id.to_string(), updated_task.clone());
drop(tasks_guard);
(updated_task, status_for_broadcast)
};
self.broadcast_status_update(task_id, status_for_broadcast)
.await?;
Ok(task)
}
async fn list_tasks_v3(
&self,
params: &crate::domain::ListTasksParams,
) -> Result<crate::domain::ListTasksResult, A2AError> {
use crate::domain::ListTasksResult;
let tasks_guard = self.tasks.lock().await;
let mut filtered_tasks: Vec<_> = tasks_guard
.values()
.filter(|task| {
if let Some(ref context_id) = params.context_id {
if &task.context_id != context_id {
return false;
}
}
if let Some(ref status) = params.status {
if &task.status.state != status {
return false;
}
}
if let Some(status_timestamp_after) = ¶ms.status_timestamp_after {
if let Ok(after_dt) =
chrono::DateTime::parse_from_rfc3339(status_timestamp_after)
{
let after_utc = after_dt.with_timezone(&chrono::Utc);
if let Some(timestamp) = task.status.timestamp_utc() {
if timestamp <= after_utc {
return false;
}
}
}
}
true
})
.cloned()
.collect();
filtered_tasks.sort_by(|a, b| {
let a_time = a
.status
.timestamp_utc()
.map(|t| t.timestamp_millis())
.unwrap_or(0);
let b_time = b
.status
.timestamp_utc()
.map(|t| t.timestamp_millis())
.unwrap_or(0);
b_time.cmp(&a_time)
});
let total_size = filtered_tasks.len() as i32;
let page_size = params.page_size.unwrap_or(50).clamp(1, 100) as usize;
let page_start = if let Some(ref token) = params.page_token {
token.parse::<usize>().unwrap_or(0)
} else {
0
};
let page_end = (page_start + page_size).min(filtered_tasks.len());
let has_more = page_end < filtered_tasks.len();
let mut page_tasks: Vec<_> = filtered_tasks[page_start..page_end].to_vec();
let history_length = params.history_length.unwrap_or(0);
for task in &mut page_tasks {
*task = task.with_limited_history(Some(history_length as u32));
if !params.include_artifacts.unwrap_or(false) {
task.artifacts.clear();
}
}
let next_page_token = if has_more {
page_end.to_string()
} else {
String::new()
};
Ok(ListTasksResult {
tasks: page_tasks,
total_size,
page_size: page_size as i32,
next_page_token,
})
}
async fn get_push_notification_config(
&self,
params: &crate::domain::GetTaskPushNotificationConfigParams,
) -> Result<crate::domain::TaskPushNotificationConfig, A2AError> {
self.get_task_notification(¶ms.id).await
}
async fn list_push_notification_configs(
&self,
params: &crate::domain::ListTaskPushNotificationConfigsParams,
) -> Result<Vec<crate::domain::TaskPushNotificationConfig>, A2AError> {
match self
.push_notification_registry
.get_config(¶ms.id)
.await?
{
Some(config) => Ok(vec![config]),
None => Ok(vec![]),
}
}
async fn delete_push_notification_config(
&self,
params: &crate::domain::DeleteTaskPushNotificationConfigParams,
) -> Result<(), A2AError> {
self.remove_task_notification(¶ms.id).await
}
}
#[async_trait]
impl AsyncNotificationManager for InMemoryTaskStorage {
async fn set_task_notification(
&self,
config: &TaskPushNotificationConfig,
) -> Result<TaskPushNotificationConfig, A2AError> {
#[cfg(feature = "tracing")]
tracing::info!(
task_id = %config.task_id,
url = %config.url,
"🚀 Registering push notification config for task"
);
self.push_notification_registry
.register(&config.task_id, config.clone())
.await?;
#[cfg(feature = "tracing")]
tracing::info!(
task_id = %config.task_id,
"✅ Push notification config registered successfully"
);
Ok(config.clone())
}
async fn get_task_notification(
&self,
task_id: &str,
) -> Result<TaskPushNotificationConfig, A2AError> {
match self.push_notification_registry.get_config(task_id).await? {
Some(config) => Ok(config),
None => Err(A2AError::PushNotificationNotSupported),
}
}
async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError> {
self.push_notification_registry.unregister(task_id).await?;
Ok(())
}
}
#[async_trait]
impl AsyncStreamingHandler for InMemoryTaskStorage {
async fn add_status_subscriber(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError> {
#[cfg(feature = "tracing")]
tracing::info!(
task_id = %task_id,
"✅ Adding WebSocket subscriber for status updates"
);
{
let mut subscribers_guard = self.subscribers.lock().await;
let task_subscribers = subscribers_guard
.entry(task_id.to_string())
.or_insert_with(TaskSubscribers::new);
task_subscribers.status.push(subscriber);
#[cfg(feature = "tracing")]
tracing::info!(
task_id = %task_id,
subscriber_count = task_subscribers.status.len(),
"✅ WebSocket subscriber added successfully"
);
}
if let Ok(task) = self.get_task(task_id, None).await {
let _ = self
.broadcast_status_update(
task_id,
task.status.clone().into_option().unwrap_or_default(),
)
.await;
}
Ok(format!("status-{}-{}", task_id, uuid::Uuid::new_v4()))
}
async fn add_artifact_subscriber(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError> {
{
let mut subscribers_guard = self.subscribers.lock().await;
let task_subscribers = subscribers_guard
.entry(task_id.to_string())
.or_insert_with(TaskSubscribers::new);
task_subscribers.artifacts.push(subscriber);
}
if let Ok(task) = self.get_task(task_id, None).await {
for artifact in &task.artifacts {
let _ = self
.broadcast_artifact_update(task_id, artifact.clone(), None, false)
.await;
}
}
Ok(format!("artifact-{}-{}", task_id, uuid::Uuid::new_v4()))
}
async fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> {
Err(A2AError::UnsupportedOperation(
"Subscription removal by ID requires storage layer refactoring".to_string(),
))
}
async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
{
let mut subscribers_guard = self.subscribers.lock().await;
subscribers_guard.remove(task_id);
}
Ok(())
}
async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
let subscribers_guard = self.subscribers.lock().await;
if let Some(task_subscribers) = subscribers_guard.get(task_id) {
Ok(task_subscribers.status.len() + task_subscribers.artifacts.len())
} else {
Ok(0)
}
}
async fn broadcast_status_update(
&self,
task_id: &str,
update: TaskStatusUpdateEvent,
) -> Result<(), A2AError> {
self.broadcast_status_update(task_id, update.status).await
}
async fn broadcast_artifact_update(
&self,
task_id: &str,
update: TaskArtifactUpdateEvent,
) -> Result<(), A2AError> {
self.broadcast_artifact_update(
task_id,
update.artifact,
None,
update.last_chunk.unwrap_or(false),
)
.await
}
async fn status_update_stream(
&self,
_task_id: &str,
) -> Result<
std::pin::Pin<
Box<dyn futures::Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>,
>,
A2AError,
> {
Err(A2AError::UnsupportedOperation(
"Status update stream requires storage layer refactoring".to_string(),
))
}
async fn artifact_update_stream(
&self,
_task_id: &str,
) -> Result<
std::pin::Pin<
Box<dyn futures::Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>,
>,
A2AError,
> {
Err(A2AError::UnsupportedOperation(
"Artifact update stream requires storage layer refactoring".to_string(),
))
}
async fn combined_update_stream(
&self,
_task_id: &str,
) -> Result<
std::pin::Pin<
Box<
dyn futures::Stream<
Item = Result<crate::port::streaming_handler::UpdateEvent, A2AError>,
> + Send,
>,
>,
A2AError,
> {
Err(A2AError::UnsupportedOperation(
"Combined update stream requires storage layer refactoring".to_string(),
))
}
}
impl Clone for InMemoryTaskStorage {
fn clone(&self) -> Self {
Self {
tasks: self.tasks.clone(),
subscribers: self.subscribers.clone(),
push_notification_registry: self.push_notification_registry.clone(),
}
}
}