1use async_trait::async_trait;
4use futures::stream::Stream;
5use reqwest::{
6 Client,
7 header::{HeaderMap, HeaderValue},
8};
9use std::{pin::Pin, sync::Arc, time::Duration};
10
11#[cfg(feature = "tracing")]
12use tracing::{debug, instrument};
13
14use crate::{
15 adapter::error::HttpClientError,
16 adapter::transport::codec::stream_response_to_item,
17 domain::{
18 A2AError, AgentCard, ListTasksParams, ListTasksResult, Message, Task,
19 TaskPushNotificationConfig,
20 generated::{
21 A2aServiceClient, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest,
22 GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest,
23 ListTaskPushNotificationConfigsRequest, ListTasksRequest, SendMessageConfiguration,
24 SendMessageRequest, SubscribeToTaskRequest, TaskState, send_message_response,
25 },
26 },
27 port::{StreamEvent, Transport},
28};
29
30fn map_connect_err(err: connectrpc::ConnectError) -> A2AError {
31 let code = match err.code {
32 connectrpc::ErrorCode::NotFound => crate::domain::error::TASK_NOT_FOUND,
33 connectrpc::ErrorCode::Unimplemented => crate::domain::error::METHOD_NOT_FOUND,
34 connectrpc::ErrorCode::InvalidArgument => crate::domain::error::INVALID_PARAMS,
35 connectrpc::ErrorCode::Internal => crate::domain::error::INTERNAL_ERROR,
36 connectrpc::ErrorCode::FailedPrecondition => {
37 crate::domain::error::AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED
38 }
39 _ => {
40 let code_val = err.code as i32;
41 if code_val != 0 {
42 code_val
43 } else {
44 crate::domain::error::INTERNAL_ERROR
45 }
46 }
47 };
48 A2AError::JsonRpc {
49 code,
50 message: err.message.clone().unwrap_or_default(),
51 data: None,
52 }
53}
54
55pub struct HttpClient {
57 base_url: String,
59 client: Client,
61 connect_client: A2aServiceClient<connectrpc::client::HttpClient>,
63 auth_token: Option<String>,
65 timeout: u64,
67}
68
69impl HttpClient {
70 pub fn new(base_url: String) -> Self {
72 let uri = base_url.parse::<http::Uri>().expect("Invalid base URL");
73 let is_https = uri.scheme_str() == Some("https");
74
75 let transport = if is_https {
76 let _ = rustls::crypto::ring::default_provider().install_default();
77 let mut root_store = rustls::RootCertStore::empty();
78 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
79 let tls_config = rustls::ClientConfig::builder()
80 .with_root_certificates(root_store)
81 .with_no_client_auth();
82 connectrpc::client::HttpClient::with_tls(Arc::new(tls_config))
83 } else {
84 connectrpc::client::HttpClient::plaintext()
85 };
86
87 let mut config = connectrpc::client::ClientConfig::new(uri);
88 config = config.default_timeout(Duration::from_secs(30));
89
90 let connect_client = A2aServiceClient::new(transport, config);
91
92 Self {
93 base_url,
94 client: Client::new(),
95 connect_client,
96 auth_token: None,
97 timeout: 30,
98 }
99 }
100
101 pub fn with_auth(base_url: String, auth_token: String) -> Self {
103 let uri = base_url.parse::<http::Uri>().expect("Invalid base URL");
104 let is_https = uri.scheme_str() == Some("https");
105
106 let transport = if is_https {
107 let _ = rustls::crypto::ring::default_provider().install_default();
108 let mut root_store = rustls::RootCertStore::empty();
109 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
110 let tls_config = rustls::ClientConfig::builder()
111 .with_root_certificates(root_store)
112 .with_no_client_auth();
113 connectrpc::client::HttpClient::with_tls(Arc::new(tls_config))
114 } else {
115 connectrpc::client::HttpClient::plaintext()
116 };
117
118 let mut config = connectrpc::client::ClientConfig::new(uri);
119 config = config
120 .default_timeout(Duration::from_secs(30))
121 .default_header("authorization", format!("Bearer {}", auth_token));
122
123 let connect_client = A2aServiceClient::new(transport, config);
124
125 Self {
126 base_url,
127 client: Client::new(),
128 connect_client,
129 auth_token: Some(auth_token),
130 timeout: 30,
131 }
132 }
133
134 pub fn with_timeout(mut self, timeout: u64) -> Self {
136 self.timeout = timeout;
137 *self.connect_client.config_mut() = self
138 .connect_client
139 .config()
140 .clone()
141 .default_timeout(Duration::from_secs(timeout));
142 self
143 }
144
145 fn get_headers(&self) -> Result<HeaderMap, A2AError> {
147 let mut headers = HeaderMap::new();
148 headers.insert(
149 reqwest::header::CONTENT_TYPE,
150 HeaderValue::from_static("application/json"),
151 );
152
153 if let Some(token) = &self.auth_token {
154 let auth_value = HeaderValue::from_str(&format!("Bearer {}", token)).map_err(|e| {
155 A2AError::Internal(format!("Invalid auth token for HTTP header: {}", e))
156 })?;
157 headers.insert(reqwest::header::AUTHORIZATION, auth_value);
158 }
159
160 Ok(headers)
161 }
162
163 pub fn base_url(&self) -> &str {
165 &self.base_url
166 }
167
168 pub async fn get_agent_card(&self) -> Result<AgentCard, A2AError> {
170 let url = if self.base_url.ends_with('/') {
171 format!("{}agent-card", self.base_url)
172 } else {
173 match reqwest::Url::parse(&self.base_url) {
174 Ok(parsed) => {
175 if !parsed.path().ends_with('/') {
176 match parsed.join("/agent-card") {
177 Ok(resolved) => resolved.to_string(),
178 Err(_) => format!("{}/agent-card", self.base_url),
179 }
180 } else {
181 match parsed.join("agent-card") {
182 Ok(resolved) => resolved.to_string(),
183 Err(_) => format!("{}/agent-card", self.base_url),
184 }
185 }
186 }
187 Err(_) => format!("{}/agent-card", self.base_url),
188 }
189 };
190
191 #[cfg(feature = "tracing")]
192 debug!("Fetching agent card from URL: {}", url);
193
194 let response = self
195 .client
196 .get(&url)
197 .headers(self.get_headers()?)
198 .timeout(Duration::from_secs(self.timeout))
199 .send()
200 .await
201 .map_err(HttpClientError::Reqwest)?;
202
203 if response.status().is_success() {
204 let card: AgentCard = response.json().await.map_err(|e| {
205 A2AError::Internal(format!("Failed to parse agent card JSON: {}", e))
206 })?;
207 Ok(card)
208 } else {
209 let status = response.status();
210 let body = response.text().await.unwrap_or_default();
211 Err(HttpClientError::Response {
212 status: status.as_u16(),
213 message: body,
214 }
215 .into())
216 }
217 }
218
219 pub async fn get_extended_agent_card(
221 &self,
222 tenant: Option<String>,
223 ) -> Result<AgentCard, A2AError> {
224 let request = GetExtendedAgentCardRequest {
225 tenant: tenant.unwrap_or_default(),
226 ..Default::default()
227 };
228 let response = self
229 .connect_client
230 .get_extended_agent_card(request)
231 .await
232 .map_err(map_connect_err)?;
233 Ok(response.into_owned())
234 }
235}
236
237#[async_trait]
238impl Transport for HttpClient {
239 fn protocol(&self) -> &str {
240 "CONNECTRPC"
241 }
242
243 #[cfg_attr(
244 feature = "tracing",
245 instrument(skip(self, message), fields(task_id, session_id, history_length))
246 )]
247 async fn send_task_message(
248 &self,
249 task_id: &str,
250 message: &Message,
251 session_id: Option<&str>,
252 history_length: Option<u32>,
253 ) -> Result<Task, A2AError> {
254 let mut msg = message.clone();
255 msg.task_id = task_id.to_string();
256 if let Some(sid) = session_id {
257 msg.context_id = sid.to_string();
258 }
259
260 let config = SendMessageConfiguration {
261 history_length: history_length.map(|l| l as i32),
262 ..Default::default()
263 };
264
265 let request = SendMessageRequest {
266 message: ::buffa::MessageField::some(msg),
267 configuration: ::buffa::MessageField::some(config),
268 ..Default::default()
269 };
270
271 let response = self
272 .connect_client
273 .send_message(request)
274 .await
275 .map_err(map_connect_err)?;
276 let owned_response = response.into_owned();
277
278 match owned_response.payload {
279 Some(send_message_response::Payload::Task(task)) => Ok(*task),
280 _ => Err(A2AError::Internal(
281 "Expected task in SendMessageResponse payload".to_string(),
282 )),
283 }
284 }
285
286 #[cfg_attr(
287 feature = "tracing",
288 instrument(skip(self), fields(task_id, history_length))
289 )]
290 async fn get_task(&self, task_id: &str, history_length: Option<u32>) -> Result<Task, A2AError> {
291 let request = GetTaskRequest {
292 id: task_id.to_string(),
293 history_length: history_length.map(|l| l as i32),
294 ..Default::default()
295 };
296 let response = self
297 .connect_client
298 .get_task(request)
299 .await
300 .map_err(map_connect_err)?;
301 Ok(response.into_owned())
302 }
303
304 #[cfg_attr(feature = "tracing", instrument(skip(self), fields(task_id)))]
305 async fn cancel_task(&self, task_id: &str) -> Result<Task, A2AError> {
306 let request = CancelTaskRequest {
307 id: task_id.to_string(),
308 ..Default::default()
309 };
310 let response = self
311 .connect_client
312 .cancel_task(request)
313 .await
314 .map_err(map_connect_err)?;
315 Ok(response.into_owned())
316 }
317
318 async fn set_task_push_notification(
319 &self,
320 config: &TaskPushNotificationConfig,
321 ) -> Result<TaskPushNotificationConfig, A2AError> {
322 let request = config.clone();
323 let response = self
324 .connect_client
325 .create_task_push_notification_config(request)
326 .await
327 .map_err(map_connect_err)?;
328 Ok(response.into_owned())
329 }
330
331 async fn get_task_push_notification(
332 &self,
333 task_id: &str,
334 ) -> Result<TaskPushNotificationConfig, A2AError> {
335 let request = ListTaskPushNotificationConfigsRequest {
336 task_id: task_id.to_string(),
337 ..Default::default()
338 };
339 let response = self
340 .connect_client
341 .list_task_push_notification_configs(request)
342 .await
343 .map_err(map_connect_err)?;
344 let configs = response.into_owned().configs;
345 if let Some(config) = configs.into_iter().next() {
346 Ok(config)
347 } else {
348 Err(A2AError::TaskNotFound(format!(
349 "No push notification config found for task {}",
350 task_id
351 )))
352 }
353 }
354
355 #[cfg_attr(feature = "tracing", instrument(skip(self, params)))]
356 async fn list_tasks(&self, params: &ListTasksParams) -> Result<ListTasksResult, A2AError> {
357 let mut request = ListTasksRequest {
358 context_id: params.context_id.clone().unwrap_or_default(),
359 status: ::buffa::EnumValue::from(
360 params.status.unwrap_or(TaskState::TASK_STATE_UNSPECIFIED),
361 ),
362 page_size: params.page_size,
363 page_token: params.page_token.clone().unwrap_or_default(),
364 history_length: params.history_length,
365 include_artifacts: params.include_artifacts,
366 ..Default::default()
367 };
368 if let Some(ref t_str) = params.status_timestamp_after {
369 if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(t_str) {
370 let utc_dt = dt.with_timezone(&chrono::Utc);
371 request.status_timestamp_after =
372 ::buffa::MessageField::some(::buffa_types::google::protobuf::Timestamp {
373 seconds: utc_dt.timestamp(),
374 nanos: utc_dt.timestamp_subsec_nanos() as i32,
375 ..Default::default()
376 });
377 }
378 }
379
380 let response = self
381 .connect_client
382 .list_tasks(request)
383 .await
384 .map_err(map_connect_err)?;
385 let owned = response.into_owned();
386 Ok(ListTasksResult {
387 tasks: owned.tasks,
388 total_size: owned.total_size,
389 page_size: owned.page_size,
390 next_page_token: owned.next_page_token,
391 })
392 }
393
394 async fn list_push_notification_configs(
395 &self,
396 task_id: &str,
397 ) -> Result<Vec<TaskPushNotificationConfig>, A2AError> {
398 let request = ListTaskPushNotificationConfigsRequest {
399 task_id: task_id.to_string(),
400 ..Default::default()
401 };
402 let response = self
403 .connect_client
404 .list_task_push_notification_configs(request)
405 .await
406 .map_err(map_connect_err)?;
407 Ok(response.into_owned().configs)
408 }
409
410 async fn get_push_notification_config(
411 &self,
412 task_id: &str,
413 config_id: &str,
414 ) -> Result<TaskPushNotificationConfig, A2AError> {
415 let request = GetTaskPushNotificationConfigRequest {
416 task_id: task_id.to_string(),
417 id: config_id.to_string(),
418 ..Default::default()
419 };
420 let response = self
421 .connect_client
422 .get_task_push_notification_config(request)
423 .await
424 .map_err(map_connect_err)?;
425 Ok(response.into_owned())
426 }
427
428 async fn delete_push_notification_config(
429 &self,
430 task_id: &str,
431 config_id: &str,
432 ) -> Result<(), A2AError> {
433 let request = DeleteTaskPushNotificationConfigRequest {
434 task_id: task_id.to_string(),
435 id: config_id.to_string(),
436 ..Default::default()
437 };
438 self.connect_client
439 .delete_task_push_notification_config(request)
440 .await
441 .map_err(map_connect_err)?;
442 Ok(())
443 }
444
445 async fn subscribe_to_task(
446 &self,
447 task_id: &str,
448 _history_length: Option<u32>,
449 _last_event_id: Option<&str>,
452 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, A2AError>> + Send>>, A2AError> {
453 let request = SubscribeToTaskRequest {
454 id: task_id.to_string(),
455 ..Default::default()
456 };
457 let stream = self
458 .connect_client
459 .subscribe_to_task(request)
460 .await
461 .map_err(map_connect_err)?;
462
463 let mapped = futures::stream::unfold(stream, |mut s| async move {
464 match s.message().await {
465 Ok(Some(view)) => {
466 let resp = view.to_owned_message();
467 if let Some(item) = stream_response_to_item(resp) {
468 Some((Ok(StreamEvent::untagged(item)), s))
469 } else {
470 Some((
471 Err(A2AError::Internal(
472 "Empty or unhandled stream response payload".to_string(),
473 )),
474 s,
475 ))
476 }
477 }
478 Ok(None) => None,
479 Err(e) => Some((Err(map_connect_err(e)), s)),
480 }
481 });
482
483 Ok(Box::pin(mapped))
484 }
485}