1use std::future::Future;
8use std::pin::Pin;
9
10use ra2a::SVC_PARAM_EXTENSIONS;
11use ra2a::client::{CallInterceptor, Request};
12use ra2a::error::Result;
13
14use crate::util::is_extension_supported;
15
16#[derive(Debug)]
34pub struct ExtensionActivator {
35 extension_uris: Vec<String>,
37}
38
39impl ExtensionActivator {
40 pub const fn new(extension_uris: Vec<String>) -> Self {
42 Self { extension_uris }
43 }
44}
45
46impl CallInterceptor for ExtensionActivator {
47 fn before<'a>(
48 &'a self,
49 req: &'a mut Request,
50 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
51 Box::pin(async move {
52 if let Some(card) = &req.card
54 && card.capabilities.extensions.is_empty()
55 {
56 return Ok(());
57 }
58
59 for uri in &self.extension_uris {
60 if is_extension_supported(req.card.as_ref(), uri) {
61 req.service_params.append(SVC_PARAM_EXTENSIONS, uri.clone());
62 }
63 }
64 Ok(())
65 })
66 }
67}
68
69#[cfg(test)]
70#[allow(clippy::unwrap_used)]
71mod tests {
72 use ra2a::client::ServiceParams;
73 use ra2a::types::{
74 AgentCapabilities, AgentCard, AgentExtension, AgentInterface, TransportProtocol,
75 };
76
77 use super::*;
78
79 fn make_card(uris: &[&str]) -> AgentCard {
80 let mut card = AgentCard::new(
81 "test",
82 "test agent",
83 vec![AgentInterface::new(
84 "https://example.com",
85 TransportProtocol::new("JSONRPC"),
86 )],
87 );
88 card.capabilities = AgentCapabilities {
89 extensions: uris
90 .iter()
91 .map(|u| AgentExtension {
92 uri: (*u).into(),
93 description: None,
94 required: false,
95 params: None,
96 })
97 .collect(),
98 ..AgentCapabilities::default()
99 };
100 card
101 }
102
103 fn make_request(card: Option<AgentCard>) -> Request {
104 Request {
105 method: "message/send".into(),
106 card,
107 service_params: ServiceParams::default(),
108 payload: Box::new(()),
109 }
110 }
111
112 #[tokio::test]
113 async fn test_activator_filters_by_card() {
114 let activator = ExtensionActivator::new(vec![
115 "urn:a2a:ext:duration".into(),
116 "urn:a2a:ext:missing".into(),
117 ]);
118
119 let card = make_card(&["urn:a2a:ext:duration", "urn:a2a:ext:other"]);
120 let mut req = make_request(Some(card));
121
122 activator.before(&mut req).await.unwrap();
123
124 let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
125 assert_eq!(vals, &["urn:a2a:ext:duration"]);
126 }
127
128 #[tokio::test]
129 async fn test_activator_no_card_sends_all() {
130 let activator =
131 ExtensionActivator::new(vec!["urn:a2a:ext:a".into(), "urn:a2a:ext:b".into()]);
132
133 let mut req = make_request(None);
134 activator.before(&mut req).await.unwrap();
135
136 let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
137 assert_eq!(vals, &["urn:a2a:ext:a", "urn:a2a:ext:b"]);
138 }
139
140 #[tokio::test]
141 async fn test_activator_empty_card_extensions_skips() {
142 let activator = ExtensionActivator::new(vec!["urn:a2a:ext:duration".into()]);
143
144 let card = make_card(&[]);
145 let mut req = make_request(Some(card));
146
147 activator.before(&mut req).await.unwrap();
148
149 let vals = req.service_params.get_all(SVC_PARAM_EXTENSIONS);
150 assert!(vals.is_empty());
151 }
152}