Skip to main content

a2a_rust/server/
handler.rs

1use std::collections::BTreeSet;
2use std::pin::Pin;
3
4use async_trait::async_trait;
5use axum::http::HeaderMap;
6use futures_core::Stream;
7
8use crate::A2AError;
9use crate::jsonrpc::PROTOCOL_VERSION;
10use crate::types::{
11    AgentCard, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest,
12    GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest,
13    ListTaskPushNotificationConfigsRequest, ListTaskPushNotificationConfigsResponse,
14    ListTasksRequest, ListTasksResponse, SendMessageRequest, SendMessageResponse, StreamResponse,
15    SubscribeToTaskRequest, Task, TaskPushNotificationConfig,
16};
17
18/// Server-side stream of A2A `StreamResponse` values.
19pub type A2AStream = Pin<Box<dyn Stream<Item = StreamResponse> + Send + 'static>>;
20
21/// Core server trait for implementing an A2A agent.
22///
23/// The default capability helpers call `get_agent_card()` on each gated request.
24/// Implementations that fetch the card from storage should cache it or override
25/// the relevant operation methods.
26#[async_trait]
27pub trait A2AHandler: Send + Sync + 'static {
28    /// Return the agent card served from discovery and capability endpoints.
29    async fn get_agent_card(&self) -> Result<AgentCard, A2AError>;
30
31    /// Process a unary `SendMessage` request.
32    async fn send_message(
33        &self,
34        request: SendMessageRequest,
35    ) -> Result<SendMessageResponse, A2AError>;
36
37    /// Stream responses for a submitted message.
38    ///
39    /// Message-only flows should emit exactly one `StreamResponse::Message`.
40    /// Task-based flows should emit the initial task first, followed by status
41    /// and artifact updates until the stream closes.
42    async fn send_streaming_message(
43        &self,
44        _request: SendMessageRequest,
45    ) -> Result<A2AStream, A2AError> {
46        self.require_streaming_capability("SendStreamingMessage")
47            .await?;
48        Err(A2AError::UnsupportedOperation(
49            "SendStreamingMessage".to_owned(),
50        ))
51    }
52
53    /// Fetch a task by identifier.
54    async fn get_task(&self, _request: GetTaskRequest) -> Result<Task, A2AError> {
55        Err(A2AError::UnsupportedOperation("GetTask".to_owned()))
56    }
57
58    /// List tasks visible to the caller.
59    async fn list_tasks(&self, _request: ListTasksRequest) -> Result<ListTasksResponse, A2AError> {
60        Err(A2AError::UnsupportedOperation("ListTasks".to_owned()))
61    }
62
63    /// Attempt to cancel a task.
64    async fn cancel_task(&self, _request: CancelTaskRequest) -> Result<Task, A2AError> {
65        Err(A2AError::UnsupportedOperation("CancelTask".to_owned()))
66    }
67
68    /// Subscribe to updates for an existing task.
69    ///
70    /// Implementations must emit the current `StreamResponse::Task` first before
71    /// any subsequent status or artifact updates.
72    async fn subscribe_to_task(
73        &self,
74        _request: SubscribeToTaskRequest,
75    ) -> Result<A2AStream, A2AError> {
76        self.require_streaming_capability("SubscribeToTask").await?;
77        Err(A2AError::UnsupportedOperation("SubscribeToTask".to_owned()))
78    }
79
80    /// Create or replace a push-notification configuration.
81    async fn create_task_push_notification_config(
82        &self,
83        _request: TaskPushNotificationConfig,
84    ) -> Result<TaskPushNotificationConfig, A2AError> {
85        self.require_push_notifications_capability("CreateTaskPushNotificationConfig")
86            .await?;
87        Err(A2AError::UnsupportedOperation(
88            "CreateTaskPushNotificationConfig".to_owned(),
89        ))
90    }
91
92    /// Fetch a stored push-notification configuration.
93    async fn get_task_push_notification_config(
94        &self,
95        _request: GetTaskPushNotificationConfigRequest,
96    ) -> Result<TaskPushNotificationConfig, A2AError> {
97        self.require_push_notifications_capability("GetTaskPushNotificationConfig")
98            .await?;
99        Err(A2AError::UnsupportedOperation(
100            "GetTaskPushNotificationConfig".to_owned(),
101        ))
102    }
103
104    /// List stored push-notification configurations.
105    async fn list_task_push_notification_configs(
106        &self,
107        _request: ListTaskPushNotificationConfigsRequest,
108    ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
109        self.require_push_notifications_capability("ListTaskPushNotificationConfigs")
110            .await?;
111        Err(A2AError::UnsupportedOperation(
112            "ListTaskPushNotificationConfigs".to_owned(),
113        ))
114    }
115
116    /// Delete a stored push-notification configuration.
117    async fn delete_task_push_notification_config(
118        &self,
119        _request: DeleteTaskPushNotificationConfigRequest,
120    ) -> Result<(), A2AError> {
121        self.require_push_notifications_capability("DeleteTaskPushNotificationConfig")
122            .await?;
123        Err(A2AError::UnsupportedOperation(
124            "DeleteTaskPushNotificationConfig".to_owned(),
125        ))
126    }
127
128    /// Fetch the extended agent card.
129    async fn get_extended_agent_card(
130        &self,
131        _request: GetExtendedAgentCardRequest,
132    ) -> Result<AgentCard, A2AError> {
133        self.require_extended_agent_card_capability().await?;
134        Err(A2AError::ExtendedAgentCardNotConfigured(
135            "GetExtendedAgentCard".to_owned(),
136        ))
137    }
138
139    /// Enforce the A2A streaming capability gate.
140    ///
141    /// Do not override unless you preserve the same protocol behavior.
142    async fn require_streaming_capability(&self, operation: &str) -> Result<(), A2AError> {
143        let card = self.get_agent_card().await?;
144        if card.capabilities.streaming == Some(true) {
145            return Ok(());
146        }
147
148        Err(A2AError::UnsupportedOperation(operation.to_owned()))
149    }
150
151    /// Enforce the A2A push-notifications capability gate.
152    ///
153    /// Do not override unless you preserve the same protocol behavior.
154    async fn require_push_notifications_capability(&self, operation: &str) -> Result<(), A2AError> {
155        let card = self.get_agent_card().await?;
156        if card.capabilities.push_notifications == Some(true) {
157            return Ok(());
158        }
159
160        Err(A2AError::PushNotificationNotSupported(operation.to_owned()))
161    }
162
163    /// Enforce the A2A extended-agent-card capability gate.
164    ///
165    /// Do not override unless you preserve the same protocol behavior.
166    async fn require_extended_agent_card_capability(&self) -> Result<(), A2AError> {
167        let card = self.get_agent_card().await?;
168        if card.capabilities.extended_agent_card == Some(true) {
169            return Ok(());
170        }
171
172        Err(A2AError::ExtendedAgentCardNotConfigured(
173            "GetExtendedAgentCard".to_owned(),
174        ))
175    }
176
177    /// Validate `A2A-Version` and `A2A-Extensions` request headers.
178    async fn validate_protocol_headers(&self, headers: &HeaderMap) -> Result<(), A2AError> {
179        let card = self.get_agent_card().await?;
180        validate_supported_version(&card, headers)?;
181        validate_required_extensions(&card, headers)
182    }
183
184    /// Enforce that the request version is supported by the advertised interfaces.
185    async fn require_supported_version(&self, headers: &HeaderMap) -> Result<(), A2AError> {
186        let card = self.get_agent_card().await?;
187        validate_supported_version(&card, headers)
188    }
189
190    /// Enforce that all required agent extensions are acknowledged by the caller.
191    async fn require_required_extensions(&self, headers: &HeaderMap) -> Result<(), A2AError> {
192        let card = self.get_agent_card().await?;
193        validate_required_extensions(&card, headers)
194    }
195}
196
197fn header_value(headers: &HeaderMap, name: &str) -> Option<String> {
198    headers
199        .get(name)
200        .and_then(|value| value.to_str().ok())
201        .map(ToOwned::to_owned)
202}
203
204fn validate_supported_version(card: &AgentCard, headers: &HeaderMap) -> Result<(), A2AError> {
205    let requested_version = match header_value(headers, "A2A-Version") {
206        Some(version) if version.trim().is_empty() => "0.3".to_owned(),
207        Some(version) => version,
208        None => PROTOCOL_VERSION.to_owned(),
209    };
210    let supported_versions = card
211        .supported_interfaces
212        .iter()
213        .map(|interface| interface.protocol_version.as_str())
214        .collect::<BTreeSet<_>>();
215
216    if supported_versions.is_empty() || supported_versions.contains(requested_version.as_str()) {
217        return Ok(());
218    }
219
220    Err(A2AError::VersionNotSupported(requested_version))
221}
222
223fn validate_required_extensions(card: &AgentCard, headers: &HeaderMap) -> Result<(), A2AError> {
224    let required_extensions = card
225        .capabilities
226        .extensions
227        .iter()
228        .filter(|extension| extension.required)
229        .map(|extension| extension.uri.as_str())
230        .collect::<BTreeSet<_>>();
231
232    if required_extensions.is_empty() {
233        return Ok(());
234    }
235
236    let announced_extensions = header_value(headers, "A2A-Extensions")
237        .into_iter()
238        .flat_map(|value| {
239            value
240                .split(',')
241                .map(str::trim)
242                .filter(|value| !value.is_empty())
243                .map(ToOwned::to_owned)
244                .collect::<Vec<_>>()
245        })
246        .collect::<BTreeSet<_>>();
247
248    let missing = required_extensions
249        .into_iter()
250        .filter(|extension| !announced_extensions.contains(*extension))
251        .collect::<Vec<_>>();
252
253    if missing.is_empty() {
254        return Ok(());
255    }
256
257    Err(A2AError::ExtensionSupportRequired(format!(
258        "missing required extensions: {}",
259        missing.join(", ")
260    )))
261}