1use std::sync::Arc;
2
3use axum::extract::FromRequestParts;
4use http::request::Parts;
5
6use crate::Error;
7
8use super::state::{FlashEntry, FlashState};
9
10pub struct Flash {
15 state: Arc<FlashState>,
16}
17
18impl Flash {
19 pub fn set(&self, level: &str, message: &str) {
24 self.state.push(level, message);
25 }
26
27 pub fn success(&self, message: &str) {
29 self.set("success", message);
30 }
31
32 pub fn error(&self, message: &str) {
34 self.set("error", message);
35 }
36
37 pub fn warning(&self, message: &str) {
39 self.set("warning", message);
40 }
41
42 pub fn info(&self, message: &str) {
44 self.set("info", message);
45 }
46
47 pub fn messages(&self) -> Vec<FlashEntry> {
52 self.state.mark_read();
53 self.state.incoming.clone()
54 }
55}
56
57impl<S: Send + Sync> FromRequestParts<S> for Flash {
58 type Rejection = Error;
59
60 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
61 parts
62 .extensions
63 .get::<Arc<FlashState>>()
64 .cloned()
65 .map(|state| Flash { state })
66 .ok_or_else(|| Error::internal("flash middleware not applied"))
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use http::StatusCode;
74
75 #[test]
76 fn set_pushes_to_outgoing() {
77 let state = Arc::new(FlashState::new(vec![]));
78 let flash = Flash {
79 state: state.clone(),
80 };
81 flash.set("custom", "hello");
82 let outgoing = state.drain_outgoing();
83 assert_eq!(outgoing.len(), 1);
84 assert_eq!(outgoing[0].level, "custom");
85 assert_eq!(outgoing[0].message, "hello");
86 }
87
88 #[test]
89 fn success_uses_correct_level() {
90 let state = Arc::new(FlashState::new(vec![]));
91 let flash = Flash {
92 state: state.clone(),
93 };
94 flash.success("done");
95 let outgoing = state.drain_outgoing();
96 assert_eq!(outgoing[0].level, "success");
97 }
98
99 #[test]
100 fn error_uses_correct_level() {
101 let state = Arc::new(FlashState::new(vec![]));
102 let flash = Flash {
103 state: state.clone(),
104 };
105 flash.error("fail");
106 let outgoing = state.drain_outgoing();
107 assert_eq!(outgoing[0].level, "error");
108 }
109
110 #[test]
111 fn warning_uses_correct_level() {
112 let state = Arc::new(FlashState::new(vec![]));
113 let flash = Flash {
114 state: state.clone(),
115 };
116 flash.warning("careful");
117 let outgoing = state.drain_outgoing();
118 assert_eq!(outgoing[0].level, "warning");
119 }
120
121 #[test]
122 fn info_uses_correct_level() {
123 let state = Arc::new(FlashState::new(vec![]));
124 let flash = Flash {
125 state: state.clone(),
126 };
127 flash.info("fyi");
128 let outgoing = state.drain_outgoing();
129 assert_eq!(outgoing[0].level, "info");
130 }
131
132 #[test]
133 fn multiple_messages_preserved() {
134 let state = Arc::new(FlashState::new(vec![]));
135 let flash = Flash {
136 state: state.clone(),
137 };
138 flash.success("one");
139 flash.error("two");
140 flash.info("three");
141 let outgoing = state.drain_outgoing();
142 assert_eq!(outgoing.len(), 3);
143 }
144
145 #[test]
146 fn messages_returns_incoming_and_marks_read() {
147 let entries = vec![
148 FlashEntry {
149 level: "success".into(),
150 message: "saved".into(),
151 },
152 FlashEntry {
153 level: "error".into(),
154 message: "oops".into(),
155 },
156 ];
157 let state = Arc::new(FlashState::new(entries.clone()));
158 let flash = Flash {
159 state: state.clone(),
160 };
161
162 let msgs = flash.messages();
163 assert_eq!(msgs, entries);
164 assert!(state.was_read());
165 }
166
167 #[test]
168 fn messages_returns_empty_when_no_incoming() {
169 let state = Arc::new(FlashState::new(vec![]));
170 let flash = Flash {
171 state: state.clone(),
172 };
173
174 let msgs = flash.messages();
175 assert!(msgs.is_empty());
176 assert!(state.was_read());
177 }
178
179 #[test]
180 fn messages_idempotent() {
181 let entries = vec![FlashEntry {
182 level: "info".into(),
183 message: "hi".into(),
184 }];
185 let state = Arc::new(FlashState::new(entries.clone()));
186 let flash = Flash {
187 state: state.clone(),
188 };
189
190 let first = flash.messages();
191 let second = flash.messages();
192 assert_eq!(first, second);
193 }
194
195 #[tokio::test]
196 async fn extract_from_extensions() {
197 let state = Arc::new(FlashState::new(vec![]));
198 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
199 parts.extensions.insert(state.clone());
200
201 let result = <Flash as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
202 assert!(result.is_ok());
203 let flash = result.unwrap();
204 flash.success("test");
205 assert_eq!(state.drain_outgoing().len(), 1);
206 }
207
208 #[tokio::test]
209 async fn extract_missing_returns_500() {
210 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
211
212 let result = <Flash as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
213 assert!(result.is_err());
214 let err = result.err().unwrap();
215 assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
216 }
217}