1use std::convert::Infallible;
8
9use anvil_core::Application;
10use axum::body::{Body, Bytes};
11use axum::Router;
12use http::{HeaderMap, Method, Request, StatusCode};
13use http_body_util::BodyExt;
14use serde::de::DeserializeOwned;
15use tower::ServiceExt;
16
17pub struct TestClient {
18 router: Router,
19 base_headers: HeaderMap,
20}
21
22impl TestClient {
23 pub async fn new(app: Application) -> Self {
24 Self {
25 router: app.into_router(),
26 base_headers: HeaderMap::new(),
27 }
28 }
29
30 pub fn from_router(router: Router) -> Self {
31 Self {
32 router,
33 base_headers: HeaderMap::new(),
34 }
35 }
36
37 pub fn with_header(mut self, name: &str, value: &str) -> Self {
39 if let (Ok(name), Ok(val)) = (
40 http::HeaderName::try_from(name),
41 http::HeaderValue::try_from(value),
42 ) {
43 self.base_headers.insert(name, val);
44 }
45 self
46 }
47
48 pub fn with_bearer(self, token: &str) -> Self {
50 self.with_header("authorization", &format!("Bearer {token}"))
51 }
52
53 pub fn with_ajax(self) -> Self {
55 self.with_header("x-requested-with", "XMLHttpRequest")
56 }
57
58 pub async fn get(&self, path: &str) -> TestResponse {
59 self.request(Method::GET, path, None, &[]).await
60 }
61
62 pub async fn post(&self, path: &str, body: serde_json::Value) -> TestResponse {
63 self.request(Method::POST, path, Some(body), &[]).await
64 }
65
66 pub async fn put(&self, path: &str, body: serde_json::Value) -> TestResponse {
67 self.request(Method::PUT, path, Some(body), &[]).await
68 }
69
70 pub async fn patch(&self, path: &str, body: serde_json::Value) -> TestResponse {
71 self.request(Method::PATCH, path, Some(body), &[]).await
72 }
73
74 pub async fn delete(&self, path: &str) -> TestResponse {
75 self.request(Method::DELETE, path, None, &[]).await
76 }
77
78 pub async fn post_form(&self, path: &str, form: &[(&str, &str)]) -> TestResponse {
80 let body = serde_urlencoded::to_string(form).unwrap_or_default();
81 let req = Request::builder()
82 .method(Method::POST)
83 .uri(path)
84 .header("content-type", "application/x-www-form-urlencoded")
85 .body(Body::from(body))
86 .unwrap();
87 self.send(req).await
88 }
89
90 pub async fn post_bytes(
94 &self,
95 path: &str,
96 body: impl Into<Bytes>,
97 content_type: &str,
98 ) -> TestResponse {
99 self.bytes_request(Method::POST, path, body.into(), content_type)
100 .await
101 }
102
103 pub async fn put_bytes(
105 &self,
106 path: &str,
107 body: impl Into<Bytes>,
108 content_type: &str,
109 ) -> TestResponse {
110 self.bytes_request(Method::PUT, path, body.into(), content_type)
111 .await
112 }
113
114 pub async fn patch_bytes(
116 &self,
117 path: &str,
118 body: impl Into<Bytes>,
119 content_type: &str,
120 ) -> TestResponse {
121 self.bytes_request(Method::PATCH, path, body.into(), content_type)
122 .await
123 }
124
125 async fn bytes_request(
126 &self,
127 method: Method,
128 path: &str,
129 body: Bytes,
130 content_type: &str,
131 ) -> TestResponse {
132 let req = Request::builder()
133 .method(method)
134 .uri(path)
135 .header("content-type", content_type)
136 .body(Body::from(body))
137 .unwrap();
138 self.send(req).await
139 }
140
141 async fn request(
142 &self,
143 method: Method,
144 path: &str,
145 body: Option<serde_json::Value>,
146 extra_headers: &[(&str, &str)],
147 ) -> TestResponse {
148 let mut req = Request::builder().method(method).uri(path);
149 let body = match body {
150 Some(v) => {
151 req = req.header("content-type", "application/json");
152 Body::from(serde_json::to_vec(&v).unwrap())
153 }
154 None => Body::empty(),
155 };
156 for (n, v) in extra_headers {
157 req = req.header(*n, *v);
158 }
159 let mut http_req = req.body(body).unwrap();
160 for (name, value) in &self.base_headers {
161 http_req.headers_mut().insert(name.clone(), value.clone());
162 }
163 self.send(http_req).await
164 }
165
166 async fn send(&self, req: Request<Body>) -> TestResponse {
167 let mut req = req;
168 for (name, value) in &self.base_headers {
169 req.headers_mut()
170 .entry(name.clone())
171 .or_insert_with(|| value.clone());
172 }
173 let response = self.router.clone().oneshot(req).await.unwrap();
174
175 let status = response.status();
176 let headers = response.headers().clone();
177 let bytes = response
178 .into_body()
179 .collect()
180 .await
181 .map(|c| c.to_bytes())
182 .unwrap_or_default();
183
184 TestResponse {
185 status,
186 headers,
187 body: bytes.to_vec(),
188 }
189 }
190}
191
192pub struct TestResponse {
193 pub status: StatusCode,
194 pub headers: HeaderMap,
195 pub body: Vec<u8>,
196}
197
198impl TestResponse {
199 pub fn assert_status(&self, expected: u16) -> &Self {
202 assert_eq!(
203 self.status.as_u16(),
204 expected,
205 "expected status {expected}, got {} — body: {}",
206 self.status,
207 self.body_text()
208 );
209 self
210 }
211
212 pub fn assert_ok(&self) -> &Self {
213 assert!(
214 self.status.is_success(),
215 "expected success, got {} — body: {}",
216 self.status,
217 self.body_text()
218 );
219 self
220 }
221
222 pub fn assert_created(&self) -> &Self {
223 self.assert_status(201)
224 }
225 pub fn assert_no_content(&self) -> &Self {
226 self.assert_status(204)
227 }
228 pub fn assert_bad_request(&self) -> &Self {
229 self.assert_status(400)
230 }
231 pub fn assert_unauthorized(&self) -> &Self {
232 self.assert_status(401)
233 }
234 pub fn assert_forbidden(&self) -> &Self {
235 self.assert_status(403)
236 }
237 pub fn assert_not_found(&self) -> &Self {
238 self.assert_status(404)
239 }
240 pub fn assert_unprocessable(&self) -> &Self {
241 self.assert_status(422)
242 }
243 pub fn assert_too_many_requests(&self) -> &Self {
244 self.assert_status(429)
245 }
246 pub fn assert_server_error(&self) -> &Self {
247 assert!(
248 self.status.is_server_error(),
249 "expected 5xx, got {} — body: {}",
250 self.status,
251 self.body_text()
252 );
253 self
254 }
255
256 pub fn assert_redirect(&self) -> &Self {
257 assert!(
258 self.status.is_redirection(),
259 "expected 3xx redirect, got {} — body: {}",
260 self.status,
261 self.body_text()
262 );
263 self
264 }
265
266 pub fn assert_redirect_to(&self, location: &str) -> &Self {
267 self.assert_redirect();
268 let actual = self
269 .headers
270 .get("location")
271 .and_then(|v| v.to_str().ok())
272 .unwrap_or("");
273 assert_eq!(actual, location, "redirect Location mismatch");
274 self
275 }
276
277 pub fn assert_header(&self, name: &str, value: &str) -> &Self {
280 let actual = self
281 .headers
282 .get(name)
283 .and_then(|v| v.to_str().ok())
284 .unwrap_or("");
285 assert_eq!(actual, value, "header `{name}` mismatch");
286 self
287 }
288
289 pub fn assert_header_present(&self, name: &str) -> &Self {
290 assert!(
291 self.headers.contains_key(name),
292 "expected header `{name}` to be present"
293 );
294 self
295 }
296
297 pub fn assert_header_missing(&self, name: &str) -> &Self {
298 assert!(
299 !self.headers.contains_key(name),
300 "expected header `{name}` NOT to be present"
301 );
302 self
303 }
304
305 pub fn header(&self, name: &str) -> Option<String> {
306 self.headers
307 .get(name)
308 .and_then(|v| v.to_str().ok().map(String::from))
309 }
310
311 pub fn body_text(&self) -> String {
314 String::from_utf8_lossy(&self.body).to_string()
315 }
316
317 pub fn json<T: DeserializeOwned>(&self) -> T {
318 serde_json::from_slice(&self.body).expect("response was not valid JSON")
319 }
320
321 pub fn json_value(&self) -> serde_json::Value {
322 serde_json::from_slice(&self.body).unwrap_or(serde_json::Value::Null)
323 }
324
325 pub fn assert_contains(&self, needle: &str) -> &Self {
326 let body = self.body_text();
327 assert!(
328 body.contains(needle),
329 "expected response body to contain '{needle}', got: {body}"
330 );
331 self
332 }
333 pub fn assert_dont_contain(&self, needle: &str) -> &Self {
334 let body = self.body_text();
335 assert!(
336 !body.contains(needle),
337 "expected response body NOT to contain '{needle}', got: {body}"
338 );
339 self
340 }
341 pub fn assert_see(&self, text: &str) -> &Self {
343 self.assert_contains(text)
344 }
345 pub fn assert_dont_see(&self, text: &str) -> &Self {
346 self.assert_dont_contain(text)
347 }
348
349 pub fn assert_json(&self, expected: serde_json::Value) -> &Self {
353 let actual = self.json_value();
354 assert_eq!(actual, expected, "JSON body mismatch");
355 self
356 }
357
358 pub fn assert_json_path(&self, path: &str, expected: serde_json::Value) -> &Self {
361 let actual = json_dig(&self.json_value(), path);
362 assert_eq!(
363 actual.as_ref(),
364 Some(&expected),
365 "JSON path `{path}` mismatch — full body: {}",
366 self.body_text()
367 );
368 self
369 }
370
371 pub fn assert_json_fragment(&self, subset: serde_json::Value) -> &Self {
374 let actual = self.json_value();
375 assert!(
376 json_contains(&actual, &subset),
377 "JSON body missing fragment {subset} — got {actual}"
378 );
379 self
380 }
381
382 pub fn assert_validation_error(&self, field: &str) -> &Self {
385 let v = self.json_value();
386 let arr = v
387 .get("errors")
388 .and_then(|e| e.get(field))
389 .and_then(|f| f.as_array());
390 assert!(
391 arr.map(|a| !a.is_empty()).unwrap_or(false),
392 "expected validation error on field `{field}` — body: {}",
393 self.body_text()
394 );
395 self
396 }
397}
398
399fn json_contains(actual: &serde_json::Value, expected: &serde_json::Value) -> bool {
402 use serde_json::Value::*;
403 match (actual, expected) {
404 (Object(a), Object(e)) => e
405 .iter()
406 .all(|(k, ev)| a.get(k).is_some_and(|av| json_contains(av, ev))),
407 (Array(a), Array(e)) => e.iter().all(|ev| a.iter().any(|av| json_contains(av, ev))),
408 (a, e) => a == e,
409 }
410}
411
412fn json_dig(v: &serde_json::Value, path: &str) -> Option<serde_json::Value> {
414 let mut current = v;
415 for segment in path.split('.') {
416 current = if let Ok(idx) = segment.parse::<usize>() {
417 current.get(idx)?
418 } else {
419 current.get(segment)?
420 };
421 }
422 Some(current.clone())
423}
424
425fn _force_link() {
427 let _ = std::any::type_name::<Infallible>();
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433 use axum::routing::post;
434
435 async fn echo(body: Bytes) -> Bytes {
438 body
439 }
440
441 #[tokio::test]
442 async fn post_bytes_round_trips_arbitrary_bytes() {
443 let router = Router::new().route("/echo", post(echo));
444 let client = TestClient::from_router(router);
445
446 let cbor = vec![0xA1, 0x62, 0x6F, 0x6B, 0xF5];
448 let resp = client
449 .post_bytes("/echo", cbor.clone(), "application/cbor")
450 .await;
451
452 resp.assert_ok();
453 assert_eq!(resp.body, cbor);
454 }
455
456 #[tokio::test]
457 async fn post_bytes_sets_content_type_header_for_handler_dispatch() {
458 async fn ct(headers: http::HeaderMap) -> String {
461 headers
462 .get("content-type")
463 .and_then(|v| v.to_str().ok())
464 .unwrap_or("missing")
465 .to_string()
466 }
467 let router = Router::new().route("/ct", post(ct));
468 let client = TestClient::from_router(router);
469
470 let resp = client
471 .post_bytes("/ct", b"x".to_vec(), "application/x-protobuf")
472 .await;
473 resp.assert_ok();
474 assert_eq!(resp.body_text(), "application/x-protobuf");
475 }
476
477 #[tokio::test]
478 async fn put_and_patch_bytes_dispatch_correctly() {
479 async fn method_name(method: Method) -> String {
480 method.as_str().to_string()
481 }
482 let router = Router::new()
483 .route("/m", axum::routing::put(method_name))
484 .route("/m", axum::routing::patch(method_name));
485 let client = TestClient::from_router(router);
486
487 let resp = client
488 .put_bytes("/m", b"_".to_vec(), "application/octet-stream")
489 .await;
490 resp.assert_ok();
491 assert_eq!(resp.body_text(), "PUT");
492
493 let resp = client
494 .patch_bytes("/m", b"_".to_vec(), "application/octet-stream")
495 .await;
496 resp.assert_ok();
497 assert_eq!(resp.body_text(), "PATCH");
498 }
499}