use std::any::TypeId;
use std::sync::Arc;
use async_trait::async_trait;
use nest_rs_core::Layer;
use nest_rs_graphql::async_graphql::Context as GraphqlContext;
use nest_rs_http::poem::Request as HttpRequest;
use nest_rs_ws::{WsClient, WsMessageCheck};
use serde_json::Value;
use crate::denial::Denial;
#[async_trait]
pub trait Guard: Layer {
async fn check_http(&self, _req: &mut HttpRequest) -> Result<(), Denial> {
Ok(())
}
async fn check_graphql(&self, _ctx: &GraphqlContext<'_>) -> Result<(), Denial> {
Ok(())
}
async fn check_ws_message(
&self,
_client: &WsClient,
_event: &str,
_data: &Value,
) -> Result<(), Denial> {
Ok(())
}
}
#[async_trait]
impl<T: Guard + ?Sized> Guard for Arc<T> {
async fn check_http(&self, req: &mut HttpRequest) -> Result<(), Denial> {
(**self).check_http(req).await
}
async fn check_graphql(&self, ctx: &GraphqlContext<'_>) -> Result<(), Denial> {
(**self).check_graphql(ctx).await
}
async fn check_ws_message(
&self,
client: &WsClient,
event: &str,
data: &Value,
) -> Result<(), Denial> {
(**self).check_ws_message(client, event, data).await
}
}
pub struct GuardAsWsMessageCheck {
inner: Arc<dyn Guard>,
type_id: TypeId,
name: &'static str,
}
impl GuardAsWsMessageCheck {
pub fn new(inner: Arc<dyn Guard>, type_id: TypeId, name: &'static str) -> Self {
Self {
inner,
type_id,
name,
}
}
}
#[async_trait]
impl WsMessageCheck for GuardAsWsMessageCheck {
async fn check(
&self,
client: &WsClient,
event: &str,
data: &Value,
) -> std::result::Result<(), String> {
match self.inner.check_ws_message(client, event, data).await {
Ok(()) => Ok(()),
Err(denial) => Err(denial.message().to_owned()),
}
}
fn type_key(&self) -> TypeId {
self.type_id
}
fn layer_name(&self) -> &'static str {
self.name
}
}