Skip to main content

modo/flash/
extractor.rs

1use std::sync::Arc;
2
3use axum::extract::FromRequestParts;
4use http::request::Parts;
5
6use crate::Error;
7
8use super::state::{FlashEntry, FlashState};
9
10/// Axum extractor for reading and writing flash messages within a request.
11///
12/// Requires [`FlashLayer`](crate::flash::FlashLayer) to be applied to the router.
13///
14/// # Errors
15///
16/// Extraction fails with [`Error::internal`](crate::Error::internal)
17/// (`500 Internal Server Error`) if [`FlashLayer`](crate::flash::FlashLayer) has
18/// not been applied to the router.
19pub struct Flash {
20    state: Arc<FlashState>,
21}
22
23impl Flash {
24    /// Queue a flash message with an arbitrary severity level.
25    ///
26    /// The message is stored in a signed cookie on the response and becomes
27    /// available to the next request via [`Flash::messages`].
28    pub fn set(&self, level: &str, message: &str) {
29        self.state.push(level, message);
30    }
31
32    /// Queue a flash message with level `"success"`.
33    pub fn success(&self, message: &str) {
34        self.set("success", message);
35    }
36
37    /// Queue a flash message with level `"error"`.
38    pub fn error(&self, message: &str) {
39        self.set("error", message);
40    }
41
42    /// Queue a flash message with level `"warning"`.
43    pub fn warning(&self, message: &str) {
44        self.set("warning", message);
45    }
46
47    /// Queue a flash message with level `"info"`.
48    pub fn info(&self, message: &str) {
49        self.set("info", message);
50    }
51
52    /// Read incoming flash messages and mark them as consumed.
53    ///
54    /// After calling this, the middleware clears the flash cookie on the response.
55    /// Calling this multiple times within the same request returns the same data.
56    pub fn messages(&self) -> Vec<FlashEntry> {
57        self.state.mark_read();
58        self.state.incoming.clone()
59    }
60}
61
62impl<S: Send + Sync> FromRequestParts<S> for Flash {
63    type Rejection = Error;
64
65    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
66        parts
67            .extensions
68            .get::<Arc<FlashState>>()
69            .cloned()
70            .map(|state| Flash { state })
71            .ok_or_else(|| Error::internal("flash middleware not applied"))
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use http::StatusCode;
79
80    #[test]
81    fn set_pushes_to_outgoing() {
82        let state = Arc::new(FlashState::new(vec![]));
83        let flash = Flash {
84            state: state.clone(),
85        };
86        flash.set("custom", "hello");
87        let outgoing = state.drain_outgoing();
88        assert_eq!(outgoing.len(), 1);
89        assert_eq!(outgoing[0].level, "custom");
90        assert_eq!(outgoing[0].message, "hello");
91    }
92
93    #[test]
94    fn success_uses_correct_level() {
95        let state = Arc::new(FlashState::new(vec![]));
96        let flash = Flash {
97            state: state.clone(),
98        };
99        flash.success("done");
100        let outgoing = state.drain_outgoing();
101        assert_eq!(outgoing[0].level, "success");
102    }
103
104    #[test]
105    fn error_uses_correct_level() {
106        let state = Arc::new(FlashState::new(vec![]));
107        let flash = Flash {
108            state: state.clone(),
109        };
110        flash.error("fail");
111        let outgoing = state.drain_outgoing();
112        assert_eq!(outgoing[0].level, "error");
113    }
114
115    #[test]
116    fn warning_uses_correct_level() {
117        let state = Arc::new(FlashState::new(vec![]));
118        let flash = Flash {
119            state: state.clone(),
120        };
121        flash.warning("careful");
122        let outgoing = state.drain_outgoing();
123        assert_eq!(outgoing[0].level, "warning");
124    }
125
126    #[test]
127    fn info_uses_correct_level() {
128        let state = Arc::new(FlashState::new(vec![]));
129        let flash = Flash {
130            state: state.clone(),
131        };
132        flash.info("fyi");
133        let outgoing = state.drain_outgoing();
134        assert_eq!(outgoing[0].level, "info");
135    }
136
137    #[test]
138    fn multiple_messages_preserved() {
139        let state = Arc::new(FlashState::new(vec![]));
140        let flash = Flash {
141            state: state.clone(),
142        };
143        flash.success("one");
144        flash.error("two");
145        flash.info("three");
146        let outgoing = state.drain_outgoing();
147        assert_eq!(outgoing.len(), 3);
148    }
149
150    #[test]
151    fn messages_returns_incoming_and_marks_read() {
152        let entries = vec![
153            FlashEntry {
154                level: "success".into(),
155                message: "saved".into(),
156            },
157            FlashEntry {
158                level: "error".into(),
159                message: "oops".into(),
160            },
161        ];
162        let state = Arc::new(FlashState::new(entries.clone()));
163        let flash = Flash {
164            state: state.clone(),
165        };
166
167        let msgs = flash.messages();
168        assert_eq!(msgs, entries);
169        assert!(state.was_read());
170    }
171
172    #[test]
173    fn messages_returns_empty_when_no_incoming() {
174        let state = Arc::new(FlashState::new(vec![]));
175        let flash = Flash {
176            state: state.clone(),
177        };
178
179        let msgs = flash.messages();
180        assert!(msgs.is_empty());
181        assert!(state.was_read());
182    }
183
184    #[test]
185    fn messages_idempotent() {
186        let entries = vec![FlashEntry {
187            level: "info".into(),
188            message: "hi".into(),
189        }];
190        let state = Arc::new(FlashState::new(entries.clone()));
191        let flash = Flash {
192            state: state.clone(),
193        };
194
195        let first = flash.messages();
196        let second = flash.messages();
197        assert_eq!(first, second);
198    }
199
200    #[tokio::test]
201    async fn extract_from_extensions() {
202        let state = Arc::new(FlashState::new(vec![]));
203        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
204        parts.extensions.insert(state.clone());
205
206        let result = <Flash as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
207        assert!(result.is_ok());
208        let flash = result.unwrap();
209        flash.success("test");
210        assert_eq!(state.drain_outgoing().len(), 1);
211    }
212
213    #[tokio::test]
214    async fn extract_missing_returns_500() {
215        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
216
217        let result = <Flash as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
218        assert!(result.is_err());
219        let err = result.err().unwrap();
220        assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
221    }
222}