Skip to main content

ra2a_ext/
activator.rs

1//! Client-side extension activator interceptor.
2//!
3//! Aligned with Go's `a2aext.NewActivator`. Requests extension activation
4//! on outgoing calls by appending supported extension URIs to the
5//! `x-a2a-extensions` header.
6
7use std::future::Future;
8use std::pin::Pin;
9
10use ra2a::EXTENSIONS_META_KEY;
11use ra2a::client::{CallInterceptor, Request};
12use ra2a::error::Result;
13
14use crate::util::is_extension_supported;
15
16/// Client-side [`CallInterceptor`] that requests extension activation.
17///
18/// For each outgoing request, checks the server's [`AgentCard`](ra2a::AgentCard)
19/// for supported extensions and appends matching URIs to the
20/// `x-a2a-extensions` metadata header.
21///
22/// # Example
23///
24/// ```rust,no_run
25/// use ra2a_ext::ExtensionActivator;
26///
27/// let activator = ExtensionActivator::new(vec![
28///     "urn:a2a:ext:duration".into(),
29///     "urn:a2a:ext:custom".into(),
30/// ]);
31/// // client.with_interceptor(activator);
32/// ```
33#[derive(Debug)]
34pub struct ExtensionActivator {
35    /// Extension URIs this client wishes to activate.
36    extension_uris: Vec<String>,
37}
38
39impl ExtensionActivator {
40    /// Creates a new activator for the given extension URIs.
41    pub const fn new(extension_uris: Vec<String>) -> Self {
42        Self { extension_uris }
43    }
44}
45
46impl CallInterceptor for ExtensionActivator {
47    fn before<'a>(
48        &'a self,
49        req: &'a mut Request,
50    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
51        Box::pin(async move {
52            // If the card has no extensions declared, skip entirely.
53            if let Some(card) = &req.card
54                && card.capabilities.extensions.is_empty()
55            {
56                return Ok(());
57            }
58
59            for uri in &self.extension_uris {
60                if is_extension_supported(req.card.as_ref(), uri) {
61                    req.meta.append(EXTENSIONS_META_KEY, uri.clone());
62                }
63            }
64            Ok(())
65        })
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use ra2a::client::CallMeta;
72    use ra2a::types::{AgentCapabilities, AgentCard, AgentExtension};
73
74    use super::*;
75
76    fn make_card(uris: &[&str]) -> AgentCard {
77        AgentCard {
78            name: "test".into(),
79            url: "https://example.com".into(),
80            version: "1.0".into(),
81            capabilities: AgentCapabilities {
82                extensions: uris
83                    .iter()
84                    .map(|u| AgentExtension {
85                        uri: (*u).into(),
86                        description: String::new(),
87                        required: false,
88                        params: Default::default(),
89                    })
90                    .collect(),
91                ..AgentCapabilities::default()
92            },
93            skills: vec![],
94            ..AgentCard::default()
95        }
96    }
97
98    fn make_request(card: Option<AgentCard>) -> Request {
99        Request {
100            method: "message/send".into(),
101            base_url: "https://example.com".into(),
102            meta: CallMeta::default(),
103            card,
104            payload: Box::new(()),
105        }
106    }
107
108    #[tokio::test]
109    async fn test_activator_filters_by_card() {
110        let activator = ExtensionActivator::new(vec![
111            "urn:a2a:ext:duration".into(),
112            "urn:a2a:ext:missing".into(),
113        ]);
114
115        let card = make_card(&["urn:a2a:ext:duration", "urn:a2a:ext:other"]);
116        let mut req = make_request(Some(card));
117
118        activator.before(&mut req).await.unwrap();
119
120        let vals = req.meta.get_all(EXTENSIONS_META_KEY);
121        assert_eq!(vals, &["urn:a2a:ext:duration"]);
122    }
123
124    #[tokio::test]
125    async fn test_activator_no_card_sends_all() {
126        let activator =
127            ExtensionActivator::new(vec!["urn:a2a:ext:a".into(), "urn:a2a:ext:b".into()]);
128
129        let mut req = make_request(None);
130        activator.before(&mut req).await.unwrap();
131
132        let vals = req.meta.get_all(EXTENSIONS_META_KEY);
133        assert_eq!(vals, &["urn:a2a:ext:a", "urn:a2a:ext:b"]);
134    }
135
136    #[tokio::test]
137    async fn test_activator_empty_card_extensions_skips() {
138        let activator = ExtensionActivator::new(vec!["urn:a2a:ext:duration".into()]);
139
140        let card = make_card(&[]);
141        let mut req = make_request(Some(card));
142
143        activator.before(&mut req).await.unwrap();
144
145        let vals = req.meta.get_all(EXTENSIONS_META_KEY);
146        assert!(vals.is_empty());
147    }
148}