1use std::sync::Arc;
2
3use axum::extract::FromRequestParts;
4use http::request::Parts;
5
6use crate::Error;
7
8use super::state::{FlashEntry, FlashState};
9
10pub struct Flash {
20 state: Arc<FlashState>,
21}
22
23impl Flash {
24 pub fn set(&self, level: &str, message: &str) {
29 self.state.push(level, message);
30 }
31
32 pub fn success(&self, message: &str) {
34 self.set("success", message);
35 }
36
37 pub fn error(&self, message: &str) {
39 self.set("error", message);
40 }
41
42 pub fn warning(&self, message: &str) {
44 self.set("warning", message);
45 }
46
47 pub fn info(&self, message: &str) {
49 self.set("info", message);
50 }
51
52 pub fn messages(&self) -> Vec<FlashEntry> {
57 self.state.mark_read();
58 self.state.incoming.clone()
59 }
60}
61
62impl<S: Send + Sync> FromRequestParts<S> for Flash {
63 type Rejection = Error;
64
65 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
66 parts
67 .extensions
68 .get::<Arc<FlashState>>()
69 .cloned()
70 .map(|state| Flash { state })
71 .ok_or_else(|| Error::internal("flash middleware not applied"))
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use http::StatusCode;
79
80 #[test]
81 fn set_pushes_to_outgoing() {
82 let state = Arc::new(FlashState::new(vec![]));
83 let flash = Flash {
84 state: state.clone(),
85 };
86 flash.set("custom", "hello");
87 let outgoing = state.drain_outgoing();
88 assert_eq!(outgoing.len(), 1);
89 assert_eq!(outgoing[0].level, "custom");
90 assert_eq!(outgoing[0].message, "hello");
91 }
92
93 #[test]
94 fn success_uses_correct_level() {
95 let state = Arc::new(FlashState::new(vec![]));
96 let flash = Flash {
97 state: state.clone(),
98 };
99 flash.success("done");
100 let outgoing = state.drain_outgoing();
101 assert_eq!(outgoing[0].level, "success");
102 }
103
104 #[test]
105 fn error_uses_correct_level() {
106 let state = Arc::new(FlashState::new(vec![]));
107 let flash = Flash {
108 state: state.clone(),
109 };
110 flash.error("fail");
111 let outgoing = state.drain_outgoing();
112 assert_eq!(outgoing[0].level, "error");
113 }
114
115 #[test]
116 fn warning_uses_correct_level() {
117 let state = Arc::new(FlashState::new(vec![]));
118 let flash = Flash {
119 state: state.clone(),
120 };
121 flash.warning("careful");
122 let outgoing = state.drain_outgoing();
123 assert_eq!(outgoing[0].level, "warning");
124 }
125
126 #[test]
127 fn info_uses_correct_level() {
128 let state = Arc::new(FlashState::new(vec![]));
129 let flash = Flash {
130 state: state.clone(),
131 };
132 flash.info("fyi");
133 let outgoing = state.drain_outgoing();
134 assert_eq!(outgoing[0].level, "info");
135 }
136
137 #[test]
138 fn multiple_messages_preserved() {
139 let state = Arc::new(FlashState::new(vec![]));
140 let flash = Flash {
141 state: state.clone(),
142 };
143 flash.success("one");
144 flash.error("two");
145 flash.info("three");
146 let outgoing = state.drain_outgoing();
147 assert_eq!(outgoing.len(), 3);
148 }
149
150 #[test]
151 fn messages_returns_incoming_and_marks_read() {
152 let entries = vec![
153 FlashEntry {
154 level: "success".into(),
155 message: "saved".into(),
156 },
157 FlashEntry {
158 level: "error".into(),
159 message: "oops".into(),
160 },
161 ];
162 let state = Arc::new(FlashState::new(entries.clone()));
163 let flash = Flash {
164 state: state.clone(),
165 };
166
167 let msgs = flash.messages();
168 assert_eq!(msgs, entries);
169 assert!(state.was_read());
170 }
171
172 #[test]
173 fn messages_returns_empty_when_no_incoming() {
174 let state = Arc::new(FlashState::new(vec![]));
175 let flash = Flash {
176 state: state.clone(),
177 };
178
179 let msgs = flash.messages();
180 assert!(msgs.is_empty());
181 assert!(state.was_read());
182 }
183
184 #[test]
185 fn messages_idempotent() {
186 let entries = vec![FlashEntry {
187 level: "info".into(),
188 message: "hi".into(),
189 }];
190 let state = Arc::new(FlashState::new(entries.clone()));
191 let flash = Flash {
192 state: state.clone(),
193 };
194
195 let first = flash.messages();
196 let second = flash.messages();
197 assert_eq!(first, second);
198 }
199
200 #[tokio::test]
201 async fn extract_from_extensions() {
202 let state = Arc::new(FlashState::new(vec![]));
203 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
204 parts.extensions.insert(state.clone());
205
206 let result = <Flash as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
207 assert!(result.is_ok());
208 let flash = result.unwrap();
209 flash.success("test");
210 assert_eq!(state.drain_outgoing().len(), 1);
211 }
212
213 #[tokio::test]
214 async fn extract_missing_returns_500() {
215 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
216
217 let result = <Flash as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
218 assert!(result.is_err());
219 let err = result.err().unwrap();
220 assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
221 }
222}