use std::pin::Pin;
use crate::{Middleware, Request, Response};
use super::Next;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct State<T>(pub T);
impl<T> State<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> std::ops::Deref for State<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> std::ops::DerefMut for State<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[derive(Clone)]
pub struct StateMiddleware<T>(T);
impl<T> StateMiddleware<T> {
pub fn new(value: T) -> Self {
StateMiddleware(value)
}
}
#[async_trait]
impl<T: Clone + Send + Sync + 'static> Middleware for StateMiddleware<T> {
async fn apply(
self: Pin<&Self>,
mut request: Request,
next: Next<'_>,
) -> Result<Response, anyhow::Error> {
request.extensions_mut().insert(State(self.0.clone()));
next.apply(request).await
}
}
impl<T> std::fmt::Debug for StateMiddleware<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = std::any::type_name::<T>();
f.debug_tuple("StateMiddleware").field(&name).finish()
}
}