use std::collections::BTreeSet;
use std::pin::Pin;
use async_trait::async_trait;
use axum::http::HeaderMap;
use futures_core::Stream;
use crate::A2AError;
use crate::jsonrpc::PROTOCOL_VERSION;
use crate::types::{
AgentCard, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest,
ListTaskPushNotificationConfigsRequest, ListTaskPushNotificationConfigsResponse,
ListTasksRequest, ListTasksResponse, SendMessageRequest, SendMessageResponse, StreamResponse,
SubscribeToTaskRequest, Task, TaskPushNotificationConfig,
};
pub type A2AStream = Pin<Box<dyn Stream<Item = StreamResponse> + Send + 'static>>;
#[async_trait]
pub trait A2AHandler: Send + Sync + 'static {
async fn get_agent_card(&self) -> Result<AgentCard, A2AError>;
async fn send_message(
&self,
request: SendMessageRequest,
) -> Result<SendMessageResponse, A2AError>;
async fn send_streaming_message(
&self,
_request: SendMessageRequest,
) -> Result<A2AStream, A2AError> {
self.require_streaming_capability("SendStreamingMessage")
.await?;
Err(A2AError::UnsupportedOperation(
"SendStreamingMessage".to_owned(),
))
}
async fn get_task(&self, _request: GetTaskRequest) -> Result<Task, A2AError> {
Err(A2AError::UnsupportedOperation("GetTask".to_owned()))
}
async fn list_tasks(&self, _request: ListTasksRequest) -> Result<ListTasksResponse, A2AError> {
Err(A2AError::UnsupportedOperation("ListTasks".to_owned()))
}
async fn cancel_task(&self, _request: CancelTaskRequest) -> Result<Task, A2AError> {
Err(A2AError::UnsupportedOperation("CancelTask".to_owned()))
}
async fn subscribe_to_task(
&self,
_request: SubscribeToTaskRequest,
) -> Result<A2AStream, A2AError> {
self.require_streaming_capability("SubscribeToTask").await?;
Err(A2AError::UnsupportedOperation("SubscribeToTask".to_owned()))
}
async fn create_task_push_notification_config(
&self,
_request: TaskPushNotificationConfig,
) -> Result<TaskPushNotificationConfig, A2AError> {
self.require_push_notifications_capability("CreateTaskPushNotificationConfig")
.await?;
Err(A2AError::UnsupportedOperation(
"CreateTaskPushNotificationConfig".to_owned(),
))
}
async fn get_task_push_notification_config(
&self,
_request: GetTaskPushNotificationConfigRequest,
) -> Result<TaskPushNotificationConfig, A2AError> {
self.require_push_notifications_capability("GetTaskPushNotificationConfig")
.await?;
Err(A2AError::UnsupportedOperation(
"GetTaskPushNotificationConfig".to_owned(),
))
}
async fn list_task_push_notification_configs(
&self,
_request: ListTaskPushNotificationConfigsRequest,
) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
self.require_push_notifications_capability("ListTaskPushNotificationConfigs")
.await?;
Err(A2AError::UnsupportedOperation(
"ListTaskPushNotificationConfigs".to_owned(),
))
}
async fn delete_task_push_notification_config(
&self,
_request: DeleteTaskPushNotificationConfigRequest,
) -> Result<(), A2AError> {
self.require_push_notifications_capability("DeleteTaskPushNotificationConfig")
.await?;
Err(A2AError::UnsupportedOperation(
"DeleteTaskPushNotificationConfig".to_owned(),
))
}
async fn get_extended_agent_card(
&self,
_request: GetExtendedAgentCardRequest,
) -> Result<AgentCard, A2AError> {
self.require_extended_agent_card_capability().await?;
Err(A2AError::ExtendedAgentCardNotConfigured(
"GetExtendedAgentCard".to_owned(),
))
}
async fn require_streaming_capability(&self, operation: &str) -> Result<(), A2AError> {
let card = self.get_agent_card().await?;
if card.capabilities.streaming == Some(true) {
return Ok(());
}
Err(A2AError::UnsupportedOperation(operation.to_owned()))
}
async fn require_push_notifications_capability(&self, operation: &str) -> Result<(), A2AError> {
let card = self.get_agent_card().await?;
if card.capabilities.push_notifications == Some(true) {
return Ok(());
}
Err(A2AError::PushNotificationNotSupported(operation.to_owned()))
}
async fn require_extended_agent_card_capability(&self) -> Result<(), A2AError> {
let card = self.get_agent_card().await?;
if card.capabilities.extended_agent_card == Some(true) {
return Ok(());
}
Err(A2AError::ExtendedAgentCardNotConfigured(
"GetExtendedAgentCard".to_owned(),
))
}
async fn validate_protocol_headers(&self, headers: &HeaderMap) -> Result<(), A2AError> {
let card = self.get_agent_card().await?;
validate_supported_version(&card, headers)?;
validate_required_extensions(&card, headers)
}
async fn require_supported_version(&self, headers: &HeaderMap) -> Result<(), A2AError> {
let card = self.get_agent_card().await?;
validate_supported_version(&card, headers)
}
async fn require_required_extensions(&self, headers: &HeaderMap) -> Result<(), A2AError> {
let card = self.get_agent_card().await?;
validate_required_extensions(&card, headers)
}
}
fn header_value(headers: &HeaderMap, name: &str) -> Option<String> {
headers
.get(name)
.and_then(|value| value.to_str().ok())
.map(ToOwned::to_owned)
}
fn validate_supported_version(card: &AgentCard, headers: &HeaderMap) -> Result<(), A2AError> {
let requested_version = match header_value(headers, "A2A-Version") {
Some(version) if version.trim().is_empty() => "0.3".to_owned(),
Some(version) => version,
None => PROTOCOL_VERSION.to_owned(),
};
let supported_versions = card
.supported_interfaces
.iter()
.map(|interface| interface.protocol_version.as_str())
.collect::<BTreeSet<_>>();
if supported_versions.is_empty() || supported_versions.contains(requested_version.as_str()) {
return Ok(());
}
Err(A2AError::VersionNotSupported(requested_version))
}
fn validate_required_extensions(card: &AgentCard, headers: &HeaderMap) -> Result<(), A2AError> {
let required_extensions = card
.capabilities
.extensions
.iter()
.filter(|extension| extension.required)
.map(|extension| extension.uri.as_str())
.collect::<BTreeSet<_>>();
if required_extensions.is_empty() {
return Ok(());
}
let announced_extensions = header_value(headers, "A2A-Extensions")
.into_iter()
.flat_map(|value| {
value
.split(',')
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
})
.collect::<BTreeSet<_>>();
let missing = required_extensions
.into_iter()
.filter(|extension| !announced_extensions.contains(*extension))
.collect::<Vec<_>>();
if missing.is_empty() {
return Ok(());
}
Err(A2AError::ExtensionSupportRequired(format!(
"missing required extensions: {}",
missing.join(", ")
)))
}