1use std::fmt;
12
13use a2a_protocol_types::{A2aError, TaskId};
14
15#[derive(Debug)]
19#[non_exhaustive]
20pub enum ClientError {
21 Http(hyper::Error),
23
24 HttpClient(String),
26
27 Serialization(serde_json::Error),
29
30 Protocol(A2aError),
32
33 Transport(String),
35
36 InvalidEndpoint(String),
38
39 UnexpectedStatus {
41 status: u16,
43 body: String,
45 },
46
47 AuthRequired {
49 task_id: TaskId,
51 },
52
53 Timeout(String),
55
56 ProtocolBindingMismatch(String),
62}
63
64impl fmt::Display for ClientError {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match self {
67 Self::Http(e) => write!(f, "HTTP error: {e}"),
68 Self::HttpClient(msg) => write!(f, "HTTP client error: {msg}"),
69 Self::Serialization(e) => write!(f, "serialization error: {e}"),
70 Self::Protocol(e) => write!(f, "protocol error: {e}"),
71 Self::Transport(msg) => write!(f, "transport error: {msg}"),
72 Self::InvalidEndpoint(msg) => write!(f, "invalid endpoint: {msg}"),
73 Self::UnexpectedStatus { status, body } => {
74 write!(f, "unexpected HTTP status {status}: {body}")
75 }
76 Self::AuthRequired { task_id } => {
77 write!(f, "authentication required for task: {task_id}")
78 }
79 Self::Timeout(msg) => write!(f, "timeout: {msg}"),
80 Self::ProtocolBindingMismatch(msg) => {
81 write!(
82 f,
83 "protocol binding mismatch: {msg}; check the agent card's supported_interfaces"
84 )
85 }
86 }
87 }
88}
89
90impl std::error::Error for ClientError {
91 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
92 match self {
93 Self::Http(e) => Some(e),
94 Self::Serialization(e) => Some(e),
95 Self::Protocol(e) => Some(e),
96 _ => None,
97 }
98 }
99}
100
101impl From<A2aError> for ClientError {
102 fn from(e: A2aError) -> Self {
103 Self::Protocol(e)
104 }
105}
106
107impl From<hyper::Error> for ClientError {
108 fn from(e: hyper::Error) -> Self {
109 Self::Http(e)
110 }
111}
112
113impl From<serde_json::Error> for ClientError {
114 fn from(e: serde_json::Error) -> Self {
115 Self::Serialization(e)
116 }
117}
118
119pub type ClientResult<T> = Result<T, ClientError>;
123
124#[cfg(test)]
127mod tests {
128 use super::*;
129 use a2a_protocol_types::ErrorCode;
130
131 #[test]
132 fn client_error_display_http_client() {
133 let e = ClientError::HttpClient("connection refused".into());
134 assert!(e.to_string().contains("connection refused"));
135 }
136
137 #[test]
138 fn client_error_display_protocol() {
139 let a2a = A2aError::task_not_found("task-99");
140 let e = ClientError::Protocol(a2a);
141 assert!(e.to_string().contains("task-99"));
142 }
143
144 #[test]
145 fn client_error_from_a2a_error() {
146 let a2a = A2aError::new(ErrorCode::TaskNotFound, "missing");
147 let e: ClientError = a2a.into();
148 assert!(matches!(e, ClientError::Protocol(_)));
149 }
150
151 #[test]
152 fn client_error_unexpected_status() {
153 let e = ClientError::UnexpectedStatus {
154 status: 404,
155 body: "Not Found".into(),
156 };
157 assert!(e.to_string().contains("404"));
158 }
159
160 #[test]
166 fn timeout_is_retryable_transport_is_not() {
167 let timeout = ClientError::Timeout("request timed out".into());
168 assert!(timeout.is_retryable(), "Timeout errors must be retryable");
169
170 let transport = ClientError::Transport("config error".into());
171 assert!(
172 !transport.is_retryable(),
173 "Transport errors must not be retryable"
174 );
175 }
176
177 #[test]
178 fn client_error_source_http() {
179 use std::error::Error;
180 let http_err: ClientError = ClientError::HttpClient("test".into());
183 assert!(http_err.source().is_none());
185
186 let ser_err =
188 ClientError::Serialization(serde_json::from_str::<String>("not json").unwrap_err());
189 assert!(
190 ser_err.source().is_some(),
191 "Serialization error should have a source"
192 );
193
194 let proto_err = ClientError::Protocol(a2a_protocol_types::A2aError::task_not_found("t"));
196 assert!(
197 proto_err.source().is_some(),
198 "Protocol error should have a source"
199 );
200
201 let transport_err = ClientError::Transport("config".into());
203 assert!(transport_err.source().is_none());
204 }
205
206 #[test]
209 fn client_error_display_transport() {
210 let e = ClientError::Transport("socket closed".into());
211 let s = e.to_string();
212 assert!(s.contains("transport error"), "missing prefix: {s}");
213 assert!(s.contains("socket closed"), "missing message: {s}");
214 }
215
216 #[test]
217 fn client_error_display_invalid_endpoint() {
218 let e = ClientError::InvalidEndpoint("bad url".into());
219 let s = e.to_string();
220 assert!(s.contains("invalid endpoint"), "missing prefix: {s}");
221 assert!(s.contains("bad url"), "missing message: {s}");
222 }
223
224 #[test]
225 fn client_error_display_auth_required() {
226 let e = ClientError::AuthRequired {
227 task_id: TaskId::new("task-7"),
228 };
229 let s = e.to_string();
230 assert!(s.contains("authentication required"), "missing prefix: {s}");
231 assert!(s.contains("task-7"), "missing task_id: {s}");
232 }
233
234 #[test]
235 fn client_error_display_timeout() {
236 let e = ClientError::Timeout("30s elapsed".into());
237 let s = e.to_string();
238 assert!(s.contains("timeout"), "missing prefix: {s}");
239 assert!(s.contains("30s elapsed"), "missing message: {s}");
240 }
241
242 #[test]
243 fn client_error_display_protocol_binding_mismatch() {
244 let e = ClientError::ProtocolBindingMismatch("expected REST".into());
245 let s = e.to_string();
246 assert!(
247 s.contains("protocol binding mismatch"),
248 "missing prefix: {s}"
249 );
250 assert!(s.contains("expected REST"), "missing message: {s}");
251 assert!(s.contains("supported_interfaces"), "missing advice: {s}");
252 }
253
254 #[test]
255 fn client_error_display_serialization() {
256 let e = ClientError::Serialization(serde_json::from_str::<String>("bad").unwrap_err());
257 let s = e.to_string();
258 assert!(s.contains("serialization error"), "missing prefix: {s}");
259 }
260
261 #[test]
262 fn client_error_display_unexpected_status() {
263 let e = ClientError::UnexpectedStatus {
264 status: 500,
265 body: "Internal Server Error".into(),
266 };
267 let s = e.to_string();
268 assert!(s.contains("500"), "missing status code: {s}");
269 assert!(s.contains("Internal Server Error"), "missing body: {s}");
270 }
271
272 #[test]
275 fn client_error_source_none_for_string_variants() {
276 use std::error::Error;
277 let cases: Vec<ClientError> = vec![
278 ClientError::HttpClient("msg".into()),
279 ClientError::Transport("msg".into()),
280 ClientError::InvalidEndpoint("msg".into()),
281 ClientError::UnexpectedStatus {
282 status: 404,
283 body: String::new(),
284 },
285 ClientError::AuthRequired {
286 task_id: TaskId::new("t"),
287 },
288 ClientError::Timeout("msg".into()),
289 ClientError::ProtocolBindingMismatch("msg".into()),
290 ];
291 for e in &cases {
292 assert!(
293 e.source().is_none(),
294 "{:?} should have no source",
295 std::mem::discriminant(e)
296 );
297 }
298 }
299
300 #[tokio::test]
304 async fn client_error_display_and_source_http() {
305 use http_body_util::{BodyExt, Full};
306 use hyper::body::Bytes;
307 use tokio::io::AsyncWriteExt;
308
309 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
312 let addr = listener.local_addr().unwrap();
313
314 tokio::spawn(async move {
315 let (mut stream, _) = listener.accept().await.unwrap();
316 let mut buf = [0u8; 4096];
318 let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buf).await;
319 let resp = "HTTP/1.1 200 OK\r\ncontent-length: 1000\r\n\r\nhello";
321 let _ = stream.write_all(resp.as_bytes()).await;
322 drop(stream);
324 });
325
326 let client: hyper_util::client::legacy::Client<
327 hyper_util::client::legacy::connect::HttpConnector,
328 Full<Bytes>,
329 > = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
330 .build(hyper_util::client::legacy::connect::HttpConnector::new());
331
332 let req = hyper::Request::builder()
333 .uri(format!("http://127.0.0.1:{}", addr.port()))
334 .body(Full::new(Bytes::new()))
335 .unwrap();
336
337 let resp = client.request(req).await.unwrap();
338 let body_result = resp.collect().await;
340 if let Err(hyper_err) = body_result {
341 use std::error::Error;
342
343 let client_err: ClientError = ClientError::Http(hyper_err);
345
346 let display = client_err.to_string();
348 assert!(display.contains("HTTP error"), "Display: {display}");
349
350 assert!(
352 client_err.source().is_some(),
353 "Http variant should have a source"
354 );
355 } else {
356 }
359 }
360
361 #[test]
364 fn client_error_from_serde_json_error() {
365 let serde_err = serde_json::from_str::<String>("not json").unwrap_err();
366 let e: ClientError = serde_err.into();
367 assert!(matches!(e, ClientError::Serialization(_)));
368 }
369
370 #[test]
372 fn retryable_classification_exhaustive() {
373 assert!(ClientError::HttpClient("conn reset".into()).is_retryable());
375 assert!(ClientError::Timeout("deadline".into()).is_retryable());
376 assert!(ClientError::UnexpectedStatus {
377 status: 429,
378 body: String::new()
379 }
380 .is_retryable());
381 assert!(ClientError::UnexpectedStatus {
382 status: 502,
383 body: String::new()
384 }
385 .is_retryable());
386 assert!(ClientError::UnexpectedStatus {
387 status: 503,
388 body: String::new()
389 }
390 .is_retryable());
391 assert!(ClientError::UnexpectedStatus {
392 status: 504,
393 body: String::new()
394 }
395 .is_retryable());
396
397 assert!(!ClientError::Transport("bad config".into()).is_retryable());
399 assert!(!ClientError::InvalidEndpoint("bad url".into()).is_retryable());
400 assert!(!ClientError::UnexpectedStatus {
401 status: 400,
402 body: String::new()
403 }
404 .is_retryable());
405 assert!(!ClientError::UnexpectedStatus {
406 status: 401,
407 body: String::new()
408 }
409 .is_retryable());
410 assert!(!ClientError::UnexpectedStatus {
411 status: 404,
412 body: String::new()
413 }
414 .is_retryable());
415 assert!(!ClientError::ProtocolBindingMismatch("wrong".into()).is_retryable());
416 assert!(!ClientError::AuthRequired {
417 task_id: TaskId::new("t")
418 }
419 .is_retryable());
420 }
421}