use std::{future::Future, ops::Deref, sync::Arc};
use axum::{Json, extract::FromRequestParts, response::IntoResponse};
use http::{StatusCode, request::Parts};
use nidus_core::{Inject, NidusError, SharedRequestScope};
use serde::Serialize;
#[derive(Clone, Debug)]
pub struct RequestScoped<T: Send + Sync + 'static>(Inject<T>);
impl<T> RequestScoped<T>
where
T: Send + Sync + 'static,
{
pub fn new(value: Inject<T>) -> Self {
Self(value)
}
pub fn into_inject(self) -> Inject<T> {
self.0
}
pub fn into_inner(self) -> Arc<T> {
self.0.into_inner()
}
}
impl<T> Deref for RequestScoped<T>
where
T: Send + Sync + 'static,
{
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S, T> FromRequestParts<S> for RequestScoped<T>
where
S: Send + Sync,
T: Send + Sync + 'static,
{
type Rejection = RequestScopeRejection;
fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
let scope = parts.extensions.get::<SharedRequestScope>().cloned();
async move {
let scope = scope.ok_or(RequestScopeRejection::MissingScope)?;
scope
.inject::<T>()
.map(Self::new)
.map_err(RequestScopeRejection::ResolutionFailed)
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum RequestScopeRejection {
#[error("request scope is not available; attach request_scope_layer to the router")]
MissingScope,
#[error("request-scoped provider resolution failed: {0}")]
ResolutionFailed(#[source] NidusError),
}
impl IntoResponse for RequestScopeRejection {
fn into_response(self) -> axum::response::Response {
let (code, message) = match self {
Self::MissingScope => (
"request_scope_unavailable",
"request scope is not available; attach request_scope_layer to the router"
.to_owned(),
),
Self::ResolutionFailed(error) => (
"request_scope_resolution_failed",
format!("request-scoped provider resolution failed: {error}"),
),
};
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorBody {
error: ErrorDetails { code, message },
}),
)
.into_response()
}
}
#[derive(Debug, Serialize)]
struct ErrorBody {
error: ErrorDetails,
}
#[derive(Debug, Serialize)]
struct ErrorDetails {
code: &'static str,
message: String,
}