1use std::convert::Infallible;
8
9use anvil_core::Application;
10use axum::body::{Body, Bytes};
11use axum::Router;
12use http::{HeaderMap, Method, Request, StatusCode};
13use http_body_util::BodyExt;
14use serde::de::DeserializeOwned;
15use tower::ServiceExt;
16
17pub struct TestClient {
18 router: Router,
19 base_headers: HeaderMap,
20}
21
22impl TestClient {
23 pub async fn new(app: Application) -> Self {
24 let _ = anvil_core::config::load_dotenv();
30 Self {
31 router: app.into_router(),
32 base_headers: HeaderMap::new(),
33 }
34 }
35
36 pub fn from_router(router: Router) -> Self {
37 let _ = anvil_core::config::load_dotenv();
38 Self {
39 router,
40 base_headers: HeaderMap::new(),
41 }
42 }
43
44 pub fn with_header(mut self, name: &str, value: &str) -> Self {
46 if let (Ok(name), Ok(val)) = (
47 http::HeaderName::try_from(name),
48 http::HeaderValue::try_from(value),
49 ) {
50 self.base_headers.insert(name, val);
51 }
52 self
53 }
54
55 pub fn with_bearer(self, token: &str) -> Self {
57 self.with_header("authorization", &format!("Bearer {token}"))
58 }
59
60 pub fn with_ajax(self) -> Self {
62 self.with_header("x-requested-with", "XMLHttpRequest")
63 }
64
65 pub async fn get(&self, path: &str) -> TestResponse {
66 self.request(Method::GET, path, None, &[]).await
67 }
68
69 pub async fn post(&self, path: &str, body: serde_json::Value) -> TestResponse {
70 self.request(Method::POST, path, Some(body), &[]).await
71 }
72
73 pub async fn put(&self, path: &str, body: serde_json::Value) -> TestResponse {
74 self.request(Method::PUT, path, Some(body), &[]).await
75 }
76
77 pub async fn patch(&self, path: &str, body: serde_json::Value) -> TestResponse {
78 self.request(Method::PATCH, path, Some(body), &[]).await
79 }
80
81 pub async fn delete(&self, path: &str) -> TestResponse {
82 self.request(Method::DELETE, path, None, &[]).await
83 }
84
85 pub async fn post_form(&self, path: &str, form: &[(&str, &str)]) -> TestResponse {
87 let body = serde_urlencoded::to_string(form).unwrap_or_default();
88 let req = Request::builder()
89 .method(Method::POST)
90 .uri(path)
91 .header("content-type", "application/x-www-form-urlencoded")
92 .body(Body::from(body))
93 .unwrap();
94 self.send(req).await
95 }
96
97 pub async fn post_bytes(
101 &self,
102 path: &str,
103 body: impl Into<Bytes>,
104 content_type: &str,
105 ) -> TestResponse {
106 self.bytes_request(Method::POST, path, body.into(), content_type)
107 .await
108 }
109
110 pub async fn put_bytes(
112 &self,
113 path: &str,
114 body: impl Into<Bytes>,
115 content_type: &str,
116 ) -> TestResponse {
117 self.bytes_request(Method::PUT, path, body.into(), content_type)
118 .await
119 }
120
121 pub async fn patch_bytes(
123 &self,
124 path: &str,
125 body: impl Into<Bytes>,
126 content_type: &str,
127 ) -> TestResponse {
128 self.bytes_request(Method::PATCH, path, body.into(), content_type)
129 .await
130 }
131
132 async fn bytes_request(
133 &self,
134 method: Method,
135 path: &str,
136 body: Bytes,
137 content_type: &str,
138 ) -> TestResponse {
139 let req = Request::builder()
140 .method(method)
141 .uri(path)
142 .header("content-type", content_type)
143 .body(Body::from(body))
144 .unwrap();
145 self.send(req).await
146 }
147
148 async fn request(
149 &self,
150 method: Method,
151 path: &str,
152 body: Option<serde_json::Value>,
153 extra_headers: &[(&str, &str)],
154 ) -> TestResponse {
155 let mut req = Request::builder().method(method).uri(path);
156 let body = match body {
157 Some(v) => {
158 req = req.header("content-type", "application/json");
159 Body::from(serde_json::to_vec(&v).unwrap())
160 }
161 None => Body::empty(),
162 };
163 for (n, v) in extra_headers {
164 req = req.header(*n, *v);
165 }
166 let mut http_req = req.body(body).unwrap();
167 for (name, value) in &self.base_headers {
168 http_req.headers_mut().insert(name.clone(), value.clone());
169 }
170 self.send(http_req).await
171 }
172
173 async fn send(&self, req: Request<Body>) -> TestResponse {
174 let mut req = req;
175 for (name, value) in &self.base_headers {
176 req.headers_mut()
177 .entry(name.clone())
178 .or_insert_with(|| value.clone());
179 }
180 let response = self.router.clone().oneshot(req).await.unwrap();
181
182 let status = response.status();
183 let headers = response.headers().clone();
184 let bytes = response
185 .into_body()
186 .collect()
187 .await
188 .map(|c| c.to_bytes())
189 .unwrap_or_default();
190
191 TestResponse {
192 status,
193 headers,
194 body: bytes.to_vec(),
195 }
196 }
197}
198
199pub struct TestResponse {
200 pub status: StatusCode,
201 pub headers: HeaderMap,
202 pub body: Vec<u8>,
209}
210
211impl TestResponse {
212 pub fn assert_status(&self, expected: u16) -> &Self {
215 assert_eq!(
216 self.status.as_u16(),
217 expected,
218 "expected status {expected}, got {} — body: {}",
219 self.status,
220 self.body_text()
221 );
222 self
223 }
224
225 pub fn assert_ok(&self) -> &Self {
226 assert!(
227 self.status.is_success(),
228 "expected success, got {} — body: {}",
229 self.status,
230 self.body_text()
231 );
232 self
233 }
234
235 pub fn assert_created(&self) -> &Self {
236 self.assert_status(201)
237 }
238 pub fn assert_no_content(&self) -> &Self {
239 self.assert_status(204)
240 }
241 pub fn assert_bad_request(&self) -> &Self {
242 self.assert_status(400)
243 }
244 pub fn assert_unauthorized(&self) -> &Self {
245 self.assert_status(401)
246 }
247 pub fn assert_forbidden(&self) -> &Self {
248 self.assert_status(403)
249 }
250 pub fn assert_not_found(&self) -> &Self {
251 self.assert_status(404)
252 }
253 pub fn assert_unprocessable(&self) -> &Self {
254 self.assert_status(422)
255 }
256 pub fn assert_too_many_requests(&self) -> &Self {
257 self.assert_status(429)
258 }
259 pub fn assert_server_error(&self) -> &Self {
260 assert!(
261 self.status.is_server_error(),
262 "expected 5xx, got {} — body: {}",
263 self.status,
264 self.body_text()
265 );
266 self
267 }
268
269 pub fn assert_redirect(&self) -> &Self {
270 assert!(
271 self.status.is_redirection(),
272 "expected 3xx redirect, got {} — body: {}",
273 self.status,
274 self.body_text()
275 );
276 self
277 }
278
279 pub fn assert_redirect_to(&self, location: &str) -> &Self {
280 self.assert_redirect();
281 let actual = self
282 .headers
283 .get("location")
284 .and_then(|v| v.to_str().ok())
285 .unwrap_or("");
286 assert_eq!(actual, location, "redirect Location mismatch");
287 self
288 }
289
290 pub fn assert_header(&self, name: &str, value: &str) -> &Self {
293 let actual = self
294 .headers
295 .get(name)
296 .and_then(|v| v.to_str().ok())
297 .unwrap_or("");
298 assert_eq!(actual, value, "header `{name}` mismatch");
299 self
300 }
301
302 pub fn assert_header_present(&self, name: &str) -> &Self {
303 assert!(
304 self.headers.contains_key(name),
305 "expected header `{name}` to be present"
306 );
307 self
308 }
309
310 pub fn assert_header_missing(&self, name: &str) -> &Self {
311 assert!(
312 !self.headers.contains_key(name),
313 "expected header `{name}` NOT to be present"
314 );
315 self
316 }
317
318 pub fn header(&self, name: &str) -> Option<String> {
319 self.headers
320 .get(name)
321 .and_then(|v| v.to_str().ok().map(String::from))
322 }
323
324 pub fn body_bytes(&self) -> &[u8] {
334 &self.body
335 }
336
337 pub fn body_text(&self) -> String {
338 String::from_utf8_lossy(&self.body).to_string()
339 }
340
341 pub fn assert_body_bytes(&self, expected: impl AsRef<[u8]>) -> &Self {
344 let expected = expected.as_ref();
345 assert_eq!(
346 self.body.as_slice(),
347 expected,
348 "body byte mismatch — got {} bytes, expected {} bytes",
349 self.body.len(),
350 expected.len()
351 );
352 self
353 }
354
355 pub fn json<T: DeserializeOwned>(&self) -> T {
356 serde_json::from_slice(&self.body).expect("response was not valid JSON")
357 }
358
359 pub fn json_value(&self) -> serde_json::Value {
360 serde_json::from_slice(&self.body).unwrap_or(serde_json::Value::Null)
361 }
362
363 pub fn assert_contains(&self, needle: &str) -> &Self {
364 let body = self.body_text();
365 assert!(
366 body.contains(needle),
367 "expected response body to contain '{needle}', got: {body}"
368 );
369 self
370 }
371 pub fn assert_dont_contain(&self, needle: &str) -> &Self {
372 let body = self.body_text();
373 assert!(
374 !body.contains(needle),
375 "expected response body NOT to contain '{needle}', got: {body}"
376 );
377 self
378 }
379 pub fn assert_see(&self, text: &str) -> &Self {
381 self.assert_contains(text)
382 }
383 pub fn assert_dont_see(&self, text: &str) -> &Self {
384 self.assert_dont_contain(text)
385 }
386
387 pub fn assert_json(&self, expected: serde_json::Value) -> &Self {
391 let actual = self.json_value();
392 assert_eq!(actual, expected, "JSON body mismatch");
393 self
394 }
395
396 pub fn assert_json_path(&self, path: &str, expected: serde_json::Value) -> &Self {
399 let actual = json_dig(&self.json_value(), path);
400 assert_eq!(
401 actual.as_ref(),
402 Some(&expected),
403 "JSON path `{path}` mismatch — full body: {}",
404 self.body_text()
405 );
406 self
407 }
408
409 pub fn assert_json_fragment(&self, subset: serde_json::Value) -> &Self {
412 let actual = self.json_value();
413 assert!(
414 json_contains(&actual, &subset),
415 "JSON body missing fragment {subset} — got {actual}"
416 );
417 self
418 }
419
420 pub fn assert_validation_error(&self, field: &str) -> &Self {
423 let v = self.json_value();
424 let arr = v
425 .get("errors")
426 .and_then(|e| e.get(field))
427 .and_then(|f| f.as_array());
428 assert!(
429 arr.map(|a| !a.is_empty()).unwrap_or(false),
430 "expected validation error on field `{field}` — body: {}",
431 self.body_text()
432 );
433 self
434 }
435}
436
437fn json_contains(actual: &serde_json::Value, expected: &serde_json::Value) -> bool {
440 use serde_json::Value::*;
441 match (actual, expected) {
442 (Object(a), Object(e)) => e
443 .iter()
444 .all(|(k, ev)| a.get(k).is_some_and(|av| json_contains(av, ev))),
445 (Array(a), Array(e)) => e.iter().all(|ev| a.iter().any(|av| json_contains(av, ev))),
446 (a, e) => a == e,
447 }
448}
449
450fn json_dig(v: &serde_json::Value, path: &str) -> Option<serde_json::Value> {
452 let mut current = v;
453 for segment in path.split('.') {
454 current = if let Ok(idx) = segment.parse::<usize>() {
455 current.get(idx)?
456 } else {
457 current.get(segment)?
458 };
459 }
460 Some(current.clone())
461}
462
463fn _force_link() {
465 let _ = std::any::type_name::<Infallible>();
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471 use axum::routing::post;
472
473 async fn echo(body: Bytes) -> Bytes {
476 body
477 }
478
479 #[tokio::test]
480 async fn post_bytes_round_trips_arbitrary_bytes() {
481 let router = Router::new().route("/echo", post(echo));
482 let client = TestClient::from_router(router);
483
484 let cbor = vec![0xA1, 0x62, 0x6F, 0x6B, 0xF5];
486 let resp = client
487 .post_bytes("/echo", cbor.clone(), "application/cbor")
488 .await;
489
490 resp.assert_ok();
491 assert_eq!(resp.body, cbor);
492 }
493
494 #[tokio::test]
495 async fn post_bytes_sets_content_type_header_for_handler_dispatch() {
496 async fn ct(headers: http::HeaderMap) -> String {
499 headers
500 .get("content-type")
501 .and_then(|v| v.to_str().ok())
502 .unwrap_or("missing")
503 .to_string()
504 }
505 let router = Router::new().route("/ct", post(ct));
506 let client = TestClient::from_router(router);
507
508 let resp = client
509 .post_bytes("/ct", b"x".to_vec(), "application/x-protobuf")
510 .await;
511 resp.assert_ok();
512 assert_eq!(resp.body_text(), "application/x-protobuf");
513 }
514
515 #[tokio::test]
516 async fn body_bytes_preserves_non_utf8_payload() {
517 async fn binary() -> Vec<u8> {
521 vec![0xFF, 0xFE, 0xFD, 0x00, 0x80, 0xC0]
522 }
523 let router = Router::new().route("/bin", axum::routing::get(binary));
524 let client = TestClient::from_router(router);
525
526 let resp = client.get("/bin").await;
527 resp.assert_ok();
528
529 resp.assert_body_bytes([0xFF, 0xFE, 0xFD, 0x00, 0x80, 0xC0]);
531 assert_eq!(resp.body_bytes(), &[0xFF, 0xFE, 0xFD, 0x00, 0x80, 0xC0]);
532
533 let text = resp.body_text();
536 assert!(text.contains('\u{FFFD}'), "body_text lossy-decodes");
537 }
538
539 #[tokio::test]
540 async fn put_and_patch_bytes_dispatch_correctly() {
541 async fn method_name(method: Method) -> String {
542 method.as_str().to_string()
543 }
544 let router = Router::new()
545 .route("/m", axum::routing::put(method_name))
546 .route("/m", axum::routing::patch(method_name));
547 let client = TestClient::from_router(router);
548
549 let resp = client
550 .put_bytes("/m", b"_".to_vec(), "application/octet-stream")
551 .await;
552 resp.assert_ok();
553 assert_eq!(resp.body_text(), "PUT");
554
555 let resp = client
556 .patch_bytes("/m", b"_".to_vec(), "application/octet-stream")
557 .await;
558 resp.assert_ok();
559 assert_eq!(resp.body_text(), "PATCH");
560 }
561}