a2a_protocol_client/transport/rest/
mod.rs1mod query;
38mod request;
39mod routing;
40mod streaming;
41
42use std::collections::HashMap;
43use std::future::Future;
44use std::pin::Pin;
45use std::sync::Arc;
46use std::time::Duration;
47
48#[cfg(not(feature = "tls-rustls"))]
49use http_body_util::Full;
50#[cfg(not(feature = "tls-rustls"))]
51use hyper::body::Bytes;
52#[cfg(not(feature = "tls-rustls"))]
53use hyper_util::client::legacy::connect::HttpConnector;
54#[cfg(not(feature = "tls-rustls"))]
55use hyper_util::client::legacy::Client;
56#[cfg(not(feature = "tls-rustls"))]
57use hyper_util::rt::TokioExecutor;
58
59use crate::error::{ClientError, ClientResult};
60use crate::streaming::EventStream;
61use crate::transport::Transport;
62
63#[cfg(not(feature = "tls-rustls"))]
66type HttpClient = Client<HttpConnector, Full<Bytes>>;
67
68#[cfg(feature = "tls-rustls")]
69type HttpClient = crate::tls::HttpsClient;
70
71#[derive(Clone, Debug)]
78pub struct RestTransport {
79 inner: Arc<Inner>,
80}
81
82#[derive(Debug)]
83struct Inner {
84 client: HttpClient,
85 base_url: String,
86 request_timeout: Duration,
87 stream_connect_timeout: Duration,
88}
89
90impl RestTransport {
91 pub fn new(base_url: impl Into<String>) -> ClientResult<Self> {
97 Self::with_timeout(base_url, Duration::from_secs(30))
98 }
99
100 pub fn with_timeout(
106 base_url: impl Into<String>,
107 request_timeout: Duration,
108 ) -> ClientResult<Self> {
109 Self::with_timeouts(base_url, request_timeout, request_timeout)
110 }
111
112 pub fn with_timeouts(
120 base_url: impl Into<String>,
121 request_timeout: Duration,
122 stream_connect_timeout: Duration,
123 ) -> ClientResult<Self> {
124 Self::with_all_timeouts(
125 base_url,
126 request_timeout,
127 stream_connect_timeout,
128 Duration::from_secs(10),
129 )
130 }
131
132 pub fn with_all_timeouts(
141 base_url: impl Into<String>,
142 request_timeout: Duration,
143 stream_connect_timeout: Duration,
144 connection_timeout: Duration,
145 ) -> ClientResult<Self> {
146 let base_url = base_url.into();
147 if base_url.is_empty()
148 || (!base_url.starts_with("http://") && !base_url.starts_with("https://"))
149 {
150 return Err(ClientError::InvalidEndpoint(format!(
151 "invalid base URL: {base_url}"
152 )));
153 }
154
155 #[cfg(not(feature = "tls-rustls"))]
156 let client = {
157 let mut connector = HttpConnector::new();
158 connector.set_connect_timeout(Some(connection_timeout));
159 connector.set_nodelay(true);
160 Client::builder(TokioExecutor::new())
161 .pool_idle_timeout(Duration::from_secs(90))
162 .build(connector)
163 };
164
165 #[cfg(feature = "tls-rustls")]
166 let client = crate::tls::build_https_client_with_connect_timeout(
167 crate::tls::default_tls_config(),
168 connection_timeout,
169 );
170
171 Ok(Self {
172 inner: Arc::new(Inner {
173 client,
174 base_url: base_url.trim_end_matches('/').to_owned(),
175 request_timeout,
176 stream_connect_timeout,
177 }),
178 })
179 }
180
181 #[must_use]
183 pub fn base_url(&self) -> &str {
184 &self.inner.base_url
185 }
186}
187
188impl Transport for RestTransport {
189 fn send_request<'a>(
190 &'a self,
191 method: &'a str,
192 params: serde_json::Value,
193 extra_headers: &'a HashMap<String, String>,
194 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
195 Box::pin(self.execute_request(method, params, extra_headers))
196 }
197
198 fn send_streaming_request<'a>(
199 &'a self,
200 method: &'a str,
201 params: serde_json::Value,
202 extra_headers: &'a HashMap<String, String>,
203 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
204 Box::pin(self.execute_streaming_request(method, params, extra_headers))
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn rest_transport_rejects_invalid_url() {
214 assert!(RestTransport::new("not-a-url").is_err());
215 }
216
217 #[test]
218 fn rest_transport_stores_base_url() {
219 let t = RestTransport::new("http://localhost:9090").unwrap();
220 assert_eq!(t.base_url(), "http://localhost:9090");
221 }
222
223 #[tokio::test]
225 async fn send_request_via_trait_delegation() {
226 use http_body_util::Full;
227 use hyper::body::Bytes;
228
229 let response_body = r#"{"status":"ok","data":42}"#;
230 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
231 let addr = listener.local_addr().unwrap();
232
233 tokio::spawn(async move {
234 loop {
235 let (stream, _) = listener.accept().await.unwrap();
236 let io = hyper_util::rt::TokioIo::new(stream);
237 let body = response_body.to_owned();
238 tokio::spawn(async move {
239 let service = hyper::service::service_fn(move |_req| {
240 let body = body.clone();
241 async move {
242 Ok::<_, hyper::Error>(
243 hyper::Response::builder()
244 .status(200)
245 .header("content-type", "application/json")
246 .body(Full::new(Bytes::from(body)))
247 .unwrap(),
248 )
249 }
250 });
251 let _ = hyper_util::server::conn::auto::Builder::new(
252 hyper_util::rt::TokioExecutor::new(),
253 )
254 .serve_connection(io, service)
255 .await;
256 });
257 }
258 });
259
260 let url = format!("http://127.0.0.1:{}", addr.port());
261 let transport = RestTransport::new(&url).unwrap();
262 let dyn_transport: &dyn crate::transport::Transport = &transport;
263 let result = dyn_transport
264 .send_request("SendMessage", serde_json::json!({}), &HashMap::new())
265 .await;
266 assert!(result.is_ok(), "send_request via trait should succeed");
267 }
268
269 #[tokio::test]
271 async fn send_streaming_request_via_trait_delegation() {
272 use http_body_util::Full;
273 use hyper::body::Bytes;
274
275 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
276 let addr = listener.local_addr().unwrap();
277
278 tokio::spawn(async move {
279 loop {
280 let (stream, _) = listener.accept().await.unwrap();
281 let io = hyper_util::rt::TokioIo::new(stream);
282 tokio::spawn(async move {
283 let service = hyper::service::service_fn(|_req| async {
284 let sse_body = "data: {\"hello\":\"world\"}\n\n";
285 Ok::<_, hyper::Error>(
286 hyper::Response::builder()
287 .status(200)
288 .header("content-type", "text/event-stream")
289 .body(Full::new(Bytes::from(sse_body)))
290 .unwrap(),
291 )
292 });
293 let _ = hyper_util::server::conn::auto::Builder::new(
294 hyper_util::rt::TokioExecutor::new(),
295 )
296 .serve_connection(io, service)
297 .await;
298 });
299 }
300 });
301
302 let url = format!("http://127.0.0.1:{}", addr.port());
303 let transport = RestTransport::new(&url).unwrap();
304 let dyn_transport: &dyn crate::transport::Transport = &transport;
305 let result = dyn_transport
306 .send_streaming_request(
307 "SendStreamingMessage",
308 serde_json::json!({}),
309 &HashMap::new(),
310 )
311 .await;
312 assert!(
313 result.is_ok(),
314 "send_streaming_request via trait should succeed"
315 );
316 }
317}