1use std::cell::RefCell;
22use std::collections::HashMap;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::Arc;
26
27use ra2a::SVC_PARAM_EXTENSIONS;
28use ra2a::error::A2AError;
29use ra2a::types::AgentCard;
30
31use crate::util::is_extension_supported;
32
33tokio::task_local! {
34 static PROPAGATOR_CTX: RefCell<Option<PropagatorContext>>;
37}
38
39#[derive(Debug, Clone, Default)]
43#[non_exhaustive]
44pub struct PropagatorContext {
45 pub request_headers: HashMap<String, Vec<String>>,
47 pub metadata: HashMap<String, serde_json::Value>,
49}
50
51impl PropagatorContext {
52 #[must_use]
54 pub fn current() -> Option<Self> {
55 PROPAGATOR_CTX
56 .try_with(|cell| cell.borrow().clone())
57 .ok()
58 .flatten()
59 }
60
61 pub fn install(self) -> bool {
66 PROPAGATOR_CTX
67 .try_with(|cell| {
68 *cell.borrow_mut() = Some(self);
69 })
70 .is_ok()
71 }
72
73 pub async fn scope<F: Future>(self, f: F) -> F::Output {
78 PROPAGATOR_CTX.scope(RefCell::new(Some(self)), f).await
79 }
80}
81
82pub async fn init_propagation<F: Future>(f: F) -> F::Output {
98 PROPAGATOR_CTX.scope(RefCell::new(None), f).await
99}
100
101pub type ServerMetadataPredicate = Arc<dyn Fn(&[String], &str) -> bool + Send + Sync>;
106
107pub type ServerHeaderPredicate = Arc<dyn Fn(&str) -> bool + Send + Sync>;
111
112#[derive(Default)]
116#[non_exhaustive]
117pub struct ServerPropagatorConfig {
118 pub metadata_predicate: Option<ServerMetadataPredicate>,
122 pub header_predicate: Option<ServerHeaderPredicate>,
126}
127
128impl std::fmt::Debug for ServerPropagatorConfig {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_struct("ServerPropagatorConfig")
131 .field("metadata_predicate", &self.metadata_predicate.is_some())
132 .field("header_predicate", &self.header_predicate.is_some())
133 .finish()
134 }
135}
136
137pub struct ServerPropagator {
146 metadata_predicate: ServerMetadataPredicate,
148 header_predicate: ServerHeaderPredicate,
150}
151
152impl ServerPropagator {
153 pub fn new() -> Self {
159 Self::with_config(ServerPropagatorConfig::default())
160 }
161
162 pub fn with_config(config: ServerPropagatorConfig) -> Self {
164 let metadata_predicate = config.metadata_predicate.unwrap_or_else(|| {
165 Arc::new(|requested_uris: &[String], key: &str| requested_uris.iter().any(|u| u == key))
166 });
167
168 let header_predicate = config.header_predicate.unwrap_or_else(|| {
169 Arc::new(|key: &str| key.eq_ignore_ascii_case(SVC_PARAM_EXTENSIONS))
170 });
171
172 Self {
173 metadata_predicate,
174 header_predicate,
175 }
176 }
177}
178
179impl Default for ServerPropagator {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185impl std::fmt::Debug for ServerPropagator {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 f.debug_struct("ServerPropagator").finish_non_exhaustive()
188 }
189}
190
191impl ra2a::server::CallInterceptor for ServerPropagator {
192 fn before<'a>(
193 &'a self,
194 ctx: &'a mut ra2a::server::CallContext,
195 req: &'a mut ra2a::server::Request,
196 ) -> Pin<Box<dyn Future<Output = Result<(), A2AError>> + Send + 'a>> {
197 Box::pin(async move {
198 let mut prop_ctx = PropagatorContext::default();
199
200 let requested = ctx.requested_extension_uris();
202
203 extract_metadata(
205 req,
206 &requested,
207 &self.metadata_predicate,
208 &mut prop_ctx.metadata,
209 );
210
211 let request_meta = ctx.request_meta();
213 for (header_name, header_values) in request_meta.iter() {
214 if (self.header_predicate)(header_name) {
215 prop_ctx
216 .request_headers
217 .insert(header_name.to_owned(), header_values.to_vec());
218 }
219 }
220
221 if let Some(ext_values) = prop_ctx.request_headers.get(SVC_PARAM_EXTENSIONS) {
223 for uri in ext_values {
224 ctx.activate_extension(uri);
225 }
226 }
227
228 prop_ctx.install();
230
231 Ok(())
232 })
233 }
234
235 fn after<'a>(
236 &'a self,
237 _ctx: &'a ra2a::server::CallContext,
238 _resp: &'a mut ra2a::server::Response,
239 ) -> Pin<Box<dyn Future<Output = Result<(), A2AError>> + Send + 'a>> {
240 Box::pin(async { Ok(()) })
241 }
242}
243
244fn extract_metadata(
246 req: &ra2a::server::Request,
247 requested: &[String],
248 predicate: &ServerMetadataPredicate,
249 out: &mut HashMap<String, serde_json::Value>,
250) {
251 if let Some(params) = req.downcast_ref::<ra2a::SendMessageRequest>()
252 && let Some(ref meta) = params.metadata
253 {
254 collect_matching_metadata(meta, requested, predicate, out);
255 }
256}
257
258fn collect_matching_metadata(
260 metadata: &ra2a::Metadata,
261 requested: &[String],
262 predicate: &ServerMetadataPredicate,
263 out: &mut HashMap<String, serde_json::Value>,
264) {
265 for (k, v) in metadata {
266 if predicate(requested, k) {
267 out.insert(k.clone(), v.clone());
268 }
269 }
270}
271
272pub type ClientMetadataPredicate =
277 Arc<dyn Fn(Option<&AgentCard>, &[String], &str) -> bool + Send + Sync>;
278
279pub type ClientHeaderPredicate = Arc<dyn Fn(Option<&AgentCard>, &str, &str) -> bool + Send + Sync>;
284
285#[derive(Default)]
287#[non_exhaustive]
288pub struct ClientPropagatorConfig {
289 pub metadata_predicate: Option<ClientMetadataPredicate>,
294 pub header_predicate: Option<ClientHeaderPredicate>,
299}
300
301impl std::fmt::Debug for ClientPropagatorConfig {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 f.debug_struct("ClientPropagatorConfig")
304 .field("metadata_predicate", &self.metadata_predicate.is_some())
305 .field("header_predicate", &self.header_predicate.is_some())
306 .finish()
307 }
308}
309
310pub struct ClientPropagator {
318 metadata_predicate: ClientMetadataPredicate,
320 header_predicate: ClientHeaderPredicate,
322}
323
324impl ClientPropagator {
325 pub fn new() -> Self {
327 Self::with_config(ClientPropagatorConfig::default())
328 }
329
330 pub fn with_config(config: ClientPropagatorConfig) -> Self {
332 let metadata_predicate = config.metadata_predicate.unwrap_or_else(|| {
333 Arc::new(
334 |card: Option<&AgentCard>, requested: &[String], key: &str| {
335 if !requested.iter().any(|u| u == key) {
336 return false;
337 }
338 is_extension_supported(card, key)
339 },
340 )
341 });
342
343 let header_predicate = config.header_predicate.unwrap_or_else(|| {
344 Arc::new(|card: Option<&AgentCard>, key: &str, val: &str| {
345 if !key.eq_ignore_ascii_case(SVC_PARAM_EXTENSIONS) {
346 return false;
347 }
348 is_extension_supported(card, val)
349 })
350 });
351
352 Self {
353 metadata_predicate,
354 header_predicate,
355 }
356 }
357}
358
359impl Default for ClientPropagator {
360 fn default() -> Self {
361 Self::new()
362 }
363}
364
365impl std::fmt::Debug for ClientPropagator {
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 f.debug_struct("ClientPropagator").finish_non_exhaustive()
368 }
369}
370
371impl ra2a::client::CallInterceptor for ClientPropagator {
372 fn before<'a>(
373 &'a self,
374 req: &'a mut ra2a::client::Request,
375 ) -> Pin<Box<dyn Future<Output = ra2a::error::Result<()>> + Send + 'a>> {
376 Box::pin(async move {
377 let Some(prop_ctx) = PropagatorContext::current() else {
378 return Ok(());
379 };
380
381 let requested: Vec<String> = prop_ctx
383 .request_headers
384 .get(SVC_PARAM_EXTENSIONS)
385 .cloned()
386 .unwrap_or_default();
387
388 if !prop_ctx.metadata.is_empty() {
390 inject_metadata(
391 &mut *req.payload,
392 &prop_ctx.metadata,
393 req.card.as_ref(),
394 &requested,
395 &self.metadata_predicate,
396 );
397 }
398
399 for (header_name, header_values) in &prop_ctx.request_headers {
401 for header_value in header_values {
402 if (self.header_predicate)(req.card.as_ref(), header_name, header_value) {
403 req.service_params.append(header_name, header_value);
404 }
405 }
406 }
407
408 Ok(())
409 })
410 }
411}
412
413fn inject_metadata(
415 payload: &mut dyn std::any::Any,
416 metadata: &HashMap<String, serde_json::Value>,
417 card: Option<&AgentCard>,
418 requested: &[String],
419 predicate: &ClientMetadataPredicate,
420) {
421 if let Some(params) = payload.downcast_mut::<ra2a::SendMessageRequest>() {
422 let meta = params.metadata.get_or_insert_with(Default::default);
423 inject_matching_metadata(meta, metadata, card, requested, predicate);
424 }
425}
426
427fn inject_matching_metadata(
429 target: &mut ra2a::Metadata,
430 source: &HashMap<String, serde_json::Value>,
431 card: Option<&AgentCard>,
432 requested: &[String],
433 predicate: &ClientMetadataPredicate,
434) {
435 for (k, v) in source {
436 if predicate(card, requested, k) {
437 target.insert(k.clone(), v.clone());
438 }
439 }
440}
441
442#[cfg(test)]
443#[allow(clippy::unwrap_used)]
444mod tests {
445 use ra2a::client::{CallInterceptor as _, ServiceParams};
446 use ra2a::types::{
447 AgentCapabilities, AgentCard, AgentExtension, AgentInterface, TransportProtocol,
448 };
449
450 use super::*;
451
452 fn make_card(uris: &[&str]) -> AgentCard {
453 let mut card = AgentCard::new(
454 "test",
455 "test agent",
456 vec![AgentInterface::new(
457 "https://example.com",
458 TransportProtocol::new("JSONRPC"),
459 )],
460 );
461 card.capabilities = AgentCapabilities {
462 extensions: uris
463 .iter()
464 .map(|u| AgentExtension {
465 uri: (*u).into(),
466 description: None,
467 required: false,
468 params: None,
469 })
470 .collect(),
471 ..AgentCapabilities::default()
472 };
473 card
474 }
475
476 #[tokio::test]
477 async fn test_client_propagator_injects_headers() {
478 let propagator = ClientPropagator::new();
479 let card = make_card(&["urn:a2a:ext:duration"]);
480
481 let mut prop_ctx = PropagatorContext::default();
482 prop_ctx.request_headers.insert(
483 SVC_PARAM_EXTENSIONS.to_owned(),
484 vec!["urn:a2a:ext:duration".into()],
485 );
486
487 let mut req = ra2a::client::Request {
488 method: "message/send".into(),
489 service_params: ServiceParams::default(),
490 card: Some(card),
491 payload: Box::new(()),
492 };
493
494 prop_ctx
495 .scope(async {
496 propagator.before(&mut req).await.unwrap();
497 })
498 .await;
499
500 let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
501 assert_eq!(vals, &["urn:a2a:ext:duration"]);
502 }
503
504 #[tokio::test]
505 async fn test_client_propagator_filters_unsupported() {
506 let propagator = ClientPropagator::new();
507 let card = make_card(&["urn:a2a:ext:other"]);
508
509 let mut prop_ctx = PropagatorContext::default();
510 prop_ctx.request_headers.insert(
511 SVC_PARAM_EXTENSIONS.to_owned(),
512 vec!["urn:a2a:ext:duration".into()],
513 );
514
515 let mut req = ra2a::client::Request {
516 method: "message/send".into(),
517 service_params: ServiceParams::default(),
518 card: Some(card),
519 payload: Box::new(()),
520 };
521
522 prop_ctx
523 .scope(async {
524 propagator.before(&mut req).await.unwrap();
525 })
526 .await;
527
528 let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
529 assert!(vals.is_empty());
530 }
531
532 #[tokio::test]
533 async fn test_client_propagator_no_context_is_noop() {
534 let propagator = ClientPropagator::new();
535
536 let mut req = ra2a::client::Request {
537 method: "message/send".into(),
538 service_params: ServiceParams::default(),
539 card: None,
540 payload: Box::new(()),
541 };
542
543 propagator.before(&mut req).await.unwrap();
544 assert!(req.service_params.is_empty());
545 }
546
547 #[tokio::test]
548 async fn test_propagator_context_install_and_read() {
549 let ctx = PropagatorContext {
550 request_headers: {
551 let mut m = HashMap::new();
552 m.insert("x-test".into(), vec!["val1".into()]);
553 m
554 },
555 metadata: HashMap::new(),
556 };
557
558 init_propagation(async {
559 assert!(PropagatorContext::current().is_none());
560 assert!(ctx.install());
561 let read = PropagatorContext::current().unwrap();
562 assert_eq!(
563 read.request_headers.get("x-test").unwrap(),
564 &["val1".to_owned()]
565 );
566 })
567 .await;
568 }
569}