1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use cookie::{Cookie, CookieJar, SameSite};
8use http::{HeaderValue, Request, Response};
9use tower::{Layer, Service};
10
11use crate::cookie::{CookieConfig, Key};
12
13use super::state::{FlashEntry, FlashState};
14
15const COOKIE_NAME: &str = "flash";
16const MAX_AGE_SECS: i64 = 300;
17
18pub struct FlashLayer {
34 key: Key,
35 config: CookieConfig,
36}
37
38impl Clone for FlashLayer {
39 fn clone(&self) -> Self {
40 Self {
41 key: self.key.clone(),
42 config: self.config.clone(),
43 }
44 }
45}
46
47impl FlashLayer {
48 pub fn new(config: &CookieConfig, key: &Key) -> Self {
50 Self {
51 key: key.clone(),
52 config: config.clone(),
53 }
54 }
55}
56
57impl<S> Layer<S> for FlashLayer {
58 type Service = FlashMiddleware<S>;
59
60 fn layer(&self, inner: S) -> Self::Service {
61 FlashMiddleware {
62 inner,
63 key: self.key.clone(),
64 config: self.config.clone(),
65 }
66 }
67}
68
69pub struct FlashMiddleware<S> {
73 inner: S,
74 key: Key,
75 config: CookieConfig,
76}
77
78impl<S: Clone> Clone for FlashMiddleware<S> {
79 fn clone(&self) -> Self {
80 Self {
81 inner: self.inner.clone(),
82 key: self.key.clone(),
83 config: self.config.clone(),
84 }
85 }
86}
87
88impl<S, ReqBody> Service<Request<ReqBody>> for FlashMiddleware<S>
89where
90 S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
91 S::Future: Send + 'static,
92 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
93 ReqBody: Send + 'static,
94{
95 type Response = Response<Body>;
96 type Error = S::Error;
97 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
98
99 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100 self.inner.poll_ready(cx)
101 }
102
103 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
104 let key = self.key.clone();
105 let config = self.config.clone();
106 let mut inner = self.inner.clone();
107 std::mem::swap(&mut self.inner, &mut inner);
108
109 Box::pin(async move {
110 let incoming = read_flash_cookie(request.headers(), &key);
112 let flash_state = Arc::new(FlashState::new(incoming));
113 request.extensions_mut().insert(flash_state.clone());
114
115 let mut response = inner.call(request).await?;
117
118 let outgoing = flash_state.drain_outgoing();
120 let was_read = flash_state.was_read();
121
122 if !outgoing.is_empty() {
123 write_flash_cookie(&mut response, &outgoing, &config, &key);
124 } else if was_read {
125 remove_flash_cookie(&mut response, &config, &key);
126 }
127
128 Ok(response)
129 })
130 }
131}
132
133fn read_flash_cookie(headers: &http::HeaderMap, key: &Key) -> Vec<FlashEntry> {
134 let Some(cookie_header) = headers.get(http::header::COOKIE) else {
135 return vec![];
136 };
137 let Ok(cookie_str) = cookie_header.to_str() else {
138 return vec![];
139 };
140
141 for pair in cookie_str.split(';') {
142 let pair = pair.trim();
143 if let Some((name, value)) = pair.split_once('=')
144 && name.trim() == COOKIE_NAME
145 {
146 let mut jar = CookieJar::new();
147 jar.add_original(Cookie::new(
148 COOKIE_NAME.to_string(),
149 value.trim().to_string(),
150 ));
151 if let Some(verified) = jar.signed(key).get(COOKIE_NAME)
152 && let Ok(entries) = serde_json::from_str::<Vec<FlashEntry>>(verified.value())
153 {
154 return entries;
155 }
156 return vec![];
157 }
158 }
159 vec![]
160}
161
162fn write_flash_cookie(
163 response: &mut Response<Body>,
164 entries: &[FlashEntry],
165 config: &CookieConfig,
166 key: &Key,
167) {
168 let Ok(json) = serde_json::to_string(entries) else {
169 tracing::error!("failed to serialize flash messages");
170 return;
171 };
172
173 set_cookie(response, &json, MAX_AGE_SECS, config, key);
174}
175
176fn remove_flash_cookie(response: &mut Response<Body>, config: &CookieConfig, key: &Key) {
177 set_cookie(response, "", 0, config, key);
178}
179
180fn set_cookie(
181 response: &mut Response<Body>,
182 value: &str,
183 max_age_secs: i64,
184 config: &CookieConfig,
185 key: &Key,
186) {
187 let mut jar = CookieJar::new();
188 jar.signed_mut(key)
189 .add(Cookie::new(COOKIE_NAME.to_string(), value.to_string()));
190 let signed_value = jar
191 .get(COOKIE_NAME)
192 .expect("cookie was just added")
193 .value()
194 .to_string();
195
196 let same_site = match config.same_site.as_str() {
197 "strict" => SameSite::Strict,
198 "none" => SameSite::None,
199 _ => SameSite::Lax,
200 };
201 let set_cookie_str = Cookie::build((COOKIE_NAME.to_string(), signed_value))
202 .path("/")
203 .secure(config.secure)
204 .http_only(config.http_only)
205 .same_site(same_site)
206 .max_age(cookie::time::Duration::seconds(max_age_secs))
207 .build()
208 .to_string();
209
210 match HeaderValue::from_str(&set_cookie_str) {
211 Ok(v) => {
212 response.headers_mut().append(http::header::SET_COOKIE, v);
213 }
214 Err(e) => {
215 tracing::error!("failed to set flash cookie header: {e}");
216 }
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use axum::Router;
224 use axum::routing::get;
225 use http::StatusCode;
226 use tower::ServiceExt;
227
228 fn test_config() -> CookieConfig {
229 CookieConfig {
230 secret: "a".repeat(64),
231 secure: false,
232 http_only: true,
233 same_site: "lax".into(),
234 }
235 }
236
237 fn test_key(config: &CookieConfig) -> Key {
238 crate::cookie::key_from_config(config).unwrap()
239 }
240
241 fn make_signed_cookie(entries: &[FlashEntry], key: &Key) -> String {
242 let json = serde_json::to_string(entries).unwrap();
243 let mut jar = CookieJar::new();
244 jar.signed_mut(key)
245 .add(Cookie::new(COOKIE_NAME.to_string(), json));
246 let signed_value = jar.get(COOKIE_NAME).unwrap().value().to_string();
247 format!("{COOKIE_NAME}={signed_value}")
248 }
249
250 fn extract_flash_set_cookie(resp: &Response<Body>) -> Option<String> {
252 resp.headers()
253 .get_all(http::header::SET_COOKIE)
254 .iter()
255 .find_map(|v| {
256 let s = v.to_str().ok()?;
257 if s.starts_with("flash=") {
258 Some(s.to_string())
259 } else {
260 None
261 }
262 })
263 }
264
265 async fn noop_handler() -> StatusCode {
268 StatusCode::OK
269 }
270
271 async fn set_flash_handler(flash: crate::flash::Flash) -> StatusCode {
272 flash.success("it worked");
273 StatusCode::OK
274 }
275
276 async fn set_multiple_handler(flash: crate::flash::Flash) -> StatusCode {
277 flash.error("bad");
278 flash.warning("careful");
279 StatusCode::OK
280 }
281
282 async fn mark_read_handler(req: Request<Body>) -> StatusCode {
283 let state = req.extensions().get::<Arc<FlashState>>().unwrap();
284 state.mark_read();
285 StatusCode::OK
286 }
287
288 async fn read_and_write_handler(req: Request<Body>) -> StatusCode {
289 let state = req.extensions().get::<Arc<FlashState>>().unwrap();
290 state.mark_read();
291 state.push("success", "new");
292 StatusCode::OK
293 }
294
295 #[tokio::test]
298 async fn no_cookie_empty_state_no_set_cookie() {
299 let config = test_config();
300 let key = test_key(&config);
301 let app = Router::new()
302 .route("/", get(noop_handler))
303 .layer(FlashLayer::new(&config, &key));
304
305 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
306 let resp = app.oneshot(req).await.unwrap();
307 assert_eq!(resp.status(), StatusCode::OK);
308 assert!(extract_flash_set_cookie(&resp).is_none());
309 }
310
311 #[tokio::test]
312 async fn outgoing_writes_cookie() {
313 let config = test_config();
314 let key = test_key(&config);
315 let app = Router::new()
316 .route("/", get(set_flash_handler))
317 .layer(FlashLayer::new(&config, &key));
318
319 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
320 let resp = app.oneshot(req).await.unwrap();
321 assert_eq!(resp.status(), StatusCode::OK);
322
323 let cookie_str = extract_flash_set_cookie(&resp).expect("should have Set-Cookie");
324 assert!(cookie_str.contains("flash="));
325 assert!(cookie_str.contains("HttpOnly"));
326 }
327
328 #[tokio::test]
329 async fn valid_signed_cookie_populates_incoming() {
330 let config = test_config();
331 let key = test_key(&config);
332
333 let entries = vec![FlashEntry {
334 level: "success".into(),
335 message: "saved".into(),
336 }];
337 let cookie_val = make_signed_cookie(&entries, &key);
338
339 let app = Router::new()
340 .route("/", get(noop_handler))
341 .layer(FlashLayer::new(&config, &key));
342
343 let req = Request::builder()
344 .uri("/")
345 .header(http::header::COOKIE, cookie_val)
346 .body(Body::empty())
347 .unwrap();
348 let resp = app.oneshot(req).await.unwrap();
349 assert_eq!(resp.status(), StatusCode::OK);
350 assert!(extract_flash_set_cookie(&resp).is_none());
352 }
353
354 #[tokio::test]
355 async fn invalid_cookie_gives_empty_incoming() {
356 let config = test_config();
357 let key = test_key(&config);
358
359 let app = Router::new()
360 .route("/", get(noop_handler))
361 .layer(FlashLayer::new(&config, &key));
362
363 let req = Request::builder()
364 .uri("/")
365 .header(http::header::COOKIE, "flash=tampered_value")
366 .body(Body::empty())
367 .unwrap();
368 let resp = app.oneshot(req).await.unwrap();
369 assert_eq!(resp.status(), StatusCode::OK);
370 assert!(extract_flash_set_cookie(&resp).is_none());
371 }
372
373 #[tokio::test]
374 async fn read_flag_clears_cookie() {
375 let config = test_config();
376 let key = test_key(&config);
377
378 let entries = vec![FlashEntry {
379 level: "info".into(),
380 message: "hello".into(),
381 }];
382 let cookie_val = make_signed_cookie(&entries, &key);
383
384 let app = Router::new()
385 .route("/", get(mark_read_handler))
386 .layer(FlashLayer::new(&config, &key));
387
388 let req = Request::builder()
389 .uri("/")
390 .header(http::header::COOKIE, cookie_val)
391 .body(Body::empty())
392 .unwrap();
393 let resp = app.oneshot(req).await.unwrap();
394
395 let cookie_str = extract_flash_set_cookie(&resp).expect("should clear cookie");
396 assert!(cookie_str.contains("Max-Age=0"));
397 }
398
399 #[tokio::test]
400 async fn outgoing_plus_read_writes_only_outgoing() {
401 let config = test_config();
402 let key = test_key(&config);
403
404 let entries = vec![FlashEntry {
405 level: "info".into(),
406 message: "old".into(),
407 }];
408 let cookie_val = make_signed_cookie(&entries, &key);
409
410 let app = Router::new()
411 .route("/", get(read_and_write_handler))
412 .layer(FlashLayer::new(&config, &key));
413
414 let req = Request::builder()
415 .uri("/")
416 .header(http::header::COOKIE, cookie_val)
417 .body(Body::empty())
418 .unwrap();
419 let resp = app.oneshot(req).await.unwrap();
420
421 let cookie_str = extract_flash_set_cookie(&resp).expect("should have cookie");
422 assert!(!cookie_str.contains("Max-Age=0"));
423 assert!(cookie_str.contains("flash="));
424 }
425
426 #[tokio::test]
427 async fn multiple_outgoing_messages() {
428 let config = test_config();
429 let key = test_key(&config);
430 let app = Router::new()
431 .route("/", get(set_multiple_handler))
432 .layer(FlashLayer::new(&config, &key));
433
434 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
435 let resp = app.oneshot(req).await.unwrap();
436
437 let cookie_str = extract_flash_set_cookie(&resp).expect("should have Set-Cookie");
438 assert!(cookie_str.contains("flash="));
439 }
440
441 #[tokio::test]
442 async fn cookie_attributes_applied() {
443 let config = CookieConfig {
444 secret: "a".repeat(64),
445 secure: true,
446 http_only: true,
447 same_site: "strict".into(),
448 };
449 let key = test_key(&config);
450 let app = Router::new()
451 .route("/", get(set_flash_handler))
452 .layer(FlashLayer::new(&config, &key));
453
454 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
455 let resp = app.oneshot(req).await.unwrap();
456
457 let cookie_str = extract_flash_set_cookie(&resp).expect("should have Set-Cookie");
458 assert!(cookie_str.contains("Secure"));
459 assert!(cookie_str.contains("HttpOnly"));
460 assert!(cookie_str.contains("SameSite=Strict"));
461 assert!(cookie_str.contains("Path=/"));
462 }
463}