1use std::fmt::Write as FmtWrite;
4
5use crate::escape::rust_raw_string;
6use crate::fixture::{CorsConfig, Fixture, StaticFilesConfig};
7
8enum ServerCall<'a> {
10 Shorthand(&'a str),
12 AxumMethod(&'a str),
14}
15
16enum RouteRegistration<'a> {
18 Shorthand(&'a str),
20 Explicit(&'a str),
22}
23
24pub fn render_http_test_function(out: &mut String, fixture: &Fixture, dep_name: &str) {
30 let http = match &fixture.http {
31 Some(h) => h,
32 None => return,
33 };
34
35 let fn_name = crate::escape::sanitize_ident(&fixture.id);
36 let description = &fixture.description;
37
38 let route = &http.handler.route;
39
40 let route_reg = match http.handler.method.to_lowercase().as_str() {
43 "get" => RouteRegistration::Shorthand("get"),
44 "post" => RouteRegistration::Shorthand("post"),
45 "put" => RouteRegistration::Shorthand("put"),
46 "patch" => RouteRegistration::Shorthand("patch"),
47 "delete" => RouteRegistration::Shorthand("delete"),
48 "head" => RouteRegistration::Explicit("Head"),
49 "options" => RouteRegistration::Explicit("Options"),
50 "trace" => RouteRegistration::Explicit("Trace"),
51 _ => RouteRegistration::Shorthand("get"),
52 };
53
54 let server_call = match http.request.method.to_uppercase().as_str() {
57 "GET" => ServerCall::Shorthand("get"),
58 "POST" => ServerCall::Shorthand("post"),
59 "PUT" => ServerCall::Shorthand("put"),
60 "PATCH" => ServerCall::Shorthand("patch"),
61 "DELETE" => ServerCall::Shorthand("delete"),
62 "HEAD" => ServerCall::AxumMethod("HEAD"),
63 "OPTIONS" => ServerCall::AxumMethod("OPTIONS"),
64 "TRACE" => ServerCall::AxumMethod("TRACE"),
65 _ => ServerCall::Shorthand("get"),
66 };
67
68 let req_path = &http.request.path;
69 let status = http.expected_response.status_code;
70
71 let body_str = match &http.expected_response.body {
73 Some(b) => serde_json::to_string(b).unwrap_or_else(|_| "{}".to_string()),
74 None => String::new(),
75 };
76 let body_literal = rust_raw_string(&body_str);
77
78 let req_body_str = match &http.request.body {
80 Some(b) => serde_json::to_string(b).unwrap_or_else(|_| "{}".to_string()),
81 None => String::new(),
82 };
83 let has_req_body = !req_body_str.is_empty();
84
85 let middleware = http.handler.middleware.as_ref();
87 let cors_cfg: Option<&CorsConfig> = middleware.and_then(|m| m.cors.as_ref());
88 let static_files_cfgs: Option<&Vec<StaticFilesConfig>> = middleware.and_then(|m| m.static_files.as_ref());
89 let has_static_files = static_files_cfgs.is_some_and(|v| !v.is_empty());
90
91 let _ = writeln!(out, "#[tokio::test]");
92 let _ = writeln!(out, "async fn test_{fn_name}() {{");
93 let _ = writeln!(out, " // {description}");
94
95 if has_static_files {
97 render_static_files_test(out, fixture, static_files_cfgs.unwrap(), &server_call, req_path, status);
98 return;
99 }
100
101 let _ = writeln!(out, " let expected_body = {body_literal}.to_string();");
103 let _ = writeln!(out, " let mut app = {dep_name}::App::new();");
104
105 match &route_reg {
107 RouteRegistration::Shorthand(method) => {
108 let _ = writeln!(
109 out,
110 " app.route({dep_name}::{method}({route:?}), move |_ctx: {dep_name}::RequestContext| {{"
111 );
112 }
113 RouteRegistration::Explicit(variant) => {
114 let _ = writeln!(
115 out,
116 " app.route({dep_name}::RouteBuilder::new({dep_name}::Method::{variant}, {route:?}), move |_ctx: {dep_name}::RequestContext| {{"
117 );
118 }
119 }
120 let _ = writeln!(out, " let body = expected_body.clone();");
121 let _ = writeln!(out, " async move {{");
122 let _ = writeln!(out, " Ok(axum::http::Response::builder()");
123 let _ = writeln!(out, " .status({status}u16)");
124 let _ = writeln!(out, " .header(\"content-type\", \"application/json\")");
125 let _ = writeln!(out, " .body(axum::body::Body::from(body))");
126 let _ = writeln!(out, " .unwrap())");
127 let _ = writeln!(out, " }}");
128 let _ = writeln!(out, " }}).unwrap();");
129
130 let _ = writeln!(out, " let router = app.into_router().unwrap();");
132 if let Some(cors) = cors_cfg {
133 render_cors_layer(out, cors);
134 }
135 let _ = writeln!(out, " let server = axum_test::TestServer::new(router);");
136
137 match &server_call {
139 ServerCall::Shorthand(method) => {
140 let _ = writeln!(out, " let response = server.{method}({req_path:?})");
141 }
142 ServerCall::AxumMethod(method) => {
143 let _ = writeln!(
144 out,
145 " let response = server.method(axum::http::Method::{method}, {req_path:?})"
146 );
147 }
148 }
149
150 for (name, value) in &http.request.headers {
152 let n = rust_raw_string(name);
153 let v = rust_raw_string(value);
154 let _ = writeln!(out, " .add_header({n}, {v})");
155 }
156
157 if has_req_body {
159 let req_body_literal = rust_raw_string(&req_body_str);
160 let _ = writeln!(
161 out,
162 " .bytes(bytes::Bytes::copy_from_slice({req_body_literal}.as_bytes()))"
163 );
164 }
165
166 let _ = writeln!(out, " .await;");
167
168 if cors_cfg.is_some() && (200..300).contains(&status) {
172 let _ = writeln!(
173 out,
174 " assert!(response.status_code().is_success(), \"expected CORS success status, got {{}}\", response.status_code());"
175 );
176 } else {
177 let _ = writeln!(out, " assert_eq!(response.status_code().as_u16(), {status}u16);");
178 }
179
180 let _ = writeln!(out, "}}");
181}
182
183pub fn render_cors_layer(out: &mut String, cors: &CorsConfig) {
188 let needs_header_value = !cors.allow_origins.is_empty();
192 let needs_method = !cors.allow_methods.is_empty();
193 let needs_header_name = !cors.allow_headers.is_empty()
194 && cors
195 .allow_headers
196 .iter()
197 .any(|h| !matches!(h.to_lowercase().as_str(), "content-type" | "authorization" | "accept"));
198
199 let _ = writeln!(
200 out,
201 " // Apply CorsLayer from tower-http based on fixture CORS config."
202 );
203 let _ = writeln!(out, " use tower_http::cors::CorsLayer;");
204 let mut imports: Vec<&'static str> = Vec::new();
205 if needs_header_name {
206 imports.push("HeaderName");
207 }
208 if needs_header_value {
209 imports.push("HeaderValue");
210 }
211 if needs_method {
212 imports.push("Method");
213 }
214 match imports.len() {
215 0 => {}
216 1 => {
217 let _ = writeln!(out, " use axum::http::{};", imports[0]);
218 }
219 _ => {
220 let _ = writeln!(out, " use axum::http::{{{}}};", imports.join(", "));
221 }
222 }
223 let _ = writeln!(out, " let cors_layer = CorsLayer::new()");
224
225 if cors.allow_origins.is_empty() {
227 let _ = writeln!(out, " .allow_origin(tower_http::cors::Any)");
228 } else {
229 let _ = writeln!(out, " .allow_origin([");
230 for origin in &cors.allow_origins {
231 let _ = writeln!(out, " \"{origin}\".parse::<HeaderValue>().unwrap(),");
232 }
233 let _ = writeln!(out, " ])");
234 }
235
236 if cors.allow_methods.is_empty() {
238 let _ = writeln!(out, " .allow_methods(tower_http::cors::Any)");
239 } else {
240 let methods: Vec<String> = cors
241 .allow_methods
242 .iter()
243 .map(|m| format!("Method::{}", m.to_uppercase()))
244 .collect();
245 let _ = writeln!(out, " .allow_methods([{}])", methods.join(", "));
246 }
247
248 if cors.allow_headers.is_empty() {
250 let _ = writeln!(out, " .allow_headers(tower_http::cors::Any)");
251 } else {
252 let headers: Vec<String> = cors
253 .allow_headers
254 .iter()
255 .map(|h| {
256 let lower = h.to_lowercase();
257 match lower.as_str() {
258 "content-type" => "axum::http::header::CONTENT_TYPE".to_string(),
259 "authorization" => "axum::http::header::AUTHORIZATION".to_string(),
260 "accept" => "axum::http::header::ACCEPT".to_string(),
261 _ => format!("HeaderName::from_static(\"{lower}\")"),
262 }
263 })
264 .collect();
265 let _ = writeln!(out, " .allow_headers([{}])", headers.join(", "));
266 }
267
268 if let Some(secs) = cors.max_age {
270 let _ = writeln!(out, " .max_age(std::time::Duration::from_secs({secs}));");
271 } else {
272 let _ = writeln!(out, " ;");
273 }
274
275 let _ = writeln!(out, " let router = router.layer(cors_layer);");
276}
277
278fn render_static_files_test(
283 out: &mut String,
284 fixture: &Fixture,
285 cfgs: &[StaticFilesConfig],
286 server_call: &ServerCall<'_>,
287 req_path: &str,
288 status: u16,
289) {
290 let http = fixture.http.as_ref().unwrap();
291
292 let _ = writeln!(out, " use tower_http::services::ServeDir;");
293 let _ = writeln!(out, " use axum::Router;");
294 let _ = writeln!(out, " let tmp_dir = tempfile::tempdir().expect(\"tmp dir\");");
295
296 let _ = writeln!(out, " let mut router = Router::new();");
298 for cfg in cfgs {
299 for file in &cfg.files {
300 let file_path = file.path.replace('\\', "/");
301 let content = rust_raw_string(&file.content);
302 if file_path.contains('/') {
303 let parent: String = file_path.rsplitn(2, '/').last().unwrap_or("").to_string();
304 let _ = writeln!(
305 out,
306 " std::fs::create_dir_all(tmp_dir.path().join(\"{parent}\")).unwrap();"
307 );
308 }
309 let _ = writeln!(
310 out,
311 " std::fs::write(tmp_dir.path().join(\"{file_path}\"), {content}).unwrap();"
312 );
313 }
314 let prefix = &cfg.route_prefix;
315 let serve_dir_expr = if cfg.index_file {
316 "ServeDir::new(tmp_dir.path()).append_index_html_on_directories(true)".to_string()
317 } else {
318 "ServeDir::new(tmp_dir.path())".to_string()
319 };
320 let _ = writeln!(out, " router = router.nest_service({prefix:?}, {serve_dir_expr});");
321 }
322
323 let _ = writeln!(out, " let server = axum_test::TestServer::new(router);");
324
325 match server_call {
327 ServerCall::Shorthand(method) => {
328 let _ = writeln!(out, " let response = server.{method}({req_path:?})");
329 }
330 ServerCall::AxumMethod(method) => {
331 let _ = writeln!(
332 out,
333 " let response = server.method(axum::http::Method::{method}, {req_path:?})"
334 );
335 }
336 }
337
338 for (name, value) in &http.request.headers {
340 let n = rust_raw_string(name);
341 let v = rust_raw_string(value);
342 let _ = writeln!(out, " .add_header({n}, {v})");
343 }
344
345 let _ = writeln!(out, " .await;");
346 let _ = writeln!(out, " assert_eq!(response.status_code().as_u16(), {status}u16);");
347 let _ = writeln!(out, "}}");
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn render_cors_layer_empty_policy_uses_any() {
356 let cors = CorsConfig::default();
357 let mut out = String::new();
358 render_cors_layer(&mut out, &cors);
359 assert!(out.contains("allow_origin(tower_http::cors::Any)"));
360 assert!(out.contains("allow_methods(tower_http::cors::Any)"));
361 assert!(out.contains("allow_headers(tower_http::cors::Any)"));
362 }
363
364 #[test]
367 fn render_cors_layer_empty_policy_emits_no_axum_http_imports() {
368 let cors = CorsConfig::default();
369 let mut out = String::new();
370 render_cors_layer(&mut out, &cors);
371 assert!(!out.contains("use axum::http::"));
372 }
373
374 #[test]
376 fn render_cors_layer_with_origin_imports_header_value() {
377 let cors = CorsConfig {
378 allow_origins: vec!["https://example.com".to_string()],
379 ..CorsConfig::default()
380 };
381 let mut out = String::new();
382 render_cors_layer(&mut out, &cors);
383 assert!(out.contains("use axum::http::HeaderValue;"));
384 }
385
386 #[test]
388 fn render_cors_layer_with_method_imports_method() {
389 let cors = CorsConfig {
390 allow_methods: vec!["GET".to_string()],
391 ..CorsConfig::default()
392 };
393 let mut out = String::new();
394 render_cors_layer(&mut out, &cors);
395 assert!(out.contains("use axum::http::Method;"));
396 }
397
398 #[test]
401 fn render_cors_layer_with_only_prelude_headers_omits_header_name() {
402 let cors = CorsConfig {
403 allow_headers: vec!["content-type".to_string(), "Authorization".to_string()],
404 ..CorsConfig::default()
405 };
406 let mut out = String::new();
407 render_cors_layer(&mut out, &cors);
408 assert!(!out.contains("HeaderName"));
409 }
410
411 #[test]
414 fn render_cors_layer_with_custom_header_imports_header_name() {
415 let cors = CorsConfig {
416 allow_headers: vec!["X-Custom".to_string()],
417 ..CorsConfig::default()
418 };
419 let mut out = String::new();
420 render_cors_layer(&mut out, &cors);
421 assert!(out.contains("HeaderName"));
422 assert!(out.contains("use axum::http::HeaderName;"));
423 }
424}