Skip to main content

a2a_protocol_client/
builder.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2026 Tom F.
3
4//! Fluent builder for [`A2aClient`].
5//!
6//! [`ClientBuilder`] validates configuration and assembles an [`A2aClient`]
7//! from its parts.
8//!
9//! # Example
10//!
11//! ```rust,no_run
12//! use a2a_protocol_client::{ClientBuilder, CredentialsStore};
13//! use a2a_protocol_client::auth::{AuthInterceptor, InMemoryCredentialsStore, SessionId};
14//! use std::sync::Arc;
15//!
16//! # fn example() -> Result<(), a2a_protocol_client::error::ClientError> {
17//! let store = Arc::new(InMemoryCredentialsStore::new());
18//! let session = SessionId::new("my-session");
19//! store.set(session.clone(), "bearer", "token".into());
20//!
21//! let client = ClientBuilder::new("http://localhost:8080")
22//!     .with_interceptor(AuthInterceptor::new(store, session))
23//!     .build()?;
24//! # Ok(())
25//! # }
26//! ```
27
28use std::time::Duration;
29
30use a2a_protocol_types::AgentCard;
31
32use crate::client::A2aClient;
33use crate::config::{ClientConfig, TlsConfig, BINDING_GRPC, BINDING_JSONRPC, BINDING_REST};
34use crate::error::{ClientError, ClientResult};
35use crate::interceptor::{CallInterceptor, InterceptorChain};
36use crate::transport::{JsonRpcTransport, RestTransport, Transport};
37
38// ── ClientBuilder ─────────────────────────────────────────────────────────────
39
40/// Builder for [`A2aClient`].
41///
42/// Start with [`ClientBuilder::new`] (URL) or [`ClientBuilder::from_card`]
43/// (agent card auto-configuration).
44pub struct ClientBuilder {
45    endpoint: String,
46    transport_override: Option<Box<dyn Transport>>,
47    interceptors: InterceptorChain,
48    config: ClientConfig,
49    preferred_binding: Option<String>,
50}
51
52impl ClientBuilder {
53    /// Creates a builder targeting `endpoint`.
54    ///
55    /// The endpoint is passed directly to the selected transport; it should be
56    /// the full base URL of the agent (e.g. `http://localhost:8080`).
57    #[must_use]
58    pub fn new(endpoint: impl Into<String>) -> Self {
59        Self {
60            endpoint: endpoint.into(),
61            transport_override: None,
62            interceptors: InterceptorChain::new(),
63            config: ClientConfig::default(),
64            preferred_binding: None,
65        }
66    }
67
68    /// Creates a builder pre-configured from an [`AgentCard`].
69    ///
70    /// Selects the first supported interface from the card.
71    #[must_use]
72    pub fn from_card(card: &AgentCard) -> Self {
73        let (endpoint, binding) = card
74            .supported_interfaces
75            .first()
76            .map(|i| (i.url.clone(), i.protocol_binding.clone()))
77            .unwrap_or_default();
78
79        Self {
80            endpoint,
81            transport_override: None,
82            interceptors: InterceptorChain::new(),
83            config: ClientConfig::default(),
84            preferred_binding: Some(binding),
85        }
86    }
87
88    // ── Configuration ─────────────────────────────────────────────────────────
89
90    /// Sets the per-request timeout for non-streaming calls.
91    #[must_use]
92    pub const fn with_timeout(mut self, timeout: Duration) -> Self {
93        self.config.request_timeout = timeout;
94        self
95    }
96
97    /// Sets the timeout for establishing SSE stream connections.
98    ///
99    /// Once the stream is established, this timeout no longer applies.
100    /// Defaults to 30 seconds.
101    #[must_use]
102    pub const fn with_stream_connect_timeout(mut self, timeout: Duration) -> Self {
103        self.config.stream_connect_timeout = timeout;
104        self
105    }
106
107    /// Sets the TCP connection timeout (DNS + handshake).
108    ///
109    /// Defaults to 10 seconds. Prevents hanging for the OS default (~2 min)
110    /// when the server is unreachable.
111    #[must_use]
112    pub const fn with_connection_timeout(mut self, timeout: Duration) -> Self {
113        self.config.connection_timeout = timeout;
114        self
115    }
116
117    /// Sets the preferred protocol binding.
118    ///
119    /// Overrides any binding derived from the agent card.
120    #[must_use]
121    pub fn with_protocol_binding(mut self, binding: impl Into<String>) -> Self {
122        self.preferred_binding = Some(binding.into());
123        self
124    }
125
126    /// Sets the accepted output modes sent in `SendMessage` configurations.
127    #[must_use]
128    pub fn with_accepted_output_modes(mut self, modes: Vec<String>) -> Self {
129        self.config.accepted_output_modes = modes;
130        self
131    }
132
133    /// Sets the history length to request in task responses.
134    #[must_use]
135    pub const fn with_history_length(mut self, length: u32) -> Self {
136        self.config.history_length = Some(length);
137        self
138    }
139
140    /// Sets `return_immediately` for `SendMessage` calls.
141    #[must_use]
142    pub const fn with_return_immediately(mut self, val: bool) -> Self {
143        self.config.return_immediately = val;
144        self
145    }
146
147    /// Provides a fully custom transport implementation.
148    ///
149    /// Overrides the transport that would normally be built from the endpoint
150    /// URL and protocol preference.
151    #[must_use]
152    pub fn with_custom_transport(mut self, transport: impl Transport) -> Self {
153        self.transport_override = Some(Box::new(transport));
154        self
155    }
156
157    /// Disables TLS (plain HTTP only).
158    #[must_use]
159    pub const fn without_tls(mut self) -> Self {
160        self.config.tls = TlsConfig::Disabled;
161        self
162    }
163
164    /// Adds an interceptor to the chain.
165    ///
166    /// Interceptors are run in the order they are added.
167    #[must_use]
168    pub fn with_interceptor<I: CallInterceptor>(mut self, interceptor: I) -> Self {
169        self.interceptors.push(interceptor);
170        self
171    }
172
173    // ── Build ─────────────────────────────────────────────────────────────────
174
175    /// Validates configuration and constructs the [`A2aClient`].
176    ///
177    /// # Errors
178    ///
179    /// - [`ClientError::InvalidEndpoint`] if the endpoint URL is malformed.
180    /// - [`ClientError::Transport`] if the selected transport cannot be
181    ///   initialized.
182    pub fn build(self) -> ClientResult<A2aClient> {
183        if self.config.request_timeout.is_zero() {
184            return Err(ClientError::Transport(
185                "request_timeout must be non-zero".into(),
186            ));
187        }
188        if self.config.stream_connect_timeout.is_zero() {
189            return Err(ClientError::Transport(
190                "stream_connect_timeout must be non-zero".into(),
191            ));
192        }
193
194        let transport: Box<dyn Transport> = if let Some(t) = self.transport_override {
195            t
196        } else {
197            let binding = self
198                .preferred_binding
199                .unwrap_or_else(|| BINDING_JSONRPC.into());
200
201            match binding.as_str() {
202                BINDING_JSONRPC => {
203                    let t = JsonRpcTransport::with_timeouts(
204                        &self.endpoint,
205                        self.config.request_timeout,
206                        self.config.stream_connect_timeout,
207                    )?;
208                    Box::new(t)
209                }
210                BINDING_REST => {
211                    let t = RestTransport::with_timeouts(
212                        &self.endpoint,
213                        self.config.request_timeout,
214                        self.config.stream_connect_timeout,
215                    )?;
216                    Box::new(t)
217                }
218                BINDING_GRPC => {
219                    return Err(ClientError::Transport(
220                        "gRPC transport is not supported in this version".into(),
221                    ));
222                }
223                other => {
224                    return Err(ClientError::Transport(format!(
225                        "unknown protocol binding: {other}"
226                    )));
227                }
228            }
229        };
230
231        Ok(A2aClient::new(transport, self.interceptors, self.config))
232    }
233}
234
235impl std::fmt::Debug for ClientBuilder {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        f.debug_struct("ClientBuilder")
238            .field("endpoint", &self.endpoint)
239            .field("preferred_binding", &self.preferred_binding)
240            .finish_non_exhaustive()
241    }
242}
243
244// ── Tests ─────────────────────────────────────────────────────────────────────
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn builder_defaults_to_jsonrpc() {
252        let client = ClientBuilder::new("http://localhost:8080")
253            .build()
254            .expect("build");
255        // Verify it built without error.
256        let _ = client;
257    }
258
259    #[test]
260    fn builder_rest_transport() {
261        let client = ClientBuilder::new("http://localhost:8080")
262            .with_protocol_binding(BINDING_REST)
263            .build()
264            .expect("build");
265        let _ = client;
266    }
267
268    #[test]
269    fn builder_grpc_returns_error() {
270        let result = ClientBuilder::new("http://localhost:8080")
271            .with_protocol_binding(BINDING_GRPC)
272            .build();
273        assert!(result.is_err());
274    }
275
276    #[test]
277    fn builder_invalid_url_returns_error() {
278        let result = ClientBuilder::new("not-a-url").build();
279        assert!(result.is_err());
280    }
281
282    #[test]
283    fn builder_from_card_uses_card_url() {
284        use a2a_protocol_types::{AgentCapabilities, AgentCard, AgentInterface};
285
286        let card = AgentCard {
287            name: "test".into(),
288            version: "1.0".into(),
289            description: "A test agent".into(),
290            supported_interfaces: vec![AgentInterface {
291                url: "http://localhost:9090".into(),
292                protocol_binding: "JSONRPC".into(),
293                protocol_version: "1.0.0".into(),
294                tenant: None,
295            }],
296            provider: None,
297            icon_url: None,
298            documentation_url: None,
299            capabilities: AgentCapabilities::none(),
300            security_schemes: None,
301            security_requirements: None,
302            default_input_modes: vec![],
303            default_output_modes: vec![],
304            skills: vec![],
305            signatures: None,
306        };
307
308        let client = ClientBuilder::from_card(&card).build().expect("build");
309        let _ = client;
310    }
311
312    #[test]
313    fn builder_with_timeout_sets_config() {
314        let client = ClientBuilder::new("http://localhost:8080")
315            .with_timeout(Duration::from_secs(60))
316            .build()
317            .expect("build");
318        assert_eq!(client.config().request_timeout, Duration::from_secs(60));
319    }
320}