appguard_rocket/
middleware.rs1use rocket::fairing::{Fairing, Info, Kind};
2use rocket::http::Status;
3use rocket::{Data, Request, Response};
4
5use crate::conversions::{
6 to_appguard_http_request, to_appguard_http_response, to_appguard_tcp_connection,
7};
8use appguard_client_authentication::Context;
9use nullnet_libappguard::appguard::AppGuardTcpResponse;
10use nullnet_libappguard::appguard_commands::FirewallPolicy;
11
12pub struct AppGuardMiddleware {
14 ctx: Context,
15}
16
17impl AppGuardMiddleware {
18 #[must_use]
20 pub async fn new() -> Option<Self> {
21 let ctx = Context::new(String::from("Rocket")).await.ok()?;
22
23 Some(AppGuardMiddleware { ctx })
24 }
25}
26
27#[rocket::async_trait]
28impl Fairing for AppGuardMiddleware {
29 fn info(&self) -> Info {
30 Info {
31 name: "AppGuard",
32 kind: Kind::Request | Kind::Response,
33 }
34 }
35
36 async fn on_request(&self, req: &mut Request<'_>, _data: &mut Data<'_>) {
37 let mut server = self.ctx.server.clone();
38 let token = self.ctx.token_provider.get().await.unwrap_or_default();
39 let fw_defaults = *self.ctx.firewall_defaults.lock().await;
40 let timeout = fw_defaults.timeout;
41 let default_policy = FirewallPolicy::try_from(fw_defaults.policy).unwrap_or_default();
42
43 let AppGuardTcpResponse { tcp_info } = server
44 .handle_tcp_connection(timeout, to_appguard_tcp_connection(req, token.clone()))
45 .await
46 .expect("Internal server error");
47
48 req.local_cache(|| tcp_info.clone());
49
50 let request_handler_res = server
51 .handle_http_request(
52 timeout,
53 default_policy,
54 to_appguard_http_request(req, tcp_info, token),
55 )
56 .await
57 .expect("Internal server error");
58
59 let policy = FirewallPolicy::try_from(request_handler_res.policy).unwrap_or_default();
60 assert_ne!(policy, FirewallPolicy::Deny, "Unauthorized");
61 }
62
63 async fn on_response<'r>(&self, req: &'r Request<'_>, resp: &mut Response<'r>) {
64 let mut server = self.ctx.server.clone();
65 let token = self.ctx.token_provider.get().await.unwrap_or_default();
66 let fw_defaults = *self.ctx.firewall_defaults.lock().await;
67 let timeout = fw_defaults.timeout;
68 let default_policy = FirewallPolicy::try_from(fw_defaults.policy).unwrap_or_default();
69
70 let tcp_info = req.local_cache(|| None);
71
72 let Ok(response_handler_res) = server
73 .handle_http_response(
74 timeout,
75 default_policy,
76 to_appguard_http_response(resp, tcp_info.to_owned(), token),
77 )
78 .await
79 else {
80 *resp = internal_server_error_response();
81 return;
82 };
83
84 let policy = FirewallPolicy::try_from(response_handler_res.policy).unwrap_or_default();
85 if policy == FirewallPolicy::Deny {
86 *resp = unauthorized_response();
87 return;
88 }
89 }
90}
91
92fn unauthorized_response<'r>() -> Response<'r> {
93 let mut response = Response::new();
94 let body = "Unauthorized";
95 response.set_sized_body(body.len(), std::io::Cursor::new(body));
96 response.set_status(Status::Unauthorized);
97 response
98}
99
100fn internal_server_error_response<'r>() -> Response<'r> {
101 let mut response = Response::new();
102 let body = "Internal server error";
103 response.set_sized_body(body.len(), std::io::Cursor::new(body));
104 response.set_status(Status::InternalServerError);
105 response
106}