1use std::collections::HashMap;
27use std::future::Future;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::time::Duration;
31
32use http_body_util::{BodyExt, Full};
33use hyper::body::Bytes;
34use hyper::header;
35#[cfg(not(feature = "tls-rustls"))]
36use hyper_util::client::legacy::connect::HttpConnector;
37#[cfg(not(feature = "tls-rustls"))]
38use hyper_util::client::legacy::Client;
39#[cfg(not(feature = "tls-rustls"))]
40use hyper_util::rt::TokioExecutor;
41use tokio::sync::mpsc;
42
43use a2a_protocol_types::JsonRpcResponse;
44
45use crate::error::{ClientError, ClientResult};
46use crate::streaming::EventStream;
47use crate::transport::Transport;
48
49#[cfg(not(feature = "tls-rustls"))]
52type HttpClient = Client<HttpConnector, Full<Bytes>>;
53
54#[cfg(feature = "tls-rustls")]
55type HttpClient = crate::tls::HttpsClient;
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60enum HttpMethod {
61 Get,
62 Post,
63 Delete,
64}
65
66#[derive(Debug)]
67struct Route {
68 http_method: HttpMethod,
69 path_template: &'static str,
70 path_params: &'static [&'static str],
72 #[allow(dead_code)]
74 streaming: bool,
75}
76
77#[allow(clippy::too_many_lines)]
80fn route_for(method: &str) -> Option<Route> {
81 match method {
82 "SendMessage" => Some(Route {
83 http_method: HttpMethod::Post,
84 path_template: "/message:send",
85 path_params: &[],
86 streaming: false,
87 }),
88 "SendStreamingMessage" => Some(Route {
89 http_method: HttpMethod::Post,
90 path_template: "/message:stream",
91 path_params: &[],
92 streaming: true,
93 }),
94 "GetTask" => Some(Route {
95 http_method: HttpMethod::Get,
96 path_template: "/tasks/{id}",
97 path_params: &["id"],
98 streaming: false,
99 }),
100 "CancelTask" => Some(Route {
101 http_method: HttpMethod::Post,
102 path_template: "/tasks/{id}:cancel",
103 path_params: &["id"],
104 streaming: false,
105 }),
106 "ListTasks" => Some(Route {
107 http_method: HttpMethod::Get,
108 path_template: "/tasks",
109 path_params: &[],
110 streaming: false,
111 }),
112 "SubscribeToTask" => Some(Route {
113 http_method: HttpMethod::Post,
114 path_template: "/tasks/{id}:subscribe",
115 path_params: &["id"],
116 streaming: true,
117 }),
118 "CreateTaskPushNotificationConfig" => Some(Route {
119 http_method: HttpMethod::Post,
120 path_template: "/tasks/{taskId}/pushNotificationConfigs",
121 path_params: &["taskId"],
122 streaming: false,
123 }),
124 "GetTaskPushNotificationConfig" => Some(Route {
125 http_method: HttpMethod::Get,
126 path_template: "/tasks/{taskId}/pushNotificationConfigs/{id}",
127 path_params: &["taskId", "id"],
128 streaming: false,
129 }),
130 "ListTaskPushNotificationConfigs" => Some(Route {
131 http_method: HttpMethod::Get,
132 path_template: "/tasks/{taskId}/pushNotificationConfigs",
133 path_params: &["taskId"],
134 streaming: false,
135 }),
136 "DeleteTaskPushNotificationConfig" => Some(Route {
137 http_method: HttpMethod::Delete,
138 path_template: "/tasks/{taskId}/pushNotificationConfigs/{id}",
139 path_params: &["taskId", "id"],
140 streaming: false,
141 }),
142 "GetExtendedAgentCard" => Some(Route {
143 http_method: HttpMethod::Get,
144 path_template: "/extendedAgentCard",
145 path_params: &[],
146 streaming: false,
147 }),
148 _ => None,
149 }
150}
151
152#[derive(Clone, Debug)]
159pub struct RestTransport {
160 inner: Arc<Inner>,
161}
162
163#[derive(Debug)]
164struct Inner {
165 client: HttpClient,
166 base_url: String,
167 request_timeout: Duration,
168 stream_connect_timeout: Duration,
169}
170
171impl RestTransport {
172 pub fn new(base_url: impl Into<String>) -> ClientResult<Self> {
178 Self::with_timeout(base_url, Duration::from_secs(30))
179 }
180
181 pub fn with_timeout(
187 base_url: impl Into<String>,
188 request_timeout: Duration,
189 ) -> ClientResult<Self> {
190 Self::with_timeouts(base_url, request_timeout, request_timeout)
191 }
192
193 pub fn with_timeouts(
199 base_url: impl Into<String>,
200 request_timeout: Duration,
201 stream_connect_timeout: Duration,
202 ) -> ClientResult<Self> {
203 let base_url = base_url.into();
204 if base_url.is_empty()
205 || (!base_url.starts_with("http://") && !base_url.starts_with("https://"))
206 {
207 return Err(ClientError::InvalidEndpoint(format!(
208 "invalid base URL: {base_url}"
209 )));
210 }
211
212 #[cfg(not(feature = "tls-rustls"))]
213 let client = Client::builder(TokioExecutor::new()).build_http::<Full<Bytes>>();
214
215 #[cfg(feature = "tls-rustls")]
216 let client = crate::tls::build_https_client();
217
218 Ok(Self {
219 inner: Arc::new(Inner {
220 client,
221 base_url: base_url.trim_end_matches('/').to_owned(),
222 request_timeout,
223 stream_connect_timeout,
224 }),
225 })
226 }
227
228 #[must_use]
230 pub fn base_url(&self) -> &str {
231 &self.inner.base_url
232 }
233
234 fn build_uri(
237 &self,
238 route: &Route,
239 params: &serde_json::Value,
240 ) -> ClientResult<(String, serde_json::Value)> {
241 let mut path = route.path_template.to_owned();
242 let mut remaining = params.clone();
243
244 for ¶m in route.path_params {
245 let value = remaining
246 .get(param)
247 .and_then(serde_json::Value::as_str)
248 .ok_or_else(|| ClientError::Transport(format!("missing path parameter: {param}")))?
249 .to_owned();
250
251 path = path.replace(&format!("{{{param}}}"), &value);
252
253 if let Some(obj) = remaining.as_object_mut() {
254 obj.remove(param);
255 }
256 }
257
258 let mut uri = format!("{}{path}", self.inner.base_url);
259
260 if route.http_method == HttpMethod::Get || route.http_method == HttpMethod::Delete {
262 let query = build_query_string(&remaining);
263 if !query.is_empty() {
264 uri.push('?');
265 uri.push_str(&query);
266 }
267 }
268
269 Ok((uri, remaining))
270 }
271
272 fn build_request(
273 &self,
274 method: &str,
275 params: &serde_json::Value,
276 extra_headers: &HashMap<String, String>,
277 streaming: bool,
278 ) -> ClientResult<hyper::Request<Full<Bytes>>> {
279 let route = route_for(method)
280 .ok_or_else(|| ClientError::Transport(format!("no REST route for method: {method}")))?;
281
282 let (uri, body_params) = self.build_uri(&route, params)?;
283 let accept = if streaming {
284 "text/event-stream"
285 } else {
286 "application/json"
287 };
288
289 let hyper_method = match route.http_method {
290 HttpMethod::Get => hyper::Method::GET,
291 HttpMethod::Post => hyper::Method::POST,
292 HttpMethod::Delete => hyper::Method::DELETE,
293 };
294
295 let body =
296 if route.http_method == HttpMethod::Get || route.http_method == HttpMethod::Delete {
297 Full::new(Bytes::new())
299 } else {
300 let bytes = serde_json::to_vec(&body_params).map_err(ClientError::Serialization)?;
301 Full::new(Bytes::from(bytes))
302 };
303
304 let mut builder = hyper::Request::builder()
305 .method(hyper_method)
306 .uri(uri)
307 .header(header::CONTENT_TYPE, a2a_protocol_types::A2A_CONTENT_TYPE)
308 .header(
309 a2a_protocol_types::A2A_VERSION_HEADER,
310 a2a_protocol_types::A2A_VERSION,
311 )
312 .header(header::ACCEPT, accept);
313
314 for (k, v) in extra_headers {
315 builder = builder.header(k.as_str(), v.as_str());
316 }
317
318 builder
319 .body(body)
320 .map_err(|e| ClientError::Transport(e.to_string()))
321 }
322
323 async fn execute_request(
324 &self,
325 method: &str,
326 params: serde_json::Value,
327 extra_headers: &HashMap<String, String>,
328 ) -> ClientResult<serde_json::Value> {
329 trace_info!(method, base_url = %self.inner.base_url, "sending REST request");
330
331 let req = self.build_request(method, ¶ms, extra_headers, false)?;
332
333 let resp = tokio::time::timeout(self.inner.request_timeout, self.inner.client.request(req))
334 .await
335 .map_err(|_| {
336 trace_error!(method, "request timed out");
337 ClientError::Transport("request timed out".into())
338 })?
339 .map_err(|e| {
340 trace_error!(method, error = %e, "HTTP client error");
341 ClientError::HttpClient(e.to_string())
342 })?;
343
344 let status = resp.status();
345 trace_debug!(method, %status, "received response");
346 let body_bytes = resp.collect().await.map_err(ClientError::Http)?.to_bytes();
347
348 if !status.is_success() {
349 let body_str = String::from_utf8_lossy(&body_bytes);
350 return Err(ClientError::UnexpectedStatus {
351 status: status.as_u16(),
352 body: super::truncate_body(&body_str),
353 });
354 }
355
356 if let Ok(envelope) =
358 serde_json::from_slice::<JsonRpcResponse<serde_json::Value>>(&body_bytes)
359 {
360 return match envelope {
361 JsonRpcResponse::Success(ok) => Ok(ok.result),
362 JsonRpcResponse::Error(err) => {
363 let a2a = a2a_protocol_types::A2aError::new(
364 a2a_protocol_types::ErrorCode::try_from(err.error.code)
365 .unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
366 err.error.message,
367 );
368 Err(ClientError::Protocol(a2a))
369 }
370 };
371 }
372
373 serde_json::from_slice(&body_bytes).map_err(ClientError::Serialization)
375 }
376
377 async fn execute_streaming_request(
378 &self,
379 method: &str,
380 params: serde_json::Value,
381 extra_headers: &HashMap<String, String>,
382 ) -> ClientResult<EventStream> {
383 trace_info!(method, base_url = %self.inner.base_url, "opening REST SSE stream");
384
385 let req = self.build_request(method, ¶ms, extra_headers, true)?;
386
387 let resp = tokio::time::timeout(
388 self.inner.stream_connect_timeout,
389 self.inner.client.request(req),
390 )
391 .await
392 .map_err(|_| {
393 trace_error!(method, "stream connect timed out");
394 ClientError::Timeout("stream connect timed out".into())
395 })?
396 .map_err(|e| {
397 trace_error!(method, error = %e, "HTTP client error");
398 ClientError::HttpClient(e.to_string())
399 })?;
400
401 let status = resp.status();
402 if !status.is_success() {
403 let body_bytes = resp.collect().await.map_err(ClientError::Http)?.to_bytes();
404 let body_str = String::from_utf8_lossy(&body_bytes);
405 return Err(ClientError::UnexpectedStatus {
406 status: status.as_u16(),
407 body: super::truncate_body(&body_str),
408 });
409 }
410
411 let (tx, rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(64);
412 let body = resp.into_body();
413
414 let task_handle = tokio::spawn(async move {
415 body_reader_task(body, tx).await;
416 });
417
418 Ok(EventStream::with_abort_handle(
419 rx,
420 task_handle.abort_handle(),
421 ))
422 }
423}
424
425impl Transport for RestTransport {
426 fn send_request<'a>(
427 &'a self,
428 method: &'a str,
429 params: serde_json::Value,
430 extra_headers: &'a HashMap<String, String>,
431 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
432 Box::pin(self.execute_request(method, params, extra_headers))
433 }
434
435 fn send_streaming_request<'a>(
436 &'a self,
437 method: &'a str,
438 params: serde_json::Value,
439 extra_headers: &'a HashMap<String, String>,
440 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
441 Box::pin(self.execute_streaming_request(method, params, extra_headers))
442 }
443}
444
445async fn body_reader_task(
448 body: hyper::body::Incoming,
449 tx: mpsc::Sender<crate::streaming::event_stream::BodyChunk>,
450) {
451 tokio::pin!(body);
452 loop {
453 let frame = std::future::poll_fn(|cx| {
454 use hyper::body::Body;
455 let pinned = unsafe { Pin::new_unchecked(&mut *body) };
457 pinned.poll_frame(cx)
458 })
459 .await;
460
461 match frame {
462 None => break,
463 Some(Err(e)) => {
464 let _ = tx.send(Err(ClientError::Http(e))).await;
465 break;
466 }
467 Some(Ok(f)) => {
468 if let Ok(data) = f.into_data() {
469 if tx.send(Ok(data)).await.is_err() {
470 break;
471 }
472 }
473 }
474 }
475 }
476}
477
478fn build_query_string(params: &serde_json::Value) -> String {
482 let Some(obj) = params.as_object() else {
483 return String::new();
484 };
485 let mut parts = Vec::new();
486 for (k, v) in obj {
487 match v {
488 serde_json::Value::Null => {}
489 serde_json::Value::String(s) => parts.push(format!("{k}={s}")),
490 serde_json::Value::Number(n) => parts.push(format!("{k}={n}")),
491 serde_json::Value::Bool(b) => parts.push(format!("{k}={b}")),
492 _ => {
493 if let Ok(s) = serde_json::to_string(v) {
495 parts.push(format!("{k}={s}"));
496 }
497 }
498 }
499 }
500 parts.join("&")
501}
502
503#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn route_for_known_methods() {
511 assert!(route_for("SendMessage").is_some());
512 assert!(route_for("GetTask").is_some());
513 assert!(route_for("ListTasks").is_some());
514 assert!(route_for("SendStreamingMessage").is_some_and(|r| r.streaming));
515 }
516
517 #[test]
518 fn route_for_unknown_method_returns_none() {
519 assert!(route_for("unknown/method").is_none());
520 }
521
522 #[test]
523 fn build_uri_extracts_path_param_and_appends_query() {
524 let transport = RestTransport::new("http://localhost:8080").unwrap();
525 let route = route_for("GetTask").unwrap();
526 let params = serde_json::json!({"id": "task-123", "historyLength": 5});
527 let (uri, _remaining) = transport.build_uri(&route, ¶ms).unwrap();
528 assert!(
529 uri.starts_with("http://localhost:8080/tasks/task-123"),
530 "should have task ID in path"
531 );
532 assert!(
533 uri.contains("historyLength=5"),
534 "should have historyLength in query"
535 );
536 }
537
538 #[test]
539 fn build_uri_appends_query_for_get() {
540 let transport = RestTransport::new("http://localhost:8080").unwrap();
541 let route = route_for("ListTasks").unwrap();
542 let params = serde_json::json!({"pageSize": 10});
543 let (uri, _remaining) = transport.build_uri(&route, ¶ms).unwrap();
544 assert!(uri.contains("pageSize=10"), "should have pageSize in query");
545 }
546
547 #[test]
548 fn rest_transport_rejects_invalid_url() {
549 assert!(RestTransport::new("not-a-url").is_err());
550 }
551
552 #[test]
553 fn rest_transport_stores_base_url() {
554 let t = RestTransport::new("http://localhost:9090").unwrap();
555 assert_eq!(t.base_url(), "http://localhost:9090");
556 }
557}