appguard_rocket/
middleware.rs

1use rocket::fairing::{Fairing, Info, Kind};
2use rocket::http::Status;
3use rocket::{Data, Request, Response};
4
5use crate::conversions::{
6    to_appguard_http_request, to_appguard_http_response, to_appguard_tcp_connection,
7};
8use appguard_client_authentication::Context;
9use nullnet_libappguard::appguard::AppGuardTcpResponse;
10use nullnet_libappguard::appguard_commands::FirewallPolicy;
11
12/// `AppGuard` middleware.
13pub struct AppGuardMiddleware {
14    ctx: Context,
15}
16
17impl AppGuardMiddleware {
18    /// Create a new `AppGuard` middleware instance.
19    #[must_use]
20    pub async fn new() -> Option<Self> {
21        let ctx = Context::new(String::from("Rocket")).await.ok()?;
22
23        Some(AppGuardMiddleware { ctx })
24    }
25}
26
27#[rocket::async_trait]
28impl Fairing for AppGuardMiddleware {
29    fn info(&self) -> Info {
30        Info {
31            name: "AppGuard",
32            kind: Kind::Request | Kind::Response,
33        }
34    }
35
36    async fn on_request(&self, req: &mut Request<'_>, _data: &mut Data<'_>) {
37        let mut server = self.ctx.server.clone();
38        let token = self.ctx.token_provider.get().await.unwrap_or_default();
39        let fw_defaults = *self.ctx.firewall_defaults.lock().await;
40        let timeout = fw_defaults.timeout;
41        let default_policy = FirewallPolicy::try_from(fw_defaults.policy).unwrap_or_default();
42
43        let AppGuardTcpResponse { tcp_info } = server
44            .handle_tcp_connection(timeout, to_appguard_tcp_connection(req, token.clone()))
45            .await
46            .expect("Internal server error");
47
48        req.local_cache(|| tcp_info.clone());
49
50        let request_handler_res = server
51            .handle_http_request(
52                timeout,
53                default_policy,
54                to_appguard_http_request(req, tcp_info, token),
55            )
56            .await
57            .expect("Internal server error");
58
59        let policy = FirewallPolicy::try_from(request_handler_res.policy).unwrap_or_default();
60        assert_ne!(policy, FirewallPolicy::Deny, "Unauthorized");
61    }
62
63    async fn on_response<'r>(&self, req: &'r Request<'_>, resp: &mut Response<'r>) {
64        let mut server = self.ctx.server.clone();
65        let token = self.ctx.token_provider.get().await.unwrap_or_default();
66        let fw_defaults = *self.ctx.firewall_defaults.lock().await;
67        let timeout = fw_defaults.timeout;
68        let default_policy = FirewallPolicy::try_from(fw_defaults.policy).unwrap_or_default();
69
70        let tcp_info = req.local_cache(|| None);
71
72        let Ok(response_handler_res) = server
73            .handle_http_response(
74                timeout,
75                default_policy,
76                to_appguard_http_response(resp, tcp_info.to_owned(), token),
77            )
78            .await
79        else {
80            *resp = internal_server_error_response();
81            return;
82        };
83
84        let policy = FirewallPolicy::try_from(response_handler_res.policy).unwrap_or_default();
85        if policy == FirewallPolicy::Deny {
86            *resp = unauthorized_response();
87            return;
88        }
89    }
90}
91
92fn unauthorized_response<'r>() -> Response<'r> {
93    let mut response = Response::new();
94    let body = "Unauthorized";
95    response.set_sized_body(body.len(), std::io::Cursor::new(body));
96    response.set_status(Status::Unauthorized);
97    response
98}
99
100fn internal_server_error_response<'r>() -> Response<'r> {
101    let mut response = Response::new();
102    let body = "Internal server error";
103    response.set_sized_body(body.len(), std::io::Cursor::new(body));
104    response.set_status(Status::InternalServerError);
105    response
106}