use std::sync::Arc;
use async_trait::async_trait;
use a2a_rs::{
adapter::storage::InMemoryTaskStorage,
domain::{
A2AError, Message, Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig, TaskState,
TaskStatusUpdateEvent,
},
port::{
AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager,
streaming_handler::Subscriber,
},
};
#[derive(Clone)]
pub struct SimpleAgentHandler {
storage: Arc<InMemoryTaskStorage>,
}
impl SimpleAgentHandler {
pub fn new() -> Self {
Self {
storage: Arc::new(InMemoryTaskStorage::new()),
}
}
pub fn with_storage(storage: InMemoryTaskStorage) -> Self {
Self {
storage: Arc::new(storage),
}
}
#[allow(dead_code)]
pub fn storage(&self) -> &Arc<InMemoryTaskStorage> {
&self.storage
}
}
impl Default for SimpleAgentHandler {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AsyncMessageHandler for SimpleAgentHandler {
async fn process_message(
&self,
task_id: &str,
message: &Message,
session_id: Option<&str>,
) -> Result<Task, A2AError> {
let message_handler =
a2a_rs::adapter::business::DefaultMessageHandler::new((*self.storage).clone());
message_handler
.process_message(task_id, message, session_id)
.await
}
}
#[async_trait]
impl AsyncTaskManager for SimpleAgentHandler {
async fn create_task(&self, task_id: &str, context_id: &str) -> Result<Task, A2AError> {
self.storage.create_task(task_id, context_id).await
}
async fn get_task(&self, task_id: &str, history_length: Option<u32>) -> Result<Task, A2AError> {
self.storage.get_task(task_id, history_length).await
}
async fn update_task_status(
&self,
task_id: &str,
state: TaskState,
message: Option<Message>,
) -> Result<Task, A2AError> {
self.storage
.update_task_status(task_id, state, message)
.await
}
async fn cancel_task(&self, task_id: &str) -> Result<Task, A2AError> {
self.storage.cancel_task(task_id).await
}
async fn task_exists(&self, task_id: &str) -> Result<bool, A2AError> {
self.storage.task_exists(task_id).await
}
}
#[async_trait]
impl AsyncNotificationManager for SimpleAgentHandler {
async fn set_task_notification(
&self,
config: &TaskPushNotificationConfig,
) -> Result<TaskPushNotificationConfig, A2AError> {
self.storage.set_task_notification(config).await
}
async fn get_task_notification(
&self,
task_id: &str,
) -> Result<TaskPushNotificationConfig, A2AError> {
self.storage.get_task_notification(task_id).await
}
async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError> {
self.storage.remove_task_notification(task_id).await
}
}
#[async_trait]
impl AsyncStreamingHandler for SimpleAgentHandler {
async fn add_status_subscriber(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError> {
self.storage
.add_status_subscriber(task_id, subscriber)
.await
}
async fn add_artifact_subscriber(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError> {
self.storage
.add_artifact_subscriber(task_id, subscriber)
.await
}
async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError> {
self.storage.remove_subscription(subscription_id).await
}
async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
self.storage.remove_task_subscribers(task_id).await
}
async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
self.storage.get_subscriber_count(task_id).await
}
async fn broadcast_status_update(
&self,
task_id: &str,
update: TaskStatusUpdateEvent,
) -> Result<(), A2AError> {
self.storage.broadcast_status_update(task_id, update).await
}
async fn broadcast_artifact_update(
&self,
task_id: &str,
update: TaskArtifactUpdateEvent,
) -> Result<(), A2AError> {
self.storage
.broadcast_artifact_update(task_id, update)
.await
}
async fn status_update_stream(
&self,
task_id: &str,
) -> Result<
std::pin::Pin<
Box<dyn futures::Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>,
>,
A2AError,
> {
self.storage.status_update_stream(task_id).await
}
async fn artifact_update_stream(
&self,
task_id: &str,
) -> Result<
std::pin::Pin<
Box<dyn futures::Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>,
>,
A2AError,
> {
self.storage.artifact_update_stream(task_id).await
}
async fn combined_update_stream(
&self,
task_id: &str,
) -> Result<
std::pin::Pin<
Box<
dyn futures::Stream<
Item = Result<a2a_rs::port::streaming_handler::UpdateEvent, A2AError>,
> + Send,
>,
>,
A2AError,
> {
self.storage.combined_update_stream(task_id).await
}
}