shield 0.2.4

Web authentication for Rust.
Documentation
use std::{any::Any, collections::HashMap, sync::Arc};

#[cfg(feature = "utoipa")]
use convert_case::{Case, Casing};
use ordered_hash_map::OrderedHashMap;
use tracing::warn;
#[cfg(feature = "utoipa")]
use utoipa::{
    IntoParams,
    openapi::{
        OpenApi, PathItem, Paths,
        path::{Operation, ParameterIn},
    },
};

#[cfg(feature = "utoipa")]
use crate::path::{ActionPathParams, MethodActionPathParams};
use crate::{
    SignOutAction,
    action::{Action, ActionForms, ActionMethodForm, ActionProviderForm},
    error::{ActionError, MethodError, ProviderError, SessionError, ShieldError},
    method::ErasedMethod,
    options::ShieldOptions,
    request::Request,
    response::ResponseType,
    session::Session,
    storage::Storage,
    user::User,
};

#[derive(Clone)]
pub struct Shield<U: User> {
    storage: Arc<dyn Storage<U>>,
    actions: Arc<HashMap<String, Arc<dyn Action>>>,
    methods: Arc<OrderedHashMap<String, Arc<dyn ErasedMethod>>>,
    options: ShieldOptions,
}

impl<U: User> Shield<U> {
    pub fn new<S>(storage: S, methods: Vec<Arc<dyn ErasedMethod>>, options: ShieldOptions) -> Self
    where
        S: Storage<U> + 'static,
    {
        let actions: [Arc<dyn Action>; 1] = [Arc::new(SignOutAction)];

        // TOOD: Check for duplicate action and method IDs.

        Self {
            storage: Arc::new(storage),
            actions: Arc::new(
                actions
                    .into_iter()
                    .map(|action| (action.id().to_owned(), action))
                    .collect(),
            ),
            methods: Arc::new(
                methods
                    .into_iter()
                    .map(|method| (method.erased_id(), method))
                    .collect(),
            ),
            options,
        }
    }

    pub fn storage(&self) -> &dyn Storage<U> {
        &*self.storage
    }

    pub fn options(&self) -> &ShieldOptions {
        &self.options
    }

    pub fn action_by_id(&self, action_id: &str) -> Option<&dyn Action> {
        self.actions.get(action_id).map(|v| &**v)
    }

    pub fn method_by_id(&self, method_id: &str) -> Option<&dyn ErasedMethod> {
        self.methods.get(method_id).map(|v| &**v)
    }

    pub async fn provider_by_id(
        &self,
        method_id: &str,
        provider_id: Option<&str>,
    ) -> Result<Option<Box<dyn Any + Send + Sync>>, ShieldError> {
        match self.method_by_id(method_id) {
            Some(provider) => provider.erased_provider_by_id(provider_id).await,
            None => Ok(None),
        }
    }

    pub async fn action_forms(
        &self,
        action_id: &str,
        session: Session,
    ) -> Result<ActionForms, ShieldError> {
        let mut action_name = None::<String>;
        let mut forms = vec![];
        let mut method_forms = vec![];

        if let Some(action) = self.actions.get(action_id) {
            action_name = Some(action.name().to_owned());
            forms = action.forms().await?;
        }

        for (method_id, method) in self.methods.iter() {
            let Some(action) = method.erased_action_by_id(action_id) else {
                continue;
            };

            let (base_session, method_session) = {
                let session_data = session.data();
                let session_data = session_data
                    .lock()
                    .map_err(|err| SessionError::Lock(err.to_string()))?;

                (
                    session_data.base.clone(),
                    method.erased_deserialize_session(session_data.method_str(method_id))?,
                )
            };

            let name = action.erased_name();
            if let Some(action_name) = &action_name
                && *action_name != name
            {
                warn!("Action name mismatch `{}` != `{}`", action_name, name);
            }
            action_name = Some(name);

            let mut provider_forms = vec![];
            for (provider_id, provider) in method.erased_providers().await? {
                if !action.erased_condition(&*provider, &base_session, &*method_session)? {
                    continue;
                }

                let forms = action.erased_forms(provider).await?;
                for form in forms {
                    provider_forms.push(ActionProviderForm {
                        id: provider_id.clone(),
                        form,
                    });
                }
            }

            provider_forms.sort_by(|a, b| a.id.cmp(&b.id));

            method_forms.push(ActionMethodForm {
                id: method_id.clone(),
                provider_forms,
            });
        }

        Ok(ActionForms {
            id: action_id.to_owned(),
            name: action_name.unwrap_or(action_id.to_owned()),
            forms,
            method_forms,
        })
    }

