mod transport_factory;
use std::time::Duration;
use a2a_protocol_types::AgentCard;
use crate::config::{ClientConfig, TlsConfig};
use crate::error::{ClientError, ClientResult};
use crate::interceptor::{CallInterceptor, InterceptorChain};
use crate::retry::RetryPolicy;
use crate::transport::Transport;
#[cfg(feature = "tracing")]
const SUPPORTED_PROTOCOL_MAJOR: u32 = 1;
pub struct ClientBuilder {
pub(super) endpoint: String,
pub(super) transport_override: Option<Box<dyn Transport>>,
pub(super) interceptors: InterceptorChain,
pub(super) config: ClientConfig,
pub(super) preferred_binding: Option<String>,
pub(super) retry_policy: Option<RetryPolicy>,
}
impl ClientBuilder {
#[must_use]
pub fn new(endpoint: impl Into<String>) -> Self {
Self {
endpoint: endpoint.into(),
transport_override: None,
interceptors: InterceptorChain::new(),
config: ClientConfig::default(),
preferred_binding: None,
retry_policy: None,
}
}
pub fn from_card(card: &AgentCard) -> ClientResult<Self> {
let first = card.supported_interfaces.first().ok_or_else(|| {
ClientError::InvalidEndpoint("agent card has no supported interfaces".into())
})?;
let (endpoint, binding) = (first.url.clone(), first.protocol_binding.clone());
#[cfg(feature = "tracing")]
if let Some(version) = card
.supported_interfaces
.first()
.map(|i| i.protocol_version.clone())
.filter(|v| !v.is_empty())
{
let major = version
.split('.')
.next()
.and_then(|s| s.parse::<u32>().ok());
if major != Some(SUPPORTED_PROTOCOL_MAJOR) {
trace_warn!(
agent = %card.name,
protocol_version = %version,
supported_major = SUPPORTED_PROTOCOL_MAJOR,
"agent protocol version may be incompatible with this client"
);
}
}
Ok(Self {
endpoint,
transport_override: None,
interceptors: InterceptorChain::new(),
config: ClientConfig {
tenant: first.tenant.clone(),
..ClientConfig::default()
},
preferred_binding: Some(binding),
retry_policy: None,
})
}
#[must_use]
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.config.request_timeout = timeout;
self
}
#[must_use]
pub const fn with_stream_connect_timeout(mut self, timeout: Duration) -> Self {
self.config.stream_connect_timeout = timeout;
self
}
#[must_use]
pub const fn with_connection_timeout(mut self, timeout: Duration) -> Self {
self.config.connection_timeout = timeout;
self
}
#[must_use]
pub fn with_protocol_binding(mut self, binding: impl Into<String>) -> Self {
self.preferred_binding = Some(binding.into());
self
}
#[must_use]
pub fn with_accepted_output_modes(mut self, modes: Vec<String>) -> Self {
self.config.accepted_output_modes = modes;
self
}
#[must_use]
pub const fn with_history_length(mut self, length: u32) -> Self {
self.config.history_length = Some(length);
self
}
#[must_use]
pub fn with_tenant(mut self, tenant: impl Into<String>) -> Self {
self.config.tenant = Some(tenant.into());
self
}
#[must_use]
pub const fn with_return_immediately(mut self, val: bool) -> Self {
self.config.return_immediately = val;
self
}
#[must_use]
pub fn with_custom_transport(mut self, transport: impl Transport) -> Self {
self.transport_override = Some(Box::new(transport));
self
}
#[must_use]
pub const fn without_tls(mut self) -> Self {
self.config.tls = TlsConfig::Disabled;
self
}
#[must_use]
pub const fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
self.retry_policy = Some(policy);
self
}
#[must_use]
pub fn with_interceptor<I: CallInterceptor>(mut self, interceptor: I) -> Self {
self.interceptors.push(interceptor);
self
}
}
impl std::fmt::Debug for ClientBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientBuilder")
.field("endpoint", &self.endpoint)
.field("preferred_binding", &self.preferred_binding)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn builder_from_card_uses_card_url() {
use a2a_protocol_types::{AgentCapabilities, AgentCard, AgentInterface};
let card = AgentCard {
url: None,
name: "test".into(),
version: "1.0".into(),
description: "A test agent".into(),
supported_interfaces: vec![AgentInterface {
url: "http://localhost:9090".into(),
protocol_binding: "JSONRPC".into(),
protocol_version: "1.0.0".into(),
tenant: None,
}],
provider: None,
icon_url: None,
documentation_url: None,
capabilities: AgentCapabilities::none(),
security_schemes: None,
security_requirements: None,
default_input_modes: vec![],
default_output_modes: vec![],
skills: vec![],
signatures: None,
};
let client = ClientBuilder::from_card(&card)
.unwrap()
.build()
.expect("build");
let _ = client;
}
#[test]
fn builder_with_timeout_sets_config() {
let client = ClientBuilder::new("http://localhost:8080")
.with_timeout(Duration::from_secs(60))
.build()
.expect("build");
assert_eq!(client.config().request_timeout, Duration::from_secs(60));
}
#[test]
fn builder_from_card_empty_interfaces_returns_error() {
use a2a_protocol_types::{AgentCapabilities, AgentCard};
let card = AgentCard {
url: None,
name: "empty".into(),
version: "1.0".into(),
description: "No interfaces".into(),
supported_interfaces: vec![],
provider: None,
icon_url: None,
documentation_url: None,
capabilities: AgentCapabilities::none(),
security_schemes: None,
security_requirements: None,
default_input_modes: vec![],
default_output_modes: vec![],
skills: vec![],
signatures: None,
};
let result = ClientBuilder::from_card(&card);
assert!(result.is_err(), "empty interfaces should return error");
}
#[test]
fn builder_with_return_immediately() {
let client = ClientBuilder::new("http://localhost:8080")
.with_return_immediately(true)
.build()
.expect("build");
assert!(client.config().return_immediately);
}
#[test]
fn builder_with_history_length() {
let client = ClientBuilder::new("http://localhost:8080")
.with_history_length(10)
.build()
.expect("build");
assert_eq!(client.config().history_length, Some(10));
}
#[test]
fn builder_debug_contains_fields() {
let builder = ClientBuilder::new("http://localhost:8080");
let debug = format!("{builder:?}");
assert!(
debug.contains("ClientBuilder"),
"debug output missing struct name: {debug}"
);
assert!(
debug.contains("http://localhost:8080"),
"debug output missing endpoint: {debug}"
);
}
#[test]
fn builder_from_card_mismatched_version() {
use a2a_protocol_types::{AgentCapabilities, AgentCard, AgentInterface};
let card = AgentCard {
url: None,
name: "mismatch".into(),
version: "1.0".into(),
description: "Version mismatch test".into(),
supported_interfaces: vec![AgentInterface {
url: "http://localhost:9091".into(),
protocol_binding: "JSONRPC".into(),
protocol_version: "99.0.0".into(), tenant: None,
}],
provider: None,
icon_url: None,
documentation_url: None,
capabilities: AgentCapabilities::none(),
security_schemes: None,
security_requirements: None,
default_input_modes: vec![],
default_output_modes: vec![],
skills: vec![],
signatures: None,
};
let builder = ClientBuilder::from_card(&card).unwrap();
assert_eq!(builder.endpoint, "http://localhost:9091");
}
#[test]
fn builder_with_connection_timeout_and_retry_policy() {
use crate::retry::RetryPolicy;
let client = ClientBuilder::new("http://localhost:8080")
.with_connection_timeout(Duration::from_secs(5))
.with_retry_policy(RetryPolicy::default())
.build()
.expect("build");
assert_eq!(client.config().connection_timeout, Duration::from_secs(5));
}
#[test]
fn builder_with_stream_connect_timeout() {
let client = ClientBuilder::new("http://localhost:8080")
.with_stream_connect_timeout(Duration::from_secs(15))
.build()
.expect("build");
assert_eq!(
client.config().stream_connect_timeout,
Duration::from_secs(15)
);
}
}