1use a2a::*;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::client::A2AClient;
8use crate::jsonrpc::JsonRpcTransportFactory;
9use crate::middleware::CallInterceptor;
10use crate::rest::RestTransportFactory;
11use crate::transport::TransportFactory;
12
13#[derive(Debug, Clone, Hash, PartialEq, Eq)]
15struct TransportKey {
16 protocol: String,
17 major_version: u64,
18}
19
20impl TransportKey {
21 fn from_interface(iface: &AgentInterface) -> Self {
22 let major = iface
23 .protocol_version
24 .split('.')
25 .next()
26 .and_then(|s| s.parse::<u64>().ok())
27 .unwrap_or(1);
28 TransportKey {
29 protocol: iface.protocol_binding.clone(),
30 major_version: major,
31 }
32 }
33
34 fn from_protocol(protocol: &str, version: &str) -> Self {
35 let major = version
36 .split('.')
37 .next()
38 .and_then(|s| s.parse::<u64>().ok())
39 .unwrap_or(1);
40 TransportKey {
41 protocol: protocol.to_string(),
42 major_version: major,
43 }
44 }
45}
46
47pub struct A2AClientFactory {
54 factories: HashMap<TransportKey, Arc<dyn TransportFactory>>,
55 preferred_bindings: Vec<String>,
56 interceptors: Vec<Arc<dyn CallInterceptor>>,
57}
58
59impl A2AClientFactory {
60 pub fn builder() -> A2AClientFactoryBuilder {
62 A2AClientFactoryBuilder::new()
63 }
64
65 pub async fn create_from_card(
73 &self,
74 card: &AgentCard,
75 ) -> Result<A2AClient<Box<dyn crate::Transport>>, A2AError> {
76 let mut candidates: Vec<(usize, &AgentInterface, &Arc<dyn TransportFactory>)> = Vec::new();
77
78 for iface in &card.supported_interfaces {
79 let key = TransportKey::from_interface(iface);
80 if let Some(factory) = self.factories.get(&key) {
81 let priority = self
82 .preferred_bindings
83 .iter()
84 .position(|b| b == &iface.protocol_binding)
85 .unwrap_or(usize::MAX);
86 candidates.push((priority, iface, factory));
87 }
88 }
89
90 if candidates.is_empty() {
91 return Err(A2AError::unsupported_operation(
92 "no compatible transport found for agent card interfaces",
93 ));
94 }
95
96 candidates.sort_by_key(|(prio, _, _)| *prio);
98
99 let mut last_err = None;
100 for (_prio, iface, factory) in &candidates {
101 match factory.create(card, iface).await {
102 Ok(transport) => {
103 return Ok(
104 A2AClient::new(transport).with_interceptors(self.interceptors.clone())
105 );
106 }
107 Err(e) => {
108 tracing::debug!(
109 protocol = %iface.protocol_binding,
110 url = %iface.url,
111 error = %e,
112 "transport creation failed, trying next"
113 );
114 last_err = Some(e);
115 }
116 }
117 }
118
119 Err(last_err.unwrap_or_else(|| A2AError::internal("failed to create transport")))
120 }
121}
122
123pub struct A2AClientFactoryBuilder {
125 factories: HashMap<TransportKey, Arc<dyn TransportFactory>>,
126 preferred_bindings: Vec<String>,
127 interceptors: Vec<Arc<dyn CallInterceptor>>,
128 include_defaults: bool,
129}
130
131impl A2AClientFactoryBuilder {
132 fn new() -> Self {
133 A2AClientFactoryBuilder {
134 factories: HashMap::new(),
135 preferred_bindings: vec![
136 TRANSPORT_PROTOCOL_JSONRPC.to_string(),
137 TRANSPORT_PROTOCOL_HTTP_JSON.to_string(),
138 ],
139 interceptors: Vec::new(),
140 include_defaults: true,
141 }
142 }
143
144 pub fn register(mut self, factory: Arc<dyn TransportFactory>) -> Self {
146 let key = TransportKey::from_protocol(factory.protocol(), VERSION);
147 self.factories.insert(key, factory);
148 self
149 }
150
151 pub fn preferred_bindings(mut self, bindings: Vec<String>) -> Self {
153 self.preferred_bindings = bindings;
154 self
155 }
156
157 pub fn with_interceptor(mut self, interceptor: Arc<dyn CallInterceptor>) -> Self {
159 self.interceptors.push(interceptor);
160 self
161 }
162
163 pub fn no_defaults(mut self) -> Self {
165 self.include_defaults = false;
166 self
167 }
168
169 pub fn build(mut self) -> A2AClientFactory {
171 if self.include_defaults {
172 let jsonrpc_key = TransportKey::from_protocol(TRANSPORT_PROTOCOL_JSONRPC, VERSION);
173 self.factories
174 .entry(jsonrpc_key)
175 .or_insert_with(|| Arc::new(JsonRpcTransportFactory::new(None)));
176
177 let rest_key = TransportKey::from_protocol(TRANSPORT_PROTOCOL_HTTP_JSON, VERSION);
178 self.factories
179 .entry(rest_key)
180 .or_insert_with(|| Arc::new(RestTransportFactory::new(None)));
181 }
182
183 A2AClientFactory {
184 factories: self.factories,
185 preferred_bindings: self.preferred_bindings,
186 interceptors: self.interceptors,
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn test_transport_key_from_interface() {
197 let iface = AgentInterface::new("http://localhost", "jsonrpc");
198 let key = TransportKey::from_interface(&iface);
199 assert_eq!(key.protocol, "jsonrpc");
200 }
201
202 #[test]
203 fn test_transport_key_from_interface_no_version() {
204 let mut iface = AgentInterface::new("http://localhost", "rest");
205 iface.protocol_version = "bad".to_string();
206 let key = TransportKey::from_interface(&iface);
207 assert_eq!(key.major_version, 1); }
209
210 #[test]
211 fn test_transport_key_from_protocol() {
212 let key = TransportKey::from_protocol("jsonrpc", "2.3.4");
213 assert_eq!(key.protocol, "jsonrpc");
214 assert_eq!(key.major_version, 2);
215 }
216
217 #[test]
218 fn test_builder_defaults() {
219 let factory = A2AClientFactory::builder().build();
220 assert_eq!(factory.factories.len(), 2); assert_eq!(factory.preferred_bindings.len(), 2);
222 }
223
224 #[test]
225 fn test_builder_no_defaults() {
226 let factory = A2AClientFactory::builder().no_defaults().build();
227 assert!(factory.factories.is_empty());
228 }
229
230 #[test]
231 fn test_builder_preferred_bindings() {
232 let factory = A2AClientFactory::builder()
233 .preferred_bindings(vec!["grpc".to_string()])
234 .build();
235 assert_eq!(factory.preferred_bindings, vec!["grpc"]);
236 }
237
238 #[test]
239 fn test_builder_with_interceptor() {
240 use crate::middleware::LoggingInterceptor;
241 let factory = A2AClientFactory::builder()
242 .with_interceptor(Arc::new(LoggingInterceptor))
243 .build();
244 assert_eq!(factory.interceptors.len(), 1);
245 }
246
247 #[tokio::test]
248 async fn test_create_from_card_no_matching_transport() {
249 let factory = A2AClientFactory::builder().no_defaults().build();
250 let card = AgentCard {
251 name: "test".into(),
252 description: "test agent".into(),
253 version: "1.0".into(),
254 supported_interfaces: vec![AgentInterface::new("http://localhost", "unknown")],
255 capabilities: AgentCapabilities::default(),
256 default_input_modes: vec!["text".into()],
257 default_output_modes: vec!["text".into()],
258 skills: vec![],
259 provider: None,
260 documentation_url: None,
261 icon_url: None,
262 security_schemes: None,
263 security_requirements: None,
264 signatures: None,
265 };
266 let result = factory.create_from_card(&card).await;
267 assert!(result.is_err());
268 }
269}