1use std::fmt::Write as FmtWrite;
4
5use crate::escape::rust_raw_string;
6use crate::fixture::{CorsConfig, Fixture, StaticFilesConfig};
7
8enum ServerCall<'a> {
10 Shorthand(&'a str),
12 AxumMethod(&'a str),
14}
15
16enum RouteRegistration<'a> {
18 Shorthand(&'a str),
20 Explicit(&'a str),
22}
23
24pub fn render_http_test_function(out: &mut String, fixture: &Fixture, dep_name: &str) {
30 let http = match &fixture.http {
31 Some(h) => h,
32 None => return,
33 };
34
35 let fn_name = crate::escape::sanitize_ident(&fixture.id);
36 let description = &fixture.description;
37
38 let route = &http.handler.route;
39
40 let route_reg = match http.handler.method.to_lowercase().as_str() {
44 "get" => RouteRegistration::Shorthand("get"),
45 "post" => RouteRegistration::Shorthand("post"),
46 "put" => RouteRegistration::Shorthand("put"),
47 "patch" => RouteRegistration::Shorthand("patch"),
48 "delete" => RouteRegistration::Shorthand("delete"),
49 "head" => RouteRegistration::Explicit("Head"),
50 "options" => RouteRegistration::Explicit("Options"),
51 "trace" => RouteRegistration::Explicit("Trace"),
52 _ => RouteRegistration::Shorthand("get"),
53 };
54
55 let server_call = match http.request.method.to_uppercase().as_str() {
58 "GET" => ServerCall::Shorthand("get"),
59 "POST" => ServerCall::Shorthand("post"),
60 "PUT" => ServerCall::Shorthand("put"),
61 "PATCH" => ServerCall::Shorthand("patch"),
62 "DELETE" => ServerCall::Shorthand("delete"),
63 "HEAD" => ServerCall::AxumMethod("HEAD"),
64 "OPTIONS" => ServerCall::AxumMethod("OPTIONS"),
65 "TRACE" => ServerCall::AxumMethod("TRACE"),
66 _ => ServerCall::Shorthand("get"),
67 };
68
69 let req_path = &http.request.path;
70 let status = http.expected_response.status_code;
71
72 let body_str = match &http.expected_response.body {
74 Some(b) => serde_json::to_string(b).unwrap_or_else(|_| "{}".to_string()),
75 None => String::new(),
76 };
77 let body_literal = rust_raw_string(&body_str);
78
79 let req_body_str = match &http.request.body {
81 Some(b) => serde_json::to_string(b).unwrap_or_else(|_| "{}".to_string()),
82 None => String::new(),
83 };
84 let has_req_body = !req_body_str.is_empty();
85
86 let middleware = http.handler.middleware.as_ref();
88 let cors_cfg: Option<&CorsConfig> = middleware.and_then(|m| m.cors.as_ref());
89 let static_files_cfgs: Option<&Vec<StaticFilesConfig>> = middleware.and_then(|m| m.static_files.as_ref());
90 let has_static_files = static_files_cfgs.is_some_and(|v| !v.is_empty());
91
92 let _ = writeln!(out, "#[tokio::test]");
93 let _ = writeln!(out, "async fn test_{fn_name}() {{");
94 let _ = writeln!(out, " // {description}");
95
96 if has_static_files {
98 render_static_files_test(out, fixture, static_files_cfgs.unwrap(), &server_call, req_path, status);
99 return;
100 }
101
102 let _ = writeln!(out, " let expected_body = {body_literal}.to_string();");
104 let _ = writeln!(out, " let mut app = {dep_name}::App::new();");
105
106 match &route_reg {
108 RouteRegistration::Shorthand(method) => {
109 let _ = writeln!(
110 out,
111 " app.route({dep_name}::{method}({route:?}), move |_ctx: {dep_name}::RequestContext| {{"
112 );
113 }
114 RouteRegistration::Explicit(variant) => {
115 let _ = writeln!(
116 out,
117 " app.route({dep_name}::RouteBuilder::new({dep_name}::Method::{variant}, {route:?}), move |_ctx: {dep_name}::RequestContext| {{"
118 );
119 }
120 }
121 let _ = writeln!(out, " let body = expected_body.clone();");
122 let _ = writeln!(out, " async move {{");
123 let _ = writeln!(out, " Ok(axum::http::Response::builder()");
124 let _ = writeln!(out, " .status({status}u16)");
125 let _ = writeln!(out, " .header(\"content-type\", \"application/json\")");
126 let _ = writeln!(out, " .body(axum::body::Body::from(body))");
127 let _ = writeln!(out, " .unwrap())");
128 let _ = writeln!(out, " }}");
129 let _ = writeln!(out, " }}).unwrap();");
130
131 let _ = writeln!(out, " let router = app.into_router().unwrap();");
133 if let Some(cors) = cors_cfg {
134 render_cors_layer(out, cors);
135 }
136 let _ = writeln!(out, " let server = axum_test::TestServer::new(router);");
137
138 match &server_call {
140 ServerCall::Shorthand(method) => {
141 let _ = writeln!(out, " let response = server.{method}({req_path:?})");
142 }
143 ServerCall::AxumMethod(method) => {
144 let _ = writeln!(
145 out,
146 " let response = server.method(axum::http::Method::{method}, {req_path:?})"
147 );
148 }
149 }
150
151 for (name, value) in &http.request.headers {
153 let n = rust_raw_string(name);
154 let v = rust_raw_string(value);
155 let _ = writeln!(out, " .add_header({n}, {v})");
156 }
157
158 if has_req_body {
160 let req_body_literal = rust_raw_string(&req_body_str);
161 let _ = writeln!(
162 out,
163 " .bytes(bytes::Bytes::copy_from_slice({req_body_literal}.as_bytes()))"
164 );
165 }
166
167 let _ = writeln!(out, " .await;");
168
169 if cors_cfg.is_some() && (200..300).contains(&status) {
173 let _ = writeln!(
174 out,
175 " assert!(response.status_code().is_success(), \"expected CORS success status, got {{}}\", response.status_code());"
176 );
177 } else {
178 let _ = writeln!(out, " assert_eq!(response.status_code().as_u16(), {status}u16);");
179 }
180
181 let _ = writeln!(out, "}}");
182}
183
184pub fn render_cors_layer(out: &mut String, cors: &CorsConfig) {
189 let needs_header_value = !cors.allow_origins.is_empty();
193 let needs_method = !cors.allow_methods.is_empty();
194 let needs_header_name = !cors.allow_headers.is_empty()
195 && cors
196 .allow_headers
197 .iter()
198 .any(|h| !matches!(h.to_lowercase().as_str(), "content-type" | "authorization" | "accept"));
199
200 let _ = writeln!(
201 out,
202 " // Apply CorsLayer from tower-http based on fixture CORS config."
203 );
204 let _ = writeln!(out, " use tower_http::cors::CorsLayer;");
205 let mut imports: Vec<&'static str> = Vec::new();
206 if needs_header_name {
207 imports.push("HeaderName");
208 }
209 if needs_header_value {
210 imports.push("HeaderValue");
211 }
212 if needs_method {
213 imports.push("Method");
214 }
215 match imports.len() {
216 0 => {}
217 1 => {
218 let _ = writeln!(out, " use axum::http::{};", imports[0]);
219 }
220 _ => {
221 let _ = writeln!(out, " use axum::http::{{{}}};", imports.join(", "));
222 }
223 }
224 let _ = writeln!(out, " let cors_layer = CorsLayer::new()");
225
226 if cors.allow_origins.is_empty() {
228 let _ = writeln!(out, " .allow_origin(tower_http::cors::Any)");
229 } else {
230 let _ = writeln!(out, " .allow_origin([");
231 for origin in &cors.allow_origins {
232 let _ = writeln!(out, " \"{origin}\".parse::<HeaderValue>().unwrap(),");
233 }
234 let _ = writeln!(out, " ])");
235 }
236
237 if cors.allow_methods.is_empty() {
239 let _ = writeln!(out, " .allow_methods(tower_http::cors::Any)");
240 } else {
241 let methods: Vec<String> = cors
242 .allow_methods
243 .iter()
244 .map(|m| format!("Method::{}", m.to_uppercase()))
245 .collect();
246 let _ = writeln!(out, " .allow_methods([{}])", methods.join(", "));
247 }
248
249 if cors.allow_headers.is_empty() {
251 let _ = writeln!(out, " .allow_headers(tower_http::cors::Any)");
252 } else {
253 let headers: Vec<String> = cors
254 .allow_headers
255 .iter()
256 .map(|h| {
257 let lower = h.to_lowercase();
258 match lower.as_str() {
259 "content-type" => "axum::http::header::CONTENT_TYPE".to_string(),
260 "authorization" => "axum::http::header::AUTHORIZATION".to_string(),
261 "accept" => "axum::http::header::ACCEPT".to_string(),
262 _ => format!("HeaderName::from_static(\"{lower}\")"),
263 }
264 })
265 .collect();
266 let _ = writeln!(out, " .allow_headers([{}])", headers.join(", "));
267 }
268
269 if let Some(secs) = cors.max_age {
271 let _ = writeln!(out, " .max_age(std::time::Duration::from_secs({secs}));");
272 } else {
273 let _ = writeln!(out, " ;");
274 }
275
276 let _ = writeln!(out, " let router = router.layer(cors_layer);");
277}
278
279fn render_static_files_test(
284 out: &mut String,
285 fixture: &Fixture,
286 cfgs: &[StaticFilesConfig],
287 server_call: &ServerCall<'_>,
288 req_path: &str,
289 status: u16,
290) {
291 let http = fixture.http.as_ref().unwrap();
292
293 let _ = writeln!(out, " use tower_http::services::ServeDir;");
294 let _ = writeln!(out, " use axum::Router;");
295 let _ = writeln!(out, " let tmp_dir = tempfile::tempdir().expect(\"tmp dir\");");
296
297 let _ = writeln!(out, " let mut router = Router::new();");
299 for cfg in cfgs {
300 for file in &cfg.files {
301 let file_path = file.path.replace('\\', "/");
302 let content = rust_raw_string(&file.content);
303 if file_path.contains('/') {
304 let parent: String = file_path.rsplitn(2, '/').last().unwrap_or("").to_string();
305 let _ = writeln!(
306 out,
307 " std::fs::create_dir_all(tmp_dir.path().join(\"{parent}\")).unwrap();"
308 );
309 }
310 let _ = writeln!(
311 out,
312 " std::fs::write(tmp_dir.path().join(\"{file_path}\"), {content}).unwrap();"
313 );
314 }
315 let prefix = &cfg.route_prefix;
316 let serve_dir_expr = if cfg.index_file {
317 "ServeDir::new(tmp_dir.path()).append_index_html_on_directories(true)".to_string()
318 } else {
319 "ServeDir::new(tmp_dir.path())".to_string()
320 };
321 let _ = writeln!(out, " router = router.nest_service({prefix:?}, {serve_dir_expr});");
322 }
323
324 let _ = writeln!(out, " let server = axum_test::TestServer::new(router);");
325
326 match server_call {
328 ServerCall::Shorthand(method) => {
329 let _ = writeln!(out, " let response = server.{method}({req_path:?})");
330 }
331 ServerCall::AxumMethod(method) => {
332 let _ = writeln!(
333 out,
334 " let response = server.method(axum::http::Method::{method}, {req_path:?})"
335 );
336 }
337 }
338
339 for (name, value) in &http.request.headers {
341 let n = rust_raw_string(name);
342 let v = rust_raw_string(value);
343 let _ = writeln!(out, " .add_header({n}, {v})");
344 }
345
346 let _ = writeln!(out, " .await;");
347 let _ = writeln!(out, " assert_eq!(response.status_code().as_u16(), {status}u16);");
348 let _ = writeln!(out, "}}");
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn render_cors_layer_empty_policy_uses_any() {
357 let cors = CorsConfig::default();
358 let mut out = String::new();
359 render_cors_layer(&mut out, &cors);
360 assert!(out.contains("allow_origin(tower_http::cors::Any)"));
361 assert!(out.contains("allow_methods(tower_http::cors::Any)"));
362 assert!(out.contains("allow_headers(tower_http::cors::Any)"));
363 }
364
365 #[test]
368 fn render_cors_layer_empty_policy_emits_no_axum_http_imports() {
369 let cors = CorsConfig::default();
370 let mut out = String::new();
371 render_cors_layer(&mut out, &cors);
372 assert!(!out.contains("use axum::http::"));
373 }
374
375 #[test]
377 fn render_cors_layer_with_origin_imports_header_value() {
378 let cors = CorsConfig {
379 allow_origins: vec!["https://example.com".to_string()],
380 ..CorsConfig::default()
381 };
382 let mut out = String::new();
383 render_cors_layer(&mut out, &cors);
384 assert!(out.contains("use axum::http::HeaderValue;"));
385 }
386
387 #[test]
389 fn render_cors_layer_with_method_imports_method() {
390 let cors = CorsConfig {
391 allow_methods: vec!["GET".to_string()],
392 ..CorsConfig::default()
393 };
394 let mut out = String::new();
395 render_cors_layer(&mut out, &cors);
396 assert!(out.contains("use axum::http::Method;"));
397 }
398
399 #[test]
402 fn render_cors_layer_with_only_prelude_headers_omits_header_name() {
403 let cors = CorsConfig {
404 allow_headers: vec!["content-type".to_string(), "Authorization".to_string()],
405 ..CorsConfig::default()
406 };
407 let mut out = String::new();
408 render_cors_layer(&mut out, &cors);
409 assert!(!out.contains("HeaderName"));
410 }
411
412 #[test]
415 fn render_cors_layer_with_custom_header_imports_header_name() {
416 let cors = CorsConfig {
417 allow_headers: vec!["X-Custom".to_string()],
418 ..CorsConfig::default()
419 };
420 let mut out = String::new();
421 render_cors_layer(&mut out, &cors);
422 assert!(out.contains("HeaderName"));
423 assert!(out.contains("use axum::http::HeaderName;"));
424 }
425}