use {crate::{content_types::{APPLICATION_JSON,
APPLICATION_PROTOBUF},
error::JetError},
axum::{async_trait,
body::Bytes,
extract::{FromRequest,
FromRequestParts,
Request},
http::{StatusCode,
header::{ACCEPT,
CONTENT_TYPE},
request::Parts},
response::{IntoResponse,
Response}},
prost::Message,
serde::{Serialize,
de::DeserializeOwned},
std::sync::OnceLock};
pub const DEBUG_FORMAT_HEADER: &str = "x-debug-format";
static DEBUG_KEYS: OnceLock<Vec<String>> = OnceLock::new();
pub fn configure_debug_keys(keys: Vec<String>) {
DEBUG_KEYS.set(keys).ok(); }
fn is_json_allowed(debug_header_value: Option<&str>) -> bool {
match (DEBUG_KEYS.get(), debug_header_value) {
| (Some(keys), Some(provided_key)) if !keys.is_empty() => {
keys.iter().any(|k| k == provided_key)
}
| _ => {
false
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ResponseFormat {
#[default]
Protobuf,
Json,
}
impl ResponseFormat {
pub fn from_headers(accept: Option<&str>, debug_header: Option<&str>) -> Self {
let wants_json = accept.map(|s| s.contains("application/json")).unwrap_or(false);
if wants_json && is_json_allowed(debug_header) {
ResponseFormat::Json
} else {
ResponseFormat::Protobuf
}
}
pub fn is_json(&self) -> bool {
matches!(self, ResponseFormat::Json)
}
}
pub struct AcceptFormat(pub ResponseFormat);
#[async_trait]
impl<S> FromRequestParts<S> for AcceptFormat
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let accept = parts.headers.get(ACCEPT).and_then(|v| v.to_str().ok());
let debug_header = parts.headers.get(DEBUG_FORMAT_HEADER).and_then(|v| v.to_str().ok());
Ok(AcceptFormat(ResponseFormat::from_headers(accept, debug_header)))
}
}
const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
pub struct ApiRequest<T> {
pub body: T,
pub format: ResponseFormat,
}
impl<T> ApiRequest<T> {
pub fn ok<R>(self, response: R) -> ApiResponse<R>
where
R: Message + Serialize, {
ApiResponse::ok(self.format, response)
}
pub fn respond<R>(self, status: StatusCode, response: R) -> ApiResponse<R>
where
R: Message + Serialize, {
ApiResponse::new(self.format, status, response)
}
pub fn created<R>(self, response: R) -> ApiResponse<R>
where
R: Message + Serialize, {
ApiResponse::created(self.format, response)
}
}
#[async_trait]
impl<S, T> FromRequest<S> for ApiRequest<T>
where
S: Send + Sync,
T: Message + Default + DeserializeOwned,
{
type Rejection = JetError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let accept = req
.headers()
.get(ACCEPT)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let debug_header = req
.headers()
.get(DEBUG_FORMAT_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let content_type = req
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let format = ResponseFormat::from_headers(accept.as_deref(), debug_header.as_deref());
let json_allowed = is_json_allowed(debug_header.as_deref());
let wants_json_input = content_type.contains("application/json");
let bytes = Bytes::from_request(req, state)
.await
.map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;
if bytes.len() > MAX_BODY_SIZE {
return Err(JetError::BodyTooLarge {
size: bytes.len(),
max: MAX_BODY_SIZE,
});
}
let body = if wants_json_input {
if !json_allowed {
return Err(JetError::InvalidContentType {
expected: APPLICATION_PROTOBUF.to_string(),
actual: "application/json (requires valid X-Debug-Format header)".to_string(),
});
}
serde_json::from_slice(&bytes).map_err(|e| JetError::BadRequest(format!("Invalid JSON: {}", e)))?
} else if content_type.contains(APPLICATION_PROTOBUF) || bytes.is_empty() {
if bytes.is_empty() {
T::default()
} else {
T::decode(bytes)?
}
} else {
T::decode(bytes)?
};
Ok(ApiRequest { body, format })
}
}
pub struct ApiResponse<T>
where
T: Message + Serialize, {
format: ResponseFormat,
status: StatusCode,
message: T,
}
impl<T> ApiResponse<T>
where
T: Message + Serialize,
{
pub fn new(format: ResponseFormat, status: StatusCode, message: T) -> Self {
Self {
format,
status,
message,
}
}
pub fn ok(format: ResponseFormat, message: T) -> Self {
Self::new(format, StatusCode::OK, message)
}
pub fn created(format: ResponseFormat, message: T) -> Self {
Self::new(format, StatusCode::CREATED, message)
}
pub fn accepted(format: ResponseFormat, message: T) -> Self {
Self::new(format, StatusCode::ACCEPTED, message)
}
}
impl<T> IntoResponse for ApiResponse<T>
where
T: Message + Serialize,
{
fn into_response(self) -> Response {
match self.format {
| ResponseFormat::Json => {
match serde_json::to_vec(&self.message) {
| Ok(bytes) => (self.status, [(CONTENT_TYPE, APPLICATION_JSON)], bytes).into_response(),
| Err(e) => {
(
StatusCode::INTERNAL_SERVER_ERROR,
[(CONTENT_TYPE, APPLICATION_JSON)],
format!("{{\"error\": \"JSON serialization failed: {}\"}}", e),
)
.into_response()
}
}
}
| ResponseFormat::Protobuf => {
let bytes = self.message.encode_to_vec();
(self.status, [(CONTENT_TYPE, APPLICATION_PROTOBUF)], bytes).into_response()
}
}
}
}
pub type ApiResult<T> = Result<ApiResponse<T>, JetError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_response_format_without_debug_keys() {
assert_eq!(ResponseFormat::from_headers(None, None), ResponseFormat::Protobuf);
assert_eq!(
ResponseFormat::from_headers(Some("application/x-protobuf"), None),
ResponseFormat::Protobuf
);
assert_eq!(
ResponseFormat::from_headers(Some("application/json"), None),
ResponseFormat::Protobuf
);
assert_eq!(
ResponseFormat::from_headers(Some("application/json"), Some("any-key")),
ResponseFormat::Protobuf
);
}
#[test]
fn test_is_json_allowed_not_configured() {
assert!(!is_json_allowed(None));
assert!(!is_json_allowed(Some("any-value")));
}
#[test]
fn test_response_format_is_json() {
assert!(!ResponseFormat::Protobuf.is_json());
assert!(ResponseFormat::Json.is_json());
}
}