use motore::{layer::Layer, service::Service};
use volo::context::Context;
#[derive(Debug, Default, Clone, Copy)]
pub struct Extension<T>(pub T);
impl<S, T> Layer<S> for Extension<T>
where
S: Send + Sync + 'static,
T: Sync,
{
type Service = ExtensionService<S, T>;
fn layer(self, inner: S) -> Self::Service {
ExtensionService { inner, ext: self.0 }
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct ExtensionService<I, T> {
inner: I,
ext: T,
}
impl<S, Cx, Req, Resp, E, T> Service<Cx, Req> for ExtensionService<S, T>
where
S: Service<Cx, Req, Response = Resp, Error = E> + Send + Sync + 'static,
Req: Send,
Cx: Context + Send,
T: Clone + Send + Sync + 'static,
{
type Response = S::Response;
type Error = S::Error;
async fn call(&self, cx: &mut Cx, req: Req) -> Result<Self::Response, Self::Error> {
cx.extensions_mut().insert(self.ext.clone());
self.inner.call(cx, req).await
}
}
#[cfg(feature = "server")]
mod server {
use http::{StatusCode, request::Parts};
use volo::context::Context;
use super::Extension;
use crate::{
context::ServerContext,
response::Response,
server::{IntoResponse, extract::FromContext},
};
impl<T> FromContext for Extension<T>
where
T: Clone + Send + Sync + 'static,
{
type Rejection = ExtensionRejection;
async fn from_context(
cx: &mut ServerContext,
_parts: &mut Parts,
) -> Result<Self, Self::Rejection> {
cx.extensions()
.get::<T>()
.cloned()
.map(Extension)
.ok_or(ExtensionRejection::NotExist)
}
}
pub enum ExtensionRejection {
NotExist,
}
impl IntoResponse for ExtensionRejection {
fn into_response(self) -> Response {
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}