use parking_lot::Mutex;
use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use asupersync::Cx;
use crate::context::RequestContext;
use crate::dependency::{DependencyOverrides, FromDependency};
use crate::middleware::Handler;
use crate::request::{Body, Method, Request};
use crate::response::{Response, ResponseBody, StatusCode};
#[derive(Debug, Clone, Default)]
pub struct CookieJar {
cookies: Vec<StoredCookie>,
next_id: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CookieSameSite {
Lax,
Strict,
None,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum CookieDomain {
Any,
HostOnly(String),
Domain(String),
}
#[derive(Debug, Clone)]
struct StoredCookie {
id: u64,
name: String,
value: String,
domain: CookieDomain,
path: String,
secure: bool,
#[allow(dead_code)]
http_only: bool,
#[allow(dead_code)]
same_site: Option<CookieSameSite>,
expires_at: Option<std::time::SystemTime>,
}
impl CookieJar {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set(&mut self, name: impl Into<String>, value: impl Into<String>) {
let name = name.into();
let value = value.into();
self.cookies
.retain(|c| !(c.name == name && c.domain == CookieDomain::Any));
let id = self.next_id;
self.next_id = self.next_id.wrapping_add(1);
self.cookies.push(StoredCookie {
id,
name,
value,
domain: CookieDomain::Any,
path: "/".to_string(),
secure: false,
http_only: false,
same_site: None,
expires_at: None,
});
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&str> {
self.cookies
.iter()
.filter(|c| c.name == name)
.max_by_key(|c| c.id)
.map(|c| c.value.as_str())
}
pub fn remove(&mut self, name: &str) -> Option<String> {
let mut removed: Option<String> = None;
self.cookies.retain(|c| {
if c.name == name {
removed = Some(c.value.clone());
false
} else {
true
}
});
removed
}
pub fn clear(&mut self) {
self.cookies.clear();
self.next_id = 0;
}
#[must_use]
pub fn len(&self) -> usize {
self.cookies.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.cookies.is_empty()
}
#[must_use]
pub fn to_cookie_header(&self) -> Option<String> {
let mut by_name: HashMap<&str, &StoredCookie> = HashMap::new();
for c in &self.cookies {
match by_name.get(c.name.as_str()) {
Some(existing) if existing.id >= c.id => {}
_ => {
by_name.insert(c.name.as_str(), c);
}
}
}
if by_name.is_empty() {
return None;
}
Some(
by_name
.into_values()
.map(|c| format!("{}={}", c.name, c.value))
.collect::<Vec<_>>()
.join("; "),
)
}
#[must_use]
pub fn cookie_header_for_request(&self, request: &Request) -> Option<String> {
let host = request_host(request);
let path = request.path();
let is_secure = request_is_secure(request);
let now = std::time::SystemTime::now();
let mut selected: HashMap<&str, &StoredCookie> = HashMap::new();
for c in &self.cookies {
if c.secure && !is_secure {
continue;
}
if let Some(exp) = c.expires_at {
if exp <= now {
continue;
}
}
if !domain_matches(&c.domain, host.as_deref()) {
continue;
}
if !path_matches(&c.path, path) {
continue;
}
match selected.get(c.name.as_str()) {
None => {
selected.insert(c.name.as_str(), c);
}
Some(existing) => {
let a = (c.path.len(), c.id);
let b = (existing.path.len(), existing.id);
if a > b {
selected.insert(c.name.as_str(), c);
}
}
}
}
if selected.is_empty() {
return None;
}
Some(
selected
.into_values()
.map(|c| format!("{}={}", c.name, c.value))
.collect::<Vec<_>>()
.join("; "),
)
}
pub fn parse_set_cookie(&mut self, request: &Request, header_value: &[u8]) {
let Ok(value) = std::str::from_utf8(header_value) else {
return;
};
self.parse_set_cookie_str(request, value);
}
fn parse_set_cookie_str(&mut self, request: &Request, value: &str) {
let mut parts = value.split(';');
let Some((name, val)) = parse_cookie_name_value(parts.next()) else {
return;
};
let host = request_host(request);
let attrs = parse_set_cookie_attrs(parts);
let Some(domain) = cookie_domain_for_set_cookie(host.as_deref(), attrs.domain) else {
return;
};
let path = attrs
.path
.unwrap_or_else(|| default_cookie_path(request.path()));
let expires_at = match compute_cookie_expiration(attrs.max_age, attrs.expires_at) {
CookieExpiration::Delete => {
self.remove_by_key(name, &domain, &path);
return;
}
CookieExpiration::Keep(expires_at) => expires_at,
};
let id = self.next_id;
self.next_id = self.next_id.wrapping_add(1);
self.upsert(StoredCookie {
id,
name: name.to_string(),
value: val.to_string(),
domain,
path,
secure: attrs.secure,
http_only: attrs.http_only,
same_site: attrs.same_site,
expires_at,
});
}
fn remove_by_key(&mut self, name: &str, domain: &CookieDomain, path: &str) {
self.cookies
.retain(|c| !(c.name == name && &c.domain == domain && c.path == path));
}
fn upsert(&mut self, cookie: StoredCookie) {
for existing in &mut self.cookies {
if existing.name == cookie.name
&& existing.domain == cookie.domain
&& existing.path == cookie.path
{
*existing = cookie;
return;
}
}
self.cookies.push(cookie);
}
}
#[derive(Debug, Default)]
struct SetCookieAttrs {
domain: Option<String>,
path: Option<String>,
max_age: Option<i64>,
secure: bool,
http_only: bool,
same_site: Option<CookieSameSite>,
expires_at: Option<std::time::SystemTime>,
}
#[derive(Debug, Clone, Copy)]
enum CookieExpiration {
Delete,
Keep(Option<std::time::SystemTime>),
}
fn parse_cookie_name_value(first: Option<&str>) -> Option<(&str, &str)> {
let first = first?;
let (name, val) = first.split_once('=')?;
let name = name.trim();
if name.is_empty() {
return None;
}
Some((name, val.trim()))
}
fn parse_set_cookie_attrs<'a>(parts: impl Iterator<Item = &'a str>) -> SetCookieAttrs {
let mut attrs = SetCookieAttrs::default();
for raw in parts {
let raw = raw.trim();
if raw.is_empty() {
continue;
}
if let Some((k, v)) = raw.split_once('=') {
let k = k.trim().to_ascii_lowercase();
let v = v.trim();
match k.as_str() {
"domain" => {
let mut d = v.trim_matches('"').trim().to_ascii_lowercase();
if let Some(stripped) = d.strip_prefix('.') {
d = stripped.to_string();
}
if !d.is_empty() {
attrs.domain = Some(d);
}
}
"path" => {
let p = v.trim_matches('"').trim();
if p.starts_with('/') {
attrs.path = Some(p.to_string());
}
}
"max-age" => {
if let Ok(n) = v.parse::<i64>() {
attrs.max_age = Some(n);
}
}
"samesite" => {
let ss = v.trim_matches('"').trim();
attrs.same_site = match ss.to_ascii_lowercase().as_str() {
"lax" => Some(CookieSameSite::Lax),
"strict" => Some(CookieSameSite::Strict),
"none" => Some(CookieSameSite::None),
_ => None,
};
}
"expires" => {
if let Some(t) = parse_http_date(v) {
attrs.expires_at = Some(t);
}
}
_ => {}
}
} else {
match raw.to_ascii_lowercase().as_str() {
"secure" => attrs.secure = true,
"httponly" => attrs.http_only = true,
_ => {}
}
}
}
attrs
}
fn cookie_domain_for_set_cookie(
host: Option<&str>,
domain_attr: Option<String>,
) -> Option<CookieDomain> {
match domain_attr {
Some(d) => {
let h = host?;
let domain = CookieDomain::Domain(d);
if domain_matches(&domain, Some(h)) {
Some(domain)
} else {
None
}
}
None => {
let h = host?;
Some(CookieDomain::HostOnly(h.to_string()))
}
}
}
fn compute_cookie_expiration(
max_age: Option<i64>,
mut expires_at: Option<std::time::SystemTime>,
) -> CookieExpiration {
let now = std::time::SystemTime::now();
if let Some(n) = max_age {
if n <= 0 {
return CookieExpiration::Delete;
}
let Ok(secs) = u64::try_from(n) else {
return CookieExpiration::Delete;
};
expires_at = now.checked_add(std::time::Duration::from_secs(secs));
}
if let Some(exp) = expires_at {
if exp <= now {
return CookieExpiration::Delete;
}
}
CookieExpiration::Keep(expires_at)
}
fn request_host(request: &Request) -> Option<String> {
let host = request.headers().get("host")?;
let s = std::str::from_utf8(host).ok()?;
let host = s.trim();
if host.is_empty() {
return None;
}
Some(host.split(':').next().unwrap_or(host).to_ascii_lowercase())
}
fn request_is_secure(request: &Request) -> bool {
if let Some(info) = request.get_extension::<crate::request::ConnectionInfo>() {
if info.is_tls {
return true;
}
}
if let Some(forwarded) = request.headers().get("forwarded") {
if let Ok(s) = std::str::from_utf8(forwarded) {
for entry in s.split(',') {
for param in entry.split(';') {
let param = param.trim();
if let Some((k, v)) = param.split_once('=') {
if k.trim().eq_ignore_ascii_case("proto") {
let proto = v.trim().trim_matches('"');
if proto.eq_ignore_ascii_case("https") {
return true;
}
}
}
}
}
}
}
if let Some(proto) = request.headers().get("x-forwarded-proto") {
let first = proto.split(|&b| b == b',').next().unwrap_or(proto);
let first = trim_ascii_bytes(first);
return first.eq_ignore_ascii_case(b"https");
}
if let Some(ssl) = request.headers().get("x-forwarded-ssl") {
return ssl.eq_ignore_ascii_case(b"on");
}
if let Some(https) = request.headers().get("front-end-https") {
return https.eq_ignore_ascii_case(b"on");
}
false
}
fn trim_ascii_bytes(mut bytes: &[u8]) -> &[u8] {
while matches!(bytes.first(), Some(b' ' | b'\t')) {
bytes = &bytes[1..];
}
while matches!(bytes.last(), Some(b' ' | b'\t')) {
bytes = &bytes[..bytes.len() - 1];
}
bytes
}
fn default_cookie_path(request_path: &str) -> String {
if !request_path.starts_with('/') {
return "/".to_string();
}
if request_path == "/" {
return "/".to_string();
}
match request_path.rfind('/') {
Some(0) | None => "/".to_string(),
Some(idx) => request_path[..idx].to_string(),
}
}
fn domain_matches(domain: &CookieDomain, host: Option<&str>) -> bool {
match domain {
CookieDomain::Any => true,
CookieDomain::HostOnly(d) => host.is_some_and(|h| h.eq_ignore_ascii_case(d)),
CookieDomain::Domain(d) => {
let Some(h) = host else { return false };
if h.eq_ignore_ascii_case(d) {
return true;
}
h.len() > d.len() && h.ends_with(d) && h.as_bytes()[h.len() - d.len() - 1] == b'.'
}
}
}
fn path_matches(cookie_path: &str, request_path: &str) -> bool {
if cookie_path == "/" {
return request_path.starts_with('/');
}
if !request_path.starts_with(cookie_path) {
return false;
}
if cookie_path.ends_with('/') {
return true;
}
request_path
.as_bytes()
.get(cookie_path.len())
.is_none_or(|&b| b == b'/')
}
fn parse_http_date(input: &str) -> Option<std::time::SystemTime> {
let s = input.trim().trim_matches('"').trim();
let (_dow, rest) = s.split_once(',')?;
let rest = rest.trim();
let mut it = rest.split_whitespace();
let day = it.next()?.parse::<u32>().ok()?;
let month = match it.next()? {
"Jan" => 1,
"Feb" => 2,
"Mar" => 3,
"Apr" => 4,
"May" => 5,
"Jun" => 6,
"Jul" => 7,
"Aug" => 8,
"Sep" => 9,
"Oct" => 10,
"Nov" => 11,
"Dec" => 12,
_ => return None,
};
let year = it.next()?.parse::<i32>().ok()?;
let time = it.next()?;
let tz = it.next()?;
if tz != "GMT" {
return None;
}
let (hh, mm, ss) = {
let mut t = time.split(':');
let hh = t.next()?.parse::<u32>().ok()?;
let mm = t.next()?.parse::<u32>().ok()?;
let ss = t.next()?.parse::<u32>().ok()?;
(hh, mm, ss)
};
fn days_from_civil(y: i32, m: u32, d: u32) -> i64 {
let y = i64::from(y) - i64::from(m <= 2);
let era = (if y >= 0 { y } else { y - 399 }) / 400;
let yoe = y - era * 400;
let m = i64::from(m);
let doy = (153 * (m + if m > 2 { -3 } else { 9 }) + 2) / 5 + i64::from(d) - 1;
let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
era * 146097 + doe - 719468
}
let days = days_from_civil(year, month, day);
let secs = days
.checked_mul(86_400)?
.checked_add(i64::from(hh) * 3600 + i64::from(mm) * 60 + i64::from(ss))?;
if secs < 0 {
return None;
}
let secs_u64 = u64::try_from(secs).ok()?;
Some(std::time::UNIX_EPOCH + std::time::Duration::from_secs(secs_u64))
}
pub struct TestClient<H> {
handler: Arc<H>,
cookies: Arc<Mutex<CookieJar>>,
dependency_overrides: Arc<DependencyOverrides>,
seed: Option<u64>,
request_id_counter: Arc<std::sync::atomic::AtomicU64>,
}
impl<H: Handler + 'static> TestClient<H> {
pub fn new(handler: H) -> Self {
let dependency_overrides = handler
.dependency_overrides()
.unwrap_or_else(|| Arc::new(DependencyOverrides::new()));
Self {
handler: Arc::new(handler),
cookies: Arc::new(Mutex::new(CookieJar::new())),
dependency_overrides,
seed: None,
request_id_counter: Arc::new(std::sync::atomic::AtomicU64::new(1)),
}
}
pub fn with_seed(handler: H, seed: u64) -> Self {
let dependency_overrides = handler
.dependency_overrides()
.unwrap_or_else(|| Arc::new(DependencyOverrides::new()));
Self {
handler: Arc::new(handler),
cookies: Arc::new(Mutex::new(CookieJar::new())),
dependency_overrides,
seed: Some(seed),
request_id_counter: Arc::new(std::sync::atomic::AtomicU64::new(1)),
}
}
#[must_use]
pub fn seed(&self) -> Option<u64> {
self.seed
}
pub fn cookies(&self) -> parking_lot::MutexGuard<'_, CookieJar> {
self.cookies.lock()
}
pub fn clear_cookies(&self) {
self.cookies().clear();
}
#[must_use]
pub fn get(&self, path: &str) -> RequestBuilder<'_, H> {
RequestBuilder::new(self, Method::Get, path)
}
#[must_use]
pub fn post(&self, path: &str) -> RequestBuilder<'_, H> {
RequestBuilder::new(self, Method::Post, path)
}
#[must_use]
pub fn put(&self, path: &str) -> RequestBuilder<'_, H> {
RequestBuilder::new(self, Method::Put, path)
}
#[must_use]
pub fn delete(&self, path: &str) -> RequestBuilder<'_, H> {
RequestBuilder::new(self, Method::Delete, path)
}
#[must_use]
pub fn patch(&self, path: &str) -> RequestBuilder<'_, H> {
RequestBuilder::new(self, Method::Patch, path)
}
#[must_use]
pub fn options(&self, path: &str) -> RequestBuilder<'_, H> {
RequestBuilder::new(self, Method::Options, path)
}
#[must_use]
pub fn head(&self, path: &str) -> RequestBuilder<'_, H> {
RequestBuilder::new(self, Method::Head, path)
}
#[must_use]
pub fn request(&self, method: Method, path: &str) -> RequestBuilder<'_, H> {
RequestBuilder::new(self, method, path)
}
pub fn override_dependency<T, F, Fut>(&self, f: F)
where
T: FromDependency,
F: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<T, T::Error>> + Send + 'static,
{
self.dependency_overrides.insert::<T, F, Fut>(f);
}
pub fn override_dependency_value<T>(&self, value: T)
where
T: FromDependency,
{
self.dependency_overrides.insert_value(value);
}
pub fn clear_dependency_overrides(&self) {
self.dependency_overrides.clear();
}
fn next_request_id(&self) -> u64 {
self.request_id_counter
.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
}
fn execute(&self, mut request: Request) -> TestResponse {
if !request.headers().contains("host") {
request.headers_mut().insert("host", b"testserver".to_vec());
}
{
let jar = self.cookies();
if let Some(cookie_header) = jar.cookie_header_for_request(&request) {
request
.headers_mut()
.insert("cookie", cookie_header.into_bytes());
}
}
let cx = Cx::for_testing();
let request_id = self.next_request_id();
let ctx =
RequestContext::with_overrides(cx, request_id, Arc::clone(&self.dependency_overrides));
let response = futures_executor::block_on(self.handler.call(&ctx, &mut request));
{
let mut jar = self.cookies();
for (name, value) in response.headers() {
if name.eq_ignore_ascii_case("set-cookie") {
jar.parse_set_cookie(&request, value);
}
}
}
TestResponse::new(response, request_id)
}
}
impl<H> Clone for TestClient<H> {
fn clone(&self) -> Self {
Self {
handler: Arc::clone(&self.handler),
cookies: Arc::clone(&self.cookies),
dependency_overrides: Arc::clone(&self.dependency_overrides),
seed: self.seed,
request_id_counter: Arc::clone(&self.request_id_counter),
}
}
}
pub struct RequestBuilder<'a, H> {
client: &'a TestClient<H>,
method: Method,
path: String,
query: Option<String>,
headers: Vec<(String, Vec<u8>)>,
body: Body,
}
impl<'a, H: Handler + 'static> RequestBuilder<'a, H> {
fn new(client: &'a TestClient<H>, method: Method, path: &str) -> Self {
let (path, query) = if let Some(idx) = path.find('?') {
(path[..idx].to_string(), Some(path[idx + 1..].to_string()))
} else {
(path.to_string(), None)
};
Self {
client,
method,
path,
query,
headers: Vec::new(),
body: Body::Empty,
}
}
#[must_use]
pub fn query(mut self, key: &str, value: &str) -> Self {
let param = format!("{key}={value}");
self.query = Some(match self.query {
Some(q) => format!("{q}&{param}"),
None => param,
});
self
}
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
self.headers.push((name.into(), value.into()));
self
}
#[must_use]
pub fn header_str(self, name: impl Into<String>, value: &str) -> Self {
self.header(name, value.as_bytes().to_vec())
}
#[must_use]
pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
self.body = Body::Bytes(body.into());
self
}
#[must_use]
pub fn body_str(self, body: &str) -> Self {
self.body(body.as_bytes().to_vec())
}
#[must_use]
pub fn json<T: serde::Serialize>(mut self, value: &T) -> Self {
let bytes = serde_json::to_vec(value).expect("JSON serialization failed");
self.body = Body::Bytes(bytes);
self.headers
.push(("content-type".to_string(), b"application/json".to_vec()));
self
}
#[must_use]
pub fn cookie(self, name: &str, value: &str) -> Self {
let cookie = format!("{name}={value}");
self.header("cookie", cookie.into_bytes())
}
#[must_use]
pub fn send(self) -> TestResponse {
let mut request = Request::new(self.method, self.path);
request.set_query(self.query);
request.set_body(self.body);
for (name, value) in self.headers {
request.headers_mut().insert(name, value);
}
self.client.execute(request)
}
}
#[derive(Debug)]
pub struct TestResponse {
inner: Response,
request_id: u64,
}
impl TestResponse {
fn new(response: Response, request_id: u64) -> Self {
Self {
inner: response,
request_id,
}
}
#[must_use]
pub fn request_id(&self) -> u64 {
self.request_id
}
#[must_use]
pub fn status(&self) -> StatusCode {
self.inner.status()
}
#[must_use]
pub fn status_code(&self) -> u16 {
self.inner.status().as_u16()
}
#[must_use]
pub fn is_success(&self) -> bool {
let code = self.status_code();
(200..300).contains(&code)
}
#[must_use]
pub fn is_redirect(&self) -> bool {
let code = self.status_code();
(300..400).contains(&code)
}
#[must_use]
pub fn is_client_error(&self) -> bool {
let code = self.status_code();
(400..500).contains(&code)
}
#[must_use]
pub fn is_server_error(&self) -> bool {
let code = self.status_code();
(500..600).contains(&code)
}
#[must_use]
pub fn headers(&self) -> &[(String, Vec<u8>)] {
self.inner.headers()
}
#[must_use]
pub fn header(&self, name: &str) -> Option<&[u8]> {
let name_lower = name.to_ascii_lowercase();
self.inner
.headers()
.iter()
.find(|(n, _)| n.to_ascii_lowercase() == name_lower)
.map(|(_, v)| v.as_slice())
}
#[must_use]
pub fn header_str(&self, name: &str) -> Option<&str> {
self.header(name).and_then(|v| std::str::from_utf8(v).ok())
}
#[must_use]
pub fn content_type(&self) -> Option<&str> {
self.header_str("content-type")
}
#[must_use]
pub fn bytes(&self) -> &[u8] {
match self.inner.body_ref() {
ResponseBody::Empty => &[],
ResponseBody::Bytes(b) => b,
ResponseBody::Stream(_) => {
panic!("streaming response body not supported in TestResponse")
}
}
}
#[must_use]
pub fn text(&self) -> &str {
std::str::from_utf8(self.bytes()).expect("response body is not valid UTF-8")
}
#[must_use]
pub fn text_opt(&self) -> Option<&str> {
std::str::from_utf8(self.bytes()).ok()
}
pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
serde_json::from_slice(self.bytes())
}
#[must_use]
pub fn content_length(&self) -> usize {
self.bytes().len()
}
#[must_use]
pub fn into_inner(self) -> Response {
self.inner
}
#[must_use]
pub fn assert_status(&self, expected: StatusCode) -> &Self {
assert_eq!(
self.status(),
expected,
"Expected status {}, got {} for request {}",
expected.as_u16(),
self.status_code(),
self.request_id
);
self
}
#[must_use]
pub fn assert_status_code(&self, expected: u16) -> &Self {
assert_eq!(
self.status_code(),
expected,
"Expected status {expected}, got {} for request {}",
self.status_code(),
self.request_id
);
self
}
#[must_use]
pub fn assert_success(&self) -> &Self {
assert!(
self.is_success(),
"Expected success status, got {} for request {}",
self.status_code(),
self.request_id
);
self
}
#[must_use]
pub fn assert_header(&self, name: &str, expected: &str) -> &Self {
let actual = self.header_str(name);
assert_eq!(
actual,
Some(expected),
"Expected header '{name}' to be '{expected}', got {:?} for request {}",
actual,
self.request_id
);
self
}
#[must_use]
pub fn assert_text(&self, expected: &str) -> &Self {
assert_eq!(
self.text(),
expected,
"Body mismatch for request {}",
self.request_id
);
self
}
#[must_use]
pub fn assert_text_contains(&self, expected: &str) -> &Self {
assert!(
self.text().contains(expected),
"Expected body to contain '{}', got '{}' for request {}",
expected,
self.text(),
self.request_id
);
self
}
#[must_use]
pub fn assert_json<T>(&self, expected: &T) -> &Self
where
T: serde::de::DeserializeOwned + serde::Serialize + PartialEq + std::fmt::Debug,
{
let actual: T = self.json().expect("Failed to parse response as JSON");
assert_eq!(
actual, *expected,
"JSON body mismatch for request {}",
self.request_id
);
self
}
#[must_use]
pub fn assert_json_contains(&self, expected: &serde_json::Value) -> &Self {
let actual: serde_json::Value = self.json().expect("Failed to parse response as JSON");
if let Err(path) = json_contains(&actual, expected) {
panic!(
"JSON partial match failed at path '{}' for request {}\n\
Expected (partial):\n{}\n\
Actual:\n{}",
path,
self.request_id,
serde_json::to_string_pretty(expected).unwrap_or_else(|_| format!("{expected:?}")),
serde_json::to_string_pretty(&actual).unwrap_or_else(|_| format!("{actual:?}")),
);
}
self
}
#[cfg(feature = "regex")]
#[must_use]
pub fn assert_body_matches(&self, pattern: &str) -> &Self {
let re = regex::Regex::new(pattern)
.unwrap_or_else(|e| panic!("Invalid regex pattern '{pattern}': {e}"));
let body = self.text();
assert!(
re.is_match(body),
"Expected body to match pattern '{}', got '{}' for request {}",
pattern,
body,
self.request_id
);
self
}
#[cfg(feature = "regex")]
#[must_use]
pub fn assert_header_matches(&self, name: &str, pattern: &str) -> &Self {
let re = regex::Regex::new(pattern)
.unwrap_or_else(|e| panic!("Invalid regex pattern '{pattern}': {e}"));
let value = self
.header_str(name)
.unwrap_or_else(|| panic!("Header '{name}' not found for request {}", self.request_id));
assert!(
re.is_match(value),
"Expected header '{}' to match pattern '{}', got '{}' for request {}",
name,
pattern,
value,
self.request_id
);
self
}
#[must_use]
pub fn assert_header_exists(&self, name: &str) -> &Self {
assert!(
self.header(name).is_some(),
"Expected header '{}' to exist for request {}",
name,
self.request_id
);
self
}
#[must_use]
pub fn assert_header_missing(&self, name: &str) -> &Self {
assert!(
self.header(name).is_none(),
"Expected header '{}' to not exist for request {}, but found {:?}",
name,
self.request_id,
self.header_str(name)
);
self
}
#[must_use]
pub fn assert_content_type_contains(&self, expected: &str) -> &Self {
let ct = self.content_type().unwrap_or_else(|| {
panic!(
"Content-Type header not found for request {}",
self.request_id
)
});
assert!(
ct.contains(expected),
"Expected Content-Type to contain '{}', got '{}' for request {}",
expected,
ct,
self.request_id
);
self
}
}
pub fn json_contains(
actual: &serde_json::Value,
expected: &serde_json::Value,
) -> Result<(), String> {
json_contains_at_path(actual, expected, "$")
}
fn json_contains_at_path(
actual: &serde_json::Value,
expected: &serde_json::Value,
path: &str,
) -> Result<(), String> {
use serde_json::Value;
match (actual, expected) {
(Value::Object(actual_obj), Value::Object(expected_obj)) => {
for (key, expected_val) in expected_obj {
let child_path = format!("{path}.{key}");
match actual_obj.get(key) {
Some(actual_val) => {
json_contains_at_path(actual_val, expected_val, &child_path)?;
}
None => {
return Err(child_path);
}
}
}
Ok(())
}
(Value::Array(actual_arr), Value::Array(expected_arr)) => {
if actual_arr.len() != expected_arr.len() {
return Err(format!("{path}[length]"));
}
for (i, (actual_elem, expected_elem)) in
actual_arr.iter().zip(expected_arr.iter()).enumerate()
{
let child_path = format!("{path}[{i}]");
json_contains_at_path(actual_elem, expected_elem, &child_path)?;
}
Ok(())
}
_ => {
if actual == expected {
Ok(())
} else {
Err(path.to_string())
}
}
}
}
pub trait IntoStatusU16 {
fn into_status_u16(self) -> u16;
}
impl IntoStatusU16 for u16 {
fn into_status_u16(self) -> u16 {
self
}
}
impl IntoStatusU16 for StatusCode {
fn into_status_u16(self) -> u16 {
self.as_u16()
}
}
impl IntoStatusU16 for i32 {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
fn into_status_u16(self) -> u16 {
self as u16
}
}
#[macro_export]
macro_rules! assert_status {
($response:expr, $expected:expr) => {{
let response = &$response;
let actual = response.status_code();
let expected_code: u16 = $crate::testing::IntoStatusU16::into_status_u16($expected);
if actual != expected_code {
panic!(
"assertion failed: `(response.status() == {})`\n\
expected status: {}\n\
actual status: {}\n\
request id: {}\n\
response body: {}",
expected_code,
expected_code,
actual,
response.request_id(),
response.text_opt().unwrap_or("<non-UTF8 body>")
);
}
}};
($response:expr, $expected:expr, $($msg:tt)+) => {{
let response = &$response;
let actual = response.status_code();
let expected_code: u16 = $crate::testing::IntoStatusU16::into_status_u16($expected);
if actual != expected_code {
panic!(
"{}\n\
assertion failed: `(response.status() == {})`\n\
expected status: {}\n\
actual status: {}\n\
request id: {}\n\
response body: {}",
format_args!($($msg)+),
expected_code,
expected_code,
actual,
response.request_id(),
response.text_opt().unwrap_or("<non-UTF8 body>")
);
}
}};
}
#[macro_export]
macro_rules! assert_header {
($response:expr, $name:expr, $expected:expr) => {{
let response = &$response;
let name = $name;
let expected = $expected;
let actual = response.header_str(name);
if actual != Some(expected) {
panic!(
"assertion failed: `(response.header(\"{}\") == \"{}\")`\n\
expected header '{}': \"{}\"\n\
actual header '{}': {:?}\n\
request id: {}",
name,
expected,
name,
expected,
name,
actual,
response.request_id()
);
}
}};
($response:expr, $name:expr, $expected:expr, $($msg:tt)+) => {{
let response = &$response;
let name = $name;
let expected = $expected;
let actual = response.header_str(name);
if actual != Some(expected) {
panic!(
"{}\n\
assertion failed: `(response.header(\"{}\") == \"{}\")`\n\
expected header '{}': \"{}\"\n\
actual header '{}': {:?}\n\
request id: {}",
format_args!($($msg)+),
name,
expected,
name,
expected,
name,
actual,
response.request_id()
);
}
}};
}
#[macro_export]
macro_rules! assert_body_contains {
($response:expr, $expected:expr) => {{
let response = &$response;
let expected = $expected;
let body = response.text();
if !body.contains(expected) {
panic!(
"assertion failed: response body does not contain \"{}\"\n\
expected substring: \"{}\"\n\
actual body: \"{}\"\n\
request id: {}",
expected, expected, body, response.request_id()
);
}
}};
($response:expr, $expected:expr, $($msg:tt)+) => {{
let response = &$response;
let expected = $expected;
let body = response.text();
if !body.contains(expected) {
panic!(
"{}\n\
assertion failed: response body does not contain \"{}\"\n\
expected substring: \"{}\"\n\
actual body: \"{}\"\n\
request id: {}",
format_args!($($msg)+),
expected,
expected,
body,
response.request_id()
);
}
}};
}
#[macro_export]
macro_rules! assert_json {
($response:expr, $expected:tt) => {{
let response = &$response;
let expected = serde_json::json!($expected);
let actual: serde_json::Value = response
.json()
.expect("Failed to parse response body as JSON");
if let Err(path) = $crate::testing::json_contains(&actual, &expected) {
panic!(
"assertion failed: JSON partial match failed at path '{}'\n\
expected (partial):\n{}\n\
actual:\n{}\n\
request id: {}",
path,
serde_json::to_string_pretty(&expected).unwrap_or_else(|_| format!("{:?}", expected)),
serde_json::to_string_pretty(&actual).unwrap_or_else(|_| format!("{:?}", actual)),
response.request_id()
);
}
}};
($response:expr, $expected:tt, $($msg:tt)+) => {{
let response = &$response;
let expected = serde_json::json!($expected);
let actual: serde_json::Value = response
.json()
.expect("Failed to parse response body as JSON");
if let Err(path) = $crate::testing::json_contains(&actual, &expected) {
panic!(
"{}\n\
assertion failed: JSON partial match failed at path '{}'\n\
expected (partial):\n{}\n\
actual:\n{}\n\
request id: {}",
format_args!($($msg)+),
path,
serde_json::to_string_pretty(&expected).unwrap_or_else(|_| format!("{:?}", expected)),
serde_json::to_string_pretty(&actual).unwrap_or_else(|_| format!("{:?}", actual)),
response.request_id()
);
}
}};
}
#[cfg(feature = "regex")]
#[macro_export]
macro_rules! assert_body_matches {
($response:expr, $pattern:expr) => {{
let response = &$response;
let pattern = $pattern;
let re = regex::Regex::new(pattern)
.unwrap_or_else(|e| panic!("Invalid regex pattern '{}': {}", pattern, e));
let body = response.text();
if !re.is_match(body) {
panic!(
"assertion failed: response body does not match pattern\n\
pattern: \"{}\"\n\
actual body: \"{}\"\n\
request id: {}",
pattern, body, response.request_id()
);
}
}};
($response:expr, $pattern:expr, $($msg:tt)+) => {{
let response = &$response;
let pattern = $pattern;
let re = regex::Regex::new(pattern)
.unwrap_or_else(|e| panic!("Invalid regex pattern '{}': {}", pattern, e));
let body = response.text();
if !re.is_match(body) {
panic!(
"{}\n\
assertion failed: response body does not match pattern\n\
pattern: \"{}\"\n\
actual body: \"{}\"\n\
request id: {}",
format_args!($($msg)+),
pattern,
body,
response.request_id()
);
}
}};
}
#[derive(Debug, Clone)]
pub struct LabTestConfig {
pub seed: u64,
pub chaos_enabled: bool,
pub chaos_intensity: f64,
pub max_steps: Option<u64>,
pub capture_traces: bool,
}
impl Default for LabTestConfig {
fn default() -> Self {
Self {
seed: 42,
chaos_enabled: false,
chaos_intensity: 0.0,
max_steps: Some(10_000),
capture_traces: false,
}
}
}
impl LabTestConfig {
#[must_use]
pub fn new(seed: u64) -> Self {
Self {
seed,
..Default::default()
}
}
#[must_use]
pub fn with_light_chaos(mut self) -> Self {
self.chaos_enabled = true;
self.chaos_intensity = 0.05;
self
}
#[must_use]
pub fn with_heavy_chaos(mut self) -> Self {
self.chaos_enabled = true;
self.chaos_intensity = 0.2;
self
}
#[must_use]
pub fn with_chaos_intensity(mut self, intensity: f64) -> Self {
self.chaos_enabled = intensity > 0.0;
self.chaos_intensity = intensity.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn with_max_steps(mut self, max: u64) -> Self {
self.max_steps = Some(max);
self
}
#[must_use]
pub fn without_step_limit(mut self) -> Self {
self.max_steps = None;
self
}
#[must_use]
pub fn with_traces(mut self) -> Self {
self.capture_traces = true;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct TestChaosStats {
pub decision_points: u64,
pub delays_injected: u64,
pub cancellations_injected: u64,
pub steps_executed: u64,
}
impl TestChaosStats {
#[must_use]
#[allow(clippy::cast_precision_loss)] pub fn injection_rate(&self) -> f64 {
if self.decision_points == 0 {
0.0
} else {
(self.delays_injected + self.cancellations_injected) as f64
/ self.decision_points as f64
}
}
#[must_use]
pub fn had_chaos(&self) -> bool {
self.delays_injected > 0 || self.cancellations_injected > 0
}
}
#[derive(Debug, Clone)]
pub struct MockTime {
current_us: Arc<std::sync::atomic::AtomicU64>,
}
impl Default for MockTime {
fn default() -> Self {
Self::new()
}
}
impl MockTime {
#[must_use]
pub fn new() -> Self {
Self {
current_us: Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
#[must_use]
pub fn starting_at(initial: std::time::Duration) -> Self {
Self {
current_us: Arc::new(std::sync::atomic::AtomicU64::new(initial.as_micros() as u64)),
}
}
#[must_use]
pub fn now(&self) -> std::time::Duration {
std::time::Duration::from_micros(self.current_us.load(std::sync::atomic::Ordering::Relaxed))
}
#[must_use]
pub fn elapsed(&self) -> std::time::Duration {
self.now()
}
pub fn advance(&self, duration: std::time::Duration) {
self.current_us.fetch_add(
duration.as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
pub fn set(&self, time: std::time::Duration) {
self.current_us.store(
time.as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
pub fn reset(&self) {
self.current_us
.store(0, std::sync::atomic::Ordering::Relaxed);
}
}
#[derive(Debug)]
pub struct CancellationTestResult {
pub completed: bool,
pub cancelled_at_checkpoint: bool,
pub response: Option<Response>,
pub cancellation_point: Option<String>,
}
impl CancellationTestResult {
#[must_use]
pub fn gracefully_cancelled(&self) -> bool {
!self.completed && self.cancelled_at_checkpoint
}
#[must_use]
pub fn completed_despite_cancel(&self) -> bool {
self.completed
}
}
pub struct CancellationTest<H> {
handler: H,
seed: u64,
}
impl<H: Handler + 'static> CancellationTest<H> {
#[must_use]
pub fn new(handler: H) -> Self {
Self { handler, seed: 42 }
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn test_respects_cancellation(&self) -> CancellationTestResult {
let cx = asupersync::Cx::for_testing();
let ctx = RequestContext::new(cx, 1);
ctx.cx().set_cancel_requested(true);
let mut request = Request::new(Method::Get, "/test");
let response = futures_executor::block_on(self.handler.call(&ctx, &mut request));
let is_cancelled_response = response.status().as_u16() == 499
|| response.status().as_u16() == 504
|| response.status().as_u16() == 503;
CancellationTestResult {
completed: true,
cancelled_at_checkpoint: is_cancelled_response,
response: Some(response),
cancellation_point: None,
}
}
pub fn complete_normally(&self) -> CancellationTestResult {
let cx = asupersync::Cx::for_testing();
let ctx = RequestContext::new(cx, 1);
let mut request = Request::new(Method::Get, "/test");
let response = futures_executor::block_on(self.handler.call(&ctx, &mut request));
CancellationTestResult {
completed: true,
cancelled_at_checkpoint: false,
response: Some(response),
cancellation_point: None,
}
}
pub fn test_with_cancel_callback<F>(
&self,
path: &str,
mut cancel_fn: F,
) -> CancellationTestResult
where
F: FnMut(&RequestContext) -> bool,
{
let cx = asupersync::Cx::for_testing();
let ctx = RequestContext::new(cx, 1);
if cancel_fn(&ctx) {
ctx.cx().set_cancel_requested(true);
}
let mut request = Request::new(Method::Get, path);
let response = futures_executor::block_on(self.handler.call(&ctx, &mut request));
let is_cancelled = ctx.is_cancelled();
let is_cancelled_response =
response.status().as_u16() == 499 || response.status().as_u16() == 504;
CancellationTestResult {
completed: true,
cancelled_at_checkpoint: is_cancelled && is_cancelled_response,
response: Some(response),
cancellation_point: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::app::App;
use crate::dependency::{Depends, FromDependency};
use crate::error::HttpError;
use crate::extract::FromRequest;
use crate::middleware::BoxFuture;
struct EchoHandler;
impl Handler for EchoHandler {
fn call<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, Response> {
let method = format!("{:?}", req.method());
let path = req.path().to_string();
let body = format!("Method: {method}, Path: {path}");
Box::pin(async move {
Response::ok()
.header("content-type", b"text/plain".to_vec())
.body(ResponseBody::Bytes(body.into_bytes()))
})
}
}
struct CookieHandler;
impl Handler for CookieHandler {
fn call<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
Response::ok()
.header("set-cookie", b"session=abc123".to_vec())
.body(ResponseBody::Bytes(b"Cookie set".to_vec()))
})
}
}
struct CookieEchoHandler;
impl Handler for CookieEchoHandler {
fn call<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, Response> {
let cookie = req.headers().get("cookie").map_or_else(
|| "no cookies".to_string(),
|v| String::from_utf8_lossy(v).to_string(),
);
Box::pin(async move { Response::ok().body(ResponseBody::Bytes(cookie.into_bytes())) })
}
}
#[derive(Clone)]
struct OverrideDep {
value: usize,
}
impl FromDependency for OverrideDep {
type Error = HttpError;
async fn from_dependency(
_ctx: &RequestContext,
_req: &mut Request,
) -> Result<Self, Self::Error> {
Ok(Self { value: 1 })
}
}
struct OverrideDepHandler;
impl Handler for OverrideDepHandler {
fn call<'a>(
&'a self,
ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
let dep = Depends::<OverrideDep>::from_request(ctx, req)
.await
.expect("dependency extraction failed");
Response::ok().body(ResponseBody::Bytes(dep.value.to_string().into_bytes()))
})
}
}
fn override_dep_route(ctx: &RequestContext, req: &mut Request) -> std::future::Ready<Response> {
let dep = futures_executor::block_on(Depends::<OverrideDep>::from_request(ctx, req))
.expect("dependency extraction failed");
std::future::ready(
Response::ok().body(ResponseBody::Bytes(dep.value.to_string().into_bytes())),
)
}
#[test]
fn test_client_get() {
let client = TestClient::new(EchoHandler);
let response = client.get("/test/path").send();
assert_eq!(response.status_code(), 200);
assert_eq!(response.text(), "Method: Get, Path: /test/path");
}
#[test]
fn test_client_post() {
let client = TestClient::new(EchoHandler);
let response = client.post("/api/items").send();
assert_eq!(response.status_code(), 200);
assert!(response.text().contains("Method: Post"));
}
#[test]
#[ignore = "nested executor issue: override_dep_route uses block_on inside TestClient's block_on"]
fn test_app_dependency_override_used_by_test_client() {
let app = App::builder()
.route("/", Method::Get, override_dep_route)
.build();
app.override_dependency_value(OverrideDep { value: 42 });
let client = TestClient::new(app);
let response = client.get("/").send();
assert_eq!(response.text(), "42");
}
#[test]
fn test_test_client_override_clear() {
let client = TestClient::new(OverrideDepHandler);
client.override_dependency_value(OverrideDep { value: 9 });
let response = client.get("/").send();
assert_eq!(response.text(), "9");
client.clear_dependency_overrides();
let response = client.get("/").send();
assert_eq!(response.text(), "1");
}
#[test]
fn test_client_all_methods() {
let client = TestClient::new(EchoHandler);
assert!(client.get("/").send().text().contains("Get"));
assert!(client.post("/").send().text().contains("Post"));
assert!(client.put("/").send().text().contains("Put"));
assert!(client.delete("/").send().text().contains("Delete"));
assert!(client.patch("/").send().text().contains("Patch"));
assert!(client.options("/").send().text().contains("Options"));
assert!(client.head("/").send().text().contains("Head"));
}
#[test]
fn test_query_params() {
let client = TestClient::new(EchoHandler);
let response = client
.get("/search")
.query("q", "rust")
.query("limit", "10")
.send();
assert_eq!(response.status_code(), 200);
}
#[test]
fn test_response_assertions() {
let client = TestClient::new(EchoHandler);
let response = client.get("/test").send();
let _ = response
.assert_status_code(200)
.assert_success()
.assert_header("content-type", "text/plain")
.assert_text_contains("Get");
}
#[test]
fn test_response_status_checks() {
let client = TestClient::new(EchoHandler);
let response = client.get("/").send();
assert!(response.is_success());
assert!(!response.is_redirect());
assert!(!response.is_client_error());
assert!(!response.is_server_error());
}
#[test]
fn test_cookie_jar() {
let mut jar = CookieJar::new();
assert!(jar.is_empty());
jar.set("session", "abc123");
jar.set("user", "alice");
assert_eq!(jar.len(), 2);
assert_eq!(jar.get("session"), Some("abc123"));
assert_eq!(jar.get("user"), Some("alice"));
let header = jar.to_cookie_header().unwrap();
assert!(header.contains("session=abc123"));
assert!(header.contains("user=alice"));
jar.remove("session");
assert_eq!(jar.len(), 1);
assert_eq!(jar.get("session"), None);
}
#[test]
fn test_cookie_jar_request_matching_rules() {
use crate::request::ConnectionInfo;
let mut jar = CookieJar::new();
let mut req = Request::new(Method::Get, "/account/settings");
req.headers_mut().insert("host", b"example.com".to_vec());
jar.parse_set_cookie(
&req,
b"sid=1; Path=/account; Secure; HttpOnly; SameSite=Lax",
);
assert_eq!(jar.cookie_header_for_request(&req), None);
req.insert_extension(ConnectionInfo::HTTPS);
assert_eq!(
jar.cookie_header_for_request(&req).as_deref(),
Some("sid=1")
);
let mut req2 = Request::new(Method::Get, "/other");
req2.headers_mut().insert("host", b"example.com".to_vec());
req2.insert_extension(ConnectionInfo::HTTPS);
assert_eq!(jar.cookie_header_for_request(&req2), None);
jar.parse_set_cookie(&req, b"sub=1; Domain=example.com; Path=/");
let mut req3 = Request::new(Method::Get, "/");
req3.headers_mut()
.insert("host", b"api.example.com".to_vec());
let hdr = jar.cookie_header_for_request(&req3).expect("cookie header");
assert!(hdr.contains("sub=1"));
}
#[test]
fn test_cookie_persistence() {
let client = TestClient::new(CookieHandler);
let _ = client.get("/set-cookie").send();
assert_eq!(client.cookies().get("session"), Some("abc123"));
let client2 = TestClient::new(CookieEchoHandler);
client2.cookies().set("session", "abc123");
let response = client2.get("/check-cookie").send();
assert!(response.text().contains("session=abc123"));
}
#[test]
fn test_request_id_increments() {
let client = TestClient::new(EchoHandler);
let r1 = client.get("/").send();
let r2 = client.get("/").send();
let r3 = client.get("/").send();
assert_eq!(r1.request_id(), 1);
assert_eq!(r2.request_id(), 2);
assert_eq!(r3.request_id(), 3);
}
#[test]
fn test_client_with_seed() {
let client = TestClient::with_seed(EchoHandler, 42);
assert_eq!(client.seed(), Some(42));
}
#[test]
fn test_client_clone() {
let client = TestClient::new(EchoHandler);
client.cookies().set("test", "value");
let cloned = client.clone();
assert_eq!(cloned.cookies().get("test"), Some("value"));
let r1 = client.get("/").send();
let r2 = cloned.get("/").send();
assert_eq!(r1.request_id(), 1);
assert_eq!(r2.request_id(), 2);
}
#[test]
fn test_json_contains_exact_match() {
let actual = serde_json::json!({"id": 1, "name": "Alice"});
let expected = serde_json::json!({"id": 1, "name": "Alice"});
assert!(json_contains(&actual, &expected).is_ok());
}
#[test]
fn test_json_contains_partial_match() {
let actual = serde_json::json!({"id": 1, "name": "Alice", "email": "alice@example.com"});
let expected = serde_json::json!({"name": "Alice"});
assert!(json_contains(&actual, &expected).is_ok());
}
#[test]
fn test_json_contains_nested_partial_match() {
let actual = serde_json::json!({
"user": {"id": 1, "name": "Alice", "email": "alice@example.com"},
"status": "active"
});
let expected = serde_json::json!({
"user": {"name": "Alice"}
});
assert!(json_contains(&actual, &expected).is_ok());
}
#[test]
fn test_json_contains_mismatch_value() {
let actual = serde_json::json!({"id": 1, "name": "Alice"});
let expected = serde_json::json!({"name": "Bob"});
let result = json_contains(&actual, &expected);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "$.name");
}
#[test]
fn test_json_contains_missing_key() {
let actual = serde_json::json!({"id": 1, "name": "Alice"});
let expected = serde_json::json!({"email": "alice@example.com"});
let result = json_contains(&actual, &expected);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "$.email");
}
#[test]
fn test_json_contains_array_exact_match() {
let actual = serde_json::json!({"items": [1, 2, 3]});
let expected = serde_json::json!({"items": [1, 2, 3]});
assert!(json_contains(&actual, &expected).is_ok());
}
#[test]
fn test_json_contains_array_length_mismatch() {
let actual = serde_json::json!({"items": [1, 2, 3]});
let expected = serde_json::json!({"items": [1, 2]});
let result = json_contains(&actual, &expected);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "$.items[length]");
}
#[test]
fn test_json_contains_array_element_mismatch() {
let actual = serde_json::json!({"items": [1, 2, 3]});
let expected = serde_json::json!({"items": [1, 5, 3]});
let result = json_contains(&actual, &expected);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "$.items[1]");
}
#[test]
fn test_json_contains_primitives() {
assert!(json_contains(&serde_json::json!(42), &serde_json::json!(42)).is_ok());
assert!(json_contains(&serde_json::json!(42), &serde_json::json!(43)).is_err());
assert!(json_contains(&serde_json::json!("hello"), &serde_json::json!("hello")).is_ok());
assert!(json_contains(&serde_json::json!("hello"), &serde_json::json!("world")).is_err());
assert!(json_contains(&serde_json::json!(true), &serde_json::json!(true)).is_ok());
assert!(json_contains(&serde_json::json!(true), &serde_json::json!(false)).is_err());
assert!(json_contains(&serde_json::json!(null), &serde_json::json!(null)).is_ok());
}
#[test]
fn test_json_contains_type_mismatch() {
let actual = serde_json::json!({"id": "1"});
let expected = serde_json::json!({"id": 1});
let result = json_contains(&actual, &expected);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "$.id");
}
#[test]
fn test_json_contains_deeply_nested() {
let actual = serde_json::json!({
"level1": {
"level2": {
"level3": {
"value": 42,
"extra": "ignored"
}
}
}
});
let expected = serde_json::json!({
"level1": {
"level2": {
"level3": {
"value": 42
}
}
}
});
assert!(json_contains(&actual, &expected).is_ok());
}
struct JsonHandler;
impl Handler for JsonHandler {
fn call<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, Response> {
let json = serde_json::json!({
"id": 1,
"name": "Alice",
"email": "alice@example.com",
"active": true
});
let body = serde_json::to_vec(&json).unwrap();
Box::pin(async move {
Response::ok()
.header("content-type", b"application/json".to_vec())
.header("x-request-id", b"req-123".to_vec())
.body(ResponseBody::Bytes(body))
})
}
}
#[allow(dead_code)]
struct StatusHandler(u16);
#[allow(dead_code)]
impl Handler for StatusHandler {
fn call<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, Response> {
let status = StatusCode::from_u16(self.0);
Box::pin(async move { Response::with_status(status) })
}
}
#[test]
fn test_assert_status_macro_with_u16() {
let client = TestClient::new(EchoHandler);
let response = client.get("/").send();
crate::assert_status!(response, 200);
}
#[test]
fn test_assert_status_macro_with_status_code() {
let client = TestClient::new(EchoHandler);
let response = client.get("/").send();
crate::assert_status!(response, StatusCode::OK);
}
#[test]
#[should_panic(expected = "assertion failed")]
fn test_assert_status_macro_failure() {
let client = TestClient::new(EchoHandler);
let response = client.get("/").send();
crate::assert_status!(response, 404);
}
#[test]
fn test_assert_header_macro() {
let client = TestClient::new(EchoHandler);
let response = client.get("/").send();
crate::assert_header!(response, "content-type", "text/plain");
}
#[test]
#[should_panic(expected = "assertion failed")]
fn test_assert_header_macro_failure() {
let client = TestClient::new(EchoHandler);
let response = client.get("/").send();
crate::assert_header!(response, "content-type", "application/json");
}
#[test]
fn test_assert_body_contains_macro() {
let client = TestClient::new(EchoHandler);
let response = client.get("/test").send();
crate::assert_body_contains!(response, "Method: Get");
crate::assert_body_contains!(response, "Path: /test");
}
#[test]
#[should_panic(expected = "assertion failed")]
fn test_assert_body_contains_macro_failure() {
let client = TestClient::new(EchoHandler);
let response = client.get("/test").send();
crate::assert_body_contains!(response, "nonexistent");
}
#[test]
fn test_assert_json_macro_partial_match() {
let client = TestClient::new(JsonHandler);
let response = client.get("/user").send();
crate::assert_json!(response, {"name": "Alice"});
crate::assert_json!(response, {"id": 1, "active": true});
}
#[test]
fn test_assert_json_macro_exact_match() {
let client = TestClient::new(JsonHandler);
let response = client.get("/user").send();
crate::assert_json!(response, {
"id": 1,
"name": "Alice",
"email": "alice@example.com",
"active": true
});
}
#[test]
#[should_panic(expected = "JSON partial match failed")]
fn test_assert_json_macro_failure() {
let client = TestClient::new(JsonHandler);
let response = client.get("/user").send();
crate::assert_json!(response, {"name": "Bob"});
}
#[test]
fn test_assert_json_contains_method() {
let client = TestClient::new(JsonHandler);
let response = client.get("/user").send();
let _ = response.assert_json_contains(&serde_json::json!({"name": "Alice"}));
}
#[test]
fn test_assert_header_exists() {
let client = TestClient::new(JsonHandler);
let response = client.get("/").send();
let _ = response
.assert_header_exists("content-type")
.assert_header_exists("x-request-id");
}
#[test]
#[should_panic(expected = "Expected header 'nonexistent' to exist")]
fn test_assert_header_exists_failure() {
let client = TestClient::new(JsonHandler);
let response = client.get("/").send();
let _ = response.assert_header_exists("nonexistent");
}
#[test]
fn test_assert_header_missing() {
let client = TestClient::new(JsonHandler);
let response = client.get("/").send();
let _ = response.assert_header_missing("x-nonexistent");
}
#[test]
#[should_panic(expected = "Expected header 'content-type' to not exist")]
fn test_assert_header_missing_failure() {
let client = TestClient::new(JsonHandler);
let response = client.get("/").send();
let _ = response.assert_header_missing("content-type");
}
#[test]
fn test_assert_content_type_contains() {
let client = TestClient::new(JsonHandler);
let response = client.get("/").send();
let _ = response.assert_content_type_contains("application/json");
let _ = response.assert_content_type_contains("json");
}
#[test]
#[should_panic(expected = "Expected Content-Type to contain")]
fn test_assert_content_type_contains_failure() {
let client = TestClient::new(JsonHandler);
let response = client.get("/").send();
let _ = response.assert_content_type_contains("text/html");
}
#[test]
fn test_assertion_chaining() {
let client = TestClient::new(JsonHandler);
let response = client.get("/user").send();
let _ = response
.assert_status_code(200)
.assert_success()
.assert_header_exists("content-type")
.assert_content_type_contains("json")
.assert_json_contains(&serde_json::json!({"name": "Alice"}));
}
#[test]
fn test_macro_with_custom_message() {
let client = TestClient::new(EchoHandler);
let response = client.get("/").send();
crate::assert_status!(response, 200, "Expected 200 OK from echo handler");
crate::assert_header!(
response,
"content-type",
"text/plain",
"Should have text content type"
);
crate::assert_body_contains!(response, "Get", "Should contain HTTP method");
}
#[derive(Clone)]
struct DatabasePool {
connection_string: String,
}
impl FromDependency for DatabasePool {
type Error = HttpError;
async fn from_dependency(
_ctx: &RequestContext,
_req: &mut Request,
) -> Result<Self, Self::Error> {
Ok(DatabasePool {
connection_string: "postgres://localhost/test".to_string(),
})
}
}
#[derive(Clone)]
struct UserRepository {
pool_conn_str: String,
}
impl FromDependency for UserRepository {
type Error = HttpError;
async fn from_dependency(
ctx: &RequestContext,
req: &mut Request,
) -> Result<Self, Self::Error> {
let pool = Depends::<DatabasePool>::from_request(ctx, req).await?;
Ok(UserRepository {
pool_conn_str: pool.connection_string.clone(),
})
}
}
#[derive(Clone)]
struct AuthService {
user_repo_pool: String,
}
impl FromDependency for AuthService {
type Error = HttpError;
async fn from_dependency(
ctx: &RequestContext,
req: &mut Request,
) -> Result<Self, Self::Error> {
let repo = Depends::<UserRepository>::from_request(ctx, req).await?;
Ok(AuthService {
user_repo_pool: repo.pool_conn_str.clone(),
})
}
}
struct ComplexDepHandler;
impl Handler for ComplexDepHandler {
fn call<'a>(
&'a self,
ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
let auth = Depends::<AuthService>::from_request(ctx, req)
.await
.expect("dependency resolution failed");
let body = format!("AuthService.pool={}", auth.user_repo_pool);
Response::ok().body(ResponseBody::Bytes(body.into_bytes()))
})
}
}
#[test]
fn test_full_request_with_complex_deps() {
let client = TestClient::new(ComplexDepHandler);
let response = client.get("/auth/check").send();
assert_eq!(response.status_code(), 200);
assert!(response.text().contains("postgres://localhost/test"));
}
#[test]
fn test_complex_deps_with_override_at_leaf() {
let client = TestClient::new(ComplexDepHandler);
client.override_dependency_value(DatabasePool {
connection_string: "mysql://prod/users".to_string(),
});
let response = client.get("/auth/check").send();
assert_eq!(response.status_code(), 200);
assert!(
response.text().contains("mysql://prod/users"),
"Override at leaf should propagate through dependency chain"
);
}
#[test]
fn test_complex_deps_with_override_at_middle() {
let client = TestClient::new(ComplexDepHandler);
client.override_dependency_value(UserRepository {
pool_conn_str: "overridden-repo-connection".to_string(),
});
let response = client.get("/auth/check").send();
assert_eq!(response.status_code(), 200);
assert!(
response.text().contains("overridden-repo-connection"),
"Override at middle level should be used"
);
}
#[test]
fn test_dependency_caching_across_handler() {
use std::sync::atomic::{AtomicUsize, Ordering};
static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
#[derive(Clone)]
struct TrackedDep {
call_number: usize,
}
impl FromDependency for TrackedDep {
type Error = HttpError;
async fn from_dependency(
_ctx: &RequestContext,
_req: &mut Request,
) -> Result<Self, Self::Error> {
let call_number = CALL_COUNT.fetch_add(1, Ordering::SeqCst);
Ok(TrackedDep { call_number })
}
}
struct MultiDepHandler;
impl Handler for MultiDepHandler {
fn call<'a>(
&'a self,
ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
let dep1 = Depends::<TrackedDep>::from_request(ctx, req)
.await
.expect("first resolution failed");
let dep2 = Depends::<TrackedDep>::from_request(ctx, req)
.await
.expect("second resolution failed");
let body = format!("dep1={} dep2={}", dep1.call_number, dep2.call_number);
Response::ok().body(ResponseBody::Bytes(body.into_bytes()))
})
}
}
CALL_COUNT.store(0, Ordering::SeqCst);
let client = TestClient::new(MultiDepHandler);
let response = client.get("/").send();
let text = response.text();
assert!(
text.contains("dep1=0 dep2=0"),
"Dependencies should be cached within request. Got: {}",
text
);
assert_eq!(CALL_COUNT.load(Ordering::SeqCst), 1);
}
#[test]
#[allow(clippy::float_cmp)] fn lab_test_config_defaults() {
let config = LabTestConfig::default();
assert_eq!(config.seed, 42);
assert!(!config.chaos_enabled);
assert_eq!(config.chaos_intensity, 0.0);
assert_eq!(config.max_steps, Some(10_000));
assert!(!config.capture_traces);
}
#[test]
fn lab_test_config_with_seed() {
let config = LabTestConfig::new(12345);
assert_eq!(config.seed, 12345);
}
#[test]
#[allow(clippy::float_cmp)] fn lab_test_config_light_chaos() {
let config = LabTestConfig::new(42).with_light_chaos();
assert!(config.chaos_enabled);
assert_eq!(config.chaos_intensity, 0.05);
}
#[test]
#[allow(clippy::float_cmp)] fn lab_test_config_heavy_chaos() {
let config = LabTestConfig::new(42).with_heavy_chaos();
assert!(config.chaos_enabled);
assert_eq!(config.chaos_intensity, 0.2);
}
#[test]
#[allow(clippy::float_cmp)] fn lab_test_config_custom_intensity() {
let config = LabTestConfig::new(42).with_chaos_intensity(0.15);
assert!(config.chaos_enabled);
assert_eq!(config.chaos_intensity, 0.15);
}
#[test]
#[allow(clippy::float_cmp)] fn lab_test_config_intensity_clamps() {
let config = LabTestConfig::new(42).with_chaos_intensity(1.5);
assert_eq!(config.chaos_intensity, 1.0);
let config = LabTestConfig::new(42).with_chaos_intensity(-0.5);
assert_eq!(config.chaos_intensity, 0.0);
assert!(!config.chaos_enabled);
}
#[test]
fn lab_test_config_max_steps() {
let config = LabTestConfig::new(42).with_max_steps(1000);
assert_eq!(config.max_steps, Some(1000));
}
#[test]
fn lab_test_config_no_step_limit() {
let config = LabTestConfig::new(42).without_step_limit();
assert_eq!(config.max_steps, None);
}
#[test]
fn lab_test_config_with_traces() {
let config = LabTestConfig::new(42).with_traces();
assert!(config.capture_traces);
}
#[test]
#[allow(clippy::float_cmp)] fn test_chaos_stats_empty() {
let stats = TestChaosStats::default();
assert_eq!(stats.decision_points, 0);
assert_eq!(stats.delays_injected, 0);
assert_eq!(stats.cancellations_injected, 0);
assert_eq!(stats.injection_rate(), 0.0);
assert!(!stats.had_chaos());
}
#[test]
fn test_chaos_stats_with_injections() {
let stats = TestChaosStats {
decision_points: 100,
delays_injected: 5,
cancellations_injected: 2,
steps_executed: 50,
};
assert!((stats.injection_rate() - 0.07).abs() < 0.001);
assert!(stats.had_chaos());
}
#[test]
fn mock_time_basic() {
let time = MockTime::new();
assert_eq!(time.now(), std::time::Duration::ZERO);
assert_eq!(time.elapsed(), std::time::Duration::ZERO);
time.advance(std::time::Duration::from_secs(5));
assert_eq!(time.now(), std::time::Duration::from_secs(5));
}
#[test]
fn mock_time_set_and_reset() {
let time = MockTime::new();
time.set(std::time::Duration::from_secs(100));
assert_eq!(time.now(), std::time::Duration::from_secs(100));
time.reset();
assert_eq!(time.now(), std::time::Duration::ZERO);
}
#[test]
fn mock_time_starting_at() {
let time = MockTime::starting_at(std::time::Duration::from_secs(10));
assert_eq!(time.now(), std::time::Duration::from_secs(10));
}
#[test]
fn cancellation_test_completes_normally() {
let test = CancellationTest::new(EchoHandler);
let result = test.complete_normally();
assert!(result.completed);
assert!(!result.cancelled_at_checkpoint);
assert!(result.response.is_some());
assert_eq!(result.response.as_ref().unwrap().status().as_u16(), 200);
}
#[test]
fn cancellation_test_respects_cancellation() {
struct CheckpointHandler;
impl Handler for CheckpointHandler {
fn call<'a>(
&'a self,
ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
if ctx.checkpoint().is_err() {
return Response::with_status(StatusCode::CLIENT_CLOSED_REQUEST);
}
Response::ok().body(ResponseBody::Bytes(b"OK".to_vec()))
})
}
}
let test = CancellationTest::new(CheckpointHandler);
let result = test.test_respects_cancellation();
assert!(result.completed);
assert!(result.cancelled_at_checkpoint);
assert!(result.response.is_some());
assert_eq!(result.response.as_ref().unwrap().status().as_u16(), 499);
}
#[test]
fn cancellation_test_result_helpers() {
let graceful = CancellationTestResult {
completed: false,
cancelled_at_checkpoint: true,
response: None,
cancellation_point: None,
};
assert!(graceful.gracefully_cancelled());
assert!(!graceful.completed_despite_cancel());
let completed = CancellationTestResult {
completed: true,
cancelled_at_checkpoint: false,
response: Some(Response::ok()),
cancellation_point: None,
};
assert!(!completed.gracefully_cancelled());
assert!(completed.completed_despite_cancel());
}
#[test]
fn test_logger_captures_all_levels() {
let logger = TestLogger::new();
logger.log_message(LogLevel::Debug, "debug message", 1);
logger.log_message(LogLevel::Info, "info message", 1);
logger.log_message(LogLevel::Warn, "warn message", 1);
logger.log_message(LogLevel::Error, "error message", 1);
let logs = logger.logs();
assert_eq!(logs.len(), 4);
assert_eq!(logs[0].level, LogLevel::Debug);
assert_eq!(logs[1].level, LogLevel::Info);
assert_eq!(logs[2].level, LogLevel::Warn);
assert_eq!(logs[3].level, LogLevel::Error);
}
#[test]
fn test_logger_logs_at_level_filters_correctly() {
let logger = TestLogger::new();
logger.log_message(LogLevel::Debug, "debug", 1);
logger.log_message(LogLevel::Info, "info 1", 1);
logger.log_message(LogLevel::Info, "info 2", 2);
logger.log_message(LogLevel::Warn, "warn", 1);
logger.log_message(LogLevel::Error, "error", 1);
let info_logs = logger.logs_at_level(LogLevel::Info);
assert_eq!(info_logs.len(), 2);
assert!(info_logs[0].contains("info 1"));
assert!(info_logs[1].contains("info 2"));
let error_logs = logger.logs_at_level(LogLevel::Error);
assert_eq!(error_logs.len(), 1);
assert!(error_logs[0].contains("error"));
let trace_logs = logger.logs_at_level(LogLevel::Trace);
assert_eq!(trace_logs.len(), 0);
}
#[test]
fn test_logger_contains_message_search() {
let logger = TestLogger::new();
logger.log_message(LogLevel::Info, "User alice logged in", 100);
logger.log_message(LogLevel::Info, "Request processed for /api/users", 101);
logger.log_message(LogLevel::Warn, "Rate limit approaching for alice", 102);
assert!(logger.contains_message("alice"));
assert!(logger.contains_message("/api/users"));
assert!(logger.contains_message("Rate limit"));
assert!(!logger.contains_message("bob"));
assert!(!logger.contains_message("nonexistent"));
}
#[test]
fn test_logger_contains_multiple_messages() {
let logger = TestLogger::new();
logger.log_message(LogLevel::Info, "step 1 complete", 1);
logger.log_message(LogLevel::Info, "step 2 complete", 2);
logger.log_message(LogLevel::Info, "step 3 complete", 3);
assert!(logger.contains_message("step 1"));
assert!(logger.contains_message("step 2"));
assert!(logger.contains_message("step 3"));
assert!(logger.contains_message("complete"));
assert!(!logger.contains_message("step 4"));
}
#[test]
fn test_log_capture_captures_logs_in_closure() {
let capture = TestLogger::capture(|logger| {
logger.log_message(LogLevel::Info, "inside capture", 1);
logger.log_message(LogLevel::Warn, "warning inside", 2);
42
});
assert!(capture.passed());
assert!(!capture.failed());
assert_eq!(capture.result, Some(42));
assert_eq!(capture.logs.len(), 2);
assert!(capture.contains_message("inside capture"));
assert!(capture.contains_message("warning inside"));
}
#[test]
fn test_log_capture_count_by_level() {
let capture = TestLogger::capture(|logger| {
logger.log_message(LogLevel::Info, "info 1", 1);
logger.log_message(LogLevel::Info, "info 2", 2);
logger.log_message(LogLevel::Info, "info 3", 3);
logger.log_message(LogLevel::Error, "error 1", 4);
});
assert_eq!(capture.count_by_level(LogLevel::Info), 3);
assert_eq!(capture.count_by_level(LogLevel::Error), 1);
assert_eq!(capture.count_by_level(LogLevel::Warn), 0);
}
#[test]
fn test_log_capture_phased_all_phases() {
let capture = TestLogger::capture_phased(
|logger| {
logger.log_message(LogLevel::Info, "setup phase", 1);
},
|logger| {
logger.log_message(LogLevel::Info, "execute phase", 2);
"result"
},
|logger| {
logger.log_message(LogLevel::Info, "teardown phase", 3);
},
);
assert!(capture.passed());
assert_eq!(capture.result, Some("result"));
assert_eq!(capture.logs.len(), 3);
assert!(capture.contains_message("setup phase"));
assert!(capture.contains_message("execute phase"));
assert!(capture.contains_message("teardown phase"));
}
#[test]
fn test_log_capture_timings_recorded() {
let capture = TestLogger::capture(|_logger| {
let mut sum = 0;
for i in 0..1000 {
sum += i;
}
sum
});
assert!(capture.passed());
let timings = &capture.timings;
assert!(timings.total() >= std::time::Duration::ZERO);
}
#[test]
fn test_log_capture_failure_context() {
let capture = TestLogger::capture(|logger| {
logger.log_message(LogLevel::Info, "step 1", 1);
logger.log_message(LogLevel::Info, "step 2", 2);
logger.log_message(LogLevel::Error, "something went wrong", 3);
logger.log_message(LogLevel::Info, "step 3", 4);
});
let context = capture.failure_context(3);
assert!(context.contains("something went wrong") || context.contains("step 3"));
}
#[test]
fn test_captured_log_format() {
let log = CapturedLog::new(LogLevel::Warn, "test warning message", 12345);
let formatted = log.format();
assert!(formatted.contains("[W]"));
assert!(formatted.contains("test warning message"));
assert!(formatted.contains("12345"));
}
#[test]
fn test_captured_log_contains() {
let log = CapturedLog::new(LogLevel::Info, "user login successful for alice", 1);
assert!(log.contains("login"));
assert!(log.contains("alice"));
assert!(log.contains("successful"));
assert!(!log.contains("bob"));
assert!(!log.contains("failed"));
}
#[test]
fn test_captured_log_fields() {
let log = CapturedLog::new(LogLevel::Error, "database connection failed", 999);
assert_eq!(log.level, LogLevel::Error);
assert_eq!(log.message, "database connection failed");
assert_eq!(log.request_id, 999);
}
#[test]
fn test_multiple_loggers_isolated() {
let logger1 = TestLogger::new();
let logger2 = TestLogger::new();
logger1.log_message(LogLevel::Info, "from logger 1", 1);
logger2.log_message(LogLevel::Info, "from logger 2", 2);
assert_eq!(logger1.logs().len(), 1);
assert_eq!(logger2.logs().len(), 1);
assert!(logger1.contains_message("logger 1"));
assert!(!logger1.contains_message("logger 2"));
assert!(logger2.contains_message("logger 2"));
assert!(!logger2.contains_message("logger 1"));
}
#[test]
fn test_logger_log_entry_integration() {
let logger = TestLogger::new();
let entry = LogEntry {
level: LogLevel::Warn,
message: "warning from entry".to_string(),
request_id: 42,
region_id: "region-1".to_string(),
task_id: "task-1".to_string(),
target: None,
fields: Vec::new(),
timestamp_ns: 0,
};
logger.log_entry(&entry);
assert_eq!(logger.logs().len(), 1);
let captured = &logger.logs()[0];
assert_eq!(captured.level, LogLevel::Warn);
assert!(captured.contains("warning from entry"));
assert_eq!(captured.request_id, 42);
}
#[test]
fn test_log_capture_unwrap_on_success() {
let capture = TestLogger::capture(|_| 123);
let value = capture.unwrap();
assert_eq!(value, 123);
}
#[test]
fn test_log_capture_unwrap_or_on_success() {
let capture = TestLogger::capture(|_| 456);
let value = capture.unwrap_or(0);
assert_eq!(value, 456);
}
}
use std::io::{Read as _, Write as _};
use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
use std::sync::atomic::AtomicBool;
use std::thread;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct RecordedRequest {
pub method: String,
pub path: String,
pub query: Option<String>,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
pub timestamp: std::time::Instant,
}
impl RecordedRequest {
#[must_use]
pub fn body_text(&self) -> &str {
std::str::from_utf8(&self.body).expect("body is not valid UTF-8")
}
#[must_use]
pub fn header(&self, name: &str) -> Option<&str> {
let name_lower = name.to_ascii_lowercase();
self.headers
.iter()
.find(|(n, _)| n.to_ascii_lowercase() == name_lower)
.map(|(_, v)| v.as_str())
}
#[must_use]
pub fn url(&self) -> String {
match &self.query {
Some(q) => format!("{}?{}", self.path, q),
None => self.path.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct MockResponse {
pub status: u16,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
pub delay: Option<Duration>,
}
impl Default for MockResponse {
fn default() -> Self {
Self {
status: 200,
headers: vec![("content-type".to_string(), "text/plain".to_string())],
body: b"OK".to_vec(),
delay: None,
}
}
}
impl MockResponse {
#[must_use]
pub fn ok() -> Self {
Self::default()
}
#[must_use]
pub fn with_status(status: u16) -> Self {
Self {
status,
..Default::default()
}
}
#[must_use]
pub fn status(mut self, status: u16) -> Self {
self.status = status;
self
}
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.push((name.into(), value.into()));
self
}
#[must_use]
pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
self.body = body.into();
self
}
#[must_use]
pub fn body_str(self, body: &str) -> Self {
self.body(body.as_bytes().to_vec())
}
#[must_use]
pub fn json<T: serde::Serialize>(mut self, value: &T) -> Self {
self.body = serde_json::to_vec(value).expect("JSON serialization failed");
self.headers
.push(("content-type".to_string(), "application/json".to_string()));
self
}
#[must_use]
pub fn delay(mut self, duration: Duration) -> Self {
self.delay = Some(duration);
self
}
fn to_http_response(&self) -> Vec<u8> {
let status_text = match self.status {
200 => "OK",
201 => "Created",
204 => "No Content",
400 => "Bad Request",
401 => "Unauthorized",
403 => "Forbidden",
404 => "Not Found",
500 => "Internal Server Error",
502 => "Bad Gateway",
503 => "Service Unavailable",
504 => "Gateway Timeout",
_ => "Unknown",
};
let mut response = format!("HTTP/1.1 {} {}\r\n", self.status, status_text);
response.push_str(&format!("content-length: {}\r\n", self.body.len()));
for (name, value) in &self.headers {
response.push_str(&format!("{}: {}\r\n", name, value));
}
response.push_str("\r\n");
let mut bytes = response.into_bytes();
bytes.extend_from_slice(&self.body);
bytes
}
}
pub struct MockServer {
addr: SocketAddr,
requests: Arc<Mutex<Vec<RecordedRequest>>>,
responses: Arc<Mutex<HashMap<String, MockResponse>>>,
default_response: Arc<Mutex<MockResponse>>,
shutdown: Arc<AtomicBool>,
handle: Option<thread::JoinHandle<()>>,
}
impl MockServer {
#[must_use]
pub fn start() -> Self {
Self::start_with_options(MockServerOptions::default())
}
#[must_use]
pub fn start_with_options(options: MockServerOptions) -> Self {
let listener =
StdTcpListener::bind("127.0.0.1:0").expect("Failed to bind mock server to port");
let addr = listener.local_addr().expect("Failed to get local address");
listener
.set_nonblocking(true)
.expect("Failed to set non-blocking");
let requests = Arc::new(Mutex::new(Vec::new()));
let responses = Arc::new(Mutex::new(HashMap::new()));
let default_response = Arc::new(Mutex::new(options.default_response));
let shutdown = Arc::new(AtomicBool::new(false));
let requests_clone = Arc::clone(&requests);
let responses_clone = Arc::clone(&responses);
let default_response_clone = Arc::clone(&default_response);
let shutdown_clone = Arc::clone(&shutdown);
let read_timeout = options.read_timeout;
let handle = thread::spawn(move || {
Self::server_loop(
listener,
requests_clone,
responses_clone,
default_response_clone,
shutdown_clone,
read_timeout,
);
});
Self {
addr,
requests,
responses,
default_response,
shutdown,
handle: Some(handle),
}
}
fn server_loop(
listener: StdTcpListener,
requests: Arc<Mutex<Vec<RecordedRequest>>>,
responses: Arc<Mutex<HashMap<String, MockResponse>>>,
default_response: Arc<Mutex<MockResponse>>,
shutdown: Arc<AtomicBool>,
read_timeout: Duration,
) {
loop {
if shutdown.load(std::sync::atomic::Ordering::Acquire) {
break;
}
match listener.accept() {
Ok((stream, _peer)) => {
let requests = Arc::clone(&requests);
let responses = Arc::clone(&responses);
let default_response = Arc::clone(&default_response);
Self::handle_connection(
stream,
requests,
responses,
default_response,
read_timeout,
);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(10));
}
Err(e) => {
eprintln!("MockServer accept error: {}", e);
break;
}
}
}
}
fn handle_connection(
mut stream: StdTcpStream,
requests: Arc<Mutex<Vec<RecordedRequest>>>,
responses: Arc<Mutex<HashMap<String, MockResponse>>>,
default_response: Arc<Mutex<MockResponse>>,
read_timeout: Duration,
) {
let _ = stream.set_read_timeout(Some(read_timeout));
let mut buffer = vec![0u8; 8192];
let Ok(bytes_read) = stream.read(&mut buffer) else {
return;
};
if bytes_read == 0 {
return;
}
buffer.truncate(bytes_read);
let Some(recorded) = Self::parse_request(&buffer) else {
return;
};
{
let mut reqs = requests.lock();
reqs.push(recorded.clone());
}
let response = {
let resps = responses.lock();
match resps.get(&recorded.path) {
Some(r) => r.clone(),
None => {
let mut matched = None;
for (pattern, resp) in resps.iter() {
if pattern.ends_with('*') {
let prefix = &pattern[..pattern.len() - 1];
if recorded.path.starts_with(prefix) {
matched = Some(resp.clone());
break;
}
}
}
matched.unwrap_or_else(|| default_response.lock().clone())
}
}
};
if let Some(delay) = response.delay {
thread::sleep(delay);
}
let response_bytes = response.to_http_response();
let _ = stream.write_all(&response_bytes);
let _ = stream.flush();
}
fn parse_request(data: &[u8]) -> Option<RecordedRequest> {
let text = std::str::from_utf8(data).ok()?;
let mut lines = text.lines();
let request_line = lines.next()?;
let parts: Vec<&str> = request_line.split_whitespace().collect();
if parts.len() < 2 {
return None;
}
let method = parts[0].to_string();
let full_path = parts[1];
let (path, query) = if let Some(idx) = full_path.find('?') {
(
full_path[..idx].to_string(),
Some(full_path[idx + 1..].to_string()),
)
} else {
(full_path.to_string(), None)
};
let mut headers = Vec::new();
let mut content_length = 0usize;
for line in lines.by_ref() {
if line.is_empty() {
break;
}
if let Some((name, value)) = line.split_once(':') {
let name = name.trim().to_string();
let value = value.trim().to_string();
if name.eq_ignore_ascii_case("content-length") {
content_length = value.parse().unwrap_or(0);
}
headers.push((name, value));
}
}
let body = if content_length > 0 {
if let Some(body_start) = text.find("\r\n\r\n") {
let body_start = body_start + 4;
if body_start < data.len() {
data[body_start..].to_vec()
} else {
Vec::new()
}
} else if let Some(body_start) = text.find("\n\n") {
let body_start = body_start + 2;
if body_start < data.len() {
data[body_start..].to_vec()
} else {
Vec::new()
}
} else {
Vec::new()
}
} else {
Vec::new()
};
Some(RecordedRequest {
method,
path,
query,
headers,
body,
timestamp: std::time::Instant::now(),
})
}
#[must_use]
pub fn addr(&self) -> SocketAddr {
self.addr
}
#[must_use]
pub fn url(&self) -> String {
format!("http://{}", self.addr)
}
#[must_use]
pub fn url_for(&self, path: &str) -> String {
let path = if path.starts_with('/') {
path
} else {
&format!("/{}", path)
};
format!("http://{}{}", self.addr, path)
}
pub fn mock_response(&self, path: impl Into<String>, response: MockResponse) {
let mut responses = self.responses.lock();
responses.insert(path.into(), response);
}
pub fn set_default_response(&self, response: MockResponse) {
let mut default = self.default_response.lock();
*default = response;
}
#[must_use]
pub fn requests(&self) -> Vec<RecordedRequest> {
let requests = self.requests.lock();
requests.clone()
}
#[must_use]
pub fn request_count(&self) -> usize {
let requests = self.requests.lock();
requests.len()
}
#[must_use]
pub fn requests_for(&self, path: &str) -> Vec<RecordedRequest> {
let requests = self.requests.lock();
requests
.iter()
.filter(|r| r.path == path)
.cloned()
.collect()
}
#[must_use]
pub fn last_request(&self) -> Option<RecordedRequest> {
let requests = self.requests.lock();
requests.last().cloned()
}
pub fn clear_requests(&self) {
let mut requests = self.requests.lock();
requests.clear();
}
pub fn clear_responses(&self) {
let mut responses = self.responses.lock();
responses.clear();
}
pub fn reset(&self) {
self.clear_requests();
self.clear_responses();
}
pub fn wait_for_requests(&self, count: usize, timeout: Duration) -> bool {
let start = std::time::Instant::now();
loop {
if self.request_count() >= count {
return true;
}
if start.elapsed() >= timeout {
return false;
}
thread::sleep(Duration::from_millis(10));
}
}
pub fn assert_received(&self, path: &str) {
let requests = self.requests_for(path);
assert!(
!requests.is_empty(),
"Expected request to path '{}', but none was received. Received paths: {:?}",
path,
self.requests().iter().map(|r| &r.path).collect::<Vec<_>>()
);
}
pub fn assert_not_received(&self, path: &str) {
let requests = self.requests_for(path);
assert!(
requests.is_empty(),
"Expected no request to path '{}', but {} were received",
path,
requests.len()
);
}
pub fn assert_request_count(&self, expected: usize) {
let actual = self.request_count();
assert_eq!(
actual, expected,
"Expected {} requests, but received {}",
expected, actual
);
}
}
impl Drop for MockServer {
fn drop(&mut self) {
self.shutdown
.store(true, std::sync::atomic::Ordering::Release);
if let Some(handle) = self.handle.take() {
let _ = handle.join();
}
}
}
#[derive(Debug, Clone)]
pub struct MockServerOptions {
pub default_response: MockResponse,
pub read_timeout: Duration,
}
impl Default for MockServerOptions {
fn default() -> Self {
Self {
default_response: MockResponse::with_status(404).body_str("Not Found"),
read_timeout: Duration::from_secs(5),
}
}
}
impl MockServerOptions {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn default_response(mut self, response: MockResponse) -> Self {
self.default_response = response;
self
}
#[must_use]
pub fn read_timeout(mut self, timeout: Duration) -> Self {
self.read_timeout = timeout;
self
}
}
#[derive(Debug, Clone)]
pub struct TestServerLogEntry {
pub method: String,
pub path: String,
pub status: u16,
pub duration: Duration,
pub timestamp: std::time::Instant,
}
#[derive(Debug, Clone)]
pub struct TestServerConfig {
pub read_timeout: Duration,
pub log_requests: bool,
}
impl Default for TestServerConfig {
fn default() -> Self {
Self {
read_timeout: Duration::from_secs(5),
log_requests: true,
}
}
}
impl TestServerConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn read_timeout(mut self, timeout: Duration) -> Self {
self.read_timeout = timeout;
self
}
#[must_use]
pub fn log_requests(mut self, log: bool) -> Self {
self.log_requests = log;
self
}
}
pub struct TestServer {
addr: SocketAddr,
shutdown: Arc<AtomicBool>,
handle: Option<thread::JoinHandle<()>>,
log_entries: Arc<Mutex<Vec<TestServerLogEntry>>>,
shutdown_controller: crate::shutdown::ShutdownController,
}
impl TestServer {
#[must_use]
pub fn start(app: crate::app::App) -> Self {
Self::start_with_config(app, TestServerConfig::default())
}
#[must_use]
pub fn start_with_config(app: crate::app::App, config: TestServerConfig) -> Self {
let listener =
StdTcpListener::bind("127.0.0.1:0").expect("Failed to bind test server to port");
let addr = listener.local_addr().expect("Failed to get local address");
listener
.set_nonblocking(true)
.expect("Failed to set non-blocking");
let app = Arc::new(app);
let shutdown = Arc::new(AtomicBool::new(false));
let log_entries = Arc::new(Mutex::new(Vec::new()));
let shutdown_controller = crate::shutdown::ShutdownController::new();
let shutdown_clone = Arc::clone(&shutdown);
let log_entries_clone = Arc::clone(&log_entries);
let app_clone = Arc::clone(&app);
let controller_clone = shutdown_controller.clone();
let handle = thread::spawn(move || {
Self::server_loop(
listener,
app_clone,
shutdown_clone,
log_entries_clone,
config,
controller_clone,
);
});
Self {
addr,
shutdown,
handle: Some(handle),
log_entries,
shutdown_controller,
}
}
fn server_loop(
listener: StdTcpListener,
app: Arc<crate::app::App>,
shutdown: Arc<AtomicBool>,
log_entries: Arc<Mutex<Vec<TestServerLogEntry>>>,
config: TestServerConfig,
controller: crate::shutdown::ShutdownController,
) {
let request_counter = std::sync::atomic::AtomicU64::new(1);
loop {
if shutdown.load(std::sync::atomic::Ordering::Acquire) {
while let Some(hook) = controller.pop_hook() {
hook.run();
}
break;
}
match listener.accept() {
Ok((stream, _peer)) => {
let _guard = controller.track_request();
if controller.is_shutting_down() {
Self::send_503(stream);
continue;
}
Self::handle_connection(stream, &app, &log_entries, &config, &request_counter);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(5));
}
Err(_) => {
break;
}
}
}
}
fn handle_connection(
mut stream: StdTcpStream,
app: &Arc<crate::app::App>,
log_entries: &Arc<Mutex<Vec<TestServerLogEntry>>>,
config: &TestServerConfig,
request_counter: &std::sync::atomic::AtomicU64,
) {
let _ = stream.set_read_timeout(Some(config.read_timeout));
let mut buffer = vec![0u8; 65536];
let bytes_read = match stream.read(&mut buffer) {
Ok(n) if n > 0 => n,
_ => return,
};
buffer.truncate(bytes_read);
let Some(parsed) = Self::parse_raw_request(&buffer) else {
let bad_request = b"HTTP/1.1 400 Bad Request\r\ncontent-length: 11\r\n\r\nBad Request";
let _ = stream.write_all(bad_request);
let _ = stream.flush();
return;
};
let start_time = std::time::Instant::now();
let method = match parsed.method.to_uppercase().as_str() {
"GET" => Method::Get,
"POST" => Method::Post,
"PUT" => Method::Put,
"DELETE" => Method::Delete,
"PATCH" => Method::Patch,
"HEAD" => Method::Head,
"OPTIONS" => Method::Options,
_ => Method::Get,
};
let mut request = Request::new(method, &parsed.path);
if let Some(ref query) = parsed.query {
request.set_query(Some(query.clone()));
}
for (name, value) in &parsed.headers {
request
.headers_mut()
.insert(name.clone(), value.as_bytes().to_vec());
}
if !parsed.body.is_empty() {
request.set_body(Body::Bytes(parsed.body.clone()));
}
let cx = Cx::for_testing();
let request_id = request_counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let dependency_overrides = Handler::dependency_overrides(app.as_ref())
.unwrap_or_else(|| Arc::new(crate::dependency::DependencyOverrides::new()));
let ctx = RequestContext::with_overrides(cx, request_id, dependency_overrides);
let response = futures_executor::block_on(app.handle(&ctx, &mut request));
let duration = start_time.elapsed();
let status_code = response.status().as_u16();
if config.log_requests {
let entry = TestServerLogEntry {
method: parsed.method.clone(),
path: parsed.path.clone(),
status: status_code,
duration,
timestamp: start_time,
};
log_entries.lock().push(entry);
}
let response_bytes = Self::serialize_response(response);
let _ = stream.write_all(&response_bytes);
let _ = stream.flush();
}
fn parse_raw_request(data: &[u8]) -> Option<ParsedRequest> {
let text = std::str::from_utf8(data).ok()?;
let mut lines = text.lines();
let request_line = lines.next()?;
let parts: Vec<&str> = request_line.split_whitespace().collect();
if parts.len() < 2 {
return None;
}
let method = parts[0].to_string();
let full_path = parts[1];
let (path, query) = if let Some(idx) = full_path.find('?') {
(
full_path[..idx].to_string(),
Some(full_path[idx + 1..].to_string()),
)
} else {
(full_path.to_string(), None)
};
let mut headers = Vec::new();
let mut content_length = 0usize;
for line in lines.by_ref() {
if line.is_empty() {
break;
}
if let Some((name, value)) = line.split_once(':') {
let name = name.trim().to_string();
let value = value.trim().to_string();
if name.eq_ignore_ascii_case("content-length") {
content_length = value.parse().unwrap_or(0);
}
headers.push((name, value));
}
}
let body = if content_length > 0 {
if let Some(body_start) = text.find("\r\n\r\n") {
let body_start = body_start + 4;
if body_start < data.len() {
data[body_start..].to_vec()
} else {
Vec::new()
}
} else if let Some(body_start) = text.find("\n\n") {
let body_start = body_start + 2;
if body_start < data.len() {
data[body_start..].to_vec()
} else {
Vec::new()
}
} else {
Vec::new()
}
} else {
Vec::new()
};
Some(ParsedRequest {
method,
path,
query,
headers,
body,
})
}
fn serialize_response(response: Response) -> Vec<u8> {
let (status, headers, body) = response.into_parts();
let body_bytes = match body {
ResponseBody::Empty => Vec::new(),
ResponseBody::Bytes(b) => b,
ResponseBody::Stream(_) => {
Vec::new()
}
};
let mut buf = Vec::with_capacity(512 + body_bytes.len());
buf.extend_from_slice(b"HTTP/1.1 ");
buf.extend_from_slice(status.as_u16().to_string().as_bytes());
buf.extend_from_slice(b" ");
buf.extend_from_slice(status.canonical_reason().as_bytes());
buf.extend_from_slice(b"\r\n");
for (name, value) in &headers {
if name.eq_ignore_ascii_case("content-length")
|| name.eq_ignore_ascii_case("transfer-encoding")
{
continue;
}
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value);
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"content-length: ");
buf.extend_from_slice(body_bytes.len().to_string().as_bytes());
buf.extend_from_slice(b"\r\n");
buf.extend_from_slice(b"\r\n");
buf.extend_from_slice(&body_bytes);
buf
}
#[must_use]
pub fn addr(&self) -> SocketAddr {
self.addr
}
#[must_use]
pub fn port(&self) -> u16 {
self.addr.port()
}
#[must_use]
pub fn url(&self) -> String {
format!("http://{}", self.addr)
}
#[must_use]
pub fn url_for(&self, path: &str) -> String {
let path = if path.starts_with('/') {
path.to_string()
} else {
format!("/{path}")
};
format!("http://{}{}", self.addr, path)
}
#[must_use]
pub fn log_entries(&self) -> Vec<TestServerLogEntry> {
self.log_entries.lock().clone()
}
#[must_use]
pub fn request_count(&self) -> usize {
self.log_entries.lock().len()
}
pub fn clear_logs(&self) {
self.log_entries.lock().clear();
}
fn send_503(mut stream: StdTcpStream) {
let response =
b"HTTP/1.1 503 Service Unavailable\r\ncontent-length: 19\r\n\r\nService Unavailable";
let _ = stream.write_all(response);
let _ = stream.flush();
}
#[must_use]
pub fn shutdown_controller(&self) -> &crate::shutdown::ShutdownController {
&self.shutdown_controller
}
#[must_use]
pub fn in_flight_count(&self) -> usize {
self.shutdown_controller.in_flight_count()
}
pub fn shutdown(&self) {
self.shutdown_controller.shutdown();
self.shutdown
.store(true, std::sync::atomic::Ordering::Release);
}
#[must_use]
pub fn is_shutdown(&self) -> bool {
self.shutdown.load(std::sync::atomic::Ordering::Acquire)
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.shutdown
.store(true, std::sync::atomic::Ordering::Release);
if let Some(handle) = self.handle.take() {
let _ = handle.join();
}
}
}
struct ParsedRequest {
method: String,
path: String,
query: Option<String>,
headers: Vec<(String, String)>,
body: Vec<u8>,
}
#[derive(Debug, Clone)]
pub enum E2EStepResult {
Passed,
Failed(String),
Skipped,
}
impl E2EStepResult {
#[must_use]
pub fn is_passed(&self) -> bool {
matches!(self, Self::Passed)
}
#[must_use]
pub fn is_failed(&self) -> bool {
matches!(self, Self::Failed(_))
}
}
#[derive(Debug, Clone)]
pub struct E2ECapture {
pub method: String,
pub path: String,
pub request_headers: Vec<(String, String)>,
pub request_body: Option<String>,
pub response_status: u16,
pub response_headers: Vec<(String, String)>,
pub response_body: String,
}
#[derive(Debug, Clone)]
pub struct E2EStep {
pub name: String,
pub started_at: std::time::Instant,
pub duration: std::time::Duration,
pub result: E2EStepResult,
pub capture: Option<E2ECapture>,
}
impl E2EStep {
fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
started_at: std::time::Instant::now(),
duration: std::time::Duration::ZERO,
result: E2EStepResult::Skipped,
capture: None,
}
}
fn complete(&mut self, result: E2EStepResult) {
self.duration = self.started_at.elapsed();
self.result = result;
}
}
pub struct E2EScenario<H> {
name: String,
description: Option<String>,
client: TestClient<H>,
steps: Vec<E2EStep>,
stop_on_failure: bool,
has_failure: bool,
log_buffer: Vec<String>,
}
impl<H: Handler + 'static> E2EScenario<H> {
pub fn new(name: impl Into<String>, client: TestClient<H>) -> Self {
let name = name.into();
Self {
name,
description: None,
client,
steps: Vec::new(),
stop_on_failure: true,
has_failure: false,
log_buffer: Vec::new(),
}
}
#[must_use]
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
#[must_use]
pub fn stop_on_failure(mut self, stop: bool) -> Self {
self.stop_on_failure = stop;
self
}
pub fn client(&self) -> &TestClient<H> {
&self.client
}
pub fn client_mut(&mut self) -> &mut TestClient<H> {
&mut self.client
}
pub fn log(&mut self, message: impl Into<String>) {
let msg = message.into();
self.log_buffer.push(format!(
"[{:?}] {}",
std::time::Instant::now().elapsed(),
msg
));
}
pub fn step<F>(&mut self, name: impl Into<String>, f: F)
where
F: FnOnce(&TestClient<H>) + std::panic::UnwindSafe,
{
let name = name.into();
let mut step = E2EStep::new(&name);
if self.has_failure && self.stop_on_failure {
step.complete(E2EStepResult::Skipped);
self.log_buffer.push(format!("[SKIP] {}", name));
self.steps.push(step);
return;
}
self.log_buffer.push(format!("[START] {}", name));
let client_ref = std::panic::AssertUnwindSafe(&self.client);
let result = std::panic::catch_unwind(|| {
f(&client_ref);
});
match result {
Ok(()) => {
step.complete(E2EStepResult::Passed);
self.log_buffer
.push(format!("[PASS] {} ({:?})", name, step.duration));
}
Err(panic_info) => {
let error_msg = if let Some(s) = panic_info.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = panic_info.downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic".to_string()
};
step.complete(E2EStepResult::Failed(error_msg.clone()));
self.has_failure = true;
self.log_buffer
.push(format!("[FAIL] {} - {}", name, error_msg));
}
}
self.steps.push(step);
}
pub fn try_step<F, E>(&mut self, name: impl Into<String>, f: F) -> Result<(), E>
where
F: FnOnce(&TestClient<H>) -> Result<(), E>,
E: std::fmt::Display,
{
let name = name.into();
let mut step = E2EStep::new(&name);
if self.has_failure && self.stop_on_failure {
step.complete(E2EStepResult::Skipped);
self.steps.push(step);
return Ok(());
}
self.log_buffer.push(format!("[START] {}", name));
match f(&self.client) {
Ok(()) => {
step.complete(E2EStepResult::Passed);
self.log_buffer
.push(format!("[PASS] {} ({:?})", name, step.duration));
self.steps.push(step);
Ok(())
}
Err(e) => {
let error_msg = e.to_string();
step.complete(E2EStepResult::Failed(error_msg.clone()));
self.has_failure = true;
self.log_buffer
.push(format!("[FAIL] {} - {}", name, error_msg));
self.steps.push(step);
Err(e)
}
}
}
#[must_use]
pub fn passed(&self) -> bool {
!self.has_failure
}
#[must_use]
pub fn steps(&self) -> &[E2EStep] {
&self.steps
}
#[must_use]
pub fn logs(&self) -> &[String] {
&self.log_buffer
}
#[must_use]
pub fn report(&self) -> E2EReport {
let passed = self.steps.iter().filter(|s| s.result.is_passed()).count();
let failed = self.steps.iter().filter(|s| s.result.is_failed()).count();
let skipped = self
.steps
.iter()
.filter(|s| matches!(s.result, E2EStepResult::Skipped))
.count();
let total_duration: std::time::Duration = self.steps.iter().map(|s| s.duration).sum();
E2EReport {
scenario_name: self.name.clone(),
description: self.description.clone(),
passed,
failed,
skipped,
total_duration,
steps: self.steps.clone(),
logs: self.log_buffer.clone(),
}
}
pub fn assert_passed(&self) {
if !self.passed() {
let report = self.report();
panic!(
"E2E Scenario '{}' failed!\n\n{}",
self.name,
report.to_text()
);
}
}
}
#[derive(Debug, Clone)]
pub struct E2EReport {
pub scenario_name: String,
pub description: Option<String>,
pub passed: usize,
pub failed: usize,
pub skipped: usize,
pub total_duration: std::time::Duration,
pub steps: Vec<E2EStep>,
pub logs: Vec<String>,
}
impl E2EReport {
#[must_use]
pub fn to_text(&self) -> String {
let mut output = String::new();
output.push_str(&format!("E2E Test Report: {}\n", self.scenario_name));
output.push_str(&"=".repeat(60));
output.push('\n');
if let Some(desc) = &self.description {
output.push_str(&format!("Description: {}\n", desc));
}
output.push_str(&format!(
"\nSummary: {} passed, {} failed, {} skipped\n",
self.passed, self.failed, self.skipped
));
output.push_str(&format!("Total Duration: {:?}\n", self.total_duration));
output.push_str(&"-".repeat(60));
output.push('\n');
output.push_str("\nSteps:\n");
for (i, step) in self.steps.iter().enumerate() {
let status = match &step.result {
E2EStepResult::Passed => "[PASS]",
E2EStepResult::Failed(_) => "[FAIL]",
E2EStepResult::Skipped => "[SKIP]",
};
output.push_str(&format!(
" {}. {} {} ({:?})\n",
i + 1,
status,
step.name,
step.duration
));
if let E2EStepResult::Failed(msg) = &step.result {
output.push_str(&format!(" Error: {}\n", msg));
}
}
if !self.logs.is_empty() {
output.push_str(&"-".repeat(60));
output.push_str("\n\nLogs:\n");
for log in &self.logs {
output.push_str(&format!(" {}\n", log));
}
}
output
}
#[must_use]
pub fn to_json(&self) -> String {
let steps_json: Vec<String> = self
.steps
.iter()
.map(|step| {
let status = match &step.result {
E2EStepResult::Passed => "passed",
E2EStepResult::Failed(_) => "failed",
E2EStepResult::Skipped => "skipped",
};
let error = match &step.result {
E2EStepResult::Failed(msg) => format!(r#", "error": "{}""#, escape_json(msg)),
_ => String::new(),
};
format!(
r#" {{ "name": "{}", "status": "{}", "duration_ms": {}{} }}"#,
escape_json(&step.name),
status,
step.duration.as_millis(),
error
)
})
.collect();
format!(
r#"{{
"scenario": "{}",
"description": {},
"summary": {{
"passed": {},
"failed": {},
"skipped": {},
"total_duration_ms": {}
}},
"steps": [
{}
]
}}"#,
escape_json(&self.scenario_name),
self.description
.as_ref()
.map_or("null".to_string(), |d| format!(r#""{}""#, escape_json(d))),
self.passed,
self.failed,
self.skipped,
self.total_duration.as_millis(),
steps_json.join(",\n")
)
}
#[must_use]
pub fn to_html(&self) -> String {
let status_class = if self.failed > 0 { "failed" } else { "passed" };
use std::fmt::Write;
let steps_html =
self.steps
.iter()
.enumerate()
.fold(String::new(), |mut output, (i, step)| {
let (status, class) = match &step.result {
E2EStepResult::Passed => ("✓", "pass"),
E2EStepResult::Failed(_) => ("✗", "fail"),
E2EStepResult::Skipped => ("â—‹", "skip"),
};
let error_html = match &step.result {
E2EStepResult::Failed(msg) => {
format!(r#"<div class="error">{}</div>"#, escape_html(msg))
}
_ => String::new(),
};
let _ = write!(
output,
r#" <tr class="{}">
<td>{}</td>
<td><span class="status">{}</span></td>
<td>{}</td>
<td>{:?}</td>
</tr>
{}"#,
class,
i + 1,
status,
escape_html(&step.name),
step.duration,
error_html
);
output
});
format!(
r#"<!DOCTYPE html>
<html>
<head>
<title>E2E Report: {}</title>
<style>
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; margin: 2rem; }}
h1 {{ color: #333; }}
.summary {{ padding: 1rem; border-radius: 8px; margin: 1rem 0; }}
.summary.passed {{ background: #d4edda; }}
.summary.failed {{ background: #f8d7da; }}
table {{ width: 100%; border-collapse: collapse; margin-top: 1rem; }}
th, td {{ padding: 0.75rem; text-align: left; border-bottom: 1px solid #dee2e6; }}
th {{ background: #f8f9fa; }}
.pass {{ color: #28a745; }}
.fail {{ color: #dc3545; }}
.skip {{ color: #6c757d; }}
.status {{ font-size: 1.2rem; }}
.error {{ color: #dc3545; font-size: 0.9rem; padding: 0.5rem; background: #fff; margin-top: 0.25rem; }}
</style>
</head>
<body>
<h1>E2E Report: {}</h1>
{}
<div class="summary {}">
<strong>Summary:</strong> {} passed, {} failed, {} skipped<br>
<strong>Duration:</strong> {:?}
</div>
<table>
<thead>
<tr><th>#</th><th>Status</th><th>Step</th><th>Duration</th></tr>
</thead>
<tbody>
{}
</tbody>
</table>
</body>
</html>"#,
escape_html(&self.scenario_name),
escape_html(&self.scenario_name),
self.description
.as_ref()
.map_or(String::new(), |d| format!("<p>{}</p>", escape_html(d))),
status_class,
self.passed,
self.failed,
self.skipped,
self.total_duration,
steps_html
)
}
}
fn escape_json(s: &str) -> String {
s.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\t', "\\t")
}
fn escape_html(s: &str) -> String {
s.replace('&', "&")
.replace('<', "<")
.replace('>', ">")
.replace('"', """)
}
#[macro_export]
macro_rules! e2e_test {
(
name: $name:expr,
$(description: $desc:expr,)?
client: $client:expr,
$(step $step_name:literal => |$client_param:ident| $step_body:block),+ $(,)?
) => {{
let client = $client;
let mut scenario = $crate::testing::E2EScenario::new($name, client);
$(
scenario = scenario.description($desc);
)?
$(
scenario.step($step_name, |$client_param| $step_body);
)+
scenario.assert_passed();
scenario.report()
}};
}
pub use e2e_test;
use crate::logging::{LogEntry, LogLevel};
#[derive(Debug, Clone)]
pub struct CapturedLog {
pub level: LogLevel,
pub message: String,
pub request_id: u64,
pub captured_at: std::time::Instant,
pub fields: Vec<(String, String)>,
pub target: Option<String>,
}
impl CapturedLog {
pub fn from_entry(entry: &LogEntry) -> Self {
Self {
level: entry.level,
message: entry.message.clone(),
request_id: entry.request_id,
captured_at: std::time::Instant::now(),
fields: entry.fields.clone(),
target: entry.target.clone(),
}
}
pub fn new(level: LogLevel, message: impl Into<String>, request_id: u64) -> Self {
Self {
level,
message: message.into(),
request_id,
captured_at: std::time::Instant::now(),
fields: Vec::new(),
target: None,
}
}
#[must_use]
pub fn contains(&self, text: &str) -> bool {
self.message.contains(text)
}
#[must_use]
pub fn format(&self) -> String {
let mut output = format!(
"[{}] req={} {}",
self.level.as_char(),
self.request_id,
self.message
);
if !self.fields.is_empty() {
output.push_str(" {");
for (i, (k, v)) in self.fields.iter().enumerate() {
if i > 0 {
output.push_str(", ");
}
output.push_str(&format!("{k}={v}"));
}
output.push('}');
}
output
}
}
#[derive(Debug, Clone)]
pub struct TestLogger {
logs: Arc<Mutex<Vec<CapturedLog>>>,
timings: Arc<Mutex<TestTimings>>,
echo_logs: bool,
}
#[derive(Debug, Clone, Default)]
pub struct TestTimings {
pub setup: Option<std::time::Duration>,
pub execute: Option<std::time::Duration>,
pub teardown: Option<std::time::Duration>,
phase_start: Option<std::time::Instant>,
}
impl TestTimings {
pub fn start_phase(&mut self) {
self.phase_start = Some(std::time::Instant::now());
}
pub fn end_setup(&mut self) {
if let Some(start) = self.phase_start.take() {
self.setup = Some(start.elapsed());
}
}
pub fn end_execute(&mut self) {
if let Some(start) = self.phase_start.take() {
self.execute = Some(start.elapsed());
}
}
pub fn end_teardown(&mut self) {
if let Some(start) = self.phase_start.take() {
self.teardown = Some(start.elapsed());
}
}
#[must_use]
pub fn total(&self) -> std::time::Duration {
self.setup.unwrap_or_default()
+ self.execute.unwrap_or_default()
+ self.teardown.unwrap_or_default()
}
#[must_use]
pub fn format(&self) -> String {
format!(
"Timings: setup={:?}, execute={:?}, teardown={:?}, total={:?}",
self.setup.unwrap_or_default(),
self.execute.unwrap_or_default(),
self.teardown.unwrap_or_default(),
self.total()
)
}
}
impl TestLogger {
pub fn new() -> Self {
Self {
logs: Arc::new(Mutex::new(Vec::new())),
timings: Arc::new(Mutex::new(TestTimings::default())),
echo_logs: std::env::var("FASTAPI_TEST_ECHO_LOGS").is_ok(),
}
}
pub fn with_echo() -> Self {
let mut logger = Self::new();
logger.echo_logs = true;
logger
}
pub fn log(&self, entry: CapturedLog) {
if self.echo_logs {
eprintln!("[LOG] {}", entry.format());
}
self.logs.lock().push(entry);
}
pub fn log_entry(&self, entry: &LogEntry) {
self.log(CapturedLog::from_entry(entry));
}
pub fn log_message(&self, level: LogLevel, message: impl Into<String>, request_id: u64) {
self.log(CapturedLog::new(level, message, request_id));
}
#[must_use]
pub fn logs(&self) -> Vec<CapturedLog> {
self.logs.lock().clone()
}
#[must_use]
pub fn count(&self) -> usize {
self.logs.lock().len()
}
pub fn clear(&self) {
self.logs.lock().clear();
}
#[must_use]
pub fn contains_message(&self, text: &str) -> bool {
self.logs.lock().iter().any(|log| log.contains(text))
}
#[must_use]
pub fn count_by_level(&self, level: LogLevel) -> usize {
self.logs
.lock()
.iter()
.filter(|log| log.level == level)
.count()
}
#[must_use]
pub fn logs_at_level(&self, level: LogLevel) -> Vec<CapturedLog> {
self.logs
.lock()
.iter()
.filter(|log| log.level == level)
.cloned()
.collect()
}
#[must_use]
pub fn failure_context(&self, n: usize) -> String {
let logs = self.logs.lock();
let start = logs.len().saturating_sub(n);
let recent: Vec<_> = logs[start..].iter().map(CapturedLog::format).collect();
if recent.is_empty() {
"No logs captured".to_string()
} else {
format!(
"Last {} log(s) before failure:\n {}",
recent.len(),
recent.join("\n ")
)
}
}
#[must_use]
pub fn timings(&self) -> TestTimings {
self.timings.lock().clone()
}
pub fn start_phase(&self) {
self.timings.lock().start_phase();
}
pub fn end_setup(&self) {
self.timings.lock().end_setup();
}
pub fn end_execute(&self) {
self.timings.lock().end_execute();
}
pub fn end_teardown(&self) {
self.timings.lock().end_teardown();
}
pub fn capture<F, T>(f: F) -> LogCapture<T>
where
F: FnOnce(&TestLogger) -> T,
{
let logger = TestLogger::new();
logger.start_phase();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
logger.end_setup();
logger.start_phase();
let result = f(&logger);
logger.end_execute();
result
}));
let (ok_result, panic_info) = match result {
Ok(v) => (Some(v), None),
Err(p) => {
let msg = if let Some(s) = p.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = p.downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic".to_string()
};
(None, Some(msg))
}
};
LogCapture {
logs: logger.logs(),
timings: logger.timings(),
result: ok_result,
panic_info,
}
}
pub fn capture_phased<S, E, D, T>(setup: S, execute: E, teardown: D) -> LogCapture<T>
where
S: FnOnce(&TestLogger),
E: FnOnce(&TestLogger) -> T,
D: FnOnce(&TestLogger),
{
let logger = TestLogger::new();
logger.start_phase();
let setup_panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
setup(&logger);
}));
logger.end_setup();
if setup_panic.is_err() {
return LogCapture {
logs: logger.logs(),
timings: logger.timings(),
result: None,
panic_info: Some("Setup phase panicked".to_string()),
};
}
logger.start_phase();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| execute(&logger)));
logger.end_execute();
logger.start_phase();
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
teardown(&logger);
}));
logger.end_teardown();
let (ok_result, panic_info) = match result {
Ok(v) => (Some(v), None),
Err(p) => {
let msg = if let Some(s) = p.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = p.downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic".to_string()
};
(None, Some(msg))
}
};
LogCapture {
logs: logger.logs(),
timings: logger.timings(),
result: ok_result,
panic_info,
}
}
}
impl Default for TestLogger {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct LogCapture<T> {
pub logs: Vec<CapturedLog>,
pub timings: TestTimings,
pub result: Option<T>,
pub panic_info: Option<String>,
}
impl<T> LogCapture<T> {
#[must_use]
pub fn passed(&self) -> bool {
self.result.is_some()
}
#[must_use]
pub fn failed(&self) -> bool {
self.panic_info.is_some()
}
#[must_use]
pub fn contains_message(&self, text: &str) -> bool {
self.logs.iter().any(|log| log.contains(text))
}
#[must_use]
pub fn count_by_level(&self, level: LogLevel) -> usize {
self.logs.iter().filter(|log| log.level == level).count()
}
#[must_use]
pub fn failure_context(&self, n: usize) -> String {
let start = self.logs.len().saturating_sub(n);
let recent: Vec<_> = self.logs[start..].iter().map(CapturedLog::format).collect();
let mut output = String::new();
if let Some(ref panic) = self.panic_info {
output.push_str(&format!("Test failed: {}\n\n", panic));
}
output.push_str(&self.timings.format());
output.push_str("\n\n");
if recent.is_empty() {
output.push_str("No logs captured");
} else {
output.push_str(&format!(
"Last {} log(s) before failure:\n {}",
recent.len(),
recent.join("\n ")
));
}
output
}
pub fn unwrap(self) -> T {
match self.result {
Some(v) => v,
None => panic!(
"Test failed with log context:\n{}",
self.failure_context(10)
),
}
}
pub fn unwrap_or(self, default: T) -> T {
self.result.unwrap_or(default)
}
}
#[macro_export]
macro_rules! assert_with_logs {
($logger:expr, $cond:expr) => {
if !$cond {
panic!(
"Assertion failed: {}\n\n{}",
stringify!($cond),
$logger.failure_context(10)
);
}
};
($logger:expr, $cond:expr, $($arg:tt)+) => {
if !$cond {
panic!(
"Assertion failed: {}\n\n{}",
format!($($arg)+),
$logger.failure_context(10)
);
}
};
}
#[macro_export]
macro_rules! assert_eq_with_logs {
($logger:expr, $left:expr, $right:expr) => {
if $left != $right {
panic!(
"Assertion failed: {} == {}\n left: {:?}\n right: {:?}\n\n{}",
stringify!($left),
stringify!($right),
$left,
$right,
$logger.failure_context(10)
);
}
};
($logger:expr, $left:expr, $right:expr, $($arg:tt)+) => {
if $left != $right {
panic!(
"Assertion failed: {}\n left: {:?}\n right: {:?}\n\n{}",
format!($($arg)+),
$left,
$right,
$logger.failure_context(10)
);
}
};
}
pub use assert_eq_with_logs;
pub use assert_with_logs;
#[derive(Debug)]
pub struct ResponseDiff {
pub expected_status: u16,
pub actual_status: u16,
pub expected_body: Option<String>,
pub actual_body: String,
pub header_diffs: Vec<(String, Option<String>, Option<String>)>,
}
impl ResponseDiff {
pub fn new(expected_status: u16, actual: &TestResponse) -> Self {
Self {
expected_status,
actual_status: actual.status().as_u16(),
expected_body: None,
actual_body: actual.text().to_string(),
header_diffs: Vec::new(),
}
}
#[must_use]
pub fn expected_body(mut self, body: impl Into<String>) -> Self {
self.expected_body = Some(body.into());
self
}
#[must_use]
pub fn expected_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.header_diffs
.push((name.into(), Some(value.into()), None));
self
}
#[must_use]
pub fn is_match(&self) -> bool {
if self.expected_status != self.actual_status {
return false;
}
if let Some(ref expected) = self.expected_body {
if !self.actual_body.contains(expected) {
return false;
}
}
true
}
#[must_use]
pub fn format(&self) -> String {
let mut output = String::new();
if self.expected_status != self.actual_status {
output.push_str(&format!(
"Status mismatch:\n expected: {}\n actual: {}\n",
self.expected_status, self.actual_status
));
}
if let Some(ref expected) = self.expected_body {
if !self.actual_body.contains(expected) {
output.push_str(&format!(
"Body mismatch:\n expected to contain: {:?}\n actual: {:?}\n",
expected, self.actual_body
));
}
}
for (name, expected, actual) in &self.header_diffs {
output.push_str(&format!(
"Header '{}' mismatch:\n expected: {:?}\n actual: {:?}\n",
name, expected, actual
));
}
if output.is_empty() {
"No differences".to_string()
} else {
output
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
pub struct ResponseSnapshot {
pub status: u16,
pub headers: Vec<(String, String)>,
pub body: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub body_json: Option<serde_json::Value>,
}
impl ResponseSnapshot {
pub fn from_test_response(resp: &TestResponse) -> Self {
let body = resp.text().to_string();
let body_json = serde_json::from_str::<serde_json::Value>(&body).ok();
let mut headers: Vec<(String, String)> = resp
.headers()
.iter()
.filter_map(|(name, value)| {
std::str::from_utf8(value)
.ok()
.map(|v| (name.to_lowercase(), v.to_string()))
})
.collect();
headers.sort();
Self {
status: resp.status().as_u16(),
headers,
body,
body_json,
}
}
pub fn from_test_response_with_headers(resp: &TestResponse, header_names: &[&str]) -> Self {
let mut snapshot = Self::from_test_response(resp);
let names: Vec<String> = header_names.iter().map(|n| n.to_lowercase()).collect();
snapshot.headers.retain(|(name, _)| names.contains(name));
snapshot
}
#[must_use]
pub fn mask_fields(mut self, paths: &[&str], placeholder: &str) -> Self {
if let Some(ref mut json) = self.body_json {
for path in paths {
mask_json_path(json, path, placeholder);
}
self.body = serde_json::to_string_pretty(json).unwrap_or(self.body);
}
self
}
pub fn save(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
let path = path.as_ref();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let json = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?;
std::fs::write(path, json)
}
pub fn load(path: impl AsRef<std::path::Path>) -> std::io::Result<Self> {
let data = std::fs::read_to_string(path)?;
serde_json::from_str(&data)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
#[must_use]
pub fn diff(&self, other: &Self) -> String {
let mut output = String::new();
if self.status != other.status {
output.push_str(&format!("Status: {} vs {}\n", self.status, other.status));
}
for (name, value) in &self.headers {
match other.headers.iter().find(|(n, _)| n == name) {
Some((_, other_value)) if value != other_value => {
output.push_str(&format!(
"Header '{}': {:?} vs {:?}\n",
name, value, other_value
));
}
None => {
output.push_str(&format!("Header '{}': present vs missing\n", name));
}
_ => {}
}
}
for (name, _) in &other.headers {
if !self.headers.iter().any(|(n, _)| n == name) {
output.push_str(&format!("Header '{}': missing vs present\n", name));
}
}
if self.body != other.body {
output.push_str(&format!(
"Body:\n expected: {:?}\n actual: {:?}\n",
other.body, self.body
));
}
if output.is_empty() {
"No differences".to_string()
} else {
output
}
}
pub fn matches_ignoring_headers(&self, other: &Self, ignore: &[&str]) -> bool {
if self.status != other.status {
return false;
}
let ignore_lower: Vec<String> = ignore.iter().map(|s| s.to_lowercase()).collect();
let self_headers: Vec<_> = self
.headers
.iter()
.filter(|(n, _)| !ignore_lower.contains(n))
.collect();
let other_headers: Vec<_> = other
.headers
.iter()
.filter(|(n, _)| !ignore_lower.contains(n))
.collect();
if self_headers != other_headers {
return false;
}
match (&self.body_json, &other.body_json) {
(Some(a), Some(b)) => a == b,
_ => self.body == other.body,
}
}
}
fn mask_json_path(value: &mut serde_json::Value, path: &str, placeholder: &str) {
let parts: Vec<&str> = path.splitn(2, '.').collect();
match parts.as_slice() {
[key] => {
if let Some(obj) = value.as_object_mut() {
if obj.contains_key(*key) {
obj.insert(
key.to_string(),
serde_json::Value::String(placeholder.to_string()),
);
}
}
if let Some(arr) = value.as_array_mut() {
if let Ok(idx) = key.parse::<usize>() {
if idx < arr.len() {
arr[idx] = serde_json::Value::String(placeholder.to_string());
}
}
}
}
[key, rest] => {
if let Some(obj) = value.as_object_mut() {
if let Some(child) = obj.get_mut(*key) {
mask_json_path(child, rest, placeholder);
}
}
if let Some(arr) = value.as_array_mut() {
if let Ok(idx) = key.parse::<usize>() {
if let Some(child) = arr.get_mut(idx) {
mask_json_path(child, rest, placeholder);
}
}
}
}
_ => {}
}
}
#[macro_export]
macro_rules! assert_response_snapshot {
($response:expr, $path:expr) => {{
let snapshot = $crate::ResponseSnapshot::from_test_response(&$response);
let path = std::path::Path::new($path);
if std::env::var("SNAPSHOT_UPDATE").is_ok() || !path.exists() {
snapshot.save(path).expect("failed to save snapshot");
} else {
let expected =
$crate::ResponseSnapshot::load(path).expect("failed to load snapshot");
assert!(
snapshot == expected,
"Snapshot mismatch for {}:\n{}",
$path,
snapshot.diff(&expected)
);
}
}};
($response:expr, $path:expr, mask: [$($field:expr),* $(,)?]) => {{
let snapshot = $crate::ResponseSnapshot::from_test_response(&$response)
.mask_fields(&[$($field),*], "<MASKED>");
let path = std::path::Path::new($path);
if std::env::var("SNAPSHOT_UPDATE").is_ok() || !path.exists() {
snapshot.save(path).expect("failed to save snapshot");
} else {
let expected =
$crate::ResponseSnapshot::load(path).expect("failed to load snapshot");
assert!(
snapshot == expected,
"Snapshot mismatch for {}:\n{}",
$path,
snapshot.diff(&expected)
);
}
}};
}
#[cfg(test)]
mod snapshot_tests {
use super::*;
fn mock_test_response(status: u16, body: &str, headers: &[(&str, &str)]) -> TestResponse {
let mut resp =
crate::response::Response::with_status(crate::response::StatusCode::from_u16(status));
for (name, value) in headers {
resp = resp.header(*name, value.as_bytes().to_vec());
}
resp = resp.body(crate::response::ResponseBody::Bytes(
body.as_bytes().to_vec(),
));
TestResponse::new(resp, 0)
}
#[test]
fn snapshot_from_test_response() {
let resp = mock_test_response(
200,
r#"{"id":1,"name":"Alice"}"#,
&[("content-type", "application/json")],
);
let snap = ResponseSnapshot::from_test_response(&resp);
assert_eq!(snap.status, 200);
assert!(snap.body_json.is_some());
assert_eq!(snap.body_json.as_ref().unwrap()["name"], "Alice");
}
#[test]
fn snapshot_equality() {
let resp = mock_test_response(200, "hello", &[]);
let snap1 = ResponseSnapshot::from_test_response(&resp);
let snap2 = ResponseSnapshot::from_test_response(&resp);
assert_eq!(snap1, snap2);
}
#[test]
fn snapshot_diff_status() {
let s1 = ResponseSnapshot {
status: 200,
headers: vec![],
body: "ok".to_string(),
body_json: None,
};
let s2 = ResponseSnapshot {
status: 404,
..s1.clone()
};
let diff = s1.diff(&s2);
assert!(diff.contains("200 vs 404"));
}
#[test]
fn snapshot_diff_body() {
let s1 = ResponseSnapshot {
status: 200,
headers: vec![],
body: "hello".to_string(),
body_json: None,
};
let s2 = ResponseSnapshot {
body: "world".to_string(),
..s1.clone()
};
let diff = s1.diff(&s2);
assert!(diff.contains("Body:"));
}
#[test]
fn snapshot_diff_no_differences() {
let s = ResponseSnapshot {
status: 200,
headers: vec![],
body: "ok".to_string(),
body_json: None,
};
assert_eq!(s.diff(&s), "No differences");
}
#[test]
fn snapshot_mask_fields() {
let resp = mock_test_response(
200,
r#"{"id":42,"name":"Alice","created_at":"2026-01-01"}"#,
&[],
);
let snap = ResponseSnapshot::from_test_response(&resp)
.mask_fields(&["id", "created_at"], "<MASKED>");
let json = snap.body_json.unwrap();
assert_eq!(json["id"], "<MASKED>");
assert_eq!(json["name"], "Alice");
assert_eq!(json["created_at"], "<MASKED>");
}
#[test]
fn snapshot_mask_nested_fields() {
let resp = mock_test_response(200, r#"{"user":{"id":1,"name":"Bob"}}"#, &[]);
let snap =
ResponseSnapshot::from_test_response(&resp).mask_fields(&["user.id"], "<MASKED>");
let json = snap.body_json.unwrap();
assert_eq!(json["user"]["id"], "<MASKED>");
assert_eq!(json["user"]["name"], "Bob");
}
#[test]
fn snapshot_save_and_load() {
let snap = ResponseSnapshot {
status: 200,
headers: vec![("content-type".to_string(), "application/json".to_string())],
body: r#"{"ok":true}"#.to_string(),
body_json: Some(serde_json::json!({"ok": true})),
};
let dir = std::env::temp_dir().join("fastapi_snapshot_test");
let path = dir.join("test_snap.json");
snap.save(&path).unwrap();
let loaded = ResponseSnapshot::load(&path).unwrap();
assert_eq!(snap, loaded);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn snapshot_matches_ignoring_headers() {
let s1 = ResponseSnapshot {
status: 200,
headers: vec![
("content-type".to_string(), "application/json".to_string()),
("x-request-id".to_string(), "abc".to_string()),
],
body: "ok".to_string(),
body_json: None,
};
let s2 = ResponseSnapshot {
headers: vec![
("content-type".to_string(), "application/json".to_string()),
("x-request-id".to_string(), "xyz".to_string()),
],
..s1.clone()
};
assert!(!s1.matches_ignoring_headers(&s2, &[]));
assert!(s1.matches_ignoring_headers(&s2, &["X-Request-Id"]));
}
#[test]
fn snapshot_with_selected_headers() {
let resp = mock_test_response(
200,
"ok",
&[
("content-type", "text/plain"),
("x-request-id", "abc123"),
("x-trace-id", "trace-456"),
],
);
let snap = ResponseSnapshot::from_test_response_with_headers(&resp, &["content-type"]);
assert_eq!(snap.headers.len(), 1);
assert_eq!(snap.headers[0].0, "content-type");
}
#[test]
fn snapshot_json_structural_comparison() {
let s1 = ResponseSnapshot {
status: 200,
headers: vec![],
body: r#"{"a":1,"b":2}"#.to_string(),
body_json: Some(serde_json::json!({"a": 1, "b": 2})),
};
let s2 = ResponseSnapshot {
body: r#"{"b":2,"a":1}"#.to_string(),
body_json: Some(serde_json::json!({"b": 2, "a": 1})),
..s1.clone()
};
assert_ne!(s1, s2);
assert!(s1.matches_ignoring_headers(&s2, &[]));
}
}
#[cfg(test)]
mod mock_server_tests {
use super::*;
#[test]
fn mock_server_starts_and_responds() {
let server = MockServer::start();
server.mock_response("/hello", MockResponse::ok().body_str("Hello, World!"));
let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
stream
.write_all(b"GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n")
.unwrap();
let mut response = String::new();
stream.read_to_string(&mut response).unwrap();
assert!(response.contains("200 OK"));
assert!(response.contains("Hello, World!"));
}
#[test]
fn mock_server_records_requests() {
let server = MockServer::start();
let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
stream
.write_all(b"GET /api/users HTTP/1.1\r\nHost: localhost\r\nX-Custom: value\r\n\r\n")
.unwrap();
let mut response = Vec::new();
let _ = stream.read_to_end(&mut response);
thread::sleep(Duration::from_millis(50));
let requests = server.requests();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].method, "GET");
assert_eq!(requests[0].path, "/api/users");
assert_eq!(requests[0].header("x-custom"), Some("value"));
}
#[test]
fn mock_server_handles_post_with_body() {
let server = MockServer::start();
server.mock_response(
"/api/create",
MockResponse::with_status(201).body_str("Created"),
);
let body = r#"{"name":"test"}"#;
let request = format!(
"POST /api/create HTTP/1.1\r\nHost: localhost\r\nContent-Length: {}\r\nContent-Type: application/json\r\n\r\n{}",
body.len(),
body
);
let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
stream.write_all(request.as_bytes()).unwrap();
let mut response = String::new();
stream.read_to_string(&mut response).unwrap();
assert!(response.contains("201 Created"));
thread::sleep(Duration::from_millis(50));
let requests = server.requests();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].method, "POST");
assert_eq!(requests[0].body_text(), body);
}
#[test]
fn mock_server_pattern_matching() {
let server = MockServer::start();
server.mock_response("/api/*", MockResponse::ok().body_str("API Response"));
let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
stream
.write_all(b"GET /api/users/123 HTTP/1.1\r\nHost: localhost\r\n\r\n")
.unwrap();
let mut response = String::new();
stream.read_to_string(&mut response).unwrap();
assert!(response.contains("API Response"));
}
#[test]
fn mock_server_default_response() {
let server = MockServer::start();
let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
stream
.write_all(b"GET /unknown HTTP/1.1\r\nHost: localhost\r\n\r\n")
.unwrap();
let mut response = String::new();
stream.read_to_string(&mut response).unwrap();
assert!(response.contains("404"));
}
#[test]
fn mock_server_url_helpers() {
let server = MockServer::start();
let url = server.url();
assert!(url.starts_with("http://127.0.0.1:"));
let api_url = server.url_for("/api/users");
assert!(api_url.contains("/api/users"));
}
#[test]
fn mock_server_clear_requests() {
let server = MockServer::start();
let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
stream
.write_all(b"GET /test HTTP/1.1\r\nHost: localhost\r\n\r\n")
.unwrap();
let mut response = Vec::new();
let _ = stream.read_to_end(&mut response);
thread::sleep(Duration::from_millis(50));
assert_eq!(server.request_count(), 1);
server.clear_requests();
assert_eq!(server.request_count(), 0);
}
#[test]
fn mock_server_wait_for_requests() {
let server = MockServer::start();
let addr = server.addr();
thread::spawn(move || {
thread::sleep(Duration::from_millis(50));
let mut stream = StdTcpStream::connect(addr).expect("Failed to connect");
stream
.write_all(b"GET /delayed HTTP/1.1\r\nHost: localhost\r\n\r\n")
.unwrap();
});
let received = server.wait_for_requests(1, Duration::from_millis(500));
assert!(received);
assert_eq!(server.request_count(), 1);
}
#[test]
fn mock_server_assert_helpers() {
let server = MockServer::start();
let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
stream
.write_all(b"GET /expected HTTP/1.1\r\nHost: localhost\r\n\r\n")
.unwrap();
let mut response = Vec::new();
let _ = stream.read_to_end(&mut response);
thread::sleep(Duration::from_millis(50));
server.assert_received("/expected");
server.assert_not_received("/not-expected");
server.assert_request_count(1);
}
#[test]
fn mock_server_query_string_parsing() {
let server = MockServer::start();
let mut stream = StdTcpStream::connect(server.addr()).expect("Failed to connect");
stream
.write_all(b"GET /search?q=rust&limit=10 HTTP/1.1\r\nHost: localhost\r\n\r\n")
.unwrap();
let mut response = Vec::new();
let _ = stream.read_to_end(&mut response);
thread::sleep(Duration::from_millis(50));
let requests = server.requests();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].path, "/search");
assert_eq!(requests[0].query, Some("q=rust&limit=10".to_string()));
assert_eq!(requests[0].url(), "/search?q=rust&limit=10");
}
#[test]
fn mock_response_json() {
#[derive(serde::Serialize)]
struct User {
name: String,
}
let response = MockResponse::ok().json(&User {
name: "Alice".to_string(),
});
let bytes = response.to_http_response();
let http = String::from_utf8_lossy(&bytes);
assert!(http.contains("application/json"));
assert!(http.contains("Alice"));
}
#[test]
fn recorded_request_helpers() {
let request = RecordedRequest {
method: "GET".to_string(),
path: "/api/users".to_string(),
query: Some("page=1".to_string()),
headers: vec![("Content-Type".to_string(), "application/json".to_string())],
body: b"test body".to_vec(),
timestamp: std::time::Instant::now(),
};
assert_eq!(request.body_text(), "test body");
assert_eq!(request.header("content-type"), Some("application/json"));
assert_eq!(request.url(), "/api/users?page=1");
}
}
#[cfg(test)]
mod e2e_tests {
use super::*;
fn test_handler(_ctx: &RequestContext, req: &mut Request) -> std::future::Ready<Response> {
let path = req.path();
let response = match path {
"/" => Response::ok().body(ResponseBody::Bytes(b"Home".to_vec())),
"/login" => Response::ok().body(ResponseBody::Bytes(b"Login Page".to_vec())),
"/dashboard" => Response::ok().body(ResponseBody::Bytes(b"Dashboard".to_vec())),
"/api/users" => {
Response::ok().body(ResponseBody::Bytes(b"[\"Alice\",\"Bob\"]".to_vec()))
}
"/fail" => Response::with_status(StatusCode::INTERNAL_SERVER_ERROR)
.body(ResponseBody::Bytes(b"Error".to_vec())),
_ => Response::with_status(StatusCode::NOT_FOUND)
.body(ResponseBody::Bytes(b"Not Found".to_vec())),
};
std::future::ready(response)
}
#[test]
fn e2e_scenario_all_steps_pass() {
let client = TestClient::new(test_handler);
let mut scenario = E2EScenario::new("Basic Navigation", client);
scenario.step("Visit home page", |client| {
let response = client.get("/").send();
assert_eq!(response.status().as_u16(), 200);
assert_eq!(response.text(), "Home");
});
scenario.step("Visit login page", |client| {
let response = client.get("/login").send();
assert_eq!(response.status().as_u16(), 200);
});
assert!(scenario.passed());
assert_eq!(scenario.steps().len(), 2);
assert!(scenario.steps().iter().all(|s| s.result.is_passed()));
}
#[test]
fn e2e_scenario_step_failure() {
let client = TestClient::new(test_handler);
let mut scenario = E2EScenario::new("Failure Test", client).stop_on_failure(true);
scenario.step("First step passes", |client| {
let response = client.get("/").send();
assert_eq!(response.status().as_u16(), 200);
});
scenario.step("Second step fails", |_client| {
panic!("Intentional failure");
});
scenario.step("Third step skipped", |client| {
let response = client.get("/dashboard").send();
assert_eq!(response.status().as_u16(), 200);
});
assert!(!scenario.passed());
assert_eq!(scenario.steps().len(), 3);
assert!(scenario.steps()[0].result.is_passed());
assert!(scenario.steps()[1].result.is_failed());
assert!(matches!(scenario.steps()[2].result, E2EStepResult::Skipped));
}
#[test]
fn e2e_scenario_continue_on_failure() {
let client = TestClient::new(test_handler);
let mut scenario = E2EScenario::new("Continue Test", client).stop_on_failure(false);
scenario.step("First step fails", |_client| {
panic!("First failure");
});
scenario.step("Second step still runs", |client| {
let response = client.get("/").send();
assert_eq!(response.status().as_u16(), 200);
});
assert!(!scenario.passed());
assert_eq!(scenario.steps().len(), 2);
assert!(scenario.steps()[0].result.is_failed());
assert!(scenario.steps()[1].result.is_passed());
}
#[test]
fn e2e_report_text_format() {
let client = TestClient::new(test_handler);
let mut scenario =
E2EScenario::new("Report Test", client).description("Tests report generation");
scenario.step("Step 1", |client| {
let _ = client.get("/").send();
});
let report = scenario.report();
let text = report.to_text();
assert!(text.contains("E2E Test Report: Report Test"));
assert!(text.contains("Tests report generation"));
assert!(text.contains("1 passed"));
assert!(text.contains("Step 1"));
}
#[test]
fn e2e_report_json_format() {
let client = TestClient::new(test_handler);
let mut scenario = E2EScenario::new("JSON Test", client);
scenario.step("API call", |client| {
let response = client.get("/api/users").send();
assert_eq!(response.status().as_u16(), 200);
});
let report = scenario.report();
let json = report.to_json();
assert!(json.contains(r#""scenario": "JSON Test""#));
assert!(json.contains(r#""passed": 1"#));
assert!(json.contains(r#""name": "API call""#));
assert!(json.contains(r#""status": "passed""#));
}
#[test]
fn e2e_report_html_format() {
let client = TestClient::new(test_handler);
let mut scenario = E2EScenario::new("HTML Test", client);
scenario.step("Web visit", |client| {
let _ = client.get("/").send();
});
let report = scenario.report();
let html = report.to_html();
assert!(html.contains("<!DOCTYPE html>"));
assert!(html.contains("E2E Report: HTML Test"));
assert!(html.contains("1 passed"));
assert!(html.contains("Web visit"));
}
#[test]
fn e2e_step_timing() {
let client = TestClient::new(test_handler);
let mut scenario = E2EScenario::new("Timing Test", client);
scenario.step("Timed step", |_client| {
std::thread::sleep(std::time::Duration::from_millis(10));
});
assert!(scenario.steps()[0].duration >= std::time::Duration::from_millis(10));
}
#[test]
fn e2e_logs_captured() {
let client = TestClient::new(test_handler);
let mut scenario = E2EScenario::new("Log Test", client);
scenario.log("Manual log entry");
scenario.step("Logged step", |_client| {});
assert!(
scenario
.logs()
.iter()
.any(|l| l.contains("Manual log entry"))
);
assert!(
scenario
.logs()
.iter()
.any(|l| l.contains("[START] Logged step"))
);
assert!(
scenario
.logs()
.iter()
.any(|l| l.contains("[PASS] Logged step"))
);
}
#[test]
fn e2e_try_step_with_result() {
let client = TestClient::new(test_handler);
let mut scenario = E2EScenario::new("Try Step Test", client);
let result: Result<(), &str> = scenario.try_step("Success step", |client| {
let response = client.get("/").send();
if response.status().as_u16() == 200 {
Ok(())
} else {
Err("Unexpected status")
}
});
assert!(result.is_ok());
assert!(scenario.passed());
}
#[test]
fn e2e_escape_functions() {
assert_eq!(escape_json("hello"), "hello");
assert_eq!(escape_json("a\"b"), "a\\\"b");
assert_eq!(escape_json("a\nb"), "a\\nb");
assert_eq!(escape_html("hello"), "hello");
assert_eq!(escape_html("<script>"), "<script>");
assert_eq!(escape_html("a&b"), "a&b");
}
#[test]
fn e2e_step_result_helpers() {
let passed = E2EStepResult::Passed;
let failed = E2EStepResult::Failed("error".to_string());
let skipped = E2EStepResult::Skipped;
assert!(passed.is_passed());
assert!(!passed.is_failed());
assert!(!failed.is_passed());
assert!(failed.is_failed());
assert!(!skipped.is_passed());
assert!(!skipped.is_failed());
}
}
pub trait TestFixture: Sized + Send {
fn setup() -> Self;
fn teardown(&mut self) {}
}
pub struct FixtureGuard<F: TestFixture> {
fixture: Option<F>,
}
impl<F: TestFixture> FixtureGuard<F> {
pub fn new() -> Self {
Self {
fixture: Some(F::setup()),
}
}
pub fn get(&self) -> &F {
self.fixture.as_ref().unwrap()
}
pub fn get_mut(&mut self) -> &mut F {
self.fixture.as_mut().unwrap()
}
}
impl<F: TestFixture> Default for FixtureGuard<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: TestFixture> Drop for FixtureGuard<F> {
fn drop(&mut self) {
if let Some(mut fixture) = self.fixture.take() {
fixture.teardown();
}
}
}
impl<F: TestFixture> std::ops::Deref for FixtureGuard<F> {
type Target = F;
fn deref(&self) -> &Self::Target {
self.get()
}
}
impl<F: TestFixture> std::ops::DerefMut for FixtureGuard<F> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.get_mut()
}
}
pub struct IntegrationTest<H: Handler + 'static> {
name: String,
client: TestClient<H>,
fixtures: HashMap<std::any::TypeId, Box<dyn std::any::Any + Send>>,
reset_hooks: Vec<Box<dyn Fn() + Send + Sync>>,
}
impl<H: Handler + 'static> IntegrationTest<H> {
pub fn new(name: impl Into<String>, handler: H) -> Self {
Self {
name: name.into(),
client: TestClient::new(handler),
fixtures: HashMap::new(),
reset_hooks: Vec::new(),
}
}
pub fn with_seed(name: impl Into<String>, handler: H, seed: u64) -> Self {
Self {
name: name.into(),
client: TestClient::with_seed(handler, seed),
fixtures: HashMap::new(),
reset_hooks: Vec::new(),
}
}
#[must_use]
pub fn with_fixture<F: TestFixture + 'static>(mut self) -> Self {
let guard = FixtureGuard::<F>::new();
self.fixtures
.insert(std::any::TypeId::of::<F>(), Box::new(guard));
self
}
#[must_use]
pub fn on_reset<F: Fn() + Send + Sync + 'static>(mut self, f: F) -> Self {
self.reset_hooks.push(Box::new(f));
self
}
pub fn run<F>(mut self, test_fn: F)
where
F: FnOnce(&IntegrationTestContext<'_, H>) + std::panic::UnwindSafe,
{
let ctx = IntegrationTestContext {
name: &self.name,
client: &self.client,
fixtures: &self.fixtures,
};
let ctx_ref = std::panic::AssertUnwindSafe(&ctx);
let result = std::panic::catch_unwind(|| {
test_fn(&ctx_ref);
});
for hook in &self.reset_hooks {
hook();
}
self.client.clear_cookies();
self.client.clear_dependency_overrides();
self.fixtures.clear();
if let Err(e) = result {
std::panic::resume_unwind(e);
}
}
}
pub struct IntegrationTestContext<'a, H: Handler> {
name: &'a str,
client: &'a TestClient<H>,
fixtures: &'a HashMap<std::any::TypeId, Box<dyn std::any::Any + Send>>,
}
impl<'a, H: Handler + 'static> IntegrationTestContext<'a, H> {
#[must_use]
pub fn name(&self) -> &str {
self.name
}
#[must_use]
pub fn client(&self) -> &TestClient<H> {
self.client
}
#[must_use]
pub fn fixture<F: TestFixture + 'static>(&self) -> Option<&F> {
self.fixtures
.get(&std::any::TypeId::of::<F>())
.and_then(|boxed| boxed.downcast_ref::<FixtureGuard<F>>())
.map(FixtureGuard::get)
}
#[must_use]
pub fn fixture_mut<F: TestFixture + 'static>(&self) -> Option<&mut F> {
None }
pub fn get(&self, path: &str) -> RequestBuilder<'_, H> {
self.client.get(path)
}
pub fn post(&self, path: &str) -> RequestBuilder<'_, H> {
self.client.post(path)
}
pub fn put(&self, path: &str) -> RequestBuilder<'_, H> {
self.client.put(path)
}
pub fn delete(&self, path: &str) -> RequestBuilder<'_, H> {
self.client.delete(path)
}
pub fn patch(&self, path: &str) -> RequestBuilder<'_, H> {
self.client.patch(path)
}
pub fn options(&self, path: &str) -> RequestBuilder<'_, H> {
self.client.options(path)
}
pub fn request(&self, method: Method, path: &str) -> RequestBuilder<'_, H> {
self.client.request(method, path)
}
}
#[cfg(test)]
mod test_server_tests {
use super::*;
use crate::app::App;
use std::net::TcpStream as StdTcpStreamAlias;
fn make_test_app() -> App {
App::builder()
.get("/health", |_ctx: &RequestContext, _req: &mut Request| {
std::future::ready(
Response::ok()
.header("content-type", b"text/plain".to_vec())
.body(ResponseBody::Bytes(b"OK".to_vec())),
)
})
.get("/hello", |_ctx: &RequestContext, _req: &mut Request| {
std::future::ready(
Response::ok()
.header("content-type", b"application/json".to_vec())
.body(ResponseBody::Bytes(
br#"{"message":"Hello, World!"}"#.to_vec(),
)),
)
})
.post("/echo", |_ctx: &RequestContext, req: &mut Request| {
let body = match req.body() {
Body::Bytes(b) => b.clone(),
_ => Vec::new(),
};
std::future::ready(
Response::ok()
.header("content-type", b"application/octet-stream".to_vec())
.body(ResponseBody::Bytes(body)),
)
})
.build()
}
fn send_request(addr: SocketAddr, request: &[u8]) -> String {
let mut stream = StdTcpStreamAlias::connect(addr).expect("Failed to connect to TestServer");
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.expect("set_read_timeout");
stream.write_all(request).expect("Failed to write request");
stream.flush().expect("Failed to flush");
let mut buf = vec![0u8; 65536];
let n = stream.read(&mut buf).expect("Failed to read response");
String::from_utf8_lossy(&buf[..n]).to_string()
}
#[test]
fn test_server_starts_and_responds() {
let app = make_test_app();
let server = TestServer::start(app);
let response = send_request(
server.addr(),
b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
assert!(
response.contains("200 OK"),
"Expected 200 OK, got: {response}"
);
assert!(response.contains("OK"), "Expected body 'OK'");
}
#[test]
fn test_server_json_response() {
let app = make_test_app();
let server = TestServer::start(app);
let response = send_request(
server.addr(),
b"GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
assert!(response.contains("200 OK"));
assert!(response.contains("application/json"));
assert!(response.contains(r#"{"message":"Hello, World!"}"#));
}
#[test]
fn test_server_post_with_body() {
let app = make_test_app();
let server = TestServer::start(app);
let request =
b"POST /echo HTTP/1.1\r\nHost: localhost\r\nContent-Length: 11\r\n\r\nHello World";
let response = send_request(server.addr(), request);
assert!(response.contains("200 OK"));
assert!(response.contains("Hello World"));
}
#[test]
fn test_server_logs_requests() {
let app = make_test_app();
let server = TestServer::start(app);
send_request(
server.addr(),
b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
let logs = server.log_entries();
assert_eq!(logs.len(), 1);
assert_eq!(logs[0].method, "GET");
assert_eq!(logs[0].path, "/health");
assert_eq!(logs[0].status, 200);
}
#[test]
fn test_server_request_count() {
let app = make_test_app();
let server = TestServer::start(app);
assert_eq!(server.request_count(), 0);
send_request(
server.addr(),
b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
send_request(
server.addr(),
b"GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
assert_eq!(server.request_count(), 2);
}
#[test]
fn test_server_clear_logs() {
let app = make_test_app();
let server = TestServer::start(app);
send_request(
server.addr(),
b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
assert_eq!(server.request_count(), 1);
server.clear_logs();
assert_eq!(server.request_count(), 0);
}
#[test]
fn test_server_url_helpers() {
let app = make_test_app();
let server = TestServer::start(app);
assert!(server.url().starts_with("http://127.0.0.1:"));
assert!(server.url_for("/health").ends_with("/health"));
assert!(server.url_for("health").ends_with("/health"));
assert!(server.port() > 0);
}
#[test]
fn test_server_shutdown() {
let app = make_test_app();
let server = TestServer::start(app);
let addr = server.addr();
let response = send_request(addr, b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n");
assert!(response.contains("200 OK"));
server.shutdown();
assert!(server.is_shutdown());
}
#[test]
fn test_server_config_no_logging() {
let app = make_test_app();
let config = TestServerConfig::new().log_requests(false);
let server = TestServer::start_with_config(app, config);
send_request(
server.addr(),
b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
assert_eq!(server.request_count(), 0);
}
#[test]
fn test_server_bad_request() {
let app = make_test_app();
let server = TestServer::start(app);
let response = send_request(server.addr(), b"NOT_HTTP_AT_ALL");
assert!(response.contains("400 Bad Request"));
}
#[test]
fn test_server_content_length_header() {
let app = make_test_app();
let server = TestServer::start(app);
let response = send_request(
server.addr(),
b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
assert!(
response.contains("content-length: 2"),
"Expected content-length: 2, got: {response}"
);
}
#[test]
fn test_server_multiple_requests_sequential() {
let app = make_test_app();
let server = TestServer::start(app);
for _ in 0..5 {
let response = send_request(
server.addr(),
b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
assert!(response.contains("200 OK"));
}
assert_eq!(server.request_count(), 5);
}
#[test]
fn test_server_log_entry_has_timing() {
let app = make_test_app();
let server = TestServer::start(app);
send_request(
server.addr(),
b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
let logs = server.log_entries();
assert_eq!(logs.len(), 1);
assert!(logs[0].duration < Duration::from_secs(1));
}
#[test]
fn test_server_shutdown_controller_available() {
let app = make_test_app();
let server = TestServer::start(app);
let controller = server.shutdown_controller();
assert!(!controller.is_shutting_down());
assert_eq!(controller.phase(), crate::shutdown::ShutdownPhase::Running);
}
#[test]
fn test_server_shutdown_triggers_controller() {
let app = make_test_app();
let server = TestServer::start(app);
assert!(!server.shutdown_controller().is_shutting_down());
server.shutdown();
assert!(server.is_shutdown());
assert!(server.shutdown_controller().is_shutting_down());
assert_eq!(
server.shutdown_controller().phase(),
crate::shutdown::ShutdownPhase::StopAccepting
);
}
#[test]
fn test_server_requests_complete_before_shutdown() {
let app = make_test_app();
let server = TestServer::start(app);
let response = send_request(
server.addr(),
b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
assert!(response.contains("200 OK"));
assert_eq!(server.request_count(), 1);
server.shutdown();
let logs = server.log_entries();
assert_eq!(logs.len(), 1);
assert_eq!(logs[0].status, 200);
assert_eq!(logs[0].path, "/health");
}
#[test]
fn test_server_in_flight_tracking() {
let app = make_test_app();
let server = TestServer::start(app);
assert_eq!(server.in_flight_count(), 0);
send_request(
server.addr(),
b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(500);
while server.in_flight_count() > 0 && start.elapsed() < timeout {
std::thread::sleep(std::time::Duration::from_millis(1));
}
assert_eq!(
server.in_flight_count(),
0,
"In-flight count should return to 0 after request completes"
);
}
#[test]
fn test_server_in_flight_guard_tracks_correctly() {
let app = make_test_app();
let server = TestServer::start(app);
let controller = server.shutdown_controller();
assert_eq!(controller.in_flight_count(), 0);
let guard1 = controller.track_request();
assert_eq!(controller.in_flight_count(), 1);
let guard2 = controller.track_request();
assert_eq!(controller.in_flight_count(), 2);
drop(guard1);
assert_eq!(controller.in_flight_count(), 1);
drop(guard2);
assert_eq!(controller.in_flight_count(), 0);
}
#[test]
fn test_server_shutdown_hooks_executed() {
let app = make_test_app();
let server = TestServer::start(app);
let hook_executed = Arc::new(AtomicBool::new(false));
let hook_executed_clone = Arc::clone(&hook_executed);
server.shutdown_controller().register_hook(move || {
hook_executed_clone.store(true, std::sync::atomic::Ordering::Release);
});
assert!(!hook_executed.load(std::sync::atomic::Ordering::Acquire));
server.shutdown();
drop(server);
assert!(
hook_executed.load(std::sync::atomic::Ordering::Acquire),
"Shutdown hook should have been executed"
);
}
#[test]
fn test_server_multiple_shutdown_hooks_lifo() {
let app = make_test_app();
let server = TestServer::start(app);
let execution_order = Arc::new(Mutex::new(Vec::new()));
let order1 = Arc::clone(&execution_order);
server.shutdown_controller().register_hook(move || {
order1.lock().push(1);
});
let order2 = Arc::clone(&execution_order);
server.shutdown_controller().register_hook(move || {
order2.lock().push(2);
});
let order3 = Arc::clone(&execution_order);
server.shutdown_controller().register_hook(move || {
order3.lock().push(3);
});
server.shutdown();
drop(server);
let order = execution_order.lock();
assert_eq!(*order, vec![3, 2, 1]);
}
#[test]
fn test_server_shutdown_controller_phase_progression() {
let app = make_test_app();
let server = TestServer::start(app);
let controller = server.shutdown_controller();
assert_eq!(controller.phase(), crate::shutdown::ShutdownPhase::Running);
assert!(controller.advance_phase());
assert_eq!(
controller.phase(),
crate::shutdown::ShutdownPhase::StopAccepting
);
assert!(controller.advance_phase());
assert_eq!(
controller.phase(),
crate::shutdown::ShutdownPhase::ShutdownFlagged
);
assert!(controller.advance_phase());
assert_eq!(
controller.phase(),
crate::shutdown::ShutdownPhase::GracePeriod
);
assert!(controller.advance_phase());
assert_eq!(
controller.phase(),
crate::shutdown::ShutdownPhase::Cancelling
);
assert!(controller.advance_phase());
assert_eq!(
controller.phase(),
crate::shutdown::ShutdownPhase::RunningHooks
);
assert!(controller.advance_phase());
assert_eq!(controller.phase(), crate::shutdown::ShutdownPhase::Stopped);
assert!(!controller.advance_phase());
}
#[test]
fn test_server_receiver_notified_on_shutdown() {
let app = make_test_app();
let server = TestServer::start(app);
let receiver = server.shutdown_controller().subscribe();
assert!(!receiver.is_shutting_down());
server.shutdown();
assert!(receiver.is_shutting_down());
assert!(!receiver.is_forced());
}
#[test]
fn test_server_forced_shutdown() {
let app = make_test_app();
let server = TestServer::start(app);
let receiver = server.shutdown_controller().subscribe();
server.shutdown_controller().shutdown();
assert!(receiver.is_shutting_down());
assert!(!receiver.is_forced());
server.shutdown_controller().shutdown();
assert!(receiver.is_forced());
}
#[test]
fn test_server_requests_work_before_shutdown_signal() {
let app = make_test_app();
let server = TestServer::start(app);
for i in 0..3 {
let response = send_request(
server.addr(),
b"GET /health HTTP/1.1\r\nHost: localhost\r\n\r\n",
);
assert!(
response.contains("200 OK"),
"Request {i} should succeed before shutdown"
);
}
assert_eq!(server.request_count(), 3);
server.shutdown();
assert!(server.is_shutdown());
}
}