Skip to main content

a2a_rs/adapter/transport/http/
client.rs

1//! HTTP client adapter for the A2A protocol using ConnectRPC
2
3use async_trait::async_trait;
4use futures::stream::Stream;
5use reqwest::{
6    Client,
7    header::{HeaderMap, HeaderValue},
8};
9use std::{pin::Pin, sync::Arc, time::Duration};
10
11#[cfg(feature = "tracing")]
12use tracing::{debug, instrument};
13
14use crate::{
15    adapter::error::HttpClientError,
16    adapter::transport::codec::stream_response_to_item,
17    domain::{
18        A2AError, AgentCard, ListTasksParams, ListTasksResult, Message, Task,
19        TaskPushNotificationConfig,
20        generated::{
21            A2aServiceClient, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest,
22            GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest,
23            ListTaskPushNotificationConfigsRequest, ListTasksRequest, SendMessageConfiguration,
24            SendMessageRequest, SubscribeToTaskRequest, TaskState, send_message_response,
25        },
26    },
27    port::{StreamEvent, Transport},
28};
29
30fn map_connect_err(err: connectrpc::ConnectError) -> A2AError {
31    let code = match err.code {
32        connectrpc::ErrorCode::NotFound => crate::domain::error::TASK_NOT_FOUND,
33        connectrpc::ErrorCode::Unimplemented => crate::domain::error::METHOD_NOT_FOUND,
34        connectrpc::ErrorCode::InvalidArgument => crate::domain::error::INVALID_PARAMS,
35        connectrpc::ErrorCode::Internal => crate::domain::error::INTERNAL_ERROR,
36        connectrpc::ErrorCode::FailedPrecondition => {
37            crate::domain::error::AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED
38        }
39        _ => {
40            let code_val = err.code as i32;
41            if code_val != 0 {
42                code_val
43            } else {
44                crate::domain::error::INTERNAL_ERROR
45            }
46        }
47    };
48    A2AError::JsonRpc {
49        code,
50        message: err.message.clone().unwrap_or_default(),
51        data: None,
52    }
53}
54
55/// HTTP client for interacting with the A2A protocol via ConnectRPC
56pub struct HttpClient {
57    /// Base URL of the A2A API
58    base_url: String,
59    /// reqwest Client for standard GET operations like agent card
60    client: Client,
61    /// ConnectRPC Client
62    connect_client: A2aServiceClient<connectrpc::client::HttpClient>,
63    /// Authorization token, if any
64    auth_token: Option<String>,
65    /// Timeout in seconds
66    timeout: u64,
67}
68
69impl HttpClient {
70    /// Create a new HTTP client with the given base URL
71    pub fn new(base_url: String) -> Self {
72        let uri = base_url.parse::<http::Uri>().expect("Invalid base URL");
73        let is_https = uri.scheme_str() == Some("https");
74
75        let transport = if is_https {
76            let _ = rustls::crypto::ring::default_provider().install_default();
77            let mut root_store = rustls::RootCertStore::empty();
78            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
79            let tls_config = rustls::ClientConfig::builder()
80                .with_root_certificates(root_store)
81                .with_no_client_auth();
82            connectrpc::client::HttpClient::with_tls(Arc::new(tls_config))
83        } else {
84            connectrpc::client::HttpClient::plaintext()
85        };
86
87        let mut config = connectrpc::client::ClientConfig::new(uri);
88        config = config.default_timeout(Duration::from_secs(30));
89
90        let connect_client = A2aServiceClient::new(transport, config);
91
92        Self {
93            base_url,
94            client: Client::new(),
95            connect_client,
96            auth_token: None,
97            timeout: 30,
98        }
99    }
100
101    /// Create a new HTTP client with authentication
102    pub fn with_auth(base_url: String, auth_token: String) -> Self {
103        let uri = base_url.parse::<http::Uri>().expect("Invalid base URL");
104        let is_https = uri.scheme_str() == Some("https");
105
106        let transport = if is_https {
107            let _ = rustls::crypto::ring::default_provider().install_default();
108            let mut root_store = rustls::RootCertStore::empty();
109            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
110            let tls_config = rustls::ClientConfig::builder()
111                .with_root_certificates(root_store)
112                .with_no_client_auth();
113            connectrpc::client::HttpClient::with_tls(Arc::new(tls_config))
114        } else {
115            connectrpc::client::HttpClient::plaintext()
116        };
117
118        let mut config = connectrpc::client::ClientConfig::new(uri);
119        config = config
120            .default_timeout(Duration::from_secs(30))
121            .default_header("authorization", format!("Bearer {}", auth_token));
122
123        let connect_client = A2aServiceClient::new(transport, config);
124
125        Self {
126            base_url,
127            client: Client::new(),
128            connect_client,
129            auth_token: Some(auth_token),
130            timeout: 30,
131        }
132    }
133
134    /// Set the timeout for requests
135    pub fn with_timeout(mut self, timeout: u64) -> Self {
136        self.timeout = timeout;
137        *self.connect_client.config_mut() = self
138            .connect_client
139            .config()
140            .clone()
141            .default_timeout(Duration::from_secs(timeout));
142        self
143    }
144
145    /// Get the headers for a request (used for reqwest)
146    fn get_headers(&self) -> Result<HeaderMap, A2AError> {
147        let mut headers = HeaderMap::new();
148        headers.insert(
149            reqwest::header::CONTENT_TYPE,
150            HeaderValue::from_static("application/json"),
151        );
152
153        if let Some(token) = &self.auth_token {
154            let auth_value = HeaderValue::from_str(&format!("Bearer {}", token)).map_err(|e| {
155                A2AError::Internal(format!("Invalid auth token for HTTP header: {}", e))
156            })?;
157            headers.insert(reqwest::header::AUTHORIZATION, auth_value);
158        }
159
160        Ok(headers)
161    }
162
163    /// Get the base URL of the client
164    pub fn base_url(&self) -> &str {
165        &self.base_url
166    }
167
168    /// Fetch the agent card from the agent's `/agent-card` endpoint (plain HTTP GET)
169    pub async fn get_agent_card(&self) -> Result<AgentCard, A2AError> {
170        let url = if self.base_url.ends_with('/') {
171            format!("{}agent-card", self.base_url)
172        } else {
173            match reqwest::Url::parse(&self.base_url) {
174                Ok(parsed) => {
175                    if !parsed.path().ends_with('/') {
176                        match parsed.join("/agent-card") {
177                            Ok(resolved) => resolved.to_string(),
178                            Err(_) => format!("{}/agent-card", self.base_url),
179                        }
180                    } else {
181                        match parsed.join("agent-card") {
182                            Ok(resolved) => resolved.to_string(),
183                            Err(_) => format!("{}/agent-card", self.base_url),
184                        }
185                    }
186                }
187                Err(_) => format!("{}/agent-card", self.base_url),
188            }
189        };
190
191        #[cfg(feature = "tracing")]
192        debug!("Fetching agent card from URL: {}", url);
193
194        let response = self
195            .client
196            .get(&url)
197            .headers(self.get_headers()?)
198            .timeout(Duration::from_secs(self.timeout))
199            .send()
200            .await
201            .map_err(HttpClientError::Reqwest)?;
202
203        if response.status().is_success() {
204            let card: AgentCard = response.json().await.map_err(|e| {
205                A2AError::Internal(format!("Failed to parse agent card JSON: {}", e))
206            })?;
207            Ok(card)
208        } else {
209            let status = response.status();
210            let body = response.text().await.unwrap_or_default();
211            Err(HttpClientError::Response {
212                status: status.as_u16(),
213                message: body,
214            }
215            .into())
216        }
217    }
218
219    /// Fetch the extended agent card using ConnectRPC
220    pub async fn get_extended_agent_card(
221        &self,
222        tenant: Option<String>,
223    ) -> Result<AgentCard, A2AError> {
224        let request = GetExtendedAgentCardRequest {
225            tenant: tenant.unwrap_or_default(),
226            ..Default::default()
227        };
228        let response = self
229            .connect_client
230            .get_extended_agent_card(request)
231            .await
232            .map_err(map_connect_err)?;
233        Ok(response.into_owned())
234    }
235}
236
237#[async_trait]
238impl Transport for HttpClient {
239    fn protocol(&self) -> &str {
240        "CONNECTRPC"
241    }
242
243    #[cfg_attr(
244        feature = "tracing",
245        instrument(skip(self, message), fields(task_id, session_id, history_length))
246    )]
247    async fn send_task_message(
248        &self,
249        task_id: &str,
250        message: &Message,
251        session_id: Option<&str>,
252        history_length: Option<u32>,
253    ) -> Result<Task, A2AError> {
254        let mut msg = message.clone();
255        msg.task_id = task_id.to_string();
256        if let Some(sid) = session_id {
257            msg.context_id = sid.to_string();
258        }
259
260        let config = SendMessageConfiguration {
261            history_length: history_length.map(|l| l as i32),
262            ..Default::default()
263        };
264
265        let request = SendMessageRequest {
266            message: ::buffa::MessageField::some(msg),
267            configuration: ::buffa::MessageField::some(config),
268            ..Default::default()
269        };
270
271        let response = self
272            .connect_client
273            .send_message(request)
274            .await
275            .map_err(map_connect_err)?;
276        let owned_response = response.into_owned();
277
278        match owned_response.payload {
279            Some(send_message_response::Payload::Task(task)) => Ok(*task),
280            _ => Err(A2AError::Internal(
281                "Expected task in SendMessageResponse payload".to_string(),
282            )),
283        }
284    }
285
286    #[cfg_attr(
287        feature = "tracing",
288        instrument(skip(self), fields(task_id, history_length))
289    )]
290    async fn get_task(&self, task_id: &str, history_length: Option<u32>) -> Result<Task, A2AError> {
291        let request = GetTaskRequest {
292            id: task_id.to_string(),
293            history_length: history_length.map(|l| l as i32),
294            ..Default::default()
295        };
296        let response = self
297            .connect_client
298            .get_task(request)
299            .await
300            .map_err(map_connect_err)?;
301        Ok(response.into_owned())
302    }
303
304    #[cfg_attr(feature = "tracing", instrument(skip(self), fields(task_id)))]
305    async fn cancel_task(&self, task_id: &str) -> Result<Task, A2AError> {
306        let request = CancelTaskRequest {
307            id: task_id.to_string(),
308            ..Default::default()
309        };
310        let response = self
311            .connect_client
312            .cancel_task(request)
313            .await
314            .map_err(map_connect_err)?;
315        Ok(response.into_owned())
316    }
317
318    async fn set_task_push_notification(
319        &self,
320        config: &TaskPushNotificationConfig,
321    ) -> Result<TaskPushNotificationConfig, A2AError> {
322        let request = config.clone();
323        let response = self
324            .connect_client
325            .create_task_push_notification_config(request)
326            .await
327            .map_err(map_connect_err)?;
328        Ok(response.into_owned())
329    }
330
331    async fn get_task_push_notification(
332        &self,
333        task_id: &str,
334    ) -> Result<TaskPushNotificationConfig, A2AError> {
335        let request = ListTaskPushNotificationConfigsRequest {
336            task_id: task_id.to_string(),
337            ..Default::default()
338        };
339        let response = self
340            .connect_client
341            .list_task_push_notification_configs(request)
342            .await
343            .map_err(map_connect_err)?;
344        let configs = response.into_owned().configs;
345        if let Some(config) = configs.into_iter().next() {
346            Ok(config)
347        } else {
348            Err(A2AError::TaskNotFound(format!(
349                "No push notification config found for task {}",
350                task_id
351            )))
352        }
353    }
354
355    #[cfg_attr(feature = "tracing", instrument(skip(self, params)))]
356    async fn list_tasks(&self, params: &ListTasksParams) -> Result<ListTasksResult, A2AError> {
357        let mut request = ListTasksRequest {
358            context_id: params.context_id.clone().unwrap_or_default(),
359            status: ::buffa::EnumValue::from(
360                params.status.unwrap_or(TaskState::TASK_STATE_UNSPECIFIED),
361            ),
362            page_size: params.page_size,
363            page_token: params.page_token.clone().unwrap_or_default(),
364            history_length: params.history_length,
365            include_artifacts: params.include_artifacts,
366            ..Default::default()
367        };
368        if let Some(ref t_str) = params.status_timestamp_after {
369            if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(t_str) {
370                let utc_dt = dt.with_timezone(&chrono::Utc);
371                request.status_timestamp_after =
372                    ::buffa::MessageField::some(::buffa_types::google::protobuf::Timestamp {
373                        seconds: utc_dt.timestamp(),
374                        nanos: utc_dt.timestamp_subsec_nanos() as i32,
375                        ..Default::default()
376                    });
377            }
378        }
379
380        let response = self
381            .connect_client
382            .list_tasks(request)
383            .await
384            .map_err(map_connect_err)?;
385        let owned = response.into_owned();
386        Ok(ListTasksResult {
387            tasks: owned.tasks,
388            total_size: owned.total_size,
389            page_size: owned.page_size,
390            next_page_token: owned.next_page_token,
391        })
392    }
393
394    async fn list_push_notification_configs(
395        &self,
396        task_id: &str,
397    ) -> Result<Vec<TaskPushNotificationConfig>, A2AError> {
398        let request = ListTaskPushNotificationConfigsRequest {
399            task_id: task_id.to_string(),
400            ..Default::default()
401        };
402        let response = self
403            .connect_client
404            .list_task_push_notification_configs(request)
405            .await
406            .map_err(map_connect_err)?;
407        Ok(response.into_owned().configs)
408    }
409
410    async fn get_push_notification_config(
411        &self,
412        task_id: &str,
413        config_id: &str,
414    ) -> Result<TaskPushNotificationConfig, A2AError> {
415        let request = GetTaskPushNotificationConfigRequest {
416            task_id: task_id.to_string(),
417            id: config_id.to_string(),
418            ..Default::default()
419        };
420        let response = self
421            .connect_client
422            .get_task_push_notification_config(request)
423            .await
424            .map_err(map_connect_err)?;
425        Ok(response.into_owned())
426    }
427
428    async fn delete_push_notification_config(
429        &self,
430        task_id: &str,
431        config_id: &str,
432    ) -> Result<(), A2AError> {
433        let request = DeleteTaskPushNotificationConfigRequest {
434            task_id: task_id.to_string(),
435            id: config_id.to_string(),
436            ..Default::default()
437        };
438        self.connect_client
439            .delete_task_push_notification_config(request)
440            .await
441            .map_err(map_connect_err)?;
442        Ok(())
443    }
444
445    async fn subscribe_to_task(
446        &self,
447        task_id: &str,
448        _history_length: Option<u32>,
449        // ConnectRPC streaming has no SSE `Last-Event-ID`; resumption is not
450        // supported on this transport, so the hint is ignored.
451        _last_event_id: Option<&str>,
452    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, A2AError>> + Send>>, A2AError> {
453        let request = SubscribeToTaskRequest {
454            id: task_id.to_string(),
455            ..Default::default()
456        };
457        let stream = self
458            .connect_client
459            .subscribe_to_task(request)
460            .await
461            .map_err(map_connect_err)?;
462
463        let mapped = futures::stream::unfold(stream, |mut s| async move {
464            match s.message().await {
465                Ok(Some(view)) => {
466                    let resp = view.to_owned_message();
467                    if let Some(item) = stream_response_to_item(resp) {
468                        Some((Ok(StreamEvent::untagged(item)), s))
469                    } else {
470                        Some((
471                            Err(A2AError::Internal(
472                                "Empty or unhandled stream response payload".to_string(),
473                            )),
474                            s,
475                        ))
476                    }
477                }
478                Ok(None) => None,
479                Err(e) => Some((Err(map_connect_err(e)), s)),
480            }
481        });
482
483        Ok(Box::pin(mapped))
484    }
485}