use std::task::{Context, Poll};
use axum::body::Body;
use http::{Request, Response};
use tower::{Layer, Service};
use crate::error::A2aError;
const SUPPORTED_VERSION: &str = "1.0";
const VERSION_EXEMPT_PATHS: &[&str] = &["/.well-known/agent-card.json"];
fn is_version_exempt(path: &str) -> bool {
VERSION_EXEMPT_PATHS.contains(&path)
}
#[derive(Clone)]
pub struct TransportComplianceLayer;
impl<S> Layer<S> for TransportComplianceLayer {
type Service = TransportComplianceService<S>;
fn layer(&self, inner: S) -> Self::Service {
TransportComplianceService { inner }
}
}
#[derive(Clone)]
pub struct TransportComplianceService<S> {
inner: S,
}
impl<S> Service<Request<Body>> for TransportComplianceService<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let mut inner = self.inner.clone();
Box::pin(async move {
let path = req.uri().path().to_string();
let method = req.method().clone();
if !is_version_exempt(&path) {
match req
.headers()
.get("a2a-version")
.and_then(|v| v.to_str().ok())
{
Some(v) if v == SUPPORTED_VERSION => {} Some(v) => {
let err = A2aError::VersionNotSupported {
version: v.to_string(),
};
return Ok(err.into_response_body());
}
None => {
#[cfg(not(feature = "compat-v03"))]
{
let err = A2aError::VersionNotSupported {
version: "missing (A2A-Version header is required)".to_string(),
};
return Ok(err.into_response_body());
}
}
}
}
if method == http::Method::POST {
let has_content_type = req.headers().contains_key(http::header::CONTENT_TYPE);
if has_content_type {
let content_type = req
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !content_type.contains("application/json") {
let err = A2aError::ContentTypeNotSupported {
content_type: content_type.to_string(),
};
return Ok(err.into_response_body());
}
}
}
inner.call(req).await
})
}
}
impl A2aError {
#[allow(clippy::wrong_self_convention)] fn into_response_body(&self) -> Response<Body> {
let status = axum::http::StatusCode::from_u16(self.http_status())
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
let body = self.to_http_error_body();
Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&body).unwrap_or_default()))
.unwrap_or_else(|_| Response::builder().status(500).body(Body::empty()).unwrap())
}
}