salvo_flash/
lib.rs

1//! The flash message lib for Salvo web framework.
2//!
3//! Read more: <https://salvo.rs>
4#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
5#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
6#![cfg_attr(docsrs, feature(doc_cfg))]
7
8use std::fmt::{self, Debug, Display, Formatter};
9use std::ops::Deref;
10
11use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
12use serde::{Deserialize, Serialize};
13
14#[macro_use]
15mod cfg;
16
17cfg_feature! {
18    #![feature = "cookie-store"]
19
20    mod cookie_store;
21    pub use cookie_store::CookieStore;
22
23    /// Helper function to create a `CookieStore`.
24    #[must_use] pub fn cookie_store() -> CookieStore {
25        CookieStore::new()
26    }
27}
28
29cfg_feature! {
30    #![feature = "session-store"]
31
32    mod session_store;
33    pub use session_store::SessionStore;
34
35    /// Helper function to create a `SessionStore`.
36    #[must_use]
37    pub fn session_store() -> SessionStore {
38        SessionStore::new()
39    }
40}
41
42/// Key for incoming flash messages in depot.
43pub const INCOMING_FLASH_KEY: &str = "::salvo::flash::incoming_flash";
44
45/// Key for outgoing flash messages in depot.
46pub const OUTGOING_FLASH_KEY: &str = "::salvo::flash::outgoing_flash";
47
48/// A flash is a list of messages.
49#[derive(Default, Serialize, Deserialize, Clone, Debug)]
50pub struct Flash(pub Vec<FlashMessage>);
51impl Flash {
52    /// Add a new message with level `Debug`.
53    #[inline]
54    pub fn debug(&mut self, message: impl Into<String>) -> &mut Self {
55        self.0.push(FlashMessage::debug(message));
56        self
57    }
58    /// Add a new message with level `Info`.
59    #[inline]
60    pub fn info(&mut self, message: impl Into<String>) -> &mut Self {
61        self.0.push(FlashMessage::info(message));
62        self
63    }
64    /// Add a new message with level `Success`.
65    #[inline]
66    pub fn success(&mut self, message: impl Into<String>) -> &mut Self {
67        self.0.push(FlashMessage::success(message));
68        self
69    }
70    /// Add a new message with level `Warning`.
71    #[inline]
72    pub fn warning(&mut self, message: impl Into<String>) -> &mut Self {
73        self.0.push(FlashMessage::warning(message));
74        self
75    }
76    /// Add a new message with level `Error`.
77    #[inline]
78    pub fn error(&mut self, message: impl Into<String>) -> &mut Self {
79        self.0.push(FlashMessage::error(message));
80        self
81    }
82}
83
84impl Deref for Flash {
85    type Target = Vec<FlashMessage>;
86
87    fn deref(&self) -> &Self::Target {
88        &self.0
89    }
90}
91
92/// A flash message.
93#[derive(Serialize, Deserialize, Clone, Debug)]
94#[non_exhaustive]
95pub struct FlashMessage {
96    /// Flash message level.
97    pub level: FlashLevel,
98    /// Flash message content.
99    pub value: String,
100}
101impl FlashMessage {
102    /// Create a new `FlashMessage` with `FlashLevel::Debug`.
103    #[inline]
104    pub fn debug(message: impl Into<String>) -> Self {
105        Self {
106            level: FlashLevel::Debug,
107            value: message.into(),
108        }
109    }
110    /// Create a new `FlashMessage` with `FlashLevel::Info`.
111    #[inline]
112    pub fn info(message: impl Into<String>) -> Self {
113        Self {
114            level: FlashLevel::Info,
115            value: message.into(),
116        }
117    }
118    /// Create a new `FlashMessage` with `FlashLevel::Success`.
119    #[inline]
120    pub fn success(message: impl Into<String>) -> Self {
121        Self {
122            level: FlashLevel::Success,
123            value: message.into(),
124        }
125    }
126    /// Create a new `FlashMessage` with `FlashLevel::Warning`.
127    #[inline]
128    pub fn warning(message: impl Into<String>) -> Self {
129        Self {
130            level: FlashLevel::Warning,
131            value: message.into(),
132        }
133    }
134    /// create a new `FlashMessage` with `FlashLevel::Error`.
135    #[inline]
136    pub fn error(message: impl Into<String>) -> Self {
137        Self {
138            level: FlashLevel::Error,
139            value: message.into(),
140        }
141    }
142}
143
144/// Verbosity level of a flash message.
145#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
146pub enum FlashLevel {
147    #[allow(missing_docs)]
148    Debug = 0,
149    #[allow(missing_docs)]
150    Info = 1,
151    #[allow(missing_docs)]
152    Success = 2,
153    #[allow(missing_docs)]
154    Warning = 3,
155    #[allow(missing_docs)]
156    Error = 4,
157}
158impl FlashLevel {
159    /// Convert a `FlashLevel` to a `&str`.
160    #[must_use]
161    pub fn to_str(&self) -> &'static str {
162        match self {
163            Self::Debug => "debug",
164            Self::Info => "info",
165            Self::Success => "success",
166            Self::Warning => "warning",
167            Self::Error => "error",
168        }
169    }
170}
171impl Debug for FlashLevel {
172    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
173        write!(f, "{}", self.to_str())
174    }
175}
176
177impl Display for FlashLevel {
178    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
179        write!(f, "{}", self.to_str())
180    }
181}
182
183/// `FlashStore` is for stores flash messages.
184pub trait FlashStore: Debug + Send + Sync + 'static {
185    /// Get the flash messages from the store.
186    fn load_flash(
187        &self,
188        req: &mut Request,
189        depot: &mut Depot,
190    ) -> impl Future<Output = Option<Flash>> + Send;
191    /// Save the flash messages to the store.
192    fn save_flash(
193        &self,
194        req: &mut Request,
195        depot: &mut Depot,
196        res: &mut Response,
197        flash: Flash,
198    ) -> impl Future<Output = ()> + Send;
199    /// Clear the flash store.
200    fn clear_flash(&self, depot: &mut Depot, res: &mut Response)
201    -> impl Future<Output = ()> + Send;
202}
203
204/// A trait for `Depot` to get flash messages.
205pub trait FlashDepotExt {
206    /// Get incoming flash.
207    fn incoming_flash(&mut self) -> Option<&Flash>;
208    /// Get outgoing flash.
209    fn outgoing_flash(&self) -> &Flash;
210    /// Get mutable outgoing flash.
211    fn outgoing_flash_mut(&mut self) -> &mut Flash;
212}
213
214impl FlashDepotExt for Depot {
215    #[inline]
216    fn incoming_flash(&mut self) -> Option<&Flash> {
217        self.get::<Flash>(INCOMING_FLASH_KEY).ok()
218    }
219
220    #[inline]
221    fn outgoing_flash(&self) -> &Flash {
222        self.get::<Flash>(OUTGOING_FLASH_KEY)
223            .expect("Flash should be initialized")
224    }
225
226    #[inline]
227    fn outgoing_flash_mut(&mut self) -> &mut Flash {
228        self.get_mut::<Flash>(OUTGOING_FLASH_KEY)
229            .expect("Flash should be initialized")
230    }
231}
232
233/// `FlashHandler` is a middleware for flash messages.
234#[non_exhaustive]
235pub struct FlashHandler<S> {
236    store: S,
237    /// Minimum level of messages to be displayed.
238    pub minimum_level: Option<FlashLevel>,
239}
240impl<S> FlashHandler<S> {
241    /// Create a new `FlashHandler` with the given `FlashStore`.
242    #[inline]
243    pub fn new(store: S) -> Self {
244        Self {
245            store,
246            minimum_level: None,
247        }
248    }
249
250    /// Sets the minimum level of messages to be displayed.
251    #[inline]
252    pub fn minimum_level(&mut self, level: impl Into<Option<FlashLevel>>) -> &mut Self {
253        self.minimum_level = level.into();
254        self
255    }
256}
257impl<S: FlashStore> fmt::Debug for FlashHandler<S> {
258    #[inline]
259    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
260        f.debug_struct("FlashHandler")
261            .field("store", &self.store)
262            .finish()
263    }
264}
265#[async_trait]
266impl<S> Handler for FlashHandler<S>
267where
268    S: FlashStore,
269{
270    async fn handle(
271        &self,
272        req: &mut Request,
273        depot: &mut Depot,
274        res: &mut Response,
275        ctrl: &mut FlowCtrl,
276    ) {
277        let mut has_incoming = false;
278        if let Some(flash) = self.store.load_flash(req, depot).await {
279            has_incoming = !flash.is_empty();
280            depot.insert(INCOMING_FLASH_KEY, flash);
281        }
282        depot.insert(OUTGOING_FLASH_KEY, Flash(vec![]));
283
284        ctrl.call_next(req, depot, res).await;
285        if ctrl.is_ceased() {
286            return;
287        }
288
289        let mut flash = depot
290            .remove::<Flash>(OUTGOING_FLASH_KEY)
291            .unwrap_or_default();
292        if let Some(min_level) = self.minimum_level {
293            flash.0.retain(|msg| msg.level >= min_level);
294        }
295        if !flash.is_empty() {
296            self.store.save_flash(req, depot, res, flash).await;
297        } else if has_incoming {
298            self.store.clear_flash(depot, res).await;
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use std::fmt::Write;
306
307    use salvo_core::http::header::{COOKIE, SET_COOKIE};
308    use salvo_core::prelude::*;
309    use salvo_core::test::{ResponseExt, TestClient};
310
311    use super::*;
312
313    #[handler]
314    pub async fn set_flash(depot: &mut Depot, res: &mut Response) {
315        let flash = depot.outgoing_flash_mut();
316        flash.info("Hey there!").debug("How is it going?");
317        res.render(Redirect::other("/get"));
318    }
319
320    #[handler]
321    pub async fn get_flash(depot: &mut Depot, _res: &mut Response) -> String {
322        let mut body = String::new();
323        if let Some(flash) = depot.incoming_flash() {
324            for message in flash.iter() {
325                writeln!(body, "{} - {}", message.value, message.level).unwrap();
326            }
327        }
328        body
329    }
330
331    #[cfg(feature = "cookie-store")]
332    #[tokio::test]
333    async fn test_cookie_store() {
334        let cookie_name = "my-custom-cookie-name".to_owned();
335        let router = Router::new()
336            .hoop(CookieStore::new().name(&cookie_name).into_handler())
337            .push(Router::with_path("get").get(get_flash))
338            .push(Router::with_path("set").get(set_flash));
339        let service = Service::new(router);
340
341        let response = TestClient::get("http://127.0.0.1:8698/set")
342            .send(&service)
343            .await;
344        assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
345
346        let cookie = response.headers().get(SET_COOKIE).unwrap();
347        assert!(cookie.to_str().unwrap().contains(&cookie_name));
348
349        let mut response = TestClient::get("http://127.0.0.1:8698/get")
350            .add_header(COOKIE, cookie, true)
351            .send(&service)
352            .await;
353        assert!(response.take_string().await.unwrap().contains("Hey there!"));
354
355        let cookie = response.headers().get(SET_COOKIE).unwrap();
356        assert!(cookie.to_str().unwrap().contains(&cookie_name));
357
358        let mut response = TestClient::get("http://127.0.0.1:8698/get")
359            .add_header(COOKIE, cookie, true)
360            .send(&service)
361            .await;
362        assert!(response.take_string().await.unwrap().is_empty());
363    }
364
365    #[cfg(feature = "session-store")]
366    #[tokio::test]
367    async fn test_session_store() {
368        let session_handler = salvo_session::SessionHandler::builder(
369            salvo_session::MemoryStore::new(),
370            b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
371        )
372        .build()
373        .unwrap();
374
375        let session_name = "my-custom-session-name".to_string();
376        let router = Router::new()
377            .hoop(session_handler)
378            .hoop(SessionStore::new().name(&session_name).into_handler())
379            .push(Router::with_path("get").get(get_flash))
380            .push(Router::with_path("set").get(set_flash));
381        let service = Service::new(router);
382
383        let response = TestClient::get("http://127.0.0.1:8698/set")
384            .send(&service)
385            .await;
386        assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
387
388        let cookie = response.headers().get(SET_COOKIE).unwrap();
389
390        let mut response = TestClient::get("http://127.0.0.1:8698/get")
391            .add_header(COOKIE, cookie, true)
392            .send(&service)
393            .await;
394        assert!(response.take_string().await.unwrap().contains("Hey there!"));
395
396        let mut response = TestClient::get("http://127.0.0.1:8698/get")
397            .add_header(COOKIE, cookie, true)
398            .send(&service)
399            .await;
400        assert!(response.take_string().await.unwrap().is_empty());
401    }
402}