use std::fmt::Write as FmtWrite;
use crate::escape::rust_raw_string;
use crate::fixture::{CorsConfig, Fixture, StaticFilesConfig};
enum ServerCall<'a> {
Shorthand(&'a str),
AxumMethod(&'a str),
}
enum RouteRegistration<'a> {
Shorthand(&'a str),
Explicit(&'a str),
}
pub fn render_http_test_function(out: &mut String, fixture: &Fixture, dep_name: &str) {
let http = match &fixture.http {
Some(h) => h,
None => return,
};
let fn_name = crate::escape::sanitize_ident(&fixture.id);
let description = &fixture.description;
let route = &http.handler.route;
let route_reg = match http.handler.method.to_lowercase().as_str() {
"get" => RouteRegistration::Shorthand("get"),
"post" => RouteRegistration::Shorthand("post"),
"put" => RouteRegistration::Shorthand("put"),
"patch" => RouteRegistration::Shorthand("patch"),
"delete" => RouteRegistration::Shorthand("delete"),
"head" => RouteRegistration::Explicit("Head"),
"options" => RouteRegistration::Explicit("Options"),
"trace" => RouteRegistration::Explicit("Trace"),
_ => RouteRegistration::Shorthand("get"),
};
let server_call = match http.request.method.to_uppercase().as_str() {
"GET" => ServerCall::Shorthand("get"),
"POST" => ServerCall::Shorthand("post"),
"PUT" => ServerCall::Shorthand("put"),
"PATCH" => ServerCall::Shorthand("patch"),
"DELETE" => ServerCall::Shorthand("delete"),
"HEAD" => ServerCall::AxumMethod("HEAD"),
"OPTIONS" => ServerCall::AxumMethod("OPTIONS"),
"TRACE" => ServerCall::AxumMethod("TRACE"),
_ => ServerCall::Shorthand("get"),
};
let req_path = &http.request.path;
let status = http.expected_response.status_code;
let body_str = match &http.expected_response.body {
Some(b) => serde_json::to_string(b).unwrap_or_else(|_| "{}".to_string()),
None => String::new(),
};
let body_literal = rust_raw_string(&body_str);
let req_body_str = match &http.request.body {
Some(b) => serde_json::to_string(b).unwrap_or_else(|_| "{}".to_string()),
None => String::new(),
};
let has_req_body = !req_body_str.is_empty();
let middleware = http.handler.middleware.as_ref();
let cors_cfg: Option<&CorsConfig> = middleware.and_then(|m| m.cors.as_ref());
let static_files_cfgs: Option<&Vec<StaticFilesConfig>> = middleware.and_then(|m| m.static_files.as_ref());
let has_static_files = static_files_cfgs.is_some_and(|v| !v.is_empty());
let _ = writeln!(out, "#[tokio::test]");
let _ = writeln!(out, "async fn test_{fn_name}() {{");
let _ = writeln!(out, " // {description}");
if has_static_files {
render_static_files_test(out, fixture, static_files_cfgs.unwrap(), &server_call, req_path, status);
return;
}
let _ = writeln!(out, " let expected_body = {body_literal}.to_string();");
let _ = writeln!(out, " let mut app = {dep_name}::App::new();");
match &route_reg {
RouteRegistration::Shorthand(method) => {
let _ = writeln!(
out,
" app.route({dep_name}::{method}({route:?}), move |_ctx: {dep_name}::RequestContext| {{"
);
}
RouteRegistration::Explicit(variant) => {
let _ = writeln!(
out,
" app.route({dep_name}::RouteBuilder::new({dep_name}::Method::{variant}, {route:?}), move |_ctx: {dep_name}::RequestContext| {{"
);
}
}
let _ = writeln!(out, " let body = expected_body.clone();");
let _ = writeln!(out, " async move {{");
let _ = writeln!(out, " Ok(axum::http::Response::builder()");
let _ = writeln!(out, " .status({status}u16)");
let _ = writeln!(out, " .header(\"content-type\", \"application/json\")");
let _ = writeln!(out, " .body(axum::body::Body::from(body))");
let _ = writeln!(out, " .unwrap())");
let _ = writeln!(out, " }}");
let _ = writeln!(out, " }}).unwrap();");
let _ = writeln!(out, " let router = app.into_router().unwrap();");
if let Some(cors) = cors_cfg {
render_cors_layer(out, cors);
}
let _ = writeln!(out, " let server = axum_test::TestServer::new(router);");
match &server_call {
ServerCall::Shorthand(method) => {
let _ = writeln!(out, " let response = server.{method}({req_path:?})");
}
ServerCall::AxumMethod(method) => {
let _ = writeln!(
out,
" let response = server.method(axum::http::Method::{method}, {req_path:?})"
);
}
}
for (name, value) in &http.request.headers {
let n = rust_raw_string(name);
let v = rust_raw_string(value);
let _ = writeln!(out, " .add_header({n}, {v})");
}
if has_req_body {
let req_body_literal = rust_raw_string(&req_body_str);
let _ = writeln!(
out,
" .bytes(bytes::Bytes::copy_from_slice({req_body_literal}.as_bytes()))"
);
}
let _ = writeln!(out, " .await;");
if cors_cfg.is_some() && (200..300).contains(&status) {
let _ = writeln!(
out,
" assert!(response.status_code().is_success(), \"expected CORS success status, got {{}}\", response.status_code());"
);
} else {
let _ = writeln!(out, " assert_eq!(response.status_code().as_u16(), {status}u16);");
}
let _ = writeln!(out, "}}");
}
pub fn render_cors_layer(out: &mut String, cors: &CorsConfig) {
let _ = writeln!(
out,
" // Apply CorsLayer from tower-http based on fixture CORS config."
);
let _ = writeln!(out, " use tower_http::cors::CorsLayer;");
let _ = writeln!(out, " use axum::http::{{HeaderName, HeaderValue, Method}};");
let _ = writeln!(out, " let cors_layer = CorsLayer::new()");
if cors.allow_origins.is_empty() {
let _ = writeln!(out, " .allow_origin(tower_http::cors::Any)");
} else {
let _ = writeln!(out, " .allow_origin([");
for origin in &cors.allow_origins {
let _ = writeln!(out, " \"{origin}\".parse::<HeaderValue>().unwrap(),");
}
let _ = writeln!(out, " ])");
}
if cors.allow_methods.is_empty() {
let _ = writeln!(out, " .allow_methods(tower_http::cors::Any)");
} else {
let methods: Vec<String> = cors
.allow_methods
.iter()
.map(|m| format!("Method::{}", m.to_uppercase()))
.collect();
let _ = writeln!(out, " .allow_methods([{}])", methods.join(", "));
}
if cors.allow_headers.is_empty() {
let _ = writeln!(out, " .allow_headers(tower_http::cors::Any)");
} else {
let headers: Vec<String> = cors
.allow_headers
.iter()
.map(|h| {
let lower = h.to_lowercase();
match lower.as_str() {
"content-type" => "axum::http::header::CONTENT_TYPE".to_string(),
"authorization" => "axum::http::header::AUTHORIZATION".to_string(),
"accept" => "axum::http::header::ACCEPT".to_string(),
_ => format!("HeaderName::from_static(\"{lower}\")"),
}
})
.collect();
let _ = writeln!(out, " .allow_headers([{}])", headers.join(", "));
}
if let Some(secs) = cors.max_age {
let _ = writeln!(out, " .max_age(std::time::Duration::from_secs({secs}));");
} else {
let _ = writeln!(out, " ;");
}
let _ = writeln!(out, " let router = router.layer(cors_layer);");
}
fn render_static_files_test(
out: &mut String,
fixture: &Fixture,
cfgs: &[StaticFilesConfig],
server_call: &ServerCall<'_>,
req_path: &str,
status: u16,
) {
let http = fixture.http.as_ref().unwrap();
let _ = writeln!(out, " use tower_http::services::ServeDir;");
let _ = writeln!(out, " use axum::Router;");
let _ = writeln!(out, " let tmp_dir = tempfile::tempdir().expect(\"tmp dir\");");
let _ = writeln!(out, " let mut router = Router::new();");
for cfg in cfgs {
for file in &cfg.files {
let file_path = file.path.replace('\\', "/");
let content = rust_raw_string(&file.content);
if file_path.contains('/') {
let parent: String = file_path.rsplitn(2, '/').last().unwrap_or("").to_string();
let _ = writeln!(
out,
" std::fs::create_dir_all(tmp_dir.path().join(\"{parent}\")).unwrap();"
);
}
let _ = writeln!(
out,
" std::fs::write(tmp_dir.path().join(\"{file_path}\"), {content}).unwrap();"
);
}
let prefix = &cfg.route_prefix;
let serve_dir_expr = if cfg.index_file {
"ServeDir::new(tmp_dir.path()).append_index_html_on_directories(true)".to_string()
} else {
"ServeDir::new(tmp_dir.path())".to_string()
};
let _ = writeln!(out, " router = router.nest_service({prefix:?}, {serve_dir_expr});");
}
let _ = writeln!(out, " let server = axum_test::TestServer::new(router);");
match server_call {
ServerCall::Shorthand(method) => {
let _ = writeln!(out, " let response = server.{method}({req_path:?})");
}
ServerCall::AxumMethod(method) => {
let _ = writeln!(
out,
" let response = server.method(axum::http::Method::{method}, {req_path:?})"
);
}
}
for (name, value) in &http.request.headers {
let n = rust_raw_string(name);
let v = rust_raw_string(value);
let _ = writeln!(out, " .add_header({n}, {v})");
}
let _ = writeln!(out, " .await;");
let _ = writeln!(out, " assert_eq!(response.status_code().as_u16(), {status}u16);");
let _ = writeln!(out, "}}");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn render_cors_layer_empty_policy_uses_any() {
let cors = CorsConfig::default();
let mut out = String::new();
render_cors_layer(&mut out, &cors);
assert!(out.contains("allow_origin(tower_http::cors::Any)"));
assert!(out.contains("allow_methods(tower_http::cors::Any)"));
assert!(out.contains("allow_headers(tower_http::cors::Any)"));
}
}