Skip to main content

alterion_encrypt/
interceptor.rs

1// SPDX-License-Identifier: GPL-3.0
2use actix_web::{
3    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
4    web, Error, HttpMessage,
5    body::{BoxBody, EitherBody, MessageBody},
6};
7use futures_util::future::{ready, LocalBoxFuture, Ready};
8use futures_util::TryStreamExt;
9use std::{rc::Rc, sync::Arc};
10use tokio::sync::RwLock;
11use alterion_ecdh::{KeyStore, HandshakeStore, ecdh, ecdh_ephemeral};
12use redis::aio::ConnectionManager;
13use crate::tools::crypt::aes_decrypt;
14use crate::tools::serializer::{deserialize_packet, build_signed_response_raw, derive_wrap_key};
15use zeroize::ZeroizeOnDrop;
16
17/// Raw decrypted request body, injected into Actix request extensions by [`Interceptor`] after a
18/// packet is successfully validated and decrypted.
19///
20/// Retrieve it inside a handler with:
21/// ```rust,ignore
22/// let body = req.extensions().get::<DecryptedBody>().cloned();
23/// ```
24/// `body.0` contains the original plaintext bytes as sent by the client (post-AES-GCM decrypt,
25/// before any application-level deserialisation). The bytes are in the same format the client
26/// packed them: msgpack-encoded `ByteBuf` wrapping deflate-compressed JSON.
27/// Use [`crate::tools::serializer::decode_request_payload`] to complete the decode.
28#[derive(Clone)]
29pub struct DecryptedBody(pub Vec<u8>);
30
31/// Per-request AES-256 session key, injected alongside [`DecryptedBody`].
32///
33/// The interceptor stores this so the **response** can be encrypted with the exact same key that
34/// the client generated for this request. The client holds the key in memory indexed by request
35/// ID and passes it to [`crate::tools::serializer::decode_response_packet`] to decrypt the reply.
36///
37/// Zeroized on drop — the key material is cleared from memory as soon as the response has been
38/// sent and this struct is dropped.
39#[derive(Clone, ZeroizeOnDrop)]
40pub struct RequestSessionKeys {
41    pub enc_key: [u8; 32],
42}
43
44/// Actix-web middleware that transparently decrypts incoming request bodies and encrypts outgoing
45/// response bodies using the X25519 ECDH + AES-256-GCM + HMAC-SHA256 pipeline.
46///
47/// # Usage
48/// ```rust,no_run
49/// use alterion_encrypt::interceptor::Interceptor;
50/// use alterion_encrypt::{init_key_store, init_handshake_store, start_rotation};
51///
52/// let store = init_key_store(3600);
53/// let hs    = init_handshake_store();
54/// start_rotation(store.clone(), 3600, hs.clone());
55/// // App::new().wrap(Interceptor { key_store: store, handshake_store: hs, replay_store: None })
56/// ```
57///
58/// **Request path** (POST / PUT / PATCH):
59/// 1. Collect raw body bytes.
60/// 2. MessagePack-decode a [`Request`](crate::tools::serializer::Request) and validate timestamp.
61/// 3. Perform X25519 ECDH using the server key identified by `key_id` and the client's ephemeral
62///    public key from the packet.
63/// 4. Derive a wrap key via HKDF-SHA256 and use it to AES-GCM-unwrap the client's `enc_key`.
64/// 5. AES-256-GCM-decrypt the payload using `enc_key`.
65/// 6. Inject `DecryptedBody` and `RequestSessionKeys` into request extensions.
66///
67/// Requests whose body is not a valid `Request` are passed through unchanged.
68///
69/// **Response path** (only when `RequestSessionKeys` was set):
70/// JSON → deflate compress → msgpack → AES-256-GCM (`enc_key`) → HMAC-SHA256 (mac key derived
71/// from `enc_key`) → [`Response`](crate::tools::serializer::Response) → msgpack.
72pub struct Interceptor {
73    pub key_store:       Arc<RwLock<KeyStore>>,
74    pub handshake_store: HandshakeStore,
75    pub replay_store:    Option<ConnectionManager>,
76}
77
78impl<S, B> Transform<S, ServiceRequest> for Interceptor
79where
80    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
81    B: MessageBody + 'static,
82{
83    type Response  = ServiceResponse<EitherBody<B>>;
84    type Error     = Error;
85    type Transform = InterceptorService<S>;
86    type InitError = ();
87    type Future    = Ready<Result<Self::Transform, Self::InitError>>;
88
89    fn new_transform(&self, service: S) -> Self::Future {
90        ready(Ok(InterceptorService {
91            service:         Rc::new(service),
92            key_store:       self.key_store.clone(),
93            handshake_store: self.handshake_store.clone(),
94            replay_store:    self.replay_store.clone(),
95        }))
96    }
97}
98
99/// The concrete [`Service`](actix_web::dev::Service) produced by [`Interceptor::new_transform`].
100///
101/// One instance is created per worker thread. Holds `Rc`-wrapped references to the inner service
102/// and `Arc`-shared references to the key/handshake/replay stores. Not constructed directly —
103/// Actix creates it automatically when the middleware is mounted.
104pub struct InterceptorService<S> {
105    service:         Rc<S>,
106    key_store:       Arc<RwLock<KeyStore>>,
107    handshake_store: HandshakeStore,
108    replay_store:    Option<ConnectionManager>,
109}
110
111impl<S, B> Service<ServiceRequest> for InterceptorService<S>
112where
113    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
114    B: MessageBody + 'static,
115{
116    type Response = ServiceResponse<EitherBody<B>>;
117    type Error    = Error;
118    type Future   = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
119
120    forward_ready!(service);
121
122    fn call(&self, mut req: ServiceRequest) -> Self::Future {
123        let service         = self.service.clone();
124        let key_store       = self.key_store.clone();
125        let handshake_store = self.handshake_store.clone();
126        let replay_store    = self.replay_store.clone();
127
128        Box::pin(async move {
129            let has_body = !matches!(req.method().as_str(), "GET" | "HEAD" | "OPTIONS");
130
131            if has_body {
132                let mut payload = req.take_payload();
133                let mut raw = web::BytesMut::new();
134                while let Some(chunk) = payload
135                    .try_next().await
136                    .map_err(actix_web::error::ErrorBadRequest)?
137                {
138                    raw.extend_from_slice(&chunk);
139                }
140
141                if !raw.is_empty() {
142                    match deserialize_packet(&raw) {
143                        Ok(packet) => {
144                            let client_pk_bytes: [u8; 32] = packet.client_pk.as_ref()
145                                .try_into()
146                                .map_err(|_| actix_web::error::ErrorBadRequest("client_pk must be 32 bytes"))?;
147
148                            let (shared_secret, server_pk) =
149                                if packet.key_id.starts_with("hs_") {
150                                    ecdh_ephemeral(&handshake_store, &packet.key_id, &client_pk_bytes)
151                                        .await
152                                        .map_err(|e| actix_web::error::ErrorUnauthorized(e.to_string()))?
153                                } else {
154                                    ecdh(&key_store, &packet.key_id, &client_pk_bytes)
155                                        .await
156                                        .map_err(|e| actix_web::error::ErrorUnauthorized(e.to_string()))?
157                                };
158
159                            let shared_bytes: &[u8; 32] = shared_secret.as_ref()
160                                .try_into()
161                                .map_err(|_| actix_web::error::ErrorInternalServerError("shared secret length invalid"))?;
162                            let wrap_key = derive_wrap_key(shared_bytes, &client_pk_bytes, &server_pk);
163
164                            let enc_key_bytes = aes_decrypt(packet.kx.as_ref(), &wrap_key)
165                                .map_err(|e| actix_web::error::ErrorUnauthorized(e.to_string()))?;
166                            let enc_key: [u8; 32] = enc_key_bytes.as_slice()
167                                .try_into()
168                                .map_err(|_| actix_web::error::ErrorBadRequest("enc_key must be 32 bytes"))?;
169
170                            if let Some(mut redis) = replay_store {
171                                let seen_key = format!("replay:seen:{}", hex::encode(packet.kx.as_ref()));
172                                let is_new: bool = redis::cmd("SET")
173                                    .arg(&seen_key).arg(1u8)
174                                    .arg("NX").arg("EX").arg(60u64)
175                                    .query_async(&mut redis).await
176                                    .map(|v: Option<String>| v.is_some())
177                                    .unwrap_or(true);
178                                if !is_new {
179                                    return Err(actix_web::error::ErrorUnauthorized("replay detected"));
180                                }
181                            }
182
183                            let decrypted = aes_decrypt(packet.data.as_ref(), &enc_key)
184                                .map_err(|e| actix_web::error::ErrorBadRequest(e.to_string()))?;
185
186                            req.extensions_mut().insert(DecryptedBody(decrypted));
187                            req.extensions_mut().insert(RequestSessionKeys { enc_key });
188                        }
189                        Err(_) => {
190                            let frozen: actix_web::web::Bytes = raw.freeze();
191                            let (_, mut pl) = actix_http::h1::Payload::create(true);
192                            pl.unread_data(frozen);
193                            req.set_payload(actix_web::dev::Payload::from(pl));
194                        }
195                    }
196                }
197            }
198
199            let session_keys = req.extensions().get::<RequestSessionKeys>().cloned();
200            let res          = service.call(req).await?;
201
202            let session_keys = match session_keys {
203                Some(k) => k,
204                None    => return Ok(res.map_into_left_body()),
205            };
206
207            let (req, res)   = res.into_parts();
208            let (head, body) = res.into_parts();
209
210            let body_bytes = actix_web::body::to_bytes(body)
211                .await
212                .map_err(|_| actix_web::error::ErrorInternalServerError("body collect failed"))?;
213
214            let encrypted = match build_signed_response_raw(&body_bytes, &session_keys.enc_key) {
215                Ok(b)  => b,
216                Err(_) => return Ok(ServiceResponse::new(
217                    req,
218                    head.set_body(BoxBody::new(body_bytes)).map_into_right_body(),
219                )),
220            };
221
222            Ok(ServiceResponse::new(
223                req,
224                head.set_body(BoxBody::new(encrypted)).map_into_right_body(),
225            ))
226        })
227    }
228}