use std::sync::Arc;
use async_trait::async_trait;
use a2a_rs::{
adapter::{business::DefaultMessageHandler, storage::InMemoryTaskStorage},
domain::{A2AError, Message, Task, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent},
port::{
AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager,
MessageHandler, NotificationManager, StreamingHandler, TaskManager,
streaming_handler::Subscriber,
},
};
#[derive(Clone)]
pub struct TestBusinessHandler {
storage: Arc<InMemoryTaskStorage>,
}
impl TestBusinessHandler {
pub fn new() -> Self {
Self {
storage: Arc::new(InMemoryTaskStorage::new()),
}
}
#[allow(dead_code)]
pub fn with_storage(storage: InMemoryTaskStorage) -> Self {
Self {
storage: Arc::new(storage),
}
}
#[allow(dead_code)]
pub fn storage(&self) -> &Arc<InMemoryTaskStorage> {
&self.storage
}
}
impl Default for TestBusinessHandler {
fn default() -> Self {
Self::new()
}
}
impl MessageHandler for TestBusinessHandler {
fn process_message(
&self,
_task_id: &str,
_message: &Message,
_session_id: Option<&str>,
) -> Result<Task, A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous message processing not supported. Use async version.".to_string(),
))
}
}
impl TaskManager for TestBusinessHandler {
fn create_task(&self, _task_id: &str, _context_id: &str) -> Result<Task, A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous task creation not supported. Use async version.".to_string(),
))
}
fn get_task(&self, _task_id: &str, _history_length: Option<u32>) -> Result<Task, A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous task retrieval not supported. Use async version.".to_string(),
))
}
fn update_task_status(
&self,
_task_id: &str,
_state: TaskState,
_message: Option<Message>,
) -> Result<Task, A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous task status update not supported. Use async version.".to_string(),
))
}
fn cancel_task(&self, _task_id: &str) -> Result<Task, A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous task cancellation not supported. Use async version.".to_string(),
))
}
fn task_exists(&self, _task_id: &str) -> Result<bool, A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous task existence check not supported. Use async version.".to_string(),
))
}
}
impl NotificationManager for TestBusinessHandler {
fn set_task_notification(
&self,
_config: &a2a_rs::domain::TaskPushNotificationConfig,
) -> Result<a2a_rs::domain::TaskPushNotificationConfig, A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous notification setup not supported. Use async version.".to_string(),
))
}
fn get_task_notification(
&self,
_task_id: &str,
) -> Result<a2a_rs::domain::TaskPushNotificationConfig, A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous notification retrieval not supported. Use async version.".to_string(),
))
}
fn remove_task_notification(&self, _task_id: &str) -> Result<(), A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous notification removal not supported. Use async version.".to_string(),
))
}
}
impl StreamingHandler for TestBusinessHandler {
fn add_status_subscriber(
&self,
_task_id: &str,
_subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous streaming subscription not supported. Use async version.".to_string(),
))
}
fn add_artifact_subscriber(
&self,
_task_id: &str,
_subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous streaming subscription not supported. Use async version.".to_string(),
))
}
fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous streaming unsubscription not supported. Use async version.".to_string(),
))
}
fn remove_task_subscribers(&self, _task_id: &str) -> Result<(), A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous streaming unsubscription not supported. Use async version.".to_string(),
))
}
fn get_subscriber_count(&self, _task_id: &str) -> Result<usize, A2AError> {
Err(A2AError::UnsupportedOperation(
"Synchronous subscriber count not supported. Use async version.".to_string(),
))
}
}
#[async_trait]
impl AsyncMessageHandler for TestBusinessHandler {
async fn process_message(
&self,
task_id: &str,
message: &Message,
session_id: Option<&str>,
) -> Result<Task, A2AError> {
let message_handler = DefaultMessageHandler::new((*self.storage).clone());
message_handler
.process_message(task_id, message, session_id)
.await
}
}
#[async_trait]
impl AsyncTaskManager for TestBusinessHandler {
async fn create_task(&self, task_id: &str, context_id: &str) -> Result<Task, A2AError> {
self.storage.create_task(task_id, context_id).await
}
async fn get_task(&self, task_id: &str, history_length: Option<u32>) -> Result<Task, A2AError> {
self.storage.get_task(task_id, history_length).await
}
async fn update_task_status(
&self,
task_id: &str,
state: TaskState,
message: Option<Message>,
) -> Result<Task, A2AError> {
self.storage
.update_task_status(task_id, state, message)
.await
}
async fn cancel_task(&self, task_id: &str) -> Result<Task, A2AError> {
self.storage.cancel_task(task_id).await
}
async fn task_exists(&self, task_id: &str) -> Result<bool, A2AError> {
self.storage.task_exists(task_id).await
}
async fn list_tasks_v3(
&self,
params: &a2a_rs::domain::ListTasksParams,
) -> Result<a2a_rs::domain::ListTasksResult, A2AError> {
self.storage.list_tasks_v3(params).await
}
async fn get_push_notification_config(
&self,
params: &a2a_rs::domain::GetTaskPushNotificationConfigParams,
) -> Result<a2a_rs::domain::TaskPushNotificationConfig, A2AError> {
self.storage.get_push_notification_config(params).await
}
async fn list_push_notification_configs(
&self,
params: &a2a_rs::domain::ListTaskPushNotificationConfigsParams,
) -> Result<Vec<a2a_rs::domain::TaskPushNotificationConfig>, A2AError> {
self.storage.list_push_notification_configs(params).await
}
async fn delete_push_notification_config(
&self,
params: &a2a_rs::domain::DeleteTaskPushNotificationConfigParams,
) -> Result<(), A2AError> {
self.storage.delete_push_notification_config(params).await
}
}
#[async_trait]
impl AsyncNotificationManager for TestBusinessHandler {
async fn set_task_notification(
&self,
config: &a2a_rs::domain::TaskPushNotificationConfig,
) -> Result<a2a_rs::domain::TaskPushNotificationConfig, A2AError> {
self.storage.set_task_notification(config).await
}
async fn get_task_notification(
&self,
task_id: &str,
) -> Result<a2a_rs::domain::TaskPushNotificationConfig, A2AError> {
self.storage.get_task_notification(task_id).await
}
async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError> {
self.storage.remove_task_notification(task_id).await
}
}
#[async_trait]
impl AsyncStreamingHandler for TestBusinessHandler {
async fn add_status_subscriber(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError> {
self.storage
.add_status_subscriber(task_id, subscriber)
.await
}
async fn add_artifact_subscriber(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError> {
self.storage
.add_artifact_subscriber(task_id, subscriber)
.await
}
async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError> {
self.storage.remove_subscription(subscription_id).await
}
async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
self.storage.remove_task_subscribers(task_id).await
}
async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
self.storage.get_subscriber_count(task_id).await
}
async fn broadcast_status_update(
&self,
task_id: &str,
update: TaskStatusUpdateEvent,
) -> Result<(), A2AError> {
self.storage.broadcast_status_update(task_id, update).await
}
async fn broadcast_artifact_update(
&self,
task_id: &str,
update: TaskArtifactUpdateEvent,
) -> Result<(), A2AError> {
self.storage
.broadcast_artifact_update(task_id, update)
.await
}
async fn status_update_stream(
&self,
task_id: &str,
) -> Result<
std::pin::Pin<
Box<dyn futures::Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>,
>,
A2AError,
> {
self.storage.status_update_stream(task_id).await
}
async fn artifact_update_stream(
&self,
task_id: &str,
) -> Result<
std::pin::Pin<
Box<dyn futures::Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>,
>,
A2AError,
> {
self.storage.artifact_update_stream(task_id).await
}
async fn combined_update_stream(
&self,
task_id: &str,
) -> Result<
std::pin::Pin<
Box<
dyn futures::Stream<
Item = Result<a2a_rs::port::streaming_handler::UpdateEvent, A2AError>,
> + Send,
>,
>,
A2AError,
> {
self.storage.combined_update_stream(task_id).await
}
}