1use std::future::Future;
8use std::pin::Pin;
9
10use ra2a::SVC_PARAM_EXTENSIONS;
11use ra2a::client::{CallInterceptor, Request};
12use ra2a::error::Result;
13
14use crate::util::is_extension_supported;
15
16#[derive(Debug)]
34pub struct ExtensionActivator {
35 extension_uris: Vec<String>,
37}
38
39impl ExtensionActivator {
40 #[must_use]
42 pub const fn new(extension_uris: Vec<String>) -> Self {
43 Self { extension_uris }
44 }
45}
46
47impl CallInterceptor for ExtensionActivator {
48 fn before<'a>(
49 &'a self,
50 req: &'a mut Request,
51 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
52 self.activate(req);
53 Box::pin(std::future::ready(Ok(())))
54 }
55}
56
57impl ExtensionActivator {
58 fn activate(&self, req: &mut Request) {
60 if let Some(card) = &req.card
62 && card.capabilities.extensions.is_empty()
63 {
64 return;
65 }
66
67 for uri in &self.extension_uris {
68 if is_extension_supported(req.card.as_ref(), uri) {
69 req.service_params.append(SVC_PARAM_EXTENSIONS, uri.clone());
70 }
71 }
72 }
73}
74
75#[cfg(test)]
76#[allow(clippy::unwrap_used, reason = "tests use unwrap for brevity")]
77mod tests {
78 use ra2a::client::ServiceParams;
79 use ra2a::types::{
80 AgentCapabilities, AgentCard, AgentExtension, AgentInterface, TransportProtocol,
81 };
82
83 use super::*;
84
85 fn make_card(uris: &[&str]) -> AgentCard {
86 let mut card = AgentCard::new(
87 "test",
88 "test agent",
89 vec![AgentInterface::new(
90 "https://example.com",
91 TransportProtocol::new("JSONRPC"),
92 )],
93 );
94 card.capabilities = AgentCapabilities {
95 extensions: uris
96 .iter()
97 .map(|u| AgentExtension {
98 uri: (*u).into(),
99 description: None,
100 required: false,
101 params: None,
102 })
103 .collect(),
104 ..AgentCapabilities::default()
105 };
106 card
107 }
108
109 fn make_request(card: Option<AgentCard>) -> Request {
110 Request {
111 method: "message/send".into(),
112 card,
113 service_params: ServiceParams::default(),
114 payload: Box::new(()),
115 }
116 }
117
118 #[tokio::test]
119 async fn test_activator_filters_by_card() {
120 let activator = ExtensionActivator::new(vec![
121 "urn:a2a:ext:duration".into(),
122 "urn:a2a:ext:missing".into(),
123 ]);
124
125 let card = make_card(&["urn:a2a:ext:duration", "urn:a2a:ext:other"]);
126 let mut req = make_request(Some(card));
127
128 activator.before(&mut req).await.unwrap();
129
130 let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
131 assert_eq!(vals, &["urn:a2a:ext:duration"]);
132 }
133
134 #[tokio::test]
135 async fn test_activator_no_card_sends_all() {
136 let activator =
137 ExtensionActivator::new(vec!["urn:a2a:ext:a".into(), "urn:a2a:ext:b".into()]);
138
139 let mut req = make_request(None);
140 activator.before(&mut req).await.unwrap();
141
142 let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
143 assert_eq!(vals, &["urn:a2a:ext:a", "urn:a2a:ext:b"]);
144 }
145
146 #[tokio::test]
147 async fn test_activator_empty_card_extensions_skips() {
148 let activator = ExtensionActivator::new(vec!["urn:a2a:ext:duration".into()]);
149
150 let card = make_card(&[]);
151 let mut req = make_request(Some(card));
152
153 activator.before(&mut req).await.unwrap();
154
155 let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
156 assert!(vals.is_empty());
157 }
158}