Skip to main content

ra2a_ext/
propagator.rs

1//! Extension propagation interceptors for agent-to-agent chaining.
2//!
3//! Aligned with Go's `a2aext.NewServerPropagator` and `a2aext.NewClientPropagator`.
4//!
5//! When an agent (B) acts as both server and client in a chain (A → B → C),
6//! the [`ServerPropagator`] extracts extension-related metadata and headers
7//! from the incoming request (A → B), stores them in a [`PropagatorContext`],
8//! and the [`ClientPropagator`] injects them into the outgoing request (B → C).
9//!
10//! ## Data flow
11//!
12//! ```text
13//! A → [HTTP] → B (ServerPropagator.before extracts → PropagatorContext)
14//!                   → handler wraps executor in PropagatorContext::scope()
15//!                     → B calls C (ClientPropagator.before injects from task_local)
16//! ```
17//!
18//! The user must wrap downstream client calls within [`PropagatorContext::scope()`]
19//! so that the [`ClientPropagator`] can access the extracted data via `task_local`.
20
21use std::cell::RefCell;
22use std::collections::HashMap;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::Arc;
26
27use ra2a::SVC_PARAM_EXTENSIONS;
28use ra2a::error::A2AError;
29use ra2a::types::AgentCard;
30
31use crate::util::is_extension_supported;
32
33tokio::task_local! {
34    /// Mutable cell for propagator data. Must be initialized via
35    /// [`init_propagation`] before [`ServerPropagator`] can store data.
36    static PROPAGATOR_CTX: RefCell<Option<PropagatorContext>>;
37}
38
39/// Extension data extracted by [`ServerPropagator`] for downstream propagation.
40///
41/// Aligned with Go's internal `propagatorContext` struct.
42#[derive(Debug, Clone, Default)]
43#[non_exhaustive]
44pub struct PropagatorContext {
45    /// HTTP headers to propagate (key → values).
46    pub request_headers: HashMap<String, Vec<String>>,
47    /// Payload metadata to propagate (key → value).
48    pub metadata: HashMap<String, serde_json::Value>,
49}
50
51impl PropagatorContext {
52    /// Reads the current task-local propagator context, if set.
53    #[must_use]
54    pub fn current() -> Option<Self> {
55        PROPAGATOR_CTX
56            .try_with(|cell| cell.borrow().clone())
57            .ok()
58            .flatten()
59    }
60
61    /// Stores this context in the task-local cell.
62    ///
63    /// Requires that the current task is running within [`init_propagation`].
64    /// Returns `true` if stored successfully.
65    pub fn install(self) -> bool {
66        PROPAGATOR_CTX
67            .try_with(|cell| {
68                *cell.borrow_mut() = Some(self);
69            })
70            .is_ok()
71    }
72
73    /// Executes a future with this context directly available via task-local.
74    ///
75    /// This is a convenience wrapper for simple cases where you already have
76    /// the context and want to make it available to [`ClientPropagator`].
77    pub async fn scope<F: Future>(self, f: F) -> F::Output {
78        PROPAGATOR_CTX.scope(RefCell::new(Some(self)), f).await
79    }
80}
81
82/// Wraps a future with an empty propagation scope.
83///
84/// Call this around your request handler so that [`ServerPropagator`] can store
85/// extracted data and [`ClientPropagator`] can read it later.
86///
87/// # Example
88///
89/// ```rust,ignore
90/// let result = ra2a_ext::init_propagation(async {
91///     // ServerPropagator.before() stores data here
92///     // handler runs
93///     // ClientPropagator.before() reads data here
94///     handle_request(req).await
95/// }).await;
96/// ```
97pub async fn init_propagation<F: Future>(f: F) -> F::Output {
98    PROPAGATOR_CTX.scope(RefCell::new(None), f).await
99}
100
101/// Predicate function for filtering metadata keys on the server side.
102///
103/// Receives the list of requested extension URIs and the metadata key.
104/// Returns `true` if the key should be propagated.
105pub type ServerMetadataPredicate = Arc<dyn Fn(&[String], &str) -> bool + Send + Sync>;
106
107/// Predicate function for filtering request headers on the server side.
108///
109/// Receives the header key. Returns `true` if the header should be propagated.
110pub type ServerHeaderPredicate = Arc<dyn Fn(&str) -> bool + Send + Sync>;
111
112/// Configuration for [`ServerPropagator`].
113///
114/// Both predicates are optional — sensible defaults are used when `None`.
115#[derive(Default)]
116#[non_exhaustive]
117pub struct ServerPropagatorConfig {
118    /// Determines which payload metadata keys are propagated.
119    ///
120    /// Default: propagate keys whose name matches a client-requested extension URI.
121    pub metadata_predicate: Option<ServerMetadataPredicate>,
122    /// Determines which request headers are propagated.
123    ///
124    /// Default: propagate only the `x-a2a-extensions` header.
125    pub header_predicate: Option<ServerHeaderPredicate>,
126}
127
128impl std::fmt::Debug for ServerPropagatorConfig {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        f.debug_struct("ServerPropagatorConfig")
131            .field("metadata_predicate", &self.metadata_predicate.is_some())
132            .field("header_predicate", &self.header_predicate.is_some())
133            .finish()
134    }
135}
136
137/// Server-side [`CallInterceptor`](ra2a::server::CallInterceptor) that extracts
138/// extension-related metadata and headers from incoming requests.
139///
140/// The extracted data is stored in a [`PropagatorContext`] via `task_local`.
141/// The handler must be wrapped in [`init_propagation`] for this to work.
142/// [`ClientPropagator`] reads the stored context when making downstream calls.
143///
144/// Aligned with Go's `a2aext.NewServerPropagator`.
145pub struct ServerPropagator {
146    /// Metadata filter predicate.
147    metadata_predicate: ServerMetadataPredicate,
148    /// Header filter predicate.
149    header_predicate: ServerHeaderPredicate,
150}
151
152impl ServerPropagator {
153    /// Creates a new server propagator with default configuration.
154    ///
155    /// Default behavior:
156    /// - Propagates metadata keys matching client-requested extension URIs
157    /// - Propagates the `x-a2a-extensions` header
158    pub fn new() -> Self {
159        Self::with_config(ServerPropagatorConfig::default())
160    }
161
162    /// Creates a new server propagator with custom configuration.
163    pub fn with_config(config: ServerPropagatorConfig) -> Self {
164        let metadata_predicate = config.metadata_predicate.unwrap_or_else(|| {
165            Arc::new(|requested_uris: &[String], key: &str| requested_uris.iter().any(|u| u == key))
166        });
167
168        let header_predicate = config.header_predicate.unwrap_or_else(|| {
169            Arc::new(|key: &str| key.eq_ignore_ascii_case(SVC_PARAM_EXTENSIONS))
170        });
171
172        Self {
173            metadata_predicate,
174            header_predicate,
175        }
176    }
177}
178
179impl Default for ServerPropagator {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185impl std::fmt::Debug for ServerPropagator {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        f.debug_struct("ServerPropagator").finish_non_exhaustive()
188    }
189}
190
191impl ra2a::server::CallInterceptor for ServerPropagator {
192    fn before<'a>(
193        &'a self,
194        ctx: &'a mut ra2a::server::CallContext,
195        req: &'a mut ra2a::server::Request,
196    ) -> Pin<Box<dyn Future<Output = Result<(), A2AError>> + Send + 'a>> {
197        Box::pin(async move {
198            let mut prop_ctx = PropagatorContext::default();
199
200            // Collect requested extension URIs for the metadata predicate.
201            let requested = ctx.requested_extension_uris();
202
203            // Extract matching metadata from the request payload.
204            extract_metadata(
205                req,
206                &requested,
207                &self.metadata_predicate,
208                &mut prop_ctx.metadata,
209            );
210
211            // Extract matching headers from request metadata.
212            let request_meta = ctx.request_meta();
213            for (header_name, header_values) in request_meta.iter() {
214                if (self.header_predicate)(header_name) {
215                    prop_ctx
216                        .request_headers
217                        .insert(header_name.to_owned(), header_values.to_vec());
218                }
219            }
220
221            // Also activate extensions in the CallContext for downstream use.
222            if let Some(ext_values) = prop_ctx.request_headers.get(SVC_PARAM_EXTENSIONS) {
223                for uri in ext_values {
224                    ctx.activate_extension(uri);
225                }
226            }
227
228            // Store in task-local (requires init_propagation wrapper).
229            prop_ctx.install();
230
231            Ok(())
232        })
233    }
234
235    fn after<'a>(
236        &'a self,
237        _ctx: &'a ra2a::server::CallContext,
238        _resp: &'a mut ra2a::server::Response,
239    ) -> Pin<Box<dyn Future<Output = Result<(), A2AError>> + Send + 'a>> {
240        Box::pin(async { Ok(()) })
241    }
242}
243
244/// Extracts matching metadata from known request payload types.
245fn extract_metadata(
246    req: &ra2a::server::Request,
247    requested: &[String],
248    predicate: &ServerMetadataPredicate,
249    out: &mut HashMap<String, serde_json::Value>,
250) {
251    if let Some(params) = req.downcast_ref::<ra2a::SendMessageRequest>()
252        && let Some(ref meta) = params.metadata
253    {
254        collect_matching_metadata(meta, requested, predicate, out);
255    }
256}
257
258/// Collects metadata entries that pass the predicate.
259fn collect_matching_metadata(
260    metadata: &ra2a::Metadata,
261    requested: &[String],
262    predicate: &ServerMetadataPredicate,
263    out: &mut HashMap<String, serde_json::Value>,
264) {
265    for (k, v) in metadata {
266        if predicate(requested, k) {
267            out.insert(k.clone(), v.clone());
268        }
269    }
270}
271
272/// Predicate function for filtering metadata keys on the client side.
273///
274/// Receives the target server's agent card (if available), the list of
275/// requested extension URIs, and the metadata key.
276pub type ClientMetadataPredicate =
277    Arc<dyn Fn(Option<&AgentCard>, &[String], &str) -> bool + Send + Sync>;
278
279/// Predicate function for filtering request headers on the client side.
280///
281/// Receives the target server's agent card (if available), the header key
282/// and value. Returns `true` if the header should be forwarded.
283pub type ClientHeaderPredicate = Arc<dyn Fn(Option<&AgentCard>, &str, &str) -> bool + Send + Sync>;
284
285/// Configuration for [`ClientPropagator`].
286#[derive(Default)]
287#[non_exhaustive]
288pub struct ClientPropagatorConfig {
289    /// Determines which payload metadata keys are propagated.
290    ///
291    /// Default: propagate keys that are requested extensions and supported by
292    /// the downstream server.
293    pub metadata_predicate: Option<ClientMetadataPredicate>,
294    /// Determines which request headers are propagated.
295    ///
296    /// Default: propagate `x-a2a-extensions` header values for extensions
297    /// supported by the downstream server.
298    pub header_predicate: Option<ClientHeaderPredicate>,
299}
300
301impl std::fmt::Debug for ClientPropagatorConfig {
302    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303        f.debug_struct("ClientPropagatorConfig")
304            .field("metadata_predicate", &self.metadata_predicate.is_some())
305            .field("header_predicate", &self.header_predicate.is_some())
306            .finish()
307    }
308}
309
310/// Client-side [`CallInterceptor`](ra2a::client::CallInterceptor) that injects
311/// propagated extension data into outgoing requests.
312///
313/// Reads [`PropagatorContext`] from the task-local (set by [`ServerPropagator`])
314/// and injects matching metadata and headers into the outgoing request.
315///
316/// Aligned with Go's `a2aext.NewClientPropagator`.
317pub struct ClientPropagator {
318    /// Metadata filter predicate.
319    metadata_predicate: ClientMetadataPredicate,
320    /// Header filter predicate.
321    header_predicate: ClientHeaderPredicate,
322}
323
324impl ClientPropagator {
325    /// Creates a new client propagator with default configuration.
326    pub fn new() -> Self {
327        Self::with_config(ClientPropagatorConfig::default())
328    }
329
330    /// Creates a new client propagator with custom configuration.
331    pub fn with_config(config: ClientPropagatorConfig) -> Self {
332        let metadata_predicate = config.metadata_predicate.unwrap_or_else(|| {
333            Arc::new(
334                |card: Option<&AgentCard>, requested: &[String], key: &str| {
335                    if !requested.iter().any(|u| u == key) {
336                        return false;
337                    }
338                    is_extension_supported(card, key)
339                },
340            )
341        });
342
343        let header_predicate = config.header_predicate.unwrap_or_else(|| {
344            Arc::new(|card: Option<&AgentCard>, key: &str, val: &str| {
345                if !key.eq_ignore_ascii_case(SVC_PARAM_EXTENSIONS) {
346                    return false;
347                }
348                is_extension_supported(card, val)
349            })
350        });
351
352        Self {
353            metadata_predicate,
354            header_predicate,
355        }
356    }
357}
358
359impl Default for ClientPropagator {
360    fn default() -> Self {
361        Self::new()
362    }
363}
364
365impl std::fmt::Debug for ClientPropagator {
366    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        f.debug_struct("ClientPropagator").finish_non_exhaustive()
368    }
369}
370
371impl ra2a::client::CallInterceptor for ClientPropagator {
372    fn before<'a>(
373        &'a self,
374        req: &'a mut ra2a::client::Request,
375    ) -> Pin<Box<dyn Future<Output = ra2a::error::Result<()>> + Send + 'a>> {
376        Box::pin(async move {
377            let Some(prop_ctx) = PropagatorContext::current() else {
378                return Ok(());
379            };
380
381            // Collect requested URIs from propagated headers for the predicate.
382            let requested: Vec<String> = prop_ctx
383                .request_headers
384                .get(SVC_PARAM_EXTENSIONS)
385                .cloned()
386                .unwrap_or_default();
387
388            // Inject matching metadata into the outgoing payload.
389            if !prop_ctx.metadata.is_empty() {
390                inject_metadata(
391                    &mut *req.payload,
392                    &prop_ctx.metadata,
393                    req.card.as_ref(),
394                    &requested,
395                    &self.metadata_predicate,
396                );
397            }
398
399            // Inject matching headers.
400            for (header_name, header_values) in &prop_ctx.request_headers {
401                for header_value in header_values {
402                    if (self.header_predicate)(req.card.as_ref(), header_name, header_value) {
403                        req.service_params.append(header_name, header_value);
404                    }
405                }
406            }
407
408            Ok(())
409        })
410    }
411}
412
413/// Injects matching metadata into known outgoing payload types.
414fn inject_metadata(
415    payload: &mut dyn std::any::Any,
416    metadata: &HashMap<String, serde_json::Value>,
417    card: Option<&AgentCard>,
418    requested: &[String],
419    predicate: &ClientMetadataPredicate,
420) {
421    if let Some(params) = payload.downcast_mut::<ra2a::SendMessageRequest>() {
422        let meta = params.metadata.get_or_insert_with(Default::default);
423        inject_matching_metadata(meta, metadata, card, requested, predicate);
424    }
425}
426
427/// Inserts metadata entries that pass the predicate into the target map.
428fn inject_matching_metadata(
429    target: &mut ra2a::Metadata,
430    source: &HashMap<String, serde_json::Value>,
431    card: Option<&AgentCard>,
432    requested: &[String],
433    predicate: &ClientMetadataPredicate,
434) {
435    for (k, v) in source {
436        if predicate(card, requested, k) {
437            target.insert(k.clone(), v.clone());
438        }
439    }
440}
441
442#[cfg(test)]
443#[allow(clippy::unwrap_used)]
444mod tests {
445    use ra2a::client::{CallInterceptor as _, ServiceParams};
446    use ra2a::types::{
447        AgentCapabilities, AgentCard, AgentExtension, AgentInterface, TransportProtocol,
448    };
449
450    use super::*;
451
452    fn make_card(uris: &[&str]) -> AgentCard {
453        let mut card = AgentCard::new(
454            "test",
455            "test agent",
456            vec![AgentInterface::new(
457                "https://example.com",
458                TransportProtocol::new("JSONRPC"),
459            )],
460        );
461        card.capabilities = AgentCapabilities {
462            extensions: uris
463                .iter()
464                .map(|u| AgentExtension {
465                    uri: (*u).into(),
466                    description: None,
467                    required: false,
468                    params: None,
469                })
470                .collect(),
471            ..AgentCapabilities::default()
472        };
473        card
474    }
475
476    #[tokio::test]
477    async fn test_client_propagator_injects_headers() {
478        let propagator = ClientPropagator::new();
479        let card = make_card(&["urn:a2a:ext:duration"]);
480
481        let mut prop_ctx = PropagatorContext::default();
482        prop_ctx.request_headers.insert(
483            SVC_PARAM_EXTENSIONS.to_owned(),
484            vec!["urn:a2a:ext:duration".into()],
485        );
486
487        let mut req = ra2a::client::Request {
488            method: "message/send".into(),
489            service_params: ServiceParams::default(),
490            card: Some(card),
491            payload: Box::new(()),
492        };
493
494        prop_ctx
495            .scope(async {
496                propagator.before(&mut req).await.unwrap();
497            })
498            .await;
499
500        let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
501        assert_eq!(vals, &["urn:a2a:ext:duration"]);
502    }
503
504    #[tokio::test]
505    async fn test_client_propagator_filters_unsupported() {
506        let propagator = ClientPropagator::new();
507        let card = make_card(&["urn:a2a:ext:other"]);
508
509        let mut prop_ctx = PropagatorContext::default();
510        prop_ctx.request_headers.insert(
511            SVC_PARAM_EXTENSIONS.to_owned(),
512            vec!["urn:a2a:ext:duration".into()],
513        );
514
515        let mut req = ra2a::client::Request {
516            method: "message/send".into(),
517            service_params: ServiceParams::default(),
518            card: Some(card),
519            payload: Box::new(()),
520        };
521
522        prop_ctx
523            .scope(async {
524                propagator.before(&mut req).await.unwrap();
525            })
526            .await;
527
528        let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
529        assert!(vals.is_empty());
530    }
531
532    #[tokio::test]
533    async fn test_client_propagator_no_context_is_noop() {
534        let propagator = ClientPropagator::new();
535
536        let mut req = ra2a::client::Request {
537            method: "message/send".into(),
538            service_params: ServiceParams::default(),
539            card: None,
540            payload: Box::new(()),
541        };
542
543        propagator.before(&mut req).await.unwrap();
544        assert!(req.service_params.is_empty());
545    }
546
547    #[tokio::test]
548    async fn test_propagator_context_install_and_read() {
549        let ctx = PropagatorContext {
550            request_headers: {
551                let mut m = HashMap::new();
552                m.insert("x-test".into(), vec!["val1".into()]);
553                m
554            },
555            metadata: HashMap::new(),
556        };
557
558        init_propagation(async {
559            assert!(PropagatorContext::current().is_none());
560            assert!(ctx.install());
561            let read = PropagatorContext::current().unwrap();
562            assert_eq!(
563                read.request_headers.get("x-test").unwrap(),
564                &["val1".to_owned()]
565            );
566        })
567        .await;
568    }
569}