Skip to main content

alterion_encrypt/
interceptor.rs

1// SPDX-License-Identifier: GPL-3.0
2use actix_web::{
3    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
4    web, Error, HttpMessage,
5    body::{BoxBody, EitherBody, MessageBody},
6};
7use futures_util::future::{ready, LocalBoxFuture, Ready};
8use futures_util::TryStreamExt;
9use std::{
10    collections::HashMap,
11    rc::Rc,
12    sync::Arc,
13    time::{Duration, Instant},
14};
15use tokio::sync::{Mutex, RwLock};
16use alterion_ecdh::{KeyStore, HandshakeStore, ecdh, ecdh_ephemeral};
17use redis::aio::ConnectionManager;
18use serde_bytes::ByteBuf;
19use crate::tools::crypt::aes_decrypt;
20use crate::tools::serializer::{
21    deserialize_packet, deserialize, decompress,
22    build_signed_response_raw, derive_wrap_key,
23    MAX_DECOMPRESSED_SIZE,
24};
25use zeroize::ZeroizeOnDrop;
26
27/// Default maximum raw request body size (1 MiB). Override via [`Interceptor::max_body_bytes`].
28pub const DEFAULT_MAX_BODY_BYTES: usize = 1024 * 1024;
29
30/// In-memory replay protection store for environments without Redis.
31///
32/// Tracks seen `kx` hashes for a configurable TTL and prunes expired entries on each check.
33/// Wrap in `Arc` (via [`MemoryReplayStore::new`]) and share across all Actix workers via [`Interceptor`].
34pub struct MemoryReplayStore {
35    seen: Mutex<HashMap<String, Instant>>,
36    ttl:  Duration,
37}
38
39impl MemoryReplayStore {
40    /// Creates an `Arc`-wrapped store whose entries expire after `ttl`.
41    pub fn new(ttl: Duration) -> Arc<Self> {
42        Arc::new(Self { seen: Mutex::default(), ttl })
43    }
44
45    /// Returns `true` if `key` is new (not seen within `ttl`), inserting it.
46    /// Returns `false` on replay. Prunes expired entries on every call.
47    pub async fn is_new(&self, key: &str) -> bool {
48        let mut map = self.seen.lock().await;
49        let now = Instant::now();
50        map.retain(|_, inserted_at| now.duration_since(*inserted_at) < self.ttl);
51        if map.contains_key(key) {
52            return false;
53        }
54        map.insert(key.to_string(), now);
55        true
56    }
57}
58
59/// Raw decrypted request body, injected into Actix request extensions by [`Interceptor`] after a
60/// packet is successfully validated and decrypted.
61///
62/// Retrieve it inside a handler with:
63/// ```rust,ignore
64/// let body = req.extensions().get::<DecryptedBody>().cloned();
65/// ```
66/// `body.0` contains the original plaintext bytes as sent by the client (post-AES-GCM decrypt,
67/// before any application-level deserialisation). The bytes are in the same format the client
68/// packed them: msgpack-encoded `ByteBuf` wrapping deflate-compressed JSON.
69/// Use [`crate::tools::serializer::decode_request_payload`] to complete the decode.
70#[derive(Clone)]
71pub struct DecryptedBody(pub Vec<u8>);
72
73/// Per-request AES-256 session key, injected alongside [`DecryptedBody`].
74///
75/// The interceptor stores this so the **response** can be encrypted with the exact same key that
76/// the client generated for this request. The client holds the key in memory indexed by request
77/// ID and passes it to [`crate::tools::serializer::decode_response_packet`] to decrypt the reply.
78///
79/// Zeroized on drop — the key material is cleared from memory as soon as the response has been
80/// sent and this struct is dropped.
81#[derive(Clone, ZeroizeOnDrop)]
82pub struct RequestSessionKeys {
83    pub enc_key: [u8; 32],
84}
85
86/// Actix-web middleware that transparently decrypts incoming request bodies and encrypts outgoing
87/// response bodies using the X25519 ECDH + AES-256-GCM + HMAC-SHA256 pipeline.
88///
89/// # Usage
90///
91/// Prefer [`Interceptor::new_with_memory_replay`] for new deployments — it enables in-memory
92/// replay protection and sensible body/decompression size limits without requiring Redis:
93///
94/// ```rust,no_run
95/// use alterion_encrypt::interceptor::Interceptor;
96/// use alterion_encrypt::{init_key_store, init_handshake_store, start_rotation};
97///
98/// let store = init_key_store(3600);
99/// let hs    = init_handshake_store();
100/// start_rotation(store.clone(), 3600, hs.clone());
101/// // App::new().wrap(Interceptor::new_with_memory_replay(store, hs))
102/// ```
103///
104/// To tune size limits or add Redis replay protection after construction:
105/// ```rust,no_run
106/// # use alterion_encrypt::interceptor::Interceptor;
107/// # use alterion_encrypt::{init_key_store, init_handshake_store};
108/// # let store = init_key_store(3600);
109/// # let hs    = init_handshake_store();
110/// let mut interceptor = Interceptor::new_with_memory_replay(store, hs);
111/// interceptor.max_body_bytes        = 5 * 1024 * 1024;  // 5 MiB raw body
112/// interceptor.max_decompressed_bytes = 50 * 1024 * 1024; // 50 MiB decompressed
113/// // interceptor.replay_store = Some(redis_connection_manager);
114/// ```
115///
116/// **Request path** (POST / PUT / PATCH, and GET when `allow_encrypted_get` is `true`):
117/// 1. Collect raw body bytes up to `max_body_bytes` — reject 413 if exceeded.
118/// 2. MessagePack-decode a [`Request`](crate::tools::serializer::Request) and validate timestamp.
119/// 3. Check the replay store (Redis → in-memory fallback). Fails closed on store error.
120/// 4. ECDH → wrap key → AES-GCM unwrap `enc_key` → AES-256-GCM decrypt payload.
121/// 5. Preflight decompress against `max_decompressed_bytes` — reject 413 if exceeded.
122/// 6. Inject `DecryptedBody` and `RequestSessionKeys` into request extensions.
123///
124/// Requests whose body is not a valid encrypted `Request` are passed through unchanged.
125///
126/// **Response path** (only when `RequestSessionKeys` is present):
127/// JSON → deflate → msgpack → AES-256-GCM → HMAC-SHA256 → [`Response`](crate::tools::serializer::Response) → msgpack.
128pub struct Interceptor {
129    pub key_store:             Arc<RwLock<KeyStore>>,
130    pub handshake_store:       HandshakeStore,
131    /// Redis-backed replay store. Takes precedence over `memory_replay_store` when `Some`.
132    pub replay_store:          Option<ConnectionManager>,
133    /// In-memory replay store used when `replay_store` is `None`.
134    /// Initialized automatically by [`Interceptor::new_with_memory_replay`].
135    pub memory_replay_store:   Option<Arc<MemoryReplayStore>>,
136    /// Maximum raw (compressed + encrypted) request body in bytes. Requests exceeding this are
137    /// rejected with 413 before any decryption occurs. Default: [`DEFAULT_MAX_BODY_BYTES`] (1 MiB).
138    pub max_body_bytes:        usize,
139    /// Maximum decompressed payload size in bytes. Requests whose payload would expand beyond this
140    /// are rejected with 413 after decryption but before the handler sees the body. Set this to
141    /// whatever your largest valid request body is — there is no upper bound imposed by the library.
142    /// Default: [`MAX_DECOMPRESSED_SIZE`] (10 MiB).
143    pub max_decompressed_bytes: usize,
144    /// When `true`, GET requests that carry a body are processed through the full encrypt/decrypt
145    /// pipeline identically to POST/PUT/PATCH. The client sends the msgpack-encoded [`Request`]
146    /// as the GET body using the same `buildRequestPacket` function. Default: `false`.
147    pub allow_encrypted_get:   bool,
148}
149
150impl Interceptor {
151    /// Creates an `Interceptor` with in-memory replay protection and default size limits
152    /// (1 MiB raw body, 10 MiB decompressed).
153    ///
154    /// This is the recommended constructor for new deployments. Tune `max_body_bytes` and
155    /// `max_decompressed_bytes` on the returned value for your workload, or assign `replay_store`
156    /// to upgrade to Redis-backed replay detection for multi-instance deployments.
157    pub fn new_with_memory_replay(
158        key_store:       Arc<RwLock<KeyStore>>,
159        handshake_store: HandshakeStore,
160    ) -> Self {
161        Self {
162            key_store,
163            handshake_store,
164            replay_store:           None,
165            memory_replay_store:    Some(MemoryReplayStore::new(Duration::from_secs(90))),
166            max_body_bytes:         DEFAULT_MAX_BODY_BYTES,
167            max_decompressed_bytes: MAX_DECOMPRESSED_SIZE,
168            allow_encrypted_get:    false,
169        }
170    }
171}
172
173impl<S, B> Transform<S, ServiceRequest> for Interceptor
174where
175    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
176    B: MessageBody + 'static,
177{
178    type Response  = ServiceResponse<EitherBody<B>>;
179    type Error     = Error;
180    type Transform = InterceptorService<S>;
181    type InitError = ();
182    type Future    = Ready<Result<Self::Transform, Self::InitError>>;
183
184    fn new_transform(&self, service: S) -> Self::Future {
185        if self.replay_store.is_none() && self.memory_replay_store.is_none() {
186            tracing::warn!(
187                "alterion-encrypt: no replay_store configured — replay attacks are possible \
188                 within the 30-second timestamp window. Use Interceptor::new_with_memory_replay() \
189                 or configure a Redis ConnectionManager for production deployments."
190            );
191        }
192        ready(Ok(InterceptorService {
193            service:               Rc::new(service),
194            key_store:             self.key_store.clone(),
195            handshake_store:       self.handshake_store.clone(),
196            replay_store:          self.replay_store.clone(),
197            memory_replay_store:   self.memory_replay_store.clone(),
198            max_body_bytes:        self.max_body_bytes,
199            max_decompressed_bytes: self.max_decompressed_bytes,
200            allow_encrypted_get:   self.allow_encrypted_get,
201        }))
202    }
203}
204
205/// The concrete [`Service`](actix_web::dev::Service) produced by [`Interceptor::new_transform`].
206///
207/// One instance is created per worker thread. Holds `Rc`-wrapped references to the inner service
208/// and `Arc`-shared references to the key/handshake/replay stores. Not constructed directly —
209/// Actix creates it automatically when the middleware is mounted.
210pub struct InterceptorService<S> {
211    service:               Rc<S>,
212    key_store:             Arc<RwLock<KeyStore>>,
213    handshake_store:       HandshakeStore,
214    replay_store:          Option<ConnectionManager>,
215    memory_replay_store:   Option<Arc<MemoryReplayStore>>,
216    max_body_bytes:        usize,
217    max_decompressed_bytes: usize,
218    allow_encrypted_get:   bool,
219}
220
221impl<S, B> Service<ServiceRequest> for InterceptorService<S>
222where
223    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
224    B: MessageBody + 'static,
225{
226    type Response = ServiceResponse<EitherBody<B>>;
227    type Error    = Error;
228    type Future   = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
229
230    forward_ready!(service);
231
232    fn call(&self, mut req: ServiceRequest) -> Self::Future {
233        let service               = self.service.clone();
234        let key_store             = self.key_store.clone();
235        let handshake_store       = self.handshake_store.clone();
236        let replay_store          = self.replay_store.clone();
237        let memory_replay_store   = self.memory_replay_store.clone();
238        let max_body_bytes        = self.max_body_bytes;
239        let max_decompressed_bytes = self.max_decompressed_bytes;
240        let allow_encrypted_get   = self.allow_encrypted_get;
241
242        Box::pin(async move {
243            let method = req.method().as_str();
244            let has_body = match method {
245                "HEAD" | "OPTIONS" => false,
246                "GET"              => allow_encrypted_get,
247                _                  => true,
248            };
249
250            if has_body {
251                let mut payload = req.take_payload();
252                let mut raw = web::BytesMut::new();
253                while let Some(chunk) = payload
254                    .try_next().await
255                    .map_err(actix_web::error::ErrorBadRequest)?
256                {
257                    raw.extend_from_slice(&chunk);
258                    if raw.len() > max_body_bytes {
259                        return Err(actix_web::error::ErrorPayloadTooLarge(
260                            "request body exceeds maximum allowed size",
261                        ));
262                    }
263                }
264
265                if !raw.is_empty() {
266                    match deserialize_packet(&raw) {
267                        Ok(packet) => {
268                            let client_pk_bytes: [u8; 32] = packet.client_pk.as_ref()
269                                .try_into()
270                                .map_err(|_| actix_web::error::ErrorBadRequest("client_pk must be 32 bytes"))?;
271
272                            let (shared_secret, server_pk) =
273                                if packet.key_id.starts_with("hs_") {
274                                    ecdh_ephemeral(&handshake_store, &packet.key_id, &client_pk_bytes)
275                                        .await
276                                        .map_err(|e| actix_web::error::ErrorUnauthorized(e.to_string()))?
277                                } else {
278                                    ecdh(&key_store, &packet.key_id, &client_pk_bytes)
279                                        .await
280                                        .map_err(|e| actix_web::error::ErrorUnauthorized(e.to_string()))?
281                                };
282
283                            let shared_bytes: &[u8; 32] = shared_secret.as_ref()
284                                .try_into()
285                                .map_err(|_| actix_web::error::ErrorInternalServerError("shared secret length invalid"))?;
286                            let wrap_key = derive_wrap_key(shared_bytes, &client_pk_bytes, &server_pk);
287
288                            let enc_key_bytes = aes_decrypt(packet.kx.as_ref(), &wrap_key)
289                                .map_err(|e| actix_web::error::ErrorUnauthorized(e.to_string()))?;
290                            let enc_key: [u8; 32] = enc_key_bytes.as_slice()
291                                .try_into()
292                                .map_err(|_| actix_web::error::ErrorBadRequest("enc_key must be 32 bytes"))?;
293
294                            let seen_key = format!("replay:seen:{}", hex::encode(packet.kx.as_ref()));
295                            if let Some(mut redis) = replay_store {
296                                let is_new: bool = redis::cmd("SET")
297                                    .arg(&seen_key).arg(1u8)
298                                    .arg("NX").arg("EX").arg(60u64)
299                                    .query_async::<Option<String>>(&mut redis)
300                                    .await
301                                    .map_err(|e| {
302                                        tracing::error!("replay store unavailable: {e}");
303                                        actix_web::error::ErrorInternalServerError("replay store unavailable")
304                                    })?
305                                    .is_some();
306                                if !is_new {
307                                    return Err(actix_web::error::ErrorUnauthorized("replay detected"));
308                                }
309                            } else if let Some(mem) = &memory_replay_store {
310                                if !mem.is_new(&seen_key).await {
311                                    return Err(actix_web::error::ErrorUnauthorized("replay detected"));
312                                }
313                            }
314
315                            let decrypted = aes_decrypt(packet.data.as_ref(), &enc_key)
316                                .map_err(|e| actix_web::error::ErrorBadRequest(e.to_string()))?;
317
318                            let compressed: ByteBuf = deserialize(&decrypted)
319                                .map_err(|_| actix_web::error::ErrorBadRequest("payload msgpack decode failed"))?;
320                            decompress(&compressed, max_decompressed_bytes)
321                                .map_err(|_| actix_web::error::ErrorPayloadTooLarge("decompressed payload exceeds limit"))?;
322
323                            req.extensions_mut().insert(DecryptedBody(decrypted));
324                            req.extensions_mut().insert(RequestSessionKeys { enc_key });
325                        }
326                        Err(_) => {
327                            let frozen: actix_web::web::Bytes = raw.freeze();
328                            let (_, mut pl) = actix_http::h1::Payload::create(true);
329                            pl.unread_data(frozen);
330                            req.set_payload(actix_web::dev::Payload::from(pl));
331                        }
332                    }
333                }
334            }
335
336            let session_keys = req.extensions().get::<RequestSessionKeys>().cloned();
337            let res          = service.call(req).await?;
338
339            let session_keys = match session_keys {
340                Some(k) => k,
341                None    => return Ok(res.map_into_left_body()),
342            };
343
344            let (req, res)   = res.into_parts();
345            let (head, body) = res.into_parts();
346
347            let body_bytes = actix_web::body::to_bytes(body)
348                .await
349                .map_err(|_| actix_web::error::ErrorInternalServerError("body collect failed"))?;
350
351            let encrypted = match build_signed_response_raw(&body_bytes, &session_keys.enc_key) {
352                Ok(b)  => b,
353                Err(_) => return Ok(ServiceResponse::new(
354                    req,
355                    head.set_body(BoxBody::new(body_bytes)).map_into_right_body(),
356                )),
357            };
358
359            Ok(ServiceResponse::new(
360                req,
361                head.set_body(BoxBody::new(encrypted)).map_into_right_body(),
362            ))
363        })
364    }
365}