a2a_rust/server/
handler.rs1use std::collections::BTreeSet;
2use std::pin::Pin;
3
4use async_trait::async_trait;
5use axum::http::HeaderMap;
6use futures_core::Stream;
7
8use crate::A2AError;
9use crate::jsonrpc::PROTOCOL_VERSION;
10use crate::types::{
11 AgentCard, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest,
12 GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest,
13 ListTaskPushNotificationConfigsRequest, ListTaskPushNotificationConfigsResponse,
14 ListTasksRequest, ListTasksResponse, SendMessageRequest, SendMessageResponse, StreamResponse,
15 SubscribeToTaskRequest, Task, TaskPushNotificationConfig,
16};
17
18pub type A2AStream = Pin<Box<dyn Stream<Item = StreamResponse> + Send + 'static>>;
20
21#[async_trait]
27pub trait A2AHandler: Send + Sync + 'static {
28 async fn get_agent_card(&self) -> Result<AgentCard, A2AError>;
30
31 async fn send_message(
33 &self,
34 request: SendMessageRequest,
35 ) -> Result<SendMessageResponse, A2AError>;
36
37 async fn send_streaming_message(
43 &self,
44 _request: SendMessageRequest,
45 ) -> Result<A2AStream, A2AError> {
46 self.require_streaming_capability("SendStreamingMessage")
47 .await?;
48 Err(A2AError::UnsupportedOperation(
49 "SendStreamingMessage".to_owned(),
50 ))
51 }
52
53 async fn get_task(&self, _request: GetTaskRequest) -> Result<Task, A2AError> {
55 Err(A2AError::UnsupportedOperation("GetTask".to_owned()))
56 }
57
58 async fn list_tasks(&self, _request: ListTasksRequest) -> Result<ListTasksResponse, A2AError> {
60 Err(A2AError::UnsupportedOperation("ListTasks".to_owned()))
61 }
62
63 async fn cancel_task(&self, _request: CancelTaskRequest) -> Result<Task, A2AError> {
65 Err(A2AError::UnsupportedOperation("CancelTask".to_owned()))
66 }
67
68 async fn subscribe_to_task(
73 &self,
74 _request: SubscribeToTaskRequest,
75 ) -> Result<A2AStream, A2AError> {
76 self.require_streaming_capability("SubscribeToTask").await?;
77 Err(A2AError::UnsupportedOperation("SubscribeToTask".to_owned()))
78 }
79
80 async fn create_task_push_notification_config(
82 &self,
83 _request: TaskPushNotificationConfig,
84 ) -> Result<TaskPushNotificationConfig, A2AError> {
85 self.require_push_notifications_capability("CreateTaskPushNotificationConfig")
86 .await?;
87 Err(A2AError::UnsupportedOperation(
88 "CreateTaskPushNotificationConfig".to_owned(),
89 ))
90 }
91
92 async fn get_task_push_notification_config(
94 &self,
95 _request: GetTaskPushNotificationConfigRequest,
96 ) -> Result<TaskPushNotificationConfig, A2AError> {
97 self.require_push_notifications_capability("GetTaskPushNotificationConfig")
98 .await?;
99 Err(A2AError::UnsupportedOperation(
100 "GetTaskPushNotificationConfig".to_owned(),
101 ))
102 }
103
104 async fn list_task_push_notification_configs(
106 &self,
107 _request: ListTaskPushNotificationConfigsRequest,
108 ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
109 self.require_push_notifications_capability("ListTaskPushNotificationConfigs")
110 .await?;
111 Err(A2AError::UnsupportedOperation(
112 "ListTaskPushNotificationConfigs".to_owned(),
113 ))
114 }
115
116 async fn delete_task_push_notification_config(
118 &self,
119 _request: DeleteTaskPushNotificationConfigRequest,
120 ) -> Result<(), A2AError> {
121 self.require_push_notifications_capability("DeleteTaskPushNotificationConfig")
122 .await?;
123 Err(A2AError::UnsupportedOperation(
124 "DeleteTaskPushNotificationConfig".to_owned(),
125 ))
126 }
127
128 async fn get_extended_agent_card(
130 &self,
131 _request: GetExtendedAgentCardRequest,
132 ) -> Result<AgentCard, A2AError> {
133 self.require_extended_agent_card_capability().await?;
134 Err(A2AError::ExtendedAgentCardNotConfigured(
135 "GetExtendedAgentCard".to_owned(),
136 ))
137 }
138
139 async fn require_streaming_capability(&self, operation: &str) -> Result<(), A2AError> {
143 let card = self.get_agent_card().await?;
144 if card.capabilities.streaming == Some(true) {
145 return Ok(());
146 }
147
148 Err(A2AError::UnsupportedOperation(operation.to_owned()))
149 }
150
151 async fn require_push_notifications_capability(&self, operation: &str) -> Result<(), A2AError> {
155 let card = self.get_agent_card().await?;
156 if card.capabilities.push_notifications == Some(true) {
157 return Ok(());
158 }
159
160 Err(A2AError::PushNotificationNotSupported(operation.to_owned()))
161 }
162
163 async fn require_extended_agent_card_capability(&self) -> Result<(), A2AError> {
167 let card = self.get_agent_card().await?;
168 if card.capabilities.extended_agent_card == Some(true) {
169 return Ok(());
170 }
171
172 Err(A2AError::ExtendedAgentCardNotConfigured(
173 "GetExtendedAgentCard".to_owned(),
174 ))
175 }
176
177 async fn validate_protocol_headers(&self, headers: &HeaderMap) -> Result<(), A2AError> {
179 let card = self.get_agent_card().await?;
180 validate_supported_version(&card, headers)?;
181 validate_required_extensions(&card, headers)
182 }
183
184 async fn require_supported_version(&self, headers: &HeaderMap) -> Result<(), A2AError> {
186 let card = self.get_agent_card().await?;
187 validate_supported_version(&card, headers)
188 }
189
190 async fn require_required_extensions(&self, headers: &HeaderMap) -> Result<(), A2AError> {
192 let card = self.get_agent_card().await?;
193 validate_required_extensions(&card, headers)
194 }
195}
196
197fn header_value(headers: &HeaderMap, name: &str) -> Option<String> {
198 headers
199 .get(name)
200 .and_then(|value| value.to_str().ok())
201 .map(ToOwned::to_owned)
202}
203
204fn validate_supported_version(card: &AgentCard, headers: &HeaderMap) -> Result<(), A2AError> {
205 let requested_version = match header_value(headers, "A2A-Version") {
206 Some(version) if version.trim().is_empty() => "0.3".to_owned(),
207 Some(version) => version,
208 None => PROTOCOL_VERSION.to_owned(),
209 };
210 let supported_versions = card
211 .supported_interfaces
212 .iter()
213 .map(|interface| interface.protocol_version.as_str())
214 .collect::<BTreeSet<_>>();
215
216 if supported_versions.is_empty() || supported_versions.contains(requested_version.as_str()) {
217 return Ok(());
218 }
219
220 Err(A2AError::VersionNotSupported(requested_version))
221}
222
223fn validate_required_extensions(card: &AgentCard, headers: &HeaderMap) -> Result<(), A2AError> {
224 let required_extensions = card
225 .capabilities
226 .extensions
227 .iter()
228 .filter(|extension| extension.required)
229 .map(|extension| extension.uri.as_str())
230 .collect::<BTreeSet<_>>();
231
232 if required_extensions.is_empty() {
233 return Ok(());
234 }
235
236 let announced_extensions = header_value(headers, "A2A-Extensions")
237 .into_iter()
238 .flat_map(|value| {
239 value
240 .split(',')
241 .map(str::trim)
242 .filter(|value| !value.is_empty())
243 .map(ToOwned::to_owned)
244 .collect::<Vec<_>>()
245 })
246 .collect::<BTreeSet<_>>();
247
248 let missing = required_extensions
249 .into_iter()
250 .filter(|extension| !announced_extensions.contains(*extension))
251 .collect::<Vec<_>>();
252
253 if missing.is_empty() {
254 return Ok(());
255 }
256
257 Err(A2AError::ExtensionSupportRequired(format!(
258 "missing required extensions: {}",
259 missing.join(", ")
260 )))
261}