Skip to main content

a2a_client/
factory.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3use a2a::*;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::client::A2AClient;
8use crate::jsonrpc::JsonRpcTransportFactory;
9use crate::middleware::CallInterceptor;
10use crate::rest::RestTransportFactory;
11use crate::transport::TransportFactory;
12
13/// Key for looking up transport factories.
14#[derive(Debug, Clone, Hash, PartialEq, Eq)]
15struct TransportKey {
16    protocol: String,
17    major_version: u64,
18}
19
20impl TransportKey {
21    fn from_interface(iface: &AgentInterface) -> Self {
22        let major = iface
23            .protocol_version
24            .split('.')
25            .next()
26            .and_then(|s| s.parse::<u64>().ok())
27            .unwrap_or(1);
28        TransportKey {
29            protocol: iface.protocol_binding.clone(),
30            major_version: major,
31        }
32    }
33
34    fn from_protocol(protocol: &str, version: &str) -> Self {
35        let major = version
36            .split('.')
37            .next()
38            .and_then(|s| s.parse::<u64>().ok())
39            .unwrap_or(1);
40        TransportKey {
41            protocol: protocol.to_string(),
42            major_version: major,
43        }
44    }
45}
46
47/// Factory for creating [`A2AClient`] instances with automatic protocol negotiation.
48///
49/// Maintains a registry of [`TransportFactory`] implementations indexed by protocol
50/// and major version. When creating a client from an [`AgentCard`], it selects the
51/// best matching transport based on the agent's declared interfaces and the client's
52/// preferred bindings.
53pub struct A2AClientFactory {
54    factories: HashMap<TransportKey, Arc<dyn TransportFactory>>,
55    preferred_bindings: Vec<String>,
56    interceptors: Vec<Arc<dyn CallInterceptor>>,
57}
58
59impl A2AClientFactory {
60    /// Create a builder for configuring the factory.
61    pub fn builder() -> A2AClientFactoryBuilder {
62        A2AClientFactoryBuilder::new()
63    }
64
65    /// Create a client by negotiating the best transport with the agent card.
66    ///
67    /// Selection algorithm (mirrors Go SDK):
68    /// 1. For each interface in the agent card's `supported_interfaces`
69    /// 2. Look up matching factory by (protocol, major_version)
70    /// 3. Rank candidates: client preference order, then newest version
71    /// 4. Try connecting in rank order; first success wins
72    pub async fn create_from_card(
73        &self,
74        card: &AgentCard,
75    ) -> Result<A2AClient<Box<dyn crate::Transport>>, A2AError> {
76        let mut candidates: Vec<(usize, &AgentInterface, &Arc<dyn TransportFactory>)> = Vec::new();
77
78        for iface in &card.supported_interfaces {
79            let key = TransportKey::from_interface(iface);
80            if let Some(factory) = self.factories.get(&key) {
81                let priority = self
82                    .preferred_bindings
83                    .iter()
84                    .position(|b| b == &iface.protocol_binding)
85                    .unwrap_or(usize::MAX);
86                candidates.push((priority, iface, factory));
87            }
88        }
89
90        if candidates.is_empty() {
91            return Err(A2AError::unsupported_operation(
92                "no compatible transport found for agent card interfaces",
93            ));
94        }
95
96        // Sort by preference (lower is better)
97        candidates.sort_by_key(|(prio, _, _)| *prio);
98
99        let mut last_err = None;
100        for (_prio, iface, factory) in &candidates {
101            match factory.create(card, iface).await {
102                Ok(transport) => {
103                    return Ok(
104                        A2AClient::new(transport).with_interceptors(self.interceptors.clone())
105                    );
106                }
107                Err(e) => {
108                    tracing::debug!(
109                        protocol = %iface.protocol_binding,
110                        url = %iface.url,
111                        error = %e,
112                        "transport creation failed, trying next"
113                    );
114                    last_err = Some(e);
115                }
116            }
117        }
118
119        Err(last_err.unwrap_or_else(|| A2AError::internal("failed to create transport")))
120    }
121}
122
123/// Builder for [`A2AClientFactory`].
124pub struct A2AClientFactoryBuilder {
125    factories: HashMap<TransportKey, Arc<dyn TransportFactory>>,
126    preferred_bindings: Vec<String>,
127    interceptors: Vec<Arc<dyn CallInterceptor>>,
128    include_defaults: bool,
129}
130
131impl A2AClientFactoryBuilder {
132    fn new() -> Self {
133        A2AClientFactoryBuilder {
134            factories: HashMap::new(),
135            preferred_bindings: vec![
136                TRANSPORT_PROTOCOL_JSONRPC.to_string(),
137                TRANSPORT_PROTOCOL_HTTP_JSON.to_string(),
138            ],
139            interceptors: Vec::new(),
140            include_defaults: true,
141        }
142    }
143
144    /// Register a transport factory for a protocol binding.
145    pub fn register(mut self, factory: Arc<dyn TransportFactory>) -> Self {
146        let key = TransportKey::from_protocol(factory.protocol(), VERSION);
147        self.factories.insert(key, factory);
148        self
149    }
150
151    /// Set preferred binding order. First is most preferred.
152    pub fn preferred_bindings(mut self, bindings: Vec<String>) -> Self {
153        self.preferred_bindings = bindings;
154        self
155    }
156
157    /// Add a call interceptor.
158    pub fn with_interceptor(mut self, interceptor: Arc<dyn CallInterceptor>) -> Self {
159        self.interceptors.push(interceptor);
160        self
161    }
162
163    /// Disable default JSON-RPC and REST transport factories.
164    pub fn no_defaults(mut self) -> Self {
165        self.include_defaults = false;
166        self
167    }
168
169    /// Build the factory.
170    pub fn build(mut self) -> A2AClientFactory {
171        if self.include_defaults {
172            let jsonrpc_key = TransportKey::from_protocol(TRANSPORT_PROTOCOL_JSONRPC, VERSION);
173            self.factories
174                .entry(jsonrpc_key)
175                .or_insert_with(|| Arc::new(JsonRpcTransportFactory::new(None)));
176
177            let rest_key = TransportKey::from_protocol(TRANSPORT_PROTOCOL_HTTP_JSON, VERSION);
178            self.factories
179                .entry(rest_key)
180                .or_insert_with(|| Arc::new(RestTransportFactory::new(None)));
181        }
182
183        A2AClientFactory {
184            factories: self.factories,
185            preferred_bindings: self.preferred_bindings,
186            interceptors: self.interceptors,
187        }
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_transport_key_from_interface() {
197        let iface = AgentInterface::new("http://localhost", "jsonrpc");
198        let key = TransportKey::from_interface(&iface);
199        assert_eq!(key.protocol, "jsonrpc");
200    }
201
202    #[test]
203    fn test_transport_key_from_interface_no_version() {
204        let mut iface = AgentInterface::new("http://localhost", "rest");
205        iface.protocol_version = "bad".to_string();
206        let key = TransportKey::from_interface(&iface);
207        assert_eq!(key.major_version, 1); // defaults to 1
208    }
209
210    #[test]
211    fn test_transport_key_from_protocol() {
212        let key = TransportKey::from_protocol("jsonrpc", "2.3.4");
213        assert_eq!(key.protocol, "jsonrpc");
214        assert_eq!(key.major_version, 2);
215    }
216
217    #[test]
218    fn test_builder_defaults() {
219        let factory = A2AClientFactory::builder().build();
220        assert_eq!(factory.factories.len(), 2); // jsonrpc + rest
221        assert_eq!(factory.preferred_bindings.len(), 2);
222    }
223
224    #[test]
225    fn test_builder_no_defaults() {
226        let factory = A2AClientFactory::builder().no_defaults().build();
227        assert!(factory.factories.is_empty());
228    }
229
230    #[test]
231    fn test_builder_preferred_bindings() {
232        let factory = A2AClientFactory::builder()
233            .preferred_bindings(vec!["grpc".to_string()])
234            .build();
235        assert_eq!(factory.preferred_bindings, vec!["grpc"]);
236    }
237
238    #[test]
239    fn test_builder_with_interceptor() {
240        use crate::middleware::LoggingInterceptor;
241        let factory = A2AClientFactory::builder()
242            .with_interceptor(Arc::new(LoggingInterceptor))
243            .build();
244        assert_eq!(factory.interceptors.len(), 1);
245    }
246
247    #[tokio::test]
248    async fn test_create_from_card_no_matching_transport() {
249        let factory = A2AClientFactory::builder().no_defaults().build();
250        let card = AgentCard {
251            name: "test".into(),
252            description: "test agent".into(),
253            version: "1.0".into(),
254            supported_interfaces: vec![AgentInterface::new("http://localhost", "unknown")],
255            capabilities: AgentCapabilities::default(),
256            default_input_modes: vec!["text".into()],
257            default_output_modes: vec!["text".into()],
258            skills: vec![],
259            provider: None,
260            documentation_url: None,
261            icon_url: None,
262            security_schemes: None,
263            security_requirements: None,
264            signatures: None,
265        };
266        let result = factory.create_from_card(&card).await;
267        assert!(result.is_err());
268    }
269}