1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
use std::{
    fmt,
    ops::{Deref, DerefMut},
};

use viz_utils::{anyhow::anyhow, futures::future::BoxFuture, tracing};

use crate::{http, Context, Error, Extract, Result};

/// Application state factory.
pub trait StateFactory: Send + Sync + 'static {
    /// Injects a state to Application.
    fn create(&self, extensions: &mut http::Extensions) -> bool;
}

/// Application state.
#[derive(Clone)]
pub struct State<T>(pub T);

impl<T> State<T> {
    /// Create new `State` instance.
    pub fn new(t: T) -> Self {
        Self(t)
    }

    /// Deconstruct to an inner value,
    pub fn into_inner(self) -> T {
        self.0
    }
}

impl<T> AsRef<T> for State<T> {
    fn as_ref(&self) -> &T {
        &self.0
    }
}

impl<T> Deref for State<T> {
    type Target = T;

    fn deref(&self) -> &T {
        &self.0
    }
}

impl<T> DerefMut for State<T> {
    fn deref_mut(&mut self) -> &mut T {
        &mut self.0
    }
}

impl<T: fmt::Debug> fmt::Debug for State<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        T::fmt(self, f)
    }
}

impl<T> StateFactory for State<T>
where
    T: Clone + Send + Sync + 'static,
{
    fn create(&self, extensions: &mut http::Extensions) -> bool {
        if extensions.get::<Self>().is_none() {
            extensions.insert(self.clone());
            true
        } else {
            false
        }
    }
}

impl<T> Extract for State<T>
where
    T: Clone + Send + Sync + 'static,
{
    type Error = Error;

    #[inline]
    fn extract(cx: &mut Context) -> BoxFuture<'_, Result<Self, Self::Error>> {
        let state = cx.extensions().get::<Self>().cloned().ok_or_else(|| {
            tracing::error!("State extract error: {}", std::any::type_name::<T>());
            anyhow!("State is not configured")
        });
        Box::pin(async move { state })
    }
}

impl Context {
    /// Gets an app state.
    pub fn state<T>(&self) -> Result<T, Error>
    where
        T: Clone + Send + Sync + 'static,
    {
        self.extensions()
            .get::<State<T>>()
            .cloned()
            .ok_or_else(|| {
                tracing::error!("State extract error: {}", std::any::type_name::<T>());
                anyhow!("State is not configured")
            })
            .map(State::into_inner)
    }
}