use std::{any::Any, collections::HashMap, sync::Arc};
#[cfg(feature = "utoipa")]
use convert_case::{Case, Casing};
use futures::future::try_join_all;
use tracing::warn;
#[cfg(feature = "utoipa")]
use utoipa::{
IntoParams,
openapi::{
OpenApi, PathItem, Paths,
path::{Operation, ParameterIn},
},
};
#[cfg(feature = "utoipa")]
use crate::path::ActionPathParams;
use crate::{
action::{ActionForms, ActionMethodForm, ActionProviderForm},
error::{ActionError, MethodError, ProviderError, SessionError, ShieldError},
method::ErasedMethod,
options::ShieldOptions,
request::Request,
response::ResponseType,
session::Session,
storage::Storage,
user::User,
};
#[derive(Clone)]
pub struct Shield<U: User> {
storage: Arc<dyn Storage<U>>,
methods: Arc<HashMap<String, Arc<dyn ErasedMethod>>>,
options: ShieldOptions,
}
impl<U: User> Shield<U> {
pub fn new<S>(storage: S, methods: Vec<Arc<dyn ErasedMethod>>, options: ShieldOptions) -> Self
where
S: Storage<U> + 'static,
{
Self {
storage: Arc::new(storage),
methods: Arc::new(
methods
.into_iter()
.map(|method| (method.erased_id(), method))
.collect(),
),
options,
}
}
pub fn storage(&self) -> &dyn Storage<U> {
&*self.storage
}
pub fn options(&self) -> &ShieldOptions {
&self.options
}
pub fn method_by_id(&self, method_id: &str) -> Option<&dyn ErasedMethod> {
self.methods.get(method_id).map(|v| &**v)
}
pub async fn providers(&self) -> Result<Vec<Box<dyn Any + Send + Sync>>, ShieldError> {
try_join_all(
self.methods
.values()
.map(|provider| provider.erased_providers()),
)
.await
.map(|providers| {
providers
.into_iter()
.flat_map(|providers| providers.into_iter().map(|(_, provider)| provider))
.collect::<Vec<_>>()
})
}
pub async fn provider_by_id(
&self,
method_id: &str,
provider_id: Option<&str>,
) -> Result<Option<Box<dyn Any + Send + Sync>>, ShieldError> {
match self.method_by_id(method_id) {
Some(provider) => provider.erased_provider_by_id(provider_id).await,
None => Ok(None),
}
}
pub async fn action_forms(
&self,
action_id: &str,
session: Session,
) -> Result<ActionForms, ShieldError> {
let mut action_name = None::<String>;
let mut method_forms = vec![];
for (method_id, method) in self.methods.iter() {
let Some(action) = method.erased_action_by_id(action_id) else {
continue;
};
let (base_session, method_session) = {
let session_data = session.data();
let session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;
(
session_data.base.clone(),
method.erased_deserialize_session(session_data.method_str(method_id))?,
)
};
let name = action.erased_name();
if let Some(action_name) = &action_name
&& *action_name != name
{
warn!("Action name mismatch `{}` != `{}`", action_name, name);
}
action_name = Some(name);
let mut provider_forms = vec![];
for (provider_id, provider) in method.erased_providers().await? {
if !action.erased_condition(&*provider, &base_session, &*method_session)? {
continue;
}
let forms = action.erased_forms(provider).await?;
for form in forms {
provider_forms.push(ActionProviderForm {
id: provider_id.clone(),
form,
});
}
}
method_forms.push(ActionMethodForm {
id: method_id.clone(),
provider_forms,
});
}
Ok(ActionForms {
id: action_id.to_owned(),
name: action_name.unwrap_or(action_id.to_owned()),
method_forms,
})
}
pub async fn call(
&self,
action_id: &str,
method_id: &str,
provider_id: Option<&str>,
session: Session,
request: Request,
) -> Result<ResponseType, ShieldError> {
let method =
self.method_by_id(method_id)
.ok_or(ShieldError::Method(MethodError::NotFound(
method_id.to_owned(),
)))?;
let action = method
.erased_action_by_id(action_id)
.ok_or(ShieldError::Action(ActionError::NotFound(
action_id.to_owned(),
)))?;
let provider =
method
.erased_provider_by_id(provider_id)
.await?
.ok_or(ShieldError::Provider(ProviderError::NotFound(
provider_id.map(ToOwned::to_owned),
)))?;
let (base_session, method_session) = {
let session_data = session.data();
let session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;
(
session_data.base.clone(),
method.erased_deserialize_session(session_data.method_str(method_id))?,
)
};
let response = action
.erased_call(provider, &base_session, &*method_session, request)
.await?;
for session_action in &response.session_actions {
session_action
.call(method_id, provider_id, &session)
.await?;
}
Ok(response.r#type)
}
pub async fn user(&self, session: &Session) -> Result<Option<U>, ShieldError> {
let authentication = {
let session_data = session.data();
let session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;
session_data.base.authentication.clone()
};
match authentication {
Some(authentication) => {
if self
.provider_by_id(
&authentication.method_id,
authentication.provider_id.as_deref(),
)
.await?
.is_none()
{
session.purge().await?;
return Ok(None);
}
let user = self.storage().user_by_id(&authentication.user_id).await?;
if user.is_none() {
session.purge().await?;
}
Ok(user)
}
None => Ok(None),
}
}
#[cfg(feature = "utoipa")]
pub fn openapi(&self) -> OpenApi {
let mut paths = Paths::builder();
for method in self.methods.values() {
for action in method.erased_actions() {
use utoipa::openapi::Response;
let method_id = method.erased_id();
let action_id = action.erased_id();
paths = paths.path(
format!("/{}/{}/{{providerId}}", method_id, action_id),
PathItem::builder()
.operation(
action.erased_method().into(),
Operation::builder()
.operation_id(Some(format!(
"{}{}",
action_id.to_case(Case::Camel),
method_id.to_case(Case::UpperCamel)
)))
.summary(Some(action.erased_openapi_summary()))
.description(Some(action.erased_openapi_description()))
.tag("auth")
.parameters(Some(ActionPathParams::into_params(|| {
Some(ParameterIn::Path)
})))
.response(
"500",
Response::builder().description("Internal server error."),
),
)
.build(),
);
}
}
OpenApi::builder().paths(paths.build()).build()
}
}
#[cfg(test)]
mod tests {
use crate::{
options::ShieldOptions,
storage::tests::{TEST_STORAGE_ID, TestStorage},
};
use super::Shield;
#[test]
fn test_storage() {
let shield = Shield::new(TestStorage::default(), vec![], ShieldOptions::default());
assert_eq!(TEST_STORAGE_ID, shield.storage().id());
}
}