1use a2a::*;
4use a2a_pb::protojson_conv::{self, ProtoJsonPayload};
5use async_trait::async_trait;
6use futures::stream::BoxStream;
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::Value;
10use std::collections::HashMap;
11
12use crate::push_config_compat::{
13 deserialize_list_task_push_notification_configs_response,
14 deserialize_task_push_notification_config,
15};
16use crate::transport::{ServiceParams, Transport, TransportFactory};
17
18const REST_SEND_MESSAGE_PATH: &str = "/message:send";
19const REST_STREAM_MESSAGE_PATH: &str = "/message:stream";
20const REST_EXTENDED_AGENT_CARD_PATH: &str = "/extendedAgentCard";
21const REST_ERROR_INFO_TYPE_URL: &str = "type.googleapis.com/google.rpc.ErrorInfo";
22const REST_ERROR_DOMAIN: &str = "a2a-protocol.org";
23
24#[derive(Debug, Deserialize)]
25struct RestErrorEnvelope {
26 error: RestErrorStatus,
27}
28
29#[derive(Debug, Deserialize)]
30struct RestErrorStatus {
31 message: String,
32
33 #[serde(default)]
34 details: Vec<Value>,
35}
36
37#[derive(Debug, Deserialize)]
38struct RestErrorInfo {
39 #[serde(rename = "@type")]
40 type_url: String,
41 reason: String,
42 domain: String,
43
44 #[serde(default)]
45 metadata: HashMap<String, String>,
46}
47
48pub struct RestTransport {
52 client: Client,
53 base_url: String,
54}
55
56impl RestTransport {
57 pub fn new(client: Client, base_url: String) -> Self {
58 let base_url = base_url.trim_end_matches('/').to_string();
59 RestTransport { client, base_url }
60 }
61
62 fn build_request(
63 &self,
64 method: reqwest::Method,
65 path: &str,
66 params: &ServiceParams,
67 ) -> reqwest::RequestBuilder {
68 let url = format!("{}{}", self.base_url, path);
69 let mut builder = self.client.request(method, &url);
70 for (key, values) in params {
71 for v in values {
72 builder = builder.header(key, v);
73 }
74 }
75 builder
76 }
77
78 fn build_request_with_query(
79 &self,
80 method: reqwest::Method,
81 path: &str,
82 params: &ServiceParams,
83 query: &[(String, String)],
84 ) -> reqwest::RequestBuilder {
85 let builder = self.build_request(method, path, params);
86 if query.is_empty() {
87 builder
88 } else {
89 builder.query(query)
90 }
91 }
92
93 async fn send(&self, builder: reqwest::RequestBuilder) -> Result<reqwest::Response, A2AError> {
94 builder
95 .send()
96 .await
97 .map_err(|e| A2AError::internal(format!("HTTP request failed: {e}")))
98 }
99
100 async fn into_rest_error(resp: reqwest::Response) -> A2AError {
101 let status = resp.status();
102 let body = resp.text().await.unwrap_or_default();
103 parse_rest_error(status, &body)
104 }
105
106 async fn post_value<Req>(
107 &self,
108 path: &str,
109 params: &ServiceParams,
110 body: &Req,
111 ) -> Result<Value, A2AError>
112 where
113 Req: ProtoJsonPayload,
114 {
115 let payload = protojson_conv::to_value(body).map_err(|e| {
116 A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
117 })?;
118 let resp = self
119 .send(
120 self.build_request(reqwest::Method::POST, path, params)
121 .json(&payload),
122 )
123 .await?;
124
125 if !resp.status().is_success() {
126 return Err(Self::into_rest_error(resp).await);
127 }
128 let payload = resp
129 .json::<Value>()
130 .await
131 .map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
132
133 Ok(payload)
134 }
135
136 async fn post_json<Req, Resp>(
137 &self,
138 path: &str,
139 params: &ServiceParams,
140 body: &Req,
141 ) -> Result<Resp, A2AError>
142 where
143 Req: ProtoJsonPayload,
144 Resp: ProtoJsonPayload,
145 {
146 let payload = self.post_value(path, params, body).await?;
147
148 protojson_conv::from_value(payload).map_err(|e| {
149 A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
150 })
151 }
152
153 async fn get_value(
154 &self,
155 path: &str,
156 params: &ServiceParams,
157 query: &[(String, String)],
158 ) -> Result<Value, A2AError> {
159 let resp = self
160 .send(self.build_request_with_query(reqwest::Method::GET, path, params, query))
161 .await?;
162
163 if !resp.status().is_success() {
164 return Err(Self::into_rest_error(resp).await);
165 }
166 let payload = resp
167 .json::<Value>()
168 .await
169 .map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
170
171 Ok(payload)
172 }
173
174 async fn get_json<Resp>(
175 &self,
176 path: &str,
177 params: &ServiceParams,
178 query: &[(String, String)],
179 ) -> Result<Resp, A2AError>
180 where
181 Resp: ProtoJsonPayload,
182 {
183 let payload = self.get_value(path, params, query).await?;
184
185 protojson_conv::from_value(payload).map_err(|e| {
186 A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
187 })
188 }
189
190 async fn post_streaming<Req>(
191 &self,
192 path: &str,
193 params: &ServiceParams,
194 body: &Req,
195 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError>
196 where
197 Req: ProtoJsonPayload,
198 {
199 let payload = protojson_conv::to_value(body).map_err(|e| {
200 A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
201 })?;
202 let resp = self
203 .send(
204 self.build_request(reqwest::Method::POST, path, params)
205 .header("Accept", "text/event-stream")
206 .json(&payload),
207 )
208 .await?;
209
210 if !resp.status().is_success() {
211 return Err(Self::into_rest_error(resp).await);
212 }
213
214 let stream = resp.bytes_stream();
215 Ok(crate::jsonrpc::parse_sse_stream_rest(stream))
216 }
217
218 async fn get_streaming(
219 &self,
220 path: &str,
221 params: &ServiceParams,
222 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
223 let resp = self
224 .send(
225 self.build_request(reqwest::Method::GET, path, params)
226 .header("Accept", "text/event-stream"),
227 )
228 .await?;
229
230 if !resp.status().is_success() {
231 return Err(Self::into_rest_error(resp).await);
232 }
233
234 let stream = resp.bytes_stream();
235 Ok(crate::jsonrpc::parse_sse_stream_rest(stream))
236 }
237
238 async fn delete(&self, path: &str, params: &ServiceParams) -> Result<(), A2AError> {
239 let resp = self
240 .send(self.build_request(reqwest::Method::DELETE, path, params))
241 .await?;
242
243 if !resp.status().is_success() {
244 return Err(Self::into_rest_error(resp).await);
245 }
246 Ok(())
247 }
248}
249
250fn parse_rest_error(status: reqwest::StatusCode, body: &str) -> A2AError {
251 let Ok(envelope) = serde_json::from_str::<RestErrorEnvelope>(body) else {
252 return A2AError::internal(format!("HTTP {status}: {body}"));
253 };
254
255 let mut details = HashMap::new();
256 let mut code = None;
257
258 for raw_detail in envelope.error.details {
259 if let Ok(info) = serde_json::from_value::<RestErrorInfo>(raw_detail.clone()) {
260 if info.type_url == REST_ERROR_INFO_TYPE_URL && info.domain == REST_ERROR_DOMAIN {
261 code = reason_to_error_code(&info.reason).or(code);
262 for (key, value) in info.metadata {
263 details.insert(key, Value::String(value));
264 }
265 continue;
266 }
267 }
268
269 if let Value::Object(values) = raw_detail {
270 details.extend(values.into_iter());
271 }
272 }
273
274 A2AError {
275 code: code.unwrap_or(error_code::INTERNAL_ERROR),
276 message: envelope.error.message,
277 details: (!details.is_empty()).then_some(details),
278 }
279}
280
281fn reason_to_error_code(reason: &str) -> Option<i32> {
282 match reason {
283 "TASK_NOT_FOUND" => Some(error_code::TASK_NOT_FOUND),
284 "TASK_NOT_CANCELABLE" => Some(error_code::TASK_NOT_CANCELABLE),
285 "PUSH_NOTIFICATION_NOT_SUPPORTED" => Some(error_code::PUSH_NOTIFICATION_NOT_SUPPORTED),
286 "UNSUPPORTED_OPERATION" => Some(error_code::UNSUPPORTED_OPERATION),
287 "UNSUPPORTED_CONTENT_TYPE" | "CONTENT_TYPE_NOT_SUPPORTED" => {
288 Some(error_code::CONTENT_TYPE_NOT_SUPPORTED)
289 }
290 "INVALID_AGENT_RESPONSE" => Some(error_code::INVALID_AGENT_RESPONSE),
291 "EXTENDED_AGENT_CARD_NOT_CONFIGURED" | "EXTENDED_CARD_NOT_CONFIGURED" => {
292 Some(error_code::EXTENDED_CARD_NOT_CONFIGURED)
293 }
294 "EXTENSION_SUPPORT_REQUIRED" => Some(error_code::EXTENSION_SUPPORT_REQUIRED),
295 "VERSION_NOT_SUPPORTED" => Some(error_code::VERSION_NOT_SUPPORTED),
296 "PARSE_ERROR" => Some(error_code::PARSE_ERROR),
297 "INVALID_REQUEST" => Some(error_code::INVALID_REQUEST),
298 "METHOD_NOT_FOUND" => Some(error_code::METHOD_NOT_FOUND),
299 "INVALID_PARAMS" => Some(error_code::INVALID_PARAMS),
300 "INTERNAL_ERROR" => Some(error_code::INTERNAL_ERROR),
301 _ => None,
302 }
303}
304
305#[async_trait]
306impl Transport for RestTransport {
307 async fn send_message(
308 &self,
309 params: &ServiceParams,
310 req: &SendMessageRequest,
311 ) -> Result<SendMessageResponse, A2AError> {
312 self.post_json(REST_SEND_MESSAGE_PATH, params, req).await
313 }
314
315 async fn send_streaming_message(
316 &self,
317 params: &ServiceParams,
318 req: &SendMessageRequest,
319 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
320 self.post_streaming(REST_STREAM_MESSAGE_PATH, params, req)
321 .await
322 }
323
324 async fn get_task(
325 &self,
326 params: &ServiceParams,
327 req: &GetTaskRequest,
328 ) -> Result<Task, A2AError> {
329 let path = format!("/tasks/{}", req.id);
330 let mut query_parts = Vec::new();
331 if let Some(hl) = req.history_length {
332 query_parts.push(("historyLength".to_string(), hl.to_string()));
333 }
334 self.get_json(&path, params, &query_parts).await
335 }
336
337 async fn list_tasks(
338 &self,
339 params: &ServiceParams,
340 req: &ListTasksRequest,
341 ) -> Result<ListTasksResponse, A2AError> {
342 let mut query_parts = Vec::new();
343 if let Some(ref cid) = req.context_id {
344 query_parts.push(("contextId".to_string(), cid.clone()));
345 }
346 if let Some(ref status) = req.status {
347 let s = serde_json::to_value(status)
348 .ok()
349 .and_then(|v| v.as_str().map(String::from))
350 .unwrap_or_default();
351 query_parts.push(("status".to_string(), s));
352 }
353 if let Some(ps) = req.page_size {
354 query_parts.push(("pageSize".to_string(), ps.to_string()));
355 }
356 if let Some(ref pt) = req.page_token {
357 query_parts.push(("pageToken".to_string(), pt.clone()));
358 }
359 if let Some(hl) = req.history_length {
360 query_parts.push(("historyLength".to_string(), hl.to_string()));
361 }
362 if let Some(ref ts) = req.status_timestamp_after {
363 query_parts.push(("statusTimestampAfter".to_string(), ts.to_rfc3339()));
364 }
365 if let Some(ia) = req.include_artifacts {
366 query_parts.push(("includeArtifacts".to_string(), ia.to_string()));
367 }
368 self.get_json("/tasks", params, &query_parts).await
369 }
370
371 async fn cancel_task(
372 &self,
373 params: &ServiceParams,
374 req: &CancelTaskRequest,
375 ) -> Result<Task, A2AError> {
376 self.post_json(&format!("/tasks/{}:cancel", req.id), params, req)
377 .await
378 }
379
380 async fn subscribe_to_task(
381 &self,
382 params: &ServiceParams,
383 req: &SubscribeToTaskRequest,
384 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
385 self.get_streaming(&format!("/tasks/{}:subscribe", req.id), params)
386 .await
387 }
388
389 async fn create_push_config(
390 &self,
391 params: &ServiceParams,
392 req: &CreateTaskPushNotificationConfigRequest,
393 ) -> Result<TaskPushNotificationConfig, A2AError> {
394 let payload = self
395 .post_value(
396 &format!("/tasks/{}/pushNotificationConfigs", req.task_id),
397 params,
398 &req.config,
399 )
400 .await?;
401 deserialize_task_push_notification_config(payload)
402 }
403
404 async fn get_push_config(
405 &self,
406 params: &ServiceParams,
407 req: &GetTaskPushNotificationConfigRequest,
408 ) -> Result<TaskPushNotificationConfig, A2AError> {
409 let payload = self
410 .get_value(
411 &format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
412 params,
413 &[],
414 )
415 .await?;
416 deserialize_task_push_notification_config(payload)
417 }
418
419 async fn list_push_configs(
420 &self,
421 params: &ServiceParams,
422 req: &ListTaskPushNotificationConfigsRequest,
423 ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
424 let mut query_parts = Vec::new();
425 if let Some(page_size) = req.page_size {
426 query_parts.push(("pageSize".to_string(), page_size.to_string()));
427 }
428 if let Some(ref page_token) = req.page_token {
429 query_parts.push(("pageToken".to_string(), page_token.clone()));
430 }
431
432 let payload = self
433 .get_value(
434 &format!("/tasks/{}/pushNotificationConfigs", req.task_id),
435 params,
436 &query_parts,
437 )
438 .await?;
439 deserialize_list_task_push_notification_configs_response(payload)
440 }
441
442 async fn delete_push_config(
443 &self,
444 params: &ServiceParams,
445 req: &DeleteTaskPushNotificationConfigRequest,
446 ) -> Result<(), A2AError> {
447 self.delete(
448 &format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
449 params,
450 )
451 .await
452 }
453
454 async fn get_extended_agent_card(
455 &self,
456 params: &ServiceParams,
457 _req: &GetExtendedAgentCardRequest,
458 ) -> Result<AgentCard, A2AError> {
459 self.get_json(REST_EXTENDED_AGENT_CARD_PATH, params, &[])
460 .await
461 }
462
463 async fn destroy(&self) -> Result<(), A2AError> {
464 Ok(())
465 }
466}
467
468pub struct RestTransportFactory {
470 client: Client,
471}
472
473impl RestTransportFactory {
474 pub fn new(client: Option<Client>) -> Self {
475 RestTransportFactory {
476 client: client.unwrap_or_default(),
477 }
478 }
479}
480
481#[async_trait]
482impl TransportFactory for RestTransportFactory {
483 fn protocol(&self) -> &str {
484 TRANSPORT_PROTOCOL_HTTP_JSON
485 }
486
487 async fn create(
488 &self,
489 _card: &AgentCard,
490 iface: &AgentInterface,
491 ) -> Result<Box<dyn Transport>, A2AError> {
492 Ok(Box::new(RestTransport::new(
493 self.client.clone(),
494 iface.url.clone(),
495 )))
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use serde_json::json;
503
504 #[test]
505 fn test_rest_transport_new_strips_trailing_slash() {
506 let t = RestTransport::new(Client::new(), "http://localhost:8080/".into());
507 assert_eq!(t.base_url, "http://localhost:8080");
508 }
509
510 #[test]
511 fn test_rest_transport_new_no_trailing_slash() {
512 let t = RestTransport::new(Client::new(), "http://localhost:8080".into());
513 assert_eq!(t.base_url, "http://localhost:8080");
514 }
515
516 #[test]
517 fn test_rest_transport_factory_protocol() {
518 let f = RestTransportFactory::new(None);
519 assert_eq!(f.protocol(), "HTTP+JSON");
520 }
521
522 #[tokio::test]
523 async fn test_rest_transport_factory_create() {
524 let f = RestTransportFactory::new(None);
525 let card = AgentCard {
526 name: "Test".into(),
527 description: "Test".into(),
528 version: "1.0".into(),
529 supported_interfaces: vec![],
530 capabilities: AgentCapabilities::default(),
531 default_input_modes: vec!["text/plain".into()],
532 default_output_modes: vec!["text/plain".into()],
533 skills: vec![],
534 provider: None,
535 documentation_url: None,
536 icon_url: None,
537 security_schemes: None,
538 security_requirements: None,
539 signatures: None,
540 };
541 let iface = AgentInterface::new("http://localhost:8080/", "HTTP+JSON");
542 let transport = f.create(&card, &iface).await.unwrap();
543 transport.destroy().await.unwrap();
544 }
545
546 #[test]
547 fn test_build_request_adds_params() {
548 let t = RestTransport::new(Client::new(), "http://localhost:8080".into());
549 let mut params = ServiceParams::new();
550 params.insert("X-Custom".into(), vec!["val1".into(), "val2".into()]);
551 let builder = t.build_request(reqwest::Method::GET, "/test", ¶ms);
552 let req = builder.build().unwrap();
553 let vals: Vec<_> = req
554 .headers()
555 .get_all("X-Custom")
556 .iter()
557 .map(|v| v.to_str().unwrap().to_string())
558 .collect();
559 assert_eq!(vals, vec!["val1", "val2"]);
560 }
561
562 #[test]
563 fn test_parse_rest_error_preserves_a2a_error_code() {
564 let body = json!({
565 "error": {
566 "code": 404,
567 "status": "NOT_FOUND",
568 "message": "task not found: t1",
569 "details": [
570 {
571 "@type": REST_ERROR_INFO_TYPE_URL,
572 "reason": "TASK_NOT_FOUND",
573 "domain": REST_ERROR_DOMAIN,
574 "metadata": {
575 "taskId": "t1"
576 }
577 },
578 {
579 "resource": "task"
580 }
581 ]
582 }
583 })
584 .to_string();
585
586 let err = parse_rest_error(reqwest::StatusCode::NOT_FOUND, &body);
587
588 assert_eq!(err.code, error_code::TASK_NOT_FOUND);
589 assert_eq!(err.message, "task not found: t1");
590 let details = err.details.expect("expected structured details");
591 assert_eq!(details.get("taskId"), Some(&Value::String("t1".into())));
592 assert_eq!(details.get("resource"), Some(&Value::String("task".into())));
593 }
594
595 #[test]
596 fn test_parse_rest_error_accepts_go_reason_aliases() {
597 let body = json!({
598 "error": {
599 "code": 400,
600 "status": "INVALID_ARGUMENT",
601 "message": "incompatible content types",
602 "details": [
603 {
604 "@type": REST_ERROR_INFO_TYPE_URL,
605 "reason": "UNSUPPORTED_CONTENT_TYPE",
606 "domain": REST_ERROR_DOMAIN,
607 "metadata": {}
608 }
609 ]
610 }
611 })
612 .to_string();
613
614 let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
615 assert_eq!(err.code, error_code::CONTENT_TYPE_NOT_SUPPORTED);
616 }
617}