use std::collections::HashMap;
use axum::{
Router,
http::{StatusCode, header::LOCATION},
};
use crate::http::request::HttpRequest;
pub use super::utils::{assert_form_error, assert_template_used};
use super::{
client::{TestClient, TestResponse},
request_factory::RequestFactory,
runner::{TestRunner, TestSuiteResult},
};
macro_rules! impl_case_helpers {
($name:ident) => {
impl $name {
#[must_use]
pub fn request_factory(&self) -> RequestFactory {
RequestFactory::new()
}
#[must_use]
pub fn get(&self, path: &str) -> HttpRequest {
self.request_factory().get(path)
}
#[must_use]
pub fn client(&self, router: Router) -> TestClient {
TestClient::new(router)
}
#[must_use]
pub fn run_tests(&self, tests: Vec<(&str, Box<dyn Fn() -> bool>)>) -> TestSuiteResult {
TestRunner::new().run_tests(tests)
}
#[must_use]
pub fn run_tests_with(
&self,
runner: &TestRunner,
tests: Vec<(&str, Box<dyn Fn() -> bool>)>,
) -> TestSuiteResult {
runner.run_tests(tests)
}
}
};
}
pub struct SimpleTestCase {
pub settings_overrides: HashMap<String, String>,
}
impl SimpleTestCase {
#[must_use]
pub fn new() -> Self {
Self {
settings_overrides: HashMap::new(),
}
}
pub fn override_settings(&mut self, key: &str, value: &str) {
self.settings_overrides.insert(key.into(), value.into());
}
}
impl_case_helpers!(SimpleTestCase);
impl Default for SimpleTestCase {
fn default() -> Self {
Self::new()
}
}
#[derive(Default)]
pub struct TestCase {
pub base: SimpleTestCase,
}
impl TestCase {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn client_following_redirects(&self, router: Router) -> TestClient {
TestClient::with_follow_redirects(router, true)
}
}
impl_case_helpers!(TestCase);
#[derive(Default)]
pub struct TransactionTestCase {
pub base: SimpleTestCase,
}
impl TransactionTestCase {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn uses_database_transactions(&self) -> bool {
false
}
}
impl_case_helpers!(TransactionTestCase);
const TEST_SERVER_HOST: &str = "http://127.0.0.1";
pub struct LiveServerTestCase {
pub base: SimpleTestCase,
pub port: u16,
}
impl LiveServerTestCase {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn live_server_url(&self) -> String {
format!("{TEST_SERVER_HOST}:{}", self.port)
}
#[must_use]
pub fn absolute_url(&self, path: &str) -> String {
format!(
"{}{}",
self.live_server_url(),
normalize_live_server_path(path)
)
}
}
impl_case_helpers!(LiveServerTestCase);
impl Default for LiveServerTestCase {
fn default() -> Self {
Self {
base: SimpleTestCase::new(),
port: 8081,
}
}
}
#[must_use]
fn normalize_live_server_path(path: &str) -> String {
if path.starts_with('/') {
path.to_string()
} else {
format!("/{path}")
}
}
pub fn assert_redirects(response: &TestResponse, expected_url: &str) {
assert!(
response.status_code.is_redirection(),
"expected redirect status for {}, got {}",
response.url,
response.status_code
);
let actual_url = response
.headers
.get(LOCATION)
.and_then(|value| value.to_str().ok())
.unwrap_or_default();
assert_eq!(
actual_url, expected_url,
"expected redirect location {expected_url:?}, got {actual_url:?}"
);
}
pub fn assert_contains(response: &TestResponse, text: &str) {
assert!(
response.content.contains(text),
"expected response body for {} to contain {text:?}, got {:?}",
response.url,
response.content
);
}
pub fn assert_not_contains(response: &TestResponse, text: &str) {
assert!(
!response.content.contains(text),
"expected response body for {} to not contain {text:?}, got {:?}",
response.url,
response.content
);
}
pub fn assert_status_code(response: &TestResponse, expected: StatusCode) {
assert_eq!(
response.status_code, expected,
"expected status {} for {}, got {}",
expected, response.url, response.status_code
);
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
Router,
http::{HeaderMap, HeaderValue, Method},
response::Redirect,
routing::get,
};
use std::panic::catch_unwind;
fn response_with(status: StatusCode, content: &str, url: &str) -> TestResponse {
TestResponse {
status_code: status,
content: content.to_string(),
headers: HeaderMap::new(),
url: url.to_string(),
}
}
fn test_router() -> Router {
Router::new()
.route("/hello/", get(|| async { "hello from router" }))
.route("/source/", get(|| async { Redirect::to("/target/") }))
.route("/target/", get(|| async { "redirect target" }))
}
#[test]
fn simple_test_case_starts_with_no_overrides() {
let case = SimpleTestCase::new();
assert!(case.settings_overrides.is_empty());
}
#[test]
fn simple_test_case_records_setting_overrides() {
let mut case = SimpleTestCase::new();
case.override_settings("DEBUG", "true");
assert_eq!(
case.settings_overrides.get("DEBUG"),
Some(&"true".to_string())
);
}
#[test]
fn default_test_case_wraps_simple_test_case() {
let case = TestCase::default();
assert!(case.base.settings_overrides.is_empty());
}
#[test]
fn live_server_test_case_uses_default_port() {
let case = LiveServerTestCase::default();
assert_eq!(case.port, 8081);
assert!(case.base.settings_overrides.is_empty());
}
#[test]
fn test_assert_redirects_success() {
let mut response = response_with(StatusCode::FOUND, "", "/redirect/");
response
.headers
.insert(LOCATION, HeaderValue::from_static("/target/"));
assert_redirects(&response, "/target/");
}
#[test]
fn test_assert_redirects_wrong_url_panics() {
let mut response = response_with(StatusCode::FOUND, "", "/redirect/");
response
.headers
.insert(LOCATION, HeaderValue::from_static("/target/"));
let result = catch_unwind(|| assert_redirects(&response, "/other/"));
assert!(result.is_err());
}
#[test]
fn test_assert_contains_found() {
let response = response_with(StatusCode::OK, "hello rust", "/hello/");
assert_contains(&response, "rust");
}
#[test]
fn test_assert_contains_missing_panics() {
let response = response_with(StatusCode::OK, "hello rust", "/hello/");
let result = catch_unwind(|| assert_contains(&response, "django"));
assert!(result.is_err());
}
#[test]
fn test_assert_not_contains_success() {
let response = response_with(StatusCode::OK, "hello rust", "/hello/");
assert_not_contains(&response, "django");
}
#[test]
fn assert_not_contains_panics_when_text_is_present() {
let response = response_with(StatusCode::OK, "hello rust", "/hello/");
let result = catch_unwind(|| assert_not_contains(&response, "rust"));
assert!(result.is_err());
}
#[test]
fn test_assert_status_code_match() {
let response = response_with(StatusCode::CREATED, "", "/items/");
assert_status_code(&response, StatusCode::CREATED);
}
#[test]
fn test_assert_status_code_mismatch_panics() {
let response = response_with(StatusCode::OK, "", "/items/");
let result = catch_unwind(|| assert_status_code(&response, StatusCode::NOT_FOUND));
assert!(result.is_err());
}
#[test]
fn assert_template_used_accepts_plain_template_marker() {
let response = response_with(StatusCode::OK, "rendered by app/detail.html", "/items/1/");
assert_template_used(&response, "app/detail.html");
}
#[test]
fn assert_template_used_panics_when_marker_is_missing() {
let response = response_with(StatusCode::OK, "plain body", "/items/1/");
let result = catch_unwind(|| assert_template_used(&response, "app/detail.html"));
assert!(result.is_err());
}
#[test]
fn assert_form_error_requires_field_and_message() {
let response = response_with(StatusCode::OK, "email: This field is required.", "/signup/");
assert_form_error(&response, "email", "This field is required.");
}
#[test]
fn assert_form_error_panics_when_message_is_missing() {
let response = response_with(StatusCode::OK, "email", "/signup/");
let result = catch_unwind(|| assert_form_error(&response, "email", "required"));
assert!(result.is_err());
}
#[test]
fn simple_test_case_builds_get_requests() {
let request = SimpleTestCase::new().get("/hello/");
assert_eq!(request.method, Method::GET);
assert_eq!(request.path, "/hello/");
}
#[test]
fn simple_test_case_runs_tests_with_failfast_runner() {
let case = SimpleTestCase::new();
let runner = TestRunner::new().with_failfast();
let result = case.run_tests_with(
&runner,
vec![
("first", Box::new(|| false)),
(
"second",
Box::new(|| panic!("failfast should stop before this test")),
),
],
);
assert_eq!(result.total(), 1);
assert_eq!(result.failed(), 1);
}
#[test]
fn test_case_can_follow_redirects() {
let mut client = TestCase::new().client_following_redirects(test_router());
let response = client.get("/source/");
assert_eq!(response.status_code, StatusCode::OK);
assert_eq!(response.url, "/target/");
assert!(response.content.contains("redirect target"));
}
#[test]
fn transaction_test_case_reports_transactions_are_unavailable() {
assert!(!TransactionTestCase::new().uses_database_transactions());
}
#[test]
fn live_server_test_case_builds_absolute_urls_for_test_client() {
let case = LiveServerTestCase::new();
let mut client = case.client(test_router());
let response = client.get(&case.absolute_url("hello/"));
assert_eq!(
case.live_server_url(),
format!("{TEST_SERVER_HOST}:{}", case.port)
);
assert_eq!(response.status_code, StatusCode::OK);
assert!(response.content.contains("hello from router"));
}
}