Skip to main content

alef_e2e/codegen/rust/
http.rs

1//! HTTP integration test generation for Rust e2e tests.
2
3use std::fmt::Write as FmtWrite;
4
5use crate::escape::rust_raw_string;
6use crate::fixture::{CorsConfig, Fixture, StaticFilesConfig};
7
8/// How to call a method on axum_test::TestServer in generated code.
9enum ServerCall<'a> {
10    /// Emit `server.get(path)` / `server.post(path)` etc.
11    Shorthand(&'a str),
12    /// Emit `server.method(axum::http::Method::OPTIONS, path)` etc.
13    AxumMethod(&'a str),
14}
15
16/// How to register a route on the configured HTTP framework's App in generated code.
17enum RouteRegistration<'a> {
18    /// Emit `<dep>::get(path)` / `<dep>::post(path)` etc.
19    Shorthand(&'a str),
20    /// Emit `<dep>::RouteBuilder::new(<dep>::Method::Options, path)` etc.
21    Explicit(&'a str),
22}
23
24/// Generate a complete integration test function for an http fixture.
25///
26/// Builds a real `App` from the configured HTTP framework crate with a handler
27/// that returns the expected response, then uses `axum_test::TestServer` to send
28/// the request and assert the status code.
29pub fn render_http_test_function(out: &mut String, fixture: &Fixture, dep_name: &str) {
30    let http = match &fixture.http {
31        Some(h) => h,
32        None => return,
33    };
34
35    let fn_name = crate::escape::sanitize_ident(&fixture.id);
36    let description = &fixture.description;
37
38    let route = &http.handler.route;
39
40    // The configured HTTP framework crate is expected to expose convenience functions
41    // for GET/POST/PUT/PATCH/DELETE. All other methods (HEAD, OPTIONS, TRACE, etc.) must
42    // use RouteBuilder::new directly.
43    let route_reg = match http.handler.method.to_lowercase().as_str() {
44        "get" => RouteRegistration::Shorthand("get"),
45        "post" => RouteRegistration::Shorthand("post"),
46        "put" => RouteRegistration::Shorthand("put"),
47        "patch" => RouteRegistration::Shorthand("patch"),
48        "delete" => RouteRegistration::Shorthand("delete"),
49        "head" => RouteRegistration::Explicit("Head"),
50        "options" => RouteRegistration::Explicit("Options"),
51        "trace" => RouteRegistration::Explicit("Trace"),
52        _ => RouteRegistration::Shorthand("get"),
53    };
54
55    // axum_test::TestServer has shorthand methods for GET/POST/PUT/PATCH/DELETE.
56    // For HEAD and other methods, use server.method(axum::http::Method::HEAD, path).
57    let server_call = match http.request.method.to_uppercase().as_str() {
58        "GET" => ServerCall::Shorthand("get"),
59        "POST" => ServerCall::Shorthand("post"),
60        "PUT" => ServerCall::Shorthand("put"),
61        "PATCH" => ServerCall::Shorthand("patch"),
62        "DELETE" => ServerCall::Shorthand("delete"),
63        "HEAD" => ServerCall::AxumMethod("HEAD"),
64        "OPTIONS" => ServerCall::AxumMethod("OPTIONS"),
65        "TRACE" => ServerCall::AxumMethod("TRACE"),
66        _ => ServerCall::Shorthand("get"),
67    };
68
69    let req_path = &http.request.path;
70    let status = http.expected_response.status_code;
71
72    // Serialize expected response body (if any).
73    let body_str = match &http.expected_response.body {
74        Some(b) => serde_json::to_string(b).unwrap_or_else(|_| "{}".to_string()),
75        None => String::new(),
76    };
77    let body_literal = rust_raw_string(&body_str);
78
79    // Serialize request body (if any).
80    let req_body_str = match &http.request.body {
81        Some(b) => serde_json::to_string(b).unwrap_or_else(|_| "{}".to_string()),
82        None => String::new(),
83    };
84    let has_req_body = !req_body_str.is_empty();
85
86    // Extract middleware from handler (if any).
87    let middleware = http.handler.middleware.as_ref();
88    let cors_cfg: Option<&CorsConfig> = middleware.and_then(|m| m.cors.as_ref());
89    let static_files_cfgs: Option<&Vec<StaticFilesConfig>> = middleware.and_then(|m| m.static_files.as_ref());
90    let has_static_files = static_files_cfgs.is_some_and(|v| !v.is_empty());
91
92    let _ = writeln!(out, "#[tokio::test]");
93    let _ = writeln!(out, "async fn test_{fn_name}() {{");
94    let _ = writeln!(out, "    // {description}");
95
96    // When static-files middleware is configured, serve from a temp dir via ServeDir.
97    if has_static_files {
98        render_static_files_test(out, fixture, static_files_cfgs.unwrap(), &server_call, req_path, status);
99        return;
100    }
101
102    // Build handler that returns the expected response.
103    let _ = writeln!(out, "    let expected_body = {body_literal}.to_string();");
104    let _ = writeln!(out, "    let mut app = {dep_name}::App::new();");
105
106    // Emit route registration.
107    match &route_reg {
108        RouteRegistration::Shorthand(method) => {
109            let _ = writeln!(
110                out,
111                "    app.route({dep_name}::{method}({route:?}), move |_ctx: {dep_name}::RequestContext| {{"
112            );
113        }
114        RouteRegistration::Explicit(variant) => {
115            let _ = writeln!(
116                out,
117                "    app.route({dep_name}::RouteBuilder::new({dep_name}::Method::{variant}, {route:?}), move |_ctx: {dep_name}::RequestContext| {{"
118            );
119        }
120    }
121    let _ = writeln!(out, "        let body = expected_body.clone();");
122    let _ = writeln!(out, "        async move {{");
123    let _ = writeln!(out, "            Ok(axum::http::Response::builder()");
124    let _ = writeln!(out, "                .status({status}u16)");
125    let _ = writeln!(out, "                .header(\"content-type\", \"application/json\")");
126    let _ = writeln!(out, "                .body(axum::body::Body::from(body))");
127    let _ = writeln!(out, "                .unwrap())");
128    let _ = writeln!(out, "        }}");
129    let _ = writeln!(out, "    }}).unwrap();");
130
131    // Build axum-test TestServer from the app router, optionally wrapping with CorsLayer.
132    let _ = writeln!(out, "    let router = app.into_router().unwrap();");
133    if let Some(cors) = cors_cfg {
134        render_cors_layer(out, cors);
135    }
136    let _ = writeln!(out, "    let server = axum_test::TestServer::new(router);");
137
138    // Build and send the request.
139    match &server_call {
140        ServerCall::Shorthand(method) => {
141            let _ = writeln!(out, "    let response = server.{method}({req_path:?})");
142        }
143        ServerCall::AxumMethod(method) => {
144            let _ = writeln!(
145                out,
146                "    let response = server.method(axum::http::Method::{method}, {req_path:?})"
147            );
148        }
149    }
150
151    // Add request headers (axum_test::TestRequest::add_header accepts &str via TryInto).
152    for (name, value) in &http.request.headers {
153        let n = rust_raw_string(name);
154        let v = rust_raw_string(value);
155        let _ = writeln!(out, "        .add_header({n}, {v})");
156    }
157
158    // Add request body if present (pass as a JSON string so axum-test's bytes() API gets a Bytes value).
159    if has_req_body {
160        let req_body_literal = rust_raw_string(&req_body_str);
161        let _ = writeln!(
162            out,
163            "        .bytes(bytes::Bytes::copy_from_slice({req_body_literal}.as_bytes()))"
164        );
165    }
166
167    let _ = writeln!(out, "        .await;");
168
169    // Assert status code.
170    // When a CorsLayer is applied and the fixture expects a 2xx status, tower-http may
171    // return 200 instead of 204 for preflight. Accept any 2xx status in that case.
172    if cors_cfg.is_some() && (200..300).contains(&status) {
173        let _ = writeln!(
174            out,
175            "    assert!(response.status_code().is_success(), \"expected CORS success status, got {{}}\", response.status_code());"
176        );
177    } else {
178        let _ = writeln!(out, "    assert_eq!(response.status_code().as_u16(), {status}u16);");
179    }
180
181    let _ = writeln!(out, "}}");
182}
183
184/// Emit lines that wrap the axum router with a `tower_http::cors::CorsLayer`.
185///
186/// The CORS policy is derived from the fixture's `cors` middleware config.
187/// After this function, `router` is reassigned to the layer-wrapped version.
188pub fn render_cors_layer(out: &mut String, cors: &CorsConfig) {
189    // Decide up-front which axum::http re-exports we will actually reference so we
190    // can emit a tight `use` group — emitting all three unconditionally trips
191    // `-D unused_imports` for fixtures that, say, allow no custom headers.
192    let needs_header_value = !cors.allow_origins.is_empty();
193    let needs_method = !cors.allow_methods.is_empty();
194    let needs_header_name = !cors.allow_headers.is_empty()
195        && cors
196            .allow_headers
197            .iter()
198            .any(|h| !matches!(h.to_lowercase().as_str(), "content-type" | "authorization" | "accept"));
199
200    let _ = writeln!(
201        out,
202        "    // Apply CorsLayer from tower-http based on fixture CORS config."
203    );
204    let _ = writeln!(out, "    use tower_http::cors::CorsLayer;");
205    let mut imports: Vec<&'static str> = Vec::new();
206    if needs_header_name {
207        imports.push("HeaderName");
208    }
209    if needs_header_value {
210        imports.push("HeaderValue");
211    }
212    if needs_method {
213        imports.push("Method");
214    }
215    match imports.len() {
216        0 => {}
217        1 => {
218            let _ = writeln!(out, "    use axum::http::{};", imports[0]);
219        }
220        _ => {
221            let _ = writeln!(out, "    use axum::http::{{{}}};", imports.join(", "));
222        }
223    }
224    let _ = writeln!(out, "    let cors_layer = CorsLayer::new()");
225
226    // allow_origins
227    if cors.allow_origins.is_empty() {
228        let _ = writeln!(out, "        .allow_origin(tower_http::cors::Any)");
229    } else {
230        let _ = writeln!(out, "        .allow_origin([");
231        for origin in &cors.allow_origins {
232            let _ = writeln!(out, "            \"{origin}\".parse::<HeaderValue>().unwrap(),");
233        }
234        let _ = writeln!(out, "        ])");
235    }
236
237    // allow_methods
238    if cors.allow_methods.is_empty() {
239        let _ = writeln!(out, "        .allow_methods(tower_http::cors::Any)");
240    } else {
241        let methods: Vec<String> = cors
242            .allow_methods
243            .iter()
244            .map(|m| format!("Method::{}", m.to_uppercase()))
245            .collect();
246        let _ = writeln!(out, "        .allow_methods([{}])", methods.join(", "));
247    }
248
249    // allow_headers
250    if cors.allow_headers.is_empty() {
251        let _ = writeln!(out, "        .allow_headers(tower_http::cors::Any)");
252    } else {
253        let headers: Vec<String> = cors
254            .allow_headers
255            .iter()
256            .map(|h| {
257                let lower = h.to_lowercase();
258                match lower.as_str() {
259                    "content-type" => "axum::http::header::CONTENT_TYPE".to_string(),
260                    "authorization" => "axum::http::header::AUTHORIZATION".to_string(),
261                    "accept" => "axum::http::header::ACCEPT".to_string(),
262                    _ => format!("HeaderName::from_static(\"{lower}\")"),
263                }
264            })
265            .collect();
266        let _ = writeln!(out, "        .allow_headers([{}])", headers.join(", "));
267    }
268
269    // max_age
270    if let Some(secs) = cors.max_age {
271        let _ = writeln!(out, "        .max_age(std::time::Duration::from_secs({secs}));");
272    } else {
273        let _ = writeln!(out, "        ;");
274    }
275
276    let _ = writeln!(out, "    let router = router.layer(cors_layer);");
277}
278
279/// Emit lines for a static-files integration test.
280///
281/// Writes fixture files to a temporary directory and serves them via
282/// `tower_http::services::ServeDir`, bypassing the framework App entirely.
283fn render_static_files_test(
284    out: &mut String,
285    fixture: &Fixture,
286    cfgs: &[StaticFilesConfig],
287    server_call: &ServerCall<'_>,
288    req_path: &str,
289    status: u16,
290) {
291    let http = fixture.http.as_ref().unwrap();
292
293    let _ = writeln!(out, "    use tower_http::services::ServeDir;");
294    let _ = writeln!(out, "    use axum::Router;");
295    let _ = writeln!(out, "    let tmp_dir = tempfile::tempdir().expect(\"tmp dir\");");
296
297    // Build the router by nesting a ServeDir for each config entry.
298    let _ = writeln!(out, "    let mut router = Router::new();");
299    for cfg in cfgs {
300        for file in &cfg.files {
301            let file_path = file.path.replace('\\', "/");
302            let content = rust_raw_string(&file.content);
303            if file_path.contains('/') {
304                let parent: String = file_path.rsplitn(2, '/').last().unwrap_or("").to_string();
305                let _ = writeln!(
306                    out,
307                    "    std::fs::create_dir_all(tmp_dir.path().join(\"{parent}\")).unwrap();"
308                );
309            }
310            let _ = writeln!(
311                out,
312                "    std::fs::write(tmp_dir.path().join(\"{file_path}\"), {content}).unwrap();"
313            );
314        }
315        let prefix = &cfg.route_prefix;
316        let serve_dir_expr = if cfg.index_file {
317            "ServeDir::new(tmp_dir.path()).append_index_html_on_directories(true)".to_string()
318        } else {
319            "ServeDir::new(tmp_dir.path())".to_string()
320        };
321        let _ = writeln!(out, "    router = router.nest_service({prefix:?}, {serve_dir_expr});");
322    }
323
324    let _ = writeln!(out, "    let server = axum_test::TestServer::new(router);");
325
326    // Build and send the request.
327    match server_call {
328        ServerCall::Shorthand(method) => {
329            let _ = writeln!(out, "    let response = server.{method}({req_path:?})");
330        }
331        ServerCall::AxumMethod(method) => {
332            let _ = writeln!(
333                out,
334                "    let response = server.method(axum::http::Method::{method}, {req_path:?})"
335            );
336        }
337    }
338
339    // Add request headers.
340    for (name, value) in &http.request.headers {
341        let n = rust_raw_string(name);
342        let v = rust_raw_string(value);
343        let _ = writeln!(out, "        .add_header({n}, {v})");
344    }
345
346    let _ = writeln!(out, "        .await;");
347    let _ = writeln!(out, "    assert_eq!(response.status_code().as_u16(), {status}u16);");
348    let _ = writeln!(out, "}}");
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn render_cors_layer_empty_policy_uses_any() {
357        let cors = CorsConfig::default();
358        let mut out = String::new();
359        render_cors_layer(&mut out, &cors);
360        assert!(out.contains("allow_origin(tower_http::cors::Any)"));
361        assert!(out.contains("allow_methods(tower_http::cors::Any)"));
362        assert!(out.contains("allow_headers(tower_http::cors::Any)"));
363    }
364
365    /// An empty CORS policy must not import `HeaderName`/`HeaderValue`/`Method`
366    /// — emitting unused imports trips `-D unused_imports` in the consumer.
367    #[test]
368    fn render_cors_layer_empty_policy_emits_no_axum_http_imports() {
369        let cors = CorsConfig::default();
370        let mut out = String::new();
371        render_cors_layer(&mut out, &cors);
372        assert!(!out.contains("use axum::http::"));
373    }
374
375    /// `allow_origins` set → `HeaderValue` is referenced, so the import must appear.
376    #[test]
377    fn render_cors_layer_with_origin_imports_header_value() {
378        let cors = CorsConfig {
379            allow_origins: vec!["https://example.com".to_string()],
380            ..CorsConfig::default()
381        };
382        let mut out = String::new();
383        render_cors_layer(&mut out, &cors);
384        assert!(out.contains("use axum::http::HeaderValue;"));
385    }
386
387    /// `allow_methods` set → `Method` is referenced.
388    #[test]
389    fn render_cors_layer_with_method_imports_method() {
390        let cors = CorsConfig {
391            allow_methods: vec!["GET".to_string()],
392            ..CorsConfig::default()
393        };
394        let mut out = String::new();
395        render_cors_layer(&mut out, &cors);
396        assert!(out.contains("use axum::http::Method;"));
397    }
398
399    /// `allow_headers` containing only prelude-mapped names (content-type, etc.)
400    /// must NOT import `HeaderName` — those headers expand to qualified constants.
401    #[test]
402    fn render_cors_layer_with_only_prelude_headers_omits_header_name() {
403        let cors = CorsConfig {
404            allow_headers: vec!["content-type".to_string(), "Authorization".to_string()],
405            ..CorsConfig::default()
406        };
407        let mut out = String::new();
408        render_cors_layer(&mut out, &cors);
409        assert!(!out.contains("HeaderName"));
410    }
411
412    /// `allow_headers` containing a custom header → `HeaderName::from_static(...)` is
413    /// emitted, so the `HeaderName` import must appear.
414    #[test]
415    fn render_cors_layer_with_custom_header_imports_header_name() {
416        let cors = CorsConfig {
417            allow_headers: vec!["X-Custom".to_string()],
418            ..CorsConfig::default()
419        };
420        let mut out = String::new();
421        render_cors_layer(&mut out, &cors);
422        assert!(out.contains("HeaderName"));
423        assert!(out.contains("use axum::http::HeaderName;"));
424    }
425}