1use std::future::Future;
8use std::pin::Pin;
9
10use ra2a::EXTENSIONS_META_KEY;
11use ra2a::client::{CallInterceptor, Request};
12use ra2a::error::Result;
13
14use crate::util::is_extension_supported;
15
16#[derive(Debug)]
34pub struct ExtensionActivator {
35 extension_uris: Vec<String>,
37}
38
39impl ExtensionActivator {
40 pub const fn new(extension_uris: Vec<String>) -> Self {
42 Self { extension_uris }
43 }
44}
45
46impl CallInterceptor for ExtensionActivator {
47 fn before<'a>(
48 &'a self,
49 req: &'a mut Request,
50 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
51 Box::pin(async move {
52 if let Some(card) = &req.card
54 && card.capabilities.extensions.is_empty()
55 {
56 return Ok(());
57 }
58
59 for uri in &self.extension_uris {
60 if is_extension_supported(req.card.as_ref(), uri) {
61 req.meta.append(EXTENSIONS_META_KEY, uri.clone());
62 }
63 }
64 Ok(())
65 })
66 }
67}
68
69#[cfg(test)]
70#[allow(clippy::unwrap_used)]
71mod tests {
72 use std::collections::HashMap;
73
74 use ra2a::client::CallMeta;
75 use ra2a::types::{AgentCapabilities, AgentCard, AgentExtension};
76
77 use super::*;
78
79 fn make_card(uris: &[&str]) -> AgentCard {
80 AgentCard {
81 name: "test".into(),
82 url: "https://example.com".into(),
83 version: "1.0".into(),
84 capabilities: AgentCapabilities {
85 extensions: uris
86 .iter()
87 .map(|u| AgentExtension {
88 uri: (*u).into(),
89 description: String::new(),
90 required: false,
91 params: HashMap::default(),
92 })
93 .collect(),
94 ..AgentCapabilities::default()
95 },
96 skills: vec![],
97 ..AgentCard::default()
98 }
99 }
100
101 fn make_request(card: Option<AgentCard>) -> Request {
102 Request {
103 method: "message/send".into(),
104 base_url: "https://example.com".into(),
105 meta: CallMeta::default(),
106 card,
107 payload: Box::new(()),
108 }
109 }
110
111 #[tokio::test]
112 async fn test_activator_filters_by_card() {
113 let activator = ExtensionActivator::new(vec![
114 "urn:a2a:ext:duration".into(),
115 "urn:a2a:ext:missing".into(),
116 ]);
117
118 let card = make_card(&["urn:a2a:ext:duration", "urn:a2a:ext:other"]);
119 let mut req = make_request(Some(card));
120
121 activator.before(&mut req).await.unwrap();
122
123 let vals = req.meta.get_all(EXTENSIONS_META_KEY);
124 assert_eq!(vals, &["urn:a2a:ext:duration"]);
125 }
126
127 #[tokio::test]
128 async fn test_activator_no_card_sends_all() {
129 let activator =
130 ExtensionActivator::new(vec!["urn:a2a:ext:a".into(), "urn:a2a:ext:b".into()]);
131
132 let mut req = make_request(None);
133 activator.before(&mut req).await.unwrap();
134
135 let vals = req.meta.get_all(EXTENSIONS_META_KEY);
136 assert_eq!(vals, &["urn:a2a:ext:a", "urn:a2a:ext:b"]);
137 }
138
139 #[tokio::test]
140 async fn test_activator_empty_card_extensions_skips() {
141 let activator = ExtensionActivator::new(vec!["urn:a2a:ext:duration".into()]);
142
143 let card = make_card(&[]);
144 let mut req = make_request(Some(card));
145
146 activator.before(&mut req).await.unwrap();
147
148 let vals = req.meta.get_all(EXTENSIONS_META_KEY);
149 assert!(vals.is_empty());
150 }
151}