axum_flash/
lib.rs

1//! One-time notifications (aka flash messages) for [axum].
2//!
3//! # Example
4//!
5//! ```
6//! use axum::{
7//!     response::{IntoResponse, Redirect},
8//!     extract::FromRef,
9//!     routing::get,
10//!     Router,
11//! };
12//! use axum_flash::{IncomingFlashes, Flash, Key};
13//!
14//! #[derive(Clone)]
15//! struct AppState {
16//!     flash_config: axum_flash::Config,
17//! }
18//!
19//! let app_state = AppState {
20//!     // The key should probably come from configuration
21//!     flash_config: axum_flash::Config::new(Key::generate()),
22//! };
23//!
24//! // Our state type must implement this trait. That is how the config
25//! // is passed to axum-flash in a type safe way.
26//! impl FromRef<AppState> for axum_flash::Config {
27//!     fn from_ref(state: &AppState) -> axum_flash::Config {
28//!         state.flash_config.clone()
29//!     }
30//! }
31//!
32//! let app = Router::new()
33//!     .route("/", get(root))
34//!     .route("/set-flash", get(set_flash))
35//!     .with_state(app_state);
36//!
37//! async fn root(flashes: IncomingFlashes) -> IncomingFlashes {
38//!     for (level, text) in &flashes {
39//!         // ...
40//!     }
41//!
42//!     // The flashes must be returned so the cookie is removed
43//!     flashes
44//! }
45//!
46//! async fn set_flash(flash: Flash) -> (Flash, Redirect) {
47//!     (
48//!         // The flash must be returned so the cookie is set
49//!         flash.debug("Hi from flash!"),
50//!         Redirect::to("/"),
51//!     )
52//! }
53//! # let _: Router = app;
54//! ```
55//!
56//! [axum]: https://crates.io/crates/axum
57
58#![warn(
59    clippy::all,
60    clippy::dbg_macro,
61    clippy::todo,
62    clippy::empty_enum,
63    clippy::enum_glob_use,
64    clippy::mem_forget,
65    clippy::unused_self,
66    clippy::filter_map_next,
67    clippy::needless_continue,
68    clippy::needless_borrow,
69    clippy::match_wildcard_for_single_variants,
70    clippy::if_let_mutex,
71    clippy::mismatched_target_os,
72    clippy::await_holding_lock,
73    clippy::match_on_vec_items,
74    clippy::imprecise_flops,
75    clippy::suboptimal_flops,
76    clippy::lossy_float_literal,
77    clippy::rest_pat_in_fully_bound_structs,
78    clippy::fn_params_excessive_bools,
79    clippy::exit,
80    clippy::inefficient_to_string,
81    clippy::linkedlist,
82    clippy::macro_use_imports,
83    clippy::option_option,
84    clippy::verbose_file_reads,
85    clippy::unnested_or_patterns,
86    rust_2018_idioms,
87    future_incompatible,
88    nonstandard_style,
89    missing_debug_implementations,
90    missing_docs
91)]
92#![deny(unreachable_pub)]
93#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
94#![forbid(unsafe_code)]
95#![cfg_attr(docsrs, feature(doc_cfg))]
96#![cfg_attr(test, allow(clippy::float_cmp))]
97
98use async_trait::async_trait;
99use axum_core::{
100    extract::{FromRef, FromRequestParts},
101    response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
102};
103use axum_extra::extract::cookie::{Cookie, SignedCookieJar};
104use http::{request::Parts, StatusCode};
105use serde::{Deserialize, Serialize};
106use std::{borrow::Cow, fmt};
107use std::{
108    convert::{Infallible, TryInto},
109    time::Duration,
110};
111
112pub use axum_extra::extract::cookie::Key;
113
114/// Extractor for setting outgoing flash messages.
115///
116/// The flashes will be stored in a signed cookie.
117#[derive(Clone)]
118pub struct Flash {
119    flashes: Vec<FlashMessage>,
120    use_secure_cookies: bool,
121    key: Key,
122}
123
124impl fmt::Debug for Flash {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        f.debug_struct("Flash")
127            .field("flashes", &self.flashes)
128            .field("use_secure_cookies", &self.use_secure_cookies)
129            .field("key", &"REDACTED")
130            .finish()
131    }
132}
133
134impl Flash {
135    /// Push an `Debug` flash message.
136    pub fn debug(self, message: impl Into<String>) -> Self {
137        self.push(Level::Debug, message)
138    }
139
140    /// Push an `Info` flash message.
141    pub fn info(self, message: impl Into<String>) -> Self {
142        self.push(Level::Info, message)
143    }
144
145    /// Push an `Success` flash message.
146    pub fn success(self, message: impl Into<String>) -> Self {
147        self.push(Level::Success, message)
148    }
149
150    /// Push an `Warning` flash message.
151    pub fn warning(self, message: impl Into<String>) -> Self {
152        self.push(Level::Warning, message)
153    }
154
155    /// Push an `Error` flash message.
156    pub fn error(self, message: impl Into<String>) -> Self {
157        self.push(Level::Error, message)
158    }
159
160    /// Push a flash message with the given level and message.
161    pub fn push(mut self, level: Level, message: impl Into<String>) -> Self {
162        self.flashes.push(FlashMessage {
163            message: message.into(),
164            level,
165        });
166        self
167    }
168}
169
170#[async_trait]
171impl<S> FromRequestParts<S> for Flash
172where
173    S: Send + Sync,
174    Config: FromRef<S>,
175{
176    type Rejection = Infallible;
177
178    async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
179        let config = Config::from_ref(state);
180
181        Ok(Self {
182            key: config.key,
183            use_secure_cookies: config.use_secure_cookies,
184            flashes: Default::default(),
185        })
186    }
187}
188
189const COOKIE_NAME: &str = "axum-flash";
190
191impl IntoResponseParts for Flash {
192    type Error = Infallible;
193
194    fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
195        let json =
196            serde_json::to_string(&self.flashes).expect("failed to serialize flash messages");
197
198        let cookies = SignedCookieJar::new(self.key.clone());
199
200        let cookies = cookies.add(create_cookie(json, self.use_secure_cookies));
201        cookies.into_response_parts(res)
202    }
203}
204
205pub(crate) fn create_cookie(
206    value: impl Into<Cow<'static, str>>,
207    use_secure_cookies: bool,
208) -> Cookie<'static> {
209    // process is inspired by
210    // https://github.com/LukeMathWalker/actix-web-flash-messages/blob/main/src/storage/cookies.rs#L54
211    Cookie::build((COOKIE_NAME, value))
212        // only send the cookie for https (maybe)
213        .secure(use_secure_cookies)
214        // don't allow javascript to access the cookie
215        .http_only(true)
216        // don't send the cookie to other domains
217        .same_site(cookie::SameSite::Strict)
218        // allow the cookie for all paths
219        .path("/")
220        // expire after 10 minutes
221        .max_age(
222            Duration::from_secs(10 * 60)
223                .try_into()
224                .expect("failed to convert `std::time::Duration` to `time::Duration`"),
225        )
226        .build()
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize)]
230struct FlashMessage {
231    #[serde(rename = "l")]
232    level: Level,
233    #[serde(rename = "m")]
234    message: String,
235}
236
237/// Verbosity level of a flash message.
238#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
239pub enum Level {
240    #[allow(missing_docs)]
241    Debug = 0,
242    #[allow(missing_docs)]
243    Info = 1,
244    #[allow(missing_docs)]
245    Success = 2,
246    #[allow(missing_docs)]
247    Warning = 3,
248    #[allow(missing_docs)]
249    Error = 4,
250}
251
252/// Configuration for axum-flash.
253#[derive(Clone)]
254pub struct Config {
255    use_secure_cookies: bool,
256    key: Key,
257}
258
259impl Config {
260    /// Create a new `Config` using the given key.
261    ///
262    /// Cookies will be signed using `key`.
263    pub fn new(key: Key) -> Self {
264        Self {
265            use_secure_cookies: true,
266            key,
267        }
268    }
269
270    /// Mark the cookie as secure so the cookie will only be sent on `https`.
271    ///
272    /// Defaults to marking cookies as secure.
273    ///
274    /// For local development, depending on your brwoser, you might have to set
275    /// this to `false` for flash messages to show up.
276    ///
277    /// See [mdn] for more details on secure cookies.
278    ///
279    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
280    pub fn use_secure_cookies(mut self, use_secure_cookies: bool) -> Self {
281        self.use_secure_cookies = use_secure_cookies;
282        self
283    }
284}
285
286impl fmt::Debug for Config {
287    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288        f.debug_struct("Config")
289            .field("use_secure_cookies", &self.use_secure_cookies)
290            .field("key", &"REDACTED")
291            .finish()
292    }
293}
294
295/// Extractor for incoming flash messages.
296///
297/// See [root module docs](crate) for an example.
298#[derive(Clone)]
299pub struct IncomingFlashes {
300    flashes: Vec<FlashMessage>,
301    use_secure_cookies: bool,
302    key: Key,
303}
304
305impl fmt::Debug for IncomingFlashes {
306    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307        f.debug_struct("IncomingFlashes")
308            .field("flashes", &self.flashes)
309            .field("use_secure_cookies", &self.use_secure_cookies)
310            .field("key", &"REDACTED")
311            .finish()
312    }
313}
314
315impl IncomingFlashes {
316    /// Get an iterator over the flash messages.
317    pub fn iter(&self) -> Iter<'_> {
318        Iter(self.flashes.iter())
319    }
320
321    /// Get the number of flash messages.
322    pub fn len(&self) -> usize {
323        self.flashes.len()
324    }
325
326    /// Whether there are any flash messages or not.
327    pub fn is_empty(&self) -> bool {
328        self.flashes.is_empty()
329    }
330}
331
332/// An iterator over the flash messages.
333#[derive(Debug)]
334pub struct Iter<'a>(std::slice::Iter<'a, FlashMessage>);
335
336impl<'a> Iterator for Iter<'a> {
337    type Item = (Level, &'a str);
338
339    fn next(&mut self) -> Option<Self::Item> {
340        let message = self.0.next()?;
341        Some((message.level, &message.message))
342    }
343}
344
345impl<'a> IntoIterator for &'a IncomingFlashes {
346    type Item = (Level, &'a str);
347    type IntoIter = Iter<'a>;
348
349    fn into_iter(self) -> Self::IntoIter {
350        self.iter()
351    }
352}
353
354#[async_trait]
355impl<S> FromRequestParts<S> for IncomingFlashes
356where
357    S: Send + Sync,
358    Config: FromRef<S>,
359{
360    type Rejection = (StatusCode, &'static str);
361
362    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
363        let config = Config::from_ref(state);
364        let cookies = SignedCookieJar::from_headers(&parts.headers, config.key.clone());
365
366        let flashes = cookies
367            .get(COOKIE_NAME)
368            .map(|cookie| cookie.into_owned())
369            .and_then(|cookie| serde_json::from_str::<Vec<FlashMessage>>(cookie.value()).ok())
370            .unwrap_or_default();
371
372        Ok(Self {
373            flashes,
374            use_secure_cookies: config.use_secure_cookies,
375            key: config.key,
376        })
377    }
378}
379
380impl IntoResponseParts for IncomingFlashes {
381    type Error = Infallible;
382
383    fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
384        let cookies = SignedCookieJar::from_headers(res.headers(), self.key);
385
386        let mut cookie = create_cookie("".to_owned(), self.use_secure_cookies);
387        cookie.make_removal();
388        let cookies = cookies.add(cookie);
389        cookies.into_response_parts(res)
390    }
391}
392
393impl IntoResponse for IncomingFlashes {
394    fn into_response(self) -> Response {
395        (self, ()).into_response()
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    #[allow(unused_imports)]
402    use super::*;
403    use axum::{
404        body::Body,
405        http::{header, Request},
406        response::Redirect,
407        routing::get,
408        Router,
409    };
410    use http_body_util::BodyExt;
411    use tower::ServiceExt;
412
413    #[tokio::test]
414    async fn basic() {
415        let config = Config::new(Key::generate()).use_secure_cookies(false);
416
417        let app = Router::new()
418            .route("/", get(root))
419            .route("/set-flash", get(set_flash))
420            .with_state(config);
421
422        async fn root(flash: IncomingFlashes) -> (IncomingFlashes, String) {
423            let messages = flash
424                .into_iter()
425                .map(|(level, text)| format!("{:?}: {}", level, text))
426                .collect::<Vec<_>>()
427                .join(", ");
428            (flash, messages)
429        }
430
431        #[axum::debug_handler(state = Config)]
432        async fn set_flash(flash: Flash) -> (Flash, Redirect) {
433            (flash.debug("Hi from flash!"), Redirect::to("/"))
434        }
435
436        let request = Request::builder()
437            .uri("/set-flash")
438            .body(Body::empty())
439            .unwrap();
440        let mut response = app.clone().oneshot(request).await.unwrap();
441        assert!(response.status().is_redirection());
442        let cookie = response.headers_mut().remove(header::SET_COOKIE).unwrap();
443
444        let request = Request::builder()
445            .uri("/")
446            .header(header::COOKIE, cookie)
447            .body(Body::empty())
448            .unwrap();
449        let response = app.clone().oneshot(request).await.unwrap();
450
451        assert!(response.headers()[header::SET_COOKIE]
452            .to_str()
453            .unwrap()
454            .contains("Max-Age=0"),);
455
456        let bytes = response.into_body().collect().await.unwrap().to_bytes();
457        let body = String::from_utf8(bytes.to_vec()).unwrap();
458        assert_eq!(body, "Debug: Hi from flash!");
459    }
460}