1use std::fmt;
13
14use a2a_protocol_types::error::{A2aError, ErrorCode};
15use a2a_protocol_types::task::TaskId;
16
17#[derive(Debug)]
23#[non_exhaustive]
24pub enum ServerError {
25 TaskNotFound(TaskId),
27 TaskNotCancelable(TaskId),
29 InvalidParams(String),
31 Serialization(serde_json::Error),
33 Http(hyper::Error),
35 HttpClient(String),
37 Transport(String),
39 PushNotSupported,
41 Internal(String),
43 MethodNotFound(String),
45 Protocol(A2aError),
47 PayloadTooLarge(String),
49 InvalidStateTransition {
51 task_id: TaskId,
53 from: a2a_protocol_types::task::TaskState,
55 to: a2a_protocol_types::task::TaskState,
57 },
58}
59
60impl fmt::Display for ServerError {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 match self {
63 Self::TaskNotFound(id) => write!(f, "task not found: {id}"),
64 Self::TaskNotCancelable(id) => write!(f, "task not cancelable: {id}"),
65 Self::InvalidParams(msg) => write!(f, "invalid params: {msg}"),
66 Self::Serialization(e) => write!(f, "serialization error: {e}"),
67 Self::Http(e) => write!(f, "HTTP error: {e}"),
68 Self::HttpClient(msg) => write!(f, "HTTP client error: {msg}"),
69 Self::Transport(msg) => write!(f, "transport error: {msg}"),
70 Self::PushNotSupported => f.write_str("push notifications not supported"),
71 Self::Internal(msg) => write!(f, "internal error: {msg}"),
72 Self::MethodNotFound(m) => write!(f, "method not found: {m}"),
73 Self::Protocol(e) => write!(f, "protocol error: {e}"),
74 Self::PayloadTooLarge(msg) => write!(f, "payload too large: {msg}"),
75 Self::InvalidStateTransition { task_id, from, to } => {
76 write!(
77 f,
78 "invalid state transition for task {task_id}: {from} → {to}"
79 )
80 }
81 }
82 }
83}
84
85impl std::error::Error for ServerError {
86 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
87 match self {
88 Self::Serialization(e) => Some(e),
89 Self::Http(e) => Some(e),
90 Self::Protocol(e) => Some(e),
91 _ => None,
92 }
93 }
94}
95
96impl ServerError {
97 #[must_use]
111 pub fn to_a2a_error(&self) -> A2aError {
112 match self {
113 Self::TaskNotFound(id) => A2aError::task_not_found(id),
114 Self::TaskNotCancelable(id) => A2aError::task_not_cancelable(id),
115 Self::InvalidParams(msg) => A2aError::invalid_params(msg.clone()),
116 Self::Serialization(e) => A2aError::parse_error(e.to_string()),
117 Self::MethodNotFound(m) => {
118 A2aError::new(ErrorCode::MethodNotFound, format!("Method not found: {m}"))
119 }
120 Self::PushNotSupported => A2aError::new(
121 ErrorCode::PushNotificationNotSupported,
122 "Push notifications not supported",
123 ),
124 Self::Protocol(e) => e.clone(),
125 Self::Http(e) => A2aError::internal(e.to_string()),
126 Self::HttpClient(msg)
127 | Self::Transport(msg)
128 | Self::Internal(msg)
129 | Self::PayloadTooLarge(msg) => A2aError::internal(msg.clone()),
130 Self::InvalidStateTransition { task_id, from, to } => A2aError::invalid_params(
131 format!("invalid state transition for task {task_id}: {from} → {to}"),
132 ),
133 }
134 }
135}
136
137impl From<A2aError> for ServerError {
140 fn from(e: A2aError) -> Self {
141 Self::Protocol(e)
142 }
143}
144
145impl From<serde_json::Error> for ServerError {
146 fn from(e: serde_json::Error) -> Self {
147 Self::Serialization(e)
148 }
149}
150
151impl From<hyper::Error> for ServerError {
152 fn from(e: hyper::Error) -> Self {
153 Self::Http(e)
154 }
155}
156
157pub type ServerResult<T> = Result<T, ServerError>;
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use std::error::Error;
166
167 #[test]
168 fn source_serialization_returns_some() {
169 let err = ServerError::Serialization(serde_json::from_str::<String>("x").unwrap_err());
170 assert!(err.source().is_some());
171 }
172
173 #[test]
174 fn source_protocol_returns_some() {
175 let err = ServerError::Protocol(A2aError::task_not_found("t"));
176 assert!(err.source().is_some());
177 }
178
179 #[tokio::test]
180 async fn source_http_returns_some() {
181 use tokio::io::AsyncWriteExt;
183 let (mut client, server) = tokio::io::duplex(256);
184 let client_task = tokio::spawn(async move {
186 client.write_all(b"NOT VALID HTTP\r\n\r\n").await.unwrap();
187 client.shutdown().await.unwrap();
188 });
189 let hyper_err = hyper::server::conn::http1::Builder::new()
190 .serve_connection(
191 hyper_util::rt::TokioIo::new(server),
192 hyper::service::service_fn(|_req: hyper::Request<hyper::body::Incoming>| async {
193 Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Full::new(
194 hyper::body::Bytes::new(),
195 )))
196 }),
197 )
198 .await
199 .unwrap_err();
200 client_task.await.unwrap();
201 let err = ServerError::Http(hyper_err);
202 assert!(err.source().is_some());
203 }
204
205 #[test]
206 fn source_transport_returns_none() {
207 let err = ServerError::Transport("test".into());
208 assert!(err.source().is_none());
209 }
210
211 #[test]
212 fn source_task_not_found_returns_none() {
213 let err = ServerError::TaskNotFound("t".into());
214 assert!(err.source().is_none());
215 }
216
217 #[test]
218 fn source_internal_returns_none() {
219 let err = ServerError::Internal("oops".into());
220 assert!(err.source().is_none());
221 }
222
223 #[test]
226 fn display_all_variants() {
227 assert!(ServerError::TaskNotFound("t1".into())
228 .to_string()
229 .contains("t1"));
230 assert!(ServerError::TaskNotCancelable("t2".into())
231 .to_string()
232 .contains("t2"));
233 assert!(ServerError::InvalidParams("bad".into())
234 .to_string()
235 .contains("bad"));
236 assert!(ServerError::HttpClient("conn".into())
237 .to_string()
238 .contains("conn"));
239 assert!(ServerError::Transport("tcp".into())
240 .to_string()
241 .contains("tcp"));
242 assert_eq!(
243 ServerError::PushNotSupported.to_string(),
244 "push notifications not supported"
245 );
246 assert!(ServerError::Internal("oops".into())
247 .to_string()
248 .contains("oops"));
249 assert!(ServerError::MethodNotFound("foo/bar".into())
250 .to_string()
251 .contains("foo/bar"));
252 assert!(ServerError::Protocol(A2aError::task_not_found("t"))
253 .to_string()
254 .contains("protocol error"));
255 assert!(ServerError::PayloadTooLarge("too big".into())
256 .to_string()
257 .contains("too big"));
258 let ist = ServerError::InvalidStateTransition {
259 task_id: "t3".into(),
260 from: a2a_protocol_types::task::TaskState::Working,
261 to: a2a_protocol_types::task::TaskState::Submitted,
262 };
263 let s = ist.to_string();
264 assert!(s.contains("t3"), "missing task_id: {s}");
265 assert!(
266 s.contains("working") || s.contains("WORKING") || s.contains("Working"),
267 "missing from state: {s}"
268 );
269 }
270
271 #[test]
274 fn to_a2a_error_all_variants() {
275 assert_eq!(
276 ServerError::TaskNotFound("t".into()).to_a2a_error().code,
277 ErrorCode::TaskNotFound
278 );
279 assert_eq!(
280 ServerError::TaskNotCancelable("t".into())
281 .to_a2a_error()
282 .code,
283 ErrorCode::TaskNotCancelable
284 );
285 assert_eq!(
286 ServerError::InvalidParams("x".into()).to_a2a_error().code,
287 ErrorCode::InvalidParams
288 );
289 assert_eq!(
290 ServerError::Serialization(serde_json::from_str::<String>("x").unwrap_err())
291 .to_a2a_error()
292 .code,
293 ErrorCode::ParseError
294 );
295 assert_eq!(
296 ServerError::MethodNotFound("m".into()).to_a2a_error().code,
297 ErrorCode::MethodNotFound
298 );
299 assert_eq!(
300 ServerError::PushNotSupported.to_a2a_error().code,
301 ErrorCode::PushNotificationNotSupported
302 );
303 assert_eq!(
304 ServerError::Protocol(A2aError::task_not_found("t"))
305 .to_a2a_error()
306 .code,
307 ErrorCode::TaskNotFound
308 );
309 assert_eq!(
310 ServerError::HttpClient("x".into()).to_a2a_error().code,
311 ErrorCode::InternalError
312 );
313 assert_eq!(
314 ServerError::Transport("x".into()).to_a2a_error().code,
315 ErrorCode::InternalError
316 );
317 assert_eq!(
318 ServerError::Internal("x".into()).to_a2a_error().code,
319 ErrorCode::InternalError
320 );
321 assert_eq!(
322 ServerError::PayloadTooLarge("x".into()).to_a2a_error().code,
323 ErrorCode::InternalError
324 );
325 let ist = ServerError::InvalidStateTransition {
326 task_id: "t".into(),
327 from: a2a_protocol_types::task::TaskState::Working,
328 to: a2a_protocol_types::task::TaskState::Submitted,
329 };
330 assert_eq!(ist.to_a2a_error().code, ErrorCode::InvalidParams);
331 }
332
333 #[test]
336 fn from_a2a_error() {
337 let e: ServerError = A2aError::internal("test").into();
338 assert!(matches!(e, ServerError::Protocol(_)));
339 }
340
341 #[test]
342 fn from_serde_error() {
343 let e: ServerError = serde_json::from_str::<String>("bad").unwrap_err().into();
344 assert!(matches!(e, ServerError::Serialization(_)));
345 }
346
347 #[tokio::test]
349 async fn display_http_variant() {
350 use tokio::io::AsyncWriteExt;
351 let (mut client, server) = tokio::io::duplex(256);
352 let client_task = tokio::spawn(async move {
353 client.write_all(b"NOT VALID HTTP\r\n\r\n").await.unwrap();
354 client.shutdown().await.unwrap();
355 });
356 let hyper_err = hyper::server::conn::http1::Builder::new()
357 .serve_connection(
358 hyper_util::rt::TokioIo::new(server),
359 hyper::service::service_fn(|_req: hyper::Request<hyper::body::Incoming>| async {
360 Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Full::new(
361 hyper::body::Bytes::new(),
362 )))
363 }),
364 )
365 .await
366 .unwrap_err();
367 client_task.await.unwrap();
368 let err = ServerError::Http(hyper_err);
369 let display = err.to_string();
370 assert!(
371 display.contains("HTTP error"),
372 "Display for Http variant should contain 'HTTP error', got: {display}"
373 );
374 }
375
376 #[tokio::test]
378 async fn from_hyper_error() {
379 use tokio::io::AsyncWriteExt;
380 let (mut client, server) = tokio::io::duplex(256);
381 let client_task = tokio::spawn(async move {
382 client.write_all(b"NOT VALID HTTP\r\n\r\n").await.unwrap();
383 client.shutdown().await.unwrap();
384 });
385 let hyper_err = hyper::server::conn::http1::Builder::new()
386 .serve_connection(
387 hyper_util::rt::TokioIo::new(server),
388 hyper::service::service_fn(|_req: hyper::Request<hyper::body::Incoming>| async {
389 Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Full::new(
390 hyper::body::Bytes::new(),
391 )))
392 }),
393 )
394 .await
395 .unwrap_err();
396 client_task.await.unwrap();
397 let e: ServerError = hyper_err.into();
398 assert!(matches!(e, ServerError::Http(_)));
399 }
400
401 #[test]
403 fn display_serialization_variant() {
404 let err = ServerError::Serialization(serde_json::from_str::<String>("x").unwrap_err());
405 let display = err.to_string();
406 assert!(
407 display.contains("serialization error"),
408 "Display for Serialization should contain 'serialization error', got: {display}"
409 );
410 }
411
412 #[tokio::test]
414 async fn to_a2a_error_http_variant() {
415 use tokio::io::AsyncWriteExt;
416 let (mut client, server) = tokio::io::duplex(256);
417 let client_task = tokio::spawn(async move {
418 client.write_all(b"NOT VALID HTTP\r\n\r\n").await.unwrap();
419 client.shutdown().await.unwrap();
420 });
421 let hyper_err = hyper::server::conn::http1::Builder::new()
422 .serve_connection(
423 hyper_util::rt::TokioIo::new(server),
424 hyper::service::service_fn(|_req: hyper::Request<hyper::body::Incoming>| async {
425 Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Full::new(
426 hyper::body::Bytes::new(),
427 )))
428 }),
429 )
430 .await
431 .unwrap_err();
432 client_task.await.unwrap();
433 let err = ServerError::Http(hyper_err);
434 let a2a_err = err.to_a2a_error();
435 assert_eq!(a2a_err.code, ErrorCode::InternalError);
436 }
437}