a2a_protocol_client/transport/
grpc.rs1use std::collections::HashMap;
33use std::future::Future;
34use std::pin::Pin;
35use std::sync::Arc;
36use std::time::Duration;
37
38use tokio::sync::mpsc;
39use tonic::transport::Channel;
40
41use crate::error::{ClientError, ClientResult};
42use crate::streaming::EventStream;
43use crate::transport::Transport;
44
45mod proto {
47 #![allow(
48 clippy::all,
49 clippy::pedantic,
50 clippy::nursery,
51 missing_docs,
52 unused_qualifications
53 )]
54 tonic::include_proto!("a2a.v1");
55}
56
57use proto::a2a_service_client::A2aServiceClient;
58use proto::JsonPayload;
59
60#[derive(Debug, Clone)]
75pub struct GrpcTransportConfig {
76 pub timeout: Duration,
78 pub connect_timeout: Duration,
80 pub max_message_size: usize,
82 pub stream_channel_capacity: usize,
84}
85
86impl Default for GrpcTransportConfig {
87 fn default() -> Self {
88 Self {
89 timeout: Duration::from_secs(30),
90 connect_timeout: Duration::from_secs(10),
91 max_message_size: 4 * 1024 * 1024,
92 stream_channel_capacity: 64,
93 }
94 }
95}
96
97impl GrpcTransportConfig {
98 #[must_use]
100 pub const fn with_timeout(mut self, timeout: Duration) -> Self {
101 self.timeout = timeout;
102 self
103 }
104
105 #[must_use]
107 pub const fn with_connect_timeout(mut self, timeout: Duration) -> Self {
108 self.connect_timeout = timeout;
109 self
110 }
111
112 #[must_use]
114 pub const fn with_max_message_size(mut self, size: usize) -> Self {
115 self.max_message_size = size;
116 self
117 }
118
119 #[must_use]
121 pub const fn with_stream_channel_capacity(mut self, capacity: usize) -> Self {
122 self.stream_channel_capacity = capacity;
123 self
124 }
125}
126
127#[derive(Clone, Debug)]
135pub struct GrpcTransport {
136 inner: Arc<Inner>,
137}
138
139#[derive(Debug)]
140struct Inner {
141 channel: Channel,
145 endpoint: String,
146 config: GrpcTransportConfig,
147}
148
149impl GrpcTransport {
150 pub async fn connect(endpoint: impl Into<String>) -> ClientResult<Self> {
158 Self::connect_with_config(endpoint, GrpcTransportConfig::default()).await
159 }
160
161 pub async fn connect_with_config(
167 endpoint: impl Into<String>,
168 config: GrpcTransportConfig,
169 ) -> ClientResult<Self> {
170 let endpoint_str = endpoint.into();
171 validate_url(&endpoint_str)?;
172
173 let channel = tonic::transport::Channel::from_shared(endpoint_str.clone())
174 .map_err(|e| ClientError::InvalidEndpoint(format!("invalid gRPC endpoint: {e}")))?
175 .connect_timeout(config.connect_timeout)
176 .timeout(config.timeout)
177 .connect()
178 .await
179 .map_err(|e| ClientError::Transport(format!("gRPC connect failed: {e}")))?;
180
181 Ok(Self {
182 inner: Arc::new(Inner {
183 channel,
184 endpoint: endpoint_str,
185 config,
186 }),
187 })
188 }
189
190 #[must_use]
192 pub fn endpoint(&self) -> &str {
193 &self.inner.endpoint
194 }
195
196 fn encode_params(params: &serde_json::Value) -> ClientResult<JsonPayload> {
199 let data = serde_json::to_vec(params).map_err(ClientError::Serialization)?;
200 Ok(JsonPayload { data })
201 }
202
203 fn add_metadata(
204 req: &mut tonic::Request<JsonPayload>,
205 extra_headers: &HashMap<String, String>,
206 ) {
207 let md = req.metadata_mut();
208 md.insert(
209 "a2a-version",
210 a2a_protocol_types::A2A_VERSION
211 .parse()
212 .unwrap_or_else(|_| tonic::metadata::MetadataValue::from_static("")),
213 );
214 for (k, v) in extra_headers {
215 if let (Ok(key), Ok(val)) = (
216 k.parse::<tonic::metadata::MetadataKey<_>>(),
217 v.parse::<tonic::metadata::MetadataValue<_>>(),
218 ) {
219 md.insert(key, val);
220 }
221 }
222 }
223
224 fn decode_response(payload: &JsonPayload) -> ClientResult<serde_json::Value> {
225 serde_json::from_slice(&payload.data).map_err(ClientError::Serialization)
226 }
227
228 fn status_to_error(status: &tonic::Status) -> ClientError {
229 match status.code() {
232 tonic::Code::DeadlineExceeded => {
233 ClientError::Timeout(format!("gRPC deadline exceeded: {}", status.message()))
234 }
235 tonic::Code::Cancelled => {
236 ClientError::Timeout(format!("gRPC request cancelled: {}", status.message()))
237 }
238 tonic::Code::Unavailable => {
239 ClientError::HttpClient(format!("gRPC unavailable: {}", status.message()))
240 }
241 _ => {
242 let a2a = a2a_protocol_types::A2aError::new(
243 grpc_code_to_error_code(status.code()),
244 status.message().to_owned(),
245 );
246 ClientError::Protocol(a2a)
247 }
248 }
249 }
250
251 async fn execute_unary(
252 &self,
253 method: &str,
254 params: serde_json::Value,
255 extra_headers: &HashMap<String, String>,
256 ) -> ClientResult<serde_json::Value> {
257 trace_info!(
258 method,
259 endpoint = %self.inner.endpoint,
260 "sending gRPC request"
261 );
262
263 let payload = Self::encode_params(¶ms)?;
264 let mut req = tonic::Request::new(payload);
265 req.set_timeout(self.inner.config.timeout);
266 Self::add_metadata(&mut req, extra_headers);
267
268 let mut client = A2aServiceClient::new(self.inner.channel.clone())
272 .max_decoding_message_size(self.inner.config.max_message_size)
273 .max_encoding_message_size(self.inner.config.max_message_size);
274
275 let response = tokio::time::timeout(self.inner.config.timeout, async {
276 match method {
277 "SendMessage" => client.send_message(req).await,
278 "GetTask" => client.get_task(req).await,
279 "ListTasks" => client.list_tasks(req).await,
280 "CancelTask" => client.cancel_task(req).await,
281 "CreateTaskPushNotificationConfig" => {
282 client.create_task_push_notification_config(req).await
283 }
284 "GetTaskPushNotificationConfig" => {
285 client.get_task_push_notification_config(req).await
286 }
287 "ListTaskPushNotificationConfigs" => {
288 client.list_task_push_notification_configs(req).await
289 }
290 "DeleteTaskPushNotificationConfig" => {
291 client.delete_task_push_notification_config(req).await
292 }
293 "GetExtendedAgentCard" => client.get_extended_agent_card(req).await,
294 other => Err(tonic::Status::unimplemented(format!(
295 "unknown gRPC method: {other}"
296 ))),
297 }
298 })
299 .await
300 .map_err(|_| {
301 trace_error!(method, "gRPC request timed out");
302 ClientError::Timeout("gRPC request timed out".into())
303 })?;
304
305 match response {
306 Ok(resp) => Self::decode_response(&resp.into_inner()),
307 Err(status) => Err(Self::status_to_error(&status)),
308 }
309 }
310
311 async fn execute_streaming(
312 &self,
313 method: &str,
314 params: serde_json::Value,
315 extra_headers: &HashMap<String, String>,
316 ) -> ClientResult<EventStream> {
317 trace_info!(
318 method,
319 endpoint = %self.inner.endpoint,
320 "opening gRPC stream"
321 );
322
323 let payload = Self::encode_params(¶ms)?;
324 let mut req = tonic::Request::new(payload);
325 Self::add_metadata(&mut req, extra_headers);
326
327 let mut client = A2aServiceClient::new(self.inner.channel.clone())
329 .max_decoding_message_size(self.inner.config.max_message_size)
330 .max_encoding_message_size(self.inner.config.max_message_size);
331
332 let stream = tokio::time::timeout(self.inner.config.timeout, async {
333 let response = match method {
334 "SendStreamingMessage" => client.send_streaming_message(req).await,
335 "SubscribeToTask" => client.subscribe_to_task(req).await,
336 #[allow(clippy::needless_return)]
337 other => {
338 return Err(tonic::Status::unimplemented(format!(
339 "unknown streaming gRPC method: {other}"
340 )));
341 }
342 };
343 match response {
344 Ok(resp) => Ok(resp.into_inner()),
345 Err(status) => Err(status),
346 }
347 })
348 .await
349 .map_err(|_| {
350 trace_error!(method, "gRPC stream connect timed out");
351 ClientError::Timeout("gRPC stream connect timed out".into())
352 })?
353 .map_err(|status| Self::status_to_error(&status))?;
354
355 let cap = self.inner.config.stream_channel_capacity;
356 let (tx, rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(cap);
357
358 let task_handle = tokio::spawn(async move {
359 grpc_stream_reader_task(stream, tx).await;
360 });
361
362 Ok(EventStream::with_status(
365 rx,
366 task_handle.abort_handle(),
367 200,
368 ))
369 }
370}
371
372impl Transport for GrpcTransport {
373 fn send_request<'a>(
374 &'a self,
375 method: &'a str,
376 params: serde_json::Value,
377 extra_headers: &'a HashMap<String, String>,
378 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
379 Box::pin(self.execute_unary(method, params, extra_headers))
380 }
381
382 fn send_streaming_request<'a>(
383 &'a self,
384 method: &'a str,
385 params: serde_json::Value,
386 extra_headers: &'a HashMap<String, String>,
387 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
388 Box::pin(self.execute_streaming(method, params, extra_headers))
389 }
390}
391
392async fn grpc_stream_reader_task(
398 mut stream: tonic::Streaming<JsonPayload>,
399 tx: mpsc::Sender<crate::streaming::event_stream::BodyChunk>,
400) {
401 use tonic::codegen::tokio_stream::StreamExt;
402
403 loop {
404 match stream.next().await {
405 Some(Ok(payload)) => {
406 let json_str = match String::from_utf8(payload.data) {
410 Ok(s) => s,
411 Err(e) => {
412 let _ = tx
413 .send(Err(ClientError::Transport(format!(
414 "invalid UTF-8 in gRPC payload: {e}"
415 ))))
416 .await;
417 break;
418 }
419 };
420 let envelope =
422 format!("data: {{\"jsonrpc\":\"2.0\",\"id\":null,\"result\":{json_str}}}\n\n");
423 if tx
424 .send(Ok(hyper::body::Bytes::from(envelope)))
425 .await
426 .is_err()
427 {
428 break;
429 }
430 }
431 Some(Err(status)) => {
432 let a2a = a2a_protocol_types::A2aError::new(
436 grpc_code_to_error_code(status.code()),
437 status.message().to_owned(),
438 );
439 let _ = tx.send(Err(ClientError::Protocol(a2a))).await;
440 break;
441 }
442 None => break,
443 }
444 }
445}
446
447fn validate_url(url: &str) -> ClientResult<()> {
450 if url.is_empty() {
451 return Err(ClientError::InvalidEndpoint("URL must not be empty".into()));
452 }
453 if !url.starts_with("http://") && !url.starts_with("https://") {
454 return Err(ClientError::InvalidEndpoint(format!(
455 "URL must start with http:// or https://: {url}"
456 )));
457 }
458 Ok(())
459}
460
461const fn grpc_code_to_error_code(code: tonic::Code) -> a2a_protocol_types::ErrorCode {
462 match code {
463 tonic::Code::NotFound => a2a_protocol_types::ErrorCode::TaskNotFound,
464 tonic::Code::InvalidArgument
465 | tonic::Code::Unauthenticated
466 | tonic::Code::PermissionDenied
467 | tonic::Code::ResourceExhausted => a2a_protocol_types::ErrorCode::InvalidParams,
468 tonic::Code::Unimplemented => a2a_protocol_types::ErrorCode::MethodNotFound,
469 tonic::Code::FailedPrecondition => a2a_protocol_types::ErrorCode::TaskNotCancelable,
470 tonic::Code::DeadlineExceeded | tonic::Code::Cancelled => {
471 a2a_protocol_types::ErrorCode::InternalError
472 }
473 _ => a2a_protocol_types::ErrorCode::InternalError,
474 }
475}
476
477#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn validate_url_rejects_empty() {
485 assert!(validate_url("").is_err());
486 }
487
488 #[test]
489 fn validate_url_rejects_non_http() {
490 assert!(validate_url("ftp://example.com").is_err());
491 }
492
493 #[test]
494 fn validate_url_accepts_http() {
495 assert!(validate_url("http://localhost:50051").is_ok());
496 }
497
498 #[test]
499 fn config_default_timeout() {
500 let cfg = GrpcTransportConfig::default();
501 assert_eq!(cfg.timeout, Duration::from_secs(30));
502 }
503
504 #[test]
505 fn config_builder() {
506 let cfg = GrpcTransportConfig::default()
507 .with_timeout(Duration::from_secs(60))
508 .with_max_message_size(8 * 1024 * 1024)
509 .with_stream_channel_capacity(128);
510 assert_eq!(cfg.timeout, Duration::from_secs(60));
511 assert_eq!(cfg.max_message_size, 8 * 1024 * 1024);
512 assert_eq!(cfg.stream_channel_capacity, 128);
513 }
514
515 #[test]
516 fn grpc_code_not_found_maps_to_task_not_found() {
517 assert_eq!(
518 grpc_code_to_error_code(tonic::Code::NotFound),
519 a2a_protocol_types::ErrorCode::TaskNotFound,
520 );
521 }
522
523 #[test]
524 fn grpc_code_invalid_argument_maps_to_invalid_params() {
525 assert_eq!(
526 grpc_code_to_error_code(tonic::Code::InvalidArgument),
527 a2a_protocol_types::ErrorCode::InvalidParams,
528 );
529 }
530
531 #[test]
532 fn grpc_code_unauthenticated_maps_to_invalid_params() {
533 assert_eq!(
534 grpc_code_to_error_code(tonic::Code::Unauthenticated),
535 a2a_protocol_types::ErrorCode::InvalidParams,
536 );
537 }
538
539 #[test]
540 fn grpc_code_permission_denied_maps_to_invalid_params() {
541 assert_eq!(
542 grpc_code_to_error_code(tonic::Code::PermissionDenied),
543 a2a_protocol_types::ErrorCode::InvalidParams,
544 );
545 }
546
547 #[test]
548 fn grpc_code_resource_exhausted_maps_to_invalid_params() {
549 assert_eq!(
550 grpc_code_to_error_code(tonic::Code::ResourceExhausted),
551 a2a_protocol_types::ErrorCode::InvalidParams,
552 );
553 }
554
555 #[test]
556 fn grpc_code_unimplemented_maps_to_method_not_found() {
557 assert_eq!(
558 grpc_code_to_error_code(tonic::Code::Unimplemented),
559 a2a_protocol_types::ErrorCode::MethodNotFound,
560 );
561 }
562
563 #[test]
564 fn grpc_code_failed_precondition_maps_to_task_not_cancelable() {
565 assert_eq!(
566 grpc_code_to_error_code(tonic::Code::FailedPrecondition),
567 a2a_protocol_types::ErrorCode::TaskNotCancelable,
568 );
569 }
570
571 #[test]
572 fn grpc_code_deadline_exceeded_maps_to_internal() {
573 assert_eq!(
574 grpc_code_to_error_code(tonic::Code::DeadlineExceeded),
575 a2a_protocol_types::ErrorCode::InternalError,
576 );
577 }
578
579 #[test]
580 fn grpc_code_cancelled_maps_to_internal() {
581 assert_eq!(
582 grpc_code_to_error_code(tonic::Code::Cancelled),
583 a2a_protocol_types::ErrorCode::InternalError,
584 );
585 }
586
587 #[test]
588 fn grpc_code_unknown_maps_to_internal() {
589 assert_eq!(
590 grpc_code_to_error_code(tonic::Code::Unknown),
591 a2a_protocol_types::ErrorCode::InternalError,
592 );
593 }
594
595 #[test]
596 fn add_metadata_injects_a2a_version() {
597 let payload = JsonPayload { data: vec![] };
598 let mut req = tonic::Request::new(payload);
599 let headers = HashMap::new();
600 GrpcTransport::add_metadata(&mut req, &headers);
601 let md = req.metadata();
602 let version_value = md
603 .get("a2a-version")
604 .expect("a2a-version header should be present");
605 assert_eq!(
606 version_value.to_str().unwrap(),
607 a2a_protocol_types::A2A_VERSION,
608 );
609 }
610
611 #[test]
612 fn add_metadata_injects_extra_headers() {
613 let payload = JsonPayload { data: vec![] };
614 let mut req = tonic::Request::new(payload);
615 let mut headers = HashMap::new();
616 headers.insert("x-custom".to_string(), "value123".to_string());
617 GrpcTransport::add_metadata(&mut req, &headers);
618 let md = req.metadata();
619 assert_eq!(md.get("x-custom").unwrap().to_str().unwrap(), "value123",);
620 }
621
622 #[test]
625 fn status_to_error_deadline_exceeded_is_timeout() {
626 let status = tonic::Status::deadline_exceeded("test deadline");
627 let err = GrpcTransport::status_to_error(&status);
628 assert!(
629 matches!(err, ClientError::Timeout(_)),
630 "DeadlineExceeded should map to Timeout, got: {err:?}"
631 );
632 }
633
634 #[test]
635 fn status_to_error_cancelled_is_timeout() {
636 let status = tonic::Status::cancelled("test cancel");
637 let err = GrpcTransport::status_to_error(&status);
638 assert!(
639 matches!(err, ClientError::Timeout(_)),
640 "Cancelled should map to Timeout, got: {err:?}"
641 );
642 }
643
644 #[test]
645 fn status_to_error_unavailable_is_http_client() {
646 let status = tonic::Status::unavailable("test unavailable");
647 let err = GrpcTransport::status_to_error(&status);
648 assert!(
649 matches!(err, ClientError::HttpClient(_)),
650 "Unavailable should map to HttpClient, got: {err:?}"
651 );
652 }
653
654 #[test]
655 fn status_to_error_other_is_protocol() {
656 let status = tonic::Status::internal("test internal");
657 let err = GrpcTransport::status_to_error(&status);
658 assert!(
659 matches!(err, ClientError::Protocol(_)),
660 "other codes should map to Protocol, got: {err:?}"
661 );
662 }
663}