1use std::future::Future;
8use std::pin::Pin;
9
10use ra2a::EXTENSIONS_META_KEY;
11use ra2a::client::{CallInterceptor, Request};
12use ra2a::error::Result;
13
14use crate::util::is_extension_supported;
15
16#[derive(Debug)]
34pub struct ExtensionActivator {
35 extension_uris: Vec<String>,
37}
38
39impl ExtensionActivator {
40 pub const fn new(extension_uris: Vec<String>) -> Self {
42 Self { extension_uris }
43 }
44}
45
46impl CallInterceptor for ExtensionActivator {
47 fn before<'a>(
48 &'a self,
49 req: &'a mut Request,
50 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
51 Box::pin(async move {
52 if let Some(card) = &req.card
54 && card.capabilities.extensions.is_empty()
55 {
56 return Ok(());
57 }
58
59 for uri in &self.extension_uris {
60 if is_extension_supported(req.card.as_ref(), uri) {
61 req.meta.append(EXTENSIONS_META_KEY, uri.clone());
62 }
63 }
64 Ok(())
65 })
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use ra2a::client::CallMeta;
72 use ra2a::types::{AgentCapabilities, AgentCard, AgentExtension};
73
74 use super::*;
75
76 fn make_card(uris: &[&str]) -> AgentCard {
77 AgentCard {
78 name: "test".into(),
79 url: "https://example.com".into(),
80 version: "1.0".into(),
81 capabilities: AgentCapabilities {
82 extensions: uris
83 .iter()
84 .map(|u| AgentExtension {
85 uri: (*u).into(),
86 description: String::new(),
87 required: false,
88 params: Default::default(),
89 })
90 .collect(),
91 ..AgentCapabilities::default()
92 },
93 skills: vec![],
94 ..AgentCard::default()
95 }
96 }
97
98 fn make_request(card: Option<AgentCard>) -> Request {
99 Request {
100 method: "message/send".into(),
101 base_url: "https://example.com".into(),
102 meta: CallMeta::default(),
103 card,
104 payload: Box::new(()),
105 }
106 }
107
108 #[tokio::test]
109 async fn test_activator_filters_by_card() {
110 let activator = ExtensionActivator::new(vec![
111 "urn:a2a:ext:duration".into(),
112 "urn:a2a:ext:missing".into(),
113 ]);
114
115 let card = make_card(&["urn:a2a:ext:duration", "urn:a2a:ext:other"]);
116 let mut req = make_request(Some(card));
117
118 activator.before(&mut req).await.unwrap();
119
120 let vals = req.meta.get_all(EXTENSIONS_META_KEY);
121 assert_eq!(vals, &["urn:a2a:ext:duration"]);
122 }
123
124 #[tokio::test]
125 async fn test_activator_no_card_sends_all() {
126 let activator =
127 ExtensionActivator::new(vec!["urn:a2a:ext:a".into(), "urn:a2a:ext:b".into()]);
128
129 let mut req = make_request(None);
130 activator.before(&mut req).await.unwrap();
131
132 let vals = req.meta.get_all(EXTENSIONS_META_KEY);
133 assert_eq!(vals, &["urn:a2a:ext:a", "urn:a2a:ext:b"]);
134 }
135
136 #[tokio::test]
137 async fn test_activator_empty_card_extensions_skips() {
138 let activator = ExtensionActivator::new(vec!["urn:a2a:ext:duration".into()]);
139
140 let card = make_card(&[]);
141 let mut req = make_request(Some(card));
142
143 activator.before(&mut req).await.unwrap();
144
145 let vals = req.meta.get_all(EXTENSIONS_META_KEY);
146 assert!(vals.is_empty());
147 }
148}