Skip to main content

ra2a_ext/
propagator.rs

1//! Extension propagation interceptors for agent-to-agent chaining.
2//!
3//! Aligned with Go's `a2aext.NewServerPropagator` and `a2aext.NewClientPropagator`.
4//!
5//! When an agent (B) acts as both server and client in a chain (A → B → C),
6//! the [`ServerPropagator`] extracts extension-related metadata and headers
7//! from the incoming request (A → B), stores them in a [`PropagatorContext`],
8//! and the [`ClientPropagator`] injects them into the outgoing request (B → C).
9//!
10//! ## Data flow
11//!
12//! ```text
13//! A → [HTTP] → B (ServerPropagator.before extracts → PropagatorContext)
14//!                   → handler wraps executor in PropagatorContext::scope()
15//!                     → B calls C (ClientPropagator.before injects from task_local)
16//! ```
17//!
18//! The user must wrap downstream client calls within [`PropagatorContext::scope()`]
19//! so that the [`ClientPropagator`] can access the extracted data via `task_local`.
20
21use std::cell::RefCell;
22use std::collections::HashMap;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::Arc;
26
27use ra2a::SVC_PARAM_EXTENSIONS;
28use ra2a::error::A2AError;
29use ra2a::types::AgentCard;
30
31use crate::util::is_extension_supported;
32
33tokio::task_local! {
34    /// Mutable cell for propagator data. Must be initialized via
35    /// [`init_propagation`] before [`ServerPropagator`] can store data.
36    static PROPAGATOR_CTX: RefCell<Option<PropagatorContext>>;
37}
38
39/// Extension data extracted by [`ServerPropagator`] for downstream propagation.
40///
41/// Aligned with Go's internal `propagatorContext` struct.
42#[derive(Debug, Clone, Default)]
43#[non_exhaustive]
44pub struct PropagatorContext {
45    /// HTTP headers to propagate (key → values).
46    pub request_headers: HashMap<String, Vec<String>>,
47    /// Payload metadata to propagate (key → value).
48    pub metadata: HashMap<String, serde_json::Value>,
49}
50
51impl PropagatorContext {
52    /// Reads the current task-local propagator context, if set.
53    #[must_use]
54    pub fn current() -> Option<Self> {
55        PROPAGATOR_CTX
56            .try_with(|cell| cell.borrow().clone())
57            .ok()
58            .flatten()
59    }
60
61    /// Stores this context in the task-local cell.
62    ///
63    /// Requires that the current task is running within [`init_propagation`].
64    /// Returns `true` if stored successfully.
65    #[must_use]
66    pub fn install(self) -> bool {
67        PROPAGATOR_CTX
68            .try_with(|cell| {
69                *cell.borrow_mut() = Some(self);
70            })
71            .is_ok()
72    }
73
74    /// Executes a future with this context directly available via task-local.
75    ///
76    /// This is a convenience wrapper for simple cases where you already have
77    /// the context and want to make it available to [`ClientPropagator`].
78    pub async fn scope<F: Future>(self, f: F) -> F::Output {
79        PROPAGATOR_CTX.scope(RefCell::new(Some(self)), f).await
80    }
81}
82
83/// Wraps a future with an empty propagation scope.
84///
85/// Call this around your request handler so that [`ServerPropagator`] can store
86/// extracted data and [`ClientPropagator`] can read it later.
87///
88/// # Example
89///
90/// ```rust,ignore
91/// let result = ra2a_ext::init_propagation(async {
92///     // ServerPropagator.before() stores data here
93///     // handler runs
94///     // ClientPropagator.before() reads data here
95///     handle_request(req).await
96/// }).await;
97/// ```
98pub async fn init_propagation<F: Future>(f: F) -> F::Output {
99    PROPAGATOR_CTX.scope(RefCell::new(None), f).await
100}
101
102/// Predicate function for filtering metadata keys on the server side.
103///
104/// Receives the list of requested extension URIs and the metadata key.
105/// Returns `true` if the key should be propagated.
106pub(crate) type ServerMetadataPredicate = Arc<dyn Fn(&[String], &str) -> bool + Send + Sync>;
107
108/// Predicate function for filtering request headers on the server side.
109///
110/// Receives the header key. Returns `true` if the header should be propagated.
111pub(crate) type ServerHeaderPredicate = Arc<dyn Fn(&str) -> bool + Send + Sync>;
112
113/// Configuration for [`ServerPropagator`].
114///
115/// Both predicates are optional — sensible defaults are used when `None`.
116#[derive(Default)]
117#[non_exhaustive]
118pub struct ServerPropagatorConfig {
119    /// Determines which payload metadata keys are propagated.
120    ///
121    /// Default: propagate keys whose name matches a client-requested extension URI.
122    pub metadata_predicate: Option<ServerMetadataPredicate>,
123    /// Determines which request headers are propagated.
124    ///
125    /// Default: propagate only the `x-a2a-extensions` header.
126    pub header_predicate: Option<ServerHeaderPredicate>,
127}
128
129impl std::fmt::Debug for ServerPropagatorConfig {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.debug_struct("ServerPropagatorConfig")
132            .field("metadata_predicate", &self.metadata_predicate.is_some())
133            .field("header_predicate", &self.header_predicate.is_some())
134            .finish()
135    }
136}
137
138/// Server-side [`CallInterceptor`](ra2a::server::CallInterceptor) that extracts
139/// extension-related metadata and headers from incoming requests.
140///
141/// The extracted data is stored in a [`PropagatorContext`] via `task_local`.
142/// The handler must be wrapped in [`init_propagation`] for this to work.
143/// [`ClientPropagator`] reads the stored context when making downstream calls.
144///
145/// Aligned with Go's `a2aext.NewServerPropagator`.
146pub struct ServerPropagator {
147    /// Metadata filter predicate.
148    metadata_predicate: ServerMetadataPredicate,
149    /// Header filter predicate.
150    header_predicate: ServerHeaderPredicate,
151}
152
153impl ServerPropagator {
154    /// Creates a new server propagator with default configuration.
155    ///
156    /// Default behavior:
157    /// - Propagates metadata keys matching client-requested extension URIs
158    /// - Propagates the `x-a2a-extensions` header
159    #[must_use]
160    pub fn new() -> Self {
161        Self::with_config(ServerPropagatorConfig::default())
162    }
163
164    /// Creates a new server propagator with custom configuration.
165    #[must_use]
166    pub fn with_config(config: ServerPropagatorConfig) -> Self {
167        let metadata_predicate = config.metadata_predicate.unwrap_or_else(|| {
168            Arc::new(|requested_uris: &[String], key: &str| requested_uris.iter().any(|u| u == key))
169        });
170
171        let header_predicate = config.header_predicate.unwrap_or_else(|| {
172            Arc::new(|key: &str| key.eq_ignore_ascii_case(SVC_PARAM_EXTENSIONS))
173        });
174
175        Self {
176            metadata_predicate,
177            header_predicate,
178        }
179    }
180}
181
182impl Default for ServerPropagator {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188impl std::fmt::Debug for ServerPropagator {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        f.debug_struct("ServerPropagator").finish_non_exhaustive()
191    }
192}
193
194impl ServerPropagator {
195    /// Extracts extension metadata from the incoming request into the call context.
196    fn propagate_server(&self, ctx: &mut ra2a::server::CallContext, req: &ra2a::server::Request) {
197        let mut prop_ctx = PropagatorContext::default();
198
199        let requested = ctx.requested_extension_uris();
200
201        extract_metadata(
202            req,
203            &requested,
204            &self.metadata_predicate,
205            &mut prop_ctx.metadata,
206        );
207
208        let request_meta = ctx.request_meta();
209        for (header_name, header_values) in request_meta.iter() {
210            if (self.header_predicate)(header_name) {
211                prop_ctx
212                    .request_headers
213                    .insert(header_name.to_owned(), header_values.to_vec());
214            }
215        }
216
217        if let Some(ext_values) = prop_ctx.request_headers.get(SVC_PARAM_EXTENSIONS) {
218            for uri in ext_values {
219                ctx.activate_extension(uri);
220            }
221        }
222
223        // Best-effort install; fails silently if not inside init_propagation.
224        let _installed = prop_ctx.install();
225    }
226}
227
228impl ra2a::server::CallInterceptor for ServerPropagator {
229    fn before<'a>(
230        &'a self,
231        ctx: &'a mut ra2a::server::CallContext,
232        req: &'a mut ra2a::server::Request,
233    ) -> Pin<Box<dyn Future<Output = Result<(), A2AError>> + Send + 'a>> {
234        self.propagate_server(ctx, req);
235        Box::pin(std::future::ready(Ok(())))
236    }
237
238    fn after<'a>(
239        &'a self,
240        _ctx: &'a ra2a::server::CallContext,
241        _resp: &'a mut ra2a::server::Response,
242    ) -> Pin<Box<dyn Future<Output = Result<(), A2AError>> + Send + 'a>> {
243        Box::pin(async { Ok(()) })
244    }
245}
246
247/// Extracts matching metadata from known request payload types.
248fn extract_metadata(
249    req: &ra2a::server::Request,
250    requested: &[String],
251    predicate: &ServerMetadataPredicate,
252    out: &mut HashMap<String, serde_json::Value>,
253) {
254    if let Some(params) = req.downcast_ref::<ra2a::SendMessageRequest>()
255        && let Some(ref meta) = params.metadata
256    {
257        collect_matching_metadata(meta, requested, predicate, out);
258    }
259}
260
261/// Collects metadata entries that pass the predicate.
262fn collect_matching_metadata(
263    metadata: &ra2a::Metadata,
264    requested: &[String],
265    predicate: &ServerMetadataPredicate,
266    out: &mut HashMap<String, serde_json::Value>,
267) {
268    for (k, v) in metadata {
269        if predicate(requested, k) {
270            out.insert(k.clone(), v.clone());
271        }
272    }
273}
274
275/// Predicate function for filtering metadata keys on the client side.
276///
277/// Receives the target server's agent card (if available), the list of
278/// requested extension URIs, and the metadata key.
279pub(crate) type ClientMetadataPredicate =
280    Arc<dyn Fn(Option<&AgentCard>, &[String], &str) -> bool + Send + Sync>;
281
282/// Predicate function for filtering request headers on the client side.
283///
284/// Receives the target server's agent card (if available), the header key
285/// and value. Returns `true` if the header should be forwarded.
286pub(crate) type ClientHeaderPredicate =
287    Arc<dyn Fn(Option<&AgentCard>, &str, &str) -> bool + Send + Sync>;
288
289/// Configuration for [`ClientPropagator`].
290#[derive(Default)]
291#[non_exhaustive]
292pub struct ClientPropagatorConfig {
293    /// Determines which payload metadata keys are propagated.
294    ///
295    /// Default: propagate keys that are requested extensions and supported by
296    /// the downstream server.
297    pub metadata_predicate: Option<ClientMetadataPredicate>,
298    /// Determines which request headers are propagated.
299    ///
300    /// Default: propagate `x-a2a-extensions` header values for extensions
301    /// supported by the downstream server.
302    pub header_predicate: Option<ClientHeaderPredicate>,
303}
304
305impl std::fmt::Debug for ClientPropagatorConfig {
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        f.debug_struct("ClientPropagatorConfig")
308            .field("metadata_predicate", &self.metadata_predicate.is_some())
309            .field("header_predicate", &self.header_predicate.is_some())
310            .finish()
311    }
312}
313
314/// Client-side [`CallInterceptor`](ra2a::client::CallInterceptor) that injects
315/// propagated extension data into outgoing requests.
316///
317/// Reads [`PropagatorContext`] from the task-local (set by [`ServerPropagator`])
318/// and injects matching metadata and headers into the outgoing request.
319///
320/// Aligned with Go's `a2aext.NewClientPropagator`.
321pub struct ClientPropagator {
322    /// Metadata filter predicate.
323    metadata_predicate: ClientMetadataPredicate,
324    /// Header filter predicate.
325    header_predicate: ClientHeaderPredicate,
326}
327
328impl ClientPropagator {
329    /// Creates a new client propagator with default configuration.
330    #[must_use]
331    pub fn new() -> Self {
332        Self::with_config(ClientPropagatorConfig::default())
333    }
334
335    /// Creates a new client propagator with custom configuration.
336    #[must_use]
337    pub fn with_config(config: ClientPropagatorConfig) -> Self {
338        let metadata_predicate = config
339            .metadata_predicate
340            .unwrap_or_else(|| Arc::new(default_client_metadata_predicate));
341
342        let header_predicate = config
343            .header_predicate
344            .unwrap_or_else(|| Arc::new(default_client_header_predicate));
345
346        Self {
347            metadata_predicate,
348            header_predicate,
349        }
350    }
351}
352
353impl Default for ClientPropagator {
354    fn default() -> Self {
355        Self::new()
356    }
357}
358
359impl std::fmt::Debug for ClientPropagator {
360    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361        f.debug_struct("ClientPropagator").finish_non_exhaustive()
362    }
363}
364
365impl ClientPropagator {
366    /// Injects extension metadata from the current propagator context into the outgoing request.
367    fn propagate_client(&self, req: &mut ra2a::client::Request) {
368        let Some(prop_ctx) = PropagatorContext::current() else {
369            return;
370        };
371
372        let requested: Vec<String> = prop_ctx
373            .request_headers
374            .get(SVC_PARAM_EXTENSIONS)
375            .cloned()
376            .unwrap_or_default();
377
378        if !prop_ctx.metadata.is_empty() {
379            inject_metadata(
380                &mut *req.payload,
381                &prop_ctx.metadata,
382                req.card.as_ref(),
383                &requested,
384                &self.metadata_predicate,
385            );
386        }
387
388        for (name, val) in prop_ctx
389            .request_headers
390            .iter()
391            .flat_map(|(k, vs)| vs.iter().map(move |v| (k, v)))
392        {
393            if (self.header_predicate)(req.card.as_ref(), name, val) {
394                req.service_params.append(name, val);
395            }
396        }
397    }
398}
399
400impl ra2a::client::CallInterceptor for ClientPropagator {
401    fn before<'a>(
402        &'a self,
403        req: &'a mut ra2a::client::Request,
404    ) -> Pin<Box<dyn Future<Output = ra2a::error::Result<()>> + Send + 'a>> {
405        self.propagate_client(req);
406        Box::pin(std::future::ready(Ok(())))
407    }
408}
409
410/// Default metadata predicate for [`ClientPropagator`]: propagates keys that
411/// are in the requested list and supported by the downstream agent card.
412fn default_client_metadata_predicate(
413    card: Option<&AgentCard>,
414    requested: &[String],
415    key: &str,
416) -> bool {
417    requested.iter().any(|u| u == key) && is_extension_supported(card, key)
418}
419
420/// Default header predicate for [`ClientPropagator`]: propagates
421/// `x-a2a-extensions` header values for extensions supported by the downstream agent.
422fn default_client_header_predicate(card: Option<&AgentCard>, key: &str, val: &str) -> bool {
423    key.eq_ignore_ascii_case(SVC_PARAM_EXTENSIONS) && is_extension_supported(card, val)
424}
425
426/// Injects matching metadata into known outgoing payload types.
427fn inject_metadata(
428    payload: &mut dyn std::any::Any,
429    metadata: &HashMap<String, serde_json::Value>,
430    card: Option<&AgentCard>,
431    requested: &[String],
432    predicate: &ClientMetadataPredicate,
433) {
434    if let Some(params) = payload.downcast_mut::<ra2a::SendMessageRequest>() {
435        let meta = params.metadata.get_or_insert_with(Default::default);
436        inject_matching_metadata(meta, metadata, card, requested, predicate);
437    }
438}
439
440/// Inserts metadata entries that pass the predicate into the target map.
441fn inject_matching_metadata(
442    target: &mut ra2a::Metadata,
443    source: &HashMap<String, serde_json::Value>,
444    card: Option<&AgentCard>,
445    requested: &[String],
446    predicate: &ClientMetadataPredicate,
447) {
448    for (k, v) in source {
449        if predicate(card, requested, k) {
450            target.insert(k.clone(), v.clone());
451        }
452    }
453}
454
455#[cfg(test)]
456#[allow(clippy::unwrap_used, reason = "tests use unwrap for brevity")]
457mod tests {
458    use ra2a::client::{CallInterceptor as _, ServiceParams};
459    use ra2a::types::{
460        AgentCapabilities, AgentCard, AgentExtension, AgentInterface, TransportProtocol,
461    };
462
463    use super::*;
464
465    fn make_card(uris: &[&str]) -> AgentCard {
466        let mut card = AgentCard::new(
467            "test",
468            "test agent",
469            vec![AgentInterface::new(
470                "https://example.com",
471                TransportProtocol::new("JSONRPC"),
472            )],
473        );
474        card.capabilities = AgentCapabilities {
475            extensions: uris
476                .iter()
477                .map(|u| AgentExtension {
478                    uri: (*u).into(),
479                    description: None,
480                    required: false,
481                    params: None,
482                })
483                .collect(),
484            ..AgentCapabilities::default()
485        };
486        card
487    }
488
489    #[tokio::test]
490    async fn test_client_propagator_injects_headers() {
491        let propagator = ClientPropagator::new();
492        let card = make_card(&["urn:a2a:ext:duration"]);
493
494        let mut prop_ctx = PropagatorContext::default();
495        prop_ctx.request_headers.insert(
496            SVC_PARAM_EXTENSIONS.to_owned(),
497            vec!["urn:a2a:ext:duration".into()],
498        );
499
500        let mut req = ra2a::client::Request {
501            method: "message/send".into(),
502            service_params: ServiceParams::default(),
503            card: Some(card),
504            payload: Box::new(()),
505        };
506
507        prop_ctx
508            .scope(async {
509                propagator.before(&mut req).await.unwrap();
510            })
511            .await;
512
513        let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
514        assert_eq!(vals, &["urn:a2a:ext:duration"]);
515    }
516
517    #[tokio::test]
518    async fn test_client_propagator_filters_unsupported() {
519        let propagator = ClientPropagator::new();
520        let card = make_card(&["urn:a2a:ext:other"]);
521
522        let mut prop_ctx = PropagatorContext::default();
523        prop_ctx.request_headers.insert(
524            SVC_PARAM_EXTENSIONS.to_owned(),
525            vec!["urn:a2a:ext:duration".into()],
526        );
527
528        let mut req = ra2a::client::Request {
529            method: "message/send".into(),
530            service_params: ServiceParams::default(),
531            card: Some(card),
532            payload: Box::new(()),
533        };
534
535        prop_ctx
536            .scope(async {
537                propagator.before(&mut req).await.unwrap();
538            })
539            .await;
540
541        let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
542        assert!(vals.is_empty());
543    }
544
545    #[tokio::test]
546    async fn test_client_propagator_no_context_is_noop() {
547        let propagator = ClientPropagator::new();
548
549        let mut req = ra2a::client::Request {
550            method: "message/send".into(),
551            service_params: ServiceParams::default(),
552            card: None,
553            payload: Box::new(()),
554        };
555
556        propagator.before(&mut req).await.unwrap();
557        assert!(req.service_params.is_empty());
558    }
559
560    #[tokio::test]
561    async fn test_propagator_context_install_and_read() {
562        let ctx = PropagatorContext {
563            request_headers: {
564                let mut m = HashMap::new();
565                m.insert("x-test".into(), vec!["val1".into()]);
566                m
567            },
568            metadata: HashMap::new(),
569        };
570
571        init_propagation(async {
572            assert!(PropagatorContext::current().is_none());
573            assert!(ctx.install());
574            let read = PropagatorContext::current().unwrap();
575            assert_eq!(
576                read.request_headers.get("x-test").unwrap(),
577                &["val1".to_owned()]
578            );
579        })
580        .await;
581    }
582}