use std::any::Any;
use crate::{
error::ShieldError,
form::Form,
provider::Provider,
request::Request,
response::Response,
session::{BaseSession, MethodSession},
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[cfg(feature = "utoipa")]
use utoipa::openapi::HttpMethod;
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum ActionMethod {
Get,
Post,
Put,
Delete,
Options,
Head,
Patch,
Trace,
}
#[cfg(feature = "utoipa")]
impl From<ActionMethod> for HttpMethod {
fn from(value: ActionMethod) -> Self {
match value {
ActionMethod::Get => Self::Get,
ActionMethod::Post => Self::Post,
ActionMethod::Put => Self::Put,
ActionMethod::Delete => Self::Delete,
ActionMethod::Options => Self::Options,
ActionMethod::Head => Self::Head,
ActionMethod::Patch => Self::Patch,
ActionMethod::Trace => Self::Trace,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase")]
pub struct ActionForms {
pub id: String,
pub name: String,
pub method_forms: Vec<ActionMethodForm>,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase")]
pub struct ActionMethodForm {
pub id: String,
pub provider_forms: Vec<ActionProviderForm>,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase")]
pub struct ActionProviderForm {
pub id: Option<String>,
pub form: Form,
}
#[async_trait]
pub trait Action<P: Provider, S>: ErasedAction + Send + Sync {
fn id(&self) -> String;
fn name(&self) -> String;
fn openapi_summary(&self) -> &'static str;
fn openapi_description(&self) -> &'static str;
fn method(&self) -> ActionMethod;
fn condition(&self, _provider: &P, _session: &MethodSession<S>) -> Result<bool, ShieldError> {
Ok(true)
}
async fn forms(&self, provider: P) -> Result<Vec<Form>, ShieldError>;
async fn call(
&self,
provider: P,
session: &MethodSession<S>,
request: Request,
) -> Result<Response, ShieldError>;
}
#[async_trait]
pub trait ErasedAction: Send + Sync {
fn erased_id(&self) -> String;
fn erased_name(&self) -> String;
fn erased_openapi_summary(&self) -> &'static str;
fn erased_openapi_description(&self) -> &'static str;
fn erased_method(&self) -> ActionMethod;
fn erased_condition(
&self,
provider: &(dyn Any + Send + Sync),
base_session: &BaseSession,
method_session: &(dyn Any + Send + Sync),
) -> Result<bool, ShieldError>;
async fn erased_forms(
&self,
provider: Box<dyn Any + Send + Sync>,
) -> Result<Vec<Form>, ShieldError>;
async fn erased_call(
&self,
provider: Box<dyn Any + Send + Sync>,
base_session: &BaseSession,
method_session: &(dyn Any + Send + Sync),
request: Request,
) -> Result<Response, ShieldError>;
}
#[macro_export]
macro_rules! erased_action {
($action:ident $(, < $( $generic_name:ident : $generic_type:ident ),+ > )*) => {
#[async_trait]
impl $( < $( $generic_name: $generic_type + 'static ),+ > )* $crate::ErasedAction for $action $( < $( $generic_name ),+ > )* {
fn erased_id(&self) -> String {
self.id()
}
fn erased_name(&self) -> String {
self.name()
}
fn erased_openapi_summary(&self) -> &'static str {
self.openapi_summary()
}
fn erased_openapi_description(&self) -> &'static str {
self.openapi_description()
}
fn erased_method(&self) -> $crate::ActionMethod {
self.method()
}
fn erased_condition(
&self,
provider: &(dyn std::any::Any + Send + Sync),
base_session: &$crate::BaseSession,
method_session: &(dyn std::any::Any + Send + Sync)
) -> Result<bool, $crate::ShieldError> {
self.condition(
provider.downcast_ref().expect("Provider should be downcast"),
&MethodSession {
base: base_session,
method: method_session.downcast_ref().expect("Session should be downcast"),
},
)
}
async fn erased_forms(
&self,
provider: Box<dyn std::any::Any + Send + Sync>
) -> Result<Vec<$crate::Form>, $crate::ShieldError> {
self.forms(*provider.downcast().expect("Provider should be downcast")).await
}
async fn erased_call(
&self,
provider: Box<dyn std::any::Any + Send + Sync>,
base_session: &$crate::BaseSession,
method_session: &(dyn std::any::Any + Send + Sync),
request: $crate::Request,
) -> Result<$crate::Response, $crate::ShieldError> {
self
.call(
*provider.downcast().expect("Provider should be downcast"),
&$crate::MethodSession {
base: base_session,
method: method_session.downcast_ref().expect("Session should be downcast"),
},
request
)
.await
}
}
};
}