1use parking_lot::Mutex;
46use std::collections::HashMap;
47use std::future::Future;
48use std::sync::Arc;
49
50use asupersync::Cx;
51
52use crate::context::RequestContext;
53use crate::dependency::{DependencyOverrides, FromDependency};
54use crate::middleware::Handler;
55use crate::request::{Body, Method, Request};
56use crate::response::{Response, ResponseBody, StatusCode};
57
58#[derive(Debug, Clone, Default)]
63pub struct CookieJar {
64 cookies: Vec<StoredCookie>,
65 next_id: u64,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69enum CookieSameSite {
70 Lax,
71 Strict,
72 None,
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
76enum CookieDomain {
77 Any,
79 HostOnly(String),
81 Domain(String),
83}
84
85#[derive(Debug, Clone)]
86struct StoredCookie {
87 id: u64,
88 name: String,
89 value: String,
90 domain: CookieDomain,
91 path: String,
92 secure: bool,
93 #[allow(dead_code)]
94 http_only: bool,
95 #[allow(dead_code)]
96 same_site: Option<CookieSameSite>,
97 expires_at: Option<std::time::SystemTime>,
98}
99
100impl CookieJar {
101 #[must_use]
103 pub fn new() -> Self {
104 Self::default()
105 }
106
107 pub fn set(&mut self, name: impl Into<String>, value: impl Into<String>) {
109 let name = name.into();
110 let value = value.into();
111
112 self.cookies
114 .retain(|c| !(c.name == name && c.domain == CookieDomain::Any));
115
116 let id = self.next_id;
117 self.next_id = self.next_id.wrapping_add(1);
118 self.cookies.push(StoredCookie {
119 id,
120 name,
121 value,
122 domain: CookieDomain::Any,
123 path: "/".to_string(),
124 secure: false,
125 http_only: false,
126 same_site: None,
127 expires_at: None,
128 });
129 }
130
131 #[must_use]
133 pub fn get(&self, name: &str) -> Option<&str> {
134 self.cookies
135 .iter()
136 .filter(|c| c.name == name)
137 .max_by_key(|c| c.id)
138 .map(|c| c.value.as_str())
139 }
140
141 pub fn remove(&mut self, name: &str) -> Option<String> {
143 let mut removed: Option<String> = None;
144 self.cookies.retain(|c| {
145 if c.name == name {
146 removed = Some(c.value.clone());
147 false
148 } else {
149 true
150 }
151 });
152 removed
153 }
154
155 pub fn clear(&mut self) {
157 self.cookies.clear();
158 self.next_id = 0;
159 }
160
161 #[must_use]
163 pub fn len(&self) -> usize {
164 self.cookies.len()
165 }
166
167 #[must_use]
169 pub fn is_empty(&self) -> bool {
170 self.cookies.is_empty()
171 }
172
173 #[must_use]
175 pub fn to_cookie_header(&self) -> Option<String> {
176 let mut by_name: HashMap<&str, &StoredCookie> = HashMap::new();
177 for c in &self.cookies {
178 match by_name.get(c.name.as_str()) {
179 Some(existing) if existing.id >= c.id => {}
180 _ => {
181 by_name.insert(c.name.as_str(), c);
182 }
183 }
184 }
185
186 if by_name.is_empty() {
187 return None;
188 }
189
190 Some(
191 by_name
192 .into_values()
193 .map(|c| format!("{}={}", c.name, c.value))
194 .collect::<Vec<_>>()
195 .join("; "),
196 )
197 }
198
199 #[must_use]
204 pub fn cookie_header_for_request(&self, request: &Request) -> Option<String> {
205 let host = request_host(request);
206 let path = request.path();
207 let is_secure = request_is_secure(request);
208 let now = std::time::SystemTime::now();
209
210 let mut selected: HashMap<&str, &StoredCookie> = HashMap::new();
212 for c in &self.cookies {
213 if c.secure && !is_secure {
214 continue;
215 }
216 if let Some(exp) = c.expires_at {
217 if exp <= now {
218 continue;
219 }
220 }
221 if !domain_matches(&c.domain, host.as_deref()) {
222 continue;
223 }
224 if !path_matches(&c.path, path) {
225 continue;
226 }
227
228 match selected.get(c.name.as_str()) {
229 None => {
230 selected.insert(c.name.as_str(), c);
231 }
232 Some(existing) => {
233 let a = (c.path.len(), c.id);
234 let b = (existing.path.len(), existing.id);
235 if a > b {
236 selected.insert(c.name.as_str(), c);
237 }
238 }
239 }
240 }
241
242 if selected.is_empty() {
243 return None;
244 }
245
246 Some(
247 selected
248 .into_values()
249 .map(|c| format!("{}={}", c.name, c.value))
250 .collect::<Vec<_>>()
251 .join("; "),
252 )
253 }
254
255 pub fn parse_set_cookie(&mut self, request: &Request, header_value: &[u8]) {
257 let Ok(value) = std::str::from_utf8(header_value) else {
258 return;
259 };
260 self.parse_set_cookie_str(request, value);
261 }
262
263 fn parse_set_cookie_str(&mut self, request: &Request, value: &str) {
264 let mut parts = value.split(';');
265 let Some((name, val)) = parse_cookie_name_value(parts.next()) else {
266 return;
267 };
268
269 let host = request_host(request);
270 let attrs = parse_set_cookie_attrs(parts);
271
272 let Some(domain) = cookie_domain_for_set_cookie(host.as_deref(), attrs.domain) else {
273 return;
274 };
275
276 let path = attrs
277 .path
278 .unwrap_or_else(|| default_cookie_path(request.path()));
279
280 let expires_at = match compute_cookie_expiration(attrs.max_age, attrs.expires_at) {
281 CookieExpiration::Delete => {
282 self.remove_by_key(name, &domain, &path);
283 return;
284 }
285 CookieExpiration::Keep(expires_at) => expires_at,
286 };
287
288 let id = self.next_id;
289 self.next_id = self.next_id.wrapping_add(1);
290 self.upsert(StoredCookie {
291 id,
292 name: name.to_string(),
293 value: val.to_string(),
294 domain,
295 path,
296 secure: attrs.secure,
297 http_only: attrs.http_only,
298 same_site: attrs.same_site,
299 expires_at,
300 });
301 }
302
303 fn remove_by_key(&mut self, name: &str, domain: &CookieDomain, path: &str) {
304 self.cookies
305 .retain(|c| !(c.name == name && &c.domain == domain && c.path == path));
306 }
307
308 fn upsert(&mut self, cookie: StoredCookie) {
309 for existing in &mut self.cookies {
311 if existing.name == cookie.name
312 && existing.domain == cookie.domain
313 && existing.path == cookie.path
314 {
315 *existing = cookie;
316 return;
317 }
318 }
319 self.cookies.push(cookie);
320 }
321}
322
323#[derive(Debug, Default)]
324struct SetCookieAttrs {
325 domain: Option<String>,
326 path: Option<String>,
327 max_age: Option<i64>,
328 secure: bool,
329 http_only: bool,
330 same_site: Option<CookieSameSite>,
331 expires_at: Option<std::time::SystemTime>,
332}
333
334#[derive(Debug, Clone, Copy)]
335enum CookieExpiration {
336 Delete,
337 Keep(Option<std::time::SystemTime>),
338}
339
340fn parse_cookie_name_value(first: Option<&str>) -> Option<(&str, &str)> {
341 let first = first?;
342 let (name, val) = first.split_once('=')?;
343 let name = name.trim();
344 if name.is_empty() {
345 return None;
346 }
347 Some((name, val.trim()))
348}
349
350fn parse_set_cookie_attrs<'a>(parts: impl Iterator<Item = &'a str>) -> SetCookieAttrs {
351 let mut attrs = SetCookieAttrs::default();
352 for raw in parts {
353 let raw = raw.trim();
354 if raw.is_empty() {
355 continue;
356 }
357 if let Some((k, v)) = raw.split_once('=') {
358 let k = k.trim().to_ascii_lowercase();
359 let v = v.trim();
360 match k.as_str() {
361 "domain" => {
362 let mut d = v.trim_matches('"').trim().to_ascii_lowercase();
363 if let Some(stripped) = d.strip_prefix('.') {
364 d = stripped.to_string();
365 }
366 if !d.is_empty() {
367 attrs.domain = Some(d);
368 }
369 }
370 "path" => {
371 let p = v.trim_matches('"').trim();
372 if p.starts_with('/') {
373 attrs.path = Some(p.to_string());
374 }
375 }
376 "max-age" => {
377 if let Ok(n) = v.parse::<i64>() {
378 attrs.max_age = Some(n);
379 }
380 }
381 "samesite" => {
382 let ss = v.trim_matches('"').trim();
383 attrs.same_site = match ss.to_ascii_lowercase().as_str() {
384 "lax" => Some(CookieSameSite::Lax),
385 "strict" => Some(CookieSameSite::Strict),
386 "none" => Some(CookieSameSite::None),
387 _ => None,
388 };
389 }
390 "expires" => {
391 if let Some(t) = parse_http_date(v) {
392 attrs.expires_at = Some(t);
393 }
394 }
395 _ => {}
396 }
397 } else {
398 match raw.to_ascii_lowercase().as_str() {
399 "secure" => attrs.secure = true,
400 "httponly" => attrs.http_only = true,
401 _ => {}
402 }
403 }
404 }
405 attrs
406}
407
408fn cookie_domain_for_set_cookie(
409 host: Option<&str>,
410 domain_attr: Option<String>,
411) -> Option<CookieDomain> {
412 match domain_attr {
413 Some(d) => {
414 let h = host?;
415 let domain = CookieDomain::Domain(d);
416 if domain_matches(&domain, Some(h)) {
417 Some(domain)
418 } else {
419 None
420 }
421 }
422 None => {
423 let h = host?;
424 Some(CookieDomain::HostOnly(h.to_string()))
425 }
426 }
427}
428
429fn compute_cookie_expiration(
430 max_age: Option<i64>,
431 mut expires_at: Option<std::time::SystemTime>,
432) -> CookieExpiration {
433 let now = std::time::SystemTime::now();
434 if let Some(n) = max_age {
435 if n <= 0 {
436 return CookieExpiration::Delete;
437 }
438 let Ok(secs) = u64::try_from(n) else {
439 return CookieExpiration::Delete;
440 };
441 expires_at = now.checked_add(std::time::Duration::from_secs(secs));
442 }
443 if let Some(exp) = expires_at {
444 if exp <= now {
445 return CookieExpiration::Delete;
446 }
447 }
448 CookieExpiration::Keep(expires_at)
449}
450
451fn request_host(request: &Request) -> Option<String> {
452 let host = request.headers().get("host")?;
453 let s = std::str::from_utf8(host).ok()?;
454 let host = s.trim();
455 if host.is_empty() {
456 return None;
457 }
458 Some(host.split(':').next().unwrap_or(host).to_ascii_lowercase())
460}
461
462fn request_is_secure(request: &Request) -> bool {
463 if let Some(info) = request.get_extension::<crate::request::ConnectionInfo>() {
464 if info.is_tls {
465 return true;
466 }
467 }
468
469 if let Some(forwarded) = request.headers().get("forwarded") {
470 if let Ok(s) = std::str::from_utf8(forwarded) {
471 for entry in s.split(',') {
472 for param in entry.split(';') {
473 let param = param.trim();
474 if let Some((k, v)) = param.split_once('=') {
475 if k.trim().eq_ignore_ascii_case("proto") {
476 let proto = v.trim().trim_matches('"');
477 if proto.eq_ignore_ascii_case("https") {
478 return true;
479 }
480 }
481 }
482 }
483 }
484 }
485 }
486
487 if let Some(proto) = request.headers().get("x-forwarded-proto") {
488 let first = proto.split(|&b| b == b',').next().unwrap_or(proto);
489 let first = trim_ascii_bytes(first);
490 return first.eq_ignore_ascii_case(b"https");
491 }
492 if let Some(ssl) = request.headers().get("x-forwarded-ssl") {
493 return ssl.eq_ignore_ascii_case(b"on");
494 }
495 if let Some(https) = request.headers().get("front-end-https") {
496 return https.eq_ignore_ascii_case(b"on");
497 }
498
499 false
500}
501
502fn trim_ascii_bytes(mut bytes: &[u8]) -> &[u8] {
503 while matches!(bytes.first(), Some(b' ' | b'\t')) {
504 bytes = &bytes[1..];
505 }
506 while matches!(bytes.last(), Some(b' ' | b'\t')) {
507 bytes = &bytes[..bytes.len() - 1];
508 }
509 bytes
510}
511
512fn default_cookie_path(request_path: &str) -> String {
513 if !request_path.starts_with('/') {
515 return "/".to_string();
516 }
517 if request_path == "/" {
518 return "/".to_string();
519 }
520 match request_path.rfind('/') {
521 Some(0) | None => "/".to_string(),
522 Some(idx) => request_path[..idx].to_string(),
523 }
524}
525
526fn domain_matches(domain: &CookieDomain, host: Option<&str>) -> bool {
527 match domain {
528 CookieDomain::Any => true,
529 CookieDomain::HostOnly(d) => host.is_some_and(|h| h.eq_ignore_ascii_case(d)),
530 CookieDomain::Domain(d) => {
531 let Some(h) = host else { return false };
532 if h.eq_ignore_ascii_case(d) {
533 return true;
534 }
535 h.len() > d.len() && h.ends_with(d) && h.as_bytes()[h.len() - d.len() - 1] == b'.'
537 }
538 }
539}
540
541fn path_matches(cookie_path: &str, request_path: &str) -> bool {
542 if cookie_path == "/" {
543 return request_path.starts_with('/');
544 }
545 if !request_path.starts_with(cookie_path) {
546 return false;
547 }
548 if cookie_path.ends_with('/') {
549 return true;
550 }
551 request_path
552 .as_bytes()
553 .get(cookie_path.len())
554 .is_none_or(|&b| b == b'/')
555}
556
557fn parse_http_date(input: &str) -> Option<std::time::SystemTime> {
558 let s = input.trim().trim_matches('"').trim();
561 let (_dow, rest) = s.split_once(',')?;
562 let rest = rest.trim();
563 let mut it = rest.split_whitespace();
564 let day = it.next()?.parse::<u32>().ok()?;
565 let month = match it.next()? {
566 "Jan" => 1,
567 "Feb" => 2,
568 "Mar" => 3,
569 "Apr" => 4,
570 "May" => 5,
571 "Jun" => 6,
572 "Jul" => 7,
573 "Aug" => 8,
574 "Sep" => 9,
575 "Oct" => 10,
576 "Nov" => 11,
577 "Dec" => 12,
578 _ => return None,
579 };
580 let year = it.next()?.parse::<i32>().ok()?;
581 let time = it.next()?;
582 let tz = it.next()?;
583 if tz != "GMT" {
584 return None;
585 }
586 let (hh, mm, ss) = {
587 let mut t = time.split(':');
588 let hh = t.next()?.parse::<u32>().ok()?;
589 let mm = t.next()?.parse::<u32>().ok()?;
590 let ss = t.next()?.parse::<u32>().ok()?;
591 (hh, mm, ss)
592 };
593
594 fn days_from_civil(y: i32, m: u32, d: u32) -> i64 {
596 let y = i64::from(y) - i64::from(m <= 2);
598 let era = (if y >= 0 { y } else { y - 399 }) / 400;
599 let yoe = y - era * 400;
600 let m = i64::from(m);
601 let doy = (153 * (m + if m > 2 { -3 } else { 9 }) + 2) / 5 + i64::from(d) - 1;
602 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
603 era * 146097 + doe - 719468
604 }
605
606 let days = days_from_civil(year, month, day);
607 let secs = days
608 .checked_mul(86_400)?
609 .checked_add(i64::from(hh) * 3600 + i64::from(mm) * 60 + i64::from(ss))?;
610 if secs < 0 {
611 return None;
612 }
613 let secs_u64 = u64::try_from(secs).ok()?;
614 Some(std::time::UNIX_EPOCH + std::time::Duration::from_secs(secs_u64))
615}
616
617pub struct TestClient<H> {
651 handler: Arc<H>,
652 cookies: Arc<Mutex<CookieJar>>,
653 dependency_overrides: Arc<DependencyOverrides>,
654 seed: Option<u64>,
655 request_id_counter: Arc<std::sync::atomic::AtomicU64>,
656}
657
658impl<H: Handler + 'static> TestClient<H> {
659 pub fn new(handler: H) -> Self {
667 let dependency_overrides = handler
668 .dependency_overrides()
669 .unwrap_or_else(|| Arc::new(DependencyOverrides::new()));
670 Self {
671 handler: Arc::new(handler),
672 cookies: Arc::new(Mutex::new(CookieJar::new())),
673 dependency_overrides,
674 seed: None,
675 request_id_counter: Arc::new(std::sync::atomic::AtomicU64::new(1)),
676 }
677 }
678
679 pub fn with_seed(handler: H, seed: u64) -> Self {
690 let dependency_overrides = handler
691 .dependency_overrides()
692 .unwrap_or_else(|| Arc::new(DependencyOverrides::new()));
693 Self {
694 handler: Arc::new(handler),
695 cookies: Arc::new(Mutex::new(CookieJar::new())),
696 dependency_overrides,
697 seed: Some(seed),
698 request_id_counter: Arc::new(std::sync::atomic::AtomicU64::new(1)),
699 }
700 }
701
702 #[must_use]
704 pub fn seed(&self) -> Option<u64> {
705 self.seed
706 }
707
708 pub fn cookies(&self) -> parking_lot::MutexGuard<'_, CookieJar> {
713 self.cookies.lock()
714 }
715
716 pub fn clear_cookies(&self) {
718 self.cookies().clear();
719 }
720
721 #[must_use]
723 pub fn get(&self, path: &str) -> RequestBuilder<'_, H> {
724 RequestBuilder::new(self, Method::Get, path)
725 }
726
727 #[must_use]
729 pub fn post(&self, path: &str) -> RequestBuilder<'_, H> {
730 RequestBuilder::new(self, Method::Post, path)
731 }
732
733 #[must_use]
735 pub fn put(&self, path: &str) -> RequestBuilder<'_, H> {
736 RequestBuilder::new(self, Method::Put, path)
737 }
738
739 #[must_use]
741 pub fn delete(&self, path: &str) -> RequestBuilder<'_, H> {
742 RequestBuilder::new(self, Method::Delete, path)
743 }
744
745 #[must_use]
747 pub fn patch(&self, path: &str) -> RequestBuilder<'_, H> {
748 RequestBuilder::new(self, Method::Patch, path)
749 }
750
751 #[must_use]
753 pub fn options(&self, path: &str) -> RequestBuilder<'_, H> {
754 RequestBuilder::new(self, Method::Options, path)
755 }
756
757 #[must_use]
759 pub fn head(&self, path: &str) -> RequestBuilder<'_, H> {
760 RequestBuilder::new(self, Method::Head, path)
761 }
762
763 #[must_use]
765 pub fn request(&self, method: Method, path: &str) -> RequestBuilder<'_, H> {
766 RequestBuilder::new(self, method, path)
767 }
768
769 pub fn override_dependency<T, F, Fut>(&self, f: F)
771 where
772 T: FromDependency,
773 F: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
774 Fut: Future<Output = Result<T, T::Error>> + Send + 'static,
775 {
776 self.dependency_overrides.insert::<T, F, Fut>(f);
777 }
778
779 pub fn override_dependency_value<T>(&self, value: T)
781 where
782 T: FromDependency,
783 {
784 self.dependency_overrides.insert_value(value);
785 }
786
787 pub fn clear_dependency_overrides(&self) {
789 self.dependency_overrides.clear();
790 }
791
792 fn next_request_id(&self) -> u64 {
794 self.request_id_counter
795 .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
796 }
797
798 fn execute(&self, mut request: Request) -> TestResponse {
802 if !request.headers().contains("host") {
805 request.headers_mut().insert("host", b"testserver".to_vec());
806 }
807
808 {
810 let jar = self.cookies();
811 if let Some(cookie_header) = jar.cookie_header_for_request(&request) {
812 request
813 .headers_mut()
814 .insert("cookie", cookie_header.into_bytes());
815 }
816 }
817
818 let cx = Cx::for_testing();
820 let request_id = self.next_request_id();
821 let ctx =
822 RequestContext::with_overrides(cx, request_id, Arc::clone(&self.dependency_overrides));
823
824 let response = futures_executor::block_on(self.handler.call(&ctx, &mut request));
826
827 {
829 let mut jar = self.cookies();
830 for (name, value) in response.headers() {
831 if name.eq_ignore_ascii_case("set-cookie") {
832 jar.parse_set_cookie(&request, value);
833 }
834 }
835 }
836
837 TestResponse::new(response, request_id)
838 }
839}
840
841impl<H> Clone for TestClient<H> {
842 fn clone(&self) -> Self {
843 Self {
844 handler: Arc::clone(&self.handler),
845 cookies: Arc::clone(&self.cookies),
846 dependency_overrides: Arc::clone(&self.dependency_overrides),
847 seed: self.seed,
848 request_id_counter: Arc::clone(&self.request_id_counter),
849 }
850 }
851}
852
853pub struct RequestBuilder<'a, H> {
868 client: &'a TestClient<H>,
869 method: Method,
870 path: String,
871 query: Option<String>,
872 headers: Vec<(String, Vec<u8>)>,
873 body: Body,
874}
875
876impl<'a, H: Handler + 'static> RequestBuilder<'a, H> {
877 fn new(client: &'a TestClient<H>, method: Method, path: &str) -> Self {
879 let (path, query) = if let Some(idx) = path.find('?') {
881 (path[..idx].to_string(), Some(path[idx + 1..].to_string()))
882 } else {
883 (path.to_string(), None)
884 };
885
886 Self {
887 client,
888 method,
889 path,
890 query,
891 headers: Vec::new(),
892 body: Body::Empty,
893 }
894 }
895
896 #[must_use]
906 pub fn query(mut self, key: &str, value: &str) -> Self {
907 let param = format!("{key}={value}");
908 self.query = Some(match self.query {
909 Some(q) => format!("{q}&{param}"),
910 None => param,
911 });
912 self
913 }
914
915 #[must_use]
923 pub fn header(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
924 self.headers.push((name.into(), value.into()));
925 self
926 }
927
928 #[must_use]
930 pub fn header_str(self, name: impl Into<String>, value: &str) -> Self {
931 self.header(name, value.as_bytes().to_vec())
932 }
933
934 #[must_use]
942 pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
943 self.body = Body::Bytes(body.into());
944 self
945 }
946
947 #[must_use]
949 pub fn body_str(self, body: &str) -> Self {
950 self.body(body.as_bytes().to_vec())
951 }
952
953 #[must_use]
966 pub fn json<T: serde::Serialize>(mut self, value: &T) -> Self {
967 let bytes = serde_json::to_vec(value).expect("JSON serialization failed");
968 self.body = Body::Bytes(bytes);
969 self.headers
970 .push(("content-type".to_string(), b"application/json".to_vec()));
971 self
972 }
973
974 #[must_use]
978 pub fn cookie(self, name: &str, value: &str) -> Self {
979 let cookie = format!("{name}={value}");
980 self.header("cookie", cookie.into_bytes())
981 }
982
983 #[must_use]
991 pub fn send(self) -> TestResponse {
992 let mut request = Request::new(self.method, self.path);
993 request.set_query(self.query);
994 request.set_body(self.body);
995
996 for (name, value) in self.headers {
997 request.headers_mut().insert(name, value);
998 }
999
1000 self.client.execute(request)
1001 }
1002}
1003
1004#[derive(Debug)]
1021pub struct TestResponse {
1022 inner: Response,
1023 request_id: u64,
1024}
1025
1026impl TestResponse {
1027 fn new(response: Response, request_id: u64) -> Self {
1029 Self {
1030 inner: response,
1031 request_id,
1032 }
1033 }
1034
1035 #[must_use]
1037 pub fn request_id(&self) -> u64 {
1038 self.request_id
1039 }
1040
1041 #[must_use]
1043 pub fn status(&self) -> StatusCode {
1044 self.inner.status()
1045 }
1046
1047 #[must_use]
1049 pub fn status_code(&self) -> u16 {
1050 self.inner.status().as_u16()
1051 }
1052
1053 #[must_use]
1055 pub fn is_success(&self) -> bool {
1056 let code = self.status_code();
1057 (200..300).contains(&code)
1058 }
1059
1060 #[must_use]
1062 pub fn is_redirect(&self) -> bool {
1063 let code = self.status_code();
1064 (300..400).contains(&code)
1065 }
1066
1067 #[must_use]
1069 pub fn is_client_error(&self) -> bool {
1070 let code = self.status_code();
1071 (400..500).contains(&code)
1072 }
1073
1074 #[must_use]
1076 pub fn is_server_error(&self) -> bool {
1077 let code = self.status_code();
1078 (500..600).contains(&code)
1079 }
1080
1081 #[must_use]
1083 pub fn headers(&self) -> &[(String, Vec<u8>)] {
1084 self.inner.headers()
1085 }
1086
1087 #[must_use]
1089 pub fn header(&self, name: &str) -> Option<&[u8]> {
1090 let name_lower = name.to_ascii_lowercase();
1091 self.inner
1092 .headers()
1093 .iter()
1094 .find(|(n, _)| n.to_ascii_lowercase() == name_lower)
1095 .map(|(_, v)| v.as_slice())
1096 }
1097
1098 #[must_use]
1100 pub fn header_str(&self, name: &str) -> Option<&str> {
1101 self.header(name).and_then(|v| std::str::from_utf8(v).ok())
1102 }
1103
1104 #[must_use]
1106 pub fn content_type(&self) -> Option<&str> {
1107 self.header_str("content-type")
1108 }
1109
1110 #[must_use]
1112 pub fn bytes(&self) -> &[u8] {
1113 match self.inner.body_ref() {
1114 ResponseBody::Empty => &[],
1115 ResponseBody::Bytes(b) => b,
1116 ResponseBody::Stream(_) => {
1117 panic!("streaming response body not supported in TestResponse")
1118 }
1119 }
1120 }
1121
1122 #[must_use]
1128 pub fn text(&self) -> &str {
1129 std::str::from_utf8(self.bytes()).expect("response body is not valid UTF-8")
1130 }
1131
1132 #[must_use]
1134 pub fn text_opt(&self) -> Option<&str> {
1135 std::str::from_utf8(self.bytes()).ok()
1136 }
1137
1138 pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
1153 serde_json::from_slice(self.bytes())
1154 }
1155
1156 #[must_use]
1158 pub fn content_length(&self) -> usize {
1159 self.bytes().len()
1160 }
1161
1162 #[must_use]
1164 pub fn into_inner(self) -> Response {
1165 self.inner
1166 }
1167
1168 #[must_use]
1178 pub fn assert_status(&self, expected: StatusCode) -> &Self {
1179 assert_eq!(
1180 self.status(),
1181 expected,
1182 "Expected status {}, got {} for request {}",
1183 expected.as_u16(),
1184 self.status_code(),
1185 self.request_id
1186 );
1187 self
1188 }
1189
1190 #[must_use]
1196 pub fn assert_status_code(&self, expected: u16) -> &Self {
1197 assert_eq!(
1198 self.status_code(),
1199 expected,
1200 "Expected status {expected}, got {} for request {}",
1201 self.status_code(),
1202 self.request_id
1203 );
1204 self
1205 }
1206
1207 #[must_use]
1213 pub fn assert_success(&self) -> &Self {
1214 assert!(
1215 self.is_success(),
1216 "Expected success status, got {} for request {}",
1217 self.status_code(),
1218 self.request_id
1219 );
1220 self
1221 }
1222
1223 #[must_use]
1229 pub fn assert_header(&self, name: &str, expected: &str) -> &Self {
1230 let actual = self.header_str(name);
1231 assert_eq!(
1232 actual,
1233 Some(expected),
1234 "Expected header '{name}' to be '{expected}', got {:?} for request {}",
1235 actual,
1236 self.request_id
1237 );
1238 self
1239 }
1240
1241 #[must_use]
1247 pub fn assert_text(&self, expected: &str) -> &Self {
1248 assert_eq!(
1249 self.text(),
1250 expected,
1251 "Body mismatch for request {}",
1252 self.request_id
1253 );
1254 self
1255 }
1256
1257 #[must_use]
1263 pub fn assert_text_contains(&self, expected: &str) -> &Self {
1264 assert!(
1265 self.text().contains(expected),
1266 "Expected body to contain '{}', got '{}' for request {}",
1267 expected,
1268 self.text(),
1269 self.request_id
1270 );
1271 self
1272 }
1273
1274 #[must_use]
1280 pub fn assert_json<T>(&self, expected: &T) -> &Self
1281 where
1282 T: serde::de::DeserializeOwned + serde::Serialize + PartialEq + std::fmt::Debug,
1283 {
1284 let actual: T = self.json().expect("Failed to parse response as JSON");
1285 assert_eq!(
1286 actual, *expected,
1287 "JSON body mismatch for request {}",
1288 self.request_id
1289 );
1290 self
1291 }
1292
1293 #[must_use]
1311 pub fn assert_json_contains(&self, expected: &serde_json::Value) -> &Self {
1312 let actual: serde_json::Value = self.json().expect("Failed to parse response as JSON");
1313 if let Err(path) = json_contains(&actual, expected) {
1314 panic!(
1315 "JSON partial match failed at path '{}' for request {}\n\
1316 Expected (partial):\n{}\n\
1317 Actual:\n{}",
1318 path,
1319 self.request_id,
1320 serde_json::to_string_pretty(expected).unwrap_or_else(|_| format!("{expected:?}")),
1321 serde_json::to_string_pretty(&actual).unwrap_or_else(|_| format!("{actual:?}")),
1322 );
1323 }
1324 self
1325 }
1326
1327 #[cfg(feature = "regex")]
1339 #[must_use]
1340 pub fn assert_body_matches(&self, pattern: &str) -> &Self {
1341 let re = regex::Regex::new(pattern)
1342 .unwrap_or_else(|e| panic!("Invalid regex pattern '{pattern}': {e}"));
1343 let body = self.text();
1344 assert!(
1345 re.is_match(body),
1346 "Expected body to match pattern '{}', got '{}' for request {}",
1347 pattern,
1348 body,
1349 self.request_id
1350 );
1351 self
1352 }
1353
1354 #[cfg(feature = "regex")]
1361 #[must_use]
1362 pub fn assert_header_matches(&self, name: &str, pattern: &str) -> &Self {
1363 let re = regex::Regex::new(pattern)
1364 .unwrap_or_else(|e| panic!("Invalid regex pattern '{pattern}': {e}"));
1365 let value = self
1366 .header_str(name)
1367 .unwrap_or_else(|| panic!("Header '{name}' not found for request {}", self.request_id));
1368 assert!(
1369 re.is_match(value),
1370 "Expected header '{}' to match pattern '{}', got '{}' for request {}",
1371 name,
1372 pattern,
1373 value,
1374 self.request_id
1375 );
1376 self
1377 }
1378
1379 #[must_use]
1385 pub fn assert_header_exists(&self, name: &str) -> &Self {
1386 assert!(
1387 self.header(name).is_some(),
1388 "Expected header '{}' to exist for request {}",
1389 name,
1390 self.request_id
1391 );
1392 self
1393 }
1394
1395 #[must_use]
1401 pub fn assert_header_missing(&self, name: &str) -> &Self {
1402 assert!(
1403 self.header(name).is_none(),
1404 "Expected header '{}' to not exist for request {}, but found {:?}",
1405 name,
1406 self.request_id,
1407 self.header_str(name)
1408 );
1409 self
1410 }
1411
1412 #[must_use]
1421 pub fn assert_content_type_contains(&self, expected: &str) -> &Self {
1422 let ct = self.content_type().unwrap_or_else(|| {
1423 panic!(
1424 "Content-Type header not found for request {}",
1425 self.request_id
1426 )
1427 });
1428 assert!(
1429 ct.contains(expected),
1430 "Expected Content-Type to contain '{}', got '{}' for request {}",
1431 expected,
1432 ct,
1433 self.request_id
1434 );
1435 self
1436 }
1437}
1438
1439pub fn json_contains(
1471 actual: &serde_json::Value,
1472 expected: &serde_json::Value,
1473) -> Result<(), String> {
1474 json_contains_at_path(actual, expected, "$")
1475}
1476
1477fn json_contains_at_path(
1478 actual: &serde_json::Value,
1479 expected: &serde_json::Value,
1480 path: &str,
1481) -> Result<(), String> {
1482 use serde_json::Value;
1483
1484 match (actual, expected) {
1485 (Value::Object(actual_obj), Value::Object(expected_obj)) => {
1487 for (key, expected_val) in expected_obj {
1488 let child_path = format!("{path}.{key}");
1489 match actual_obj.get(key) {
1490 Some(actual_val) => {
1491 json_contains_at_path(actual_val, expected_val, &child_path)?;
1492 }
1493 None => {
1494 return Err(child_path);
1495 }
1496 }
1497 }
1498 Ok(())
1499 }
1500 (Value::Array(actual_arr), Value::Array(expected_arr)) => {
1502 if actual_arr.len() != expected_arr.len() {
1503 return Err(format!("{path}[length]"));
1504 }
1505 for (i, (actual_elem, expected_elem)) in
1506 actual_arr.iter().zip(expected_arr.iter()).enumerate()
1507 {
1508 let child_path = format!("{path}[{i}]");
1509 json_contains_at_path(actual_elem, expected_elem, &child_path)?;
1510 }
1511 Ok(())
1512 }
1513 _ => {
1515 if actual == expected {
1516 Ok(())
1517 } else {
1518 Err(path.to_string())
1519 }
1520 }
1521 }
1522}
1523
1524pub trait IntoStatusU16 {
1533 fn into_status_u16(self) -> u16;
1534}
1535
1536impl IntoStatusU16 for u16 {
1537 fn into_status_u16(self) -> u16 {
1538 self
1539 }
1540}
1541
1542impl IntoStatusU16 for StatusCode {
1543 fn into_status_u16(self) -> u16 {
1544 self.as_u16()
1545 }
1546}
1547
1548impl IntoStatusU16 for i32 {
1550 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
1551 fn into_status_u16(self) -> u16 {
1552 self as u16
1554 }
1555}
1556
1557#[macro_export]
1580macro_rules! assert_status {
1581 ($response:expr, $expected:expr) => {{
1582 let response = &$response;
1583 let actual = response.status_code();
1584 let expected_code: u16 = $crate::testing::IntoStatusU16::into_status_u16($expected);
1586 if actual != expected_code {
1587 panic!(
1588 "assertion failed: `(response.status() == {})`\n\
1589 expected status: {}\n\
1590 actual status: {}\n\
1591 request id: {}\n\
1592 response body: {}",
1593 expected_code,
1594 expected_code,
1595 actual,
1596 response.request_id(),
1597 response.text_opt().unwrap_or("<non-UTF8 body>")
1598 );
1599 }
1600 }};
1601 ($response:expr, $expected:expr, $($msg:tt)+) => {{
1602 let response = &$response;
1603 let actual = response.status_code();
1604 let expected_code: u16 = $crate::testing::IntoStatusU16::into_status_u16($expected);
1605 if actual != expected_code {
1606 panic!(
1607 "{}\n\
1608 assertion failed: `(response.status() == {})`\n\
1609 expected status: {}\n\
1610 actual status: {}\n\
1611 request id: {}\n\
1612 response body: {}",
1613 format_args!($($msg)+),
1614 expected_code,
1615 expected_code,
1616 actual,
1617 response.request_id(),
1618 response.text_opt().unwrap_or("<non-UTF8 body>")
1619 );
1620 }
1621 }};
1622}
1623
1624#[macro_export]
1640macro_rules! assert_header {
1641 ($response:expr, $name:expr, $expected:expr) => {{
1642 let response = &$response;
1643 let name = $name;
1644 let expected = $expected;
1645 let actual = response.header_str(name);
1646 if actual != Some(expected) {
1647 panic!(
1648 "assertion failed: `(response.header(\"{}\") == \"{}\")`\n\
1649 expected header '{}': \"{}\"\n\
1650 actual header '{}': {:?}\n\
1651 request id: {}",
1652 name,
1653 expected,
1654 name,
1655 expected,
1656 name,
1657 actual,
1658 response.request_id()
1659 );
1660 }
1661 }};
1662 ($response:expr, $name:expr, $expected:expr, $($msg:tt)+) => {{
1663 let response = &$response;
1664 let name = $name;
1665 let expected = $expected;
1666 let actual = response.header_str(name);
1667 if actual != Some(expected) {
1668 panic!(
1669 "{}\n\
1670 assertion failed: `(response.header(\"{}\") == \"{}\")`\n\
1671 expected header '{}': \"{}\"\n\
1672 actual header '{}': {:?}\n\
1673 request id: {}",
1674 format_args!($($msg)+),
1675 name,
1676 expected,
1677 name,
1678 expected,
1679 name,
1680 actual,
1681 response.request_id()
1682 );
1683 }
1684 }};
1685}
1686
1687#[macro_export]
1703macro_rules! assert_body_contains {
1704 ($response:expr, $expected:expr) => {{
1705 let response = &$response;
1706 let expected = $expected;
1707 let body = response.text();
1708 if !body.contains(expected) {
1709 panic!(
1710 "assertion failed: response body does not contain \"{}\"\n\
1711 expected substring: \"{}\"\n\
1712 actual body: \"{}\"\n\
1713 request id: {}",
1714 expected, expected, body, response.request_id()
1715 );
1716 }
1717 }};
1718 ($response:expr, $expected:expr, $($msg:tt)+) => {{
1719 let response = &$response;
1720 let expected = $expected;
1721 let body = response.text();
1722 if !body.contains(expected) {
1723 panic!(
1724 "{}\n\
1725 assertion failed: response body does not contain \"{}\"\n\
1726 expected substring: \"{}\"\n\
1727 actual body: \"{}\"\n\
1728 request id: {}",
1729 format_args!($($msg)+),
1730 expected,
1731 expected,
1732 body,
1733 response.request_id()
1734 );
1735 }
1736 }};
1737}
1738
1739#[macro_export]
1766macro_rules! assert_json {
1767 ($response:expr, $expected:tt) => {{
1768 let response = &$response;
1769 let expected = serde_json::json!($expected);
1770 let actual: serde_json::Value = response
1771 .json()
1772 .expect("Failed to parse response body as JSON");
1773
1774 if let Err(path) = $crate::testing::json_contains(&actual, &expected) {
1775 panic!(
1776 "assertion failed: JSON partial match failed at path '{}'\n\
1777 expected (partial):\n{}\n\
1778 actual:\n{}\n\
1779 request id: {}",
1780 path,
1781 serde_json::to_string_pretty(&expected).unwrap_or_else(|_| format!("{:?}", expected)),
1782 serde_json::to_string_pretty(&actual).unwrap_or_else(|_| format!("{:?}", actual)),
1783 response.request_id()
1784 );
1785 }
1786 }};
1787 ($response:expr, $expected:tt, $($msg:tt)+) => {{
1788 let response = &$response;
1789 let expected = serde_json::json!($expected);
1790 let actual: serde_json::Value = response
1791 .json()
1792 .expect("Failed to parse response body as JSON");
1793
1794 if let Err(path) = $crate::testing::json_contains(&actual, &expected) {
1795 panic!(
1796 "{}\n\
1797 assertion failed: JSON partial match failed at path '{}'\n\
1798 expected (partial):\n{}\n\
1799 actual:\n{}\n\
1800 request id: {}",
1801 format_args!($($msg)+),
1802 path,
1803 serde_json::to_string_pretty(&expected).unwrap_or_else(|_| format!("{:?}", expected)),
1804 serde_json::to_string_pretty(&actual).unwrap_or_else(|_| format!("{:?}", actual)),
1805 response.request_id()
1806 );
1807 }
1808 }};
1809}
1810
1811#[cfg(feature = "regex")]
1825#[macro_export]
1826macro_rules! assert_body_matches {
1827 ($response:expr, $pattern:expr) => {{
1828 let response = &$response;
1829 let pattern = $pattern;
1830 let re = regex::Regex::new(pattern)
1831 .unwrap_or_else(|e| panic!("Invalid regex pattern '{}': {}", pattern, e));
1832 let body = response.text();
1833 if !re.is_match(body) {
1834 panic!(
1835 "assertion failed: response body does not match pattern\n\
1836 pattern: \"{}\"\n\
1837 actual body: \"{}\"\n\
1838 request id: {}",
1839 pattern, body, response.request_id()
1840 );
1841 }
1842 }};
1843 ($response:expr, $pattern:expr, $($msg:tt)+) => {{
1844 let response = &$response;
1845 let pattern = $pattern;
1846 let re = regex::Regex::new(pattern)
1847 .unwrap_or_else(|e| panic!("Invalid regex pattern '{}': {}", pattern, e));
1848 let body = response.text();
1849 if !re.is_match(body) {
1850 panic!(
1851 "{}\n\
1852 assertion failed: response body does not match pattern\n\
1853 pattern: \"{}\"\n\
1854 actual body: \"{}\"\n\
1855 request id: {}",
1856 format_args!($($msg)+),
1857 pattern,
1858 body,
1859 response.request_id()
1860 );
1861 }
1862 }};
1863}
1864
1865#[derive(Debug, Clone)]
1890pub struct LabTestConfig {
1891 pub seed: u64,
1893 pub chaos_enabled: bool,
1895 pub chaos_intensity: f64,
1897 pub max_steps: Option<u64>,
1899 pub capture_traces: bool,
1901}
1902
1903impl Default for LabTestConfig {
1904 fn default() -> Self {
1905 Self {
1906 seed: 42,
1907 chaos_enabled: false,
1908 chaos_intensity: 0.0,
1909 max_steps: Some(10_000),
1910 capture_traces: false,
1911 }
1912 }
1913}
1914
1915impl LabTestConfig {
1916 #[must_use]
1918 pub fn new(seed: u64) -> Self {
1919 Self {
1920 seed,
1921 ..Default::default()
1922 }
1923 }
1924
1925 #[must_use]
1929 pub fn with_light_chaos(mut self) -> Self {
1930 self.chaos_enabled = true;
1931 self.chaos_intensity = 0.05;
1932 self
1933 }
1934
1935 #[must_use]
1939 pub fn with_heavy_chaos(mut self) -> Self {
1940 self.chaos_enabled = true;
1941 self.chaos_intensity = 0.2;
1942 self
1943 }
1944
1945 #[must_use]
1947 pub fn with_chaos_intensity(mut self, intensity: f64) -> Self {
1948 self.chaos_enabled = intensity > 0.0;
1949 self.chaos_intensity = intensity.clamp(0.0, 1.0);
1950 self
1951 }
1952
1953 #[must_use]
1955 pub fn with_max_steps(mut self, max: u64) -> Self {
1956 self.max_steps = Some(max);
1957 self
1958 }
1959
1960 #[must_use]
1962 pub fn without_step_limit(mut self) -> Self {
1963 self.max_steps = None;
1964 self
1965 }
1966
1967 #[must_use]
1969 pub fn with_traces(mut self) -> Self {
1970 self.capture_traces = true;
1971 self
1972 }
1973}
1974
1975#[derive(Debug, Clone, Default)]
1979pub struct TestChaosStats {
1980 pub decision_points: u64,
1982 pub delays_injected: u64,
1984 pub cancellations_injected: u64,
1986 pub steps_executed: u64,
1988}
1989
1990impl TestChaosStats {
1991 #[must_use]
1993 #[allow(clippy::cast_precision_loss)] pub fn injection_rate(&self) -> f64 {
1995 if self.decision_points == 0 {
1996 0.0
1997 } else {
1998 (self.delays_injected + self.cancellations_injected) as f64
1999 / self.decision_points as f64
2000 }
2001 }
2002
2003 #[must_use]
2005 pub fn had_chaos(&self) -> bool {
2006 self.delays_injected > 0 || self.cancellations_injected > 0
2007 }
2008}
2009
2010#[derive(Debug, Clone)]
2029pub struct MockTime {
2030 current_us: Arc<std::sync::atomic::AtomicU64>,
2032}
2033
2034impl Default for MockTime {
2035 fn default() -> Self {
2036 Self::new()
2037 }
2038}
2039
2040impl MockTime {
2041 #[must_use]
2043 pub fn new() -> Self {
2044 Self {
2045 current_us: Arc::new(std::sync::atomic::AtomicU64::new(0)),
2046 }
2047 }
2048
2049 #[must_use]
2051 pub fn starting_at(initial: std::time::Duration) -> Self {
2052 Self {
2053 current_us: Arc::new(std::sync::atomic::AtomicU64::new(initial.as_micros() as u64)),
2054 }
2055 }
2056
2057 #[must_use]
2059 pub fn now(&self) -> std::time::Duration {
2060 std::time::Duration::from_micros(self.current_us.load(std::sync::atomic::Ordering::Relaxed))
2061 }
2062
2063 #[must_use]
2065 pub fn elapsed(&self) -> std::time::Duration {
2066 self.now()
2067 }
2068
2069 pub fn advance(&self, duration: std::time::Duration) {
2071 self.current_us.fetch_add(
2072 duration.as_micros() as u64,
2073 std::sync::atomic::Ordering::Relaxed,
2074 );
2075 }
2076
2077 pub fn set(&self, time: std::time::Duration) {
2079 self.current_us.store(
2080 time.as_micros() as u64,
2081 std::sync::atomic::Ordering::Relaxed,
2082 );
2083 }
2084
2085 pub fn reset(&self) {
2087 self.current_us
2088 .store(0, std::sync::atomic::Ordering::Relaxed);
2089 }
2090}
2091
2092#[derive(Debug)]
2096pub struct CancellationTestResult {
2097 pub completed: bool,
2099 pub cancelled_at_checkpoint: bool,
2101 pub response: Option<Response>,
2103 pub cancellation_point: Option<String>,
2105}
2106
2107impl CancellationTestResult {
2108 #[must_use]
2110 pub fn gracefully_cancelled(&self) -> bool {
2111 !self.completed && self.cancelled_at_checkpoint
2112 }
2113
2114 #[must_use]
2116 pub fn completed_despite_cancel(&self) -> bool {
2117 self.completed
2118 }
2119}
2120
2121pub struct CancellationTest<H> {
2135 handler: H,
2136 seed: u64,
2137}
2138
2139impl<H: Handler + 'static> CancellationTest<H> {
2140 #[must_use]
2142 pub fn new(handler: H) -> Self {
2143 Self { handler, seed: 42 }
2144 }
2145
2146 #[must_use]
2148 pub fn with_seed(mut self, seed: u64) -> Self {
2149 self.seed = seed;
2150 self
2151 }
2152
2153 pub fn test_respects_cancellation(&self) -> CancellationTestResult {
2159 let cx = asupersync::Cx::for_testing();
2160 let ctx = RequestContext::new(cx, 1);
2161
2162 ctx.cx().set_cancel_requested(true);
2164
2165 let mut request = Request::new(Method::Get, "/test");
2166 let response = futures_executor::block_on(self.handler.call(&ctx, &mut request));
2167
2168 let is_cancelled_response = response.status().as_u16() == 499
2170 || response.status().as_u16() == 504
2171 || response.status().as_u16() == 503;
2172
2173 CancellationTestResult {
2174 completed: true,
2175 cancelled_at_checkpoint: is_cancelled_response,
2176 response: Some(response),
2177 cancellation_point: None,
2178 }
2179 }
2180
2181 pub fn complete_normally(&self) -> CancellationTestResult {
2183 let cx = asupersync::Cx::for_testing();
2184 let ctx = RequestContext::new(cx, 1);
2185 let mut request = Request::new(Method::Get, "/test");
2186
2187 let response = futures_executor::block_on(self.handler.call(&ctx, &mut request));
2188
2189 CancellationTestResult {
2190 completed: true,
2191 cancelled_at_checkpoint: false,
2192 response: Some(response),
2193 cancellation_point: None,
2194 }
2195 }
2196
2197 pub fn test_with_cancel_callback<F>(
2202 &self,
2203 path: &str,
2204 mut cancel_fn: F,
2205 ) -> CancellationTestResult
2206 where
2207 F: FnMut(&RequestContext) -> bool,
2208 {
2209 let cx = asupersync::Cx::for_testing();
2210 let ctx = RequestContext::new(cx, 1);
2211
2212 if cancel_fn(&ctx) {
2214 ctx.cx().set_cancel_requested(true);
2215 }
2216
2217 let mut request = Request::new(Method::Get, path);
2218 let response = futures_executor::block_on(self.handler.call(&ctx, &mut request));
2219
2220 let is_cancelled = ctx.is_cancelled();
2221 let is_cancelled_response =
2222 response.status().as_u16() == 499 || response.status().as_u16() == 504;
2223
2224 CancellationTestResult {
2225 completed: true,
2226 cancelled_at_checkpoint: is_cancelled && is_cancelled_response,
2227 response: Some(response),
2228 cancellation_point: None,
2229 }
2230 }
2231}
2232
2233#[cfg(test)]
2234mod tests {
2235 use super::*;
2236 use crate::app::App;
2237 use crate::dependency::{Depends, FromDependency};
2238 use crate::error::HttpError;
2239 use crate::extract::FromRequest;
2240 use crate::middleware::BoxFuture;
2241
2242 struct EchoHandler;
2244
2245 impl Handler for EchoHandler {
2246 fn call<'a>(
2247 &'a self,
2248 _ctx: &'a RequestContext,
2249 req: &'a mut Request,
2250 ) -> BoxFuture<'a, Response> {
2251 let method = format!("{:?}", req.method());
2252 let path = req.path().to_string();
2253 let body = format!("Method: {method}, Path: {path}");
2254 Box::pin(async move {
2255 Response::ok()
2256 .header("content-type", b"text/plain".to_vec())
2257 .body(ResponseBody::Bytes(body.into_bytes()))
2258 })
2259 }
2260 }
2261
2262 struct CookieHandler;
2264
2265 impl Handler for CookieHandler {
2266 fn call<'a>(
2267 &'a self,
2268 _ctx: &'a RequestContext,
2269 _req: &'a mut Request,
2270 ) -> BoxFuture<'a, Response> {
2271 Box::pin(async move {
2272 Response::ok()
2273 .header("set-cookie", b"session=abc123".to_vec())
2274 .body(ResponseBody::Bytes(b"Cookie set".to_vec()))
2275 })
2276 }
2277 }
2278
2279 struct CookieEchoHandler;
2281
2282 impl Handler for CookieEchoHandler {
2283 fn call<'a>(
2284 &'a self,
2285 _ctx: &'a RequestContext,
2286 req: &'a mut Request,
2287 ) -> BoxFuture<'a, Response> {
2288 let cookie = req.headers().get("cookie").map_or_else(
2289 || "no cookies".to_string(),
2290 |v| String::from_utf8_lossy(v).to_string(),
2291 );
2292 Box::pin(async move { Response::ok().body(ResponseBody::Bytes(cookie.into_bytes())) })
2293 }
2294 }
2295
2296 #[derive(Clone)]
2297 struct OverrideDep {
2298 value: usize,
2299 }
2300
2301 impl FromDependency for OverrideDep {
2302 type Error = HttpError;
2303
2304 async fn from_dependency(
2305 _ctx: &RequestContext,
2306 _req: &mut Request,
2307 ) -> Result<Self, Self::Error> {
2308 Ok(Self { value: 1 })
2309 }
2310 }
2311
2312 struct OverrideDepHandler;
2313
2314 impl Handler for OverrideDepHandler {
2315 fn call<'a>(
2316 &'a self,
2317 ctx: &'a RequestContext,
2318 req: &'a mut Request,
2319 ) -> BoxFuture<'a, Response> {
2320 Box::pin(async move {
2321 let dep = Depends::<OverrideDep>::from_request(ctx, req)
2322 .await
2323 .expect("dependency extraction failed");
2324 Response::ok().body(ResponseBody::Bytes(dep.value.to_string().into_bytes()))
2325 })
2326 }
2327 }
2328
2329 fn override_dep_route(ctx: &RequestContext, req: &mut Request) -> std::future::Ready<Response> {
2330 let dep = futures_executor::block_on(Depends::<OverrideDep>::from_request(ctx, req))
2331 .expect("dependency extraction failed");
2332 std::future::ready(
2333 Response::ok().body(ResponseBody::Bytes(dep.value.to_string().into_bytes())),
2334 )
2335 }
2336
2337 #[test]
2338 fn test_client_get() {
2339 let client = TestClient::new(EchoHandler);
2340 let response = client.get("/test/path").send();
2341
2342 assert_eq!(response.status_code(), 200);
2343 assert_eq!(response.text(), "Method: Get, Path: /test/path");
2344 }
2345
2346 #[test]
2347 fn test_client_post() {
2348 let client = TestClient::new(EchoHandler);
2349 let response = client.post("/api/items").send();
2350
2351 assert_eq!(response.status_code(), 200);
2352 assert!(response.text().contains("Method: Post"));
2353 }
2354
2355 #[test]
2360 #[ignore = "nested executor issue: override_dep_route uses block_on inside TestClient's block_on"]
2361 fn test_app_dependency_override_used_by_test_client() {
2362 let app = App::builder()
2363 .route("/", Method::Get, override_dep_route)
2364 .build();
2365
2366 app.override_dependency_value(OverrideDep { value: 42 });
2367
2368 let client = TestClient::new(app);
2369
2370 let response = client.get("/").send();
2371
2372 assert_eq!(response.text(), "42");
2373 }
2374
2375 #[test]
2376 fn test_test_client_override_clear() {
2377 let client = TestClient::new(OverrideDepHandler);
2378
2379 client.override_dependency_value(OverrideDep { value: 9 });
2380 let response = client.get("/").send();
2381 assert_eq!(response.text(), "9");
2382
2383 client.clear_dependency_overrides();
2384 let response = client.get("/").send();
2385 assert_eq!(response.text(), "1");
2386 }
2387
2388 #[test]
2389 fn test_client_all_methods() {
2390 let client = TestClient::new(EchoHandler);
2391
2392 assert!(client.get("/").send().text().contains("Get"));
2393 assert!(client.post("/").send().text().contains("Post"));
2394 assert!(client.put("/").send().text().contains("Put"));
2395 assert!(client.delete("/").send().text().contains("Delete"));
2396 assert!(client.patch("/").send().text().contains("Patch"));
2397 assert!(client.options("/").send().text().contains("Options"));
2398 assert!(client.head("/").send().text().contains("Head"));
2399 }
2400
2401 #[test]
2402 fn test_query_params() {
2403 let client = TestClient::new(EchoHandler);
2404 let response = client
2405 .get("/search")
2406 .query("q", "rust")
2407 .query("limit", "10")
2408 .send();
2409
2410 assert_eq!(response.status_code(), 200);
2411 }
2412
2413 #[test]
2414 fn test_response_assertions() {
2415 let client = TestClient::new(EchoHandler);
2416 let response = client.get("/test").send();
2417
2418 let _ = response
2419 .assert_status_code(200)
2420 .assert_success()
2421 .assert_header("content-type", "text/plain")
2422 .assert_text_contains("Get");
2423 }
2424
2425 #[test]
2426 fn test_response_status_checks() {
2427 let client = TestClient::new(EchoHandler);
2428 let response = client.get("/").send();
2429
2430 assert!(response.is_success());
2431 assert!(!response.is_redirect());
2432 assert!(!response.is_client_error());
2433 assert!(!response.is_server_error());
2434 }
2435
2436 #[test]
2437 fn test_cookie_jar() {
2438 let mut jar = CookieJar::new();
2439 assert!(jar.is_empty());
2440
2441 jar.set("session", "abc123");
2442 jar.set("user", "alice");
2443
2444 assert_eq!(jar.len(), 2);
2445 assert_eq!(jar.get("session"), Some("abc123"));
2446 assert_eq!(jar.get("user"), Some("alice"));
2447
2448 let header = jar.to_cookie_header().unwrap();
2449 assert!(header.contains("session=abc123"));
2450 assert!(header.contains("user=alice"));
2451
2452 jar.remove("session");
2453 assert_eq!(jar.len(), 1);
2454 assert_eq!(jar.get("session"), None);
2455 }
2456
2457 #[test]
2458 fn test_cookie_jar_request_matching_rules() {
2459 use crate::request::ConnectionInfo;
2460
2461 let mut jar = CookieJar::new();
2462
2463 let mut req = Request::new(Method::Get, "/account/settings");
2464 req.headers_mut().insert("host", b"example.com".to_vec());
2465
2466 jar.parse_set_cookie(
2468 &req,
2469 b"sid=1; Path=/account; Secure; HttpOnly; SameSite=Lax",
2470 );
2471 assert_eq!(jar.cookie_header_for_request(&req), None);
2472
2473 req.insert_extension(ConnectionInfo::HTTPS);
2475 assert_eq!(
2476 jar.cookie_header_for_request(&req).as_deref(),
2477 Some("sid=1")
2478 );
2479
2480 let mut req2 = Request::new(Method::Get, "/other");
2482 req2.headers_mut().insert("host", b"example.com".to_vec());
2483 req2.insert_extension(ConnectionInfo::HTTPS);
2484 assert_eq!(jar.cookie_header_for_request(&req2), None);
2485
2486 jar.parse_set_cookie(&req, b"sub=1; Domain=example.com; Path=/");
2488 let mut req3 = Request::new(Method::Get, "/");
2489 req3.headers_mut()
2490 .insert("host", b"api.example.com".to_vec());
2491 let hdr = jar.cookie_header_for_request(&req3).expect("cookie header");
2492 assert!(hdr.contains("sub=1"));
2493 }
2494
2495 #[test]
2496 fn test_cookie_persistence() {
2497 let client = TestClient::new(CookieHandler);
2498
2499 let _ = client.get("/set-cookie").send();
2501
2502 assert_eq!(client.cookies().get("session"), Some("abc123"));
2504
2505 let client2 = TestClient::new(CookieEchoHandler);
2507 client2.cookies().set("session", "abc123");
2508
2509 let response = client2.get("/check-cookie").send();
2510 assert!(response.text().contains("session=abc123"));
2511 }
2512
2513 #[test]
2514 fn test_request_id_increments() {
2515 let client = TestClient::new(EchoHandler);
2516
2517 let r1 = client.get("/").send();
2518 let r2 = client.get("/").send();
2519 let r3 = client.get("/").send();
2520
2521 assert_eq!(r1.request_id(), 1);
2522 assert_eq!(r2.request_id(), 2);
2523 assert_eq!(r3.request_id(), 3);
2524 }
2525
2526 #[test]
2527 fn test_client_with_seed() {
2528 let client = TestClient::with_seed(EchoHandler, 42);
2529 assert_eq!(client.seed(), Some(42));
2530 }
2531
2532 #[test]
2533 fn test_client_clone() {
2534 let client = TestClient::new(EchoHandler);
2535 client.cookies().set("test", "value");
2536
2537 let cloned = client.clone();
2538
2539 assert_eq!(cloned.cookies().get("test"), Some("value"));
2541
2542 let r1 = client.get("/").send();
2544 let r2 = cloned.get("/").send();
2545 assert_eq!(r1.request_id(), 1);
2546 assert_eq!(r2.request_id(), 2);
2547 }
2548
2549 #[test]
2554 fn test_json_contains_exact_match() {
2555 let actual = serde_json::json!({"id": 1, "name": "Alice"});
2556 let expected = serde_json::json!({"id": 1, "name": "Alice"});
2557 assert!(json_contains(&actual, &expected).is_ok());
2558 }
2559
2560 #[test]
2561 fn test_json_contains_partial_match() {
2562 let actual = serde_json::json!({"id": 1, "name": "Alice", "email": "alice@example.com"});
2563 let expected = serde_json::json!({"name": "Alice"});
2564 assert!(json_contains(&actual, &expected).is_ok());
2565 }
2566
2567 #[test]
2568 fn test_json_contains_nested_partial_match() {
2569 let actual = serde_json::json!({
2570 "user": {"id": 1, "name": "Alice", "email": "alice@example.com"},
2571 "status": "active"
2572 });
2573 let expected = serde_json::json!({
2574 "user": {"name": "Alice"}
2575 });
2576 assert!(json_contains(&actual, &expected).is_ok());
2577 }
2578
2579 #[test]
2580 fn test_json_contains_mismatch_value() {
2581 let actual = serde_json::json!({"id": 1, "name": "Alice"});
2582 let expected = serde_json::json!({"name": "Bob"});
2583 let result = json_contains(&actual, &expected);
2584 assert!(result.is_err());
2585 assert_eq!(result.unwrap_err(), "$.name");
2586 }
2587
2588 #[test]
2589 fn test_json_contains_missing_key() {
2590 let actual = serde_json::json!({"id": 1, "name": "Alice"});
2591 let expected = serde_json::json!({"email": "alice@example.com"});
2592 let result = json_contains(&actual, &expected);
2593 assert!(result.is_err());
2594 assert_eq!(result.unwrap_err(), "$.email");
2595 }
2596
2597 #[test]
2598 fn test_json_contains_array_exact_match() {
2599 let actual = serde_json::json!({"items": [1, 2, 3]});
2600 let expected = serde_json::json!({"items": [1, 2, 3]});
2601 assert!(json_contains(&actual, &expected).is_ok());
2602 }
2603
2604 #[test]
2605 fn test_json_contains_array_length_mismatch() {
2606 let actual = serde_json::json!({"items": [1, 2, 3]});
2607 let expected = serde_json::json!({"items": [1, 2]});
2608 let result = json_contains(&actual, &expected);
2609 assert!(result.is_err());
2610 assert_eq!(result.unwrap_err(), "$.items[length]");
2611 }
2612
2613 #[test]
2614 fn test_json_contains_array_element_mismatch() {
2615 let actual = serde_json::json!({"items": [1, 2, 3]});
2616 let expected = serde_json::json!({"items": [1, 5, 3]});
2617 let result = json_contains(&actual, &expected);
2618 assert!(result.is_err());
2619 assert_eq!(result.unwrap_err(), "$.items[1]");
2620 }
2621
2622 #[test]
2623 fn test_json_contains_primitives() {
2624 assert!(json_contains(&serde_json::json!(42), &serde_json::json!(42)).is_ok());
2626 assert!(json_contains(&serde_json::json!(42), &serde_json::json!(43)).is_err());
2627
2628 assert!(json_contains(&serde_json::json!("hello"), &serde_json::json!("hello")).is_ok());
2630 assert!(json_contains(&serde_json::json!("hello"), &serde_json::json!("world")).is_err());
2631
2632 assert!(json_contains(&serde_json::json!(true), &serde_json::json!(true)).is_ok());
2634 assert!(json_contains(&serde_json::json!(true), &serde_json::json!(false)).is_err());
2635
2636 assert!(json_contains(&serde_json::json!(null), &serde_json::json!(null)).is_ok());
2638 }
2639
2640 #[test]
2641 fn test_json_contains_type_mismatch() {
2642 let actual = serde_json::json!({"id": "1"});
2643 let expected = serde_json::json!({"id": 1});
2644 let result = json_contains(&actual, &expected);
2645 assert!(result.is_err());
2646 assert_eq!(result.unwrap_err(), "$.id");
2647 }
2648
2649 #[test]
2650 fn test_json_contains_deeply_nested() {
2651 let actual = serde_json::json!({
2652 "level1": {
2653 "level2": {
2654 "level3": {
2655 "value": 42,
2656 "extra": "ignored"
2657 }
2658 }
2659 }
2660 });
2661 let expected = serde_json::json!({
2662 "level1": {
2663 "level2": {
2664 "level3": {
2665 "value": 42
2666 }
2667 }
2668 }
2669 });
2670 assert!(json_contains(&actual, &expected).is_ok());
2671 }
2672
2673 struct JsonHandler;
2679
2680 impl Handler for JsonHandler {
2681 fn call<'a>(
2682 &'a self,
2683 _ctx: &'a RequestContext,
2684 _req: &'a mut Request,
2685 ) -> BoxFuture<'a, Response> {
2686 let json = serde_json::json!({
2687 "id": 1,
2688 "name": "Alice",
2689 "email": "alice@example.com",
2690 "active": true
2691 });
2692 let body = serde_json::to_vec(&json).unwrap();
2693 Box::pin(async move {
2694 Response::ok()
2695 .header("content-type", b"application/json".to_vec())
2696 .header("x-request-id", b"req-123".to_vec())
2697 .body(ResponseBody::Bytes(body))
2698 })
2699 }
2700 }
2701
2702 #[allow(dead_code)]
2704 struct StatusHandler(u16);
2705
2706 #[allow(dead_code)]
2707 impl Handler for StatusHandler {
2708 fn call<'a>(
2709 &'a self,
2710 _ctx: &'a RequestContext,
2711 _req: &'a mut Request,
2712 ) -> BoxFuture<'a, Response> {
2713 let status = StatusCode::from_u16(self.0);
2714 Box::pin(async move { Response::with_status(status) })
2715 }
2716 }
2717
2718 #[test]
2719 fn test_assert_status_macro_with_u16() {
2720 let client = TestClient::new(EchoHandler);
2721 let response = client.get("/").send();
2722 crate::assert_status!(response, 200);
2723 }
2724
2725 #[test]
2726 fn test_assert_status_macro_with_status_code() {
2727 let client = TestClient::new(EchoHandler);
2728 let response = client.get("/").send();
2729 crate::assert_status!(response, StatusCode::OK);
2730 }
2731
2732 #[test]
2733 #[should_panic(expected = "assertion failed")]
2734 fn test_assert_status_macro_failure() {
2735 let client = TestClient::new(EchoHandler);
2736 let response = client.get("/").send();
2737 crate::assert_status!(response, 404);
2738 }
2739
2740 #[test]
2741 fn test_assert_header_macro() {
2742 let client = TestClient::new(EchoHandler);
2743 let response = client.get("/").send();
2744 crate::assert_header!(response, "content-type", "text/plain");
2745 }
2746
2747 #[test]
2748 #[should_panic(expected = "assertion failed")]
2749 fn test_assert_header_macro_failure() {
2750 let client = TestClient::new(EchoHandler);
2751 let response = client.get("/").send();
2752 crate::assert_header!(response, "content-type", "application/json");
2753 }
2754
2755 #[test]
2756 fn test_assert_body_contains_macro() {
2757 let client = TestClient::new(EchoHandler);
2758 let response = client.get("/test").send();
2759 crate::assert_body_contains!(response, "Method: Get");
2760 crate::assert_body_contains!(response, "Path: /test");
2761 }
2762
2763 #[test]
2764 #[should_panic(expected = "assertion failed")]
2765 fn test_assert_body_contains_macro_failure() {
2766 let client = TestClient::new(EchoHandler);
2767 let response = client.get("/test").send();
2768 crate::assert_body_contains!(response, "nonexistent");
2769 }
2770
2771 #[test]
2772 fn test_assert_json_macro_partial_match() {
2773 let client = TestClient::new(JsonHandler);
2774 let response = client.get("/user").send();
2775
2776 crate::assert_json!(response, {"name": "Alice"});
2778 crate::assert_json!(response, {"id": 1, "active": true});
2779 }
2780
2781 #[test]
2782 fn test_assert_json_macro_exact_match() {
2783 let client = TestClient::new(JsonHandler);
2784 let response = client.get("/user").send();
2785
2786 crate::assert_json!(response, {
2788 "id": 1,
2789 "name": "Alice",
2790 "email": "alice@example.com",
2791 "active": true
2792 });
2793 }
2794
2795 #[test]
2796 #[should_panic(expected = "JSON partial match failed")]
2797 fn test_assert_json_macro_failure() {
2798 let client = TestClient::new(JsonHandler);
2799 let response = client.get("/user").send();
2800 crate::assert_json!(response, {"name": "Bob"});
2801 }
2802
2803 #[test]
2808 fn test_assert_json_contains_method() {
2809 let client = TestClient::new(JsonHandler);
2810 let response = client.get("/user").send();
2811
2812 let _ = response.assert_json_contains(&serde_json::json!({"name": "Alice"}));
2813 }
2814
2815 #[test]
2816 fn test_assert_header_exists() {
2817 let client = TestClient::new(JsonHandler);
2818 let response = client.get("/").send();
2819
2820 let _ = response
2821 .assert_header_exists("content-type")
2822 .assert_header_exists("x-request-id");
2823 }
2824
2825 #[test]
2826 #[should_panic(expected = "Expected header 'nonexistent' to exist")]
2827 fn test_assert_header_exists_failure() {
2828 let client = TestClient::new(JsonHandler);
2829 let response = client.get("/").send();
2830 let _ = response.assert_header_exists("nonexistent");
2831 }
2832
2833 #[test]
2834 fn test_assert_header_missing() {
2835 let client = TestClient::new(JsonHandler);
2836 let response = client.get("/").send();
2837
2838 let _ = response.assert_header_missing("x-nonexistent");
2839 }
2840
2841 #[test]
2842 #[should_panic(expected = "Expected header 'content-type' to not exist")]
2843 fn test_assert_header_missing_failure() {
2844 let client = TestClient::new(JsonHandler);
2845 let response = client.get("/").send();
2846 let _ = response.assert_header_missing("content-type");
2847 }
2848
2849 #[test]
2850 fn test_assert_content_type_contains() {
2851 let client = TestClient::new(JsonHandler);
2852 let response = client.get("/").send();
2853
2854 let _ = response.assert_content_type_contains("application/json");
2855 let _ = response.assert_content_type_contains("json");
2856 }
2857
2858 #[test]
2859 #[should_panic(expected = "Expected Content-Type to contain")]
2860 fn test_assert_content_type_contains_failure() {
2861 let client = TestClient::new(JsonHandler);
2862 let response = client.get("/").send();
2863 let _ = response.assert_content_type_contains("text/html");
2864 }
2865
2866 #[test]
2867 fn test_assertion_chaining() {
2868 let client = TestClient::new(JsonHandler);
2869 let response = client.get("/user").send();
2870
2871 let _ = response
2873 .assert_status_code(200)
2874 .assert_success()
2875 .assert_header_exists("content-type")
2876 .assert_content_type_contains("json")
2877 .assert_json_contains(&serde_json::json!({"name": "Alice"}));
2878 }
2879
2880 #[test]
2881 fn test_macro_with_custom_message() {
2882 let client = TestClient::new(EchoHandler);
2883 let response = client.get("/").send();
2884
2885 crate::assert_status!(response, 200, "Expected 200 OK from echo handler");
2887 crate::assert_header!(
2888 response,
2889 "content-type",
2890 "text/plain",
2891 "Should have text content type"
2892 );
2893 crate::assert_body_contains!(response, "Get", "Should contain HTTP method");
2894 }
2895
2896 #[derive(Clone)]
2902 struct DatabasePool {
2903 connection_string: String,
2904 }
2905
2906 impl FromDependency for DatabasePool {
2907 type Error = HttpError;
2908 async fn from_dependency(
2909 _ctx: &RequestContext,
2910 _req: &mut Request,
2911 ) -> Result<Self, Self::Error> {
2912 Ok(DatabasePool {
2913 connection_string: "postgres://localhost/test".to_string(),
2914 })
2915 }
2916 }
2917
2918 #[derive(Clone)]
2919 struct UserRepository {
2920 pool_conn_str: String,
2921 }
2922
2923 impl FromDependency for UserRepository {
2924 type Error = HttpError;
2925 async fn from_dependency(
2926 ctx: &RequestContext,
2927 req: &mut Request,
2928 ) -> Result<Self, Self::Error> {
2929 let pool = Depends::<DatabasePool>::from_request(ctx, req).await?;
2930 Ok(UserRepository {
2931 pool_conn_str: pool.connection_string.clone(),
2932 })
2933 }
2934 }
2935
2936 #[derive(Clone)]
2937 struct AuthService {
2938 user_repo_pool: String,
2939 }
2940
2941 impl FromDependency for AuthService {
2942 type Error = HttpError;
2943 async fn from_dependency(
2944 ctx: &RequestContext,
2945 req: &mut Request,
2946 ) -> Result<Self, Self::Error> {
2947 let repo = Depends::<UserRepository>::from_request(ctx, req).await?;
2948 Ok(AuthService {
2949 user_repo_pool: repo.pool_conn_str.clone(),
2950 })
2951 }
2952 }
2953
2954 struct ComplexDepHandler;
2955
2956 impl Handler for ComplexDepHandler {
2957 fn call<'a>(
2958 &'a self,
2959 ctx: &'a RequestContext,
2960 req: &'a mut Request,
2961 ) -> BoxFuture<'a, Response> {
2962 Box::pin(async move {
2963 let auth = Depends::<AuthService>::from_request(ctx, req)
2964 .await
2965 .expect("dependency resolution failed");
2966 let body = format!("AuthService.pool={}", auth.user_repo_pool);
2967 Response::ok().body(ResponseBody::Bytes(body.into_bytes()))
2968 })
2969 }
2970 }
2971
2972 #[test]
2973 fn test_full_request_with_complex_deps() {
2974 let client = TestClient::new(ComplexDepHandler);
2977 let response = client.get("/auth/check").send();
2978
2979 assert_eq!(response.status_code(), 200);
2980 assert!(response.text().contains("postgres://localhost/test"));
2981 }
2982
2983 #[test]
2984 fn test_complex_deps_with_override_at_leaf() {
2985 let client = TestClient::new(ComplexDepHandler);
2987 client.override_dependency_value(DatabasePool {
2988 connection_string: "mysql://prod/users".to_string(),
2989 });
2990
2991 let response = client.get("/auth/check").send();
2992
2993 assert_eq!(response.status_code(), 200);
2994 assert!(
2995 response.text().contains("mysql://prod/users"),
2996 "Override at leaf should propagate through dependency chain"
2997 );
2998 }
2999
3000 #[test]
3001 fn test_complex_deps_with_override_at_middle() {
3002 let client = TestClient::new(ComplexDepHandler);
3004 client.override_dependency_value(UserRepository {
3005 pool_conn_str: "overridden-repo-connection".to_string(),
3006 });
3007
3008 let response = client.get("/auth/check").send();
3009
3010 assert_eq!(response.status_code(), 200);
3011 assert!(
3012 response.text().contains("overridden-repo-connection"),
3013 "Override at middle level should be used"
3014 );
3015 }
3016
3017 #[test]
3018 fn test_dependency_caching_across_handler() {
3019 use std::sync::atomic::{AtomicUsize, Ordering};
3021
3022 static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
3023
3024 #[derive(Clone)]
3025 struct TrackedDep {
3026 call_number: usize,
3027 }
3028
3029 impl FromDependency for TrackedDep {
3030 type Error = HttpError;
3031 async fn from_dependency(
3032 _ctx: &RequestContext,
3033 _req: &mut Request,
3034 ) -> Result<Self, Self::Error> {
3035 let call_number = CALL_COUNT.fetch_add(1, Ordering::SeqCst);
3036 Ok(TrackedDep { call_number })
3037 }
3038 }
3039
3040 struct MultiDepHandler;
3041
3042 impl Handler for MultiDepHandler {
3043 fn call<'a>(
3044 &'a self,
3045 ctx: &'a RequestContext,
3046 req: &'a mut Request,
3047 ) -> BoxFuture<'a, Response> {
3048 Box::pin(async move {
3049 let dep1 = Depends::<TrackedDep>::from_request(ctx, req)
3051 .await
3052 .expect("first resolution failed");
3053 let dep2 = Depends::<TrackedDep>::from_request(ctx, req)
3054 .await
3055 .expect("second resolution failed");
3056
3057 let body = format!("dep1={} dep2={}", dep1.call_number, dep2.call_number);
3059 Response::ok().body(ResponseBody::Bytes(body.into_bytes()))
3060 })
3061 }
3062 }
3063
3064 CALL_COUNT.store(0, Ordering::SeqCst);
3066
3067 let client = TestClient::new(MultiDepHandler);
3068 let response = client.get("/").send();
3069
3070 let text = response.text();
3071 assert!(
3073 text.contains("dep1=0 dep2=0"),
3074 "Dependencies should be cached within request. Got: {}",
3075 text
3076 );
3077
3078 assert_eq!(CALL_COUNT.load(Ordering::SeqCst), 1);
3080 }
3081
3082 #[test]
3087 #[allow(clippy::float_cmp)] fn lab_test_config_defaults() {
3089 let config = LabTestConfig::default();
3090 assert_eq!(config.seed, 42);
3091 assert!(!config.chaos_enabled);
3092 assert_eq!(config.chaos_intensity, 0.0);
3093 assert_eq!(config.max_steps, Some(10_000));
3094 assert!(!config.capture_traces);
3095 }
3096
3097 #[test]
3098 fn lab_test_config_with_seed() {
3099 let config = LabTestConfig::new(12345);
3100 assert_eq!(config.seed, 12345);
3101 }
3102
3103 #[test]
3104 #[allow(clippy::float_cmp)] fn lab_test_config_light_chaos() {
3106 let config = LabTestConfig::new(42).with_light_chaos();
3107 assert!(config.chaos_enabled);
3108 assert_eq!(config.chaos_intensity, 0.05);
3109 }
3110
3111 #[test]
3112 #[allow(clippy::float_cmp)] fn lab_test_config_heavy_chaos() {
3114 let config = LabTestConfig::new(42).with_heavy_chaos();
3115 assert!(config.chaos_enabled);
3116 assert_eq!(config.chaos_intensity, 0.2);
3117 }
3118
3119 #[test]
3120 #[allow(clippy::float_cmp)] fn lab_test_config_custom_intensity() {
3122 let config = LabTestConfig::new(42).with_chaos_intensity(0.15);
3123 assert!(config.chaos_enabled);
3124 assert_eq!(config.chaos_intensity, 0.15);
3125 }
3126
3127 #[test]
3128 #[allow(clippy::float_cmp)] fn lab_test_config_intensity_clamps() {
3130 let config = LabTestConfig::new(42).with_chaos_intensity(1.5);
3131 assert_eq!(config.chaos_intensity, 1.0);
3132
3133 let config = LabTestConfig::new(42).with_chaos_intensity(-0.5);
3134 assert_eq!(config.chaos_intensity, 0.0);
3135 assert!(!config.chaos_enabled);
3136 }
3137
3138 #[test]
3139 fn lab_test_config_max_steps() {
3140 let config = LabTestConfig::new(42).with_max_steps(1000);
3141 assert_eq!(config.max_steps, Some(1000));
3142 }
3143
3144 #[test]
3145 fn lab_test_config_no_step_limit() {
3146 let config = LabTestConfig::new(42).without_step_limit();
3147 assert_eq!(config.max_steps, None);
3148 }
3149
3150 #[test]
3151 fn lab_test_config_with_traces() {
3152 let config = LabTestConfig::new(42).with_traces();
3153 assert!(config.capture_traces);
3154 }
3155
3156 #[test]
3157 #[allow(clippy::float_cmp)] fn test_chaos_stats_empty() {
3159 let stats = TestChaosStats::default();
3160 assert_eq!(stats.decision_points, 0);
3161 assert_eq!(stats.delays_injected, 0);
3162 assert_eq!(stats.cancellations_injected, 0);
3163 assert_eq!(stats.injection_rate(), 0.0);
3164 assert!(!stats.had_chaos());
3165 }
3166
3167 #[test]
3168 fn test_chaos_stats_with_injections() {
3169 let stats = TestChaosStats {
3170 decision_points: 100,
3171 delays_injected: 5,
3172 cancellations_injected: 2,
3173 steps_executed: 50,
3174 };
3175 assert!((stats.injection_rate() - 0.07).abs() < 0.001);
3176 assert!(stats.had_chaos());
3177 }
3178
3179 #[test]
3180 fn mock_time_basic() {
3181 let time = MockTime::new();
3182 assert_eq!(time.now(), std::time::Duration::ZERO);
3183 assert_eq!(time.elapsed(), std::time::Duration::ZERO);
3184
3185 time.advance(std::time::Duration::from_secs(5));
3186 assert_eq!(time.now(), std::time::Duration::from_secs(5));
3187 }
3188
3189 #[test]
3190 fn mock_time_set_and_reset() {
3191 let time = MockTime::new();
3192 time.set(std::time::Duration::from_secs(100));
3193 assert_eq!(time.now(), std::time::Duration::from_secs(100));
3194
3195 time.reset();
3196 assert_eq!(time.now(), std::time::Duration::ZERO);
3197 }
3198
3199 #[test]
3200 fn mock_time_starting_at() {
3201 let time = MockTime::starting_at(std::time::Duration::from_secs(10));
3202 assert_eq!(time.now(), std::time::Duration::from_secs(10));
3203 }
3204
3205 #[test]
3206 fn cancellation_test_completes_normally() {
3207 let test = CancellationTest::new(EchoHandler);
3208 let result = test.complete_normally();
3209
3210 assert!(result.completed);
3211 assert!(!result.cancelled_at_checkpoint);
3212 assert!(result.response.is_some());
3213 assert_eq!(result.response.as_ref().unwrap().status().as_u16(), 200);
3214 }
3215
3216 #[test]
3217 fn cancellation_test_respects_cancellation() {
3218 struct CheckpointHandler;
3220
3221 impl Handler for CheckpointHandler {
3222 fn call<'a>(
3223 &'a self,
3224 ctx: &'a RequestContext,
3225 _req: &'a mut Request,
3226 ) -> BoxFuture<'a, Response> {
3227 Box::pin(async move {
3228 if ctx.checkpoint().is_err() {
3230 return Response::with_status(StatusCode::CLIENT_CLOSED_REQUEST);
3231 }
3232 Response::ok().body(ResponseBody::Bytes(b"OK".to_vec()))
3233 })
3234 }
3235 }
3236
3237 let test = CancellationTest::new(CheckpointHandler);
3238 let result = test.test_respects_cancellation();
3239
3240 assert!(result.completed);
3241 assert!(result.cancelled_at_checkpoint);
3242 assert!(result.response.is_some());
3243 assert_eq!(result.response.as_ref().unwrap().status().as_u16(), 499);
3245 }
3246
3247 #[test]
3248 fn cancellation_test_result_helpers() {
3249 let graceful = CancellationTestResult {
3250 completed: false,
3251 cancelled_at_checkpoint: true,
3252 response: None,
3253 cancellation_point: None,
3254 };
3255 assert!(graceful.gracefully_cancelled());
3256 assert!(!graceful.completed_despite_cancel());
3257
3258 let completed = CancellationTestResult {
3259 completed: true,
3260 cancelled_at_checkpoint: false,
3261 response: Some(Response::ok()),
3262 cancellation_point: None,
3263 };
3264 assert!(!completed.gracefully_cancelled());
3265 assert!(completed.completed_despite_cancel());
3266 }
3267
3268 #[test]
3273 fn test_logger_captures_all_levels() {
3274 let logger = TestLogger::new();
3275
3276 logger.log_message(LogLevel::Debug, "debug message", 1);
3277 logger.log_message(LogLevel::Info, "info message", 1);
3278 logger.log_message(LogLevel::Warn, "warn message", 1);
3279 logger.log_message(LogLevel::Error, "error message", 1);
3280
3281 let logs = logger.logs();
3282 assert_eq!(logs.len(), 4);
3283
3284 assert_eq!(logs[0].level, LogLevel::Debug);
3285 assert_eq!(logs[1].level, LogLevel::Info);
3286 assert_eq!(logs[2].level, LogLevel::Warn);
3287 assert_eq!(logs[3].level, LogLevel::Error);
3288 }
3289
3290 #[test]
3291 fn test_logger_logs_at_level_filters_correctly() {
3292 let logger = TestLogger::new();
3293
3294 logger.log_message(LogLevel::Debug, "debug", 1);
3295 logger.log_message(LogLevel::Info, "info 1", 1);
3296 logger.log_message(LogLevel::Info, "info 2", 2);
3297 logger.log_message(LogLevel::Warn, "warn", 1);
3298 logger.log_message(LogLevel::Error, "error", 1);
3299
3300 let info_logs = logger.logs_at_level(LogLevel::Info);
3301 assert_eq!(info_logs.len(), 2);
3302 assert!(info_logs[0].contains("info 1"));
3303 assert!(info_logs[1].contains("info 2"));
3304
3305 let error_logs = logger.logs_at_level(LogLevel::Error);
3306 assert_eq!(error_logs.len(), 1);
3307 assert!(error_logs[0].contains("error"));
3308
3309 let trace_logs = logger.logs_at_level(LogLevel::Trace);
3310 assert_eq!(trace_logs.len(), 0);
3311 }
3312
3313 #[test]
3314 fn test_logger_contains_message_search() {
3315 let logger = TestLogger::new();
3316
3317 logger.log_message(LogLevel::Info, "User alice logged in", 100);
3318 logger.log_message(LogLevel::Info, "Request processed for /api/users", 101);
3319 logger.log_message(LogLevel::Warn, "Rate limit approaching for alice", 102);
3320
3321 assert!(logger.contains_message("alice"));
3322 assert!(logger.contains_message("/api/users"));
3323 assert!(logger.contains_message("Rate limit"));
3324 assert!(!logger.contains_message("bob"));
3325 assert!(!logger.contains_message("nonexistent"));
3326 }
3327
3328 #[test]
3329 fn test_logger_contains_multiple_messages() {
3330 let logger = TestLogger::new();
3331
3332 logger.log_message(LogLevel::Info, "step 1 complete", 1);
3333 logger.log_message(LogLevel::Info, "step 2 complete", 2);
3334 logger.log_message(LogLevel::Info, "step 3 complete", 3);
3335
3336 assert!(logger.contains_message("step 1"));
3338 assert!(logger.contains_message("step 2"));
3339 assert!(logger.contains_message("step 3"));
3340 assert!(logger.contains_message("complete"));
3341 assert!(!logger.contains_message("step 4"));
3343 }
3344
3345 #[test]
3346 fn test_log_capture_captures_logs_in_closure() {
3347 let capture = TestLogger::capture(|logger| {
3348 logger.log_message(LogLevel::Info, "inside capture", 1);
3349 logger.log_message(LogLevel::Warn, "warning inside", 2);
3350 42
3351 });
3352
3353 assert!(capture.passed());
3354 assert!(!capture.failed());
3355 assert_eq!(capture.result, Some(42));
3356 assert_eq!(capture.logs.len(), 2);
3357 assert!(capture.contains_message("inside capture"));
3358 assert!(capture.contains_message("warning inside"));
3359 }
3360
3361 #[test]
3362 fn test_log_capture_count_by_level() {
3363 let capture = TestLogger::capture(|logger| {
3364 logger.log_message(LogLevel::Info, "info 1", 1);
3365 logger.log_message(LogLevel::Info, "info 2", 2);
3366 logger.log_message(LogLevel::Info, "info 3", 3);
3367 logger.log_message(LogLevel::Error, "error 1", 4);
3368 });
3369
3370 assert_eq!(capture.count_by_level(LogLevel::Info), 3);
3371 assert_eq!(capture.count_by_level(LogLevel::Error), 1);
3372 assert_eq!(capture.count_by_level(LogLevel::Warn), 0);
3373 }
3374
3375 #[test]
3376 fn test_log_capture_phased_all_phases() {
3377 let capture = TestLogger::capture_phased(
3378 |logger| {
3379 logger.log_message(LogLevel::Info, "setup phase", 1);
3380 },
3381 |logger| {
3382 logger.log_message(LogLevel::Info, "execute phase", 2);
3383 "result"
3384 },
3385 |logger| {
3386 logger.log_message(LogLevel::Info, "teardown phase", 3);
3387 },
3388 );
3389
3390 assert!(capture.passed());
3391 assert_eq!(capture.result, Some("result"));
3392 assert_eq!(capture.logs.len(), 3);
3393 assert!(capture.contains_message("setup phase"));
3394 assert!(capture.contains_message("execute phase"));
3395 assert!(capture.contains_message("teardown phase"));
3396 }
3397
3398 #[test]
3399 fn test_log_capture_timings_recorded() {
3400 let capture = TestLogger::capture(|_logger| {
3401 let mut sum = 0;
3403 for i in 0..1000 {
3404 sum += i;
3405 }
3406 sum
3407 });
3408
3409 assert!(capture.passed());
3410 let timings = &capture.timings;
3412 assert!(timings.total() >= std::time::Duration::ZERO);
3413 }
3414
3415 #[test]
3416 fn test_log_capture_failure_context() {
3417 let capture = TestLogger::capture(|logger| {
3418 logger.log_message(LogLevel::Info, "step 1", 1);
3419 logger.log_message(LogLevel::Info, "step 2", 2);
3420 logger.log_message(LogLevel::Error, "something went wrong", 3);
3421 logger.log_message(LogLevel::Info, "step 3", 4);
3422 });
3423
3424 let context = capture.failure_context(3);
3425 assert!(context.contains("something went wrong") || context.contains("step 3"));
3427 }
3428
3429 #[test]
3430 fn test_captured_log_format() {
3431 let log = CapturedLog::new(LogLevel::Warn, "test warning message", 12345);
3432
3433 let formatted = log.format();
3434 assert!(formatted.contains("[W]"));
3436 assert!(formatted.contains("test warning message"));
3437 assert!(formatted.contains("12345"));
3438 }
3439
3440 #[test]
3441 fn test_captured_log_contains() {
3442 let log = CapturedLog::new(LogLevel::Info, "user login successful for alice", 1);
3443
3444 assert!(log.contains("login"));
3445 assert!(log.contains("alice"));
3446 assert!(log.contains("successful"));
3447 assert!(!log.contains("bob"));
3448 assert!(!log.contains("failed"));
3449 }
3450
3451 #[test]
3452 fn test_captured_log_fields() {
3453 let log = CapturedLog::new(LogLevel::Error, "database connection failed", 999);
3454
3455 assert_eq!(log.level, LogLevel::Error);
3456 assert_eq!(log.message, "database connection failed");
3457 assert_eq!(log.request_id, 999);
3458 }
3459
3460 #[test]
3461 fn test_multiple_loggers_isolated() {
3462 let logger1 = TestLogger::new();
3463 let logger2 = TestLogger::new();
3464
3465 logger1.log_message(LogLevel::Info, "from logger 1", 1);
3466 logger2.log_message(LogLevel::Info, "from logger 2", 2);
3467
3468 assert_eq!(logger1.logs().len(), 1);
3469 assert_eq!(logger2.logs().len(), 1);
3470 assert!(logger1.contains_message("logger 1"));
3471 assert!(!logger1.contains_message("logger 2"));
3472 assert!(logger2.contains_message("logger 2"));
3473 assert!(!logger2.contains_message("logger 1"));
3474 }
3475
3476 #[test]
3477 fn test_logger_log_entry_integration() {
3478 let logger = TestLogger::new();
3479
3480 let entry = LogEntry {
3481 level: LogLevel::Warn,
3482 message: "warning from entry".to_string(),
3483 request_id: 42,
3484 region_id: "region-1".to_string(),
3485 task_id: "task-1".to_string(),
3486 target: None,
3487 fields: Vec::new(),
3488 timestamp_ns: 0,
3489 };
3490
3491 logger.log_entry(&entry);
3492
3493 assert_eq!(logger.logs().len(), 1);
3494 let captured = &logger.logs()[0];
3495 assert_eq!(captured.level, LogLevel::Warn);
3496 assert!(captured.contains("warning from entry"));
3497 assert_eq!(captured.request_id, 42);
3498 }
3499
3500 #[test]
3501 fn test_log_capture_unwrap_on_success() {
3502 let capture = TestLogger::capture(|_| 123);
3503 let value = capture.unwrap();
3504 assert_eq!(value, 123);
3505 }
3506
3507 #[test]
3508 fn test_log_capture_unwrap_or_on_success() {
3509 let capture = TestLogger::capture(|_| 456);
3510 let value = capture.unwrap_or(0);
3511 assert_eq!(value, 456);
3512 }
3513}
3514
3515use std::io::{Read as _, Write as _};
3520use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
3521use std::sync::atomic::AtomicBool;
3522use std::thread;
3523use std::time::Duration;
3524
3525#[derive(Debug, Clone)]
3530pub struct RecordedRequest {
3531 pub method: String,
3533 pub path: String,
3535 pub query: Option<String>,
3537 pub headers: Vec<(String, String)>,
3539 pub body: Vec<u8>,
3541 pub timestamp: std::time::Instant,
3543}
3544
3545impl RecordedRequest {
3546 #[must_use]
3552 pub fn body_text(&self) -> &str {
3553 std::str::from_utf8(&self.body).expect("body is not valid UTF-8")
3554 }
3555
3556 #[must_use]
3558 pub fn header(&self, name: &str) -> Option<&str> {
3559 let name_lower = name.to_ascii_lowercase();
3560 self.headers
3561 .iter()
3562 .find(|(n, _)| n.to_ascii_lowercase() == name_lower)
3563 .map(|(_, v)| v.as_str())
3564 }
3565
3566 #[must_use]
3568 pub fn url(&self) -> String {
3569 match &self.query {
3570 Some(q) => format!("{}?{}", self.path, q),
3571 None => self.path.clone(),
3572 }
3573 }
3574}
3575
3576#[derive(Debug, Clone)]
3578pub struct MockResponse {
3579 pub status: u16,
3581 pub headers: Vec<(String, String)>,
3583 pub body: Vec<u8>,
3585 pub delay: Option<Duration>,
3587}
3588
3589impl Default for MockResponse {
3590 fn default() -> Self {
3591 Self {
3592 status: 200,
3593 headers: vec![("content-type".to_string(), "text/plain".to_string())],
3594 body: b"OK".to_vec(),
3595 delay: None,
3596 }
3597 }
3598}
3599
3600impl MockResponse {
3601 #[must_use]
3603 pub fn ok() -> Self {
3604 Self::default()
3605 }
3606
3607 #[must_use]
3609 pub fn with_status(status: u16) -> Self {
3610 Self {
3611 status,
3612 ..Default::default()
3613 }
3614 }
3615
3616 #[must_use]
3618 pub fn status(mut self, status: u16) -> Self {
3619 self.status = status;
3620 self
3621 }
3622
3623 #[must_use]
3625 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
3626 self.headers.push((name.into(), value.into()));
3627 self
3628 }
3629
3630 #[must_use]
3632 pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
3633 self.body = body.into();
3634 self
3635 }
3636
3637 #[must_use]
3639 pub fn body_str(self, body: &str) -> Self {
3640 self.body(body.as_bytes().to_vec())
3641 }
3642
3643 #[must_use]
3645 pub fn json<T: serde::Serialize>(mut self, value: &T) -> Self {
3646 self.body = serde_json::to_vec(value).expect("JSON serialization failed");
3647 self.headers
3648 .push(("content-type".to_string(), "application/json".to_string()));
3649 self
3650 }
3651
3652 #[must_use]
3654 pub fn delay(mut self, duration: Duration) -> Self {
3655 self.delay = Some(duration);
3656 self
3657 }
3658
3659 fn to_http_response(&self) -> Vec<u8> {
3661 let status_text = match self.status {
3662 200 => "OK",
3663 201 => "Created",
3664 204 => "No Content",
3665 400 => "Bad Request",
3666 401 => "Unauthorized",
3667 403 => "Forbidden",
3668 404 => "Not Found",
3669 500 => "Internal Server Error",
3670 502 => "Bad Gateway",
3671 503 => "Service Unavailable",
3672 504 => "Gateway Timeout",
3673 _ => "Unknown",
3674 };
3675
3676 let mut response = format!("HTTP/1.1 {} {}\r\n", self.status, status_text);
3677
3678 response.push_str(&format!("content-length: {}\r\n", self.body.len()));
3680
3681 for (name, value) in &self.headers {
3683 response.push_str(&format!("{}: {}\r\n", name, value));
3684 }
3685
3686 response.push_str("\r\n");
3687
3688 let mut bytes = response.into_bytes();
3689 bytes.extend_from_slice(&self.body);
3690 bytes
3691 }
3692}
3693
3694pub struct MockServer {
3728 addr: SocketAddr,
3729 requests: Arc<Mutex<Vec<RecordedRequest>>>,
3730 responses: Arc<Mutex<HashMap<String, MockResponse>>>,
3731 default_response: Arc<Mutex<MockResponse>>,
3732 shutdown: Arc<AtomicBool>,
3733 handle: Option<thread::JoinHandle<()>>,
3734}
3735
3736impl MockServer {
3737 #[must_use]
3748 pub fn start() -> Self {
3749 Self::start_with_options(MockServerOptions::default())
3750 }
3751
3752 #[must_use]
3754 pub fn start_with_options(options: MockServerOptions) -> Self {
3755 let listener =
3757 StdTcpListener::bind("127.0.0.1:0").expect("Failed to bind mock server to port");
3758 let addr = listener.local_addr().expect("Failed to get local address");
3759
3760 listener
3762 .set_nonblocking(true)
3763 .expect("Failed to set non-blocking");
3764
3765 let requests = Arc::new(Mutex::new(Vec::new()));
3766 let responses = Arc::new(Mutex::new(HashMap::new()));
3767 let default_response = Arc::new(Mutex::new(options.default_response));
3768 let shutdown = Arc::new(AtomicBool::new(false));
3769
3770 let requests_clone = Arc::clone(&requests);
3771 let responses_clone = Arc::clone(&responses);
3772 let default_response_clone = Arc::clone(&default_response);
3773 let shutdown_clone = Arc::clone(&shutdown);
3774 let read_timeout = options.read_timeout;
3775
3776 let handle = thread::spawn(move || {
3777 Self::server_loop(
3778 listener,
3779 requests_clone,
3780 responses_clone,
3781 default_response_clone,
3782 shutdown_clone,
3783 read_timeout,
3784 );
3785 });
3786
3787 Self {
3788 addr,
3789 requests,
3790 responses,
3791 default_response,
3792 shutdown,
3793 handle: Some(handle),
3794 }
3795 }
3796
3797 fn server_loop(
3799 listener: StdTcpListener,
3800 requests: Arc<Mutex<Vec<RecordedRequest>>>,
3801 responses: Arc<Mutex<HashMap<String, MockResponse>>>,
3802 default_response: Arc<Mutex<MockResponse>>,
3803 shutdown: Arc<AtomicBool>,
3804 read_timeout: Duration,
3805 ) {
3806 loop {
3807 if shutdown.load(std::sync::atomic::Ordering::Acquire) {
3808 break;
3809 }
3810
3811 match listener.accept() {
3812 Ok((stream, _peer)) => {
3813 let requests = Arc::clone(&requests);
3815 let responses = Arc::clone(&responses);
3816 let default_response = Arc::clone(&default_response);
3817
3818 Self::handle_connection(
3820 stream,
3821 requests,
3822 responses,
3823 default_response,
3824 read_timeout,
3825 );
3826 }
3827 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
3828 thread::sleep(Duration::from_millis(10));
3830 }
3831 Err(e) => {
3832 eprintln!("MockServer accept error: {}", e);
3833 break;
3834 }
3835 }
3836 }
3837 }
3838
3839 fn handle_connection(
3841 mut stream: StdTcpStream,
3842 requests: Arc<Mutex<Vec<RecordedRequest>>>,
3843 responses: Arc<Mutex<HashMap<String, MockResponse>>>,
3844 default_response: Arc<Mutex<MockResponse>>,
3845 read_timeout: Duration,
3846 ) {
3847 let _ = stream.set_read_timeout(Some(read_timeout));
3849
3850 let mut buffer = vec![0u8; 8192];
3852 let Ok(bytes_read) = stream.read(&mut buffer) else {
3853 return;
3854 };
3855
3856 if bytes_read == 0 {
3857 return;
3858 }
3859
3860 buffer.truncate(bytes_read);
3861
3862 let Some(recorded) = Self::parse_request(&buffer) else {
3864 return;
3865 };
3866
3867 {
3869 let mut reqs = requests.lock();
3870 reqs.push(recorded.clone());
3871 }
3872
3873 let response = {
3875 let resps = responses.lock();
3876 match resps.get(&recorded.path) {
3877 Some(r) => r.clone(),
3878 None => {
3879 let mut matched = None;
3881 for (pattern, resp) in resps.iter() {
3882 if pattern.ends_with('*') {
3883 let prefix = &pattern[..pattern.len() - 1];
3884 if recorded.path.starts_with(prefix) {
3885 matched = Some(resp.clone());
3886 break;
3887 }
3888 }
3889 }
3890 matched.unwrap_or_else(|| default_response.lock().clone())
3891 }
3892 }
3893 };
3894
3895 if let Some(delay) = response.delay {
3897 thread::sleep(delay);
3898 }
3899
3900 let response_bytes = response.to_http_response();
3902 let _ = stream.write_all(&response_bytes);
3903 let _ = stream.flush();
3904 }
3905
3906 fn parse_request(data: &[u8]) -> Option<RecordedRequest> {
3908 let text = std::str::from_utf8(data).ok()?;
3909 let mut lines = text.lines();
3910
3911 let request_line = lines.next()?;
3913 let parts: Vec<&str> = request_line.split_whitespace().collect();
3914 if parts.len() < 2 {
3915 return None;
3916 }
3917
3918 let method = parts[0].to_string();
3919 let full_path = parts[1];
3920
3921 let (path, query) = if let Some(idx) = full_path.find('?') {
3923 (
3924 full_path[..idx].to_string(),
3925 Some(full_path[idx + 1..].to_string()),
3926 )
3927 } else {
3928 (full_path.to_string(), None)
3929 };
3930
3931 let mut headers = Vec::new();
3933 let mut content_length = 0usize;
3934 for line in lines.by_ref() {
3935 if line.is_empty() {
3936 break;
3937 }
3938 if let Some((name, value)) = line.split_once(':') {
3939 let name = name.trim().to_string();
3940 let value = value.trim().to_string();
3941 if name.eq_ignore_ascii_case("content-length") {
3942 content_length = value.parse().unwrap_or(0);
3943 }
3944 headers.push((name, value));
3945 }
3946 }
3947
3948 let body = if content_length > 0 {
3950 if let Some(body_start) = text.find("\r\n\r\n") {
3952 let body_start = body_start + 4;
3953 if body_start < data.len() {
3954 data[body_start..].to_vec()
3955 } else {
3956 Vec::new()
3957 }
3958 } else if let Some(body_start) = text.find("\n\n") {
3959 let body_start = body_start + 2;
3960 if body_start < data.len() {
3961 data[body_start..].to_vec()
3962 } else {
3963 Vec::new()
3964 }
3965 } else {
3966 Vec::new()
3967 }
3968 } else {
3969 Vec::new()
3970 };
3971
3972 Some(RecordedRequest {
3973 method,
3974 path,
3975 query,
3976 headers,
3977 body,
3978 timestamp: std::time::Instant::now(),
3979 })
3980 }
3981
3982 #[must_use]
3984 pub fn addr(&self) -> SocketAddr {
3985 self.addr
3986 }
3987
3988 #[must_use]
3990 pub fn url(&self) -> String {
3991 format!("http://{}", self.addr)
3992 }
3993
3994 #[must_use]
3996 pub fn url_for(&self, path: &str) -> String {
3997 let path = if path.starts_with('/') {
3998 path
3999 } else {
4000 &format!("/{}", path)
4001 };
4002 format!("http://{}{}", self.addr, path)
4003 }
4004
4005 pub fn mock_response(&self, path: impl Into<String>, response: MockResponse) {
4016 let mut responses = self.responses.lock();
4017 responses.insert(path.into(), response);
4018 }
4019
4020 pub fn set_default_response(&self, response: MockResponse) {
4022 let mut default = self.default_response.lock();
4023 *default = response;
4024 }
4025
4026 #[must_use]
4028 pub fn requests(&self) -> Vec<RecordedRequest> {
4029 let requests = self.requests.lock();
4030 requests.clone()
4031 }
4032
4033 #[must_use]
4035 pub fn request_count(&self) -> usize {
4036 let requests = self.requests.lock();
4037 requests.len()
4038 }
4039
4040 #[must_use]
4042 pub fn requests_for(&self, path: &str) -> Vec<RecordedRequest> {
4043 let requests = self.requests.lock();
4044 requests
4045 .iter()
4046 .filter(|r| r.path == path)
4047 .cloned()
4048 .collect()
4049 }
4050
4051 #[must_use]
4053 pub fn last_request(&self) -> Option<RecordedRequest> {
4054 let requests = self.requests.lock();
4055 requests.last().cloned()
4056 }
4057
4058 pub fn clear_requests(&self) {
4060 let mut requests = self.requests.lock();
4061 requests.clear();
4062 }
4063
4064 pub fn clear_responses(&self) {
4066 let mut responses = self.responses.lock();
4067 responses.clear();
4068 }
4069
4070 pub fn reset(&self) {
4072 self.clear_requests();
4073 self.clear_responses();
4074 }
4075
4076 pub fn wait_for_requests(&self, count: usize, timeout: Duration) -> bool {
4081 let start = std::time::Instant::now();
4082 loop {
4083 if self.request_count() >= count {
4084 return true;
4085 }
4086 if start.elapsed() >= timeout {
4087 return false;
4088 }
4089 thread::sleep(Duration::from_millis(10));
4090 }
4091 }
4092
4093 pub fn assert_received(&self, path: &str) {
4099 let requests = self.requests_for(path);
4100 assert!(
4101 !requests.is_empty(),
4102 "Expected request to path '{}', but none was received. Received paths: {:?}",
4103 path,
4104 self.requests().iter().map(|r| &r.path).collect::<Vec<_>>()
4105 );
4106 }
4107
4108 pub fn assert_not_received(&self, path: &str) {
4114 let requests = self.requests_for(path);
4115 assert!(
4116 requests.is_empty(),
4117 "Expected no request to path '{}', but {} were received",
4118 path,
4119 requests.len()
4120 );
4121 }
4122
4123 pub fn assert_request_count(&self, expected: usize) {
4129 let actual = self.request_count();
4130 assert_eq!(
4131 actual, expected,
4132 "Expected {} requests, but received {}",
4133 expected, actual
4134 );
4135 }
4136}
4137
4138impl Drop for MockServer {
4139 fn drop(&mut self) {
4140 self.shutdown
4142 .store(true, std::sync::atomic::Ordering::Release);
4143
4144 if let Some(handle) = self.handle.take() {
4146 let _ = handle.join();
4147 }
4148 }
4149}
4150
4151#[derive(Debug, Clone)]
4153pub struct MockServerOptions {
4154 pub default_response: MockResponse,
4156 pub read_timeout: Duration,
4158}
4159
4160impl Default for MockServerOptions {
4161 fn default() -> Self {
4162 Self {
4163 default_response: MockResponse::with_status(404).body_str("Not Found"),
4164 read_timeout: Duration::from_secs(5),
4165 }
4166 }
4167}
4168
4169impl MockServerOptions {
4170 #[must_use]
4172 pub fn new() -> Self {
4173 Self::default()
4174 }
4175
4176 #[must_use]
4178 pub fn default_response(mut self, response: MockResponse) -> Self {
4179 self.default_response = response;
4180 self
4181 }
4182
4183 #[must_use]
4185 pub fn read_timeout(mut self, timeout: Duration) -> Self {
4186 self.read_timeout = timeout;
4187 self
4188 }
4189}
4190
4191#[derive(Debug, Clone)]
4200pub struct TestServerLogEntry {
4201 pub method: String,
4203 pub path: String,
4205 pub status: u16,
4207 pub duration: Duration,
4209 pub timestamp: std::time::Instant,
4211}
4212
4213#[derive(Debug, Clone)]
4215pub struct TestServerConfig {
4216 pub read_timeout: Duration,
4218 pub log_requests: bool,
4220}
4221
4222impl Default for TestServerConfig {
4223 fn default() -> Self {
4224 Self {
4225 read_timeout: Duration::from_secs(5),
4226 log_requests: true,
4227 }
4228 }
4229}
4230
4231impl TestServerConfig {
4232 #[must_use]
4234 pub fn new() -> Self {
4235 Self::default()
4236 }
4237
4238 #[must_use]
4240 pub fn read_timeout(mut self, timeout: Duration) -> Self {
4241 self.read_timeout = timeout;
4242 self
4243 }
4244
4245 #[must_use]
4247 pub fn log_requests(mut self, log: bool) -> Self {
4248 self.log_requests = log;
4249 self
4250 }
4251}
4252
4253pub struct TestServer {
4313 addr: SocketAddr,
4314 shutdown: Arc<AtomicBool>,
4315 handle: Option<thread::JoinHandle<()>>,
4316 log_entries: Arc<Mutex<Vec<TestServerLogEntry>>>,
4317 shutdown_controller: crate::shutdown::ShutdownController,
4318}
4319
4320impl TestServer {
4321 #[must_use]
4331 pub fn start(app: crate::app::App) -> Self {
4332 Self::start_with_config(app, TestServerConfig::default())
4333 }
4334
4335 #[must_use]
4337 pub fn start_with_config(app: crate::app::App, config: TestServerConfig) -> Self {
4338 let listener =
4339 StdTcpListener::bind("127.0.0.1:0").expect("Failed to bind test server to port");
4340 let addr = listener.local_addr().expect("Failed to get local address");
4341
4342 listener
4343 .set_nonblocking(true)
4344 .expect("Failed to set non-blocking");
4345
4346 let app = Arc::new(app);
4347 let shutdown = Arc::new(AtomicBool::new(false));
4348 let log_entries = Arc::new(Mutex::new(Vec::new()));
4349 let shutdown_controller = crate::shutdown::ShutdownController::new();
4350
4351 let shutdown_clone = Arc::clone(&shutdown);
4352 let log_entries_clone = Arc::clone(&log_entries);
4353 let app_clone = Arc::clone(&app);
4354 let controller_clone = shutdown_controller.clone();
4355
4356 let handle = thread::spawn(move || {
4357 Self::server_loop(
4358 listener,
4359 app_clone,
4360 shutdown_clone,
4361 log_entries_clone,
4362 config,
4363 controller_clone,
4364 );
4365 });
4366
4367 Self {
4368 addr,
4369 shutdown,
4370 handle: Some(handle),
4371 log_entries,
4372 shutdown_controller,
4373 }
4374 }
4375
4376 fn server_loop(
4378 listener: StdTcpListener,
4379 app: Arc<crate::app::App>,
4380 shutdown: Arc<AtomicBool>,
4381 log_entries: Arc<Mutex<Vec<TestServerLogEntry>>>,
4382 config: TestServerConfig,
4383 controller: crate::shutdown::ShutdownController,
4384 ) {
4385 let request_counter = std::sync::atomic::AtomicU64::new(1);
4386
4387 loop {
4388 if shutdown.load(std::sync::atomic::Ordering::Acquire) {
4389 while let Some(hook) = controller.pop_hook() {
4391 hook.run();
4392 }
4393 break;
4394 }
4395
4396 match listener.accept() {
4397 Ok((stream, _peer)) => {
4398 let _guard = controller.track_request();
4400
4401 if controller.is_shutting_down() {
4403 Self::send_503(stream);
4404 continue;
4405 }
4406
4407 Self::handle_connection(stream, &app, &log_entries, &config, &request_counter);
4408 }
4409 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
4410 thread::sleep(Duration::from_millis(5));
4411 }
4412 Err(_) => {
4413 break;
4414 }
4415 }
4416 }
4417 }
4418
4419 fn handle_connection(
4421 mut stream: StdTcpStream,
4422 app: &Arc<crate::app::App>,
4423 log_entries: &Arc<Mutex<Vec<TestServerLogEntry>>>,
4424 config: &TestServerConfig,
4425 request_counter: &std::sync::atomic::AtomicU64,
4426 ) {
4427 let _ = stream.set_read_timeout(Some(config.read_timeout));
4428
4429 let mut buffer = vec![0u8; 65536];
4431 let bytes_read = match stream.read(&mut buffer) {
4432 Ok(n) if n > 0 => n,
4433 _ => return,
4434 };
4435 buffer.truncate(bytes_read);
4436
4437 let Some(parsed) = Self::parse_raw_request(&buffer) else {
4439 let bad_request = b"HTTP/1.1 400 Bad Request\r\ncontent-length: 11\r\n\r\nBad Request";
4441 let _ = stream.write_all(bad_request);
4442 let _ = stream.flush();
4443 return;
4444 };
4445
4446 let start_time = std::time::Instant::now();
4447
4448 let method = match parsed.method.to_uppercase().as_str() {
4450 "GET" => Method::Get,
4451 "POST" => Method::Post,
4452 "PUT" => Method::Put,
4453 "DELETE" => Method::Delete,
4454 "PATCH" => Method::Patch,
4455 "HEAD" => Method::Head,
4456 "OPTIONS" => Method::Options,
4457 _ => Method::Get,
4458 };
4459
4460 let mut request = Request::new(method, &parsed.path);
4461
4462 if let Some(ref query) = parsed.query {
4464 request.set_query(Some(query.clone()));
4465 }
4466
4467 for (name, value) in &parsed.headers {
4469 request
4470 .headers_mut()
4471 .insert(name.clone(), value.as_bytes().to_vec());
4472 }
4473
4474 if !parsed.body.is_empty() {
4476 request.set_body(Body::Bytes(parsed.body.clone()));
4477 }
4478
4479 let cx = Cx::for_testing();
4481 let request_id = request_counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
4482 let dependency_overrides = Handler::dependency_overrides(app.as_ref())
4483 .unwrap_or_else(|| Arc::new(crate::dependency::DependencyOverrides::new()));
4484 let ctx = RequestContext::with_overrides(cx, request_id, dependency_overrides);
4485
4486 let response = futures_executor::block_on(app.handle(&ctx, &mut request));
4488
4489 let duration = start_time.elapsed();
4490 let status_code = response.status().as_u16();
4491
4492 if config.log_requests {
4494 let entry = TestServerLogEntry {
4495 method: parsed.method.clone(),
4496 path: parsed.path.clone(),
4497 status: status_code,
4498 duration,
4499 timestamp: start_time,
4500 };
4501 log_entries.lock().push(entry);
4502 }
4503
4504 let response_bytes = Self::serialize_response(response);
4506 let _ = stream.write_all(&response_bytes);
4507 let _ = stream.flush();
4508 }
4509
4510 fn parse_raw_request(data: &[u8]) -> Option<ParsedRequest> {
4512 let text = std::str::from_utf8(data).ok()?;
4513 let mut lines = text.lines();
4514
4515 let request_line = lines.next()?;
4517 let parts: Vec<&str> = request_line.split_whitespace().collect();
4518 if parts.len() < 2 {
4519 return None;
4520 }
4521
4522 let method = parts[0].to_string();
4523 let full_path = parts[1];
4524
4525 let (path, query) = if let Some(idx) = full_path.find('?') {
4527 (
4528 full_path[..idx].to_string(),
4529 Some(full_path[idx + 1..].to_string()),
4530 )
4531 } else {
4532 (full_path.to_string(), None)
4533 };
4534
4535 let mut headers = Vec::new();
4537 let mut content_length = 0usize;
4538 for line in lines.by_ref() {
4539 if line.is_empty() {
4540 break;
4541 }
4542 if let Some((name, value)) = line.split_once(':') {
4543 let name = name.trim().to_string();
4544 let value = value.trim().to_string();
4545 if name.eq_ignore_ascii_case("content-length") {
4546 content_length = value.parse().unwrap_or(0);
4547 }
4548 headers.push((name, value));
4549 }
4550 }
4551
4552 let body = if content_length > 0 {
4554 if let Some(body_start) = text.find("\r\n\r\n") {
4555 let body_start = body_start + 4;
4556 if body_start < data.len() {
4557 data[body_start..].to_vec()
4558 } else {
4559 Vec::new()
4560 }
4561 } else if let Some(body_start) = text.find("\n\n") {
4562 let body_start = body_start + 2;
4563 if body_start < data.len() {
4564 data[body_start..].to_vec()
4565 } else {
4566 Vec::new()
4567 }
4568 } else {
4569 Vec::new()
4570 }
4571 } else {
4572 Vec::new()
4573 };
4574
4575 Some(ParsedRequest {
4576 method,
4577 path,
4578 query,
4579 headers,
4580 body,
4581 })
4582 }
4583
4584 fn serialize_response(response: Response) -> Vec<u8> {
4586 let (status, headers, body) = response.into_parts();
4587
4588 let body_bytes = match body {
4589 ResponseBody::Empty => Vec::new(),
4590 ResponseBody::Bytes(b) => b,
4591 ResponseBody::Stream(_) => {
4592 Vec::new()
4595 }
4596 };
4597
4598 let mut buf = Vec::with_capacity(512 + body_bytes.len());
4599
4600 buf.extend_from_slice(b"HTTP/1.1 ");
4602 buf.extend_from_slice(status.as_u16().to_string().as_bytes());
4603 buf.extend_from_slice(b" ");
4604 buf.extend_from_slice(status.canonical_reason().as_bytes());
4605 buf.extend_from_slice(b"\r\n");
4606
4607 for (name, value) in &headers {
4609 if name.eq_ignore_ascii_case("content-length")
4610 || name.eq_ignore_ascii_case("transfer-encoding")
4611 {
4612 continue;
4613 }
4614 buf.extend_from_slice(name.as_bytes());
4615 buf.extend_from_slice(b": ");
4616 buf.extend_from_slice(value);
4617 buf.extend_from_slice(b"\r\n");
4618 }
4619
4620 buf.extend_from_slice(b"content-length: ");
4622 buf.extend_from_slice(body_bytes.len().to_string().as_bytes());
4623 buf.extend_from_slice(b"\r\n");
4624
4625 buf.extend_from_slice(b"\r\n");
4627
4628 buf.extend_from_slice(&body_bytes);
4630
4631 buf
4632 }
4633
4634 #[must_use]
4636 pub fn addr(&self) -> SocketAddr {
4637 self.addr
4638 }
4639
4640 #[must_use]
4642 pub fn port(&self) -> u16 {
4643 self.addr.port()
4644 }
4645
4646 #[must_use]
4648 pub fn url(&self) -> String {
4649 format!("http://{}", self.addr)
4650 }
4651
4652 #[must_use]
4654 pub fn url_for(&self, path: &str) -> String {
4655 let path = if path.starts_with('/') {
4656 path.to_string()
4657 } else {
4658 format!("/{path}")
4659 };
4660 format!("http://{}{}", self.addr, path)
4661 }
4662
4663 #[must_use]
4665 pub fn log_entries(&self) -> Vec<TestServerLogEntry> {
4666 self.log_entries.lock().clone()
4667 }
4668
4669 #[must_use]
4671 pub fn request_count(&self) -> usize {
4672 self.log_entries.lock().len()
4673 }
4674
4675 pub fn clear_logs(&self) {
4677 self.log_entries.lock().clear();
4678 }
4679
4680 fn send_503(mut stream: StdTcpStream) {
4682 let response =
4683 b"HTTP/1.1 503 Service Unavailable\r\ncontent-length: 19\r\n\r\nService Unavailable";
4684 let _ = stream.write_all(response);
4685 let _ = stream.flush();
4686 }
4687
4688 #[must_use]
4695 pub fn shutdown_controller(&self) -> &crate::shutdown::ShutdownController {
4696 &self.shutdown_controller
4697 }
4698
4699 #[must_use]
4701 pub fn in_flight_count(&self) -> usize {
4702 self.shutdown_controller.in_flight_count()
4703 }
4704
4705 pub fn shutdown(&self) {
4711 self.shutdown_controller.shutdown();
4712 self.shutdown
4713 .store(true, std::sync::atomic::Ordering::Release);
4714 }
4715
4716 #[must_use]
4718 pub fn is_shutdown(&self) -> bool {
4719 self.shutdown.load(std::sync::atomic::Ordering::Acquire)
4720 }
4721}
4722
4723impl Drop for TestServer {
4724 fn drop(&mut self) {
4725 self.shutdown
4726 .store(true, std::sync::atomic::Ordering::Release);
4727 if let Some(handle) = self.handle.take() {
4728 let _ = handle.join();
4729 }
4730 }
4731}
4732
4733struct ParsedRequest {
4735 method: String,
4736 path: String,
4737 query: Option<String>,
4738 headers: Vec<(String, String)>,
4739 body: Vec<u8>,
4740}
4741
4742#[derive(Debug, Clone)]
4748pub enum E2EStepResult {
4749 Passed,
4751 Failed(String),
4753 Skipped,
4755}
4756
4757impl E2EStepResult {
4758 #[must_use]
4760 pub fn is_passed(&self) -> bool {
4761 matches!(self, Self::Passed)
4762 }
4763
4764 #[must_use]
4766 pub fn is_failed(&self) -> bool {
4767 matches!(self, Self::Failed(_))
4768 }
4769}
4770
4771#[derive(Debug, Clone)]
4773pub struct E2ECapture {
4774 pub method: String,
4776 pub path: String,
4778 pub request_headers: Vec<(String, String)>,
4780 pub request_body: Option<String>,
4782 pub response_status: u16,
4784 pub response_headers: Vec<(String, String)>,
4786 pub response_body: String,
4788}
4789
4790#[derive(Debug, Clone)]
4792pub struct E2EStep {
4793 pub name: String,
4795 pub started_at: std::time::Instant,
4797 pub duration: std::time::Duration,
4799 pub result: E2EStepResult,
4801 pub capture: Option<E2ECapture>,
4803}
4804
4805impl E2EStep {
4806 fn new(name: impl Into<String>) -> Self {
4808 Self {
4809 name: name.into(),
4810 started_at: std::time::Instant::now(),
4811 duration: std::time::Duration::ZERO,
4812 result: E2EStepResult::Skipped,
4813 capture: None,
4814 }
4815 }
4816
4817 fn complete(&mut self, result: E2EStepResult) {
4819 self.duration = self.started_at.elapsed();
4820 self.result = result;
4821 }
4822}
4823
4824pub struct E2EScenario<H> {
4855 name: String,
4857 description: Option<String>,
4859 client: TestClient<H>,
4861 steps: Vec<E2EStep>,
4863 stop_on_failure: bool,
4865 has_failure: bool,
4867 log_buffer: Vec<String>,
4869}
4870
4871impl<H: Handler + 'static> E2EScenario<H> {
4872 pub fn new(name: impl Into<String>, client: TestClient<H>) -> Self {
4874 let name = name.into();
4875 Self {
4876 name,
4877 description: None,
4878 client,
4879 steps: Vec::new(),
4880 stop_on_failure: true,
4881 has_failure: false,
4882 log_buffer: Vec::new(),
4883 }
4884 }
4885
4886 #[must_use]
4888 pub fn description(mut self, desc: impl Into<String>) -> Self {
4889 self.description = Some(desc.into());
4890 self
4891 }
4892
4893 #[must_use]
4895 pub fn stop_on_failure(mut self, stop: bool) -> Self {
4896 self.stop_on_failure = stop;
4897 self
4898 }
4899
4900 pub fn client(&self) -> &TestClient<H> {
4902 &self.client
4903 }
4904
4905 pub fn client_mut(&mut self) -> &mut TestClient<H> {
4907 &mut self.client
4908 }
4909
4910 pub fn log(&mut self, message: impl Into<String>) {
4912 let msg = message.into();
4913 self.log_buffer.push(format!(
4914 "[{:?}] {}",
4915 std::time::Instant::now().elapsed(),
4916 msg
4917 ));
4918 }
4919
4920 pub fn step<F>(&mut self, name: impl Into<String>, f: F)
4925 where
4926 F: FnOnce(&TestClient<H>) + std::panic::UnwindSafe,
4927 {
4928 let name = name.into();
4929 let mut step = E2EStep::new(&name);
4930
4931 if self.has_failure && self.stop_on_failure {
4933 step.complete(E2EStepResult::Skipped);
4934 self.log_buffer.push(format!("[SKIP] {}", name));
4935 self.steps.push(step);
4936 return;
4937 }
4938
4939 self.log_buffer.push(format!("[START] {}", name));
4940
4941 let client_ref = std::panic::AssertUnwindSafe(&self.client);
4943
4944 let result = std::panic::catch_unwind(|| {
4946 f(&client_ref);
4947 });
4948
4949 match result {
4950 Ok(()) => {
4951 step.complete(E2EStepResult::Passed);
4952 self.log_buffer
4953 .push(format!("[PASS] {} ({:?})", name, step.duration));
4954 }
4955 Err(panic_info) => {
4956 let error_msg = if let Some(s) = panic_info.downcast_ref::<&str>() {
4957 (*s).to_string()
4958 } else if let Some(s) = panic_info.downcast_ref::<String>() {
4959 s.clone()
4960 } else {
4961 "Unknown panic".to_string()
4962 };
4963
4964 step.complete(E2EStepResult::Failed(error_msg.clone()));
4965 self.has_failure = true;
4966 self.log_buffer
4967 .push(format!("[FAIL] {} - {}", name, error_msg));
4968 }
4969 }
4970
4971 self.steps.push(step);
4972 }
4973
4974 pub fn try_step<F, E>(&mut self, name: impl Into<String>, f: F) -> Result<(), E>
4976 where
4977 F: FnOnce(&TestClient<H>) -> Result<(), E>,
4978 E: std::fmt::Display,
4979 {
4980 let name = name.into();
4981 let mut step = E2EStep::new(&name);
4982
4983 if self.has_failure && self.stop_on_failure {
4984 step.complete(E2EStepResult::Skipped);
4985 self.steps.push(step);
4986 return Ok(());
4987 }
4988
4989 self.log_buffer.push(format!("[START] {}", name));
4990
4991 match f(&self.client) {
4992 Ok(()) => {
4993 step.complete(E2EStepResult::Passed);
4994 self.log_buffer
4995 .push(format!("[PASS] {} ({:?})", name, step.duration));
4996 self.steps.push(step);
4997 Ok(())
4998 }
4999 Err(e) => {
5000 let error_msg = e.to_string();
5001 step.complete(E2EStepResult::Failed(error_msg.clone()));
5002 self.has_failure = true;
5003 self.log_buffer
5004 .push(format!("[FAIL] {} - {}", name, error_msg));
5005 self.steps.push(step);
5006 Err(e)
5007 }
5008 }
5009 }
5010
5011 #[must_use]
5013 pub fn passed(&self) -> bool {
5014 !self.has_failure
5015 }
5016
5017 #[must_use]
5019 pub fn steps(&self) -> &[E2EStep] {
5020 &self.steps
5021 }
5022
5023 #[must_use]
5025 pub fn logs(&self) -> &[String] {
5026 &self.log_buffer
5027 }
5028
5029 #[must_use]
5031 pub fn report(&self) -> E2EReport {
5032 let passed = self.steps.iter().filter(|s| s.result.is_passed()).count();
5033 let failed = self.steps.iter().filter(|s| s.result.is_failed()).count();
5034 let skipped = self
5035 .steps
5036 .iter()
5037 .filter(|s| matches!(s.result, E2EStepResult::Skipped))
5038 .count();
5039 let total_duration: std::time::Duration = self.steps.iter().map(|s| s.duration).sum();
5040
5041 E2EReport {
5042 scenario_name: self.name.clone(),
5043 description: self.description.clone(),
5044 passed,
5045 failed,
5046 skipped,
5047 total_duration,
5048 steps: self.steps.clone(),
5049 logs: self.log_buffer.clone(),
5050 }
5051 }
5052
5053 pub fn assert_passed(&self) {
5057 if !self.passed() {
5058 let report = self.report();
5059 panic!(
5060 "E2E Scenario '{}' failed!\n\n{}",
5061 self.name,
5062 report.to_text()
5063 );
5064 }
5065 }
5066}
5067
5068#[derive(Debug, Clone)]
5070pub struct E2EReport {
5071 pub scenario_name: String,
5073 pub description: Option<String>,
5075 pub passed: usize,
5077 pub failed: usize,
5079 pub skipped: usize,
5081 pub total_duration: std::time::Duration,
5083 pub steps: Vec<E2EStep>,
5085 pub logs: Vec<String>,
5087}
5088
5089impl E2EReport {
5090 #[must_use]
5092 pub fn to_text(&self) -> String {
5093 let mut output = String::new();
5094
5095 output.push_str(&format!("E2E Test Report: {}\n", self.scenario_name));
5097 output.push_str(&"=".repeat(60));
5098 output.push('\n');
5099
5100 if let Some(desc) = &self.description {
5101 output.push_str(&format!("Description: {}\n", desc));
5102 }
5103
5104 output.push_str(&format!(
5106 "\nSummary: {} passed, {} failed, {} skipped\n",
5107 self.passed, self.failed, self.skipped
5108 ));
5109 output.push_str(&format!("Total Duration: {:?}\n", self.total_duration));
5110 output.push_str(&"-".repeat(60));
5111 output.push('\n');
5112
5113 output.push_str("\nSteps:\n");
5115 for (i, step) in self.steps.iter().enumerate() {
5116 let status = match &step.result {
5117 E2EStepResult::Passed => "[PASS]",
5118 E2EStepResult::Failed(_) => "[FAIL]",
5119 E2EStepResult::Skipped => "[SKIP]",
5120 };
5121 output.push_str(&format!(
5122 " {}. {} {} ({:?})\n",
5123 i + 1,
5124 status,
5125 step.name,
5126 step.duration
5127 ));
5128 if let E2EStepResult::Failed(msg) = &step.result {
5129 output.push_str(&format!(" Error: {}\n", msg));
5130 }
5131 }
5132
5133 if !self.logs.is_empty() {
5135 output.push_str(&"-".repeat(60));
5136 output.push_str("\n\nLogs:\n");
5137 for log in &self.logs {
5138 output.push_str(&format!(" {}\n", log));
5139 }
5140 }
5141
5142 output
5143 }
5144
5145 #[must_use]
5147 pub fn to_json(&self) -> String {
5148 let steps_json: Vec<String> = self
5149 .steps
5150 .iter()
5151 .map(|step| {
5152 let status = match &step.result {
5153 E2EStepResult::Passed => "passed",
5154 E2EStepResult::Failed(_) => "failed",
5155 E2EStepResult::Skipped => "skipped",
5156 };
5157 let error = match &step.result {
5158 E2EStepResult::Failed(msg) => format!(r#", "error": "{}""#, escape_json(msg)),
5159 _ => String::new(),
5160 };
5161 format!(
5162 r#" {{ "name": "{}", "status": "{}", "duration_ms": {}{} }}"#,
5163 escape_json(&step.name),
5164 status,
5165 step.duration.as_millis(),
5166 error
5167 )
5168 })
5169 .collect();
5170
5171 format!(
5172 r#"{{
5173 "scenario": "{}",
5174 "description": {},
5175 "summary": {{
5176 "passed": {},
5177 "failed": {},
5178 "skipped": {},
5179 "total_duration_ms": {}
5180 }},
5181 "steps": [
5182{}
5183 ]
5184}}"#,
5185 escape_json(&self.scenario_name),
5186 self.description
5187 .as_ref()
5188 .map_or("null".to_string(), |d| format!(r#""{}""#, escape_json(d))),
5189 self.passed,
5190 self.failed,
5191 self.skipped,
5192 self.total_duration.as_millis(),
5193 steps_json.join(",\n")
5194 )
5195 }
5196
5197 #[must_use]
5199 pub fn to_html(&self) -> String {
5200 let status_class = if self.failed > 0 { "failed" } else { "passed" };
5201
5202 use std::fmt::Write;
5203 let steps_html =
5204 self.steps
5205 .iter()
5206 .enumerate()
5207 .fold(String::new(), |mut output, (i, step)| {
5208 let (status, class) = match &step.result {
5209 E2EStepResult::Passed => ("✓", "pass"),
5210 E2EStepResult::Failed(_) => ("✗", "fail"),
5211 E2EStepResult::Skipped => ("○", "skip"),
5212 };
5213 let error_html = match &step.result {
5214 E2EStepResult::Failed(msg) => {
5215 format!(r#"<div class="error">{}</div>"#, escape_html(msg))
5216 }
5217 _ => String::new(),
5218 };
5219 let _ = write!(
5220 output,
5221 r#" <tr class="{}">
5222 <td>{}</td>
5223 <td><span class="status">{}</span></td>
5224 <td>{}</td>
5225 <td>{:?}</td>
5226 </tr>
5227 {}"#,
5228 class,
5229 i + 1,
5230 status,
5231 escape_html(&step.name),
5232 step.duration,
5233 error_html
5234 );
5235 output
5236 });
5237
5238 format!(
5239 r#"<!DOCTYPE html>
5240<html>
5241<head>
5242 <title>E2E Report: {}</title>
5243 <style>
5244 body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; margin: 2rem; }}
5245 h1 {{ color: #333; }}
5246 .summary {{ padding: 1rem; border-radius: 8px; margin: 1rem 0; }}
5247 .summary.passed {{ background: #d4edda; }}
5248 .summary.failed {{ background: #f8d7da; }}
5249 table {{ width: 100%; border-collapse: collapse; margin-top: 1rem; }}
5250 th, td {{ padding: 0.75rem; text-align: left; border-bottom: 1px solid #dee2e6; }}
5251 th {{ background: #f8f9fa; }}
5252 .pass {{ color: #28a745; }}
5253 .fail {{ color: #dc3545; }}
5254 .skip {{ color: #6c757d; }}
5255 .status {{ font-size: 1.2rem; }}
5256 .error {{ color: #dc3545; font-size: 0.9rem; padding: 0.5rem; background: #fff; margin-top: 0.25rem; }}
5257 </style>
5258</head>
5259<body>
5260 <h1>E2E Report: {}</h1>
5261 {}
5262 <div class="summary {}">
5263 <strong>Summary:</strong> {} passed, {} failed, {} skipped<br>
5264 <strong>Duration:</strong> {:?}
5265 </div>
5266 <table>
5267 <thead>
5268 <tr><th>#</th><th>Status</th><th>Step</th><th>Duration</th></tr>
5269 </thead>
5270 <tbody>
5271{}
5272 </tbody>
5273 </table>
5274</body>
5275</html>"#,
5276 escape_html(&self.scenario_name),
5277 escape_html(&self.scenario_name),
5278 self.description
5279 .as_ref()
5280 .map_or(String::new(), |d| format!("<p>{}</p>", escape_html(d))),
5281 status_class,
5282 self.passed,
5283 self.failed,
5284 self.skipped,
5285 self.total_duration,
5286 steps_html
5287 )
5288 }
5289}
5290
5291fn escape_json(s: &str) -> String {
5293 s.replace('\\', "\\\\")
5294 .replace('"', "\\\"")
5295 .replace('\n', "\\n")
5296 .replace('\r', "\\r")
5297 .replace('\t', "\\t")
5298}
5299
5300fn escape_html(s: &str) -> String {
5302 s.replace('&', "&")
5303 .replace('<', "<")
5304 .replace('>', ">")
5305 .replace('"', """)
5306}
5307
5308#[macro_export]
5341macro_rules! e2e_test {
5342 (
5343 name: $name:expr,
5344 $(description: $desc:expr,)?
5345 client: $client:expr,
5346 $(step $step_name:literal => |$client_param:ident| $step_body:block),+ $(,)?
5347 ) => {{
5348 let client = $client;
5349 let mut scenario = $crate::testing::E2EScenario::new($name, client);
5350 $(
5351 scenario = scenario.description($desc);
5352 )?
5353 $(
5354 scenario.step($step_name, |$client_param| $step_body);
5355 )+
5356 scenario.assert_passed();
5357 scenario.report()
5358 }};
5359}
5360
5361pub use e2e_test;
5362
5363use crate::logging::{LogEntry, LogLevel};
5368
5369#[derive(Debug, Clone)]
5371pub struct CapturedLog {
5372 pub level: LogLevel,
5374 pub message: String,
5376 pub request_id: u64,
5378 pub captured_at: std::time::Instant,
5380 pub fields: Vec<(String, String)>,
5382 pub target: Option<String>,
5384}
5385
5386impl CapturedLog {
5387 pub fn from_entry(entry: &LogEntry) -> Self {
5389 Self {
5390 level: entry.level,
5391 message: entry.message.clone(),
5392 request_id: entry.request_id,
5393 captured_at: std::time::Instant::now(),
5394 fields: entry.fields.clone(),
5395 target: entry.target.clone(),
5396 }
5397 }
5398
5399 pub fn new(level: LogLevel, message: impl Into<String>, request_id: u64) -> Self {
5401 Self {
5402 level,
5403 message: message.into(),
5404 request_id,
5405 captured_at: std::time::Instant::now(),
5406 fields: Vec::new(),
5407 target: None,
5408 }
5409 }
5410
5411 #[must_use]
5413 pub fn contains(&self, text: &str) -> bool {
5414 self.message.contains(text)
5415 }
5416
5417 #[must_use]
5419 pub fn format(&self) -> String {
5420 let mut output = format!(
5421 "[{}] req={} {}",
5422 self.level.as_char(),
5423 self.request_id,
5424 self.message
5425 );
5426 if !self.fields.is_empty() {
5427 output.push_str(" {");
5428 for (i, (k, v)) in self.fields.iter().enumerate() {
5429 if i > 0 {
5430 output.push_str(", ");
5431 }
5432 output.push_str(&format!("{k}={v}"));
5433 }
5434 output.push('}');
5435 }
5436 output
5437 }
5438}
5439
5440#[derive(Debug, Clone)]
5465pub struct TestLogger {
5466 logs: Arc<Mutex<Vec<CapturedLog>>>,
5468 timings: Arc<Mutex<TestTimings>>,
5470 echo_logs: bool,
5472}
5473
5474#[derive(Debug, Clone, Default)]
5476pub struct TestTimings {
5477 pub setup: Option<std::time::Duration>,
5479 pub execute: Option<std::time::Duration>,
5481 pub teardown: Option<std::time::Duration>,
5483 phase_start: Option<std::time::Instant>,
5485}
5486
5487impl TestTimings {
5488 pub fn start_phase(&mut self) {
5490 self.phase_start = Some(std::time::Instant::now());
5491 }
5492
5493 pub fn end_setup(&mut self) {
5495 if let Some(start) = self.phase_start.take() {
5496 self.setup = Some(start.elapsed());
5497 }
5498 }
5499
5500 pub fn end_execute(&mut self) {
5502 if let Some(start) = self.phase_start.take() {
5503 self.execute = Some(start.elapsed());
5504 }
5505 }
5506
5507 pub fn end_teardown(&mut self) {
5509 if let Some(start) = self.phase_start.take() {
5510 self.teardown = Some(start.elapsed());
5511 }
5512 }
5513
5514 #[must_use]
5516 pub fn total(&self) -> std::time::Duration {
5517 self.setup.unwrap_or_default()
5518 + self.execute.unwrap_or_default()
5519 + self.teardown.unwrap_or_default()
5520 }
5521
5522 #[must_use]
5524 pub fn format(&self) -> String {
5525 format!(
5526 "Timings: setup={:?}, execute={:?}, teardown={:?}, total={:?}",
5527 self.setup.unwrap_or_default(),
5528 self.execute.unwrap_or_default(),
5529 self.teardown.unwrap_or_default(),
5530 self.total()
5531 )
5532 }
5533}
5534
5535impl TestLogger {
5536 pub fn new() -> Self {
5538 Self {
5539 logs: Arc::new(Mutex::new(Vec::new())),
5540 timings: Arc::new(Mutex::new(TestTimings::default())),
5541 echo_logs: std::env::var("FASTAPI_TEST_ECHO_LOGS").is_ok(),
5542 }
5543 }
5544
5545 pub fn with_echo() -> Self {
5547 let mut logger = Self::new();
5548 logger.echo_logs = true;
5549 logger
5550 }
5551
5552 pub fn log(&self, entry: CapturedLog) {
5554 if self.echo_logs {
5555 eprintln!("[LOG] {}", entry.format());
5556 }
5557 self.logs.lock().push(entry);
5558 }
5559
5560 pub fn log_entry(&self, entry: &LogEntry) {
5562 self.log(CapturedLog::from_entry(entry));
5563 }
5564
5565 pub fn log_message(&self, level: LogLevel, message: impl Into<String>, request_id: u64) {
5567 self.log(CapturedLog::new(level, message, request_id));
5568 }
5569
5570 #[must_use]
5572 pub fn logs(&self) -> Vec<CapturedLog> {
5573 self.logs.lock().clone()
5574 }
5575
5576 #[must_use]
5578 pub fn count(&self) -> usize {
5579 self.logs.lock().len()
5580 }
5581
5582 pub fn clear(&self) {
5584 self.logs.lock().clear();
5585 }
5586
5587 #[must_use]
5589 pub fn contains_message(&self, text: &str) -> bool {
5590 self.logs.lock().iter().any(|log| log.contains(text))
5591 }
5592
5593 #[must_use]
5595 pub fn count_by_level(&self, level: LogLevel) -> usize {
5596 self.logs
5597 .lock()
5598 .iter()
5599 .filter(|log| log.level == level)
5600 .count()
5601 }
5602
5603 #[must_use]
5605 pub fn logs_at_level(&self, level: LogLevel) -> Vec<CapturedLog> {
5606 self.logs
5607 .lock()
5608 .iter()
5609 .filter(|log| log.level == level)
5610 .cloned()
5611 .collect()
5612 }
5613
5614 #[must_use]
5616 pub fn failure_context(&self, n: usize) -> String {
5617 let logs = self.logs.lock();
5618 let start = logs.len().saturating_sub(n);
5619 let recent: Vec<_> = logs[start..].iter().map(CapturedLog::format).collect();
5620
5621 if recent.is_empty() {
5622 "No logs captured".to_string()
5623 } else {
5624 format!(
5625 "Last {} log(s) before failure:\n {}",
5626 recent.len(),
5627 recent.join("\n ")
5628 )
5629 }
5630 }
5631
5632 #[must_use]
5634 pub fn timings(&self) -> TestTimings {
5635 self.timings.lock().clone()
5636 }
5637
5638 pub fn start_phase(&self) {
5640 self.timings.lock().start_phase();
5641 }
5642
5643 pub fn end_setup(&self) {
5645 self.timings.lock().end_setup();
5646 }
5647
5648 pub fn end_execute(&self) {
5650 self.timings.lock().end_execute();
5651 }
5652
5653 pub fn end_teardown(&self) {
5655 self.timings.lock().end_teardown();
5656 }
5657
5658 pub fn capture<F, T>(f: F) -> LogCapture<T>
5662 where
5663 F: FnOnce(&TestLogger) -> T,
5664 {
5665 let logger = TestLogger::new();
5666
5667 logger.start_phase();
5668 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
5669 logger.end_setup();
5670 logger.start_phase();
5671 let result = f(&logger);
5672 logger.end_execute();
5673 result
5674 }));
5675
5676 let (ok_result, panic_info) = match result {
5677 Ok(v) => (Some(v), None),
5678 Err(p) => {
5679 let msg = if let Some(s) = p.downcast_ref::<&str>() {
5680 (*s).to_string()
5681 } else if let Some(s) = p.downcast_ref::<String>() {
5682 s.clone()
5683 } else {
5684 "Unknown panic".to_string()
5685 };
5686 (None, Some(msg))
5687 }
5688 };
5689
5690 LogCapture {
5691 logs: logger.logs(),
5692 timings: logger.timings(),
5693 result: ok_result,
5694 panic_info,
5695 }
5696 }
5697
5698 pub fn capture_phased<S, E, D, T>(setup: S, execute: E, teardown: D) -> LogCapture<T>
5700 where
5701 S: FnOnce(&TestLogger),
5702 E: FnOnce(&TestLogger) -> T,
5703 D: FnOnce(&TestLogger),
5704 {
5705 let logger = TestLogger::new();
5706
5707 logger.start_phase();
5709 let setup_panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
5710 setup(&logger);
5711 }));
5712 logger.end_setup();
5713
5714 if setup_panic.is_err() {
5715 return LogCapture {
5716 logs: logger.logs(),
5717 timings: logger.timings(),
5718 result: None,
5719 panic_info: Some("Setup phase panicked".to_string()),
5720 };
5721 }
5722
5723 logger.start_phase();
5725 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| execute(&logger)));
5726 logger.end_execute();
5727
5728 logger.start_phase();
5730 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
5731 teardown(&logger);
5732 }));
5733 logger.end_teardown();
5734
5735 let (ok_result, panic_info) = match result {
5736 Ok(v) => (Some(v), None),
5737 Err(p) => {
5738 let msg = if let Some(s) = p.downcast_ref::<&str>() {
5739 (*s).to_string()
5740 } else if let Some(s) = p.downcast_ref::<String>() {
5741 s.clone()
5742 } else {
5743 "Unknown panic".to_string()
5744 };
5745 (None, Some(msg))
5746 }
5747 };
5748
5749 LogCapture {
5750 logs: logger.logs(),
5751 timings: logger.timings(),
5752 result: ok_result,
5753 panic_info,
5754 }
5755 }
5756}
5757
5758impl Default for TestLogger {
5759 fn default() -> Self {
5760 Self::new()
5761 }
5762}
5763
5764#[derive(Debug)]
5766pub struct LogCapture<T> {
5767 pub logs: Vec<CapturedLog>,
5769 pub timings: TestTimings,
5771 pub result: Option<T>,
5773 pub panic_info: Option<String>,
5775}
5776
5777impl<T> LogCapture<T> {
5778 #[must_use]
5780 pub fn passed(&self) -> bool {
5781 self.result.is_some()
5782 }
5783
5784 #[must_use]
5786 pub fn failed(&self) -> bool {
5787 self.panic_info.is_some()
5788 }
5789
5790 #[must_use]
5792 pub fn contains_message(&self, text: &str) -> bool {
5793 self.logs.iter().any(|log| log.contains(text))
5794 }
5795
5796 #[must_use]
5798 pub fn count_by_level(&self, level: LogLevel) -> usize {
5799 self.logs.iter().filter(|log| log.level == level).count()
5800 }
5801
5802 #[must_use]
5804 pub fn failure_context(&self, n: usize) -> String {
5805 let start = self.logs.len().saturating_sub(n);
5806 let recent: Vec<_> = self.logs[start..].iter().map(CapturedLog::format).collect();
5807
5808 let mut output = String::new();
5809
5810 if let Some(ref panic) = self.panic_info {
5811 output.push_str(&format!("Test failed: {}\n\n", panic));
5812 }
5813
5814 output.push_str(&self.timings.format());
5815 output.push_str("\n\n");
5816
5817 if recent.is_empty() {
5818 output.push_str("No logs captured");
5819 } else {
5820 output.push_str(&format!(
5821 "Last {} log(s) before failure:\n {}",
5822 recent.len(),
5823 recent.join("\n ")
5824 ));
5825 }
5826
5827 output
5828 }
5829
5830 pub fn unwrap(self) -> T {
5832 match self.result {
5833 Some(v) => v,
5834 None => panic!(
5835 "Test failed with log context:\n{}",
5836 self.failure_context(10)
5837 ),
5838 }
5839 }
5840
5841 pub fn unwrap_or(self, default: T) -> T {
5843 self.result.unwrap_or(default)
5844 }
5845}
5846
5847#[macro_export]
5852macro_rules! assert_with_logs {
5853 ($logger:expr, $cond:expr) => {
5854 if !$cond {
5855 panic!(
5856 "Assertion failed: {}\n\n{}",
5857 stringify!($cond),
5858 $logger.failure_context(10)
5859 );
5860 }
5861 };
5862 ($logger:expr, $cond:expr, $($arg:tt)+) => {
5863 if !$cond {
5864 panic!(
5865 "Assertion failed: {}\n\n{}",
5866 format!($($arg)+),
5867 $logger.failure_context(10)
5868 );
5869 }
5870 };
5871}
5872
5873#[macro_export]
5875macro_rules! assert_eq_with_logs {
5876 ($logger:expr, $left:expr, $right:expr) => {
5877 if $left != $right {
5878 panic!(
5879 "Assertion failed: {} == {}\n left: {:?}\n right: {:?}\n\n{}",
5880 stringify!($left),
5881 stringify!($right),
5882 $left,
5883 $right,
5884 $logger.failure_context(10)
5885 );
5886 }
5887 };
5888 ($logger:expr, $left:expr, $right:expr, $($arg:tt)+) => {
5889 if $left != $right {
5890 panic!(
5891 "Assertion failed: {}\n left: {:?}\n right: {:?}\n\n{}",
5892 format!($($arg)+),
5893 $left,
5894 $right,
5895 $logger.failure_context(10)
5896 );
5897 }
5898 };
5899}
5900
5901pub use assert_eq_with_logs;
5902pub use assert_with_logs;
5903
5904#[derive(Debug)]
5906pub struct ResponseDiff {
5907 pub expected_status: u16,
5909 pub actual_status: u16,
5911 pub expected_body: Option<String>,
5913 pub actual_body: String,
5915 pub header_diffs: Vec<(String, Option<String>, Option<String>)>,
5917}
5918
5919impl ResponseDiff {
5920 pub fn new(expected_status: u16, actual: &TestResponse) -> Self {
5922 Self {
5923 expected_status,
5924 actual_status: actual.status().as_u16(),
5925 expected_body: None,
5926 actual_body: actual.text().to_string(),
5927 header_diffs: Vec::new(),
5928 }
5929 }
5930
5931 #[must_use]
5933 pub fn expected_body(mut self, body: impl Into<String>) -> Self {
5934 self.expected_body = Some(body.into());
5935 self
5936 }
5937
5938 #[must_use]
5940 pub fn expected_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
5941 self.header_diffs
5942 .push((name.into(), Some(value.into()), None));
5943 self
5944 }
5945
5946 #[must_use]
5948 pub fn is_match(&self) -> bool {
5949 if self.expected_status != self.actual_status {
5950 return false;
5951 }
5952 if let Some(ref expected) = self.expected_body {
5953 if !self.actual_body.contains(expected) {
5954 return false;
5955 }
5956 }
5957 true
5958 }
5959
5960 #[must_use]
5962 pub fn format(&self) -> String {
5963 let mut output = String::new();
5964
5965 if self.expected_status != self.actual_status {
5966 output.push_str(&format!(
5967 "Status mismatch:\n expected: {}\n actual: {}\n",
5968 self.expected_status, self.actual_status
5969 ));
5970 }
5971
5972 if let Some(ref expected) = self.expected_body {
5973 if !self.actual_body.contains(expected) {
5974 output.push_str(&format!(
5975 "Body mismatch:\n expected to contain: {:?}\n actual: {:?}\n",
5976 expected, self.actual_body
5977 ));
5978 }
5979 }
5980
5981 for (name, expected, actual) in &self.header_diffs {
5982 output.push_str(&format!(
5983 "Header '{}' mismatch:\n expected: {:?}\n actual: {:?}\n",
5984 name, expected, actual
5985 ));
5986 }
5987
5988 if output.is_empty() {
5989 "No differences".to_string()
5990 } else {
5991 output
5992 }
5993 }
5994}
5995
5996#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
6020pub struct ResponseSnapshot {
6021 pub status: u16,
6023 pub headers: Vec<(String, String)>,
6025 pub body: String,
6027 #[serde(skip_serializing_if = "Option::is_none")]
6029 pub body_json: Option<serde_json::Value>,
6030}
6031
6032impl ResponseSnapshot {
6033 pub fn from_test_response(resp: &TestResponse) -> Self {
6038 let body = resp.text().to_string();
6039 let body_json = serde_json::from_str::<serde_json::Value>(&body).ok();
6040
6041 let mut headers: Vec<(String, String)> = resp
6042 .headers()
6043 .iter()
6044 .filter_map(|(name, value)| {
6045 std::str::from_utf8(value)
6046 .ok()
6047 .map(|v| (name.to_lowercase(), v.to_string()))
6048 })
6049 .collect();
6050 headers.sort();
6051
6052 Self {
6053 status: resp.status().as_u16(),
6054 headers,
6055 body,
6056 body_json,
6057 }
6058 }
6059
6060 pub fn from_test_response_with_headers(resp: &TestResponse, header_names: &[&str]) -> Self {
6062 let mut snapshot = Self::from_test_response(resp);
6063 let names: Vec<String> = header_names.iter().map(|n| n.to_lowercase()).collect();
6064 snapshot.headers.retain(|(name, _)| names.contains(name));
6065 snapshot
6066 }
6067
6068 #[must_use]
6075 pub fn mask_fields(mut self, paths: &[&str], placeholder: &str) -> Self {
6076 if let Some(ref mut json) = self.body_json {
6077 for path in paths {
6078 mask_json_path(json, path, placeholder);
6079 }
6080 self.body = serde_json::to_string_pretty(json).unwrap_or(self.body);
6081 }
6082 self
6083 }
6084
6085 pub fn save(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
6087 let path = path.as_ref();
6088 if let Some(parent) = path.parent() {
6089 std::fs::create_dir_all(parent)?;
6090 }
6091 let json = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?;
6092 std::fs::write(path, json)
6093 }
6094
6095 pub fn load(path: impl AsRef<std::path::Path>) -> std::io::Result<Self> {
6097 let data = std::fs::read_to_string(path)?;
6098 serde_json::from_str(&data)
6099 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
6100 }
6101
6102 #[must_use]
6104 pub fn diff(&self, other: &Self) -> String {
6105 let mut output = String::new();
6106
6107 if self.status != other.status {
6108 output.push_str(&format!("Status: {} vs {}\n", self.status, other.status));
6109 }
6110
6111 for (name, value) in &self.headers {
6113 match other.headers.iter().find(|(n, _)| n == name) {
6114 Some((_, other_value)) if value != other_value => {
6115 output.push_str(&format!(
6116 "Header '{}': {:?} vs {:?}\n",
6117 name, value, other_value
6118 ));
6119 }
6120 None => {
6121 output.push_str(&format!("Header '{}': present vs missing\n", name));
6122 }
6123 _ => {}
6124 }
6125 }
6126 for (name, _) in &other.headers {
6127 if !self.headers.iter().any(|(n, _)| n == name) {
6128 output.push_str(&format!("Header '{}': missing vs present\n", name));
6129 }
6130 }
6131
6132 if self.body != other.body {
6134 output.push_str(&format!(
6135 "Body:\n expected: {:?}\n actual: {:?}\n",
6136 other.body, self.body
6137 ));
6138 }
6139
6140 if output.is_empty() {
6141 "No differences".to_string()
6142 } else {
6143 output
6144 }
6145 }
6146
6147 pub fn matches_ignoring_headers(&self, other: &Self, ignore: &[&str]) -> bool {
6149 if self.status != other.status {
6150 return false;
6151 }
6152
6153 let ignore_lower: Vec<String> = ignore.iter().map(|s| s.to_lowercase()).collect();
6154
6155 let self_headers: Vec<_> = self
6156 .headers
6157 .iter()
6158 .filter(|(n, _)| !ignore_lower.contains(n))
6159 .collect();
6160 let other_headers: Vec<_> = other
6161 .headers
6162 .iter()
6163 .filter(|(n, _)| !ignore_lower.contains(n))
6164 .collect();
6165
6166 if self_headers != other_headers {
6167 return false;
6168 }
6169
6170 match (&self.body_json, &other.body_json) {
6172 (Some(a), Some(b)) => a == b,
6173 _ => self.body == other.body,
6174 }
6175 }
6176}
6177
6178fn mask_json_path(value: &mut serde_json::Value, path: &str, placeholder: &str) {
6180 let parts: Vec<&str> = path.splitn(2, '.').collect();
6181 match parts.as_slice() {
6182 [key] => {
6183 if let Some(obj) = value.as_object_mut() {
6184 if obj.contains_key(*key) {
6185 obj.insert(
6186 key.to_string(),
6187 serde_json::Value::String(placeholder.to_string()),
6188 );
6189 }
6190 }
6191 if let Some(arr) = value.as_array_mut() {
6192 if let Ok(idx) = key.parse::<usize>() {
6193 if idx < arr.len() {
6194 arr[idx] = serde_json::Value::String(placeholder.to_string());
6195 }
6196 }
6197 }
6198 }
6199 [key, rest] => {
6200 if let Some(obj) = value.as_object_mut() {
6201 if let Some(child) = obj.get_mut(*key) {
6202 mask_json_path(child, rest, placeholder);
6203 }
6204 }
6205 if let Some(arr) = value.as_array_mut() {
6206 if let Ok(idx) = key.parse::<usize>() {
6207 if let Some(child) = arr.get_mut(idx) {
6208 mask_json_path(child, rest, placeholder);
6209 }
6210 }
6211 }
6212 }
6213 _ => {}
6214 }
6215}
6216
6217#[macro_export]
6232macro_rules! assert_response_snapshot {
6233 ($response:expr, $path:expr) => {{
6234 let snapshot = $crate::ResponseSnapshot::from_test_response(&$response);
6235 let path = std::path::Path::new($path);
6236
6237 if std::env::var("SNAPSHOT_UPDATE").is_ok() || !path.exists() {
6238 snapshot.save(path).expect("failed to save snapshot");
6239 } else {
6240 let expected =
6241 $crate::ResponseSnapshot::load(path).expect("failed to load snapshot");
6242 assert!(
6243 snapshot == expected,
6244 "Snapshot mismatch for {}:\n{}",
6245 $path,
6246 snapshot.diff(&expected)
6247 );
6248 }
6249 }};
6250 ($response:expr, $path:expr, mask: [$($field:expr),* $(,)?]) => {{
6251 let snapshot = $crate::ResponseSnapshot::from_test_response(&$response)
6252 .mask_fields(&[$($field),*], "<MASKED>");
6253 let path = std::path::Path::new($path);
6254
6255 if std::env::var("SNAPSHOT_UPDATE").is_ok() || !path.exists() {
6256 snapshot.save(path).expect("failed to save snapshot");
6257 } else {
6258 let expected =
6259 $crate::ResponseSnapshot::load(path).expect("failed to load snapshot");
6260 assert!(
6261 snapshot == expected,
6262 "Snapshot mismatch for {}:\n{}",
6263 $path,
6264 snapshot.diff(&expected)
6265 );
6266 }
6267 }};
6268}
6269
6270#[cfg(test)]
6271mod snapshot_tests {
6272 use super::*;
6273
6274 fn mock_test_response(status: u16, body: &str, headers: &[(&str, &str)]) -> TestResponse {
6275 let mut resp =
6276 crate::response::Response::with_status(crate::response::StatusCode::from_u16(status));
6277 for (name, value) in headers {
6278 resp = resp.header(*name, value.as_bytes().to_vec());
6279 }
6280 resp = resp.body(crate::response::ResponseBody::Bytes(
6281 body.as_bytes().to_vec(),
6282 ));
6283 TestResponse::new(resp, 0)
6284 }
6285
6286 #[test]
6287 fn snapshot_from_test_response() {
6288 let resp = mock_test_response(
6289 200,
6290 r#"{"id":1,"name":"Alice"}"#,
6291 &[("content-type", "application/json")],
6292 );
6293 let snap = ResponseSnapshot::from_test_response(&resp);
6294
6295 assert_eq!(snap.status, 200);
6296 assert!(snap.body_json.is_some());
6297 assert_eq!(snap.body_json.as_ref().unwrap()["name"], "Alice");
6298 }
6299
6300 #[test]
6301 fn snapshot_equality() {
6302 let resp = mock_test_response(200, "hello", &[]);
6303 let snap1 = ResponseSnapshot::from_test_response(&resp);
6304 let snap2 = ResponseSnapshot::from_test_response(&resp);
6305 assert_eq!(snap1, snap2);
6306 }
6307
6308 #[test]
6309 fn snapshot_diff_status() {
6310 let s1 = ResponseSnapshot {
6311 status: 200,
6312 headers: vec![],
6313 body: "ok".to_string(),
6314 body_json: None,
6315 };
6316 let s2 = ResponseSnapshot {
6317 status: 404,
6318 ..s1.clone()
6319 };
6320 let diff = s1.diff(&s2);
6321 assert!(diff.contains("200 vs 404"));
6322 }
6323
6324 #[test]
6325 fn snapshot_diff_body() {
6326 let s1 = ResponseSnapshot {
6327 status: 200,
6328 headers: vec![],
6329 body: "hello".to_string(),
6330 body_json: None,
6331 };
6332 let s2 = ResponseSnapshot {
6333 body: "world".to_string(),
6334 ..s1.clone()
6335 };
6336 let diff = s1.diff(&s2);
6337 assert!(diff.contains("Body:"));
6338 }
6339
6340 #[test]
6341 fn snapshot_diff_no_differences() {
6342 let s = ResponseSnapshot {
6343 status: 200,
6344 headers: vec![],
6345 body: "ok".to_string(),
6346 body_json: None,
6347 };
6348 assert_eq!(s.diff(&s), "No differences");
6349 }
6350
6351 #[test]
6352 fn snapshot_mask_fields() {
6353 let resp = mock_test_response(
6354 200,
6355 r#"{"id":42,"name":"Alice","created_at":"2026-01-01"}"#,
6356 &[],
6357 );
6358 let snap = ResponseSnapshot::from_test_response(&resp)
6359 .mask_fields(&["id", "created_at"], "<MASKED>");
6360
6361 let json = snap.body_json.unwrap();
6362 assert_eq!(json["id"], "<MASKED>");
6363 assert_eq!(json["name"], "Alice");
6364 assert_eq!(json["created_at"], "<MASKED>");
6365 }
6366
6367 #[test]
6368 fn snapshot_mask_nested_fields() {
6369 let resp = mock_test_response(200, r#"{"user":{"id":1,"name":"Bob"}}"#, &[]);
6370 let snap =
6371 ResponseSnapshot::from_test_response(&resp).mask_fields(&["user.id"], "<MASKED>");
6372
6373 let json = snap.body_json.unwrap();
6374 assert_eq!(json["user"]["id"], "<MASKED>");
6375 assert_eq!(json["user"]["name"], "Bob");
6376 }
6377
6378 #[test]
6379 fn snapshot_save_and_load() {
6380 let snap = ResponseSnapshot {
6381 status: 200,
6382 headers: vec![("content-type".to_string(), "application/json".to_string())],
6383 body: r#"{"ok":true}"#.to_string(),
6384 body_json: Some(serde_json::json!({"ok": true})),
6385 };
6386
6387 let dir = std::env::temp_dir().join("fastapi_snapshot_test");
6388 let path = dir.join("test_snap.json");
6389 snap.save(&path).unwrap();
6390
6391 let loaded = ResponseSnapshot::load(&path).unwrap();
6392 assert_eq!(snap, loaded);
6393
6394 let _ = std::fs::remove_dir_all(&dir);
6396 }
6397
6398 #[test]
6399 fn snapshot_matches_ignoring_headers() {
6400 let s1 = ResponseSnapshot {
6401 status: 200,
6402 headers: vec![
6403 ("content-type".to_string(), "application/json".to_string()),
6404 ("x-request-id".to_string(), "abc".to_string()),
6405 ],
6406 body: "ok".to_string(),
6407 body_json: None,
6408 };
6409 let s2 = ResponseSnapshot {
6410 headers: vec![
6411 ("content-type".to_string(), "application/json".to_string()),
6412 ("x-request-id".to_string(), "xyz".to_string()),
6413 ],
6414 ..s1.clone()
6415 };
6416
6417 assert!(!s1.matches_ignoring_headers(&s2, &[]));
6418 assert!(s1.matches_ignoring_headers(&s2, &["X-Request-Id"]));
6419 }
6420
6421 #[test]
6422 fn snapshot_with_selected_headers() {
6423 let resp = mock_test_response(
6424 200,
6425 "ok",
6426 &[
6427 ("content-type", "text/plain"),
6428 ("x-request-id", "abc123"),
6429 ("x-trace-id", "trace-456"),
6430 ],
6431 );
6432 let snap = ResponseSnapshot::from_test_response_with_headers(&resp, &["content-type"]);
6433
6434 assert_eq!(snap.headers.len(), 1);
6435 assert_eq!(snap.headers[0].0, "content-type");
6436 }
6437
6438 #[test]
6439 fn snapshot_json_structural_comparison() {
6440 let s1 = ResponseSnapshot {
6442 status: 200,
6443 headers: vec![],
6444 body: r#"{"a":1,"b":2}"#.to_string(),
6445 body_json: Some(serde_json::json!({"a": 1, "b": 2})),
6446 };
6447 let s2 = ResponseSnapshot {
6448 body: r#"{"b":2,"a":1}"#.to_string(),
6449 body_json: Some(serde_json::json!({"b": 2, "a": 1})),
6450 ..s1.clone()
6451 };
6452
6453 assert_ne!(s1, s2);
6455 assert!(s1.matches_ignoring_headers(&s2, &[]));
6457 }
6458}
6459
6460#[cfg(test)]
6461mod mock_server_tests {
6462 use super::*;
6463
6464 #[test]
6465 fn mock_server_starts_and_responds() {
6466 let server = MockServer::start();
6467 server.mock_response("/hello", MockResponse::ok().body_str("Hello, World!"));
6468
6469 let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
6471 stream
6472 .write_all(b"GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n")
6473 .unwrap();
6474
6475 let mut response = String::new();
6476 stream.read_to_string(&mut response).unwrap();
6477
6478 assert!(response.contains("200 OK"));
6479 assert!(response.contains("Hello, World!"));
6480 }
6481
6482 #[test]
6483 fn mock_server_records_requests() {
6484 let server = MockServer::start();
6485
6486 let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
6488 stream
6489 .write_all(b"GET /api/users HTTP/1.1\r\nHost: localhost\r\nX-Custom: value\r\n\r\n")
6490 .unwrap();
6491 let mut response = Vec::new();
6492 let _ = stream.read_to_end(&mut response);
6493
6494 thread::sleep(Duration::from_millis(50));
6496
6497 let requests = server.requests();
6498 assert_eq!(requests.len(), 1);
6499 assert_eq!(requests[0].method, "GET");
6500 assert_eq!(requests[0].path, "/api/users");
6501 assert_eq!(requests[0].header("x-custom"), Some("value"));
6502 }
6503
6504 #[test]
6505 fn mock_server_handles_post_with_body() {
6506 let server = MockServer::start();
6507 server.mock_response(
6508 "/api/create",
6509 MockResponse::with_status(201).body_str("Created"),
6510 );
6511
6512 let body = r#"{"name":"test"}"#;
6513 let request = format!(
6514 "POST /api/create HTTP/1.1\r\nHost: localhost\r\nContent-Length: {}\r\nContent-Type: application/json\r\n\r\n{}",
6515 body.len(),
6516 body
6517 );
6518
6519 let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
6520 stream.write_all(request.as_bytes()).unwrap();
6521 let mut response = String::new();
6522 stream.read_to_string(&mut response).unwrap();
6523
6524 assert!(response.contains("201 Created"));
6525
6526 thread::sleep(Duration::from_millis(50));
6527 let requests = server.requests();
6528 assert_eq!(requests.len(), 1);
6529 assert_eq!(requests[0].method, "POST");
6530 assert_eq!(requests[0].body_text(), body);
6531 }
6532
6533 #[test]
6534 fn mock_server_pattern_matching() {
6535 let server = MockServer::start();
6536 server.mock_response("/api/*", MockResponse::ok().body_str("API Response"));
6537
6538 let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
6539 stream
6540 .write_all(b"GET /api/users/123 HTTP/1.1\r\nHost: localhost\r\n\r\n")
6541 .unwrap();
6542 let mut response = String::new();
6543 stream.read_to_string(&mut response).unwrap();
6544
6545 assert!(response.contains("API Response"));
6546 }
6547
6548 #[test]
6549 fn mock_server_default_response() {
6550 let server = MockServer::start();
6551
6552 let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
6553 stream
6554 .write_all(b"GET /unknown HTTP/1.1\r\nHost: localhost\r\n\r\n")
6555 .unwrap();
6556 let mut response = String::new();
6557 stream.read_to_string(&mut response).unwrap();
6558
6559 assert!(response.contains("404"));
6560 }
6561
6562 #[test]
6563 fn mock_server_url_helpers() {
6564 let server = MockServer::start();
6565
6566 let url = server.url();
6567 assert!(url.starts_with("http://127.0.0.1:"));
6568
6569 let api_url = server.url_for("/api/users");
6570 assert!(api_url.contains("/api/users"));
6571 }
6572
6573 #[test]
6574 fn mock_server_clear_requests() {
6575 let server = MockServer::start();
6576
6577 let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
6579 stream
6580 .write_all(b"GET /test HTTP/1.1\r\nHost: localhost\r\n\r\n")
6581 .unwrap();
6582 let mut response = Vec::new();
6583 let _ = stream.read_to_end(&mut response);
6584
6585 thread::sleep(Duration::from_millis(50));
6586 assert_eq!(server.request_count(), 1);
6587
6588 server.clear_requests();
6589 assert_eq!(server.request_count(), 0);
6590 }
6591
6592 #[test]
6593 fn mock_server_wait_for_requests() {
6594 let server = MockServer::start();
6595
6596 let addr = server.addr();
6598 thread::spawn(move || {
6599 thread::sleep(Duration::from_millis(50));
6600 let mut stream = StdTcpStream::connect(addr).expect("Failed to connect");
6601 stream
6602 .write_all(b"GET /delayed HTTP/1.1\r\nHost: localhost\r\n\r\n")
6603 .unwrap();
6604 });
6605
6606 let received = server.wait_for_requests(1, Duration::from_millis(500));
6607 assert!(received);
6608 assert_eq!(server.request_count(), 1);
6609 }
6610
6611 #[test]
6612 fn mock_server_assert_helpers() {
6613 let server = MockServer::start();
6614
6615 let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
6616 stream
6617 .write_all(b"GET /expected HTTP/1.1\r\nHost: localhost\r\n\r\n")
6618 .unwrap();
6619 let mut response = Vec::new();
6620 let _ = stream.read_to_end(&mut response);
6621
6622 thread::sleep(Duration::from_millis(50));
6623
6624 server.assert_received("/expected");
6625 server.assert_not_received("/not-expected");
6626 server.assert_request_count(1);
6627 }
6628
6629 #[test]
6630 fn mock_server_query_string_parsing() {
6631 let server = MockServer::start();
6632
6633 let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
6634 stream
6635 .write_all(b"GET /search?q=rust&limit=10 HTTP/1.1\r\nHost: localhost\r\n\r\n")
6636 .unwrap();
6637 let mut response = Vec::new();
6638 let _ = stream.read_to_end(&mut response);
6639
6640 thread::sleep(Duration::from_millis(50));
6641
6642 let requests = server.requests();
6643 assert_eq!(requests.len(), 1);
6644 assert_eq!(requests[0].path, "/search");
6645 assert_eq!(requests[0].query, Some("q=rust&limit=10".to_string()));
6646 assert_eq!(requests[0].url(), "/search?q=rust&limit=10");
6647 }
6648
6649 #[test]
6650 fn mock_response_json() {
6651 #[derive(serde::Serialize)]
6652 struct User {
6653 name: String,
6654 }
6655
6656 let response = MockResponse::ok().json(&User {
6657 name: "Alice".to_string(),
6658 });
6659 let bytes = response.to_http_response();
6660 let http = String::from_utf8_lossy(&bytes);
6661
6662 assert!(http.contains("application/json"));
6663 assert!(http.contains("Alice"));
6664 }
6665
6666 #[test]
6667 fn recorded_request_helpers() {
6668 let request = RecordedRequest {
6669 method: "GET".to_string(),
6670 path: "/api/users".to_string(),
6671 query: Some("page=1".to_string()),
6672 headers: vec![("Content-Type".to_string(), "application/json".to_string())],
6673 body: b"test body".to_vec(),
6674 timestamp: std::time::Instant::now(),
6675 };
6676
6677 assert_eq!(request.body_text(), "test body");
6678 assert_eq!(request.header("content-type"), Some("application/json"));
6679 assert_eq!(request.url(), "/api/users?page=1");
6680 }
6681}
6682
6683#[cfg(test)]
6684mod e2e_tests {
6685 use super::*;
6686
6687 fn test_handler(_ctx: &RequestContext, req: &mut Request) -> std::future::Ready<Response> {
6689 let path = req.path();
6690 let response = match path {
6691 "/" => Response::ok().body(ResponseBody::Bytes(b"Home".to_vec())),
6692 "/login" => Response::ok().body(ResponseBody::Bytes(b"Login Page".to_vec())),
6693 "/dashboard" => Response::ok().body(ResponseBody::Bytes(b"Dashboard".to_vec())),
6694 "/api/users" => {
6695 Response::ok().body(ResponseBody::Bytes(b"[\"Alice\",\"Bob\"]".to_vec()))
6696 }
6697 "/fail" => Response::with_status(StatusCode::INTERNAL_SERVER_ERROR)
6698 .body(ResponseBody::Bytes(b"Error".to_vec())),
6699 _ => Response::with_status(StatusCode::NOT_FOUND)
6700 .body(ResponseBody::Bytes(b"Not Found".to_vec())),
6701 };
6702 std::future::ready(response)
6703 }
6704
6705 #[test]
6706 fn e2e_scenario_all_steps_pass() {
6707 let client = TestClient::new(test_handler);
6708 let mut scenario = E2EScenario::new("Basic Navigation", client);
6709
6710 scenario.step("Visit home page", |client| {
6711 let response = client.get("/").send();
6712 assert_eq!(response.status().as_u16(), 200);
6713 assert_eq!(response.text(), "Home");
6714 });
6715
6716 scenario.step("Visit login page", |client| {
6717 let response = client.get("/login").send();
6718 assert_eq!(response.status().as_u16(), 200);
6719 });
6720
6721 assert!(scenario.passed());
6722 assert_eq!(scenario.steps().len(), 2);
6723 assert!(scenario.steps().iter().all(|s| s.result.is_passed()));
6724 }
6725
6726 #[test]
6727 fn e2e_scenario_step_failure() {
6728 let client = TestClient::new(test_handler);
6729 let mut scenario = E2EScenario::new("Failure Test", client).stop_on_failure(true);
6730
6731 scenario.step("First step passes", |client| {
6732 let response = client.get("/").send();
6733 assert_eq!(response.status().as_u16(), 200);
6734 });
6735
6736 scenario.step("Second step fails", |_client| {
6737 panic!("Intentional failure");
6738 });
6739
6740 scenario.step("Third step skipped", |client| {
6741 let response = client.get("/dashboard").send();
6742 assert_eq!(response.status().as_u16(), 200);
6743 });
6744
6745 assert!(!scenario.passed());
6746 assert_eq!(scenario.steps().len(), 3);
6747 assert!(scenario.steps()[0].result.is_passed());
6748 assert!(scenario.steps()[1].result.is_failed());
6749 assert!(matches!(scenario.steps()[2].result, E2EStepResult::Skipped));
6750 }
6751
6752 #[test]
6753 fn e2e_scenario_continue_on_failure() {
6754 let client = TestClient::new(test_handler);
6755 let mut scenario = E2EScenario::new("Continue Test", client).stop_on_failure(false);
6756
6757 scenario.step("First step fails", |_client| {
6758 panic!("First failure");
6759 });
6760
6761 scenario.step("Second step still runs", |client| {
6762 let response = client.get("/").send();
6763 assert_eq!(response.status().as_u16(), 200);
6764 });
6765
6766 assert!(!scenario.passed());
6767 assert_eq!(scenario.steps().len(), 2);
6768 assert!(scenario.steps()[0].result.is_failed());
6769 assert!(scenario.steps()[1].result.is_passed());
6771 }
6772
6773 #[test]
6774 fn e2e_report_text_format() {
6775 let client = TestClient::new(test_handler);
6776 let mut scenario =
6777 E2EScenario::new("Report Test", client).description("Tests report generation");
6778
6779 scenario.step("Step 1", |client| {
6780 let _ = client.get("/").send();
6781 });
6782
6783 let report = scenario.report();
6784 let text = report.to_text();
6785
6786 assert!(text.contains("E2E Test Report: Report Test"));
6787 assert!(text.contains("Tests report generation"));
6788 assert!(text.contains("1 passed"));
6789 assert!(text.contains("Step 1"));
6790 }
6791
6792 #[test]
6793 fn e2e_report_json_format() {
6794 let client = TestClient::new(test_handler);
6795 let mut scenario = E2EScenario::new("JSON Test", client);
6796
6797 scenario.step("API call", |client| {
6798 let response = client.get("/api/users").send();
6799 assert_eq!(response.status().as_u16(), 200);
6800 });
6801
6802 let report = scenario.report();
6803 let json = report.to_json();
6804
6805 assert!(json.contains(r#""scenario": "JSON Test""#));
6806 assert!(json.contains(r#""passed": 1"#));
6807 assert!(json.contains(r#""name": "API call""#));
6808 assert!(json.contains(r#""status": "passed""#));
6809 }
6810
6811 #[test]
6812 fn e2e_report_html_format() {
6813 let client = TestClient::new(test_handler);
6814 let mut scenario = E2EScenario::new("HTML Test", client);
6815
6816 scenario.step("Web visit", |client| {
6817 let _ = client.get("/").send();
6818 });
6819
6820 let report = scenario.report();
6821 let html = report.to_html();
6822
6823 assert!(html.contains("<!DOCTYPE html>"));
6824 assert!(html.contains("E2E Report: HTML Test"));
6825 assert!(html.contains("1 passed"));
6826 assert!(html.contains("Web visit"));
6827 }
6828
6829 #[test]
6830 fn e2e_step_timing() {
6831 let client = TestClient::new(test_handler);
6832 let mut scenario = E2EScenario::new("Timing Test", client);
6833
6834 scenario.step("Timed step", |_client| {
6835 std::thread::sleep(std::time::Duration::from_millis(10));
6837 });
6838
6839 assert!(scenario.steps()[0].duration >= std::time::Duration::from_millis(10));
6840 }
6841
6842 #[test]
6843 fn e2e_logs_captured() {
6844 let client = TestClient::new(test_handler);
6845 let mut scenario = E2EScenario::new("Log Test", client);
6846
6847 scenario.log("Manual log entry");
6848 scenario.step("Logged step", |_client| {});
6849
6850 assert!(
6851 scenario
6852 .logs()
6853 .iter()
6854 .any(|l| l.contains("Manual log entry"))
6855 );
6856 assert!(
6857 scenario
6858 .logs()
6859 .iter()
6860 .any(|l| l.contains("[START] Logged step"))
6861 );
6862 assert!(
6863 scenario
6864 .logs()
6865 .iter()
6866 .any(|l| l.contains("[PASS] Logged step"))
6867 );
6868 }
6869
6870 #[test]
6871 fn e2e_try_step_with_result() {
6872 let client = TestClient::new(test_handler);
6873 let mut scenario = E2EScenario::new("Try Step Test", client);
6874
6875 let result: Result<(), &str> = scenario.try_step("Success step", |client| {
6876 let response = client.get("/").send();
6877 if response.status().as_u16() == 200 {
6878 Ok(())
6879 } else {
6880 Err("Unexpected status")
6881 }
6882 });
6883
6884 assert!(result.is_ok());
6885 assert!(scenario.passed());
6886 }
6887
6888 #[test]
6889 fn e2e_escape_functions() {
6890 assert_eq!(escape_json("hello"), "hello");
6892 assert_eq!(escape_json("a\"b"), "a\\\"b");
6893 assert_eq!(escape_json("a\nb"), "a\\nb");
6894
6895 assert_eq!(escape_html("hello"), "hello");
6897 assert_eq!(escape_html("<script>"), "<script>");
6898 assert_eq!(escape_html("a&b"), "a&b");
6899 }
6900
6901 #[test]
6902 fn e2e_step_result_helpers() {
6903 let passed = E2EStepResult::Passed;
6904 let failed = E2EStepResult::Failed("error".to_string());
6905 let skipped = E2EStepResult::Skipped;
6906
6907 assert!(passed.is_passed());
6908 assert!(!passed.is_failed());
6909
6910 assert!(!failed.is_passed());
6911 assert!(failed.is_failed());
6912
6913 assert!(!skipped.is_passed());
6914 assert!(!skipped.is_failed());
6915 }
6916}
6917
6918pub trait TestFixture: Sized + Send {
6952 fn setup() -> Self;
6954
6955 fn teardown(&mut self) {}
6959}
6960
6961pub struct FixtureGuard<F: TestFixture> {
6965 fixture: Option<F>,
6966}
6967
6968impl<F: TestFixture> FixtureGuard<F> {
6969 pub fn new() -> Self {
6971 Self {
6972 fixture: Some(F::setup()),
6973 }
6974 }
6975
6976 pub fn get(&self) -> &F {
6978 self.fixture.as_ref().unwrap()
6979 }
6980
6981 pub fn get_mut(&mut self) -> &mut F {
6983 self.fixture.as_mut().unwrap()
6984 }
6985}
6986
6987impl<F: TestFixture> Default for FixtureGuard<F> {
6988 fn default() -> Self {
6989 Self::new()
6990 }
6991}
6992
6993impl<F: TestFixture> Drop for FixtureGuard<F> {
6994 fn drop(&mut self) {
6995 if let Some(mut fixture) = self.fixture.take() {
6996 fixture.teardown();
6997 }
6998 }
6999}
7000
7001impl<F: TestFixture> std::ops::Deref for FixtureGuard<F> {
7002 type Target = F;
7003
7004 fn deref(&self) -> &Self::Target {
7005 self.get()
7006 }
7007}
7008
7009impl<F: TestFixture> std::ops::DerefMut for FixtureGuard<F> {
7010 fn deref_mut(&mut self) -> &mut Self::Target {
7011 self.get_mut()
7012 }
7013}
7014
7015pub struct IntegrationTest<H: Handler + 'static> {
7061 name: String,
7063 client: TestClient<H>,
7065 fixtures: HashMap<std::any::TypeId, Box<dyn std::any::Any + Send>>,
7067 reset_hooks: Vec<Box<dyn Fn() + Send + Sync>>,
7069}
7070
7071impl<H: Handler + 'static> IntegrationTest<H> {
7072 pub fn new(name: impl Into<String>, handler: H) -> Self {
7074 Self {
7075 name: name.into(),
7076 client: TestClient::new(handler),
7077 fixtures: HashMap::new(),
7078 reset_hooks: Vec::new(),
7079 }
7080 }
7081
7082 pub fn with_seed(name: impl Into<String>, handler: H, seed: u64) -> Self {
7084 Self {
7085 name: name.into(),
7086 client: TestClient::with_seed(handler, seed),
7087 fixtures: HashMap::new(),
7088 reset_hooks: Vec::new(),
7089 }
7090 }
7091
7092 #[must_use]
7096 pub fn with_fixture<F: TestFixture + 'static>(mut self) -> Self {
7097 let guard = FixtureGuard::<F>::new();
7098 self.fixtures
7099 .insert(std::any::TypeId::of::<F>(), Box::new(guard));
7100 self
7101 }
7102
7103 #[must_use]
7107 pub fn on_reset<F: Fn() + Send + Sync + 'static>(mut self, f: F) -> Self {
7108 self.reset_hooks.push(Box::new(f));
7109 self
7110 }
7111
7112 pub fn run<F>(mut self, test_fn: F)
7117 where
7118 F: FnOnce(&IntegrationTestContext<'_, H>) + std::panic::UnwindSafe,
7119 {
7120 let ctx = IntegrationTestContext {
7122 name: &self.name,
7123 client: &self.client,
7124 fixtures: &self.fixtures,
7125 };
7126
7127 let ctx_ref = std::panic::AssertUnwindSafe(&ctx);
7129
7130 let result = std::panic::catch_unwind(|| {
7132 test_fn(&ctx_ref);
7133 });
7134
7135 for hook in &self.reset_hooks {
7137 hook();
7138 }
7139
7140 self.client.clear_cookies();
7142 self.client.clear_dependency_overrides();
7143
7144 self.fixtures.clear();
7146
7147 if let Err(e) = result {
7149 std::panic::resume_unwind(e);
7150 }
7151 }
7152}
7153
7154pub struct IntegrationTestContext<'a, H: Handler> {
7156 name: &'a str,
7158 client: &'a TestClient<H>,
7160 fixtures: &'a HashMap<std::any::TypeId, Box<dyn std::any::Any + Send>>,
7162}
7163
7164impl<'a, H: Handler + 'static> IntegrationTestContext<'a, H> {
7165 #[must_use]
7167 pub fn name(&self) -> &str {
7168 self.name
7169 }
7170
7171 #[must_use]
7173 pub fn client(&self) -> &TestClient<H> {
7174 self.client
7175 }
7176
7177 #[must_use]
7181 pub fn fixture<F: TestFixture + 'static>(&self) -> Option<&F> {
7182 self.fixtures
7183 .get(&std::any::TypeId::of::<F>())
7184 .and_then(|boxed| boxed.downcast_ref::<FixtureGuard<F>>())
7185 .map(FixtureGuard::get)
7186 }
7187
7188 #[must_use]
7192 pub fn fixture_mut<F: TestFixture + 'static>(&self) -> Option<&mut F> {
7193 None }
7198
7199 pub fn get(&self, path: &str) -> RequestBuilder<'_, H> {
7203 self.client.get(path)
7204 }
7205
7206 pub fn post(&self, path: &str) -> RequestBuilder<'_, H> {
7208 self.client.post(path)
7209 }
7210
7211 pub fn put(&self, path: &str) -> RequestBuilder<'_, H> {
7213 self.client.put(path)
7214 }
7215
7216 pub fn delete(&self, path: &str) -> RequestBuilder<'_, H> {
7218 self.client.delete(path)
7219 }
7220
7221 pub fn patch(&self, path: &str) -> RequestBuilder<'_, H> {
7223 self.client.patch(path)
7224 }
7225
7226 pub fn options(&self, path: &str) -> RequestBuilder<'_, H> {
7228 self.client.options(path)
7229 }
7230
7231 pub fn request(&self, method: Method, path: &str) -> RequestBuilder<'_, H> {
7233 self.client.request(method, path)
7234 }
7235}
7236
7237#[cfg(test)]
7242mod test_server_tests {
7243 use super::*;
7244 use crate::app::App;
7245 use std::net::TcpStream as StdTcpStreamAlias;
7246
7247 fn make_test_app() -> App {
7248 App::builder()
7249 .get("/health", |_ctx: &RequestContext, _req: &mut Request| {
7250 std::future::ready(
7251 Response::ok()
7252 .header("content-type", b"text/plain".to_vec())
7253 .body(ResponseBody::Bytes(b"OK".to_vec())),
7254 )
7255 })
7256 .get("/hello", |_ctx: &RequestContext, _req: &mut Request| {
7257 std::future::ready(
7258 Response::ok()
7259 .header("content-type", b"application/json".to_vec())
7260 .body(ResponseBody::Bytes(
7261 br#"{"message":"Hello, World!"}"#.to_vec(),
7262 )),
7263 )
7264 })
7265 .post("/echo", |_ctx: &RequestContext, req: &mut Request| {
7266 let body = match req.body() {
7267 Body::Bytes(b) => b.clone(),
7268 _ => Vec::new(),
7269 };
7270 std::future::ready(
7271 Response::ok()
7272 .header("content-type", b"application/octet-stream".to_vec())
7273 .body(ResponseBody::Bytes(body)),
7274 )
7275 })
7276 .build()
7277 }
7278
7279 fn send_request(addr: SocketAddr, request: &[u8]) -> String {
7280 let mut stream = StdTcpStreamAlias::connect(addr).expect("Failed to connect to TestServer");
7281 stream
7282 .set_read_timeout(Some(Duration::from_secs(5)))
7283 .expect("set_read_timeout");
7284 stream.write_all(request).expect("Failed to write request");
7285 stream.flush().expect("Failed to flush");
7286
7287 let mut buf = vec![0u8; 65536];
7288 let n = stream.read(&mut buf).expect("Failed to read response");
7289 String::from_utf8_lossy(&buf[..n]).to_string()
7290 }
7291
7292 #[test]
7293 fn test_server_starts_and_responds() {
7294 let app = make_test_app();
7295 let server = TestServer::start(app);
7296
7297 let response = send_request(
7298 server.addr(),
7299 b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
7300 );
7301
7302 assert!(
7303 response.contains("200 OK"),
7304 "Expected 200 OK, got: {response}"
7305 );
7306 assert!(response.contains("OK"), "Expected body 'OK'");
7307 }
7308
7309 #[test]
7310 fn test_server_json_response() {
7311 let app = make_test_app();
7312 let server = TestServer::start(app);
7313
7314 let response = send_request(
7315 server.addr(),
7316 b"GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n",
7317 );
7318
7319 assert!(response.contains("200 OK"));
7320 assert!(response.contains("application/json"));
7321 assert!(response.contains(r#"{"message":"Hello, World!"}"#));
7322 }
7323
7324 #[test]
7325 fn test_server_post_with_body() {
7326 let app = make_test_app();
7327 let server = TestServer::start(app);
7328
7329 let request =
7330 b"POST /echo HTTP/1.1\r\nHost: localhost\r\nContent-Length: 11\r\n\r\nHello World";
7331 let response = send_request(server.addr(), request);
7332
7333 assert!(response.contains("200 OK"));
7334 assert!(response.contains("Hello World"));
7335 }
7336
7337 #[test]
7338 fn test_server_logs_requests() {
7339 let app = make_test_app();
7340 let server = TestServer::start(app);
7341
7342 send_request(
7344 server.addr(),
7345 b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
7346 );
7347
7348 let logs = server.log_entries();
7349 assert_eq!(logs.len(), 1);
7350 assert_eq!(logs[0].method, "GET");
7351 assert_eq!(logs[0].path, "/health");
7352 assert_eq!(logs[0].status, 200);
7353 }
7354
7355 #[test]
7356 fn test_server_request_count() {
7357 let app = make_test_app();
7358 let server = TestServer::start(app);
7359
7360 assert_eq!(server.request_count(), 0);
7361
7362 send_request(
7363 server.addr(),
7364 b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
7365 );
7366 send_request(
7367 server.addr(),
7368 b"GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n",
7369 );
7370
7371 assert_eq!(server.request_count(), 2);
7372 }
7373
7374 #[test]
7375 fn test_server_clear_logs() {
7376 let app = make_test_app();
7377 let server = TestServer::start(app);
7378
7379 send_request(
7380 server.addr(),
7381 b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
7382 );
7383 assert_eq!(server.request_count(), 1);
7384
7385 server.clear_logs();
7386 assert_eq!(server.request_count(), 0);
7387 }
7388
7389 #[test]
7390 fn test_server_url_helpers() {
7391 let app = make_test_app();
7392 let server = TestServer::start(app);
7393
7394 assert!(server.url().starts_with("http://127.0.0.1:"));
7395 assert!(server.url_for("/health").ends_with("/health"));
7396 assert!(server.url_for("health").ends_with("/health"));
7397 assert!(server.port() > 0);
7398 }
7399
7400 #[test]
7401 fn test_server_shutdown() {
7402 let app = make_test_app();
7403 let server = TestServer::start(app);
7404 let addr = server.addr();
7405
7406 let response = send_request(addr, b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n");
7408 assert!(response.contains("200 OK"));
7409
7410 server.shutdown();
7412 assert!(server.is_shutdown());
7413 }
7414
7415 #[test]
7416 fn test_server_config_no_logging() {
7417 let app = make_test_app();
7418 let config = TestServerConfig::new().log_requests(false);
7419 let server = TestServer::start_with_config(app, config);
7420
7421 send_request(
7422 server.addr(),
7423 b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
7424 );
7425
7426 assert_eq!(server.request_count(), 0);
7428 }
7429
7430 #[test]
7431 fn test_server_bad_request() {
7432 let app = make_test_app();
7433 let server = TestServer::start(app);
7434
7435 let response = send_request(server.addr(), b"NOT_HTTP_AT_ALL");
7437
7438 assert!(response.contains("400 Bad Request"));
7439 }
7440
7441 #[test]
7442 fn test_server_content_length_header() {
7443 let app = make_test_app();
7444 let server = TestServer::start(app);
7445
7446 let response = send_request(
7447 server.addr(),
7448 b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
7449 );
7450
7451 assert!(
7453 response.contains("content-length: 2"),
7454 "Expected content-length: 2, got: {response}"
7455 );
7456 }
7457
7458 #[test]
7459 fn test_server_multiple_requests_sequential() {
7460 let app = make_test_app();
7461 let server = TestServer::start(app);
7462
7463 for _ in 0..5 {
7464 let response = send_request(
7465 server.addr(),
7466 b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
7467 );
7468 assert!(response.contains("200 OK"));
7469 }
7470
7471 assert_eq!(server.request_count(), 5);
7472 }
7473
7474 #[test]
7475 fn test_server_log_entry_has_timing() {
7476 let app = make_test_app();
7477 let server = TestServer::start(app);
7478
7479 send_request(
7480 server.addr(),
7481 b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
7482 );
7483
7484 let logs = server.log_entries();
7485 assert_eq!(logs.len(), 1);
7486 assert!(logs[0].duration < Duration::from_secs(1));
7488 }
7489
7490 #[test]
7495 fn test_server_shutdown_controller_available() {
7496 let app = make_test_app();
7497 let server = TestServer::start(app);
7498
7499 let controller = server.shutdown_controller();
7501 assert!(!controller.is_shutting_down());
7502 assert_eq!(controller.phase(), crate::shutdown::ShutdownPhase::Running);
7503 }
7504
7505 #[test]
7506 fn test_server_shutdown_triggers_controller() {
7507 let app = make_test_app();
7508 let server = TestServer::start(app);
7509
7510 assert!(!server.shutdown_controller().is_shutting_down());
7512
7513 server.shutdown();
7515
7516 assert!(server.is_shutdown());
7518 assert!(server.shutdown_controller().is_shutting_down());
7519 assert_eq!(
7520 server.shutdown_controller().phase(),
7521 crate::shutdown::ShutdownPhase::StopAccepting
7522 );
7523 }
7524
7525 #[test]
7526 fn test_server_requests_complete_before_shutdown() {
7527 let app = make_test_app();
7528 let server = TestServer::start(app);
7529
7530 let response = send_request(
7532 server.addr(),
7533 b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
7534 );
7535 assert!(response.contains("200 OK"));
7536 assert_eq!(server.request_count(), 1);
7537
7538 server.shutdown();
7540
7541 let logs = server.log_entries();
7543 assert_eq!(logs.len(), 1);
7544 assert_eq!(logs[0].status, 200);
7545 assert_eq!(logs[0].path, "/health");
7546 }
7547
7548 #[test]
7549 fn test_server_in_flight_tracking() {
7550 let app = make_test_app();
7551 let server = TestServer::start(app);
7552
7553 assert_eq!(server.in_flight_count(), 0);
7555
7556 send_request(
7559 server.addr(),
7560 b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
7561 );
7562
7563 let start = std::time::Instant::now();
7567 let timeout = std::time::Duration::from_millis(500);
7568 while server.in_flight_count() > 0 && start.elapsed() < timeout {
7569 std::thread::sleep(std::time::Duration::from_millis(1));
7570 }
7571 assert_eq!(
7572 server.in_flight_count(),
7573 0,
7574 "In-flight count should return to 0 after request completes"
7575 );
7576 }
7577
7578 #[test]
7579 fn test_server_in_flight_guard_tracks_correctly() {
7580 let app = make_test_app();
7581 let server = TestServer::start(app);
7582
7583 let controller = server.shutdown_controller();
7585 assert_eq!(controller.in_flight_count(), 0);
7586
7587 let guard1 = controller.track_request();
7588 assert_eq!(controller.in_flight_count(), 1);
7589
7590 let guard2 = controller.track_request();
7591 assert_eq!(controller.in_flight_count(), 2);
7592
7593 drop(guard1);
7594 assert_eq!(controller.in_flight_count(), 1);
7595
7596 drop(guard2);
7597 assert_eq!(controller.in_flight_count(), 0);
7598 }
7599
7600 #[test]
7601 fn test_server_shutdown_hooks_executed() {
7602 let app = make_test_app();
7603 let server = TestServer::start(app);
7604
7605 let hook_executed = Arc::new(AtomicBool::new(false));
7607 let hook_executed_clone = Arc::clone(&hook_executed);
7608 server.shutdown_controller().register_hook(move || {
7609 hook_executed_clone.store(true, std::sync::atomic::Ordering::Release);
7610 });
7611
7612 assert!(!hook_executed.load(std::sync::atomic::Ordering::Acquire));
7613
7614 server.shutdown();
7616
7617 drop(server);
7620
7621 assert!(
7622 hook_executed.load(std::sync::atomic::Ordering::Acquire),
7623 "Shutdown hook should have been executed"
7624 );
7625 }
7626
7627 #[test]
7628 fn test_server_multiple_shutdown_hooks_lifo() {
7629 let app = make_test_app();
7630 let server = TestServer::start(app);
7631
7632 let execution_order = Arc::new(Mutex::new(Vec::new()));
7633
7634 let order1 = Arc::clone(&execution_order);
7635 server.shutdown_controller().register_hook(move || {
7636 order1.lock().push(1);
7637 });
7638
7639 let order2 = Arc::clone(&execution_order);
7640 server.shutdown_controller().register_hook(move || {
7641 order2.lock().push(2);
7642 });
7643
7644 let order3 = Arc::clone(&execution_order);
7645 server.shutdown_controller().register_hook(move || {
7646 order3.lock().push(3);
7647 });
7648
7649 server.shutdown();
7651 drop(server);
7652
7653 let order = execution_order.lock();
7655 assert_eq!(*order, vec![3, 2, 1]);
7656 }
7657
7658 #[test]
7659 fn test_server_shutdown_controller_phase_progression() {
7660 let app = make_test_app();
7661 let server = TestServer::start(app);
7662
7663 let controller = server.shutdown_controller();
7664 assert_eq!(controller.phase(), crate::shutdown::ShutdownPhase::Running);
7665
7666 assert!(controller.advance_phase());
7668 assert_eq!(
7669 controller.phase(),
7670 crate::shutdown::ShutdownPhase::StopAccepting
7671 );
7672
7673 assert!(controller.advance_phase());
7674 assert_eq!(
7675 controller.phase(),
7676 crate::shutdown::ShutdownPhase::ShutdownFlagged
7677 );
7678
7679 assert!(controller.advance_phase());
7680 assert_eq!(
7681 controller.phase(),
7682 crate::shutdown::ShutdownPhase::GracePeriod
7683 );
7684
7685 assert!(controller.advance_phase());
7686 assert_eq!(
7687 controller.phase(),
7688 crate::shutdown::ShutdownPhase::Cancelling
7689 );
7690
7691 assert!(controller.advance_phase());
7692 assert_eq!(
7693 controller.phase(),
7694 crate::shutdown::ShutdownPhase::RunningHooks
7695 );
7696
7697 assert!(controller.advance_phase());
7698 assert_eq!(controller.phase(), crate::shutdown::ShutdownPhase::Stopped);
7699
7700 assert!(!controller.advance_phase());
7702 }
7703
7704 #[test]
7705 fn test_server_receiver_notified_on_shutdown() {
7706 let app = make_test_app();
7707 let server = TestServer::start(app);
7708
7709 let receiver = server.shutdown_controller().subscribe();
7710 assert!(!receiver.is_shutting_down());
7711
7712 server.shutdown();
7713 assert!(receiver.is_shutting_down());
7714 assert!(!receiver.is_forced());
7715 }
7716
7717 #[test]
7718 fn test_server_forced_shutdown() {
7719 let app = make_test_app();
7720 let server = TestServer::start(app);
7721
7722 let receiver = server.shutdown_controller().subscribe();
7723
7724 server.shutdown_controller().shutdown();
7726 assert!(receiver.is_shutting_down());
7727 assert!(!receiver.is_forced());
7728
7729 server.shutdown_controller().shutdown();
7731 assert!(receiver.is_forced());
7732 }
7733
7734 #[test]
7735 fn test_server_requests_work_before_shutdown_signal() {
7736 let app = make_test_app();
7737 let server = TestServer::start(app);
7738
7739 for i in 0..3 {
7741 let response = send_request(
7742 server.addr(),
7743 b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
7744 );
7745 assert!(
7746 response.contains("200 OK"),
7747 "Request {i} should succeed before shutdown"
7748 );
7749 }
7750
7751 assert_eq!(server.request_count(), 3);
7752
7753 server.shutdown();
7755 assert!(server.is_shutdown());
7756 }
7757}