Skip to main content

modo/flash/
extractor.rs

1use std::sync::Arc;
2
3use axum::extract::FromRequestParts;
4use http::request::Parts;
5
6use crate::Error;
7
8use super::state::{FlashEntry, FlashState};
9
10/// Axum extractor for reading and writing flash messages within a request.
11///
12/// Requires [`FlashLayer`](crate::flash::FlashLayer) to be applied to the router.
13/// Extraction fails with `500 Internal Server Error` if the middleware is absent.
14pub struct Flash {
15    state: Arc<FlashState>,
16}
17
18impl Flash {
19    /// Queue a flash message with an arbitrary severity level.
20    ///
21    /// The message is stored in a signed cookie on the response and becomes
22    /// available to the next request via [`Flash::messages`].
23    pub fn set(&self, level: &str, message: &str) {
24        self.state.push(level, message);
25    }
26
27    /// Queue a flash message with level `"success"`.
28    pub fn success(&self, message: &str) {
29        self.set("success", message);
30    }
31
32    /// Queue a flash message with level `"error"`.
33    pub fn error(&self, message: &str) {
34        self.set("error", message);
35    }
36
37    /// Queue a flash message with level `"warning"`.
38    pub fn warning(&self, message: &str) {
39        self.set("warning", message);
40    }
41
42    /// Queue a flash message with level `"info"`.
43    pub fn info(&self, message: &str) {
44        self.set("info", message);
45    }
46
47    /// Read incoming flash messages and mark them as consumed.
48    ///
49    /// After calling this, the middleware clears the flash cookie on the response.
50    /// Calling this multiple times within the same request returns the same data.
51    pub fn messages(&self) -> Vec<FlashEntry> {
52        self.state.mark_read();
53        self.state.incoming.clone()
54    }
55}
56
57impl<S: Send + Sync> FromRequestParts<S> for Flash {
58    type Rejection = Error;
59
60    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
61        parts
62            .extensions
63            .get::<Arc<FlashState>>()
64            .cloned()
65            .map(|state| Flash { state })
66            .ok_or_else(|| Error::internal("flash middleware not applied"))
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use http::StatusCode;
74
75    #[test]
76    fn set_pushes_to_outgoing() {
77        let state = Arc::new(FlashState::new(vec![]));
78        let flash = Flash {
79            state: state.clone(),
80        };
81        flash.set("custom", "hello");
82        let outgoing = state.drain_outgoing();
83        assert_eq!(outgoing.len(), 1);
84        assert_eq!(outgoing[0].level, "custom");
85        assert_eq!(outgoing[0].message, "hello");
86    }
87
88    #[test]
89    fn success_uses_correct_level() {
90        let state = Arc::new(FlashState::new(vec![]));
91        let flash = Flash {
92            state: state.clone(),
93        };
94        flash.success("done");
95        let outgoing = state.drain_outgoing();
96        assert_eq!(outgoing[0].level, "success");
97    }
98
99    #[test]
100    fn error_uses_correct_level() {
101        let state = Arc::new(FlashState::new(vec![]));
102        let flash = Flash {
103            state: state.clone(),
104        };
105        flash.error("fail");
106        let outgoing = state.drain_outgoing();
107        assert_eq!(outgoing[0].level, "error");
108    }
109
110    #[test]
111    fn warning_uses_correct_level() {
112        let state = Arc::new(FlashState::new(vec![]));
113        let flash = Flash {
114            state: state.clone(),
115        };
116        flash.warning("careful");
117        let outgoing = state.drain_outgoing();
118        assert_eq!(outgoing[0].level, "warning");
119    }
120
121    #[test]
122    fn info_uses_correct_level() {
123        let state = Arc::new(FlashState::new(vec![]));
124        let flash = Flash {
125            state: state.clone(),
126        };
127        flash.info("fyi");
128        let outgoing = state.drain_outgoing();
129        assert_eq!(outgoing[0].level, "info");
130    }
131
132    #[test]
133    fn multiple_messages_preserved() {
134        let state = Arc::new(FlashState::new(vec![]));
135        let flash = Flash {
136            state: state.clone(),
137        };
138        flash.success("one");
139        flash.error("two");
140        flash.info("three");
141        let outgoing = state.drain_outgoing();
142        assert_eq!(outgoing.len(), 3);
143    }
144
145    #[test]
146    fn messages_returns_incoming_and_marks_read() {
147        let entries = vec![
148            FlashEntry {
149                level: "success".into(),
150                message: "saved".into(),
151            },
152            FlashEntry {
153                level: "error".into(),
154                message: "oops".into(),
155            },
156        ];
157        let state = Arc::new(FlashState::new(entries.clone()));
158        let flash = Flash {
159            state: state.clone(),
160        };
161
162        let msgs = flash.messages();
163        assert_eq!(msgs, entries);
164        assert!(state.was_read());
165    }
166
167    #[test]
168    fn messages_returns_empty_when_no_incoming() {
169        let state = Arc::new(FlashState::new(vec![]));
170        let flash = Flash {
171            state: state.clone(),
172        };
173
174        let msgs = flash.messages();
175        assert!(msgs.is_empty());
176        assert!(state.was_read());
177    }
178
179    #[test]
180    fn messages_idempotent() {
181        let entries = vec![FlashEntry {
182            level: "info".into(),
183            message: "hi".into(),
184        }];
185        let state = Arc::new(FlashState::new(entries.clone()));
186        let flash = Flash {
187            state: state.clone(),
188        };
189
190        let first = flash.messages();
191        let second = flash.messages();
192        assert_eq!(first, second);
193    }
194
195    #[tokio::test]
196    async fn extract_from_extensions() {
197        let state = Arc::new(FlashState::new(vec![]));
198        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
199        parts.extensions.insert(state.clone());
200
201        let result = <Flash as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
202        assert!(result.is_ok());
203        let flash = result.unwrap();
204        flash.success("test");
205        assert_eq!(state.drain_outgoing().len(), 1);
206    }
207
208    #[tokio::test]
209    async fn extract_missing_returns_500() {
210        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
211
212        let result = <Flash as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
213        assert!(result.is_err());
214        let err = result.err().unwrap();
215        assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
216    }
217}