1use std::cell::RefCell;
22use std::collections::HashMap;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::Arc;
26
27use ra2a::EXTENSIONS_META_KEY;
28use ra2a::error::A2AError;
29use ra2a::types::AgentCard;
30
31use crate::util::is_extension_supported;
32
33tokio::task_local! {
34 static PROPAGATOR_CTX: RefCell<Option<PropagatorContext>>;
37}
38
39#[derive(Debug, Clone, Default)]
43#[non_exhaustive]
44pub struct PropagatorContext {
45 pub request_headers: HashMap<String, Vec<String>>,
47 pub metadata: HashMap<String, serde_json::Value>,
49}
50
51impl PropagatorContext {
52 #[must_use]
54 pub fn current() -> Option<Self> {
55 PROPAGATOR_CTX
56 .try_with(|cell| cell.borrow().clone())
57 .ok()
58 .flatten()
59 }
60
61 pub fn install(self) -> bool {
66 PROPAGATOR_CTX
67 .try_with(|cell| {
68 *cell.borrow_mut() = Some(self);
69 })
70 .is_ok()
71 }
72
73 pub async fn scope<F: Future>(self, f: F) -> F::Output {
78 PROPAGATOR_CTX.scope(RefCell::new(Some(self)), f).await
79 }
80}
81
82pub async fn init_propagation<F: Future>(f: F) -> F::Output {
98 PROPAGATOR_CTX.scope(RefCell::new(None), f).await
99}
100
101pub type ServerMetadataPredicate = Arc<dyn Fn(&[String], &str) -> bool + Send + Sync>;
106
107pub type ServerHeaderPredicate = Arc<dyn Fn(&str) -> bool + Send + Sync>;
111
112#[derive(Default)]
116#[non_exhaustive]
117pub struct ServerPropagatorConfig {
118 pub metadata_predicate: Option<ServerMetadataPredicate>,
122 pub header_predicate: Option<ServerHeaderPredicate>,
126}
127
128impl std::fmt::Debug for ServerPropagatorConfig {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_struct("ServerPropagatorConfig")
131 .field("metadata_predicate", &self.metadata_predicate.is_some())
132 .field("header_predicate", &self.header_predicate.is_some())
133 .finish()
134 }
135}
136
137pub struct ServerPropagator {
146 metadata_predicate: ServerMetadataPredicate,
148 header_predicate: ServerHeaderPredicate,
150}
151
152impl ServerPropagator {
153 pub fn new() -> Self {
159 Self::with_config(ServerPropagatorConfig::default())
160 }
161
162 pub fn with_config(config: ServerPropagatorConfig) -> Self {
164 let metadata_predicate = config.metadata_predicate.unwrap_or_else(|| {
165 Arc::new(|requested_uris: &[String], key: &str| requested_uris.iter().any(|u| u == key))
166 });
167
168 let header_predicate = config
169 .header_predicate
170 .unwrap_or_else(|| Arc::new(|key: &str| key.eq_ignore_ascii_case(EXTENSIONS_META_KEY)));
171
172 Self {
173 metadata_predicate,
174 header_predicate,
175 }
176 }
177}
178
179impl Default for ServerPropagator {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185impl std::fmt::Debug for ServerPropagator {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 f.debug_struct("ServerPropagator").finish_non_exhaustive()
188 }
189}
190
191impl ra2a::server::CallInterceptor for ServerPropagator {
192 fn before<'a>(
193 &'a self,
194 ctx: &'a mut ra2a::server::CallContext,
195 req: &'a mut ra2a::server::Request,
196 ) -> Pin<Box<dyn Future<Output = Result<(), A2AError>> + Send + 'a>> {
197 Box::pin(async move {
198 let mut prop_ctx = PropagatorContext::default();
199
200 let requested = ctx.requested_extension_uris();
202
203 extract_metadata(
205 req,
206 &requested,
207 &self.metadata_predicate,
208 &mut prop_ctx.metadata,
209 );
210
211 let request_meta = ctx.request_meta();
213 for (header_name, header_values) in request_meta.iter() {
214 if (self.header_predicate)(header_name) {
215 prop_ctx
216 .request_headers
217 .insert(header_name.to_owned(), header_values.to_vec());
218 }
219 }
220
221 if let Some(ext_values) = prop_ctx.request_headers.get(EXTENSIONS_META_KEY) {
223 for uri in ext_values {
224 ctx.activate_extension(uri);
225 }
226 }
227
228 prop_ctx.install();
230
231 Ok(())
232 })
233 }
234
235 fn after<'a>(
236 &'a self,
237 _ctx: &'a ra2a::server::CallContext,
238 _resp: &'a mut ra2a::server::Response,
239 ) -> Pin<Box<dyn Future<Output = Result<(), A2AError>> + Send + 'a>> {
240 Box::pin(async { Ok(()) })
241 }
242}
243
244fn extract_metadata(
246 req: &ra2a::server::Request,
247 requested: &[String],
248 predicate: &ServerMetadataPredicate,
249 out: &mut HashMap<String, serde_json::Value>,
250) {
251 if let Some(params) = req.downcast_ref::<ra2a::MessageSendParams>() {
253 collect_matching_metadata(¶ms.metadata, requested, predicate, out);
254 } else if let Some(params) = req.downcast_ref::<ra2a::TaskQueryParams>() {
255 collect_matching_metadata(¶ms.metadata, requested, predicate, out);
256 } else if let Some(params) = req.downcast_ref::<ra2a::TaskIdParams>() {
257 collect_matching_metadata(¶ms.metadata, requested, predicate, out);
258 }
259}
260
261fn collect_matching_metadata(
263 metadata: &ra2a::Metadata,
264 requested: &[String],
265 predicate: &ServerMetadataPredicate,
266 out: &mut HashMap<String, serde_json::Value>,
267) {
268 for (k, v) in metadata {
269 if predicate(requested, k) {
270 out.insert(k.clone(), v.clone());
271 }
272 }
273}
274
275pub type ClientMetadataPredicate =
280 Arc<dyn Fn(Option<&AgentCard>, &[String], &str) -> bool + Send + Sync>;
281
282pub type ClientHeaderPredicate = Arc<dyn Fn(Option<&AgentCard>, &str, &str) -> bool + Send + Sync>;
287
288#[derive(Default)]
290#[non_exhaustive]
291pub struct ClientPropagatorConfig {
292 pub metadata_predicate: Option<ClientMetadataPredicate>,
297 pub header_predicate: Option<ClientHeaderPredicate>,
302}
303
304impl std::fmt::Debug for ClientPropagatorConfig {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 f.debug_struct("ClientPropagatorConfig")
307 .field("metadata_predicate", &self.metadata_predicate.is_some())
308 .field("header_predicate", &self.header_predicate.is_some())
309 .finish()
310 }
311}
312
313pub struct ClientPropagator {
321 metadata_predicate: ClientMetadataPredicate,
323 header_predicate: ClientHeaderPredicate,
325}
326
327impl ClientPropagator {
328 pub fn new() -> Self {
330 Self::with_config(ClientPropagatorConfig::default())
331 }
332
333 pub fn with_config(config: ClientPropagatorConfig) -> Self {
335 let metadata_predicate = config.metadata_predicate.unwrap_or_else(|| {
336 Arc::new(
337 |card: Option<&AgentCard>, requested: &[String], key: &str| {
338 if !requested.iter().any(|u| u == key) {
339 return false;
340 }
341 is_extension_supported(card, key)
342 },
343 )
344 });
345
346 let header_predicate = config.header_predicate.unwrap_or_else(|| {
347 Arc::new(|card: Option<&AgentCard>, key: &str, val: &str| {
348 if !key.eq_ignore_ascii_case(EXTENSIONS_META_KEY) {
349 return false;
350 }
351 is_extension_supported(card, val)
352 })
353 });
354
355 Self {
356 metadata_predicate,
357 header_predicate,
358 }
359 }
360}
361
362impl Default for ClientPropagator {
363 fn default() -> Self {
364 Self::new()
365 }
366}
367
368impl std::fmt::Debug for ClientPropagator {
369 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370 f.debug_struct("ClientPropagator").finish_non_exhaustive()
371 }
372}
373
374impl ra2a::client::CallInterceptor for ClientPropagator {
375 fn before<'a>(
376 &'a self,
377 req: &'a mut ra2a::client::Request,
378 ) -> Pin<Box<dyn Future<Output = ra2a::error::Result<()>> + Send + 'a>> {
379 Box::pin(async move {
380 let Some(prop_ctx) = PropagatorContext::current() else {
381 return Ok(());
382 };
383
384 let requested: Vec<String> = prop_ctx
386 .request_headers
387 .get(EXTENSIONS_META_KEY)
388 .cloned()
389 .unwrap_or_default();
390
391 if !prop_ctx.metadata.is_empty() {
393 inject_metadata(
394 &mut *req.payload,
395 &prop_ctx.metadata,
396 req.card.as_ref(),
397 &requested,
398 &self.metadata_predicate,
399 );
400 }
401
402 for (header_name, header_values) in &prop_ctx.request_headers {
404 for header_value in header_values {
405 if (self.header_predicate)(req.card.as_ref(), header_name, header_value) {
406 req.meta.append(header_name, header_value);
407 }
408 }
409 }
410
411 Ok(())
412 })
413 }
414}
415
416fn inject_metadata(
418 payload: &mut dyn std::any::Any,
419 metadata: &HashMap<String, serde_json::Value>,
420 card: Option<&AgentCard>,
421 requested: &[String],
422 predicate: &ClientMetadataPredicate,
423) {
424 if let Some(params) = payload.downcast_mut::<ra2a::MessageSendParams>() {
425 inject_matching_metadata(&mut params.metadata, metadata, card, requested, predicate);
426 } else if let Some(params) = payload.downcast_mut::<ra2a::TaskQueryParams>() {
427 inject_matching_metadata(&mut params.metadata, metadata, card, requested, predicate);
428 } else if let Some(params) = payload.downcast_mut::<ra2a::TaskIdParams>() {
429 inject_matching_metadata(&mut params.metadata, metadata, card, requested, predicate);
430 }
431}
432
433fn inject_matching_metadata(
435 target: &mut ra2a::Metadata,
436 source: &HashMap<String, serde_json::Value>,
437 card: Option<&AgentCard>,
438 requested: &[String],
439 predicate: &ClientMetadataPredicate,
440) {
441 for (k, v) in source {
442 if predicate(card, requested, k) {
443 target.insert(k.clone(), v.clone());
444 }
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use ra2a::client::{CallInterceptor as _, CallMeta};
451 use ra2a::types::{AgentCapabilities, AgentCard, AgentExtension};
452
453 use super::*;
454
455 fn make_card(uris: &[&str]) -> AgentCard {
456 AgentCard {
457 name: "test".into(),
458 url: "https://example.com".into(),
459 version: "1.0".into(),
460 capabilities: AgentCapabilities {
461 extensions: uris
462 .iter()
463 .map(|u| AgentExtension {
464 uri: (*u).into(),
465 description: String::new(),
466 required: false,
467 params: Default::default(),
468 })
469 .collect(),
470 ..AgentCapabilities::default()
471 },
472 skills: vec![],
473 ..AgentCard::default()
474 }
475 }
476
477 #[tokio::test]
478 async fn test_client_propagator_injects_headers() {
479 let propagator = ClientPropagator::new();
480 let card = make_card(&["urn:a2a:ext:duration"]);
481
482 let mut prop_ctx = PropagatorContext::default();
483 prop_ctx.request_headers.insert(
484 EXTENSIONS_META_KEY.to_owned(),
485 vec!["urn:a2a:ext:duration".into()],
486 );
487
488 let mut req = ra2a::client::Request {
489 method: "message/send".into(),
490 base_url: "https://example.com".into(),
491 meta: CallMeta::default(),
492 card: Some(card),
493 payload: Box::new(()),
494 };
495
496 prop_ctx
498 .scope(async {
499 propagator.before(&mut req).await.unwrap();
500 })
501 .await;
502
503 let vals = req.meta.get_all(EXTENSIONS_META_KEY);
504 assert_eq!(vals, &["urn:a2a:ext:duration"]);
505 }
506
507 #[tokio::test]
508 async fn test_client_propagator_filters_unsupported() {
509 let propagator = ClientPropagator::new();
510 let card = make_card(&["urn:a2a:ext:other"]);
511
512 let mut prop_ctx = PropagatorContext::default();
513 prop_ctx.request_headers.insert(
514 EXTENSIONS_META_KEY.to_owned(),
515 vec!["urn:a2a:ext:duration".into()],
516 );
517
518 let mut req = ra2a::client::Request {
519 method: "message/send".into(),
520 base_url: "https://example.com".into(),
521 meta: CallMeta::default(),
522 card: Some(card),
523 payload: Box::new(()),
524 };
525
526 prop_ctx
527 .scope(async {
528 propagator.before(&mut req).await.unwrap();
529 })
530 .await;
531
532 let vals = req.meta.get_all(EXTENSIONS_META_KEY);
533 assert!(vals.is_empty());
534 }
535
536 #[tokio::test]
537 async fn test_client_propagator_no_context_is_noop() {
538 let propagator = ClientPropagator::new();
539
540 let mut req = ra2a::client::Request {
541 method: "message/send".into(),
542 base_url: "https://example.com".into(),
543 meta: CallMeta::default(),
544 card: None,
545 payload: Box::new(()),
546 };
547
548 propagator.before(&mut req).await.unwrap();
550 assert!(req.meta.is_empty());
551 }
552
553 #[tokio::test]
554 async fn test_propagator_context_install_and_read() {
555 let ctx = PropagatorContext {
556 request_headers: {
557 let mut m = HashMap::new();
558 m.insert("x-test".into(), vec!["val1".into()]);
559 m
560 },
561 metadata: HashMap::new(),
562 };
563
564 init_propagation(async {
565 assert!(PropagatorContext::current().is_none());
566 assert!(ctx.install());
567 let read = PropagatorContext::current().unwrap();
568 assert_eq!(
569 read.request_headers.get("x-test").unwrap(),
570 &["val1".to_owned()]
571 );
572 })
573 .await;
574 }
575}