modo-rs 0.8.0

Rust web framework for small monolithic apps
Documentation
use axum::extract::{FromRequestParts, OptionalFromRequestParts};
use http::request::Parts;

use crate::error::Error;

pub use super::types::TierInfo;

/// Extracts [`TierInfo`] from request extensions.
///
/// # Errors
///
/// Returns [`Error::internal`](crate::Error::internal) if [`TierLayer`](super::TierLayer)
/// has not been applied (i.e., `TierInfo` is missing from extensions).
impl<S: Send + Sync> FromRequestParts<S> for TierInfo {
    type Rejection = Error;

    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
        parts
            .extensions
            .get::<TierInfo>()
            .cloned()
            .ok_or_else(|| Error::internal("Tier middleware not applied"))
    }
}

/// Optionally extracts [`TierInfo`] from request extensions.
///
/// Returns `Ok(None)` when [`TierLayer`](super::TierLayer) has not been applied
/// or the owner extractor returned `None` without a default tier configured.
impl<S: Send + Sync> OptionalFromRequestParts<S> for TierInfo {
    type Rejection = Error;

    async fn from_request_parts(
        parts: &mut Parts,
        _state: &S,
    ) -> Result<Option<Self>, Self::Rejection> {
        Ok(parts.extensions.get::<TierInfo>().cloned())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashMap;

    use super::super::types::FeatureAccess;

    fn test_tier() -> TierInfo {
        TierInfo {
            name: "pro".into(),
            features: HashMap::from([("sso".into(), FeatureAccess::Toggle(true))]),
        }
    }

    #[tokio::test]
    async fn extract_from_extensions() {
        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
        parts.extensions.insert(test_tier());

        let result = <TierInfo as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
        assert!(result.is_ok());
        assert_eq!(result.unwrap().name, "pro");
    }

    #[tokio::test]
    async fn extract_missing_returns_500() {
        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();

        let result = <TierInfo as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
        assert!(result.is_err());
        let err = result.unwrap_err();
        assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
    }

    #[tokio::test]
    async fn optional_none_when_missing() {
        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();

        let result =
            <TierInfo as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
        assert!(result.is_ok());
        assert!(result.unwrap().is_none());
    }

    #[tokio::test]
    async fn optional_some_when_present() {
        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
        parts.extensions.insert(test_tier());

        let result =
            <TierInfo as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
        assert!(result.is_ok());
        let tier = result.unwrap().unwrap();
        assert_eq!(tier.name, "pro");
    }
}