1use std::collections::HashMap;
33use std::future::Future;
34use std::pin::Pin;
35use std::sync::Arc;
36use std::time::Duration;
37
38use tokio::sync::mpsc;
39use tonic::transport::Channel;
40
41use crate::error::{ClientError, ClientResult};
42use crate::streaming::EventStream;
43use crate::transport::Transport;
44
45mod proto {
47 #![allow(
48 clippy::all,
49 clippy::pedantic,
50 clippy::nursery,
51 missing_docs,
52 unused_qualifications
53 )]
54 tonic::include_proto!("a2a.v1");
55}
56
57use proto::a2a_service_client::A2aServiceClient;
58use proto::JsonPayload;
59
60#[derive(Debug, Clone)]
75pub struct GrpcTransportConfig {
76 pub timeout: Duration,
78 pub connect_timeout: Duration,
80 pub max_message_size: usize,
82 pub stream_channel_capacity: usize,
84}
85
86impl Default for GrpcTransportConfig {
87 fn default() -> Self {
88 Self {
89 timeout: Duration::from_secs(30),
90 connect_timeout: Duration::from_secs(10),
91 max_message_size: 4 * 1024 * 1024,
92 stream_channel_capacity: 64,
93 }
94 }
95}
96
97impl GrpcTransportConfig {
98 #[must_use]
100 pub const fn with_timeout(mut self, timeout: Duration) -> Self {
101 self.timeout = timeout;
102 self
103 }
104
105 #[must_use]
107 pub const fn with_connect_timeout(mut self, timeout: Duration) -> Self {
108 self.connect_timeout = timeout;
109 self
110 }
111
112 #[must_use]
114 pub const fn with_max_message_size(mut self, size: usize) -> Self {
115 self.max_message_size = size;
116 self
117 }
118
119 #[must_use]
121 pub const fn with_stream_channel_capacity(mut self, capacity: usize) -> Self {
122 self.stream_channel_capacity = capacity;
123 self
124 }
125}
126
127#[derive(Clone, Debug)]
135pub struct GrpcTransport {
136 inner: Arc<Inner>,
137}
138
139#[derive(Debug)]
140struct Inner {
141 channel: Channel,
145 endpoint: String,
146 config: GrpcTransportConfig,
147}
148
149impl GrpcTransport {
150 pub async fn connect(endpoint: impl Into<String>) -> ClientResult<Self> {
158 Self::connect_with_config(endpoint, GrpcTransportConfig::default()).await
159 }
160
161 pub async fn connect_with_config(
167 endpoint: impl Into<String>,
168 config: GrpcTransportConfig,
169 ) -> ClientResult<Self> {
170 let endpoint_str = endpoint.into();
171 validate_url(&endpoint_str)?;
172
173 let channel = tonic::transport::Channel::from_shared(endpoint_str.clone())
174 .map_err(|e| ClientError::InvalidEndpoint(format!("invalid gRPC endpoint: {e}")))?
175 .connect_timeout(config.connect_timeout)
176 .timeout(config.timeout)
177 .connect()
178 .await
179 .map_err(|e| ClientError::Transport(format!("gRPC connect failed: {e}")))?;
180
181 Ok(Self {
182 inner: Arc::new(Inner {
183 channel,
184 endpoint: endpoint_str,
185 config,
186 }),
187 })
188 }
189
190 #[must_use]
192 pub fn endpoint(&self) -> &str {
193 &self.inner.endpoint
194 }
195
196 fn encode_params(params: &serde_json::Value) -> ClientResult<JsonPayload> {
199 let data = serde_json::to_vec(params).map_err(ClientError::Serialization)?;
200 Ok(JsonPayload { data })
201 }
202
203 fn add_metadata(
204 req: &mut tonic::Request<JsonPayload>,
205 extra_headers: &HashMap<String, String>,
206 ) {
207 let md = req.metadata_mut();
208 md.insert(
209 "a2a-version",
210 a2a_protocol_types::A2A_VERSION
211 .parse()
212 .unwrap_or_else(|_| tonic::metadata::MetadataValue::from_static("")),
213 );
214 for (k, v) in extra_headers {
215 if let (Ok(key), Ok(val)) = (
216 k.parse::<tonic::metadata::MetadataKey<_>>(),
217 v.parse::<tonic::metadata::MetadataValue<_>>(),
218 ) {
219 md.insert(key, val);
220 }
221 }
222 }
223
224 fn decode_response(payload: &JsonPayload) -> ClientResult<serde_json::Value> {
225 serde_json::from_slice(&payload.data).map_err(ClientError::Serialization)
226 }
227
228 fn status_to_error(status: &tonic::Status) -> ClientError {
229 match status.code() {
232 tonic::Code::DeadlineExceeded => {
233 ClientError::Timeout(format!("gRPC deadline exceeded: {}", status.message()))
234 }
235 tonic::Code::Cancelled => {
236 ClientError::Timeout(format!("gRPC request cancelled: {}", status.message()))
237 }
238 tonic::Code::Unavailable => {
239 ClientError::HttpClient(format!("gRPC unavailable: {}", status.message()))
240 }
241 _ => {
242 let a2a = a2a_protocol_types::A2aError::new(
243 grpc_code_to_error_code(status.code()),
244 status.message().to_owned(),
245 );
246 ClientError::Protocol(a2a)
247 }
248 }
249 }
250
251 async fn execute_unary(
252 &self,
253 method: &str,
254 params: serde_json::Value,
255 extra_headers: &HashMap<String, String>,
256 ) -> ClientResult<serde_json::Value> {
257 trace_info!(
258 method,
259 endpoint = %self.inner.endpoint,
260 "sending gRPC request"
261 );
262
263 let payload = Self::encode_params(¶ms)?;
264 let mut req = tonic::Request::new(payload);
265 req.set_timeout(self.inner.config.timeout);
266 Self::add_metadata(&mut req, extra_headers);
267
268 let mut client = A2aServiceClient::new(self.inner.channel.clone())
272 .max_decoding_message_size(self.inner.config.max_message_size)
273 .max_encoding_message_size(self.inner.config.max_message_size);
274
275 let response = tokio::time::timeout(self.inner.config.timeout, async {
276 match method {
277 "SendMessage" => client.send_message(req).await,
278 "GetTask" => client.get_task(req).await,
279 "ListTasks" => client.list_tasks(req).await,
280 "CancelTask" => client.cancel_task(req).await,
281 "CreateTaskPushNotificationConfig" => {
282 client.create_task_push_notification_config(req).await
283 }
284 "GetTaskPushNotificationConfig" => {
285 client.get_task_push_notification_config(req).await
286 }
287 "ListTaskPushNotificationConfigs" => {
288 client.list_task_push_notification_configs(req).await
289 }
290 "DeleteTaskPushNotificationConfig" => {
291 client.delete_task_push_notification_config(req).await
292 }
293 "GetExtendedAgentCard" => client.get_extended_agent_card(req).await,
294 other => Err(tonic::Status::unimplemented(format!(
295 "unknown gRPC method: {other}"
296 ))),
297 }
298 })
299 .await
300 .map_err(|_| {
301 trace_error!(method, "gRPC request timed out");
302 ClientError::Timeout("gRPC request timed out".into())
303 })?;
304
305 match response {
306 Ok(resp) => Self::decode_response(&resp.into_inner()),
307 Err(status) => Err(Self::status_to_error(&status)),
308 }
309 }
310
311 async fn execute_streaming(
312 &self,
313 method: &str,
314 params: serde_json::Value,
315 extra_headers: &HashMap<String, String>,
316 ) -> ClientResult<EventStream> {
317 trace_info!(
318 method,
319 endpoint = %self.inner.endpoint,
320 "opening gRPC stream"
321 );
322
323 let payload = Self::encode_params(¶ms)?;
324 let mut req = tonic::Request::new(payload);
325 Self::add_metadata(&mut req, extra_headers);
326
327 let mut client = A2aServiceClient::new(self.inner.channel.clone())
329 .max_decoding_message_size(self.inner.config.max_message_size)
330 .max_encoding_message_size(self.inner.config.max_message_size);
331
332 let stream = tokio::time::timeout(self.inner.config.timeout, async {
333 let response = match method {
334 "SendStreamingMessage" => client.send_streaming_message(req).await,
335 "SubscribeToTask" => client.subscribe_to_task(req).await,
336 #[allow(clippy::needless_return)]
337 other => {
338 return Err(tonic::Status::unimplemented(format!(
339 "unknown streaming gRPC method: {other}"
340 )));
341 }
342 };
343 match response {
344 Ok(resp) => Ok(resp.into_inner()),
345 Err(status) => Err(status),
346 }
347 })
348 .await
349 .map_err(|_| {
350 trace_error!(method, "gRPC stream connect timed out");
351 ClientError::Timeout("gRPC stream connect timed out".into())
352 })?
353 .map_err(|status| Self::status_to_error(&status))?;
354
355 let cap = self.inner.config.stream_channel_capacity;
356 let (tx, rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(cap);
357
358 let task_handle = tokio::spawn(async move {
359 grpc_stream_reader_task(stream, tx).await;
360 });
361
362 Ok(EventStream::with_status(
365 rx,
366 task_handle.abort_handle(),
367 200,
368 ))
369 }
370}
371
372impl Transport for GrpcTransport {
373 fn send_request<'a>(
374 &'a self,
375 method: &'a str,
376 params: serde_json::Value,
377 extra_headers: &'a HashMap<String, String>,
378 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
379 Box::pin(self.execute_unary(method, params, extra_headers))
380 }
381
382 fn send_streaming_request<'a>(
383 &'a self,
384 method: &'a str,
385 params: serde_json::Value,
386 extra_headers: &'a HashMap<String, String>,
387 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
388 Box::pin(self.execute_streaming(method, params, extra_headers))
389 }
390}
391
392async fn grpc_stream_reader_task<S>(
401 mut stream: S,
402 tx: mpsc::Sender<crate::streaming::event_stream::BodyChunk>,
403) where
404 S: tonic::codegen::tokio_stream::Stream<Item = Result<JsonPayload, tonic::Status>> + Unpin,
405{
406 use tonic::codegen::tokio_stream::StreamExt;
407
408 loop {
409 match stream.next().await {
410 Some(Ok(payload)) => {
411 let json_str = match String::from_utf8(payload.data) {
415 Ok(s) => s,
416 Err(e) => {
417 let _ = tx
418 .send(Err(ClientError::Transport(format!(
419 "invalid UTF-8 in gRPC payload: {e}"
420 ))))
421 .await;
422 break;
423 }
424 };
425 let envelope =
427 format!("data: {{\"jsonrpc\":\"2.0\",\"id\":null,\"result\":{json_str}}}\n\n");
428 if tx
429 .send(Ok(hyper::body::Bytes::from(envelope)))
430 .await
431 .is_err()
432 {
433 break;
434 }
435 }
436 Some(Err(status)) => {
437 let a2a = a2a_protocol_types::A2aError::new(
441 grpc_code_to_error_code(status.code()),
442 status.message().to_owned(),
443 );
444 let _ = tx.send(Err(ClientError::Protocol(a2a))).await;
445 break;
446 }
447 None => break,
448 }
449 }
450}
451
452fn validate_url(url: &str) -> ClientResult<()> {
455 if url.is_empty() {
456 return Err(ClientError::InvalidEndpoint("URL must not be empty".into()));
457 }
458 if !url.starts_with("http://") && !url.starts_with("https://") {
459 return Err(ClientError::InvalidEndpoint(format!(
460 "URL must start with http:// or https://: {url}"
461 )));
462 }
463 Ok(())
464}
465
466const fn grpc_code_to_error_code(code: tonic::Code) -> a2a_protocol_types::ErrorCode {
467 match code {
471 tonic::Code::NotFound => a2a_protocol_types::ErrorCode::TaskNotFound,
472 tonic::Code::InvalidArgument
473 | tonic::Code::Unauthenticated
474 | tonic::Code::PermissionDenied
475 | tonic::Code::ResourceExhausted => a2a_protocol_types::ErrorCode::InvalidParams,
476 tonic::Code::Unimplemented => a2a_protocol_types::ErrorCode::MethodNotFound,
477 tonic::Code::FailedPrecondition => a2a_protocol_types::ErrorCode::TaskNotCancelable,
478 _ => a2a_protocol_types::ErrorCode::InternalError,
479 }
480}
481
482#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn validate_url_rejects_empty() {
490 assert!(validate_url("").is_err());
491 }
492
493 #[test]
494 fn validate_url_rejects_non_http() {
495 assert!(validate_url("ftp://example.com").is_err());
496 }
497
498 #[test]
499 fn validate_url_accepts_http() {
500 assert!(validate_url("http://localhost:50051").is_ok());
501 }
502
503 #[test]
504 fn config_default_timeout() {
505 let cfg = GrpcTransportConfig::default();
506 assert_eq!(cfg.timeout, Duration::from_secs(30));
507 }
508
509 #[test]
510 fn config_builder() {
511 let cfg = GrpcTransportConfig::default()
512 .with_timeout(Duration::from_secs(60))
513 .with_max_message_size(8 * 1024 * 1024)
514 .with_stream_channel_capacity(128);
515 assert_eq!(cfg.timeout, Duration::from_secs(60));
516 assert_eq!(cfg.max_message_size, 8 * 1024 * 1024);
517 assert_eq!(cfg.stream_channel_capacity, 128);
518 }
519
520 #[test]
521 fn grpc_code_not_found_maps_to_task_not_found() {
522 assert_eq!(
523 grpc_code_to_error_code(tonic::Code::NotFound),
524 a2a_protocol_types::ErrorCode::TaskNotFound,
525 );
526 }
527
528 #[test]
529 fn grpc_code_invalid_argument_maps_to_invalid_params() {
530 assert_eq!(
531 grpc_code_to_error_code(tonic::Code::InvalidArgument),
532 a2a_protocol_types::ErrorCode::InvalidParams,
533 );
534 }
535
536 #[test]
537 fn grpc_code_unauthenticated_maps_to_invalid_params() {
538 assert_eq!(
539 grpc_code_to_error_code(tonic::Code::Unauthenticated),
540 a2a_protocol_types::ErrorCode::InvalidParams,
541 );
542 }
543
544 #[test]
545 fn grpc_code_permission_denied_maps_to_invalid_params() {
546 assert_eq!(
547 grpc_code_to_error_code(tonic::Code::PermissionDenied),
548 a2a_protocol_types::ErrorCode::InvalidParams,
549 );
550 }
551
552 #[test]
553 fn grpc_code_resource_exhausted_maps_to_invalid_params() {
554 assert_eq!(
555 grpc_code_to_error_code(tonic::Code::ResourceExhausted),
556 a2a_protocol_types::ErrorCode::InvalidParams,
557 );
558 }
559
560 #[test]
561 fn grpc_code_unimplemented_maps_to_method_not_found() {
562 assert_eq!(
563 grpc_code_to_error_code(tonic::Code::Unimplemented),
564 a2a_protocol_types::ErrorCode::MethodNotFound,
565 );
566 }
567
568 #[test]
569 fn grpc_code_failed_precondition_maps_to_task_not_cancelable() {
570 assert_eq!(
571 grpc_code_to_error_code(tonic::Code::FailedPrecondition),
572 a2a_protocol_types::ErrorCode::TaskNotCancelable,
573 );
574 }
575
576 #[test]
577 fn grpc_code_deadline_exceeded_maps_to_internal() {
578 assert_eq!(
579 grpc_code_to_error_code(tonic::Code::DeadlineExceeded),
580 a2a_protocol_types::ErrorCode::InternalError,
581 );
582 }
583
584 #[test]
585 fn grpc_code_cancelled_maps_to_internal() {
586 assert_eq!(
587 grpc_code_to_error_code(tonic::Code::Cancelled),
588 a2a_protocol_types::ErrorCode::InternalError,
589 );
590 }
591
592 #[test]
593 fn grpc_code_unknown_maps_to_internal() {
594 assert_eq!(
595 grpc_code_to_error_code(tonic::Code::Unknown),
596 a2a_protocol_types::ErrorCode::InternalError,
597 );
598 }
599
600 #[test]
601 fn add_metadata_injects_a2a_version() {
602 let payload = JsonPayload { data: vec![] };
603 let mut req = tonic::Request::new(payload);
604 let headers = HashMap::new();
605 GrpcTransport::add_metadata(&mut req, &headers);
606 let md = req.metadata();
607 let version_value = md
608 .get("a2a-version")
609 .expect("a2a-version header should be present");
610 assert_eq!(
611 version_value.to_str().unwrap(),
612 a2a_protocol_types::A2A_VERSION,
613 );
614 }
615
616 #[test]
617 fn add_metadata_injects_extra_headers() {
618 let payload = JsonPayload { data: vec![] };
619 let mut req = tonic::Request::new(payload);
620 let mut headers = HashMap::new();
621 headers.insert("x-custom".to_string(), "value123".to_string());
622 GrpcTransport::add_metadata(&mut req, &headers);
623 let md = req.metadata();
624 assert_eq!(md.get("x-custom").unwrap().to_str().unwrap(), "value123",);
625 }
626
627 #[test]
630 fn status_to_error_deadline_exceeded_is_timeout() {
631 let status = tonic::Status::deadline_exceeded("test deadline");
632 let err = GrpcTransport::status_to_error(&status);
633 assert!(
634 matches!(err, ClientError::Timeout(_)),
635 "DeadlineExceeded should map to Timeout, got: {err:?}"
636 );
637 }
638
639 #[test]
640 fn status_to_error_cancelled_is_timeout() {
641 let status = tonic::Status::cancelled("test cancel");
642 let err = GrpcTransport::status_to_error(&status);
643 assert!(
644 matches!(err, ClientError::Timeout(_)),
645 "Cancelled should map to Timeout, got: {err:?}"
646 );
647 }
648
649 #[test]
650 fn status_to_error_unavailable_is_http_client() {
651 let status = tonic::Status::unavailable("test unavailable");
652 let err = GrpcTransport::status_to_error(&status);
653 assert!(
654 matches!(err, ClientError::HttpClient(_)),
655 "Unavailable should map to HttpClient, got: {err:?}"
656 );
657 }
658
659 #[test]
660 fn status_to_error_other_is_protocol() {
661 let status = tonic::Status::internal("test internal");
662 let err = GrpcTransport::status_to_error(&status);
663 assert!(
664 matches!(err, ClientError::Protocol(_)),
665 "other codes should map to Protocol, got: {err:?}"
666 );
667 }
668
669 #[tokio::test]
677 async fn grpc_stream_reader_task_forwards_payload_as_sse() {
678 let payloads = vec![Ok(JsonPayload {
679 data: br#"{"status":{"state":"working"}}"#.to_vec(),
680 })];
681 let stream = tonic::codegen::tokio_stream::iter(payloads);
682 let (tx, mut rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(8);
683
684 grpc_stream_reader_task(stream, tx).await;
685
686 let first = rx.recv().await.expect("expected one chunk");
687 let bytes = first.expect("expected Ok chunk");
688 let text = std::str::from_utf8(&bytes).expect("utf8");
689 assert!(
690 text.starts_with("data: "),
691 "chunk must be SSE-framed: {text}"
692 );
693 assert!(
694 text.contains("\"jsonrpc\":\"2.0\""),
695 "chunk must be JSON-RPC envelope: {text}"
696 );
697 assert!(
698 text.contains("\"working\""),
699 "chunk must include inner payload: {text}"
700 );
701 assert!(rx.recv().await.is_none());
703 }
704
705 #[tokio::test]
706 async fn grpc_stream_reader_task_forwards_multiple_payloads() {
707 let payloads = vec![
708 Ok(JsonPayload {
709 data: br#"{"n":1}"#.to_vec(),
710 }),
711 Ok(JsonPayload {
712 data: br#"{"n":2}"#.to_vec(),
713 }),
714 Ok(JsonPayload {
715 data: br#"{"n":3}"#.to_vec(),
716 }),
717 ];
718 let stream = tonic::codegen::tokio_stream::iter(payloads);
719 let (tx, mut rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(8);
720
721 grpc_stream_reader_task(stream, tx).await;
722
723 let mut received = 0;
724 while let Some(item) = rx.recv().await {
725 assert!(item.is_ok());
726 received += 1;
727 }
728 assert_eq!(received, 3, "all three payloads must be forwarded");
729 }
730
731 #[tokio::test]
732 async fn grpc_stream_reader_task_maps_status_error_to_protocol_error() {
733 let payloads: Vec<Result<JsonPayload, tonic::Status>> =
734 vec![Err(tonic::Status::not_found("missing"))];
735 let stream = tonic::codegen::tokio_stream::iter(payloads);
736 let (tx, mut rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(8);
737
738 grpc_stream_reader_task(stream, tx).await;
739
740 let chunk = rx.recv().await.expect("expected an error chunk");
741 match chunk {
742 Err(ClientError::Protocol(a2a)) => {
743 assert_eq!(a2a.code, a2a_protocol_types::ErrorCode::TaskNotFound);
744 assert!(a2a.message.contains("missing"));
745 }
746 other => panic!("expected Protocol(TaskNotFound), got {other:?}"),
747 }
748 }
749
750 #[tokio::test]
751 async fn grpc_stream_reader_task_handles_invalid_utf8() {
752 let payloads: Vec<Result<JsonPayload, tonic::Status>> = vec![Ok(JsonPayload {
753 data: vec![0xff, 0xfe, 0xfd],
755 })];
756 let stream = tonic::codegen::tokio_stream::iter(payloads);
757 let (tx, mut rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(8);
758
759 grpc_stream_reader_task(stream, tx).await;
760
761 let chunk = rx.recv().await.expect("expected an error chunk");
762 match chunk {
763 Err(ClientError::Transport(msg)) => {
764 assert!(msg.contains("UTF-8"), "msg should mention UTF-8: {msg}");
765 }
766 other => panic!("expected Transport error, got {other:?}"),
767 }
768 }
769
770 #[tokio::test]
778 async fn grpc_transport_endpoint_returns_input_url() {
779 let endpoint_str = "http://localhost:50055".to_string();
780 let channel = tonic::transport::Channel::from_shared(endpoint_str.clone())
781 .expect("valid endpoint")
782 .connect_lazy();
783 let transport = GrpcTransport {
784 inner: Arc::new(Inner {
785 channel,
786 endpoint: endpoint_str.clone(),
787 config: GrpcTransportConfig::default(),
788 }),
789 };
790 assert_eq!(transport.endpoint(), endpoint_str);
791 }
792
793 #[tokio::test]
794 async fn grpc_transport_endpoint_preserves_distinct_urls() {
795 let a = "http://example.com:1234".to_string();
796 let b = "https://other.test:9000".to_string();
797 let mk = |s: String| {
798 let ch = tonic::transport::Channel::from_shared(s.clone())
799 .unwrap()
800 .connect_lazy();
801 GrpcTransport {
802 inner: Arc::new(Inner {
803 channel: ch,
804 endpoint: s,
805 config: GrpcTransportConfig::default(),
806 }),
807 }
808 };
809 let ta = mk(a.clone());
810 let tb = mk(b.clone());
811 assert_eq!(ta.endpoint(), a);
812 assert_eq!(tb.endpoint(), b);
813 assert_ne!(ta.endpoint(), tb.endpoint());
814 }
815}