1use axum::body::Body;
2use axum::http::Request as AxumRequest;
3use axum_test::{TestResponse as AxumTestResponse, TestServer, TestWebSocket, WsMessage};
4
5pub mod multipart;
6pub use multipart::{MultipartFilePart, build_multipart_body};
7
8pub mod form;
9
10pub mod test_client;
11pub use test_client::TestClient;
12
13use brotli::Decompressor;
14use flate2::read::GzDecoder;
15pub use form::encode_urlencoded_body;
16use http_body_util::BodyExt;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::io::{Cursor, Read};
20
21#[derive(Debug, Clone)]
23pub struct ResponseSnapshot {
24 pub status: u16,
26 pub headers: HashMap<String, String>,
28 pub body: Vec<u8>,
30}
31
32impl ResponseSnapshot {
33 pub fn text(&self) -> Result<String, std::string::FromUtf8Error> {
35 String::from_utf8(self.body.clone())
36 }
37
38 pub fn json(&self) -> Result<Value, serde_json::Error> {
40 serde_json::from_slice(&self.body)
41 }
42
43 pub fn header(&self, name: &str) -> Option<&str> {
45 self.headers.get(&name.to_ascii_lowercase()).map(|s| s.as_str())
46 }
47}
48
49#[derive(Debug)]
51pub enum SnapshotError {
52 InvalidHeader(String),
54 Decompression(String),
56}
57
58impl std::fmt::Display for SnapshotError {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 match self {
61 SnapshotError::InvalidHeader(msg) => write!(f, "Invalid header: {}", msg),
62 SnapshotError::Decompression(msg) => write!(f, "Failed to decode body: {}", msg),
63 }
64 }
65}
66
67impl std::error::Error for SnapshotError {}
68
69pub async fn call_test_server(server: &TestServer, request: AxumRequest<Body>) -> AxumTestResponse {
72 let (parts, body) = request.into_parts();
73
74 let mut path = parts.uri.path().to_string();
75 if let Some(query) = parts.uri.query()
76 && !query.is_empty()
77 {
78 path.push('?');
79 path.push_str(query);
80 }
81
82 let mut test_request = server.method(parts.method.clone(), &path);
83
84 for (name, value) in parts.headers.iter() {
85 test_request = test_request.add_header(name.clone(), value.clone());
86 }
87
88 let collected = body
89 .collect()
90 .await
91 .expect("failed to read request body for test dispatch");
92 let bytes = collected.to_bytes();
93 if !bytes.is_empty() {
94 test_request = test_request.bytes(bytes);
95 }
96
97 test_request.await
98}
99
100pub async fn snapshot_response(response: AxumTestResponse) -> Result<ResponseSnapshot, SnapshotError> {
102 let status = response.status_code().as_u16();
103
104 let mut headers = HashMap::new();
105 for (name, value) in response.headers() {
106 let header_value = value
107 .to_str()
108 .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
109 headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
110 }
111
112 let body = response.into_bytes();
113 let decoded_body = decode_body(&headers, body.to_vec())?;
114
115 Ok(ResponseSnapshot {
116 status,
117 headers,
118 body: decoded_body,
119 })
120}
121
122pub async fn snapshot_http_response(
124 response: axum::response::Response<Body>,
125) -> Result<ResponseSnapshot, SnapshotError> {
126 let (parts, body) = response.into_parts();
127 let status = parts.status.as_u16();
128
129 let mut headers = HashMap::new();
130 for (name, value) in parts.headers.iter() {
131 let header_value = value
132 .to_str()
133 .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
134 headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
135 }
136
137 let collected = body
138 .collect()
139 .await
140 .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
141 let bytes = collected.to_bytes();
142 let decoded_body = decode_body(&headers, bytes.to_vec())?;
143
144 Ok(ResponseSnapshot {
145 status,
146 headers,
147 body: decoded_body,
148 })
149}
150
151fn decode_body(headers: &HashMap<String, String>, body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
152 let encoding = headers
153 .get("content-encoding")
154 .map(|value| value.trim().to_ascii_lowercase());
155
156 match encoding.as_deref() {
157 Some("gzip" | "x-gzip") => decode_gzip(body),
158 Some("br") => decode_brotli(body),
159 _ => Ok(body),
160 }
161}
162
163fn decode_gzip(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
164 let mut decoder = GzDecoder::new(Cursor::new(body));
165 let mut decoded_bytes = Vec::new();
166 decoder
167 .read_to_end(&mut decoded_bytes)
168 .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
169 Ok(decoded_bytes)
170}
171
172fn decode_brotli(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
173 let mut decoder = Decompressor::new(Cursor::new(body), 4096);
174 let mut decoded_bytes = Vec::new();
175 decoder
176 .read_to_end(&mut decoded_bytes)
177 .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
178 Ok(decoded_bytes)
179}
180
181pub struct WebSocketConnection {
186 inner: TestWebSocket,
187}
188
189impl WebSocketConnection {
190 pub fn new(inner: TestWebSocket) -> Self {
192 Self { inner }
193 }
194
195 pub async fn send_text(&mut self, text: impl std::fmt::Display) {
197 self.inner.send_text(text).await;
198 }
199
200 pub async fn send_json<T: serde::Serialize>(&mut self, value: &T) {
202 self.inner.send_json(value).await;
203 }
204
205 pub async fn send_message(&mut self, msg: WsMessage) {
207 self.inner.send_message(msg).await;
208 }
209
210 pub async fn receive_text(&mut self) -> String {
212 self.inner.receive_text().await
213 }
214
215 pub async fn receive_json<T: serde::de::DeserializeOwned>(&mut self) -> T {
217 self.inner.receive_json().await
218 }
219
220 pub async fn receive_bytes(&mut self) -> bytes::Bytes {
222 self.inner.receive_bytes().await
223 }
224
225 pub async fn receive_message(&mut self) -> WebSocketMessage {
227 let msg = self.inner.receive_message().await;
228 WebSocketMessage::from_ws_message(msg)
229 }
230
231 pub async fn close(self) {
233 self.inner.close().await;
234 }
235}
236
237#[derive(Debug, Clone)]
239pub enum WebSocketMessage {
240 Text(String),
242 Binary(Vec<u8>),
244 Close(Option<String>),
246 Ping(Vec<u8>),
248 Pong(Vec<u8>),
250}
251
252impl WebSocketMessage {
253 fn from_ws_message(msg: WsMessage) -> Self {
254 match msg {
255 WsMessage::Text(text) => WebSocketMessage::Text(text.to_string()),
256 WsMessage::Binary(data) => WebSocketMessage::Binary(data.to_vec()),
257 WsMessage::Close(frame) => WebSocketMessage::Close(frame.map(|f| f.reason.to_string())),
258 WsMessage::Ping(data) => WebSocketMessage::Ping(data.to_vec()),
259 WsMessage::Pong(data) => WebSocketMessage::Pong(data.to_vec()),
260 WsMessage::Frame(_) => WebSocketMessage::Close(None),
261 }
262 }
263
264 pub fn as_text(&self) -> Option<&str> {
266 match self {
267 WebSocketMessage::Text(text) => Some(text),
268 _ => None,
269 }
270 }
271
272 pub fn as_json(&self) -> Result<Value, String> {
274 match self {
275 WebSocketMessage::Text(text) => {
276 serde_json::from_str(text).map_err(|e| format!("Failed to parse JSON: {}", e))
277 }
278 _ => Err("Message is not text".to_string()),
279 }
280 }
281
282 pub fn as_binary(&self) -> Option<&[u8]> {
284 match self {
285 WebSocketMessage::Binary(data) => Some(data),
286 _ => None,
287 }
288 }
289
290 pub fn is_close(&self) -> bool {
292 matches!(self, WebSocketMessage::Close(_))
293 }
294}
295
296pub async fn connect_websocket(server: &TestServer, path: &str) -> WebSocketConnection {
298 let ws = server.get_websocket(path).await.into_websocket().await;
299 WebSocketConnection::new(ws)
300}
301
302#[derive(Debug)]
306pub struct SseStream {
307 body: String,
308 events: Vec<SseEvent>,
309}
310
311impl SseStream {
312 pub fn from_response(response: &ResponseSnapshot) -> Result<Self, String> {
314 let body = response
315 .text()
316 .map_err(|e| format!("Failed to read response body: {}", e))?;
317
318 let events = Self::parse_events(&body);
319
320 Ok(Self { body, events })
321 }
322
323 fn parse_events(body: &str) -> Vec<SseEvent> {
324 let mut events = Vec::new();
325 let lines: Vec<&str> = body.lines().collect();
326 let mut i = 0;
327
328 while i < lines.len() {
329 if lines[i].starts_with("data:") {
330 let data = lines[i].trim_start_matches("data:").trim().to_string();
331 events.push(SseEvent { data });
332 } else if lines[i].starts_with("data") {
333 let data = lines[i].trim_start_matches("data").trim().to_string();
334 if !data.is_empty() || lines[i].len() == 4 {
335 events.push(SseEvent { data });
336 }
337 }
338 i += 1;
339 }
340
341 events
342 }
343
344 pub fn events(&self) -> &[SseEvent] {
346 &self.events
347 }
348
349 pub fn body(&self) -> &str {
351 &self.body
352 }
353
354 pub fn events_as_json(&self) -> Result<Vec<Value>, String> {
356 self.events
357 .iter()
358 .map(|event| event.as_json())
359 .collect::<Result<Vec<_>, _>>()
360 }
361}
362
363#[derive(Debug, Clone)]
365pub struct SseEvent {
366 pub data: String,
368}
369
370impl SseEvent {
371 pub fn as_json(&self) -> Result<Value, String> {
373 serde_json::from_str(&self.data).map_err(|e| format!("Failed to parse JSON: {}", e))
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn sse_stream_parses_multiple_events() {
383 let mut headers = HashMap::new();
384 headers.insert("content-type".to_string(), "text/event-stream".to_string());
385
386 let snapshot = ResponseSnapshot {
387 status: 200,
388 headers,
389 body: b"data: {\"id\": 1}\n\ndata: \"hello\"\n\n".to_vec(),
390 };
391
392 let stream = SseStream::from_response(&snapshot).expect("stream");
393 assert_eq!(stream.events().len(), 2);
394 assert_eq!(stream.events()[0].as_json().unwrap()["id"], serde_json::json!(1));
395 assert_eq!(stream.events()[1].data, "\"hello\"");
396 assert_eq!(stream.events_as_json().unwrap().len(), 2);
397 }
398
399 #[test]
400 fn sse_event_reports_invalid_json() {
401 let event = SseEvent {
402 data: "not-json".to_string(),
403 };
404 assert!(event.as_json().is_err());
405 }
406}