use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use dashmap::DashMap;
use once_cell::sync::Lazy;
use std::fmt::Debug;
use std::sync::Arc;
static REQUEST_PARTS_STORAGE: Lazy<DashMap<usize, Parts>> = Lazy::new(DashMap::new);
fn get_task_id() -> usize {
let thread_id = std::thread::current().id();
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
thread_id.hash(&mut hasher);
hasher.finish() as usize
}
#[derive(Debug)]
pub enum ExtractError {
MissingParts(String),
ExtractionFailed(String),
}
impl std::fmt::Display for ExtractError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExtractError::MissingParts(msg) => write!(f, "Missing request parts: {}", msg),
ExtractError::ExtractionFailed(msg) => write!(f, "Extraction failed: {}", msg),
}
}
}
impl std::error::Error for ExtractError {}
pub async fn provide_request_parts(parts: Parts) {
let task_id = get_task_id();
REQUEST_PARTS_STORAGE.insert(task_id, parts);
}
pub async fn clear_request_parts() {
let task_id = get_task_id();
REQUEST_PARTS_STORAGE.remove(&task_id);
}
pub async fn extract<T>() -> Result<T, ExtractError>
where
T: Sized + FromRequestParts<()>,
T::Rejection: Debug,
{
extract_with_state::<T, ()>(&()).await
}
pub async fn extract_with_state<T, S>(state: &S) -> Result<T, ExtractError>
where
T: Sized + FromRequestParts<S>,
T::Rejection: Debug,
{
let task_id = get_task_id();
let parts_ref = REQUEST_PARTS_STORAGE
.get(&task_id)
.ok_or_else(|| {
ExtractError::MissingParts(
"Request parts not found. Make sure provide_request_parts() was called.".to_string()
)
})?;
let mut parts = parts_ref.value().clone();
T::from_request_parts(&mut parts, state)
.await
.map_err(|e| ExtractError::ExtractionFailed(format!("{:?}", e)))
}