Skip to main content

ra2a_ext/
propagator.rs

1//! Extension propagation interceptors for agent-to-agent chaining.
2//!
3//! Aligned with Go's `a2aext.NewServerPropagator` and `a2aext.NewClientPropagator`.
4//!
5//! When an agent (B) acts as both server and client in a chain (A → B → C),
6//! the [`ServerPropagator`] extracts extension-related metadata and headers
7//! from the incoming request (A → B), stores them in a [`PropagatorContext`],
8//! and the [`ClientPropagator`] injects them into the outgoing request (B → C).
9//!
10//! ## Data flow
11//!
12//! ```text
13//! A → [HTTP] → B (ServerPropagator.before extracts → PropagatorContext)
14//!                   → handler wraps executor in PropagatorContext::scope()
15//!                     → B calls C (ClientPropagator.before injects from task_local)
16//! ```
17//!
18//! The user must wrap downstream client calls within [`PropagatorContext::scope()`]
19//! so that the [`ClientPropagator`] can access the extracted data via `task_local`.
20
21use std::cell::RefCell;
22use std::collections::HashMap;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::Arc;
26
27use ra2a::EXTENSIONS_META_KEY;
28use ra2a::error::A2AError;
29use ra2a::types::AgentCard;
30
31use crate::util::is_extension_supported;
32
33tokio::task_local! {
34    /// Mutable cell for propagator data. Must be initialized via
35    /// [`init_propagation`] before [`ServerPropagator`] can store data.
36    static PROPAGATOR_CTX: RefCell<Option<PropagatorContext>>;
37}
38
39/// Extension data extracted by [`ServerPropagator`] for downstream propagation.
40///
41/// Aligned with Go's internal `propagatorContext` struct.
42#[derive(Debug, Clone, Default)]
43#[non_exhaustive]
44pub struct PropagatorContext {
45    /// HTTP headers to propagate (key → values).
46    pub request_headers: HashMap<String, Vec<String>>,
47    /// Payload metadata to propagate (key → value).
48    pub metadata: HashMap<String, serde_json::Value>,
49}
50
51impl PropagatorContext {
52    /// Reads the current task-local propagator context, if set.
53    #[must_use]
54    pub fn current() -> Option<Self> {
55        PROPAGATOR_CTX
56            .try_with(|cell| cell.borrow().clone())
57            .ok()
58            .flatten()
59    }
60
61    /// Stores this context in the task-local cell.
62    ///
63    /// Requires that the current task is running within [`init_propagation`].
64    /// Returns `true` if stored successfully.
65    pub fn install(self) -> bool {
66        PROPAGATOR_CTX
67            .try_with(|cell| {
68                *cell.borrow_mut() = Some(self);
69            })
70            .is_ok()
71    }
72
73    /// Executes a future with this context directly available via task-local.
74    ///
75    /// This is a convenience wrapper for simple cases where you already have
76    /// the context and want to make it available to [`ClientPropagator`].
77    pub async fn scope<F: Future>(self, f: F) -> F::Output {
78        PROPAGATOR_CTX.scope(RefCell::new(Some(self)), f).await
79    }
80}
81
82/// Wraps a future with an empty propagation scope.
83///
84/// Call this around your request handler so that [`ServerPropagator`] can store
85/// extracted data and [`ClientPropagator`] can read it later.
86///
87/// # Example
88///
89/// ```rust,ignore
90/// let result = ra2a_ext::init_propagation(async {
91///     // ServerPropagator.before() stores data here
92///     // handler runs
93///     // ClientPropagator.before() reads data here
94///     handle_request(req).await
95/// }).await;
96/// ```
97pub async fn init_propagation<F: Future>(f: F) -> F::Output {
98    PROPAGATOR_CTX.scope(RefCell::new(None), f).await
99}
100
101/// Predicate function for filtering metadata keys on the server side.
102///
103/// Receives the list of requested extension URIs and the metadata key.
104/// Returns `true` if the key should be propagated.
105pub type ServerMetadataPredicate = Arc<dyn Fn(&[String], &str) -> bool + Send + Sync>;
106
107/// Predicate function for filtering request headers on the server side.
108///
109/// Receives the header key. Returns `true` if the header should be propagated.
110pub type ServerHeaderPredicate = Arc<dyn Fn(&str) -> bool + Send + Sync>;
111
112/// Configuration for [`ServerPropagator`].
113///
114/// Both predicates are optional — sensible defaults are used when `None`.
115#[derive(Default)]
116#[non_exhaustive]
117pub struct ServerPropagatorConfig {
118    /// Determines which payload metadata keys are propagated.
119    ///
120    /// Default: propagate keys whose name matches a client-requested extension URI.
121    pub metadata_predicate: Option<ServerMetadataPredicate>,
122    /// Determines which request headers are propagated.
123    ///
124    /// Default: propagate only the `x-a2a-extensions` header.
125    pub header_predicate: Option<ServerHeaderPredicate>,
126}
127
128impl std::fmt::Debug for ServerPropagatorConfig {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        f.debug_struct("ServerPropagatorConfig")
131            .field("metadata_predicate", &self.metadata_predicate.is_some())
132            .field("header_predicate", &self.header_predicate.is_some())
133            .finish()
134    }
135}
136
137/// Server-side [`CallInterceptor`](ra2a::server::CallInterceptor) that extracts
138/// extension-related metadata and headers from incoming requests.
139///
140/// The extracted data is stored in a [`PropagatorContext`] via `task_local`.
141/// The handler must be wrapped in [`init_propagation`] for this to work.
142/// [`ClientPropagator`] reads the stored context when making downstream calls.
143///
144/// Aligned with Go's `a2aext.NewServerPropagator`.
145pub struct ServerPropagator {
146    /// Metadata filter predicate.
147    metadata_predicate: ServerMetadataPredicate,
148    /// Header filter predicate.
149    header_predicate: ServerHeaderPredicate,
150}
151
152impl ServerPropagator {
153    /// Creates a new server propagator with default configuration.
154    ///
155    /// Default behavior:
156    /// - Propagates metadata keys matching client-requested extension URIs
157    /// - Propagates the `x-a2a-extensions` header
158    pub fn new() -> Self {
159        Self::with_config(ServerPropagatorConfig::default())
160    }
161
162    /// Creates a new server propagator with custom configuration.
163    pub fn with_config(config: ServerPropagatorConfig) -> Self {
164        let metadata_predicate = config.metadata_predicate.unwrap_or_else(|| {
165            Arc::new(|requested_uris: &[String], key: &str| requested_uris.iter().any(|u| u == key))
166        });
167
168        let header_predicate = config
169            .header_predicate
170            .unwrap_or_else(|| Arc::new(|key: &str| key.eq_ignore_ascii_case(EXTENSIONS_META_KEY)));
171
172        Self {
173            metadata_predicate,
174            header_predicate,
175        }
176    }
177}
178
179impl Default for ServerPropagator {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185impl std::fmt::Debug for ServerPropagator {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        f.debug_struct("ServerPropagator").finish_non_exhaustive()
188    }
189}
190
191impl ra2a::server::CallInterceptor for ServerPropagator {
192    fn before<'a>(
193        &'a self,
194        ctx: &'a mut ra2a::server::CallContext,
195        req: &'a mut ra2a::server::Request,
196    ) -> Pin<Box<dyn Future<Output = Result<(), A2AError>> + Send + 'a>> {
197        Box::pin(async move {
198            let mut prop_ctx = PropagatorContext::default();
199
200            // Collect requested extension URIs for the metadata predicate.
201            let requested = ctx.requested_extension_uris();
202
203            // Extract matching metadata from the request payload.
204            extract_metadata(
205                req,
206                &requested,
207                &self.metadata_predicate,
208                &mut prop_ctx.metadata,
209            );
210
211            // Extract matching headers from request metadata.
212            let request_meta = ctx.request_meta();
213            for (header_name, header_values) in request_meta.iter() {
214                if (self.header_predicate)(header_name) {
215                    prop_ctx
216                        .request_headers
217                        .insert(header_name.to_owned(), header_values.to_vec());
218                }
219            }
220
221            // Also activate extensions in the CallContext for downstream use.
222            if let Some(ext_values) = prop_ctx.request_headers.get(EXTENSIONS_META_KEY) {
223                for uri in ext_values {
224                    ctx.activate_extension(uri);
225                }
226            }
227
228            // Store in task-local (requires init_propagation wrapper).
229            prop_ctx.install();
230
231            Ok(())
232        })
233    }
234
235    fn after<'a>(
236        &'a self,
237        _ctx: &'a ra2a::server::CallContext,
238        _resp: &'a mut ra2a::server::Response,
239    ) -> Pin<Box<dyn Future<Output = Result<(), A2AError>> + Send + 'a>> {
240        Box::pin(async { Ok(()) })
241    }
242}
243
244/// Extracts matching metadata from known request payload types.
245fn extract_metadata(
246    req: &ra2a::server::Request,
247    requested: &[String],
248    predicate: &ServerMetadataPredicate,
249    out: &mut HashMap<String, serde_json::Value>,
250) {
251    // Try each known param type that carries metadata.
252    if let Some(params) = req.downcast_ref::<ra2a::MessageSendParams>() {
253        collect_matching_metadata(&params.metadata, requested, predicate, out);
254    } else if let Some(params) = req.downcast_ref::<ra2a::TaskQueryParams>() {
255        collect_matching_metadata(&params.metadata, requested, predicate, out);
256    } else if let Some(params) = req.downcast_ref::<ra2a::TaskIdParams>() {
257        collect_matching_metadata(&params.metadata, requested, predicate, out);
258    }
259}
260
261/// Collects metadata entries that pass the predicate.
262fn collect_matching_metadata(
263    metadata: &ra2a::Metadata,
264    requested: &[String],
265    predicate: &ServerMetadataPredicate,
266    out: &mut HashMap<String, serde_json::Value>,
267) {
268    for (k, v) in metadata {
269        if predicate(requested, k) {
270            out.insert(k.clone(), v.clone());
271        }
272    }
273}
274
275/// Predicate function for filtering metadata keys on the client side.
276///
277/// Receives the target server's agent card (if available), the list of
278/// requested extension URIs, and the metadata key.
279pub type ClientMetadataPredicate =
280    Arc<dyn Fn(Option<&AgentCard>, &[String], &str) -> bool + Send + Sync>;
281
282/// Predicate function for filtering request headers on the client side.
283///
284/// Receives the target server's agent card (if available), the header key
285/// and value. Returns `true` if the header should be forwarded.
286pub type ClientHeaderPredicate = Arc<dyn Fn(Option<&AgentCard>, &str, &str) -> bool + Send + Sync>;
287
288/// Configuration for [`ClientPropagator`].
289#[derive(Default)]
290#[non_exhaustive]
291pub struct ClientPropagatorConfig {
292    /// Determines which payload metadata keys are propagated.
293    ///
294    /// Default: propagate keys that are requested extensions and supported by
295    /// the downstream server.
296    pub metadata_predicate: Option<ClientMetadataPredicate>,
297    /// Determines which request headers are propagated.
298    ///
299    /// Default: propagate `x-a2a-extensions` header values for extensions
300    /// supported by the downstream server.
301    pub header_predicate: Option<ClientHeaderPredicate>,
302}
303
304impl std::fmt::Debug for ClientPropagatorConfig {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        f.debug_struct("ClientPropagatorConfig")
307            .field("metadata_predicate", &self.metadata_predicate.is_some())
308            .field("header_predicate", &self.header_predicate.is_some())
309            .finish()
310    }
311}
312
313/// Client-side [`CallInterceptor`](ra2a::client::CallInterceptor) that injects
314/// propagated extension data into outgoing requests.
315///
316/// Reads [`PropagatorContext`] from the task-local (set by [`ServerPropagator`])
317/// and injects matching metadata and headers into the outgoing request.
318///
319/// Aligned with Go's `a2aext.NewClientPropagator`.
320pub struct ClientPropagator {
321    /// Metadata filter predicate.
322    metadata_predicate: ClientMetadataPredicate,
323    /// Header filter predicate.
324    header_predicate: ClientHeaderPredicate,
325}
326
327impl ClientPropagator {
328    /// Creates a new client propagator with default configuration.
329    pub fn new() -> Self {
330        Self::with_config(ClientPropagatorConfig::default())
331    }
332
333    /// Creates a new client propagator with custom configuration.
334    pub fn with_config(config: ClientPropagatorConfig) -> Self {
335        let metadata_predicate = config.metadata_predicate.unwrap_or_else(|| {
336            Arc::new(
337                |card: Option<&AgentCard>, requested: &[String], key: &str| {
338                    if !requested.iter().any(|u| u == key) {
339                        return false;
340                    }
341                    is_extension_supported(card, key)
342                },
343            )
344        });
345
346        let header_predicate = config.header_predicate.unwrap_or_else(|| {
347            Arc::new(|card: Option<&AgentCard>, key: &str, val: &str| {
348                if !key.eq_ignore_ascii_case(EXTENSIONS_META_KEY) {
349                    return false;
350                }
351                is_extension_supported(card, val)
352            })
353        });
354
355        Self {
356            metadata_predicate,
357            header_predicate,
358        }
359    }
360}
361
362impl Default for ClientPropagator {
363    fn default() -> Self {
364        Self::new()
365    }
366}
367
368impl std::fmt::Debug for ClientPropagator {
369    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370        f.debug_struct("ClientPropagator").finish_non_exhaustive()
371    }
372}
373
374impl ra2a::client::CallInterceptor for ClientPropagator {
375    fn before<'a>(
376        &'a self,
377        req: &'a mut ra2a::client::Request,
378    ) -> Pin<Box<dyn Future<Output = ra2a::error::Result<()>> + Send + 'a>> {
379        Box::pin(async move {
380            let Some(prop_ctx) = PropagatorContext::current() else {
381                return Ok(());
382            };
383
384            // Collect requested URIs from propagated headers for the predicate.
385            let requested: Vec<String> = prop_ctx
386                .request_headers
387                .get(EXTENSIONS_META_KEY)
388                .cloned()
389                .unwrap_or_default();
390
391            // Inject matching metadata into the outgoing payload.
392            if !prop_ctx.metadata.is_empty() {
393                inject_metadata(
394                    &mut *req.payload,
395                    &prop_ctx.metadata,
396                    req.card.as_ref(),
397                    &requested,
398                    &self.metadata_predicate,
399                );
400            }
401
402            // Inject matching headers.
403            for (header_name, header_values) in &prop_ctx.request_headers {
404                for header_value in header_values {
405                    if (self.header_predicate)(req.card.as_ref(), header_name, header_value) {
406                        req.meta.append(header_name, header_value);
407                    }
408                }
409            }
410
411            Ok(())
412        })
413    }
414}
415
416/// Injects matching metadata into known outgoing payload types.
417fn inject_metadata(
418    payload: &mut dyn std::any::Any,
419    metadata: &HashMap<String, serde_json::Value>,
420    card: Option<&AgentCard>,
421    requested: &[String],
422    predicate: &ClientMetadataPredicate,
423) {
424    if let Some(params) = payload.downcast_mut::<ra2a::MessageSendParams>() {
425        inject_matching_metadata(&mut params.metadata, metadata, card, requested, predicate);
426    } else if let Some(params) = payload.downcast_mut::<ra2a::TaskQueryParams>() {
427        inject_matching_metadata(&mut params.metadata, metadata, card, requested, predicate);
428    } else if let Some(params) = payload.downcast_mut::<ra2a::TaskIdParams>() {
429        inject_matching_metadata(&mut params.metadata, metadata, card, requested, predicate);
430    }
431}
432
433/// Inserts metadata entries that pass the predicate into the target map.
434fn inject_matching_metadata(
435    target: &mut ra2a::Metadata,
436    source: &HashMap<String, serde_json::Value>,
437    card: Option<&AgentCard>,
438    requested: &[String],
439    predicate: &ClientMetadataPredicate,
440) {
441    for (k, v) in source {
442        if predicate(card, requested, k) {
443            target.insert(k.clone(), v.clone());
444        }
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use ra2a::client::{CallInterceptor as _, CallMeta};
451    use ra2a::types::{AgentCapabilities, AgentCard, AgentExtension};
452
453    use super::*;
454
455    fn make_card(uris: &[&str]) -> AgentCard {
456        AgentCard {
457            name: "test".into(),
458            url: "https://example.com".into(),
459            version: "1.0".into(),
460            capabilities: AgentCapabilities {
461                extensions: uris
462                    .iter()
463                    .map(|u| AgentExtension {
464                        uri: (*u).into(),
465                        description: String::new(),
466                        required: false,
467                        params: Default::default(),
468                    })
469                    .collect(),
470                ..AgentCapabilities::default()
471            },
472            skills: vec![],
473            ..AgentCard::default()
474        }
475    }
476
477    #[tokio::test]
478    async fn test_client_propagator_injects_headers() {
479        let propagator = ClientPropagator::new();
480        let card = make_card(&["urn:a2a:ext:duration"]);
481
482        let mut prop_ctx = PropagatorContext::default();
483        prop_ctx.request_headers.insert(
484            EXTENSIONS_META_KEY.to_owned(),
485            vec!["urn:a2a:ext:duration".into()],
486        );
487
488        let mut req = ra2a::client::Request {
489            method: "message/send".into(),
490            base_url: "https://example.com".into(),
491            meta: CallMeta::default(),
492            card: Some(card),
493            payload: Box::new(()),
494        };
495
496        // Run within propagator context scope.
497        prop_ctx
498            .scope(async {
499                propagator.before(&mut req).await.unwrap();
500            })
501            .await;
502
503        let vals = req.meta.get_all(EXTENSIONS_META_KEY);
504        assert_eq!(vals, &["urn:a2a:ext:duration"]);
505    }
506
507    #[tokio::test]
508    async fn test_client_propagator_filters_unsupported() {
509        let propagator = ClientPropagator::new();
510        let card = make_card(&["urn:a2a:ext:other"]);
511
512        let mut prop_ctx = PropagatorContext::default();
513        prop_ctx.request_headers.insert(
514            EXTENSIONS_META_KEY.to_owned(),
515            vec!["urn:a2a:ext:duration".into()],
516        );
517
518        let mut req = ra2a::client::Request {
519            method: "message/send".into(),
520            base_url: "https://example.com".into(),
521            meta: CallMeta::default(),
522            card: Some(card),
523            payload: Box::new(()),
524        };
525
526        prop_ctx
527            .scope(async {
528                propagator.before(&mut req).await.unwrap();
529            })
530            .await;
531
532        let vals = req.meta.get_all(EXTENSIONS_META_KEY);
533        assert!(vals.is_empty());
534    }
535
536    #[tokio::test]
537    async fn test_client_propagator_no_context_is_noop() {
538        let propagator = ClientPropagator::new();
539
540        let mut req = ra2a::client::Request {
541            method: "message/send".into(),
542            base_url: "https://example.com".into(),
543            meta: CallMeta::default(),
544            card: None,
545            payload: Box::new(()),
546        };
547
548        // No PropagatorContext in scope — should be a no-op.
549        propagator.before(&mut req).await.unwrap();
550        assert!(req.meta.is_empty());
551    }
552
553    #[tokio::test]
554    async fn test_propagator_context_install_and_read() {
555        let ctx = PropagatorContext {
556            request_headers: {
557                let mut m = HashMap::new();
558                m.insert("x-test".into(), vec!["val1".into()]);
559                m
560            },
561            metadata: HashMap::new(),
562        };
563
564        init_propagation(async {
565            assert!(PropagatorContext::current().is_none());
566            assert!(ctx.install());
567            let read = PropagatorContext::current().unwrap();
568            assert_eq!(
569                read.request_headers.get("x-test").unwrap(),
570                &["val1".to_owned()]
571            );
572        })
573        .await;
574    }
575}