1use std::fmt;
13
14use a2a_protocol_types::error::{A2aError, ErrorCode};
15use a2a_protocol_types::task::TaskId;
16
17#[derive(Debug)]
23#[non_exhaustive]
24pub enum ServerError {
25 TaskNotFound(TaskId),
27 TaskNotCancelable(TaskId),
29 InvalidParams(String),
31 Serialization(serde_json::Error),
33 Http(hyper::Error),
35 HttpClient(String),
37 Transport(String),
39 PushNotSupported,
41 Internal(String),
43 MethodNotFound(String),
45 Protocol(A2aError),
47 PayloadTooLarge(String),
49 UnsupportedOperation(String),
52 InvalidStateTransition {
54 task_id: TaskId,
56 from: a2a_protocol_types::task::TaskState,
58 to: a2a_protocol_types::task::TaskState,
60 },
61}
62
63impl fmt::Display for ServerError {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 match self {
66 Self::TaskNotFound(id) => write!(f, "task not found: {id}"),
67 Self::TaskNotCancelable(id) => write!(f, "task not cancelable: {id}"),
68 Self::InvalidParams(msg) => write!(f, "invalid params: {msg}"),
69 Self::Serialization(e) => write!(f, "serialization error: {e}"),
70 Self::Http(e) => write!(f, "HTTP error: {e}"),
71 Self::HttpClient(msg) => write!(f, "HTTP client error: {msg}"),
72 Self::Transport(msg) => write!(f, "transport error: {msg}"),
73 Self::PushNotSupported => f.write_str("push notifications not supported"),
74 Self::UnsupportedOperation(msg) => write!(f, "unsupported operation: {msg}"),
75 Self::Internal(msg) => write!(f, "internal error: {msg}"),
76 Self::MethodNotFound(m) => write!(f, "method not found: {m}"),
77 Self::Protocol(e) => write!(f, "protocol error: {e}"),
78 Self::PayloadTooLarge(msg) => write!(f, "payload too large: {msg}"),
79 Self::InvalidStateTransition { task_id, from, to } => {
80 write!(
81 f,
82 "invalid state transition for task {task_id}: {from} → {to}"
83 )
84 }
85 }
86 }
87}
88
89impl std::error::Error for ServerError {
90 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
91 match self {
92 Self::Serialization(e) => Some(e),
93 Self::Http(e) => Some(e),
94 Self::Protocol(e) => Some(e),
95 _ => None,
96 }
97 }
98}
99
100impl ServerError {
101 #[must_use]
116 pub fn to_a2a_error(&self) -> A2aError {
117 match self {
118 Self::TaskNotFound(id) => A2aError::task_not_found(id),
119 Self::TaskNotCancelable(id) => A2aError::task_not_cancelable(id),
120 Self::InvalidParams(msg) => A2aError::invalid_params(msg.clone()),
121 Self::Serialization(e) => A2aError::parse_error(e.to_string()),
122 Self::MethodNotFound(m) => {
123 A2aError::new(ErrorCode::MethodNotFound, format!("Method not found: {m}"))
124 }
125 Self::PushNotSupported => A2aError::new(
126 ErrorCode::PushNotificationNotSupported,
127 "Push notifications not supported",
128 ),
129 Self::UnsupportedOperation(msg) => {
130 A2aError::new(ErrorCode::UnsupportedOperation, msg.clone())
131 }
132 Self::Protocol(e) => e.clone(),
133 Self::Http(e) => A2aError::internal(e.to_string()),
134 Self::HttpClient(msg) | Self::Transport(msg) | Self::Internal(msg) => {
135 A2aError::internal(msg.clone())
136 }
137 Self::PayloadTooLarge(msg) => A2aError::new(ErrorCode::InvalidRequest, msg.clone()),
138 Self::InvalidStateTransition { task_id, from, to } => A2aError::invalid_params(
139 format!("invalid state transition for task {task_id}: {from} → {to}"),
140 ),
141 }
142 }
143}
144
145impl From<A2aError> for ServerError {
148 fn from(e: A2aError) -> Self {
149 Self::Protocol(e)
150 }
151}
152
153impl From<serde_json::Error> for ServerError {
154 fn from(e: serde_json::Error) -> Self {
155 Self::Serialization(e)
156 }
157}
158
159impl From<hyper::Error> for ServerError {
160 fn from(e: hyper::Error) -> Self {
161 Self::Http(e)
162 }
163}
164
165pub type ServerResult<T> = Result<T, ServerError>;
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use std::error::Error;
174
175 #[test]
176 fn source_serialization_returns_some() {
177 let err = ServerError::Serialization(serde_json::from_str::<String>("x").unwrap_err());
178 assert!(err.source().is_some());
179 }
180
181 #[test]
182 fn source_protocol_returns_some() {
183 let err = ServerError::Protocol(A2aError::task_not_found("t"));
184 assert!(err.source().is_some());
185 }
186
187 #[tokio::test]
188 async fn source_http_returns_some() {
189 use tokio::io::AsyncWriteExt;
191 let (mut client, server) = tokio::io::duplex(256);
192 let client_task = tokio::spawn(async move {
194 client.write_all(b"NOT VALID HTTP\r\n\r\n").await.unwrap();
195 client.shutdown().await.unwrap();
196 });
197 let hyper_err = hyper::server::conn::http1::Builder::new()
198 .serve_connection(
199 hyper_util::rt::TokioIo::new(server),
200 hyper::service::service_fn(|_req: hyper::Request<hyper::body::Incoming>| async {
201 Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Full::new(
202 hyper::body::Bytes::new(),
203 )))
204 }),
205 )
206 .await
207 .unwrap_err();
208 client_task.await.unwrap();
209 let err = ServerError::Http(hyper_err);
210 assert!(err.source().is_some());
211 }
212
213 #[test]
214 fn source_transport_returns_none() {
215 let err = ServerError::Transport("test".into());
216 assert!(err.source().is_none());
217 }
218
219 #[test]
220 fn source_task_not_found_returns_none() {
221 let err = ServerError::TaskNotFound("t".into());
222 assert!(err.source().is_none());
223 }
224
225 #[test]
226 fn source_internal_returns_none() {
227 let err = ServerError::Internal("oops".into());
228 assert!(err.source().is_none());
229 }
230
231 #[test]
234 fn display_all_variants() {
235 assert!(ServerError::TaskNotFound("t1".into())
236 .to_string()
237 .contains("t1"));
238 assert!(ServerError::TaskNotCancelable("t2".into())
239 .to_string()
240 .contains("t2"));
241 assert!(ServerError::InvalidParams("bad".into())
242 .to_string()
243 .contains("bad"));
244 assert!(ServerError::HttpClient("conn".into())
245 .to_string()
246 .contains("conn"));
247 assert!(ServerError::Transport("tcp".into())
248 .to_string()
249 .contains("tcp"));
250 assert_eq!(
251 ServerError::PushNotSupported.to_string(),
252 "push notifications not supported"
253 );
254 assert!(ServerError::UnsupportedOperation("cannot do this".into())
255 .to_string()
256 .contains("cannot do this"));
257 assert!(ServerError::Internal("oops".into())
258 .to_string()
259 .contains("oops"));
260 assert!(ServerError::MethodNotFound("foo/bar".into())
261 .to_string()
262 .contains("foo/bar"));
263 assert!(ServerError::Protocol(A2aError::task_not_found("t"))
264 .to_string()
265 .contains("protocol error"));
266 assert!(ServerError::PayloadTooLarge("too big".into())
267 .to_string()
268 .contains("too big"));
269 let ist = ServerError::InvalidStateTransition {
270 task_id: "t3".into(),
271 from: a2a_protocol_types::task::TaskState::Working,
272 to: a2a_protocol_types::task::TaskState::Submitted,
273 };
274 let s = ist.to_string();
275 assert!(s.contains("t3"), "missing task_id: {s}");
276 assert!(
277 s.contains("working") || s.contains("WORKING") || s.contains("Working"),
278 "missing from state: {s}"
279 );
280 }
281
282 #[test]
285 #[allow(clippy::too_many_lines)]
286 fn to_a2a_error_all_variants() {
287 assert_eq!(
288 ServerError::TaskNotFound("t".into()).to_a2a_error().code,
289 ErrorCode::TaskNotFound
290 );
291 assert_eq!(
292 ServerError::TaskNotCancelable("t".into())
293 .to_a2a_error()
294 .code,
295 ErrorCode::TaskNotCancelable
296 );
297 assert_eq!(
298 ServerError::InvalidParams("x".into()).to_a2a_error().code,
299 ErrorCode::InvalidParams
300 );
301 assert_eq!(
302 ServerError::Serialization(serde_json::from_str::<String>("x").unwrap_err())
303 .to_a2a_error()
304 .code,
305 ErrorCode::ParseError
306 );
307 assert_eq!(
308 ServerError::MethodNotFound("m".into()).to_a2a_error().code,
309 ErrorCode::MethodNotFound
310 );
311 assert_eq!(
312 ServerError::PushNotSupported.to_a2a_error().code,
313 ErrorCode::PushNotificationNotSupported
314 );
315 assert_eq!(
316 ServerError::UnsupportedOperation("test".into())
317 .to_a2a_error()
318 .code,
319 ErrorCode::UnsupportedOperation
320 );
321 assert_eq!(
322 ServerError::Protocol(A2aError::task_not_found("t"))
323 .to_a2a_error()
324 .code,
325 ErrorCode::TaskNotFound
326 );
327 assert_eq!(
328 ServerError::HttpClient("x".into()).to_a2a_error().code,
329 ErrorCode::InternalError
330 );
331 assert_eq!(
332 ServerError::Transport("x".into()).to_a2a_error().code,
333 ErrorCode::InternalError
334 );
335 assert_eq!(
336 ServerError::Internal("x".into()).to_a2a_error().code,
337 ErrorCode::InternalError
338 );
339 assert_eq!(
340 ServerError::PayloadTooLarge("x".into()).to_a2a_error().code,
341 ErrorCode::InvalidRequest
342 );
343 let ist = ServerError::InvalidStateTransition {
344 task_id: "t".into(),
345 from: a2a_protocol_types::task::TaskState::Working,
346 to: a2a_protocol_types::task::TaskState::Submitted,
347 };
348 assert_eq!(ist.to_a2a_error().code, ErrorCode::InvalidParams);
349 }
350
351 #[test]
354 fn from_a2a_error() {
355 let e: ServerError = A2aError::internal("test").into();
356 assert!(matches!(e, ServerError::Protocol(_)));
357 }
358
359 #[test]
360 fn from_serde_error() {
361 let e: ServerError = serde_json::from_str::<String>("bad").unwrap_err().into();
362 assert!(matches!(e, ServerError::Serialization(_)));
363 }
364
365 #[tokio::test]
367 async fn display_http_variant() {
368 use tokio::io::AsyncWriteExt;
369 let (mut client, server) = tokio::io::duplex(256);
370 let client_task = tokio::spawn(async move {
371 client.write_all(b"NOT VALID HTTP\r\n\r\n").await.unwrap();
372 client.shutdown().await.unwrap();
373 });
374 let hyper_err = hyper::server::conn::http1::Builder::new()
375 .serve_connection(
376 hyper_util::rt::TokioIo::new(server),
377 hyper::service::service_fn(|_req: hyper::Request<hyper::body::Incoming>| async {
378 Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Full::new(
379 hyper::body::Bytes::new(),
380 )))
381 }),
382 )
383 .await
384 .unwrap_err();
385 client_task.await.unwrap();
386 let err = ServerError::Http(hyper_err);
387 let display = err.to_string();
388 assert!(
389 display.contains("HTTP error"),
390 "Display for Http variant should contain 'HTTP error', got: {display}"
391 );
392 }
393
394 #[tokio::test]
396 async fn from_hyper_error() {
397 use tokio::io::AsyncWriteExt;
398 let (mut client, server) = tokio::io::duplex(256);
399 let client_task = tokio::spawn(async move {
400 client.write_all(b"NOT VALID HTTP\r\n\r\n").await.unwrap();
401 client.shutdown().await.unwrap();
402 });
403 let hyper_err = hyper::server::conn::http1::Builder::new()
404 .serve_connection(
405 hyper_util::rt::TokioIo::new(server),
406 hyper::service::service_fn(|_req: hyper::Request<hyper::body::Incoming>| async {
407 Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Full::new(
408 hyper::body::Bytes::new(),
409 )))
410 }),
411 )
412 .await
413 .unwrap_err();
414 client_task.await.unwrap();
415 let e: ServerError = hyper_err.into();
416 assert!(matches!(e, ServerError::Http(_)));
417 }
418
419 #[test]
421 fn display_serialization_variant() {
422 let err = ServerError::Serialization(serde_json::from_str::<String>("x").unwrap_err());
423 let display = err.to_string();
424 assert!(
425 display.contains("serialization error"),
426 "Display for Serialization should contain 'serialization error', got: {display}"
427 );
428 }
429
430 #[tokio::test]
432 async fn to_a2a_error_http_variant() {
433 use tokio::io::AsyncWriteExt;
434 let (mut client, server) = tokio::io::duplex(256);
435 let client_task = tokio::spawn(async move {
436 client.write_all(b"NOT VALID HTTP\r\n\r\n").await.unwrap();
437 client.shutdown().await.unwrap();
438 });
439 let hyper_err = hyper::server::conn::http1::Builder::new()
440 .serve_connection(
441 hyper_util::rt::TokioIo::new(server),
442 hyper::service::service_fn(|_req: hyper::Request<hyper::body::Incoming>| async {
443 Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Full::new(
444 hyper::body::Bytes::new(),
445 )))
446 }),
447 )
448 .await
449 .unwrap_err();
450 client_task.await.unwrap();
451 let err = ServerError::Http(hyper_err);
452 let a2a_err = err.to_a2a_error();
453 assert_eq!(a2a_err.code, ErrorCode::InternalError);
454 }
455}