    pub async fn call(
        &self,
        action_id: &str,
        session: Session,
        request: Request,
    ) -> Result<ResponseType, ShieldError> {
        let action =
            self.action_by_id(action_id)
                .ok_or(ShieldError::Action(ActionError::NotFound(
                    action_id.to_owned(),
                )))?;

        let base_session = {
            let session_data = session.data();
            let session_data = session_data
                .lock()
                .map_err(|err| SessionError::Lock(err.to_string()))?;

            session_data.base.clone()
        };

        let response = action.call(&base_session, request).await?;

        for session_action in &response.session_actions {
            session_action.call(&session).await?;
        }

        Ok(response.r#type)
    }

    pub async fn call_method(
        &self,
        action_id: &str,
        method_id: &str,
        provider_id: Option<&str>,
        session: Session,
        request: Request,
    ) -> Result<ResponseType, ShieldError> {
        let method =
            self.method_by_id(method_id)
                .ok_or(ShieldError::Method(MethodError::NotFound(
                    method_id.to_owned(),
                )))?;

        let action = method
            .erased_action_by_id(action_id)
            .ok_or(ShieldError::Action(ActionError::NotFound(
                action_id.to_owned(),
            )))?;

        let provider =
            method
                .erased_provider_by_id(provider_id)
                .await?
                .ok_or(ShieldError::Provider(ProviderError::NotFound(
                    provider_id.map(ToOwned::to_owned),
                )))?;

        let (base_session, method_session) = {
            let session_data = session.data();
            let session_data = session_data
                .lock()
                .map_err(|err| SessionError::Lock(err.to_string()))?;

            (
                session_data.base.clone(),
                method.erased_deserialize_session(session_data.method_str(method_id))?,
            )
        };

        let response = action
            .erased_call(provider, &base_session, &*method_session, request)
            .await?;

        for session_action in &response.session_actions {
            session_action.call(&session).await?;
        }

        Ok(response.r#type)
    }

    pub async fn user(&self, session: &Session) -> Result<Option<U>, ShieldError> {
        let authentication = {
            let session_data = session.data();
            let session_data = session_data
                .lock()
                .map_err(|err| SessionError::Lock(err.to_string()))?;

            session_data.base.authentication.clone()
        };

        match authentication {
            Some(authentication) => {
                if self
                    .provider_by_id(
                        &authentication.method_id,
                        authentication.provider_id.as_deref(),
                    )
                    .await?
                    .is_none()
                {
                    session.purge().await?;
                    return Ok(None);
                }

                let user = self.storage().user_by_id(&authentication.user_id).await?;

                if user.is_none() {
                    session.purge().await?;
                }

                Ok(user)
            }
            None => Ok(None),
        }
    }

    #[cfg(feature = "utoipa")]
    pub fn openapi(&self) -> OpenApi {
        use utoipa::openapi::Response;

        let mut paths = Paths::builder();

        for action in self.actions.values() {
            let action_id = action.id();

            // TODO: Query, request body, responses.

            paths = paths.path(
                format!("/{action_id}"),
                PathItem::builder()
                    .operation(
                        action.method().into(),
                        Operation::builder()
                            .operation_id(Some(action_id.to_case(Case::Camel)))
                            .summary(Some(action.openapi_summary()))
                            .description(Some(action.openapi_description()))
                            .tag("auth")
                            .parameters(Some(ActionPathParams::into_params(|| {
                                Some(ParameterIn::Path)
                            })))
                            .response(
                                "500",
                                Response::builder().description("Internal server error."),
                            ),
                    )
                    .build(),
            );
        }

        for method in self.methods.values() {
            for action in method.erased_actions() {
                let method_id = method.erased_id();
                let action_id = action.erased_id();

                // TODO: Query, request body, responses.

                paths = paths.path(
                    format!("/{action_id}/{method_id}/{{providerId}}"),
                    PathItem::builder()
                        .operation(
                            action.erased_method().into(),
                            Operation::builder()
                                .operation_id(Some(format!(
                                    "{}{}",
                                    action_id.to_case(Case::Camel),
                                    method_id.to_case(Case::UpperCamel)
                                )))
                                .summary(Some(action.erased_openapi_summary()))
                                .description(Some(action.erased_openapi_description()))
                                .tag("auth")
                                .parameters(Some(MethodActionPathParams::into_params(|| {
                                    Some(ParameterIn::Path)
                                })))
                                .response(
                                    "500",
                                    Response::builder().description("Internal server error."),
                                ),
                        )
                        .build(),
                );
            }
        }

        OpenApi::builder().paths(paths.build()).build()
    }
}

#[cfg(test)]
mod tests {
    use crate::{
        options::ShieldOptions,
        storage::tests::{TEST_STORAGE_ID, TestStorage},
    };

    use super::Shield;

    #[test]
    fn test_storage() {
        let shield = Shield::new(TestStorage::default(), vec![], ShieldOptions::default());

        assert_eq!(TEST_STORAGE_ID, shield.storage().id());
    }
}