use a2a::*;
use std::collections::HashMap;
use std::sync::Arc;
use crate::client::A2AClient;
use crate::jsonrpc::JsonRpcTransportFactory;
use crate::middleware::CallInterceptor;
use crate::rest::RestTransportFactory;
use crate::transport::TransportFactory;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct TransportKey {
protocol: String,
major_version: u64,
}
impl TransportKey {
fn from_interface(iface: &AgentInterface) -> Self {
let major = iface
.protocol_version
.split('.')
.next()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(1);
TransportKey {
protocol: iface.protocol_binding.clone(),
major_version: major,
}
}
fn from_protocol(protocol: &str, version: &str) -> Self {
let major = version
.split('.')
.next()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(1);
TransportKey {
protocol: protocol.to_string(),
major_version: major,
}
}
}
pub struct A2AClientFactory {
factories: HashMap<TransportKey, Arc<dyn TransportFactory>>,
preferred_bindings: Vec<String>,
interceptors: Vec<Arc<dyn CallInterceptor>>,
}
impl A2AClientFactory {
pub fn builder() -> A2AClientFactoryBuilder {
A2AClientFactoryBuilder::new()
}
pub async fn create_from_card(
&self,
card: &AgentCard,
) -> Result<A2AClient<Box<dyn crate::Transport>>, A2AError> {
let mut candidates: Vec<(usize, &AgentInterface, &Arc<dyn TransportFactory>)> = Vec::new();
for iface in &card.supported_interfaces {
let key = TransportKey::from_interface(iface);
if let Some(factory) = self.factories.get(&key) {
let priority = self
.preferred_bindings
.iter()
.position(|b| b == &iface.protocol_binding)
.unwrap_or(usize::MAX);
candidates.push((priority, iface, factory));
}
}
if candidates.is_empty() {
return Err(A2AError::unsupported_operation(
"no compatible transport found for agent card interfaces",
));
}
candidates.sort_by_key(|(prio, _, _)| *prio);
let mut last_err = None;
for (_prio, iface, factory) in &candidates {
match factory.create(card, iface).await {
Ok(transport) => {
return Ok(
A2AClient::new(transport).with_interceptors(self.interceptors.clone())
);
}
Err(e) => {
tracing::debug!(
protocol = %iface.protocol_binding,
url = %iface.url,
error = %e,
"transport creation failed, trying next"
);
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| A2AError::internal("failed to create transport")))
}
}
pub struct A2AClientFactoryBuilder {
factories: HashMap<TransportKey, Arc<dyn TransportFactory>>,
preferred_bindings: Vec<String>,
interceptors: Vec<Arc<dyn CallInterceptor>>,
include_defaults: bool,
}
impl A2AClientFactoryBuilder {
fn new() -> Self {
A2AClientFactoryBuilder {
factories: HashMap::new(),
preferred_bindings: vec![
TRANSPORT_PROTOCOL_JSONRPC.to_string(),
TRANSPORT_PROTOCOL_HTTP_JSON.to_string(),
],
interceptors: Vec::new(),
include_defaults: true,
}
}
pub fn register(mut self, factory: Arc<dyn TransportFactory>) -> Self {
let key = TransportKey::from_protocol(factory.protocol(), VERSION);
self.factories.insert(key, factory);
self
}
pub fn preferred_bindings(mut self, bindings: Vec<String>) -> Self {
self.preferred_bindings = bindings;
self
}
pub fn with_interceptor(mut self, interceptor: Arc<dyn CallInterceptor>) -> Self {
self.interceptors.push(interceptor);
self
}
pub fn no_defaults(mut self) -> Self {
self.include_defaults = false;
self
}
pub fn build(mut self) -> A2AClientFactory {
if self.include_defaults {
let jsonrpc_key = TransportKey::from_protocol(TRANSPORT_PROTOCOL_JSONRPC, VERSION);
self.factories
.entry(jsonrpc_key)
.or_insert_with(|| Arc::new(JsonRpcTransportFactory::new(None)));
let rest_key = TransportKey::from_protocol(TRANSPORT_PROTOCOL_HTTP_JSON, VERSION);
self.factories
.entry(rest_key)
.or_insert_with(|| Arc::new(RestTransportFactory::new(None)));
}
A2AClientFactory {
factories: self.factories,
preferred_bindings: self.preferred_bindings,
interceptors: self.interceptors,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_key_from_interface() {
let iface = AgentInterface::new("http://localhost", "jsonrpc");
let key = TransportKey::from_interface(&iface);
assert_eq!(key.protocol, "jsonrpc");
}
#[test]
fn test_transport_key_from_interface_no_version() {
let mut iface = AgentInterface::new("http://localhost", "rest");
iface.protocol_version = "bad".to_string();
let key = TransportKey::from_interface(&iface);
assert_eq!(key.major_version, 1); }
#[test]
fn test_transport_key_from_protocol() {
let key = TransportKey::from_protocol("jsonrpc", "2.3.4");
assert_eq!(key.protocol, "jsonrpc");
assert_eq!(key.major_version, 2);
}
#[test]
fn test_builder_defaults() {
let factory = A2AClientFactory::builder().build();
assert_eq!(factory.factories.len(), 2); assert_eq!(factory.preferred_bindings.len(), 2);
}
#[test]
fn test_builder_no_defaults() {
let factory = A2AClientFactory::builder().no_defaults().build();
assert!(factory.factories.is_empty());
}
#[test]
fn test_builder_preferred_bindings() {
let factory = A2AClientFactory::builder()
.preferred_bindings(vec!["grpc".to_string()])
.build();
assert_eq!(factory.preferred_bindings, vec!["grpc"]);
}
#[test]
fn test_builder_with_interceptor() {
use crate::middleware::LoggingInterceptor;
let factory = A2AClientFactory::builder()
.with_interceptor(Arc::new(LoggingInterceptor))
.build();
assert_eq!(factory.interceptors.len(), 1);
}
#[tokio::test]
async fn test_create_from_card_no_matching_transport() {
let factory = A2AClientFactory::builder().no_defaults().build();
let card = AgentCard {
name: "test".into(),
description: "test agent".into(),
version: "1.0".into(),
supported_interfaces: vec![AgentInterface::new("http://localhost", "unknown")],
capabilities: AgentCapabilities::default(),
default_input_modes: vec!["text".into()],
default_output_modes: vec!["text".into()],
skills: vec![],
provider: None,
documentation_url: None,
icon_url: None,
security_schemes: None,
security_requirements: None,
signatures: None,
};
let result = factory.create_from_card(&card).await;
assert!(result.is_err());
}
}