1use a2a::*;
4use a2a_pb::protojson_conv::{self, ProtoJsonPayload};
5use async_trait::async_trait;
6use futures::stream::BoxStream;
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::Value;
10
11use crate::push_config_compat::{
12 deserialize_list_task_push_notification_configs_response,
13 deserialize_task_push_notification_config,
14};
15use crate::transport::{ServiceParams, Transport, TransportFactory};
16
17const REST_SEND_MESSAGE_PATH: &str = "/message:send";
18const REST_STREAM_MESSAGE_PATH: &str = "/message:stream";
19const REST_EXTENDED_AGENT_CARD_PATH: &str = "/extendedAgentCard";
20
21#[derive(Debug, Deserialize)]
22struct RestErrorEnvelope {
23 error: RestErrorStatus,
24}
25
26#[derive(Debug, Deserialize)]
27struct RestErrorStatus {
28 message: String,
29
30 #[serde(default)]
31 details: Vec<TypedDetail>,
32}
33
34pub struct RestTransport {
38 client: Client,
39 base_url: String,
40}
41
42impl RestTransport {
43 pub fn new(client: Client, base_url: String) -> Self {
44 let base_url = base_url.trim_end_matches('/').to_string();
45 RestTransport { client, base_url }
46 }
47
48 fn build_request(
49 &self,
50 method: reqwest::Method,
51 path: &str,
52 params: &ServiceParams,
53 ) -> reqwest::RequestBuilder {
54 let url = format!("{}{}", self.base_url, path);
55 let mut builder = self.client.request(method, &url);
56 for (key, values) in params {
57 for v in values {
58 builder = builder.header(key, v);
59 }
60 }
61 builder
62 }
63
64 fn build_request_with_query(
65 &self,
66 method: reqwest::Method,
67 path: &str,
68 params: &ServiceParams,
69 query: &[(String, String)],
70 ) -> reqwest::RequestBuilder {
71 let builder = self.build_request(method, path, params);
72 if query.is_empty() {
73 builder
74 } else {
75 builder.query(query)
76 }
77 }
78
79 async fn send(&self, builder: reqwest::RequestBuilder) -> Result<reqwest::Response, A2AError> {
80 builder
81 .send()
82 .await
83 .map_err(|e| A2AError::internal(format!("HTTP request failed: {e}")))
84 }
85
86 async fn into_rest_error(resp: reqwest::Response) -> A2AError {
87 let status = resp.status();
88 let body = resp.text().await.unwrap_or_default();
89 parse_rest_error(status, &body)
90 }
91
92 async fn post_value<Req>(
93 &self,
94 path: &str,
95 params: &ServiceParams,
96 body: &Req,
97 ) -> Result<Value, A2AError>
98 where
99 Req: ProtoJsonPayload,
100 {
101 let payload = protojson_conv::to_value(body).map_err(|e| {
102 A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
103 })?;
104 let resp = self
105 .send(
106 self.build_request(reqwest::Method::POST, path, params)
107 .json(&payload),
108 )
109 .await?;
110
111 if !resp.status().is_success() {
112 return Err(Self::into_rest_error(resp).await);
113 }
114 let payload = resp
115 .json::<Value>()
116 .await
117 .map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
118
119 Ok(payload)
120 }
121
122 async fn post_json<Req, Resp>(
123 &self,
124 path: &str,
125 params: &ServiceParams,
126 body: &Req,
127 ) -> Result<Resp, A2AError>
128 where
129 Req: ProtoJsonPayload,
130 Resp: ProtoJsonPayload,
131 {
132 let payload = self.post_value(path, params, body).await?;
133
134 protojson_conv::from_value(payload).map_err(|e| {
135 A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
136 })
137 }
138
139 async fn get_value(
140 &self,
141 path: &str,
142 params: &ServiceParams,
143 query: &[(String, String)],
144 ) -> Result<Value, A2AError> {
145 let resp = self
146 .send(self.build_request_with_query(reqwest::Method::GET, path, params, query))
147 .await?;
148
149 if !resp.status().is_success() {
150 return Err(Self::into_rest_error(resp).await);
151 }
152 let payload = resp
153 .json::<Value>()
154 .await
155 .map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
156
157 Ok(payload)
158 }
159
160 async fn get_json<Resp>(
161 &self,
162 path: &str,
163 params: &ServiceParams,
164 query: &[(String, String)],
165 ) -> Result<Resp, A2AError>
166 where
167 Resp: ProtoJsonPayload,
168 {
169 let payload = self.get_value(path, params, query).await?;
170
171 protojson_conv::from_value(payload).map_err(|e| {
172 A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
173 })
174 }
175
176 async fn post_streaming<Req>(
177 &self,
178 path: &str,
179 params: &ServiceParams,
180 body: &Req,
181 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError>
182 where
183 Req: ProtoJsonPayload,
184 {
185 let payload = protojson_conv::to_value(body).map_err(|e| {
186 A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
187 })?;
188 let resp = self
189 .send(
190 self.build_request(reqwest::Method::POST, path, params)
191 .header("Accept", "text/event-stream")
192 .json(&payload),
193 )
194 .await?;
195
196 if !resp.status().is_success() {
197 return Err(Self::into_rest_error(resp).await);
198 }
199
200 let stream = resp.bytes_stream();
201 Ok(crate::jsonrpc::parse_sse_stream_rest(stream))
202 }
203
204 async fn get_streaming(
205 &self,
206 path: &str,
207 params: &ServiceParams,
208 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
209 let resp = self
210 .send(
211 self.build_request(reqwest::Method::GET, path, params)
212 .header("Accept", "text/event-stream"),
213 )
214 .await?;
215
216 if !resp.status().is_success() {
217 return Err(Self::into_rest_error(resp).await);
218 }
219
220 let stream = resp.bytes_stream();
221 Ok(crate::jsonrpc::parse_sse_stream_rest(stream))
222 }
223
224 async fn delete(&self, path: &str, params: &ServiceParams) -> Result<(), A2AError> {
225 let resp = self
226 .send(self.build_request(reqwest::Method::DELETE, path, params))
227 .await?;
228
229 if !resp.status().is_success() {
230 return Err(Self::into_rest_error(resp).await);
231 }
232 Ok(())
233 }
234}
235
236fn parse_rest_error(status: reqwest::StatusCode, body: &str) -> A2AError {
237 let Ok(envelope) = serde_json::from_str::<RestErrorEnvelope>(body) else {
238 return A2AError::internal(format!("HTTP {status}: {body}"));
239 };
240
241 crate::a2a_error_from_details(
242 error_code::INTERNAL_ERROR,
243 envelope.error.message,
244 envelope.error.details,
245 )
246}
247
248#[async_trait]
249impl Transport for RestTransport {
250 async fn send_message(
251 &self,
252 params: &ServiceParams,
253 req: &SendMessageRequest,
254 ) -> Result<SendMessageResponse, A2AError> {
255 self.post_json(REST_SEND_MESSAGE_PATH, params, req).await
256 }
257
258 async fn send_streaming_message(
259 &self,
260 params: &ServiceParams,
261 req: &SendMessageRequest,
262 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
263 self.post_streaming(REST_STREAM_MESSAGE_PATH, params, req)
264 .await
265 }
266
267 async fn get_task(
268 &self,
269 params: &ServiceParams,
270 req: &GetTaskRequest,
271 ) -> Result<Task, A2AError> {
272 let path = format!("/tasks/{}", req.id);
273 let mut query_parts = Vec::new();
274 if let Some(hl) = req.history_length {
275 query_parts.push(("historyLength".to_string(), hl.to_string()));
276 }
277 self.get_json(&path, params, &query_parts).await
278 }
279
280 async fn list_tasks(
281 &self,
282 params: &ServiceParams,
283 req: &ListTasksRequest,
284 ) -> Result<ListTasksResponse, A2AError> {
285 let mut query_parts = Vec::new();
286 if let Some(ref cid) = req.context_id {
287 query_parts.push(("contextId".to_string(), cid.clone()));
288 }
289 if let Some(ref status) = req.status {
290 let s = serde_json::to_value(status)
291 .ok()
292 .and_then(|v| v.as_str().map(String::from))
293 .unwrap_or_default();
294 query_parts.push(("status".to_string(), s));
295 }
296 if let Some(ps) = req.page_size {
297 query_parts.push(("pageSize".to_string(), ps.to_string()));
298 }
299 if let Some(ref pt) = req.page_token {
300 query_parts.push(("pageToken".to_string(), pt.clone()));
301 }
302 if let Some(hl) = req.history_length {
303 query_parts.push(("historyLength".to_string(), hl.to_string()));
304 }
305 if let Some(ref ts) = req.status_timestamp_after {
306 query_parts.push(("statusTimestampAfter".to_string(), ts.to_rfc3339()));
307 }
308 if let Some(ia) = req.include_artifacts {
309 query_parts.push(("includeArtifacts".to_string(), ia.to_string()));
310 }
311 self.get_json("/tasks", params, &query_parts).await
312 }
313
314 async fn cancel_task(
315 &self,
316 params: &ServiceParams,
317 req: &CancelTaskRequest,
318 ) -> Result<Task, A2AError> {
319 self.post_json(&format!("/tasks/{}:cancel", req.id), params, req)
320 .await
321 }
322
323 async fn subscribe_to_task(
324 &self,
325 params: &ServiceParams,
326 req: &SubscribeToTaskRequest,
327 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
328 self.get_streaming(&format!("/tasks/{}:subscribe", req.id), params)
329 .await
330 }
331
332 async fn create_push_config(
333 &self,
334 params: &ServiceParams,
335 req: &TaskPushNotificationConfig,
336 ) -> Result<TaskPushNotificationConfig, A2AError> {
337 let payload = self
338 .post_value(
339 &format!("/tasks/{}/pushNotificationConfigs", req.task_id),
340 params,
341 req,
342 )
343 .await?;
344 deserialize_task_push_notification_config(payload)
345 }
346
347 async fn get_push_config(
348 &self,
349 params: &ServiceParams,
350 req: &GetTaskPushNotificationConfigRequest,
351 ) -> Result<TaskPushNotificationConfig, A2AError> {
352 let payload = self
353 .get_value(
354 &format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
355 params,
356 &[],
357 )
358 .await?;
359 deserialize_task_push_notification_config(payload)
360 }
361
362 async fn list_push_configs(
363 &self,
364 params: &ServiceParams,
365 req: &ListTaskPushNotificationConfigsRequest,
366 ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
367 let mut query_parts = Vec::new();
368 if let Some(page_size) = req.page_size {
369 query_parts.push(("pageSize".to_string(), page_size.to_string()));
370 }
371 if let Some(ref page_token) = req.page_token {
372 query_parts.push(("pageToken".to_string(), page_token.clone()));
373 }
374
375 let payload = self
376 .get_value(
377 &format!("/tasks/{}/pushNotificationConfigs", req.task_id),
378 params,
379 &query_parts,
380 )
381 .await?;
382 deserialize_list_task_push_notification_configs_response(payload)
383 }
384
385 async fn delete_push_config(
386 &self,
387 params: &ServiceParams,
388 req: &DeleteTaskPushNotificationConfigRequest,
389 ) -> Result<(), A2AError> {
390 self.delete(
391 &format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
392 params,
393 )
394 .await
395 }
396
397 async fn get_extended_agent_card(
398 &self,
399 params: &ServiceParams,
400 _req: &GetExtendedAgentCardRequest,
401 ) -> Result<AgentCard, A2AError> {
402 self.get_json(REST_EXTENDED_AGENT_CARD_PATH, params, &[])
403 .await
404 }
405
406 async fn destroy(&self) -> Result<(), A2AError> {
407 Ok(())
408 }
409}
410
411pub struct RestTransportFactory {
413 client: Client,
414}
415
416impl RestTransportFactory {
417 pub fn new(client: Option<Client>) -> Self {
418 RestTransportFactory {
419 client: client.unwrap_or_default(),
420 }
421 }
422
423 #[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
424 pub fn with_root_certificates_pem(pem: &[u8]) -> Result<Self, A2AError> {
425 Ok(Self {
426 client: crate::build_reqwest_client_with_root_pem(pem)?,
427 })
428 }
429}
430
431#[async_trait]
432impl TransportFactory for RestTransportFactory {
433 fn protocol(&self) -> &str {
434 TRANSPORT_PROTOCOL_HTTP_JSON
435 }
436
437 async fn create(
438 &self,
439 _card: &AgentCard,
440 iface: &AgentInterface,
441 ) -> Result<Box<dyn Transport>, A2AError> {
442 Ok(Box::new(RestTransport::new(
443 self.client.clone(),
444 iface.url.clone(),
445 )))
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use crate::test_utils::rcgen_self_signed_ca_pem;
453 use serde_json::json;
454
455 #[test]
456 fn test_rest_transport_new_strips_trailing_slash() {
457 let t = RestTransport::new(Client::new(), "http://localhost:8080/".into());
458 assert_eq!(t.base_url, "http://localhost:8080");
459 }
460
461 #[test]
462 fn test_rest_transport_new_no_trailing_slash() {
463 let t = RestTransport::new(Client::new(), "http://localhost:8080".into());
464 assert_eq!(t.base_url, "http://localhost:8080");
465 }
466
467 #[test]
468 fn test_rest_transport_factory_protocol() {
469 let f = RestTransportFactory::new(None);
470 assert_eq!(f.protocol(), "HTTP+JSON");
471 }
472
473 #[tokio::test]
474 async fn test_rest_transport_factory_create() {
475 let f = RestTransportFactory::new(None);
476 let card = AgentCard {
477 name: "Test".into(),
478 description: "Test".into(),
479 version: "1.0".into(),
480 supported_interfaces: vec![],
481 capabilities: AgentCapabilities::default(),
482 default_input_modes: vec!["text/plain".into()],
483 default_output_modes: vec!["text/plain".into()],
484 skills: vec![],
485 provider: None,
486 documentation_url: None,
487 icon_url: None,
488 security_schemes: None,
489 security_requirements: None,
490 signatures: None,
491 };
492 let iface = AgentInterface::new("http://localhost:8080/", "HTTP+JSON");
493 let transport = f.create(&card, &iface).await.unwrap();
494 transport.destroy().await.unwrap();
495 }
496
497 #[test]
498 fn test_build_request_adds_params() {
499 let t = RestTransport::new(Client::new(), "http://localhost:8080".into());
500 let mut params = ServiceParams::new();
501 params.insert("X-Custom".into(), vec!["val1".into(), "val2".into()]);
502 let builder = t.build_request(reqwest::Method::GET, "/test", ¶ms);
503 let req = builder.build().unwrap();
504 let vals: Vec<_> = req
505 .headers()
506 .get_all("X-Custom")
507 .iter()
508 .map(|v| v.to_str().unwrap().to_string())
509 .collect();
510 assert_eq!(vals, vec!["val1", "val2"]);
511 }
512
513 #[test]
514 fn test_parse_rest_error_preserves_a2a_error_code() {
515 let body = json!({
516 "error": {
517 "code": 404,
518 "status": "NOT_FOUND",
519 "message": "task not found: t1",
520 "details": [
521 {
522 "@type": errordetails::ERROR_INFO_TYPE,
523 "reason": "TASK_NOT_FOUND",
524 "domain": errordetails::PROTOCOL_DOMAIN,
525 "metadata": {
526 "taskId": "t1"
527 }
528 },
529 {
530 "resource": "task"
531 }
532 ]
533 }
534 })
535 .to_string();
536
537 let err = parse_rest_error(reqwest::StatusCode::NOT_FOUND, &body);
538
539 assert_eq!(err.code, error_code::TASK_NOT_FOUND);
540 assert_eq!(err.message, "task not found: t1");
541 let details = err.details.expect("expected structured details");
542 assert_eq!(details.len(), 2);
543 assert_eq!(details[0].type_url, errordetails::ERROR_INFO_TYPE);
544 assert_eq!(
545 details[1].value.get("resource"),
546 Some(&Value::String("task".into()))
547 );
548 }
549
550 #[test]
551 fn test_parse_rest_error_accepts_go_reason_aliases() {
552 let body = json!({
553 "error": {
554 "code": 400,
555 "status": "INVALID_ARGUMENT",
556 "message": "incompatible content types",
557 "details": [
558 {
559 "@type": errordetails::ERROR_INFO_TYPE,
560 "reason": "UNSUPPORTED_CONTENT_TYPE",
561 "domain": errordetails::PROTOCOL_DOMAIN,
562 "metadata": {}
563 }
564 ]
565 }
566 })
567 .to_string();
568
569 let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
570 assert_eq!(err.code, error_code::CONTENT_TYPE_NOT_SUPPORTED);
571 }
572
573 #[test]
574 fn test_parse_rest_error_bad_request_fallback() {
575 let body = json!({
576 "error": {
577 "code": 400,
578 "status": "INVALID_ARGUMENT",
579 "message": "invalid request parameters",
580 "details": [
581 {
582 "@type": errordetails::BAD_REQUEST_TYPE,
583 "fieldViolations": [
584 {
585 "field": "message.parts",
586 "description": "At least one part is required"
587 }
588 ]
589 }
590 ]
591 }
592 })
593 .to_string();
594
595 let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
596 assert_eq!(err.code, error_code::INVALID_PARAMS);
597 assert!(
598 err.message
599 .contains("message.parts: At least one part is required")
600 );
601 let details = err.details.expect("expected details");
602 assert_eq!(details.len(), 1);
603 assert_eq!(details[0].type_url, errordetails::BAD_REQUEST_TYPE);
604 let violations = details[0].value.get("fieldViolations").unwrap();
605 assert_eq!(violations[0]["field"], "message.parts");
606 }
607
608 #[test]
609 fn test_parse_rest_error_bad_request_with_error_info_uses_reason() {
610 let body = json!({
611 "error": {
612 "code": 400,
613 "status": "INVALID_ARGUMENT",
614 "message": "bad params",
615 "details": [
616 {
617 "@type": errordetails::BAD_REQUEST_TYPE,
618 "fieldViolations": [
619 {"field": "task.id", "description": "required"}
620 ]
621 },
622 {
623 "@type": errordetails::ERROR_INFO_TYPE,
624 "reason": "INVALID_PARAMS",
625 "domain": errordetails::PROTOCOL_DOMAIN,
626 "metadata": {}
627 }
628 ]
629 }
630 })
631 .to_string();
632
633 let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
634 assert_eq!(err.code, error_code::INVALID_PARAMS);
635 }
636
637 #[test]
638 fn test_with_root_certificates_pem_valid() {
639 let pem = rcgen_self_signed_ca_pem();
640 let f = RestTransportFactory::with_root_certificates_pem(&pem).unwrap();
641 assert_eq!(f.protocol(), TRANSPORT_PROTOCOL_HTTP_JSON);
642 }
643
644 #[tokio::test]
645 async fn test_with_root_certificates_pem_factory_create() {
646 let pem = rcgen_self_signed_ca_pem();
647 let f = RestTransportFactory::with_root_certificates_pem(&pem).unwrap();
648 let card = AgentCard {
649 name: "Test".into(),
650 description: "Test".into(),
651 version: "1.0".into(),
652 supported_interfaces: vec![],
653 capabilities: AgentCapabilities::default(),
654 default_input_modes: vec!["text/plain".into()],
655 default_output_modes: vec!["text/plain".into()],
656 skills: vec![],
657 provider: None,
658 documentation_url: None,
659 icon_url: None,
660 security_schemes: None,
661 security_requirements: None,
662 signatures: None,
663 };
664 let iface = AgentInterface::new("https://localhost:3443/rest", "HTTP+JSON");
665 let transport = f.create(&card, &iface).await.unwrap();
666 transport.destroy().await.unwrap();
667 }
668}