Skip to main content

modo/flash/
middleware.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use cookie::{Cookie, CookieJar, SameSite};
8use http::{HeaderValue, Request, Response};
9use tower::{Layer, Service};
10
11use crate::cookie::{CookieConfig, Key};
12
13use super::state::{FlashEntry, FlashState};
14
15const COOKIE_NAME: &str = "flash";
16const MAX_AGE_SECS: i64 = 300;
17
18// --- Layer ---
19
20/// Tower [`Layer`] that enables cookie-based flash messages for a router.
21///
22/// On each request the layer reads the signed `flash` cookie and populates
23/// the [`Flash`](crate::flash::Flash) extractor. On response it either writes a new
24/// signed cookie (when messages were queued) or removes the existing one (when
25/// messages were consumed via [`Flash::messages`](crate::flash::Flash::messages)).
26///
27/// # Cookie details
28///
29/// - Name: `flash`
30/// - Signed with HMAC using the application [`Key`]
31/// - `Max-Age`: 300 seconds (5 minutes)
32/// - Path, `Secure`, `HttpOnly`, and `SameSite` attributes come from [`CookieConfig`]
33pub struct FlashLayer {
34    key: Key,
35    config: CookieConfig,
36}
37
38impl Clone for FlashLayer {
39    fn clone(&self) -> Self {
40        Self {
41            key: self.key.clone(),
42            config: self.config.clone(),
43        }
44    }
45}
46
47impl FlashLayer {
48    /// Create a new `FlashLayer` from a cookie configuration and signing key.
49    pub fn new(config: &CookieConfig, key: &Key) -> Self {
50        Self {
51            key: key.clone(),
52            config: config.clone(),
53        }
54    }
55}
56
57impl<S> Layer<S> for FlashLayer {
58    type Service = FlashMiddleware<S>;
59
60    fn layer(&self, inner: S) -> Self::Service {
61        FlashMiddleware {
62            inner,
63            key: self.key.clone(),
64            config: self.config.clone(),
65        }
66    }
67}
68
69// --- Service ---
70
71/// Tower [`Service`] produced by [`FlashLayer`].
72pub struct FlashMiddleware<S> {
73    inner: S,
74    key: Key,
75    config: CookieConfig,
76}
77
78impl<S: Clone> Clone for FlashMiddleware<S> {
79    fn clone(&self) -> Self {
80        Self {
81            inner: self.inner.clone(),
82            key: self.key.clone(),
83            config: self.config.clone(),
84        }
85    }
86}
87
88impl<S, ReqBody> Service<Request<ReqBody>> for FlashMiddleware<S>
89where
90    S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
91    S::Future: Send + 'static,
92    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
93    ReqBody: Send + 'static,
94{
95    type Response = Response<Body>;
96    type Error = S::Error;
97    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
98
99    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100        self.inner.poll_ready(cx)
101    }
102
103    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
104        let key = self.key.clone();
105        let config = self.config.clone();
106        let mut inner = self.inner.clone();
107        std::mem::swap(&mut self.inner, &mut inner);
108
109        Box::pin(async move {
110            // --- Request path ---
111            let incoming = read_flash_cookie(request.headers(), &key);
112            let flash_state = Arc::new(FlashState::new(incoming));
113            request.extensions_mut().insert(flash_state.clone());
114
115            // --- Run inner service ---
116            let mut response = inner.call(request).await?;
117
118            // --- Response path ---
119            let outgoing = flash_state.drain_outgoing();
120            let was_read = flash_state.was_read();
121
122            if !outgoing.is_empty() {
123                write_flash_cookie(&mut response, &outgoing, &config, &key);
124            } else if was_read {
125                remove_flash_cookie(&mut response, &config, &key);
126            }
127
128            Ok(response)
129        })
130    }
131}
132
133fn read_flash_cookie(headers: &http::HeaderMap, key: &Key) -> Vec<FlashEntry> {
134    let Some(cookie_header) = headers.get(http::header::COOKIE) else {
135        return vec![];
136    };
137    let Ok(cookie_str) = cookie_header.to_str() else {
138        return vec![];
139    };
140
141    for pair in cookie_str.split(';') {
142        let pair = pair.trim();
143        if let Some((name, value)) = pair.split_once('=')
144            && name.trim() == COOKIE_NAME
145        {
146            let mut jar = CookieJar::new();
147            jar.add_original(Cookie::new(
148                COOKIE_NAME.to_string(),
149                value.trim().to_string(),
150            ));
151            if let Some(verified) = jar.signed(key).get(COOKIE_NAME)
152                && let Ok(entries) = serde_json::from_str::<Vec<FlashEntry>>(verified.value())
153            {
154                return entries;
155            }
156            return vec![];
157        }
158    }
159    vec![]
160}
161
162fn write_flash_cookie(
163    response: &mut Response<Body>,
164    entries: &[FlashEntry],
165    config: &CookieConfig,
166    key: &Key,
167) {
168    let Ok(json) = serde_json::to_string(entries) else {
169        tracing::error!("failed to serialize flash messages");
170        return;
171    };
172
173    set_cookie(response, &json, MAX_AGE_SECS, config, key);
174}
175
176fn remove_flash_cookie(response: &mut Response<Body>, config: &CookieConfig, key: &Key) {
177    set_cookie(response, "", 0, config, key);
178}
179
180fn set_cookie(
181    response: &mut Response<Body>,
182    value: &str,
183    max_age_secs: i64,
184    config: &CookieConfig,
185    key: &Key,
186) {
187    let mut jar = CookieJar::new();
188    jar.signed_mut(key)
189        .add(Cookie::new(COOKIE_NAME.to_string(), value.to_string()));
190    let signed_value = jar
191        .get(COOKIE_NAME)
192        .expect("cookie was just added")
193        .value()
194        .to_string();
195
196    let same_site = match config.same_site.as_str() {
197        "strict" => SameSite::Strict,
198        "none" => SameSite::None,
199        _ => SameSite::Lax,
200    };
201    let set_cookie_str = Cookie::build((COOKIE_NAME.to_string(), signed_value))
202        .path("/")
203        .secure(config.secure)
204        .http_only(config.http_only)
205        .same_site(same_site)
206        .max_age(cookie::time::Duration::seconds(max_age_secs))
207        .build()
208        .to_string();
209
210    match HeaderValue::from_str(&set_cookie_str) {
211        Ok(v) => {
212            response.headers_mut().append(http::header::SET_COOKIE, v);
213        }
214        Err(e) => {
215            tracing::error!("failed to set flash cookie header: {e}");
216        }
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use axum::Router;
224    use axum::routing::get;
225    use http::StatusCode;
226    use tower::ServiceExt;
227
228    fn test_config() -> CookieConfig {
229        CookieConfig {
230            secret: "a".repeat(64),
231            secure: false,
232            http_only: true,
233            same_site: "lax".into(),
234        }
235    }
236
237    fn test_key(config: &CookieConfig) -> Key {
238        crate::cookie::key_from_config(config).unwrap()
239    }
240
241    fn make_signed_cookie(entries: &[FlashEntry], key: &Key) -> String {
242        let json = serde_json::to_string(entries).unwrap();
243        let mut jar = CookieJar::new();
244        jar.signed_mut(key)
245            .add(Cookie::new(COOKIE_NAME.to_string(), json));
246        let signed_value = jar.get(COOKIE_NAME).unwrap().value().to_string();
247        format!("{COOKIE_NAME}={signed_value}")
248    }
249
250    /// Extract the flash Set-Cookie header from response (handles multiple Set-Cookie headers)
251    fn extract_flash_set_cookie(resp: &Response<Body>) -> Option<String> {
252        resp.headers()
253            .get_all(http::header::SET_COOKIE)
254            .iter()
255            .find_map(|v| {
256                let s = v.to_str().ok()?;
257                if s.starts_with("flash=") {
258                    Some(s.to_string())
259                } else {
260                    None
261                }
262            })
263    }
264
265    // --- Module-level handlers (axum Handler bounds require module-level async fn) ---
266
267    async fn noop_handler() -> StatusCode {
268        StatusCode::OK
269    }
270
271    async fn set_flash_handler(flash: crate::flash::Flash) -> StatusCode {
272        flash.success("it worked");
273        StatusCode::OK
274    }
275
276    async fn set_multiple_handler(flash: crate::flash::Flash) -> StatusCode {
277        flash.error("bad");
278        flash.warning("careful");
279        StatusCode::OK
280    }
281
282    async fn mark_read_handler(req: Request<Body>) -> StatusCode {
283        let state = req.extensions().get::<Arc<FlashState>>().unwrap();
284        state.mark_read();
285        StatusCode::OK
286    }
287
288    async fn read_and_write_handler(req: Request<Body>) -> StatusCode {
289        let state = req.extensions().get::<Arc<FlashState>>().unwrap();
290        state.mark_read();
291        state.push("success", "new");
292        StatusCode::OK
293    }
294
295    // --- Tests ---
296
297    #[tokio::test]
298    async fn no_cookie_empty_state_no_set_cookie() {
299        let config = test_config();
300        let key = test_key(&config);
301        let app = Router::new()
302            .route("/", get(noop_handler))
303            .layer(FlashLayer::new(&config, &key));
304
305        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
306        let resp = app.oneshot(req).await.unwrap();
307        assert_eq!(resp.status(), StatusCode::OK);
308        assert!(extract_flash_set_cookie(&resp).is_none());
309    }
310
311    #[tokio::test]
312    async fn outgoing_writes_cookie() {
313        let config = test_config();
314        let key = test_key(&config);
315        let app = Router::new()
316            .route("/", get(set_flash_handler))
317            .layer(FlashLayer::new(&config, &key));
318
319        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
320        let resp = app.oneshot(req).await.unwrap();
321        assert_eq!(resp.status(), StatusCode::OK);
322
323        let cookie_str = extract_flash_set_cookie(&resp).expect("should have Set-Cookie");
324        assert!(cookie_str.contains("flash="));
325        assert!(cookie_str.contains("HttpOnly"));
326    }
327
328    #[tokio::test]
329    async fn valid_signed_cookie_populates_incoming() {
330        let config = test_config();
331        let key = test_key(&config);
332
333        let entries = vec![FlashEntry {
334            level: "success".into(),
335            message: "saved".into(),
336        }];
337        let cookie_val = make_signed_cookie(&entries, &key);
338
339        let app = Router::new()
340            .route("/", get(noop_handler))
341            .layer(FlashLayer::new(&config, &key));
342
343        let req = Request::builder()
344            .uri("/")
345            .header(http::header::COOKIE, cookie_val)
346            .body(Body::empty())
347            .unwrap();
348        let resp = app.oneshot(req).await.unwrap();
349        assert_eq!(resp.status(), StatusCode::OK);
350        // No read, no write — cookie untouched
351        assert!(extract_flash_set_cookie(&resp).is_none());
352    }
353
354    #[tokio::test]
355    async fn invalid_cookie_gives_empty_incoming() {
356        let config = test_config();
357        let key = test_key(&config);
358
359        let app = Router::new()
360            .route("/", get(noop_handler))
361            .layer(FlashLayer::new(&config, &key));
362
363        let req = Request::builder()
364            .uri("/")
365            .header(http::header::COOKIE, "flash=tampered_value")
366            .body(Body::empty())
367            .unwrap();
368        let resp = app.oneshot(req).await.unwrap();
369        assert_eq!(resp.status(), StatusCode::OK);
370        assert!(extract_flash_set_cookie(&resp).is_none());
371    }
372
373    #[tokio::test]
374    async fn read_flag_clears_cookie() {
375        let config = test_config();
376        let key = test_key(&config);
377
378        let entries = vec![FlashEntry {
379            level: "info".into(),
380            message: "hello".into(),
381        }];
382        let cookie_val = make_signed_cookie(&entries, &key);
383
384        let app = Router::new()
385            .route("/", get(mark_read_handler))
386            .layer(FlashLayer::new(&config, &key));
387
388        let req = Request::builder()
389            .uri("/")
390            .header(http::header::COOKIE, cookie_val)
391            .body(Body::empty())
392            .unwrap();
393        let resp = app.oneshot(req).await.unwrap();
394
395        let cookie_str = extract_flash_set_cookie(&resp).expect("should clear cookie");
396        assert!(cookie_str.contains("Max-Age=0"));
397    }
398
399    #[tokio::test]
400    async fn outgoing_plus_read_writes_only_outgoing() {
401        let config = test_config();
402        let key = test_key(&config);
403
404        let entries = vec![FlashEntry {
405            level: "info".into(),
406            message: "old".into(),
407        }];
408        let cookie_val = make_signed_cookie(&entries, &key);
409
410        let app = Router::new()
411            .route("/", get(read_and_write_handler))
412            .layer(FlashLayer::new(&config, &key));
413
414        let req = Request::builder()
415            .uri("/")
416            .header(http::header::COOKIE, cookie_val)
417            .body(Body::empty())
418            .unwrap();
419        let resp = app.oneshot(req).await.unwrap();
420
421        let cookie_str = extract_flash_set_cookie(&resp).expect("should have cookie");
422        assert!(!cookie_str.contains("Max-Age=0"));
423        assert!(cookie_str.contains("flash="));
424    }
425
426    #[tokio::test]
427    async fn multiple_outgoing_messages() {
428        let config = test_config();
429        let key = test_key(&config);
430        let app = Router::new()
431            .route("/", get(set_multiple_handler))
432            .layer(FlashLayer::new(&config, &key));
433
434        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
435        let resp = app.oneshot(req).await.unwrap();
436
437        let cookie_str = extract_flash_set_cookie(&resp).expect("should have Set-Cookie");
438        assert!(cookie_str.contains("flash="));
439    }
440
441    #[tokio::test]
442    async fn cookie_attributes_applied() {
443        let config = CookieConfig {
444            secret: "a".repeat(64),
445            secure: true,
446            http_only: true,
447            same_site: "strict".into(),
448        };
449        let key = test_key(&config);
450        let app = Router::new()
451            .route("/", get(set_flash_handler))
452            .layer(FlashLayer::new(&config, &key));
453
454        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
455        let resp = app.oneshot(req).await.unwrap();
456
457        let cookie_str = extract_flash_set_cookie(&resp).expect("should have Set-Cookie");
458        assert!(cookie_str.contains("Secure"));
459        assert!(cookie_str.contains("HttpOnly"));
460        assert!(cookie_str.contains("SameSite=Strict"));
461        assert!(cookie_str.contains("Path=/"));
462    }
463}