appguard_rocket/
middleware.rs

1use rocket::fairing::{Fairing, Info, Kind};
2use rocket::http::Status;
3use rocket::{Data, Request, Response};
4
5use appguard_client_authentication::AuthHandler;
6use nullnet_libappguard::{
7    AppGuardFirewall, AppGuardGrpcInterface, AppGuardTcpResponse, FirewallPolicy,
8};
9
10use crate::conversions::{
11    to_appguard_http_request, to_appguard_http_response, to_appguard_tcp_connection,
12};
13
14/// `AppGuard` client configuration.
15pub struct AppGuardConfig {
16    client: AppGuardGrpcInterface,
17    timeout: Option<u64>,
18    default_policy: FirewallPolicy,
19    auth: AuthHandler,
20}
21
22impl AppGuardConfig {
23    /// Create a new configuration for the client.
24    ///
25    /// # Arguments
26    ///
27    /// * `host` - Hostname of the `AppGuard` server.
28    /// * `port` - Port of the `AppGuard` server.
29    /// * `tls` - Whether traffic to the `AppGuard` server should be secured with TLS.
30    /// * `timeout` - Timeout for calls to the `AppGuard` server (milliseconds).
31    /// * `default_policy` - Default firewall policy to apply when the `AppGuard` server times out.
32    /// * `firewall` - Firewall expressions (infix notation).
33    #[must_use]
34    pub async fn new(
35        host: &'static str,
36        port: u16,
37        tls: bool,
38        timeout: Option<u64>,
39        default_policy: FirewallPolicy,
40        firewall: String,
41    ) -> Option<Self> {
42        let mut client = AppGuardGrpcInterface::new(host, port, tls).await.ok()?;
43        let auth = AuthHandler::new(client.clone()).await;
44
45        let token = auth.get_token().await;
46        client
47            .update_firewall(AppGuardFirewall { token, firewall })
48            .await
49            .ok()?;
50
51        Some(AppGuardConfig {
52            client,
53            timeout,
54            default_policy,
55            auth,
56        })
57    }
58}
59
60#[rocket::async_trait]
61impl Fairing for AppGuardConfig {
62    fn info(&self) -> Info {
63        Info {
64            name: "AppGuard",
65            kind: Kind::Request | Kind::Response,
66        }
67    }
68
69    async fn on_request(&self, req: &mut Request<'_>, _data: &mut Data<'_>) {
70        let mut client = self.client.clone();
71        let token = self.auth.get_token().await;
72
73        let AppGuardTcpResponse { tcp_info } = client
74            .handle_tcp_connection(self.timeout, to_appguard_tcp_connection(req, token.clone()))
75            .await
76            .expect("Internal server error");
77
78        req.local_cache(|| tcp_info.clone());
79
80        let request_handler_res = client
81            .handle_http_request(
82                self.timeout,
83                self.default_policy,
84                to_appguard_http_request(req, tcp_info, token),
85            )
86            .await
87            .expect("Internal server error");
88
89        let policy = FirewallPolicy::try_from(request_handler_res.policy).unwrap_or_default();
90        assert_ne!(policy, FirewallPolicy::Deny, "Unauthorized");
91    }
92
93    async fn on_response<'r>(&self, req: &'r Request<'_>, resp: &mut Response<'r>) {
94        let mut client = self.client.clone();
95        let token = self.auth.get_token().await;
96
97        let tcp_info = req.local_cache(|| None);
98
99        let Ok(response_handler_res) = client
100            .handle_http_response(
101                self.timeout,
102                self.default_policy,
103                to_appguard_http_response(resp, tcp_info.to_owned(), token),
104            )
105            .await
106        else {
107            *resp = internal_server_error_response();
108            return;
109        };
110
111        let policy = FirewallPolicy::try_from(response_handler_res.policy).unwrap_or_default();
112        if policy == FirewallPolicy::Deny {
113            *resp = unauthorized_response();
114            return;
115        }
116    }
117}
118
119fn unauthorized_response<'r>() -> Response<'r> {
120    let mut response = Response::new();
121    let body = "Unauthorized";
122    response.set_sized_body(body.len(), std::io::Cursor::new(body));
123    response.set_status(Status::Unauthorized);
124    response
125}
126
127fn internal_server_error_response<'r>() -> Response<'r> {
128    let mut response = Response::new();
129    let body = "Internal server error";
130    response.set_sized_body(body.len(), std::io::Cursor::new(body));
131    response.set_status(Status::InternalServerError);
132    response
133}