use axum::async_trait;
use axum::body::Bytes;
use axum::extract::{FromRequest, FromRequestParts, Request};
use axum::http::request::Parts;
use axum::http::StatusCode;
use axum::response::Response;
use serde::de::DeserializeOwned;
use std::collections::HashMap;
use uuid::Uuid;
use crate::error::error_response;
use crate::runtime::SharedState;
use crate::schema::{check_unknown_fields, AcubeValidate, ValidationError};
use crate::security::AuthIdentity;
#[derive(Debug, Clone)]
pub struct RequestId(pub String);
#[derive(Clone)]
pub struct AcubeContext {
pub request_id: String,
pub auth: Option<AuthIdentity>,
path_params: HashMap<String, String>,
shared_state: SharedState,
}
impl std::fmt::Debug for AcubeContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AcubeContext")
.field("request_id", &self.request_id)
.field("auth", &self.auth)
.field("path_params", &self.path_params)
.finish()
}
}
impl AcubeContext {
pub fn path<T: std::str::FromStr>(&self, name: &str) -> T
where
T::Err: std::fmt::Display,
{
let value = self
.path_params
.get(name)
.unwrap_or_else(|| panic!("path parameter '{}' not found", name));
value
.parse()
.unwrap_or_else(|e| panic!("path parameter '{}' parse error: {}", name, e))
}
pub fn state<T: Clone + Send + Sync + 'static>(&self) -> T {
self.shared_state.get::<T>().unwrap_or_else(|| {
panic!(
"state type '{}' not registered — call .state(value) on the ServiceBuilder",
std::any::type_name::<T>()
)
})
}
pub fn user_id(&self) -> &str {
&self
.auth
.as_ref()
.expect("user_id() called on unauthenticated endpoint — use #[acube_security(jwt)]")
.subject
}
}
#[async_trait]
impl<S> FromRequestParts<S> for AcubeContext
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let request_id = parts
.extensions
.get::<RequestId>()
.map(|r| r.0.clone())
.unwrap_or_else(|| Uuid::new_v4().to_string());
let auth = parts.extensions.get::<AuthIdentity>().cloned();
let path_params =
axum::extract::Path::<HashMap<String, String>>::from_request_parts(parts, _state)
.await
.map(|p| p.0)
.unwrap_or_default();
let shared_state = parts
.extensions
.get::<SharedState>()
.cloned()
.unwrap_or_default();
Ok(AcubeContext {
request_id,
auth,
path_params,
shared_state,
})
}
}
pub struct Valid<T>(T);
impl<T> Valid<T> {
pub fn into_inner(self) -> T {
self.0
}
}
#[async_trait]
impl<T, S> FromRequest<S> for Valid<T>
where
T: DeserializeOwned + AcubeValidate + Send,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let request_id = req
.extensions()
.get::<RequestId>()
.map(|r| r.0.clone())
.unwrap_or_else(|| Uuid::new_v4().to_string());
let bytes = Bytes::from_request(req, state).await.map_err(|e| {
let status = e.status();
if status == StatusCode::PAYLOAD_TOO_LARGE {
error_response(
StatusCode::PAYLOAD_TOO_LARGE,
"payload_too_large",
"Request body too large",
&request_id,
false,
None,
)
} else {
error_response(
StatusCode::BAD_REQUEST,
"invalid_body",
"Failed to read request body",
&request_id,
false,
None,
)
}
})?;
let value: serde_json::Value = serde_json::from_slice(&bytes).map_err(|_| {
error_response(
StatusCode::BAD_REQUEST,
"invalid_json",
"Invalid JSON",
&request_id,
false,
None,
)
})?;
let unknown_errors = check_unknown_fields(&value, T::known_fields());
if !unknown_errors.is_empty() {
return Err(validation_error_response(&request_id, unknown_errors));
}
let mut input: T = serde_json::from_value(value).map_err(|e| {
tracing::debug!(request_id = %request_id, error = %e, "deserialization failed");
error_response(
StatusCode::BAD_REQUEST,
"deserialization_error",
"Invalid request body",
&request_id,
false,
None,
)
})?;
if let Err(errors) = input.validate() {
return Err(validation_error_response(&request_id, errors));
}
Ok(Valid(input))
}
}
fn validation_error_response(request_id: &str, errors: Vec<ValidationError>) -> Response {
for err in &errors {
tracing::warn!(
request_id = %request_id,
field = %err.field,
code = %err.code,
message = %err.message,
"validation failed"
);
}
let fields: Vec<&str> = errors.iter().map(|e| e.field.as_str()).collect();
error_response(
StatusCode::BAD_REQUEST,
"validation_error",
"Validation failed",
request_id,
false,
Some(serde_json::to_value(&fields).unwrap_or_default()),
)
}