Skip to main content

ra2a_ext/
activator.rs

1//! Client-side extension activator interceptor.
2//!
3//! Aligned with Go's `a2aext.NewActivator`. Requests extension activation
4//! on outgoing calls by appending supported extension URIs to the
5//! `x-a2a-extensions` header.
6
7use std::future::Future;
8use std::pin::Pin;
9
10use ra2a::EXTENSIONS_META_KEY;
11use ra2a::client::{CallInterceptor, Request};
12use ra2a::error::Result;
13
14use crate::util::is_extension_supported;
15
16/// Client-side [`CallInterceptor`] that requests extension activation.
17///
18/// For each outgoing request, checks the server's [`AgentCard`](ra2a::AgentCard)
19/// for supported extensions and appends matching URIs to the
20/// `x-a2a-extensions` metadata header.
21///
22/// # Example
23///
24/// ```rust,no_run
25/// use ra2a_ext::ExtensionActivator;
26///
27/// let activator = ExtensionActivator::new(vec![
28///     "urn:a2a:ext:duration".into(),
29///     "urn:a2a:ext:custom".into(),
30/// ]);
31/// // client.with_interceptor(activator);
32/// ```
33#[derive(Debug)]
34pub struct ExtensionActivator {
35    /// Extension URIs this client wishes to activate.
36    extension_uris: Vec<String>,
37}
38
39impl ExtensionActivator {
40    /// Creates a new activator for the given extension URIs.
41    pub const fn new(extension_uris: Vec<String>) -> Self {
42        Self { extension_uris }
43    }
44}
45
46impl CallInterceptor for ExtensionActivator {
47    fn before<'a>(
48        &'a self,
49        req: &'a mut Request,
50    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
51        Box::pin(async move {
52            // If the card has no extensions declared, skip entirely.
53            if let Some(card) = &req.card
54                && card.capabilities.extensions.is_empty()
55            {
56                return Ok(());
57            }
58
59            for uri in &self.extension_uris {
60                if is_extension_supported(req.card.as_ref(), uri) {
61                    req.meta.append(EXTENSIONS_META_KEY, uri.clone());
62                }
63            }
64            Ok(())
65        })
66    }
67}
68
69#[cfg(test)]
70#[allow(clippy::unwrap_used)]
71mod tests {
72    use std::collections::HashMap;
73
74    use ra2a::client::CallMeta;
75    use ra2a::types::{AgentCapabilities, AgentCard, AgentExtension};
76
77    use super::*;
78
79    fn make_card(uris: &[&str]) -> AgentCard {
80        AgentCard {
81            name: "test".into(),
82            url: "https://example.com".into(),
83            version: "1.0".into(),
84            capabilities: AgentCapabilities {
85                extensions: uris
86                    .iter()
87                    .map(|u| AgentExtension {
88                        uri: (*u).into(),
89                        description: String::new(),
90                        required: false,
91                        params: HashMap::default(),
92                    })
93                    .collect(),
94                ..AgentCapabilities::default()
95            },
96            skills: vec![],
97            ..AgentCard::default()
98        }
99    }
100
101    fn make_request(card: Option<AgentCard>) -> Request {
102        Request {
103            method: "message/send".into(),
104            base_url: "https://example.com".into(),
105            meta: CallMeta::default(),
106            card,
107            payload: Box::new(()),
108        }
109    }
110
111    #[tokio::test]
112    async fn test_activator_filters_by_card() {
113        let activator = ExtensionActivator::new(vec![
114            "urn:a2a:ext:duration".into(),
115            "urn:a2a:ext:missing".into(),
116        ]);
117
118        let card = make_card(&["urn:a2a:ext:duration", "urn:a2a:ext:other"]);
119        let mut req = make_request(Some(card));
120
121        activator.before(&mut req).await.unwrap();
122
123        let vals = req.meta.get_all(EXTENSIONS_META_KEY);
124        assert_eq!(vals, &["urn:a2a:ext:duration"]);
125    }
126
127    #[tokio::test]
128    async fn test_activator_no_card_sends_all() {
129        let activator =
130            ExtensionActivator::new(vec!["urn:a2a:ext:a".into(), "urn:a2a:ext:b".into()]);
131
132        let mut req = make_request(None);
133        activator.before(&mut req).await.unwrap();
134
135        let vals = req.meta.get_all(EXTENSIONS_META_KEY);
136        assert_eq!(vals, &["urn:a2a:ext:a", "urn:a2a:ext:b"]);
137    }
138
139    #[tokio::test]
140    async fn test_activator_empty_card_extensions_skips() {
141        let activator = ExtensionActivator::new(vec!["urn:a2a:ext:duration".into()]);
142
143        let card = make_card(&[]);
144        let mut req = make_request(Some(card));
145
146        activator.before(&mut req).await.unwrap();
147
148        let vals = req.meta.get_all(EXTENSIONS_META_KEY);
149        assert!(vals.is_empty());
150    }
151}