use super::error::A2aError;
use axum::{
http::{HeaderValue, StatusCode},
response::{IntoResponse, Response},
};
pub const SUPPORTED_VERSIONS: &[&str] = &["0.3", "1.0"];
const A2A_VERSION_HEADER: &str = "a2a-version";
pub fn negotiate_version(requested: Option<&str>) -> Result<&'static str, A2aError> {
match requested {
None | Some("") => Ok("0.3"),
Some(v) => {
if let Some(&supported) = SUPPORTED_VERSIONS.iter().find(|&&s| s == v) {
Ok(supported)
} else {
Err(A2aError::VersionNotSupported {
requested: v.to_string(),
supported: SUPPORTED_VERSIONS.iter().map(|s| (*s).to_string()).collect(),
})
}
}
}
}
pub async fn version_negotiation(
req: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
let requested = req.headers().get(A2A_VERSION_HEADER).and_then(|v| v.to_str().ok());
match negotiate_version(requested) {
Ok(version) => {
let mut response = next.run(req).await;
if let Ok(value) = HeaderValue::from_str(version) {
response.headers_mut().insert(A2A_VERSION_HEADER, value);
}
response
}
Err(err) => {
let body = err.to_http_error_response();
(
StatusCode::from_u16(err.http_status()).unwrap_or(StatusCode::BAD_REQUEST),
[(axum::http::header::CONTENT_TYPE, HeaderValue::from_static("application/json"))],
axum::Json(body),
)
.into_response()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn supported_version_0_3_returns_ok() {
let result = negotiate_version(Some("0.3"));
assert_eq!(result.unwrap(), "0.3");
}
#[test]
fn supported_version_1_0_returns_ok() {
let result = negotiate_version(Some("1.0"));
assert_eq!(result.unwrap(), "1.0");
}
#[test]
fn missing_header_defaults_to_0_3() {
let result = negotiate_version(None);
assert_eq!(result.unwrap(), "0.3");
}
#[test]
fn empty_header_defaults_to_0_3() {
let result = negotiate_version(Some(""));
assert_eq!(result.unwrap(), "0.3");
}
#[test]
fn unsupported_version_returns_error_with_supported_list() {
let result = negotiate_version(Some("2.0"));
let err = result.unwrap_err();
match &err {
A2aError::VersionNotSupported { requested, supported } => {
assert_eq!(requested, "2.0");
assert_eq!(supported, &["0.3", "1.0"]);
}
other => panic!("expected VersionNotSupported, got: {other}"),
}
assert_eq!(err.json_rpc_code(), -32009);
assert_eq!(err.http_status(), 400);
}
#[test]
fn unsupported_version_0_1_returns_error() {
let result = negotiate_version(Some("0.1"));
assert!(result.is_err());
let err = result.unwrap_err();
match &err {
A2aError::VersionNotSupported { requested, .. } => {
assert_eq!(requested, "0.1");
}
other => panic!("expected VersionNotSupported, got: {other}"),
}
}
#[test]
fn unsupported_version_garbage_returns_error() {
let result = negotiate_version(Some("not-a-version"));
assert!(result.is_err());
}
#[test]
fn all_supported_versions_return_ok() {
for &version in SUPPORTED_VERSIONS {
let result = negotiate_version(Some(version));
assert_eq!(result.unwrap(), version);
}
}
}