a2a_protocol_server/agent_card/
dynamic_handler.rs1use std::future::Future;
14use std::pin::Pin;
15
16use a2a_protocol_types::agent_card::AgentCard;
17use a2a_protocol_types::error::A2aResult;
18use bytes::Bytes;
19use http_body_util::Full;
20
21use crate::agent_card::caching::{format_http_date, make_etag, CacheConfig};
22use crate::agent_card::CORS_ALLOW_ALL;
23
24pub trait AgentCardProducer: Send + Sync + 'static {
28 fn produce<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<AgentCard>> + Send + 'a>>;
34}
35
36#[derive(Debug)]
38pub struct DynamicAgentCardHandler<P> {
39 producer: P,
40 cache_config: CacheConfig,
41}
42
43impl<P: AgentCardProducer> DynamicAgentCardHandler<P> {
44 #[must_use]
46 pub fn new(producer: P) -> Self {
47 Self {
48 producer,
49 cache_config: CacheConfig::default(),
50 }
51 }
52
53 #[must_use]
55 pub const fn with_max_age(mut self, seconds: u32) -> Self {
56 self.cache_config = CacheConfig::with_max_age(seconds);
57 self
58 }
59
60 #[allow(clippy::future_not_send)] pub async fn handle(
66 &self,
67 req: &hyper::Request<impl hyper::body::Body>,
68 ) -> hyper::Response<Full<Bytes>> {
69 let if_none_match = req
72 .headers()
73 .get("if-none-match")
74 .and_then(|v| v.to_str().ok())
75 .map(str::to_owned);
76 let if_modified_since = req
77 .headers()
78 .get("if-modified-since")
79 .and_then(|v| v.to_str().ok())
80 .map(str::to_owned);
81
82 match self.producer.produce().await {
83 Ok(card) => match serde_json::to_vec(&card) {
84 Ok(json) => {
85 let etag = make_etag(&json);
86 let last_modified = format_http_date(std::time::SystemTime::now());
87
88 let not_modified = is_not_modified(
89 if_none_match.as_deref(),
90 if_modified_since.as_deref(),
91 &etag,
92 &last_modified,
93 );
94
95 if not_modified {
96 hyper::Response::builder()
97 .status(304)
98 .header("etag", &etag)
99 .header("last-modified", &last_modified)
100 .header("cache-control", self.cache_config.header_value())
101 .body(Full::new(Bytes::new()))
102 .unwrap_or_else(|_| fallback_error_response())
103 } else {
104 hyper::Response::builder()
105 .status(200)
106 .header("content-type", "application/json")
107 .header("access-control-allow-origin", CORS_ALLOW_ALL)
108 .header("etag", &etag)
109 .header("last-modified", &last_modified)
110 .header("cache-control", self.cache_config.header_value())
111 .body(Full::new(Bytes::from(json)))
112 .unwrap_or_else(|_| fallback_error_response())
113 }
114 }
115 Err(e) => error_response(500, &format!("serialization error: {e}")),
116 },
117 Err(e) => error_response(500, &format!("card producer error: {e}")),
118 }
119 }
120
121 pub async fn handle_unconditional(&self) -> hyper::Response<Full<Bytes>> {
123 match self.producer.produce().await {
124 Ok(card) => match serde_json::to_vec(&card) {
125 Ok(json) => {
126 let etag = make_etag(&json);
127 let last_modified = format_http_date(std::time::SystemTime::now());
128 hyper::Response::builder()
129 .status(200)
130 .header("content-type", "application/json")
131 .header("access-control-allow-origin", CORS_ALLOW_ALL)
132 .header("etag", &etag)
133 .header("last-modified", &last_modified)
134 .header("cache-control", self.cache_config.header_value())
135 .body(Full::new(Bytes::from(json)))
136 .unwrap_or_else(|_| fallback_error_response())
137 }
138 Err(e) => error_response(500, &format!("serialization error: {e}")),
139 },
140 Err(e) => error_response(500, &format!("card producer error: {e}")),
141 }
142 }
143}
144
145fn is_not_modified(
147 if_none_match: Option<&str>,
148 if_modified_since: Option<&str>,
149 current_etag: &str,
150 current_last_modified: &str,
151) -> bool {
152 if let Some(inm) = if_none_match {
154 return etag_matches(inm, current_etag);
155 }
156 if let Some(ims) = if_modified_since {
157 return ims == current_last_modified;
158 }
159 false
160}
161
162fn etag_matches(header_value: &str, current: &str) -> bool {
164 let header_value = header_value.trim();
165 if header_value == "*" {
166 return true;
167 }
168 let current_bare = current.strip_prefix("W/").unwrap_or(current);
169 for candidate in header_value.split(',') {
170 let candidate = candidate.trim();
171 let candidate_bare = candidate.strip_prefix("W/").unwrap_or(candidate);
172 if candidate_bare == current_bare {
173 return true;
174 }
175 }
176 false
177}
178
179fn error_response(status: u16, message: &str) -> hyper::Response<Full<Bytes>> {
181 let body = serde_json::json!({ "error": message });
182 let bytes = serde_json::to_vec(&body).unwrap_or_default();
183 hyper::Response::builder()
184 .status(status)
185 .header("content-type", "application/json")
186 .body(Full::new(Bytes::from(bytes)))
187 .unwrap_or_else(|_| fallback_error_response())
188}
189
190fn fallback_error_response() -> hyper::Response<Full<Bytes>> {
193 hyper::Response::new(Full::new(Bytes::from_static(
194 br#"{"error":"internal server error"}"#,
195 )))
196}
197
198#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::agent_card::caching::tests::minimal_agent_card;
204
205 struct MockProducer;
207
208 impl AgentCardProducer for MockProducer {
209 fn produce<'a>(
210 &'a self,
211 ) -> Pin<Box<dyn Future<Output = A2aResult<AgentCard>> + Send + 'a>> {
212 Box::pin(async { Ok(minimal_agent_card()) })
213 }
214 }
215
216 struct FailingProducer;
218
219 impl AgentCardProducer for FailingProducer {
220 fn produce<'a>(
221 &'a self,
222 ) -> Pin<Box<dyn Future<Output = A2aResult<AgentCard>> + Send + 'a>> {
223 Box::pin(async {
224 Err(a2a_protocol_types::error::A2aError::internal(
225 "producer failure",
226 ))
227 })
228 }
229 }
230
231 #[test]
232 fn construction_with_defaults() {
233 let handler = DynamicAgentCardHandler::new(MockProducer);
234 assert_eq!(
235 handler.cache_config.max_age, 3600,
236 "default max_age should be 3600 seconds"
237 );
238 }
239
240 #[test]
241 fn with_max_age_overrides_default() {
242 let handler = DynamicAgentCardHandler::new(MockProducer).with_max_age(120);
243 assert_eq!(
244 handler.cache_config.max_age, 120,
245 "with_max_age should set the configured value"
246 );
247 }
248
249 #[tokio::test]
250 async fn handle_returns_correct_content_type() {
251 let handler = DynamicAgentCardHandler::new(MockProducer);
252 let req = hyper::Request::builder()
253 .body(Full::new(Bytes::new()))
254 .unwrap();
255 let resp = handler.handle(&req).await;
256
257 assert_eq!(resp.status(), 200, "response should be 200 OK");
258 assert_eq!(
259 resp.headers().get("content-type").unwrap(),
260 "application/json",
261 "content-type should be application/json"
262 );
263 }
264
265 #[tokio::test]
266 async fn handle_includes_etag_header() {
267 let handler = DynamicAgentCardHandler::new(MockProducer);
268 let req = hyper::Request::builder()
269 .body(Full::new(Bytes::new()))
270 .unwrap();
271 let resp = handler.handle(&req).await;
272
273 let etag = resp
274 .headers()
275 .get("etag")
276 .expect("response should include an ETag header");
277 let etag_str = etag.to_str().unwrap();
278 assert!(
279 etag_str.starts_with("W/\""),
280 "ETag should be a weak validator starting with W/\""
281 );
282 }
283
284 #[tokio::test]
285 async fn handle_includes_cache_control_header() {
286 let handler = DynamicAgentCardHandler::new(MockProducer).with_max_age(300);
287 let req = hyper::Request::builder()
288 .body(Full::new(Bytes::new()))
289 .unwrap();
290 let resp = handler.handle(&req).await;
291
292 assert_eq!(
293 resp.headers().get("cache-control").unwrap(),
294 "public, max-age=300",
295 "cache-control should reflect with_max_age setting"
296 );
297 }
298
299 #[tokio::test]
300 async fn handle_includes_cors_header() {
301 let handler = DynamicAgentCardHandler::new(MockProducer);
302 let req = hyper::Request::builder()
303 .body(Full::new(Bytes::new()))
304 .unwrap();
305 let resp = handler.handle(&req).await;
306
307 assert_eq!(
308 resp.headers().get("access-control-allow-origin").unwrap(),
309 "*",
310 "CORS header should allow all origins"
311 );
312 }
313
314 #[tokio::test]
315 async fn conditional_request_with_matching_etag_returns_304() {
316 let handler = DynamicAgentCardHandler::new(MockProducer);
317
318 let req1 = hyper::Request::builder()
320 .body(Full::new(Bytes::new()))
321 .unwrap();
322 let resp1 = handler.handle(&req1).await;
323 assert_eq!(resp1.status(), 200, "first request should return 200");
324 let etag = resp1
325 .headers()
326 .get("etag")
327 .unwrap()
328 .to_str()
329 .unwrap()
330 .to_owned();
331
332 let req2 = hyper::Request::builder()
334 .header("if-none-match", &etag)
335 .body(Full::new(Bytes::new()))
336 .unwrap();
337 let resp2 = handler.handle(&req2).await;
338 assert_eq!(
339 resp2.status(),
340 304,
341 "conditional request with matching ETag should return 304 Not Modified"
342 );
343 }
344
345 #[tokio::test]
346 async fn conditional_request_with_non_matching_etag_returns_200() {
347 let handler = DynamicAgentCardHandler::new(MockProducer);
348 let req = hyper::Request::builder()
349 .header("if-none-match", "W/\"does-not-match\"")
350 .body(Full::new(Bytes::new()))
351 .unwrap();
352 let resp = handler.handle(&req).await;
353
354 assert_eq!(
355 resp.status(),
356 200,
357 "non-matching ETag should return 200 with full body"
358 );
359 }
360
361 #[tokio::test]
362 async fn handle_unconditional_always_returns_full_response() {
363 let handler = DynamicAgentCardHandler::new(MockProducer);
364
365 let resp = handler.handle_unconditional().await;
366 assert_eq!(resp.status(), 200, "unconditional handle should return 200");
367 assert_eq!(
368 resp.headers().get("content-type").unwrap(),
369 "application/json",
370 "unconditional response should have JSON content-type"
371 );
372 assert!(
373 resp.headers().get("etag").is_some(),
374 "unconditional response should still include ETag"
375 );
376 }
377
378 #[tokio::test]
379 async fn handle_returns_500_on_producer_error() {
380 let handler = DynamicAgentCardHandler::new(FailingProducer);
381 let req = hyper::Request::builder()
382 .body(Full::new(Bytes::new()))
383 .unwrap();
384 let resp = handler.handle(&req).await;
385
386 assert_eq!(
387 resp.status(),
388 500,
389 "producer error should result in 500 status"
390 );
391 }
392
393 #[tokio::test]
394 async fn handle_unconditional_returns_500_on_producer_error() {
395 let handler = DynamicAgentCardHandler::new(FailingProducer);
396 let resp = handler.handle_unconditional().await;
397
398 assert_eq!(
399 resp.status(),
400 500,
401 "producer error in unconditional handle should result in 500 status"
402 );
403 }
404
405 #[test]
407 fn fallback_error_response_returns_internal_error_json() {
408 let resp = fallback_error_response();
409 assert_eq!(resp.status(), 200); }
412
413 #[tokio::test]
417 async fn error_response_returns_correct_status() {
418 let resp = error_response(503, "service unavailable");
419 assert_eq!(resp.status(), 503);
420 let body = {
421 use http_body_util::BodyExt;
422 resp.into_body().collect().await.unwrap().to_bytes()
423 };
424 let val: serde_json::Value = serde_json::from_slice(&body).unwrap();
425 assert_eq!(val["error"], "service unavailable");
426 }
427
428 #[tokio::test]
429 async fn response_body_deserializes_to_agent_card() {
430 use http_body_util::BodyExt;
431
432 let handler = DynamicAgentCardHandler::new(MockProducer);
433 let req = hyper::Request::builder()
434 .body(Full::new(Bytes::new()))
435 .unwrap();
436 let resp = handler.handle(&req).await;
437 let body = resp.into_body().collect().await.unwrap().to_bytes();
438 let card: AgentCard =
439 serde_json::from_slice(&body).expect("response body should be valid AgentCard JSON");
440 assert_eq!(
441 card.name, "Test Agent",
442 "deserialized card name should match"
443 );
444 }
445}