use std::any::Any;
use crate::{
error::ShieldError,
form::Form,
provider::Provider,
request::{Request, RequestMethod},
response::Response,
session::{BaseSession, MethodSession},
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[async_trait]
pub trait Action: Send + Sync {
fn id(&self) -> &'static str;
fn name(&self) -> &'static str;
fn openapi_summary(&self) -> &'static str;
fn openapi_description(&self) -> &'static str;
fn method(&self) -> RequestMethod;
async fn forms(&self) -> Result<Vec<Form>, ShieldError>;
async fn call(&self, session: &BaseSession, request: Request) -> Result<Response, ShieldError>;
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase")]
pub struct ActionForms {
pub id: String,
pub name: String,
pub forms: Vec<Form>,
pub method_forms: Vec<ActionMethodForm>,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase")]
pub struct ActionMethodForm {
pub id: String,
pub provider_forms: Vec<ActionProviderForm>,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde(rename_all = "camelCase")]
pub struct ActionProviderForm {
pub id: Option<String>,
pub form: Form,
}
#[async_trait]
pub trait MethodAction<P: Provider, S>: ErasedMethodAction + Send + Sync {
fn id(&self) -> String;
fn name(&self) -> String;
fn openapi_summary(&self) -> &'static str;
fn openapi_description(&self) -> &'static str;
fn method(&self) -> RequestMethod;
fn condition(&self, _provider: &P, _session: &MethodSession<S>) -> Result<bool, ShieldError> {
Ok(true)
}
async fn forms(&self, provider: P) -> Result<Vec<Form>, ShieldError>;
async fn call(
&self,
provider: P,
session: &MethodSession<S>,
request: Request,
) -> Result<Response, ShieldError>;
}
#[async_trait]
pub trait ErasedMethodAction: Send + Sync {
fn erased_id(&self) -> String;
fn erased_name(&self) -> String;
fn erased_openapi_summary(&self) -> &'static str;
fn erased_openapi_description(&self) -> &'static str;
fn erased_method(&self) -> RequestMethod;
fn erased_condition(
&self,
provider: &(dyn Any + Send + Sync),
base_session: &BaseSession,
method_session: &(dyn Any + Send + Sync),
) -> Result<bool, ShieldError>;
async fn erased_forms(
&self,
provider: Box<dyn Any + Send + Sync>,
) -> Result<Vec<Form>, ShieldError>;
async fn erased_call(
&self,
provider: Box<dyn Any + Send + Sync>,
base_session: &BaseSession,
method_session: &(dyn Any + Send + Sync),
request: Request,
) -> Result<Response, ShieldError>;
}
#[macro_export]
macro_rules! erased_method_action {
($action:ident $(, < $( $generic_name:ident : $generic_type:ident ),+ > )*) => {
#[async_trait]
impl $( < $( $generic_name: $generic_type + 'static ),+ > )* $crate::ErasedMethodAction for $action $( < $( $generic_name ),+ > )* {
fn erased_id(&self) -> String {
self.id()
}
fn erased_name(&self) -> String {
self.name()
}
fn erased_openapi_summary(&self) -> &'static str {
self.openapi_summary()
}
fn erased_openapi_description(&self) -> &'static str {
self.openapi_description()
}
fn erased_method(&self) -> $crate::RequestMethod {
self.method()
}
fn erased_condition(
&self,
provider: &(dyn std::any::Any + Send + Sync),
base_session: &$crate::BaseSession,
method_session: &(dyn std::any::Any + Send + Sync)
) -> Result<bool, $crate::ShieldError> {
self.condition(
provider.downcast_ref().expect("Provider should be downcast"),
&MethodSession {
base: base_session,
method: method_session.downcast_ref().expect("Session should be downcast"),
},
)
}
async fn erased_forms(
&self,
provider: Box<dyn std::any::Any + Send + Sync>
) -> Result<Vec<$crate::Form>, $crate::ShieldError> {
self.forms(*provider.downcast().expect("Provider should be downcast")).await
}
async fn erased_call(
&self,
provider: Box<dyn std::any::Any + Send + Sync>,
base_session: &$crate::BaseSession,
method_session: &(dyn std::any::Any + Send + Sync),
request: $crate::Request,
) -> Result<$crate::Response, $crate::ShieldError> {
self
.call(
*provider.downcast().expect("Provider should be downcast"),
&$crate::MethodSession {
base: base_session,
method: method_session.downcast_ref().expect("Session should be downcast"),
},
request
)
.await
}
}
};
}