use crate::error::ReceiverError;
use axum::{
extract::Request,
http::{header, HeaderValue},
middleware::Next,
response::{IntoResponse, Response},
};
use tracing::debug;
pub async fn validate_content_type(req: Request, next: Next) -> Result<Response, ReceiverError> {
if req.uri().path().contains("/health") {
return Ok(next.run(req).await);
}
if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
let content_type_str = content_type.to_str().map_err(|_| {
ReceiverError::InvalidContentType("Invalid Content-Type header".to_string())
})?;
if content_type_str.starts_with("application/x-protobuf")
|| content_type_str.starts_with("application/json")
{
return Ok(next.run(req).await);
}
return Err(ReceiverError::InvalidContentType(format!(
"Unsupported Content-Type: {}. Expected application/x-protobuf or application/json",
content_type_str
)));
}
Err(ReceiverError::InvalidContentType(
"Missing Content-Type header".to_string(),
))
}
pub async fn handle_compression(req: Request, next: Next) -> Response {
let content_encoding = req
.headers()
.get(header::CONTENT_ENCODING)
.and_then(|v| v.to_str().ok());
match content_encoding {
Some("gzip") => {
debug!("Request has gzip compression");
next.run(req).await
},
Some("deflate") => {
debug!("Request has deflate compression");
next.run(req).await
},
Some("identity") | None => {
next.run(req).await
},
Some(encoding) => {
let error = ReceiverError::CompressionError(format!(
"Unsupported Content-Encoding: {}",
encoding
));
error.into_response()
},
}
}
pub async fn add_cors_headers(req: Request, next: Next) -> Response {
let mut response = next.run(req).await;
response.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_static("*"),
);
response.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
HeaderValue::from_static("GET, POST, OPTIONS"),
);
response.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
HeaderValue::from_static("Content-Type, Content-Encoding"),
);
response
}
#[cfg(test)]
mod tests {
#[test]
fn test_middleware_module_compiles() {
}
}