Skip to main content

nest_rs_guards/
guard.rs

1//! The unified [`Guard`] trait — extends [`Layer`] so guards plug into the
2//! Layer System (dedup-by-`TypeId`, declaration-order chain).
3
4use std::any::TypeId;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use nest_rs_core::Layer;
9use nest_rs_graphql::async_graphql::Context as GraphqlContext;
10use nest_rs_http::poem::Request as HttpRequest;
11use nest_rs_ws::{WsClient, WsMessageCheck};
12use serde_json::Value;
13
14use crate::denial::Denial;
15
16/// A transport-spanning guard.
17///
18/// One impl, three transports. Override only the `check_*` method(s) where
19/// this guard has work to do; the rest inherit `Ok(())` defaults — a
20/// no-op means "doesn't apply to this transport," not "skip security."
21///
22/// `Guard` extends [`Layer`] (priority + name + dedup-by-TypeId). The
23/// `#[public]` marker is NOT a framework skip: it attaches the
24/// [`Public`](nest_rs_core::Public) data to the request and each guard
25/// decides whether to honor it. An `AbilityGuard` may want to apply
26/// visitor rules on public routes; an `AuthGuard` may want to skip
27/// rejection when no token is present. Both are policy decisions the
28/// guard owns, not the framework.
29///
30/// See the crate-level docs for copy-paste templates.
31#[async_trait]
32pub trait Guard: Layer {
33    /// HTTP request entry. Default = no-op (this guard doesn't apply to HTTP).
34    async fn check_http(&self, _req: &mut HttpRequest) -> Result<(), Denial> {
35        Ok(())
36    }
37
38    /// GraphQL resolver entry. Default = no-op.
39    async fn check_graphql(&self, _ctx: &GraphqlContext<'_>) -> Result<(), Denial> {
40        Ok(())
41    }
42
43    /// WebSocket per-message entry. Default = no-op.
44    async fn check_ws_message(
45        &self,
46        _client: &WsClient,
47        _event: &str,
48        _data: &Value,
49    ) -> Result<(), Denial> {
50        Ok(())
51    }
52}
53
54#[async_trait]
55impl<T: Guard + ?Sized> Guard for Arc<T> {
56    async fn check_http(&self, req: &mut HttpRequest) -> Result<(), Denial> {
57        (**self).check_http(req).await
58    }
59
60    async fn check_graphql(&self, ctx: &GraphqlContext<'_>) -> Result<(), Denial> {
61        (**self).check_graphql(ctx).await
62    }
63
64    async fn check_ws_message(
65        &self,
66        client: &WsClient,
67        event: &str,
68        data: &Value,
69    ) -> Result<(), Denial> {
70        (**self).check_ws_message(client, event, data).await
71    }
72}
73
74/// Newtype adapter that lets any [`Guard`] satisfy the
75/// [`WsMessageCheck`](nest_rs_ws::WsMessageCheck) interface — the bridge the
76/// `#[messages]` macro uses to put guards in the per-event chain table
77/// without nest-rs-ws depending on nest-rs-guards.
78pub struct GuardAsWsMessageCheck {
79    inner: Arc<dyn Guard>,
80    type_id: TypeId,
81    name: &'static str,
82}
83
84impl GuardAsWsMessageCheck {
85    pub fn new(inner: Arc<dyn Guard>, type_id: TypeId, name: &'static str) -> Self {
86        Self {
87            inner,
88            type_id,
89            name,
90        }
91    }
92}
93
94#[async_trait]
95impl WsMessageCheck for GuardAsWsMessageCheck {
96    async fn check(
97        &self,
98        client: &WsClient,
99        event: &str,
100        data: &Value,
101    ) -> std::result::Result<(), String> {
102        match self.inner.check_ws_message(client, event, data).await {
103            Ok(()) => Ok(()),
104            Err(denial) => Err(denial.message().to_owned()),
105        }
106    }
107
108    fn type_key(&self) -> TypeId {
109        self.type_id
110    }
111
112    fn layer_name(&self) -> &'static str {
113        self.name
114    }
115}