1use a2a::*;
4use a2a_pb::protojson_conv::{self, ProtoJsonPayload};
5use async_trait::async_trait;
6use futures::stream::BoxStream;
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::Value;
10
11use crate::push_config_compat::{
12 deserialize_list_task_push_notification_configs_response,
13 deserialize_task_push_notification_config,
14};
15use crate::transport::{ServiceParams, Transport, TransportFactory};
16
17const REST_SEND_MESSAGE_PATH: &str = "/message:send";
18const REST_STREAM_MESSAGE_PATH: &str = "/message:stream";
19const REST_EXTENDED_AGENT_CARD_PATH: &str = "/extendedAgentCard";
20
21#[derive(Debug, Deserialize)]
22struct RestErrorEnvelope {
23 error: RestErrorStatus,
24}
25
26#[derive(Debug, Deserialize)]
27struct RestErrorStatus {
28 message: String,
29
30 #[serde(default)]
31 details: Vec<TypedDetail>,
32}
33
34pub struct RestTransport {
38 client: Client,
39 base_url: String,
40}
41
42impl RestTransport {
43 pub fn new(client: Client, base_url: String) -> Self {
45 let base_url = base_url.trim_end_matches('/').to_string();
46 RestTransport { client, base_url }
47 }
48
49 fn build_request(
50 &self,
51 method: reqwest::Method,
52 path: &str,
53 params: &ServiceParams,
54 ) -> reqwest::RequestBuilder {
55 let url = format!("{}{}", self.base_url, path);
56 let mut builder = self.client.request(method, &url);
57 for (key, values) in params {
58 for v in values {
59 builder = builder.header(key, v);
60 }
61 }
62 builder
63 }
64
65 fn build_request_with_query(
66 &self,
67 method: reqwest::Method,
68 path: &str,
69 params: &ServiceParams,
70 query: &[(String, String)],
71 ) -> reqwest::RequestBuilder {
72 let builder = self.build_request(method, path, params);
73 if query.is_empty() {
74 builder
75 } else {
76 builder.query(query)
77 }
78 }
79
80 async fn send(&self, builder: reqwest::RequestBuilder) -> Result<reqwest::Response, A2AError> {
81 builder
82 .send()
83 .await
84 .map_err(|e| A2AError::internal(format!("HTTP request failed: {e}")))
85 }
86
87 async fn into_rest_error(resp: reqwest::Response) -> A2AError {
88 let status = resp.status();
89 let body = resp.text().await.unwrap_or_default();
90 parse_rest_error(status, &body)
91 }
92
93 async fn post_value<Req>(
94 &self,
95 path: &str,
96 params: &ServiceParams,
97 body: &Req,
98 ) -> Result<Value, A2AError>
99 where
100 Req: ProtoJsonPayload,
101 {
102 let payload = protojson_conv::to_value(body).map_err(|e| {
103 A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
104 })?;
105 let resp = self
106 .send(
107 self.build_request(reqwest::Method::POST, path, params)
108 .json(&payload),
109 )
110 .await?;
111
112 if !resp.status().is_success() {
113 return Err(Self::into_rest_error(resp).await);
114 }
115 let payload = resp
116 .json::<Value>()
117 .await
118 .map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
119
120 Ok(payload)
121 }
122
123 async fn post_json<Req, Resp>(
124 &self,
125 path: &str,
126 params: &ServiceParams,
127 body: &Req,
128 ) -> Result<Resp, A2AError>
129 where
130 Req: ProtoJsonPayload,
131 Resp: ProtoJsonPayload,
132 {
133 let payload = self.post_value(path, params, body).await?;
134
135 protojson_conv::from_value(payload).map_err(|e| {
136 A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
137 })
138 }
139
140 async fn get_value(
141 &self,
142 path: &str,
143 params: &ServiceParams,
144 query: &[(String, String)],
145 ) -> Result<Value, A2AError> {
146 let resp = self
147 .send(self.build_request_with_query(reqwest::Method::GET, path, params, query))
148 .await?;
149
150 if !resp.status().is_success() {
151 return Err(Self::into_rest_error(resp).await);
152 }
153 let payload = resp
154 .json::<Value>()
155 .await
156 .map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
157
158 Ok(payload)
159 }
160
161 async fn get_json<Resp>(
162 &self,
163 path: &str,
164 params: &ServiceParams,
165 query: &[(String, String)],
166 ) -> Result<Resp, A2AError>
167 where
168 Resp: ProtoJsonPayload,
169 {
170 let payload = self.get_value(path, params, query).await?;
171
172 protojson_conv::from_value(payload).map_err(|e| {
173 A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
174 })
175 }
176
177 async fn post_streaming<Req>(
178 &self,
179 path: &str,
180 params: &ServiceParams,
181 body: &Req,
182 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError>
183 where
184 Req: ProtoJsonPayload,
185 {
186 let payload = protojson_conv::to_value(body).map_err(|e| {
187 A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
188 })?;
189 let resp = self
190 .send(
191 self.build_request(reqwest::Method::POST, path, params)
192 .header("Accept", "text/event-stream")
193 .json(&payload),
194 )
195 .await?;
196
197 if !resp.status().is_success() {
198 return Err(Self::into_rest_error(resp).await);
199 }
200
201 let stream = resp.bytes_stream();
202 Ok(crate::jsonrpc::parse_sse_stream_rest(stream))
203 }
204
205 async fn get_streaming(
206 &self,
207 path: &str,
208 params: &ServiceParams,
209 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
210 let resp = self
211 .send(
212 self.build_request(reqwest::Method::GET, path, params)
213 .header("Accept", "text/event-stream"),
214 )
215 .await?;
216
217 if !resp.status().is_success() {
218 return Err(Self::into_rest_error(resp).await);
219 }
220
221 let stream = resp.bytes_stream();
222 Ok(crate::jsonrpc::parse_sse_stream_rest(stream))
223 }
224
225 async fn delete(&self, path: &str, params: &ServiceParams) -> Result<(), A2AError> {
226 let resp = self
227 .send(self.build_request(reqwest::Method::DELETE, path, params))
228 .await?;
229
230 if !resp.status().is_success() {
231 return Err(Self::into_rest_error(resp).await);
232 }
233 Ok(())
234 }
235}
236
237fn parse_rest_error(status: reqwest::StatusCode, body: &str) -> A2AError {
238 let Ok(envelope) = serde_json::from_str::<RestErrorEnvelope>(body) else {
239 return A2AError::internal(format!("HTTP {status}: {body}"));
240 };
241
242 crate::a2a_error_from_details(
243 error_code::INTERNAL_ERROR,
244 envelope.error.message,
245 envelope.error.details,
246 )
247}
248
249#[async_trait]
250impl Transport for RestTransport {
251 async fn send_message(
252 &self,
253 params: &ServiceParams,
254 req: &SendMessageRequest,
255 ) -> Result<SendMessageResponse, A2AError> {
256 self.post_json(REST_SEND_MESSAGE_PATH, params, req).await
257 }
258
259 async fn send_streaming_message(
260 &self,
261 params: &ServiceParams,
262 req: &SendMessageRequest,
263 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
264 self.post_streaming(REST_STREAM_MESSAGE_PATH, params, req)
265 .await
266 }
267
268 async fn get_task(
269 &self,
270 params: &ServiceParams,
271 req: &GetTaskRequest,
272 ) -> Result<Task, A2AError> {
273 let path = format!("/tasks/{}", req.id);
274 let mut query_parts = Vec::new();
275 if let Some(hl) = req.history_length {
276 query_parts.push(("historyLength".to_string(), hl.to_string()));
277 }
278 self.get_json(&path, params, &query_parts).await
279 }
280
281 async fn list_tasks(
282 &self,
283 params: &ServiceParams,
284 req: &ListTasksRequest,
285 ) -> Result<ListTasksResponse, A2AError> {
286 let mut query_parts = Vec::new();
287 if let Some(ref cid) = req.context_id {
288 query_parts.push(("contextId".to_string(), cid.clone()));
289 }
290 if let Some(ref status) = req.status {
291 let s = serde_json::to_value(status)
292 .ok()
293 .and_then(|v| v.as_str().map(String::from))
294 .unwrap_or_default();
295 query_parts.push(("status".to_string(), s));
296 }
297 if let Some(ps) = req.page_size {
298 query_parts.push(("pageSize".to_string(), ps.to_string()));
299 }
300 if let Some(ref pt) = req.page_token {
301 query_parts.push(("pageToken".to_string(), pt.clone()));
302 }
303 if let Some(hl) = req.history_length {
304 query_parts.push(("historyLength".to_string(), hl.to_string()));
305 }
306 if let Some(ref ts) = req.status_timestamp_after {
307 query_parts.push(("statusTimestampAfter".to_string(), ts.to_rfc3339()));
308 }
309 if let Some(ia) = req.include_artifacts {
310 query_parts.push(("includeArtifacts".to_string(), ia.to_string()));
311 }
312 self.get_json("/tasks", params, &query_parts).await
313 }
314
315 async fn cancel_task(
316 &self,
317 params: &ServiceParams,
318 req: &CancelTaskRequest,
319 ) -> Result<Task, A2AError> {
320 self.post_json(&format!("/tasks/{}:cancel", req.id), params, req)
321 .await
322 }
323
324 async fn subscribe_to_task(
325 &self,
326 params: &ServiceParams,
327 req: &SubscribeToTaskRequest,
328 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
329 self.get_streaming(&format!("/tasks/{}:subscribe", req.id), params)
330 .await
331 }
332
333 async fn create_push_config(
334 &self,
335 params: &ServiceParams,
336 req: &TaskPushNotificationConfig,
337 ) -> Result<TaskPushNotificationConfig, A2AError> {
338 let payload = self
339 .post_value(
340 &format!("/tasks/{}/pushNotificationConfigs", req.task_id),
341 params,
342 req,
343 )
344 .await?;
345 deserialize_task_push_notification_config(payload)
346 }
347
348 async fn get_push_config(
349 &self,
350 params: &ServiceParams,
351 req: &GetTaskPushNotificationConfigRequest,
352 ) -> Result<TaskPushNotificationConfig, A2AError> {
353 let payload = self
354 .get_value(
355 &format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
356 params,
357 &[],
358 )
359 .await?;
360 deserialize_task_push_notification_config(payload)
361 }
362
363 async fn list_push_configs(
364 &self,
365 params: &ServiceParams,
366 req: &ListTaskPushNotificationConfigsRequest,
367 ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
368 let mut query_parts = Vec::new();
369 if let Some(page_size) = req.page_size {
370 query_parts.push(("pageSize".to_string(), page_size.to_string()));
371 }
372 if let Some(ref page_token) = req.page_token {
373 query_parts.push(("pageToken".to_string(), page_token.clone()));
374 }
375
376 let payload = self
377 .get_value(
378 &format!("/tasks/{}/pushNotificationConfigs", req.task_id),
379 params,
380 &query_parts,
381 )
382 .await?;
383 deserialize_list_task_push_notification_configs_response(payload)
384 }
385
386 async fn delete_push_config(
387 &self,
388 params: &ServiceParams,
389 req: &DeleteTaskPushNotificationConfigRequest,
390 ) -> Result<(), A2AError> {
391 self.delete(
392 &format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
393 params,
394 )
395 .await
396 }
397
398 async fn get_extended_agent_card(
399 &self,
400 params: &ServiceParams,
401 _req: &GetExtendedAgentCardRequest,
402 ) -> Result<AgentCard, A2AError> {
403 self.get_json(REST_EXTENDED_AGENT_CARD_PATH, params, &[])
404 .await
405 }
406
407 async fn destroy(&self) -> Result<(), A2AError> {
408 Ok(())
409 }
410}
411
412pub struct RestTransportFactory {
414 client: Client,
415}
416
417impl RestTransportFactory {
418 pub fn new(client: Option<Client>) -> Self {
419 RestTransportFactory {
420 client: client
421 .unwrap_or_else(|| crate::default_reqwest_client(None).expect("default client")),
422 }
423 }
424
425 #[cfg(any(
426 feature = "rustls-tls",
427 feature = "rustls-no-provider",
428 feature = "native-tls"
429 ))]
430 pub fn with_root_certificates_pem(pem: &[u8]) -> Result<Self, A2AError> {
431 Ok(Self {
432 client: crate::default_reqwest_client(Some(pem))?,
433 })
434 }
435}
436
437#[async_trait]
438impl TransportFactory for RestTransportFactory {
439 fn protocol(&self) -> &str {
440 TRANSPORT_PROTOCOL_HTTP_JSON
441 }
442
443 async fn create(
444 &self,
445 _card: &AgentCard,
446 iface: &AgentInterface,
447 ) -> Result<Box<dyn Transport>, A2AError> {
448 Ok(Box::new(RestTransport::new(
449 self.client.clone(),
450 iface.url.clone(),
451 )))
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458 use serde_json::json;
459
460 #[test]
461 fn test_rest_transport_new_strips_trailing_slash() {
462 let t = RestTransport::new(
463 crate::default_reqwest_client(None).unwrap(),
464 "http://localhost:8080/".into(),
465 );
466 assert_eq!(t.base_url, "http://localhost:8080");
467 }
468
469 #[test]
470 fn test_rest_transport_new_no_trailing_slash() {
471 let t = RestTransport::new(
472 crate::default_reqwest_client(None).unwrap(),
473 "http://localhost:8080".into(),
474 );
475 assert_eq!(t.base_url, "http://localhost:8080");
476 }
477
478 #[test]
479 fn test_rest_transport_factory_protocol() {
480 let f = RestTransportFactory::new(None);
481 assert_eq!(f.protocol(), "HTTP+JSON");
482 }
483
484 #[tokio::test]
485 async fn test_rest_transport_factory_create() {
486 let f = RestTransportFactory::new(None);
487 let card = AgentCard {
488 name: "Test".into(),
489 description: "Test".into(),
490 version: "1.0".into(),
491 supported_interfaces: vec![],
492 capabilities: AgentCapabilities::default(),
493 default_input_modes: vec!["text/plain".into()],
494 default_output_modes: vec!["text/plain".into()],
495 skills: vec![],
496 provider: None,
497 documentation_url: None,
498 icon_url: None,
499 security_schemes: None,
500 security_requirements: None,
501 signatures: None,
502 };
503 let iface = AgentInterface::new("http://localhost:8080/", "HTTP+JSON");
504 let transport = f.create(&card, &iface).await.unwrap();
505 transport.destroy().await.unwrap();
506 }
507
508 #[test]
509 fn test_build_request_adds_params() {
510 let t = RestTransport::new(
511 crate::default_reqwest_client(None).unwrap(),
512 "http://localhost:8080".into(),
513 );
514 let mut params = ServiceParams::new();
515 params.insert("X-Custom".into(), vec!["val1".into(), "val2".into()]);
516 let builder = t.build_request(reqwest::Method::GET, "/test", ¶ms);
517 let req = builder.build().unwrap();
518 let vals: Vec<_> = req
519 .headers()
520 .get_all("X-Custom")
521 .iter()
522 .map(|v| v.to_str().unwrap().to_string())
523 .collect();
524 assert_eq!(vals, vec!["val1", "val2"]);
525 }
526
527 #[test]
528 fn test_parse_rest_error_preserves_a2a_error_code() {
529 let body = json!({
530 "error": {
531 "code": 404,
532 "status": "NOT_FOUND",
533 "message": "task not found: t1",
534 "details": [
535 {
536 "@type": errordetails::ERROR_INFO_TYPE,
537 "reason": "TASK_NOT_FOUND",
538 "domain": errordetails::PROTOCOL_DOMAIN,
539 "metadata": {
540 "taskId": "t1"
541 }
542 },
543 {
544 "resource": "task"
545 }
546 ]
547 }
548 })
549 .to_string();
550
551 let err = parse_rest_error(reqwest::StatusCode::NOT_FOUND, &body);
552
553 assert_eq!(err.code, error_code::TASK_NOT_FOUND);
554 assert_eq!(err.message, "task not found: t1");
555 let details = err.details.expect("expected structured details");
556 assert_eq!(details.len(), 2);
557 assert_eq!(details[0].type_url, errordetails::ERROR_INFO_TYPE);
558 assert_eq!(
559 details[1].value.get("resource"),
560 Some(&Value::String("task".into()))
561 );
562 }
563
564 #[test]
565 fn test_parse_rest_error_accepts_go_reason_aliases() {
566 let body = json!({
567 "error": {
568 "code": 400,
569 "status": "INVALID_ARGUMENT",
570 "message": "incompatible content types",
571 "details": [
572 {
573 "@type": errordetails::ERROR_INFO_TYPE,
574 "reason": "UNSUPPORTED_CONTENT_TYPE",
575 "domain": errordetails::PROTOCOL_DOMAIN,
576 "metadata": {}
577 }
578 ]
579 }
580 })
581 .to_string();
582
583 let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
584 assert_eq!(err.code, error_code::CONTENT_TYPE_NOT_SUPPORTED);
585 }
586
587 #[test]
588 fn test_parse_rest_error_bad_request_fallback() {
589 let body = json!({
590 "error": {
591 "code": 400,
592 "status": "INVALID_ARGUMENT",
593 "message": "invalid request parameters",
594 "details": [
595 {
596 "@type": errordetails::BAD_REQUEST_TYPE,
597 "fieldViolations": [
598 {
599 "field": "message.parts",
600 "description": "At least one part is required"
601 }
602 ]
603 }
604 ]
605 }
606 })
607 .to_string();
608
609 let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
610 assert_eq!(err.code, error_code::INVALID_PARAMS);
611 assert!(
612 err.message
613 .contains("message.parts: At least one part is required")
614 );
615 let details = err.details.expect("expected details");
616 assert_eq!(details.len(), 1);
617 assert_eq!(details[0].type_url, errordetails::BAD_REQUEST_TYPE);
618 let violations = details[0].value.get("fieldViolations").unwrap();
619 assert_eq!(violations[0]["field"], "message.parts");
620 }
621
622 #[test]
623 fn test_parse_rest_error_bad_request_with_error_info_uses_reason() {
624 let body = json!({
625 "error": {
626 "code": 400,
627 "status": "INVALID_ARGUMENT",
628 "message": "bad params",
629 "details": [
630 {
631 "@type": errordetails::BAD_REQUEST_TYPE,
632 "fieldViolations": [
633 {"field": "task.id", "description": "required"}
634 ]
635 },
636 {
637 "@type": errordetails::ERROR_INFO_TYPE,
638 "reason": "INVALID_PARAMS",
639 "domain": errordetails::PROTOCOL_DOMAIN,
640 "metadata": {}
641 }
642 ]
643 }
644 })
645 .to_string();
646
647 let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
648 assert_eq!(err.code, error_code::INVALID_PARAMS);
649 }
650
651 #[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
652 #[test]
653 fn test_with_root_certificates_pem_valid() {
654 let pem = crate::test_utils::rcgen_self_signed_ca_pem();
655 let f = RestTransportFactory::with_root_certificates_pem(&pem).unwrap();
656 assert_eq!(f.protocol(), TRANSPORT_PROTOCOL_HTTP_JSON);
657 }
658
659 #[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
660 #[tokio::test]
661 async fn test_with_root_certificates_pem_factory_create() {
662 let pem = crate::test_utils::rcgen_self_signed_ca_pem();
663 let f = RestTransportFactory::with_root_certificates_pem(&pem).unwrap();
664 let card = AgentCard {
665 name: "Test".into(),
666 description: "Test".into(),
667 version: "1.0".into(),
668 supported_interfaces: vec![],
669 capabilities: AgentCapabilities::default(),
670 default_input_modes: vec!["text/plain".into()],
671 default_output_modes: vec!["text/plain".into()],
672 skills: vec![],
673 provider: None,
674 documentation_url: None,
675 icon_url: None,
676 security_schemes: None,
677 security_requirements: None,
678 signatures: None,
679 };
680 let iface = AgentInterface::new("https://localhost:3443/rest", "HTTP+JSON");
681 let transport = f.create(&card, &iface).await.unwrap();
682 transport.destroy().await.unwrap();
683 }
684}