1use std::convert::Infallible;
8
9use anvil_core::Application;
10use axum::body::Body;
11use axum::Router;
12use http::{HeaderMap, Method, Request, StatusCode};
13use http_body_util::BodyExt;
14use serde::de::DeserializeOwned;
15use tower::ServiceExt;
16
17pub struct TestClient {
18 router: Router,
19 base_headers: HeaderMap,
20}
21
22impl TestClient {
23 pub async fn new(app: Application) -> Self {
24 Self {
25 router: app.into_router(),
26 base_headers: HeaderMap::new(),
27 }
28 }
29
30 pub fn from_router(router: Router) -> Self {
31 Self {
32 router,
33 base_headers: HeaderMap::new(),
34 }
35 }
36
37 pub fn with_header(mut self, name: &str, value: &str) -> Self {
39 if let (Ok(name), Ok(val)) = (
40 http::HeaderName::try_from(name),
41 http::HeaderValue::try_from(value),
42 ) {
43 self.base_headers.insert(name, val);
44 }
45 self
46 }
47
48 pub fn with_bearer(self, token: &str) -> Self {
50 self.with_header("authorization", &format!("Bearer {token}"))
51 }
52
53 pub fn with_ajax(self) -> Self {
55 self.with_header("x-requested-with", "XMLHttpRequest")
56 }
57
58 pub async fn get(&self, path: &str) -> TestResponse {
59 self.request(Method::GET, path, None, &[]).await
60 }
61
62 pub async fn post(&self, path: &str, body: serde_json::Value) -> TestResponse {
63 self.request(Method::POST, path, Some(body), &[]).await
64 }
65
66 pub async fn put(&self, path: &str, body: serde_json::Value) -> TestResponse {
67 self.request(Method::PUT, path, Some(body), &[]).await
68 }
69
70 pub async fn patch(&self, path: &str, body: serde_json::Value) -> TestResponse {
71 self.request(Method::PATCH, path, Some(body), &[]).await
72 }
73
74 pub async fn delete(&self, path: &str) -> TestResponse {
75 self.request(Method::DELETE, path, None, &[]).await
76 }
77
78 pub async fn post_form(&self, path: &str, form: &[(&str, &str)]) -> TestResponse {
80 let body = serde_urlencoded::to_string(form).unwrap_or_default();
81 let req = Request::builder()
82 .method(Method::POST)
83 .uri(path)
84 .header("content-type", "application/x-www-form-urlencoded")
85 .body(Body::from(body))
86 .unwrap();
87 self.send(req).await
88 }
89
90 async fn request(
91 &self,
92 method: Method,
93 path: &str,
94 body: Option<serde_json::Value>,
95 extra_headers: &[(&str, &str)],
96 ) -> TestResponse {
97 let mut req = Request::builder().method(method).uri(path);
98 let body = match body {
99 Some(v) => {
100 req = req.header("content-type", "application/json");
101 Body::from(serde_json::to_vec(&v).unwrap())
102 }
103 None => Body::empty(),
104 };
105 for (n, v) in extra_headers {
106 req = req.header(*n, *v);
107 }
108 let mut http_req = req.body(body).unwrap();
109 for (name, value) in &self.base_headers {
110 http_req.headers_mut().insert(name.clone(), value.clone());
111 }
112 self.send(http_req).await
113 }
114
115 async fn send(&self, req: Request<Body>) -> TestResponse {
116 let mut req = req;
117 for (name, value) in &self.base_headers {
118 req.headers_mut()
119 .entry(name.clone())
120 .or_insert_with(|| value.clone());
121 }
122 let response = self.router.clone().oneshot(req).await.unwrap();
123
124 let status = response.status();
125 let headers = response.headers().clone();
126 let bytes = response
127 .into_body()
128 .collect()
129 .await
130 .map(|c| c.to_bytes())
131 .unwrap_or_default();
132
133 TestResponse {
134 status,
135 headers,
136 body: bytes.to_vec(),
137 }
138 }
139}
140
141pub struct TestResponse {
142 pub status: StatusCode,
143 pub headers: HeaderMap,
144 pub body: Vec<u8>,
145}
146
147impl TestResponse {
148 pub fn assert_status(&self, expected: u16) -> &Self {
151 assert_eq!(
152 self.status.as_u16(),
153 expected,
154 "expected status {expected}, got {} — body: {}",
155 self.status,
156 self.body_text()
157 );
158 self
159 }
160
161 pub fn assert_ok(&self) -> &Self {
162 assert!(
163 self.status.is_success(),
164 "expected success, got {} — body: {}",
165 self.status,
166 self.body_text()
167 );
168 self
169 }
170
171 pub fn assert_created(&self) -> &Self {
172 self.assert_status(201)
173 }
174 pub fn assert_no_content(&self) -> &Self {
175 self.assert_status(204)
176 }
177 pub fn assert_bad_request(&self) -> &Self {
178 self.assert_status(400)
179 }
180 pub fn assert_unauthorized(&self) -> &Self {
181 self.assert_status(401)
182 }
183 pub fn assert_forbidden(&self) -> &Self {
184 self.assert_status(403)
185 }
186 pub fn assert_not_found(&self) -> &Self {
187 self.assert_status(404)
188 }
189 pub fn assert_unprocessable(&self) -> &Self {
190 self.assert_status(422)
191 }
192 pub fn assert_too_many_requests(&self) -> &Self {
193 self.assert_status(429)
194 }
195 pub fn assert_server_error(&self) -> &Self {
196 assert!(
197 self.status.is_server_error(),
198 "expected 5xx, got {} — body: {}",
199 self.status,
200 self.body_text()
201 );
202 self
203 }
204
205 pub fn assert_redirect(&self) -> &Self {
206 assert!(
207 self.status.is_redirection(),
208 "expected 3xx redirect, got {} — body: {}",
209 self.status,
210 self.body_text()
211 );
212 self
213 }
214
215 pub fn assert_redirect_to(&self, location: &str) -> &Self {
216 self.assert_redirect();
217 let actual = self
218 .headers
219 .get("location")
220 .and_then(|v| v.to_str().ok())
221 .unwrap_or("");
222 assert_eq!(actual, location, "redirect Location mismatch");
223 self
224 }
225
226 pub fn assert_header(&self, name: &str, value: &str) -> &Self {
229 let actual = self
230 .headers
231 .get(name)
232 .and_then(|v| v.to_str().ok())
233 .unwrap_or("");
234 assert_eq!(actual, value, "header `{name}` mismatch");
235 self
236 }
237
238 pub fn assert_header_present(&self, name: &str) -> &Self {
239 assert!(
240 self.headers.contains_key(name),
241 "expected header `{name}` to be present"
242 );
243 self
244 }
245
246 pub fn assert_header_missing(&self, name: &str) -> &Self {
247 assert!(
248 !self.headers.contains_key(name),
249 "expected header `{name}` NOT to be present"
250 );
251 self
252 }
253
254 pub fn header(&self, name: &str) -> Option<String> {
255 self.headers
256 .get(name)
257 .and_then(|v| v.to_str().ok().map(String::from))
258 }
259
260 pub fn body_text(&self) -> String {
263 String::from_utf8_lossy(&self.body).to_string()
264 }
265
266 pub fn json<T: DeserializeOwned>(&self) -> T {
267 serde_json::from_slice(&self.body).expect("response was not valid JSON")
268 }
269
270 pub fn json_value(&self) -> serde_json::Value {
271 serde_json::from_slice(&self.body).unwrap_or(serde_json::Value::Null)
272 }
273
274 pub fn assert_contains(&self, needle: &str) -> &Self {
275 let body = self.body_text();
276 assert!(
277 body.contains(needle),
278 "expected response body to contain '{needle}', got: {body}"
279 );
280 self
281 }
282 pub fn assert_dont_contain(&self, needle: &str) -> &Self {
283 let body = self.body_text();
284 assert!(
285 !body.contains(needle),
286 "expected response body NOT to contain '{needle}', got: {body}"
287 );
288 self
289 }
290 pub fn assert_see(&self, text: &str) -> &Self {
292 self.assert_contains(text)
293 }
294 pub fn assert_dont_see(&self, text: &str) -> &Self {
295 self.assert_dont_contain(text)
296 }
297
298 pub fn assert_json(&self, expected: serde_json::Value) -> &Self {
302 let actual = self.json_value();
303 assert_eq!(actual, expected, "JSON body mismatch");
304 self
305 }
306
307 pub fn assert_json_path(&self, path: &str, expected: serde_json::Value) -> &Self {
310 let actual = json_dig(&self.json_value(), path);
311 assert_eq!(
312 actual.as_ref(),
313 Some(&expected),
314 "JSON path `{path}` mismatch — full body: {}",
315 self.body_text()
316 );
317 self
318 }
319
320 pub fn assert_json_fragment(&self, subset: serde_json::Value) -> &Self {
323 let actual = self.json_value();
324 assert!(
325 json_contains(&actual, &subset),
326 "JSON body missing fragment {subset} — got {actual}"
327 );
328 self
329 }
330
331 pub fn assert_validation_error(&self, field: &str) -> &Self {
334 let v = self.json_value();
335 let arr = v
336 .get("errors")
337 .and_then(|e| e.get(field))
338 .and_then(|f| f.as_array());
339 assert!(
340 arr.map(|a| !a.is_empty()).unwrap_or(false),
341 "expected validation error on field `{field}` — body: {}",
342 self.body_text()
343 );
344 self
345 }
346}
347
348fn json_contains(actual: &serde_json::Value, expected: &serde_json::Value) -> bool {
351 use serde_json::Value::*;
352 match (actual, expected) {
353 (Object(a), Object(e)) => e
354 .iter()
355 .all(|(k, ev)| a.get(k).is_some_and(|av| json_contains(av, ev))),
356 (Array(a), Array(e)) => e.iter().all(|ev| a.iter().any(|av| json_contains(av, ev))),
357 (a, e) => a == e,
358 }
359}
360
361fn json_dig(v: &serde_json::Value, path: &str) -> Option<serde_json::Value> {
363 let mut current = v;
364 for segment in path.split('.') {
365 current = if let Ok(idx) = segment.parse::<usize>() {
366 current.get(idx)?
367 } else {
368 current.get(segment)?
369 };
370 }
371 Some(current.clone())
372}
373
374fn _force_link() {
376 let _ = std::any::type_name::<Infallible>();
377}