1use std::cell::RefCell;
22use std::collections::HashMap;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::Arc;
26
27use ra2a::SVC_PARAM_EXTENSIONS;
28use ra2a::error::A2AError;
29use ra2a::types::AgentCard;
30
31use crate::util::is_extension_supported;
32
33tokio::task_local! {
34 static PROPAGATOR_CTX: RefCell<Option<PropagatorContext>>;
37}
38
39#[derive(Debug, Clone, Default)]
43#[non_exhaustive]
44pub struct PropagatorContext {
45 pub request_headers: HashMap<String, Vec<String>>,
47 pub metadata: HashMap<String, serde_json::Value>,
49}
50
51impl PropagatorContext {
52 #[must_use]
54 pub fn current() -> Option<Self> {
55 PROPAGATOR_CTX
56 .try_with(|cell| cell.borrow().clone())
57 .ok()
58 .flatten()
59 }
60
61 #[must_use]
66 pub fn install(self) -> bool {
67 PROPAGATOR_CTX
68 .try_with(|cell| {
69 *cell.borrow_mut() = Some(self);
70 })
71 .is_ok()
72 }
73
74 pub async fn scope<F: Future>(self, f: F) -> F::Output {
79 PROPAGATOR_CTX.scope(RefCell::new(Some(self)), f).await
80 }
81}
82
83pub async fn init_propagation<F: Future>(f: F) -> F::Output {
99 PROPAGATOR_CTX.scope(RefCell::new(None), f).await
100}
101
102pub(crate) type ServerMetadataPredicate = Arc<dyn Fn(&[String], &str) -> bool + Send + Sync>;
107
108pub(crate) type ServerHeaderPredicate = Arc<dyn Fn(&str) -> bool + Send + Sync>;
112
113#[derive(Default)]
117#[non_exhaustive]
118pub struct ServerPropagatorConfig {
119 pub metadata_predicate: Option<ServerMetadataPredicate>,
123 pub header_predicate: Option<ServerHeaderPredicate>,
127}
128
129impl std::fmt::Debug for ServerPropagatorConfig {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 f.debug_struct("ServerPropagatorConfig")
132 .field("metadata_predicate", &self.metadata_predicate.is_some())
133 .field("header_predicate", &self.header_predicate.is_some())
134 .finish()
135 }
136}
137
138pub struct ServerPropagator {
147 metadata_predicate: ServerMetadataPredicate,
149 header_predicate: ServerHeaderPredicate,
151}
152
153impl ServerPropagator {
154 #[must_use]
160 pub fn new() -> Self {
161 Self::with_config(ServerPropagatorConfig::default())
162 }
163
164 #[must_use]
166 pub fn with_config(config: ServerPropagatorConfig) -> Self {
167 let metadata_predicate = config.metadata_predicate.unwrap_or_else(|| {
168 Arc::new(|requested_uris: &[String], key: &str| requested_uris.iter().any(|u| u == key))
169 });
170
171 let header_predicate = config.header_predicate.unwrap_or_else(|| {
172 Arc::new(|key: &str| key.eq_ignore_ascii_case(SVC_PARAM_EXTENSIONS))
173 });
174
175 Self {
176 metadata_predicate,
177 header_predicate,
178 }
179 }
180}
181
182impl Default for ServerPropagator {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188impl std::fmt::Debug for ServerPropagator {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 f.debug_struct("ServerPropagator").finish_non_exhaustive()
191 }
192}
193
194impl ServerPropagator {
195 fn propagate_server(&self, ctx: &mut ra2a::server::CallContext, req: &ra2a::server::Request) {
197 let mut prop_ctx = PropagatorContext::default();
198
199 let requested = ctx.requested_extension_uris();
200
201 extract_metadata(
202 req,
203 &requested,
204 &self.metadata_predicate,
205 &mut prop_ctx.metadata,
206 );
207
208 let request_meta = ctx.request_meta();
209 for (header_name, header_values) in request_meta.iter() {
210 if (self.header_predicate)(header_name) {
211 prop_ctx
212 .request_headers
213 .insert(header_name.to_owned(), header_values.to_vec());
214 }
215 }
216
217 if let Some(ext_values) = prop_ctx.request_headers.get(SVC_PARAM_EXTENSIONS) {
218 for uri in ext_values {
219 ctx.activate_extension(uri);
220 }
221 }
222
223 let _installed = prop_ctx.install();
225 }
226}
227
228impl ra2a::server::CallInterceptor for ServerPropagator {
229 fn before<'a>(
230 &'a self,
231 ctx: &'a mut ra2a::server::CallContext,
232 req: &'a mut ra2a::server::Request,
233 ) -> Pin<Box<dyn Future<Output = Result<(), A2AError>> + Send + 'a>> {
234 self.propagate_server(ctx, req);
235 Box::pin(std::future::ready(Ok(())))
236 }
237
238 fn after<'a>(
239 &'a self,
240 _ctx: &'a ra2a::server::CallContext,
241 _resp: &'a mut ra2a::server::Response,
242 ) -> Pin<Box<dyn Future<Output = Result<(), A2AError>> + Send + 'a>> {
243 Box::pin(async { Ok(()) })
244 }
245}
246
247fn extract_metadata(
249 req: &ra2a::server::Request,
250 requested: &[String],
251 predicate: &ServerMetadataPredicate,
252 out: &mut HashMap<String, serde_json::Value>,
253) {
254 if let Some(params) = req.downcast_ref::<ra2a::SendMessageRequest>()
255 && let Some(ref meta) = params.metadata
256 {
257 collect_matching_metadata(meta, requested, predicate, out);
258 }
259}
260
261fn collect_matching_metadata(
263 metadata: &ra2a::Metadata,
264 requested: &[String],
265 predicate: &ServerMetadataPredicate,
266 out: &mut HashMap<String, serde_json::Value>,
267) {
268 for (k, v) in metadata {
269 if predicate(requested, k) {
270 out.insert(k.clone(), v.clone());
271 }
272 }
273}
274
275pub(crate) type ClientMetadataPredicate =
280 Arc<dyn Fn(Option<&AgentCard>, &[String], &str) -> bool + Send + Sync>;
281
282pub(crate) type ClientHeaderPredicate =
287 Arc<dyn Fn(Option<&AgentCard>, &str, &str) -> bool + Send + Sync>;
288
289#[derive(Default)]
291#[non_exhaustive]
292pub struct ClientPropagatorConfig {
293 pub metadata_predicate: Option<ClientMetadataPredicate>,
298 pub header_predicate: Option<ClientHeaderPredicate>,
303}
304
305impl std::fmt::Debug for ClientPropagatorConfig {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 f.debug_struct("ClientPropagatorConfig")
308 .field("metadata_predicate", &self.metadata_predicate.is_some())
309 .field("header_predicate", &self.header_predicate.is_some())
310 .finish()
311 }
312}
313
314pub struct ClientPropagator {
322 metadata_predicate: ClientMetadataPredicate,
324 header_predicate: ClientHeaderPredicate,
326}
327
328impl ClientPropagator {
329 #[must_use]
331 pub fn new() -> Self {
332 Self::with_config(ClientPropagatorConfig::default())
333 }
334
335 #[must_use]
337 pub fn with_config(config: ClientPropagatorConfig) -> Self {
338 let metadata_predicate = config
339 .metadata_predicate
340 .unwrap_or_else(|| Arc::new(default_client_metadata_predicate));
341
342 let header_predicate = config
343 .header_predicate
344 .unwrap_or_else(|| Arc::new(default_client_header_predicate));
345
346 Self {
347 metadata_predicate,
348 header_predicate,
349 }
350 }
351}
352
353impl Default for ClientPropagator {
354 fn default() -> Self {
355 Self::new()
356 }
357}
358
359impl std::fmt::Debug for ClientPropagator {
360 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361 f.debug_struct("ClientPropagator").finish_non_exhaustive()
362 }
363}
364
365impl ClientPropagator {
366 fn propagate_client(&self, req: &mut ra2a::client::Request) {
368 let Some(prop_ctx) = PropagatorContext::current() else {
369 return;
370 };
371
372 let requested: Vec<String> = prop_ctx
373 .request_headers
374 .get(SVC_PARAM_EXTENSIONS)
375 .cloned()
376 .unwrap_or_default();
377
378 if !prop_ctx.metadata.is_empty() {
379 inject_metadata(
380 &mut *req.payload,
381 &prop_ctx.metadata,
382 req.card.as_ref(),
383 &requested,
384 &self.metadata_predicate,
385 );
386 }
387
388 for (name, val) in prop_ctx
389 .request_headers
390 .iter()
391 .flat_map(|(k, vs)| vs.iter().map(move |v| (k, v)))
392 {
393 if (self.header_predicate)(req.card.as_ref(), name, val) {
394 req.service_params.append(name, val);
395 }
396 }
397 }
398}
399
400impl ra2a::client::CallInterceptor for ClientPropagator {
401 fn before<'a>(
402 &'a self,
403 req: &'a mut ra2a::client::Request,
404 ) -> Pin<Box<dyn Future<Output = ra2a::error::Result<()>> + Send + 'a>> {
405 self.propagate_client(req);
406 Box::pin(std::future::ready(Ok(())))
407 }
408}
409
410fn default_client_metadata_predicate(
413 card: Option<&AgentCard>,
414 requested: &[String],
415 key: &str,
416) -> bool {
417 requested.iter().any(|u| u == key) && is_extension_supported(card, key)
418}
419
420fn default_client_header_predicate(card: Option<&AgentCard>, key: &str, val: &str) -> bool {
423 key.eq_ignore_ascii_case(SVC_PARAM_EXTENSIONS) && is_extension_supported(card, val)
424}
425
426fn inject_metadata(
428 payload: &mut dyn std::any::Any,
429 metadata: &HashMap<String, serde_json::Value>,
430 card: Option<&AgentCard>,
431 requested: &[String],
432 predicate: &ClientMetadataPredicate,
433) {
434 if let Some(params) = payload.downcast_mut::<ra2a::SendMessageRequest>() {
435 let meta = params.metadata.get_or_insert_with(Default::default);
436 inject_matching_metadata(meta, metadata, card, requested, predicate);
437 }
438}
439
440fn inject_matching_metadata(
442 target: &mut ra2a::Metadata,
443 source: &HashMap<String, serde_json::Value>,
444 card: Option<&AgentCard>,
445 requested: &[String],
446 predicate: &ClientMetadataPredicate,
447) {
448 for (k, v) in source {
449 if predicate(card, requested, k) {
450 target.insert(k.clone(), v.clone());
451 }
452 }
453}
454
455#[cfg(test)]
456#[allow(clippy::unwrap_used, reason = "tests use unwrap for brevity")]
457mod tests {
458 use ra2a::client::{CallInterceptor as _, ServiceParams};
459 use ra2a::types::{
460 AgentCapabilities, AgentCard, AgentExtension, AgentInterface, TransportProtocol,
461 };
462
463 use super::*;
464
465 fn make_card(uris: &[&str]) -> AgentCard {
466 let mut card = AgentCard::new(
467 "test",
468 "test agent",
469 vec![AgentInterface::new(
470 "https://example.com",
471 TransportProtocol::new("JSONRPC"),
472 )],
473 );
474 card.capabilities = AgentCapabilities {
475 extensions: uris
476 .iter()
477 .map(|u| AgentExtension {
478 uri: (*u).into(),
479 description: None,
480 required: false,
481 params: None,
482 })
483 .collect(),
484 ..AgentCapabilities::default()
485 };
486 card
487 }
488
489 #[tokio::test]
490 async fn test_client_propagator_injects_headers() {
491 let propagator = ClientPropagator::new();
492 let card = make_card(&["urn:a2a:ext:duration"]);
493
494 let mut prop_ctx = PropagatorContext::default();
495 prop_ctx.request_headers.insert(
496 SVC_PARAM_EXTENSIONS.to_owned(),
497 vec!["urn:a2a:ext:duration".into()],
498 );
499
500 let mut req = ra2a::client::Request {
501 method: "message/send".into(),
502 service_params: ServiceParams::default(),
503 card: Some(card),
504 payload: Box::new(()),
505 };
506
507 prop_ctx
508 .scope(async {
509 propagator.before(&mut req).await.unwrap();
510 })
511 .await;
512
513 let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
514 assert_eq!(vals, &["urn:a2a:ext:duration"]);
515 }
516
517 #[tokio::test]
518 async fn test_client_propagator_filters_unsupported() {
519 let propagator = ClientPropagator::new();
520 let card = make_card(&["urn:a2a:ext:other"]);
521
522 let mut prop_ctx = PropagatorContext::default();
523 prop_ctx.request_headers.insert(
524 SVC_PARAM_EXTENSIONS.to_owned(),
525 vec!["urn:a2a:ext:duration".into()],
526 );
527
528 let mut req = ra2a::client::Request {
529 method: "message/send".into(),
530 service_params: ServiceParams::default(),
531 card: Some(card),
532 payload: Box::new(()),
533 };
534
535 prop_ctx
536 .scope(async {
537 propagator.before(&mut req).await.unwrap();
538 })
539 .await;
540
541 let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
542 assert!(vals.is_empty());
543 }
544
545 #[tokio::test]
546 async fn test_client_propagator_no_context_is_noop() {
547 let propagator = ClientPropagator::new();
548
549 let mut req = ra2a::client::Request {
550 method: "message/send".into(),
551 service_params: ServiceParams::default(),
552 card: None,
553 payload: Box::new(()),
554 };
555
556 propagator.before(&mut req).await.unwrap();
557 assert!(req.service_params.is_empty());
558 }
559
560 #[tokio::test]
561 async fn test_propagator_context_install_and_read() {
562 let ctx = PropagatorContext {
563 request_headers: {
564 let mut m = HashMap::new();
565 m.insert("x-test".into(), vec!["val1".into()]);
566 m
567 },
568 metadata: HashMap::new(),
569 };
570
571 init_propagation(async {
572 assert!(PropagatorContext::current().is_none());
573 assert!(ctx.install());
574 let read = PropagatorContext::current().unwrap();
575 assert_eq!(
576 read.request_headers.get("x-test").unwrap(),
577 &["val1".to_owned()]
578 );
579 })
580 .await;
581 }
582}