Skip to main content

ra2a_ext/
activator.rs

1//! Client-side extension activator interceptor.
2//!
3//! Aligned with Go's `a2aext.NewActivator`. Requests extension activation
4//! on outgoing calls by appending supported extension URIs to the
5//! `x-a2a-extensions` header.
6
7use std::future::Future;
8use std::pin::Pin;
9
10use ra2a::SVC_PARAM_EXTENSIONS;
11use ra2a::client::{CallInterceptor, Request};
12use ra2a::error::Result;
13
14use crate::util::is_extension_supported;
15
16/// Client-side [`CallInterceptor`] that requests extension activation.
17///
18/// For each outgoing request, checks the server's [`AgentCard`](ra2a::AgentCard)
19/// for supported extensions and appends matching URIs to the
20/// `x-a2a-extensions` metadata header.
21///
22/// # Example
23///
24/// ```rust,no_run
25/// use ra2a_ext::ExtensionActivator;
26///
27/// let activator = ExtensionActivator::new(vec![
28///     "urn:a2a:ext:duration".into(),
29///     "urn:a2a:ext:custom".into(),
30/// ]);
31/// // client.with_interceptor(activator);
32/// ```
33#[derive(Debug)]
34pub struct ExtensionActivator {
35    /// Extension URIs this client wishes to activate.
36    extension_uris: Vec<String>,
37}
38
39impl ExtensionActivator {
40    /// Creates a new activator for the given extension URIs.
41    #[must_use]
42    pub const fn new(extension_uris: Vec<String>) -> Self {
43        Self { extension_uris }
44    }
45}
46
47impl CallInterceptor for ExtensionActivator {
48    fn before<'a>(
49        &'a self,
50        req: &'a mut Request,
51    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
52        self.activate(req);
53        Box::pin(std::future::ready(Ok(())))
54    }
55}
56
57impl ExtensionActivator {
58    /// Appends supported extension URIs to the request's service params.
59    fn activate(&self, req: &mut Request) {
60        // If the card has no extensions declared, skip entirely.
61        if let Some(card) = &req.card
62            && card.capabilities.extensions.is_empty()
63        {
64            return;
65        }
66
67        for uri in &self.extension_uris {
68            if is_extension_supported(req.card.as_ref(), uri) {
69                req.service_params.append(SVC_PARAM_EXTENSIONS, uri.clone());
70            }
71        }
72    }
73}
74
75#[cfg(test)]
76#[allow(clippy::unwrap_used, reason = "tests use unwrap for brevity")]
77mod tests {
78    use ra2a::client::ServiceParams;
79    use ra2a::types::{
80        AgentCapabilities, AgentCard, AgentExtension, AgentInterface, TransportProtocol,
81    };
82
83    use super::*;
84
85    fn make_card(uris: &[&str]) -> AgentCard {
86        let mut card = AgentCard::new(
87            "test",
88            "test agent",
89            vec![AgentInterface::new(
90                "https://example.com",
91                TransportProtocol::new("JSONRPC"),
92            )],
93        );
94        card.capabilities = AgentCapabilities {
95            extensions: uris
96                .iter()
97                .map(|u| AgentExtension {
98                    uri: (*u).into(),
99                    description: None,
100                    required: false,
101                    params: None,
102                })
103                .collect(),
104            ..AgentCapabilities::default()
105        };
106        card
107    }
108
109    fn make_request(card: Option<AgentCard>) -> Request {
110        Request {
111            method: "message/send".into(),
112            card,
113            service_params: ServiceParams::default(),
114            payload: Box::new(()),
115        }
116    }
117
118    #[tokio::test]
119    async fn test_activator_filters_by_card() {
120        let activator = ExtensionActivator::new(vec![
121            "urn:a2a:ext:duration".into(),
122            "urn:a2a:ext:missing".into(),
123        ]);
124
125        let card = make_card(&["urn:a2a:ext:duration", "urn:a2a:ext:other"]);
126        let mut req = make_request(Some(card));
127
128        activator.before(&mut req).await.unwrap();
129
130        let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
131        assert_eq!(vals, &["urn:a2a:ext:duration"]);
132    }
133
134    #[tokio::test]
135    async fn test_activator_no_card_sends_all() {
136        let activator =
137            ExtensionActivator::new(vec!["urn:a2a:ext:a".into(), "urn:a2a:ext:b".into()]);
138
139        let mut req = make_request(None);
140        activator.before(&mut req).await.unwrap();
141
142        let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
143        assert_eq!(vals, &["urn:a2a:ext:a", "urn:a2a:ext:b"]);
144    }
145
146    #[tokio::test]
147    async fn test_activator_empty_card_extensions_skips() {
148        let activator = ExtensionActivator::new(vec!["urn:a2a:ext:duration".into()]);
149
150        let card = make_card(&[]);
151        let mut req = make_request(Some(card));
152
153        activator.before(&mut req).await.unwrap();
154
155        let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
156        assert!(vals.is_empty());
157    }
158}