1use std::any::TypeId;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use nest_rs_core::Layer;
9use nest_rs_graphql::async_graphql::Context as GraphqlContext;
10use nest_rs_http::poem::Request as HttpRequest;
11use nest_rs_ws::{WsClient, WsMessageCheck};
12use serde_json::Value;
13
14use crate::denial::Denial;
15
16#[async_trait]
32pub trait Guard: Layer {
33 async fn check_http(&self, _req: &mut HttpRequest) -> Result<(), Denial> {
35 Ok(())
36 }
37
38 async fn check_graphql(&self, _ctx: &GraphqlContext<'_>) -> Result<(), Denial> {
40 Ok(())
41 }
42
43 async fn check_ws_message(
45 &self,
46 _client: &WsClient,
47 _event: &str,
48 _data: &Value,
49 ) -> Result<(), Denial> {
50 Ok(())
51 }
52}
53
54#[async_trait]
55impl<T: Guard + ?Sized> Guard for Arc<T> {
56 async fn check_http(&self, req: &mut HttpRequest) -> Result<(), Denial> {
57 (**self).check_http(req).await
58 }
59
60 async fn check_graphql(&self, ctx: &GraphqlContext<'_>) -> Result<(), Denial> {
61 (**self).check_graphql(ctx).await
62 }
63
64 async fn check_ws_message(
65 &self,
66 client: &WsClient,
67 event: &str,
68 data: &Value,
69 ) -> Result<(), Denial> {
70 (**self).check_ws_message(client, event, data).await
71 }
72}
73
74pub struct GuardAsWsMessageCheck {
79 inner: Arc<dyn Guard>,
80 type_id: TypeId,
81 name: &'static str,
82}
83
84impl GuardAsWsMessageCheck {
85 pub fn new(inner: Arc<dyn Guard>, type_id: TypeId, name: &'static str) -> Self {
86 Self {
87 inner,
88 type_id,
89 name,
90 }
91 }
92}
93
94#[async_trait]
95impl WsMessageCheck for GuardAsWsMessageCheck {
96 async fn check(
97 &self,
98 client: &WsClient,
99 event: &str,
100 data: &Value,
101 ) -> std::result::Result<(), String> {
102 match self.inner.check_ws_message(client, event, data).await {
103 Ok(()) => Ok(()),
104 Err(denial) => Err(denial.message().to_owned()),
105 }
106 }
107
108 fn type_key(&self) -> TypeId {
109 self.type_id
110 }
111
112 fn layer_name(&self) -> &'static str {
113 self.name
114 }
115}