use std::{any::Any, collections::HashMap, sync::Arc};
#[cfg(feature = "utoipa")]
use convert_case::{Case, Casing};
use ordered_hash_map::OrderedHashMap;
use tracing::warn;
#[cfg(feature = "utoipa")]
use utoipa::{
IntoParams,
openapi::{
OpenApi, PathItem, Paths,
path::{Operation, ParameterIn},
},
};
#[cfg(feature = "utoipa")]
use crate::path::{ActionPathParams, MethodActionPathParams};
use crate::{
SignOutAction,
action::{Action, ActionForms, ActionMethodForm, ActionProviderForm},
error::{ActionError, MethodError, ProviderError, SessionError, ShieldError},
method::ErasedMethod,
options::ShieldOptions,
request::Request,
response::ResponseType,
session::Session,
storage::Storage,
user::User,
};
#[derive(Clone)]
pub struct Shield<U: User> {
storage: Arc<dyn Storage<U>>,
actions: Arc<HashMap<String, Arc<dyn Action>>>,
methods: Arc<OrderedHashMap<String, Arc<dyn ErasedMethod>>>,
options: ShieldOptions,
}
impl<U: User> Shield<U> {
pub fn new<S>(storage: S, methods: Vec<Arc<dyn ErasedMethod>>, options: ShieldOptions) -> Self
where
S: Storage<U> + 'static,
{
let actions: [Arc<dyn Action>; 1] = [Arc::new(SignOutAction)];
Self {
storage: Arc::new(storage),
actions: Arc::new(
actions
.into_iter()
.map(|action| (action.id().to_owned(), action))
.collect(),
),
methods: Arc::new(
methods
.into_iter()
.map(|method| (method.erased_id(), method))
.collect(),
),
options,
}
}
pub fn storage(&self) -> &dyn Storage<U> {
&*self.storage
}
pub fn options(&self) -> &ShieldOptions {
&self.options
}
pub fn action_by_id(&self, action_id: &str) -> Option<&dyn Action> {
self.actions.get(action_id).map(|v| &**v)
}
pub fn method_by_id(&self, method_id: &str) -> Option<&dyn ErasedMethod> {
self.methods.get(method_id).map(|v| &**v)
}
pub async fn provider_by_id(
&self,
method_id: &str,
provider_id: Option<&str>,
) -> Result<Option<Box<dyn Any + Send + Sync>>, ShieldError> {
match self.method_by_id(method_id) {
Some(provider) => provider.erased_provider_by_id(provider_id).await,
None => Ok(None),
}
}
pub async fn action_forms(
&self,
action_id: &str,
session: Session,
) -> Result<ActionForms, ShieldError> {
let mut action_name = None::<String>;
let mut forms = vec![];
let mut method_forms = vec![];
if let Some(action) = self.actions.get(action_id) {
action_name = Some(action.name().to_owned());
forms = action.forms().await?;
}
for (method_id, method) in self.methods.iter() {
let Some(action) = method.erased_action_by_id(action_id) else {
continue;
};
let (base_session, method_session) = {
let session_data = session.data();
let session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;
(
session_data.base.clone(),
method.erased_deserialize_session(session_data.method_str(method_id))?,
)
};
let name = action.erased_name();
if let Some(action_name) = &action_name
&& *action_name != name
{
warn!("Action name mismatch `{}` != `{}`", action_name, name);
}
action_name = Some(name);
let mut provider_forms = vec![];
for (provider_id, provider) in method.erased_providers().await? {
if !action.erased_condition(&*provider, &base_session, &*method_session)? {
continue;
}
let forms = action.erased_forms(provider).await?;
for form in forms {
provider_forms.push(ActionProviderForm {
id: provider_id.clone(),
form,
});
}
}
provider_forms.sort_by(|a, b| a.id.cmp(&b.id));
method_forms.push(ActionMethodForm {
id: method_id.clone(),
provider_forms,
});
}
Ok(ActionForms {
id: action_id.to_owned(),
name: action_name.unwrap_or(action_id.to_owned()),
forms,
method_forms,
})
}
pub async fn call(
&self,
action_id: &str,
session: Session,
request: Request,
) -> Result<ResponseType, ShieldError> {
let action =
self.action_by_id(action_id)
.ok_or(ShieldError::Action(ActionError::NotFound(
action_id.to_owned(),
)))?;
let base_session = {
let session_data = session.data();
let session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;
session_data.base.clone()
};
let response = action.call(&base_session, request).await?;
for session_action in &response.session_actions {
session_action.call(&session).await?;
}
Ok(response.r#type)
}
pub async fn call_method(
&self,
action_id: &str,
method_id: &str,
provider_id: Option<&str>,
session: Session,
request: Request,
) -> Result<ResponseType, ShieldError> {
let method =
self.method_by_id(method_id)
.ok_or(ShieldError::Method(MethodError::NotFound(
method_id.to_owned(),
)))?;
let action = method
.erased_action_by_id(action_id)
.ok_or(ShieldError::Action(ActionError::NotFound(
action_id.to_owned(),
)))?;
let provider =
method
.erased_provider_by_id(provider_id)
.await?
.ok_or(ShieldError::Provider(ProviderError::NotFound(
provider_id.map(ToOwned::to_owned),
)))?;
let (base_session, method_session) = {
let session_data = session.data();
let session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;
(
session_data.base.clone(),
method.erased_deserialize_session(session_data.method_str(method_id))?,
)
};
let response = action
.erased_call(provider, &base_session, &*method_session, request)
.await?;
for session_action in &response.session_actions {
session_action.call(&session).await?;
}
Ok(response.r#type)
}
pub async fn user(&self, session: &Session) -> Result<Option<U>, ShieldError> {
let authentication = {
let session_data = session.data();
let session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;
session_data.base.authentication.clone()
};
match authentication {
Some(authentication) => {
if self
.provider_by_id(
&authentication.method_id,
authentication.provider_id.as_deref(),
)
.await?
.is_none()
{
session.purge().await?;
return Ok(None);
}
let user = self.storage().user_by_id(&authentication.user_id).await?;
if user.is_none() {
session.purge().await?;
}
Ok(user)
}
None => Ok(None),
}
}
#[cfg(feature = "utoipa")]
pub fn openapi(&self) -> OpenApi {
use utoipa::openapi::Response;
let mut paths = Paths::builder();
for action in self.actions.values() {
let action_id = action.id();
paths = paths.path(
format!("/{action_id}"),
PathItem::builder()
.operation(
action.method().into(),
Operation::builder()
.operation_id(Some(action_id.to_case(Case::Camel)))
.summary(Some(action.openapi_summary()))
.description(Some(action.openapi_description()))
.tag("auth")
.parameters(Some(ActionPathParams::into_params(|| {
Some(ParameterIn::Path)
})))
.response(
"500",
Response::builder().description("Internal server error."),
),
)
.build(),
);
}
for method in self.methods.values() {
for action in method.erased_actions() {
let method_id = method.erased_id();
let action_id = action.erased_id();
paths = paths.path(
format!("/{action_id}/{method_id}/{{providerId}}"),
PathItem::builder()
.operation(
action.erased_method().into(),
Operation::builder()
.operation_id(Some(format!(
"{}{}",
action_id.to_case(Case::Camel),
method_id.to_case(Case::UpperCamel)
)))
.summary(Some(action.erased_openapi_summary()))
.description(Some(action.erased_openapi_description()))
.tag("auth")
.parameters(Some(MethodActionPathParams::into_params(|| {
Some(ParameterIn::Path)
})))
.response(
"500",
Response::builder().description("Internal server error."),
),
)
.build(),
);
}
}
OpenApi::builder().paths(paths.build()).build()
}
}
#[cfg(test)]
mod tests {
use crate::{
options::ShieldOptions,
storage::tests::{TEST_STORAGE_ID, TestStorage},
};
use super::Shield;
#[test]
fn test_storage() {
let shield = Shield::new(TestStorage::default(), vec![], ShieldOptions::default());
assert_eq!(TEST_STORAGE_ID, shield.storage().id());
}
}