1use std::convert::Infallible;
8use std::sync::{Arc, Mutex};
9
10use anvil_core::Application;
11use axum::body::{Body, Bytes};
12use axum::Router;
13use http::{HeaderMap, Method, Request, StatusCode};
14use http_body_util::BodyExt;
15use serde::de::DeserializeOwned;
16use tower::ServiceExt;
17
18type CookieJar = Arc<Mutex<Vec<(String, String)>>>;
21
22pub struct TestClient {
23 router: Router,
24 base_headers: HeaderMap,
25 cookies: Option<CookieJar>,
33}
34
35impl TestClient {
36 pub async fn new(app: Application) -> Self {
37 let _ = anvil_core::config::load_dotenv();
43 Self {
44 router: app.into_router(),
45 base_headers: HeaderMap::new(),
46 cookies: None,
47 }
48 }
49
50 pub fn from_router(router: Router) -> Self {
51 let _ = anvil_core::config::load_dotenv();
52 Self {
53 router,
54 base_headers: HeaderMap::new(),
55 cookies: None,
56 }
57 }
58
59 pub fn with_cookie_jar(mut self) -> Self {
77 self.cookies = Some(Arc::new(Mutex::new(Vec::new())));
78 self
79 }
80
81 pub fn cookies(&self) -> Vec<(String, String)> {
85 self.cookies
86 .as_ref()
87 .map(|jar| jar.lock().unwrap().clone())
88 .unwrap_or_default()
89 }
90
91 pub fn clear_cookies(&self) {
94 if let Some(jar) = &self.cookies {
95 jar.lock().unwrap().clear();
96 }
97 }
98
99 pub fn with_header(mut self, name: &str, value: &str) -> Self {
101 if let (Ok(name), Ok(val)) = (
102 http::HeaderName::try_from(name),
103 http::HeaderValue::try_from(value),
104 ) {
105 self.base_headers.insert(name, val);
106 }
107 self
108 }
109
110 pub fn with_bearer(self, token: &str) -> Self {
112 self.with_header("authorization", &format!("Bearer {token}"))
113 }
114
115 pub fn with_ajax(self) -> Self {
117 self.with_header("x-requested-with", "XMLHttpRequest")
118 }
119
120 pub async fn get(&self, path: &str) -> TestResponse {
121 self.request(Method::GET, path, None, &[]).await
122 }
123
124 pub async fn post(&self, path: &str, body: serde_json::Value) -> TestResponse {
125 self.request(Method::POST, path, Some(body), &[]).await
126 }
127
128 pub async fn put(&self, path: &str, body: serde_json::Value) -> TestResponse {
129 self.request(Method::PUT, path, Some(body), &[]).await
130 }
131
132 pub async fn patch(&self, path: &str, body: serde_json::Value) -> TestResponse {
133 self.request(Method::PATCH, path, Some(body), &[]).await
134 }
135
136 pub async fn delete(&self, path: &str) -> TestResponse {
137 self.request(Method::DELETE, path, None, &[]).await
138 }
139
140 pub async fn post_form(&self, path: &str, form: &[(&str, &str)]) -> TestResponse {
142 let body = serde_urlencoded::to_string(form).unwrap_or_default();
143 let req = Request::builder()
144 .method(Method::POST)
145 .uri(path)
146 .header("content-type", "application/x-www-form-urlencoded")
147 .body(Body::from(body))
148 .unwrap();
149 self.send(req).await
150 }
151
152 pub async fn post_bytes(
156 &self,
157 path: &str,
158 body: impl Into<Bytes>,
159 content_type: &str,
160 ) -> TestResponse {
161 self.bytes_request(Method::POST, path, body.into(), content_type)
162 .await
163 }
164
165 pub async fn put_bytes(
167 &self,
168 path: &str,
169 body: impl Into<Bytes>,
170 content_type: &str,
171 ) -> TestResponse {
172 self.bytes_request(Method::PUT, path, body.into(), content_type)
173 .await
174 }
175
176 pub async fn patch_bytes(
178 &self,
179 path: &str,
180 body: impl Into<Bytes>,
181 content_type: &str,
182 ) -> TestResponse {
183 self.bytes_request(Method::PATCH, path, body.into(), content_type)
184 .await
185 }
186
187 async fn bytes_request(
188 &self,
189 method: Method,
190 path: &str,
191 body: Bytes,
192 content_type: &str,
193 ) -> TestResponse {
194 let req = Request::builder()
195 .method(method)
196 .uri(path)
197 .header("content-type", content_type)
198 .body(Body::from(body))
199 .unwrap();
200 self.send(req).await
201 }
202
203 async fn request(
204 &self,
205 method: Method,
206 path: &str,
207 body: Option<serde_json::Value>,
208 extra_headers: &[(&str, &str)],
209 ) -> TestResponse {
210 let mut req = Request::builder().method(method).uri(path);
211 let body = match body {
212 Some(v) => {
213 req = req.header("content-type", "application/json");
214 Body::from(serde_json::to_vec(&v).unwrap())
215 }
216 None => Body::empty(),
217 };
218 for (n, v) in extra_headers {
219 req = req.header(*n, *v);
220 }
221 let mut http_req = req.body(body).unwrap();
222 for (name, value) in &self.base_headers {
223 http_req.headers_mut().insert(name.clone(), value.clone());
224 }
225 self.send(http_req).await
226 }
227
228 async fn send(&self, req: Request<Body>) -> TestResponse {
229 let mut req = req;
230 for (name, value) in &self.base_headers {
231 req.headers_mut()
232 .entry(name.clone())
233 .or_insert_with(|| value.clone());
234 }
235 if let Some(jar) = &self.cookies {
239 let cookies = jar.lock().unwrap();
240 if !cookies.is_empty() {
241 let joined = cookies
242 .iter()
243 .map(|(n, v)| format!("{n}={v}"))
244 .collect::<Vec<_>>()
245 .join("; ");
246 if let Ok(val) = http::HeaderValue::from_str(&joined) {
247 req.headers_mut().insert("cookie", val);
248 }
249 }
250 }
251
252 let response = self.router.clone().oneshot(req).await.unwrap();
253
254 if let Some(jar) = &self.cookies {
257 let mut cookies = jar.lock().unwrap();
258 for raw in response.headers().get_all("set-cookie").iter() {
259 let Ok(s) = raw.to_str() else { continue };
260 let pair = s.split(';').next().unwrap_or(s);
263 let Some((name, value)) = pair.split_once('=') else {
264 continue;
265 };
266 let name = name.trim().to_string();
267 let value = value.trim().to_string();
268 cookies.retain(|(n, _)| n != &name);
270 if !value.is_empty() {
271 cookies.push((name, value));
272 }
273 }
274 }
275
276 let status = response.status();
277 let headers = response.headers().clone();
278 let bytes = response
279 .into_body()
280 .collect()
281 .await
282 .map(|c| c.to_bytes())
283 .unwrap_or_default();
284
285 TestResponse {
286 status,
287 headers,
288 body: bytes.to_vec(),
289 }
290 }
291}
292
293pub struct TestResponse {
294 pub status: StatusCode,
295 pub headers: HeaderMap,
296 pub body: Vec<u8>,
303}
304
305impl TestResponse {
306 pub fn assert_status(&self, expected: u16) -> &Self {
309 assert_eq!(
310 self.status.as_u16(),
311 expected,
312 "expected status {expected}, got {} — body: {}",
313 self.status,
314 self.body_text()
315 );
316 self
317 }
318
319 pub fn assert_ok(&self) -> &Self {
320 assert!(
321 self.status.is_success(),
322 "expected success, got {} — body: {}",
323 self.status,
324 self.body_text()
325 );
326 self
327 }
328
329 pub fn assert_created(&self) -> &Self {
330 self.assert_status(201)
331 }
332 pub fn assert_no_content(&self) -> &Self {
333 self.assert_status(204)
334 }
335 pub fn assert_bad_request(&self) -> &Self {
336 self.assert_status(400)
337 }
338 pub fn assert_unauthorized(&self) -> &Self {
339 self.assert_status(401)
340 }
341 pub fn assert_forbidden(&self) -> &Self {
342 self.assert_status(403)
343 }
344 pub fn assert_not_found(&self) -> &Self {
345 self.assert_status(404)
346 }
347 pub fn assert_unprocessable(&self) -> &Self {
348 self.assert_status(422)
349 }
350 pub fn assert_too_many_requests(&self) -> &Self {
351 self.assert_status(429)
352 }
353 pub fn assert_server_error(&self) -> &Self {
354 assert!(
355 self.status.is_server_error(),
356 "expected 5xx, got {} — body: {}",
357 self.status,
358 self.body_text()
359 );
360 self
361 }
362
363 pub fn assert_redirect(&self) -> &Self {
364 assert!(
365 self.status.is_redirection(),
366 "expected 3xx redirect, got {} — body: {}",
367 self.status,
368 self.body_text()
369 );
370 self
371 }
372
373 pub fn assert_redirect_to(&self, location: &str) -> &Self {
374 self.assert_redirect();
375 let actual = self
376 .headers
377 .get("location")
378 .and_then(|v| v.to_str().ok())
379 .unwrap_or("");
380 assert_eq!(actual, location, "redirect Location mismatch");
381 self
382 }
383
384 pub fn assert_header(&self, name: &str, value: &str) -> &Self {
387 let actual = self
388 .headers
389 .get(name)
390 .and_then(|v| v.to_str().ok())
391 .unwrap_or("");
392 assert_eq!(actual, value, "header `{name}` mismatch");
393 self
394 }
395
396 pub fn assert_header_present(&self, name: &str) -> &Self {
397 assert!(
398 self.headers.contains_key(name),
399 "expected header `{name}` to be present"
400 );
401 self
402 }
403
404 pub fn assert_header_missing(&self, name: &str) -> &Self {
405 assert!(
406 !self.headers.contains_key(name),
407 "expected header `{name}` NOT to be present"
408 );
409 self
410 }
411
412 pub fn header(&self, name: &str) -> Option<String> {
413 self.headers
414 .get(name)
415 .and_then(|v| v.to_str().ok().map(String::from))
416 }
417
418 pub fn body_bytes(&self) -> &[u8] {
428 &self.body
429 }
430
431 pub fn body_text(&self) -> String {
432 String::from_utf8_lossy(&self.body).to_string()
433 }
434
435 pub fn assert_body_bytes(&self, expected: impl AsRef<[u8]>) -> &Self {
438 let expected = expected.as_ref();
439 assert_eq!(
440 self.body.as_slice(),
441 expected,
442 "body byte mismatch — got {} bytes, expected {} bytes",
443 self.body.len(),
444 expected.len()
445 );
446 self
447 }
448
449 pub fn json<T: DeserializeOwned>(&self) -> T {
450 serde_json::from_slice(&self.body).expect("response was not valid JSON")
451 }
452
453 pub fn json_value(&self) -> serde_json::Value {
454 serde_json::from_slice(&self.body).unwrap_or(serde_json::Value::Null)
455 }
456
457 pub fn assert_contains(&self, needle: &str) -> &Self {
458 let body = self.body_text();
459 assert!(
460 body.contains(needle),
461 "expected response body to contain '{needle}', got: {body}"
462 );
463 self
464 }
465 pub fn assert_dont_contain(&self, needle: &str) -> &Self {
466 let body = self.body_text();
467 assert!(
468 !body.contains(needle),
469 "expected response body NOT to contain '{needle}', got: {body}"
470 );
471 self
472 }
473 pub fn assert_see(&self, text: &str) -> &Self {
475 self.assert_contains(text)
476 }
477 pub fn assert_dont_see(&self, text: &str) -> &Self {
478 self.assert_dont_contain(text)
479 }
480
481 pub fn assert_json(&self, expected: serde_json::Value) -> &Self {
485 let actual = self.json_value();
486 assert_eq!(actual, expected, "JSON body mismatch");
487 self
488 }
489
490 pub fn assert_json_path(&self, path: &str, expected: serde_json::Value) -> &Self {
493 let actual = json_dig(&self.json_value(), path);
494 assert_eq!(
495 actual.as_ref(),
496 Some(&expected),
497 "JSON path `{path}` mismatch — full body: {}",
498 self.body_text()
499 );
500 self
501 }
502
503 pub fn assert_json_fragment(&self, subset: serde_json::Value) -> &Self {
506 let actual = self.json_value();
507 assert!(
508 json_contains(&actual, &subset),
509 "JSON body missing fragment {subset} — got {actual}"
510 );
511 self
512 }
513
514 pub fn assert_validation_error(&self, field: &str) -> &Self {
517 let v = self.json_value();
518 let arr = v
519 .get("errors")
520 .and_then(|e| e.get(field))
521 .and_then(|f| f.as_array());
522 assert!(
523 arr.map(|a| !a.is_empty()).unwrap_or(false),
524 "expected validation error on field `{field}` — body: {}",
525 self.body_text()
526 );
527 self
528 }
529}
530
531fn json_contains(actual: &serde_json::Value, expected: &serde_json::Value) -> bool {
534 use serde_json::Value::*;
535 match (actual, expected) {
536 (Object(a), Object(e)) => e
537 .iter()
538 .all(|(k, ev)| a.get(k).is_some_and(|av| json_contains(av, ev))),
539 (Array(a), Array(e)) => e.iter().all(|ev| a.iter().any(|av| json_contains(av, ev))),
540 (a, e) => a == e,
541 }
542}
543
544fn json_dig(v: &serde_json::Value, path: &str) -> Option<serde_json::Value> {
546 let mut current = v;
547 for segment in path.split('.') {
548 current = if let Ok(idx) = segment.parse::<usize>() {
549 current.get(idx)?
550 } else {
551 current.get(segment)?
552 };
553 }
554 Some(current.clone())
555}
556
557fn _force_link() {
559 let _ = std::any::type_name::<Infallible>();
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565 use axum::routing::post;
566
567 async fn echo(body: Bytes) -> Bytes {
570 body
571 }
572
573 #[tokio::test]
574 async fn post_bytes_round_trips_arbitrary_bytes() {
575 let router = Router::new().route("/echo", post(echo));
576 let client = TestClient::from_router(router);
577
578 let cbor = vec![0xA1, 0x62, 0x6F, 0x6B, 0xF5];
580 let resp = client
581 .post_bytes("/echo", cbor.clone(), "application/cbor")
582 .await;
583
584 resp.assert_ok();
585 assert_eq!(resp.body, cbor);
586 }
587
588 #[tokio::test]
589 async fn post_bytes_sets_content_type_header_for_handler_dispatch() {
590 async fn ct(headers: http::HeaderMap) -> String {
593 headers
594 .get("content-type")
595 .and_then(|v| v.to_str().ok())
596 .unwrap_or("missing")
597 .to_string()
598 }
599 let router = Router::new().route("/ct", post(ct));
600 let client = TestClient::from_router(router);
601
602 let resp = client
603 .post_bytes("/ct", b"x".to_vec(), "application/x-protobuf")
604 .await;
605 resp.assert_ok();
606 assert_eq!(resp.body_text(), "application/x-protobuf");
607 }
608
609 #[tokio::test]
610 async fn body_bytes_preserves_non_utf8_payload() {
611 async fn binary() -> Vec<u8> {
615 vec![0xFF, 0xFE, 0xFD, 0x00, 0x80, 0xC0]
616 }
617 let router = Router::new().route("/bin", axum::routing::get(binary));
618 let client = TestClient::from_router(router);
619
620 let resp = client.get("/bin").await;
621 resp.assert_ok();
622
623 resp.assert_body_bytes([0xFF, 0xFE, 0xFD, 0x00, 0x80, 0xC0]);
625 assert_eq!(resp.body_bytes(), &[0xFF, 0xFE, 0xFD, 0x00, 0x80, 0xC0]);
626
627 let text = resp.body_text();
630 assert!(text.contains('\u{FFFD}'), "body_text lossy-decodes");
631 }
632
633 #[tokio::test]
634 async fn put_and_patch_bytes_dispatch_correctly() {
635 async fn method_name(method: Method) -> String {
636 method.as_str().to_string()
637 }
638 let router = Router::new()
639 .route("/m", axum::routing::put(method_name))
640 .route("/m", axum::routing::patch(method_name));
641 let client = TestClient::from_router(router);
642
643 let resp = client
644 .put_bytes("/m", b"_".to_vec(), "application/octet-stream")
645 .await;
646 resp.assert_ok();
647 assert_eq!(resp.body_text(), "PUT");
648
649 let resp = client
650 .patch_bytes("/m", b"_".to_vec(), "application/octet-stream")
651 .await;
652 resp.assert_ok();
653 assert_eq!(resp.body_text(), "PATCH");
654 }
655
656 #[tokio::test]
657 async fn cookie_jar_persists_set_cookie_across_requests() {
658 use axum::http::HeaderMap;
659 use axum::response::Response;
660 use axum::routing::get;
661
662 async fn set_cookie() -> Response {
663 Response::builder()
664 .status(200)
665 .header("set-cookie", "session_id=abc123; Path=/")
666 .body(axum::body::Body::from("set"))
667 .unwrap()
668 }
669
670 async fn read_cookie(headers: HeaderMap) -> String {
671 headers
672 .get("cookie")
673 .and_then(|v| v.to_str().ok())
674 .unwrap_or("(none)")
675 .to_string()
676 }
677
678 let router = Router::new()
679 .route("/login", get(set_cookie))
680 .route("/me", get(read_cookie));
681 let client = TestClient::from_router(router).with_cookie_jar();
682
683 let r1 = client.get("/login").await;
684 r1.assert_ok();
685
686 let r2 = client.get("/me").await;
687 r2.assert_ok();
688 assert_eq!(r2.body_text(), "session_id=abc123");
689
690 let snap = client.cookies();
692 assert_eq!(snap, vec![("session_id".to_string(), "abc123".to_string())]);
693 }
694
695 #[tokio::test]
696 async fn cookie_jar_replaces_same_name_and_deletes_on_empty_value() {
697 use axum::response::Response;
698 use axum::routing::get;
699
700 async fn rotate() -> Response {
701 Response::builder()
702 .status(200)
703 .header("set-cookie", "session_id=v2")
704 .body(axum::body::Body::from(""))
705 .unwrap()
706 }
707
708 async fn delete() -> Response {
709 Response::builder()
710 .status(200)
711 .header("set-cookie", "session_id=; Max-Age=0")
712 .body(axum::body::Body::from(""))
713 .unwrap()
714 }
715
716 let router = Router::new()
717 .route("/rotate", get(rotate))
718 .route("/logout", get(delete));
719 let client = TestClient::from_router(router).with_cookie_jar();
720
721 client.get("/rotate").await.assert_ok();
722 assert_eq!(client.cookies(), vec![("session_id".into(), "v2".into())]);
723
724 client.get("/rotate").await.assert_ok();
726 assert_eq!(client.cookies(), vec![("session_id".into(), "v2".into())]);
727
728 client.get("/logout").await.assert_ok();
730 assert!(client.cookies().is_empty());
731 }
732
733 #[tokio::test]
734 async fn cookie_jar_off_by_default_does_not_carry_state() {
735 use axum::http::HeaderMap;
736 use axum::response::Response;
737 use axum::routing::get;
738
739 async fn set_cookie() -> Response {
740 Response::builder()
741 .status(200)
742 .header("set-cookie", "x=1")
743 .body(axum::body::Body::from(""))
744 .unwrap()
745 }
746 async fn read_cookie(headers: HeaderMap) -> String {
747 headers
748 .get("cookie")
749 .and_then(|v| v.to_str().ok())
750 .unwrap_or("(none)")
751 .to_string()
752 }
753
754 let router = Router::new()
755 .route("/set", get(set_cookie))
756 .route("/read", get(read_cookie));
757 let client = TestClient::from_router(router); client.get("/set").await.assert_ok();
760 let r2 = client.get("/read").await;
761 assert_eq!(r2.body_text(), "(none)");
763 assert!(client.cookies().is_empty());
764 }
765}