1use bytes::Bytes;
27use http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
28use serde::de::DeserializeOwned;
29use serde_json::Value;
30use std::collections::HashMap;
31use std::fmt::Debug;
32use std::sync::Arc;
33
34use crate::routing::Router;
35
36pub struct TestClient {
38 #[allow(dead_code)]
39 router: Option<Arc<Router>>,
40 default_headers: HeaderMap,
41}
42
43impl TestClient {
44 pub fn new() -> Self {
46 Self {
47 router: None,
48 default_headers: HeaderMap::new(),
49 }
50 }
51
52 pub fn with_router(router: Router) -> Self {
54 Self {
55 router: Some(Arc::new(router)),
56 default_headers: HeaderMap::new(),
57 }
58 }
59
60 pub fn with_header(mut self, name: &str, value: &str) -> Self {
62 if let (Ok(name), Ok(value)) = (HeaderName::try_from(name), HeaderValue::try_from(value)) {
63 self.default_headers.insert(name, value);
64 }
65 self
66 }
67
68 pub fn json(self) -> Self {
70 self.with_header("Accept", "application/json")
71 .with_header("Content-Type", "application/json")
72 }
73
74 pub fn get(&self, path: &str) -> TestRequestBuilder<'_> {
76 TestRequestBuilder::new(self, Method::GET, path)
77 }
78
79 pub fn post(&self, path: &str) -> TestRequestBuilder<'_> {
81 TestRequestBuilder::new(self, Method::POST, path)
82 }
83
84 pub fn put(&self, path: &str) -> TestRequestBuilder<'_> {
86 TestRequestBuilder::new(self, Method::PUT, path)
87 }
88
89 pub fn patch(&self, path: &str) -> TestRequestBuilder<'_> {
91 TestRequestBuilder::new(self, Method::PATCH, path)
92 }
93
94 pub fn delete(&self, path: &str) -> TestRequestBuilder<'_> {
96 TestRequestBuilder::new(self, Method::DELETE, path)
97 }
98}
99
100impl Default for TestClient {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106pub struct TestRequestBuilder<'a> {
108 #[allow(dead_code)]
109 client: &'a TestClient,
110 #[allow(dead_code)]
111 method: Method,
112 path: String,
113 headers: HeaderMap,
114 body: Option<Bytes>,
115 query_params: HashMap<String, String>,
116}
117
118impl<'a> TestRequestBuilder<'a> {
119 fn new(client: &'a TestClient, method: Method, path: &str) -> Self {
120 let headers = client.default_headers.clone();
121
122 Self {
123 client,
124 method,
125 path: path.to_string(),
126 headers,
127 body: None,
128 query_params: HashMap::new(),
129 }
130 }
131
132 pub fn header(mut self, name: &str, value: &str) -> Self {
134 if let (Ok(name), Ok(value)) = (HeaderName::try_from(name), HeaderValue::try_from(value)) {
135 self.headers.insert(name, value);
136 }
137 self
138 }
139
140 pub fn bearer_token(self, token: &str) -> Self {
142 self.header("Authorization", &format!("Bearer {token}"))
143 }
144
145 pub fn basic_auth(self, username: &str, password: &str) -> Self {
147 use base64::Engine;
148 let credentials =
149 base64::engine::general_purpose::STANDARD.encode(format!("{username}:{password}"));
150 self.header("Authorization", &format!("Basic {credentials}"))
151 }
152
153 pub fn query(mut self, key: &str, value: &str) -> Self {
155 self.query_params.insert(key.to_string(), value.to_string());
156 self
157 }
158
159 pub fn body(mut self, body: impl Into<Bytes>) -> Self {
161 self.body = Some(body.into());
162 self
163 }
164
165 pub fn json<T: serde::Serialize>(mut self, data: &T) -> Self {
167 if let Ok(bytes) = serde_json::to_vec(data) {
168 self.body = Some(Bytes::from(bytes));
169 self.headers.insert(
170 HeaderName::from_static("content-type"),
171 HeaderValue::from_static("application/json"),
172 );
173 }
174 self
175 }
176
177 pub fn form(mut self, data: &[(String, String)]) -> Self {
179 let encoded = serde_urlencoded::to_string(data).unwrap_or_default();
180 self.body = Some(Bytes::from(encoded));
181 self.headers.insert(
182 HeaderName::from_static("content-type"),
183 HeaderValue::from_static("application/x-www-form-urlencoded"),
184 );
185 self
186 }
187
188 fn build_path(&self) -> String {
190 if self.query_params.is_empty() {
191 self.path.clone()
192 } else {
193 let query = self
194 .query_params
195 .iter()
196 .map(|(k, v)| format!("{k}={v}"))
197 .collect::<Vec<_>>()
198 .join("&");
199 format!("{}?{}", self.path, query)
200 }
201 }
202
203 pub async fn send(self) -> TestResponse {
208 let _full_path = self.build_path();
210
211 TestResponse {
215 status: StatusCode::OK,
216 headers: HeaderMap::new(),
217 body: Bytes::new(),
218 location: None,
219 }
220 }
221}
222
223#[derive(Debug, Clone)]
225pub struct TestResponse {
226 status: StatusCode,
227 headers: HeaderMap,
228 body: Bytes,
229 location: Option<String>,
230}
231
232impl TestResponse {
233 pub fn new(status: StatusCode, headers: HeaderMap, body: Bytes) -> Self {
235 let location = headers
236 .get("location")
237 .and_then(|v| v.to_str().ok())
238 .map(|s| s.to_string());
239
240 Self {
241 status,
242 headers,
243 body,
244 location,
245 }
246 }
247
248 pub fn from_parts(status: u16, headers: Vec<(&str, &str)>, body: impl Into<Bytes>) -> Self {
250 let mut header_map = HeaderMap::new();
251 for (name, value) in headers {
252 if let (Ok(n), Ok(v)) = (HeaderName::try_from(name), HeaderValue::try_from(value)) {
253 header_map.insert(n, v);
254 }
255 }
256
257 let location = header_map
258 .get("location")
259 .and_then(|v| v.to_str().ok())
260 .map(|s| s.to_string());
261
262 Self {
263 status: StatusCode::from_u16(status).unwrap_or(StatusCode::OK),
264 headers: header_map,
265 body: body.into(),
266 location,
267 }
268 }
269
270 pub fn status(&self) -> StatusCode {
272 self.status
273 }
274
275 pub fn headers(&self) -> &HeaderMap {
277 &self.headers
278 }
279
280 pub fn header(&self, name: &str) -> Option<&str> {
282 self.headers.get(name).and_then(|v| v.to_str().ok())
283 }
284
285 pub fn body(&self) -> &Bytes {
287 &self.body
288 }
289
290 pub fn text(&self) -> String {
292 String::from_utf8_lossy(&self.body).to_string()
293 }
294
295 pub fn json<T: DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
297 serde_json::from_slice(&self.body)
298 }
299
300 pub fn json_value(&self) -> Result<Value, serde_json::Error> {
302 serde_json::from_slice(&self.body)
303 }
304
305 pub fn location(&self) -> Option<&str> {
307 self.location.as_deref()
308 }
309
310 pub fn assert_status(self, expected: u16) -> Self {
314 let actual = self.status.as_u16();
315 if actual != expected {
316 panic!(
317 "\nHTTP Status Assertion Failed\n\n Expected: {}\n Received: {}\n Body: {}\n",
318 expected,
319 actual,
320 self.text()
321 );
322 }
323 self
324 }
325
326 pub fn assert_ok(self) -> Self {
328 if !self.status.is_success() {
329 panic!(
330 "\nHTTP Status Assertion Failed\n\n Expected: 2xx (success)\n Received: {}\n Body: {}\n",
331 self.status.as_u16(),
332 self.text()
333 );
334 }
335 self
336 }
337
338 pub fn assert_redirect(self) -> Self {
340 if !self.status.is_redirection() {
341 panic!(
342 "\nHTTP Status Assertion Failed\n\n Expected: 3xx (redirect)\n Received: {}\n",
343 self.status.as_u16()
344 );
345 }
346 self
347 }
348
349 pub fn assert_redirect_to(self, expected_path: &str) -> Self {
351 if !self.status.is_redirection() {
352 panic!(
353 "\nHTTP Status Assertion Failed\n\n Expected: 3xx (redirect)\n Received: {}\n",
354 self.status.as_u16()
355 );
356 }
357
358 match &self.location {
359 Some(location) if location.contains(expected_path) => self,
360 Some(location) => {
361 panic!(
362 "\nRedirect Location Assertion Failed\n\n Expected to contain: {expected_path}\n Received: {location}\n"
363 );
364 }
365 None => {
366 panic!(
367 "\nRedirect Location Assertion Failed\n\n Expected Location header but none found\n"
368 );
369 }
370 }
371 }
372
373 pub fn assert_client_error(self) -> Self {
375 if !self.status.is_client_error() {
376 panic!(
377 "\nHTTP Status Assertion Failed\n\n Expected: 4xx (client error)\n Received: {}\n",
378 self.status.as_u16()
379 );
380 }
381 self
382 }
383
384 pub fn assert_server_error(self) -> Self {
386 if !self.status.is_server_error() {
387 panic!(
388 "\nHTTP Status Assertion Failed\n\n Expected: 5xx (server error)\n Received: {}\n",
389 self.status.as_u16()
390 );
391 }
392 self
393 }
394
395 pub fn assert_not_found(self) -> Self {
397 self.assert_status(404)
398 }
399
400 pub fn assert_unauthorized(self) -> Self {
402 self.assert_status(401)
403 }
404
405 pub fn assert_forbidden(self) -> Self {
407 self.assert_status(403)
408 }
409
410 pub fn assert_unprocessable(self) -> Self {
412 self.assert_status(422)
413 }
414
415 pub fn assert_header(self, name: &str, expected: &str) -> Self {
417 match self.header(name) {
418 Some(actual) if actual == expected => self,
419 Some(actual) => {
420 panic!(
421 "\nHeader Assertion Failed\n\n Header: {name}\n Expected: {expected}\n Received: {actual}\n"
422 );
423 }
424 None => {
425 panic!(
426 "\nHeader Assertion Failed\n\n Header '{}' not found in response\n Available headers: {:?}\n",
427 name,
428 self.headers.keys().map(|k| k.as_str()).collect::<Vec<_>>()
429 );
430 }
431 }
432 }
433
434 pub fn assert_header_exists(self, name: &str) -> Self {
436 if self.header(name).is_none() {
437 panic!(
438 "\nHeader Assertion Failed\n\n Expected header '{}' to exist\n Available headers: {:?}\n",
439 name,
440 self.headers.keys().map(|k| k.as_str()).collect::<Vec<_>>()
441 );
442 }
443 self
444 }
445
446 pub fn assert_json(self) -> Self {
448 let content_type = self.header("content-type").unwrap_or("");
449 if !content_type.contains("application/json") {
450 panic!(
451 "\nContent-Type Assertion Failed\n\n Expected: application/json\n Received: {content_type}\n"
452 );
453 }
454 self
455 }
456
457 pub fn assert_json_has(self, key: &str) -> Self {
459 match self.json_value() {
460 Ok(json) => {
461 if json.get(key).is_none() {
462 panic!(
463 "\nJSON Assertion Failed\n\n Expected key '{}' in JSON\n Received: {}\n",
464 key,
465 serde_json::to_string_pretty(&json).unwrap_or_default()
466 );
467 }
468 self
469 }
470 Err(e) => {
471 panic!(
472 "\nJSON Parse Error\n\n Error: {}\n Body: {}\n",
473 e,
474 self.text()
475 );
476 }
477 }
478 }
479
480 pub fn assert_json_is<T: serde::Serialize + Debug>(self, key: &str, expected: T) -> Self {
482 match self.json_value() {
483 Ok(json) => {
484 let expected_value = serde_json::to_value(&expected).unwrap();
485 match json.get(key) {
486 Some(actual) if actual == &expected_value => self,
487 Some(actual) => {
488 panic!(
489 "\nJSON Value Assertion Failed\n\n Key: {key}\n Expected: {expected_value:?}\n Received: {actual:?}\n"
490 );
491 }
492 None => {
493 panic!(
494 "\nJSON Assertion Failed\n\n Key '{}' not found in JSON\n Available keys: {:?}\n",
495 key,
496 json.as_object().map(|o| o.keys().collect::<Vec<_>>()).unwrap_or_default()
497 );
498 }
499 }
500 }
501 Err(e) => {
502 panic!(
503 "\nJSON Parse Error\n\n Error: {}\n Body: {}\n",
504 e,
505 self.text()
506 );
507 }
508 }
509 }
510
511 pub fn assert_json_matches<F>(self, key: &str, predicate: F) -> Self
513 where
514 F: FnOnce(&Value) -> bool,
515 {
516 match self.json_value() {
517 Ok(json) => match json.get(key) {
518 Some(value) if predicate(value) => self,
519 Some(value) => {
520 panic!(
521 "\nJSON Predicate Assertion Failed\n\n Key: {key}\n Value: {value:?}\n The predicate returned false\n"
522 );
523 }
524 None => {
525 panic!("\nJSON Assertion Failed\n\n Key '{key}' not found in JSON\n");
526 }
527 },
528 Err(e) => {
529 panic!(
530 "\nJSON Parse Error\n\n Error: {}\n Body: {}\n",
531 e,
532 self.text()
533 );
534 }
535 }
536 }
537
538 pub fn assert_json_equals<T: serde::Serialize + Debug>(self, expected: T) -> Self {
540 match self.json_value() {
541 Ok(actual) => {
542 let expected_value = serde_json::to_value(&expected).unwrap();
543 if actual != expected_value {
544 panic!(
545 "\nJSON Equality Assertion Failed\n\n Expected:\n{}\n\n Received:\n{}\n",
546 serde_json::to_string_pretty(&expected_value).unwrap_or_default(),
547 serde_json::to_string_pretty(&actual).unwrap_or_default()
548 );
549 }
550 self
551 }
552 Err(e) => {
553 panic!(
554 "\nJSON Parse Error\n\n Error: {}\n Body: {}\n",
555 e,
556 self.text()
557 );
558 }
559 }
560 }
561
562 pub fn assert_see(self, needle: &str) -> Self {
564 let body = self.text();
565 if !body.contains(needle) {
566 panic!("\nBody Assertion Failed\n\n Expected to see: {needle}\n Body:\n{body}\n");
567 }
568 self
569 }
570
571 pub fn assert_dont_see(self, needle: &str) -> Self {
573 let body = self.text();
574 if body.contains(needle) {
575 panic!("\nBody Assertion Failed\n\n Expected NOT to see: {needle}\n Body:\n{body}\n");
576 }
577 self
578 }
579
580 pub fn assert_validation_errors(self, fields: &[&str]) -> Self {
582 match self.json_value() {
583 Ok(json) => {
584 let errors = json
586 .get("errors")
587 .or_else(|| json.get("validation_errors"))
588 .or_else(|| {
589 json.get("message")
590 .and_then(|m| if m.is_object() { Some(m) } else { None })
591 });
592
593 match errors {
594 Some(errors_obj) => {
595 for field in fields {
596 if errors_obj.get(*field).is_none() {
597 panic!(
598 "\nValidation Error Assertion Failed\n\n Expected error for field: {}\n Errors: {}\n",
599 field,
600 serde_json::to_string_pretty(errors_obj).unwrap_or_default()
601 );
602 }
603 }
604 self
605 }
606 None => {
607 panic!(
608 "\nValidation Error Assertion Failed\n\n Expected 'errors' key in response\n Response: {}\n",
609 serde_json::to_string_pretty(&json).unwrap_or_default()
610 );
611 }
612 }
613 }
614 Err(e) => {
615 panic!(
616 "\nJSON Parse Error\n\n Error: {}\n Body: {}\n",
617 e,
618 self.text()
619 );
620 }
621 }
622 }
623
624 pub fn assert_json_count(self, key: &str, expected: usize) -> Self {
626 match self.json_value() {
627 Ok(json) => match json.get(key) {
628 Some(Value::Array(arr)) if arr.len() == expected => self,
629 Some(Value::Array(arr)) => {
630 panic!(
631 "\nJSON Count Assertion Failed\n\n Key: {}\n Expected: {} items\n Received: {} items\n",
632 key, expected, arr.len()
633 );
634 }
635 Some(other) => {
636 panic!(
637 "\nJSON Count Assertion Failed\n\n Key '{}' is not an array\n Type: {}\n",
638 key,
639 match other {
640 Value::Null => "null",
641 Value::Bool(_) => "boolean",
642 Value::Number(_) => "number",
643 Value::String(_) => "string",
644 Value::Object(_) => "object",
645 Value::Array(_) => "array",
646 }
647 );
648 }
649 None => {
650 panic!("\nJSON Count Assertion Failed\n\n Key '{key}' not found\n");
651 }
652 },
653 Err(e) => {
654 panic!(
655 "\nJSON Parse Error\n\n Error: {}\n Body: {}\n",
656 e,
657 self.text()
658 );
659 }
660 }
661 }
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667
668 #[test]
669 fn test_response_assert_status() {
670 let response = TestResponse::from_parts(200, vec![], "");
671 response.assert_status(200);
672 }
673
674 #[test]
675 fn test_response_assert_ok() {
676 let response = TestResponse::from_parts(201, vec![], "");
677 response.assert_ok();
678 }
679
680 #[test]
681 fn test_response_assert_json_has() {
682 let body = r#"{"name": "test", "email": "test@example.com"}"#;
683 let response =
684 TestResponse::from_parts(200, vec![("content-type", "application/json")], body);
685 response.assert_json_has("name").assert_json_has("email");
686 }
687
688 #[test]
689 fn test_response_assert_json_is() {
690 let body = r#"{"count": 5, "name": "test"}"#;
691 let response =
692 TestResponse::from_parts(200, vec![("content-type", "application/json")], body);
693 response
694 .assert_json_is("count", 5)
695 .assert_json_is("name", "test");
696 }
697
698 #[test]
699 fn test_response_assert_see() {
700 let body = "Hello, World!";
701 let response = TestResponse::from_parts(200, vec![], body);
702 response.assert_see("Hello").assert_dont_see("Goodbye");
703 }
704
705 #[test]
706 fn test_response_assert_redirect() {
707 let response = TestResponse::from_parts(302, vec![("location", "/dashboard")], "");
708 response.assert_redirect().assert_redirect_to("/dashboard");
709 }
710
711 #[test]
712 fn test_response_assert_json_count() {
713 let body = r#"{"items": [1, 2, 3]}"#;
714 let response =
715 TestResponse::from_parts(200, vec![("content-type", "application/json")], body);
716 response.assert_json_count("items", 3);
717 }
718}