1#![warn(
59 clippy::all,
60 clippy::dbg_macro,
61 clippy::todo,
62 clippy::empty_enum,
63 clippy::enum_glob_use,
64 clippy::mem_forget,
65 clippy::unused_self,
66 clippy::filter_map_next,
67 clippy::needless_continue,
68 clippy::needless_borrow,
69 clippy::match_wildcard_for_single_variants,
70 clippy::if_let_mutex,
71 clippy::mismatched_target_os,
72 clippy::await_holding_lock,
73 clippy::match_on_vec_items,
74 clippy::imprecise_flops,
75 clippy::suboptimal_flops,
76 clippy::lossy_float_literal,
77 clippy::rest_pat_in_fully_bound_structs,
78 clippy::fn_params_excessive_bools,
79 clippy::exit,
80 clippy::inefficient_to_string,
81 clippy::linkedlist,
82 clippy::macro_use_imports,
83 clippy::option_option,
84 clippy::verbose_file_reads,
85 clippy::unnested_or_patterns,
86 rust_2018_idioms,
87 future_incompatible,
88 nonstandard_style,
89 missing_debug_implementations,
90 missing_docs
91)]
92#![deny(unreachable_pub)]
93#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
94#![forbid(unsafe_code)]
95#![cfg_attr(docsrs, feature(doc_cfg))]
96#![cfg_attr(test, allow(clippy::float_cmp))]
97
98use async_trait::async_trait;
99use axum_core::{
100 extract::{FromRef, FromRequestParts},
101 response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
102};
103use axum_extra::extract::cookie::{Cookie, SignedCookieJar};
104use http::{request::Parts, StatusCode};
105use serde::{Deserialize, Serialize};
106use std::{borrow::Cow, fmt};
107use std::{
108 convert::{Infallible, TryInto},
109 time::Duration,
110};
111
112pub use axum_extra::extract::cookie::Key;
113
114#[derive(Clone)]
118pub struct Flash {
119 flashes: Vec<FlashMessage>,
120 use_secure_cookies: bool,
121 key: Key,
122}
123
124impl fmt::Debug for Flash {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 f.debug_struct("Flash")
127 .field("flashes", &self.flashes)
128 .field("use_secure_cookies", &self.use_secure_cookies)
129 .field("key", &"REDACTED")
130 .finish()
131 }
132}
133
134impl Flash {
135 pub fn debug(self, message: impl Into<String>) -> Self {
137 self.push(Level::Debug, message)
138 }
139
140 pub fn info(self, message: impl Into<String>) -> Self {
142 self.push(Level::Info, message)
143 }
144
145 pub fn success(self, message: impl Into<String>) -> Self {
147 self.push(Level::Success, message)
148 }
149
150 pub fn warning(self, message: impl Into<String>) -> Self {
152 self.push(Level::Warning, message)
153 }
154
155 pub fn error(self, message: impl Into<String>) -> Self {
157 self.push(Level::Error, message)
158 }
159
160 pub fn push(mut self, level: Level, message: impl Into<String>) -> Self {
162 self.flashes.push(FlashMessage {
163 message: message.into(),
164 level,
165 });
166 self
167 }
168}
169
170#[async_trait]
171impl<S> FromRequestParts<S> for Flash
172where
173 S: Send + Sync,
174 Config: FromRef<S>,
175{
176 type Rejection = Infallible;
177
178 async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
179 let config = Config::from_ref(state);
180
181 Ok(Self {
182 key: config.key,
183 use_secure_cookies: config.use_secure_cookies,
184 flashes: Default::default(),
185 })
186 }
187}
188
189const COOKIE_NAME: &str = "axum-flash";
190
191impl IntoResponseParts for Flash {
192 type Error = Infallible;
193
194 fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
195 let json =
196 serde_json::to_string(&self.flashes).expect("failed to serialize flash messages");
197
198 let cookies = SignedCookieJar::new(self.key.clone());
199
200 let cookies = cookies.add(create_cookie(json, self.use_secure_cookies));
201 cookies.into_response_parts(res)
202 }
203}
204
205pub(crate) fn create_cookie(
206 value: impl Into<Cow<'static, str>>,
207 use_secure_cookies: bool,
208) -> Cookie<'static> {
209 Cookie::build((COOKIE_NAME, value))
212 .secure(use_secure_cookies)
214 .http_only(true)
216 .same_site(cookie::SameSite::Strict)
218 .path("/")
220 .max_age(
222 Duration::from_secs(10 * 60)
223 .try_into()
224 .expect("failed to convert `std::time::Duration` to `time::Duration`"),
225 )
226 .build()
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize)]
230struct FlashMessage {
231 #[serde(rename = "l")]
232 level: Level,
233 #[serde(rename = "m")]
234 message: String,
235}
236
237#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
239pub enum Level {
240 #[allow(missing_docs)]
241 Debug = 0,
242 #[allow(missing_docs)]
243 Info = 1,
244 #[allow(missing_docs)]
245 Success = 2,
246 #[allow(missing_docs)]
247 Warning = 3,
248 #[allow(missing_docs)]
249 Error = 4,
250}
251
252#[derive(Clone)]
254pub struct Config {
255 use_secure_cookies: bool,
256 key: Key,
257}
258
259impl Config {
260 pub fn new(key: Key) -> Self {
264 Self {
265 use_secure_cookies: true,
266 key,
267 }
268 }
269
270 pub fn use_secure_cookies(mut self, use_secure_cookies: bool) -> Self {
281 self.use_secure_cookies = use_secure_cookies;
282 self
283 }
284}
285
286impl fmt::Debug for Config {
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 f.debug_struct("Config")
289 .field("use_secure_cookies", &self.use_secure_cookies)
290 .field("key", &"REDACTED")
291 .finish()
292 }
293}
294
295#[derive(Clone)]
299pub struct IncomingFlashes {
300 flashes: Vec<FlashMessage>,
301 use_secure_cookies: bool,
302 key: Key,
303}
304
305impl fmt::Debug for IncomingFlashes {
306 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307 f.debug_struct("IncomingFlashes")
308 .field("flashes", &self.flashes)
309 .field("use_secure_cookies", &self.use_secure_cookies)
310 .field("key", &"REDACTED")
311 .finish()
312 }
313}
314
315impl IncomingFlashes {
316 pub fn iter(&self) -> Iter<'_> {
318 Iter(self.flashes.iter())
319 }
320
321 pub fn len(&self) -> usize {
323 self.flashes.len()
324 }
325
326 pub fn is_empty(&self) -> bool {
328 self.flashes.is_empty()
329 }
330}
331
332#[derive(Debug)]
334pub struct Iter<'a>(std::slice::Iter<'a, FlashMessage>);
335
336impl<'a> Iterator for Iter<'a> {
337 type Item = (Level, &'a str);
338
339 fn next(&mut self) -> Option<Self::Item> {
340 let message = self.0.next()?;
341 Some((message.level, &message.message))
342 }
343}
344
345impl<'a> IntoIterator for &'a IncomingFlashes {
346 type Item = (Level, &'a str);
347 type IntoIter = Iter<'a>;
348
349 fn into_iter(self) -> Self::IntoIter {
350 self.iter()
351 }
352}
353
354#[async_trait]
355impl<S> FromRequestParts<S> for IncomingFlashes
356where
357 S: Send + Sync,
358 Config: FromRef<S>,
359{
360 type Rejection = (StatusCode, &'static str);
361
362 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
363 let config = Config::from_ref(state);
364 let cookies = SignedCookieJar::from_headers(&parts.headers, config.key.clone());
365
366 let flashes = cookies
367 .get(COOKIE_NAME)
368 .map(|cookie| cookie.into_owned())
369 .and_then(|cookie| serde_json::from_str::<Vec<FlashMessage>>(cookie.value()).ok())
370 .unwrap_or_default();
371
372 Ok(Self {
373 flashes,
374 use_secure_cookies: config.use_secure_cookies,
375 key: config.key,
376 })
377 }
378}
379
380impl IntoResponseParts for IncomingFlashes {
381 type Error = Infallible;
382
383 fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
384 let cookies = SignedCookieJar::from_headers(res.headers(), self.key);
385
386 let mut cookie = create_cookie("".to_owned(), self.use_secure_cookies);
387 cookie.make_removal();
388 let cookies = cookies.add(cookie);
389 cookies.into_response_parts(res)
390 }
391}
392
393impl IntoResponse for IncomingFlashes {
394 fn into_response(self) -> Response {
395 (self, ()).into_response()
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 #[allow(unused_imports)]
402 use super::*;
403 use axum::{
404 body::Body,
405 http::{header, Request},
406 response::Redirect,
407 routing::get,
408 Router,
409 };
410 use http_body_util::BodyExt;
411 use tower::ServiceExt;
412
413 #[tokio::test]
414 async fn basic() {
415 let config = Config::new(Key::generate()).use_secure_cookies(false);
416
417 let app = Router::new()
418 .route("/", get(root))
419 .route("/set-flash", get(set_flash))
420 .with_state(config);
421
422 async fn root(flash: IncomingFlashes) -> (IncomingFlashes, String) {
423 let messages = flash
424 .into_iter()
425 .map(|(level, text)| format!("{:?}: {}", level, text))
426 .collect::<Vec<_>>()
427 .join(", ");
428 (flash, messages)
429 }
430
431 #[axum::debug_handler(state = Config)]
432 async fn set_flash(flash: Flash) -> (Flash, Redirect) {
433 (flash.debug("Hi from flash!"), Redirect::to("/"))
434 }
435
436 let request = Request::builder()
437 .uri("/set-flash")
438 .body(Body::empty())
439 .unwrap();
440 let mut response = app.clone().oneshot(request).await.unwrap();
441 assert!(response.status().is_redirection());
442 let cookie = response.headers_mut().remove(header::SET_COOKIE).unwrap();
443
444 let request = Request::builder()
445 .uri("/")
446 .header(header::COOKIE, cookie)
447 .body(Body::empty())
448 .unwrap();
449 let response = app.clone().oneshot(request).await.unwrap();
450
451 assert!(response.headers()[header::SET_COOKIE]
452 .to_str()
453 .unwrap()
454 .contains("Max-Age=0"),);
455
456 let bytes = response.into_body().collect().await.unwrap().to_bytes();
457 let body = String::from_utf8(bytes.to_vec()).unwrap();
458 assert_eq!(body, "Debug: Hi from flash!");
459 }
460}