use std::fmt;
use a2a_protocol_types::{A2aError, TaskId};
#[derive(Debug)]
#[non_exhaustive]
pub enum ClientError {
Http(hyper::Error),
HttpClient(String),
Serialization(serde_json::Error),
Protocol(A2aError),
Transport(String),
InvalidEndpoint(String),
UnexpectedStatus {
status: u16,
body: String,
},
AuthRequired {
task_id: TaskId,
},
Timeout(String),
ProtocolBindingMismatch(String),
}
impl fmt::Display for ClientError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Http(e) => write!(f, "HTTP error: {e}"),
Self::HttpClient(msg) => write!(f, "HTTP client error: {msg}"),
Self::Serialization(e) => write!(f, "serialization error: {e}"),
Self::Protocol(e) => write!(f, "protocol error: {e}"),
Self::Transport(msg) => write!(f, "transport error: {msg}"),
Self::InvalidEndpoint(msg) => write!(f, "invalid endpoint: {msg}"),
Self::UnexpectedStatus { status, body } => {
write!(f, "unexpected HTTP status {status}: {body}")
}
Self::AuthRequired { task_id } => {
write!(f, "authentication required for task: {task_id}")
}
Self::Timeout(msg) => write!(f, "timeout: {msg}"),
Self::ProtocolBindingMismatch(msg) => {
write!(
f,
"protocol binding mismatch: {msg}; check the agent card's supported_interfaces"
)
}
}
}
}
impl std::error::Error for ClientError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Http(e) => Some(e),
Self::Serialization(e) => Some(e),
Self::Protocol(e) => Some(e),
_ => None,
}
}
}
impl From<A2aError> for ClientError {
fn from(e: A2aError) -> Self {
Self::Protocol(e)
}
}
impl From<hyper::Error> for ClientError {
fn from(e: hyper::Error) -> Self {
Self::Http(e)
}
}
impl From<serde_json::Error> for ClientError {
fn from(e: serde_json::Error) -> Self {
Self::Serialization(e)
}
}
pub type ClientResult<T> = Result<T, ClientError>;
#[cfg(test)]
mod tests {
use super::*;
use a2a_protocol_types::ErrorCode;
#[test]
fn client_error_display_http_client() {
let e = ClientError::HttpClient("connection refused".into());
assert!(e.to_string().contains("connection refused"));
}
#[test]
fn client_error_display_protocol() {
let a2a = A2aError::task_not_found("task-99");
let e = ClientError::Protocol(a2a);
assert!(e.to_string().contains("task-99"));
}
#[test]
fn client_error_from_a2a_error() {
let a2a = A2aError::new(ErrorCode::TaskNotFound, "missing");
let e: ClientError = a2a.into();
assert!(matches!(e, ClientError::Protocol(_)));
}
#[test]
fn client_error_unexpected_status() {
let e = ClientError::UnexpectedStatus {
status: 404,
body: "Not Found".into(),
};
assert!(e.to_string().contains("404"));
}
#[test]
fn timeout_is_retryable_transport_is_not() {
let timeout = ClientError::Timeout("request timed out".into());
assert!(timeout.is_retryable(), "Timeout errors must be retryable");
let transport = ClientError::Transport("config error".into());
assert!(
!transport.is_retryable(),
"Transport errors must not be retryable"
);
}
#[test]
fn client_error_source_http() {
use std::error::Error;
let http_err: ClientError = ClientError::HttpClient("test".into());
assert!(http_err.source().is_none());
let ser_err =
ClientError::Serialization(serde_json::from_str::<String>("not json").unwrap_err());
assert!(
ser_err.source().is_some(),
"Serialization error should have a source"
);
let proto_err = ClientError::Protocol(a2a_protocol_types::A2aError::task_not_found("t"));
assert!(
proto_err.source().is_some(),
"Protocol error should have a source"
);
let transport_err = ClientError::Transport("config".into());
assert!(transport_err.source().is_none());
}
#[test]
fn client_error_display_transport() {
let e = ClientError::Transport("socket closed".into());
let s = e.to_string();
assert!(s.contains("transport error"), "missing prefix: {s}");
assert!(s.contains("socket closed"), "missing message: {s}");
}
#[test]
fn client_error_display_invalid_endpoint() {
let e = ClientError::InvalidEndpoint("bad url".into());
let s = e.to_string();
assert!(s.contains("invalid endpoint"), "missing prefix: {s}");
assert!(s.contains("bad url"), "missing message: {s}");
}
#[test]
fn client_error_display_auth_required() {
let e = ClientError::AuthRequired {
task_id: TaskId::new("task-7"),
};
let s = e.to_string();
assert!(s.contains("authentication required"), "missing prefix: {s}");
assert!(s.contains("task-7"), "missing task_id: {s}");
}
#[test]
fn client_error_display_timeout() {
let e = ClientError::Timeout("30s elapsed".into());
let s = e.to_string();
assert!(s.contains("timeout"), "missing prefix: {s}");
assert!(s.contains("30s elapsed"), "missing message: {s}");
}
#[test]
fn client_error_display_protocol_binding_mismatch() {
let e = ClientError::ProtocolBindingMismatch("expected REST".into());
let s = e.to_string();
assert!(
s.contains("protocol binding mismatch"),
"missing prefix: {s}"
);
assert!(s.contains("expected REST"), "missing message: {s}");
assert!(s.contains("supported_interfaces"), "missing advice: {s}");
}
#[test]
fn client_error_display_serialization() {
let e = ClientError::Serialization(serde_json::from_str::<String>("bad").unwrap_err());
let s = e.to_string();
assert!(s.contains("serialization error"), "missing prefix: {s}");
}
#[test]
fn client_error_display_unexpected_status() {
let e = ClientError::UnexpectedStatus {
status: 500,
body: "Internal Server Error".into(),
};
let s = e.to_string();
assert!(s.contains("500"), "missing status code: {s}");
assert!(s.contains("Internal Server Error"), "missing body: {s}");
}
#[test]
fn client_error_source_none_for_string_variants() {
use std::error::Error;
let cases: Vec<ClientError> = vec![
ClientError::HttpClient("msg".into()),
ClientError::Transport("msg".into()),
ClientError::InvalidEndpoint("msg".into()),
ClientError::UnexpectedStatus {
status: 404,
body: String::new(),
},
ClientError::AuthRequired {
task_id: TaskId::new("t"),
},
ClientError::Timeout("msg".into()),
ClientError::ProtocolBindingMismatch("msg".into()),
];
for e in &cases {
assert!(
e.source().is_none(),
"{:?} should have no source",
std::mem::discriminant(e)
);
}
}
#[tokio::test]
async fn client_error_display_and_source_http() {
use http_body_util::{BodyExt, Full};
use hyper::body::Bytes;
use tokio::io::AsyncWriteExt;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buf).await;
let resp = "HTTP/1.1 200 OK\r\ncontent-length: 1000\r\n\r\nhello";
let _ = stream.write_all(resp.as_bytes()).await;
drop(stream);
});
let client: hyper_util::client::legacy::Client<
hyper_util::client::legacy::connect::HttpConnector,
Full<Bytes>,
> = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
.build(hyper_util::client::legacy::connect::HttpConnector::new());
let req = hyper::Request::builder()
.uri(format!("http://127.0.0.1:{}", addr.port()))
.body(Full::new(Bytes::new()))
.unwrap();
let resp = client.request(req).await.unwrap();
let body_result = resp.collect().await;
if let Err(hyper_err) = body_result {
use std::error::Error;
let client_err: ClientError = ClientError::Http(hyper_err);
let display = client_err.to_string();
assert!(display.contains("HTTP error"), "Display: {display}");
assert!(
client_err.source().is_some(),
"Http variant should have a source"
);
} else {
}
}
#[test]
fn client_error_from_serde_json_error() {
let serde_err = serde_json::from_str::<String>("not json").unwrap_err();
let e: ClientError = serde_err.into();
assert!(matches!(e, ClientError::Serialization(_)));
}
#[test]
fn retryable_classification_exhaustive() {
assert!(ClientError::HttpClient("conn reset".into()).is_retryable());
assert!(ClientError::Timeout("deadline".into()).is_retryable());
assert!(ClientError::UnexpectedStatus {
status: 429,
body: String::new()
}
.is_retryable());
assert!(ClientError::UnexpectedStatus {
status: 502,
body: String::new()
}
.is_retryable());
assert!(ClientError::UnexpectedStatus {
status: 503,
body: String::new()
}
.is_retryable());
assert!(ClientError::UnexpectedStatus {
status: 504,
body: String::new()
}
.is_retryable());
assert!(!ClientError::Transport("bad config".into()).is_retryable());
assert!(!ClientError::InvalidEndpoint("bad url".into()).is_retryable());
assert!(!ClientError::UnexpectedStatus {
status: 400,
body: String::new()
}
.is_retryable());
assert!(!ClientError::UnexpectedStatus {
status: 401,
body: String::new()
}
.is_retryable());
assert!(!ClientError::UnexpectedStatus {
status: 404,
body: String::new()
}
.is_retryable());
assert!(!ClientError::ProtocolBindingMismatch("wrong".into()).is_retryable());
assert!(!ClientError::AuthRequired {
task_id: TaskId::new("t")
}
.is_retryable());
}
}