Skip to main content

alef_e2e/codegen/rust/
http.rs

1//! HTTP integration test generation for Rust e2e tests.
2
3use std::fmt::Write as FmtWrite;
4
5use crate::escape::rust_raw_string;
6use crate::fixture::{CorsConfig, Fixture, StaticFilesConfig};
7
8/// How to call a method on axum_test::TestServer in generated code.
9enum ServerCall<'a> {
10    /// Emit `server.get(path)` / `server.post(path)` etc.
11    Shorthand(&'a str),
12    /// Emit `server.method(axum::http::Method::OPTIONS, path)` etc.
13    AxumMethod(&'a str),
14}
15
16/// How to register a route on a spikard App in generated code.
17enum RouteRegistration<'a> {
18    /// Emit `spikard::get(path)` / `spikard::post(path)` etc.
19    Shorthand(&'a str),
20    /// Emit `spikard::RouteBuilder::new(spikard::Method::Options, path)` etc.
21    Explicit(&'a str),
22}
23
24/// Generate a complete integration test function for an http fixture.
25///
26/// Builds a real spikard `App` with a handler that returns the expected
27/// response, then uses `axum_test::TestServer` to send the request and
28/// assert the status code.
29pub fn render_http_test_function(out: &mut String, fixture: &Fixture, dep_name: &str) {
30    let http = match &fixture.http {
31        Some(h) => h,
32        None => return,
33    };
34
35    let fn_name = crate::escape::sanitize_ident(&fixture.id);
36    let description = &fixture.description;
37
38    let route = &http.handler.route;
39
40    // spikard provides convenience functions for GET/POST/PUT/PATCH/DELETE.
41    // All other methods (HEAD, OPTIONS, TRACE, etc.) must use RouteBuilder::new directly.
42    let route_reg = match http.handler.method.to_lowercase().as_str() {
43        "get" => RouteRegistration::Shorthand("get"),
44        "post" => RouteRegistration::Shorthand("post"),
45        "put" => RouteRegistration::Shorthand("put"),
46        "patch" => RouteRegistration::Shorthand("patch"),
47        "delete" => RouteRegistration::Shorthand("delete"),
48        "head" => RouteRegistration::Explicit("Head"),
49        "options" => RouteRegistration::Explicit("Options"),
50        "trace" => RouteRegistration::Explicit("Trace"),
51        _ => RouteRegistration::Shorthand("get"),
52    };
53
54    // axum_test::TestServer has shorthand methods for GET/POST/PUT/PATCH/DELETE.
55    // For HEAD and other methods, use server.method(axum::http::Method::HEAD, path).
56    let server_call = match http.request.method.to_uppercase().as_str() {
57        "GET" => ServerCall::Shorthand("get"),
58        "POST" => ServerCall::Shorthand("post"),
59        "PUT" => ServerCall::Shorthand("put"),
60        "PATCH" => ServerCall::Shorthand("patch"),
61        "DELETE" => ServerCall::Shorthand("delete"),
62        "HEAD" => ServerCall::AxumMethod("HEAD"),
63        "OPTIONS" => ServerCall::AxumMethod("OPTIONS"),
64        "TRACE" => ServerCall::AxumMethod("TRACE"),
65        _ => ServerCall::Shorthand("get"),
66    };
67
68    let req_path = &http.request.path;
69    let status = http.expected_response.status_code;
70
71    // Serialize expected response body (if any).
72    let body_str = match &http.expected_response.body {
73        Some(b) => serde_json::to_string(b).unwrap_or_else(|_| "{}".to_string()),
74        None => String::new(),
75    };
76    let body_literal = rust_raw_string(&body_str);
77
78    // Serialize request body (if any).
79    let req_body_str = match &http.request.body {
80        Some(b) => serde_json::to_string(b).unwrap_or_else(|_| "{}".to_string()),
81        None => String::new(),
82    };
83    let has_req_body = !req_body_str.is_empty();
84
85    // Extract middleware from handler (if any).
86    let middleware = http.handler.middleware.as_ref();
87    let cors_cfg: Option<&CorsConfig> = middleware.and_then(|m| m.cors.as_ref());
88    let static_files_cfgs: Option<&Vec<StaticFilesConfig>> = middleware.and_then(|m| m.static_files.as_ref());
89    let has_static_files = static_files_cfgs.is_some_and(|v| !v.is_empty());
90
91    let _ = writeln!(out, "#[tokio::test]");
92    let _ = writeln!(out, "async fn test_{fn_name}() {{");
93    let _ = writeln!(out, "    // {description}");
94
95    // When static-files middleware is configured, serve from a temp dir via ServeDir.
96    if has_static_files {
97        render_static_files_test(out, fixture, static_files_cfgs.unwrap(), &server_call, req_path, status);
98        return;
99    }
100
101    // Build handler that returns the expected response.
102    let _ = writeln!(out, "    let expected_body = {body_literal}.to_string();");
103    let _ = writeln!(out, "    let mut app = {dep_name}::App::new();");
104
105    // Emit route registration.
106    match &route_reg {
107        RouteRegistration::Shorthand(method) => {
108            let _ = writeln!(
109                out,
110                "    app.route({dep_name}::{method}({route:?}), move |_ctx: {dep_name}::RequestContext| {{"
111            );
112        }
113        RouteRegistration::Explicit(variant) => {
114            let _ = writeln!(
115                out,
116                "    app.route({dep_name}::RouteBuilder::new({dep_name}::Method::{variant}, {route:?}), move |_ctx: {dep_name}::RequestContext| {{"
117            );
118        }
119    }
120    let _ = writeln!(out, "        let body = expected_body.clone();");
121    let _ = writeln!(out, "        async move {{");
122    let _ = writeln!(out, "            Ok(axum::http::Response::builder()");
123    let _ = writeln!(out, "                .status({status}u16)");
124    let _ = writeln!(out, "                .header(\"content-type\", \"application/json\")");
125    let _ = writeln!(out, "                .body(axum::body::Body::from(body))");
126    let _ = writeln!(out, "                .unwrap())");
127    let _ = writeln!(out, "        }}");
128    let _ = writeln!(out, "    }}).unwrap();");
129
130    // Build axum-test TestServer from the app router, optionally wrapping with CorsLayer.
131    let _ = writeln!(out, "    let router = app.into_router().unwrap();");
132    if let Some(cors) = cors_cfg {
133        render_cors_layer(out, cors);
134    }
135    let _ = writeln!(out, "    let server = axum_test::TestServer::new(router);");
136
137    // Build and send the request.
138    match &server_call {
139        ServerCall::Shorthand(method) => {
140            let _ = writeln!(out, "    let response = server.{method}({req_path:?})");
141        }
142        ServerCall::AxumMethod(method) => {
143            let _ = writeln!(
144                out,
145                "    let response = server.method(axum::http::Method::{method}, {req_path:?})"
146            );
147        }
148    }
149
150    // Add request headers (axum_test::TestRequest::add_header accepts &str via TryInto).
151    for (name, value) in &http.request.headers {
152        let n = rust_raw_string(name);
153        let v = rust_raw_string(value);
154        let _ = writeln!(out, "        .add_header({n}, {v})");
155    }
156
157    // Add request body if present (pass as a JSON string so axum-test's bytes() API gets a Bytes value).
158    if has_req_body {
159        let req_body_literal = rust_raw_string(&req_body_str);
160        let _ = writeln!(
161            out,
162            "        .bytes(bytes::Bytes::copy_from_slice({req_body_literal}.as_bytes()))"
163        );
164    }
165
166    let _ = writeln!(out, "        .await;");
167
168    // Assert status code.
169    // When a CorsLayer is applied and the fixture expects a 2xx status, tower-http may
170    // return 200 instead of 204 for preflight. Accept any 2xx status in that case.
171    if cors_cfg.is_some() && (200..300).contains(&status) {
172        let _ = writeln!(
173            out,
174            "    assert!(response.status_code().is_success(), \"expected CORS success status, got {{}}\", response.status_code());"
175        );
176    } else {
177        let _ = writeln!(out, "    assert_eq!(response.status_code().as_u16(), {status}u16);");
178    }
179
180    let _ = writeln!(out, "}}");
181}
182
183/// Emit lines that wrap the axum router with a `tower_http::cors::CorsLayer`.
184///
185/// The CORS policy is derived from the fixture's `cors` middleware config.
186/// After this function, `router` is reassigned to the layer-wrapped version.
187pub fn render_cors_layer(out: &mut String, cors: &CorsConfig) {
188    // Decide up-front which axum::http re-exports we will actually reference so we
189    // can emit a tight `use` group — emitting all three unconditionally trips
190    // `-D unused_imports` for fixtures that, say, allow no custom headers.
191    let needs_header_value = !cors.allow_origins.is_empty();
192    let needs_method = !cors.allow_methods.is_empty();
193    let needs_header_name = !cors.allow_headers.is_empty()
194        && cors
195            .allow_headers
196            .iter()
197            .any(|h| !matches!(h.to_lowercase().as_str(), "content-type" | "authorization" | "accept"));
198
199    let _ = writeln!(
200        out,
201        "    // Apply CorsLayer from tower-http based on fixture CORS config."
202    );
203    let _ = writeln!(out, "    use tower_http::cors::CorsLayer;");
204    let mut imports: Vec<&'static str> = Vec::new();
205    if needs_header_name {
206        imports.push("HeaderName");
207    }
208    if needs_header_value {
209        imports.push("HeaderValue");
210    }
211    if needs_method {
212        imports.push("Method");
213    }
214    match imports.len() {
215        0 => {}
216        1 => {
217            let _ = writeln!(out, "    use axum::http::{};", imports[0]);
218        }
219        _ => {
220            let _ = writeln!(out, "    use axum::http::{{{}}};", imports.join(", "));
221        }
222    }
223    let _ = writeln!(out, "    let cors_layer = CorsLayer::new()");
224
225    // allow_origins
226    if cors.allow_origins.is_empty() {
227        let _ = writeln!(out, "        .allow_origin(tower_http::cors::Any)");
228    } else {
229        let _ = writeln!(out, "        .allow_origin([");
230        for origin in &cors.allow_origins {
231            let _ = writeln!(out, "            \"{origin}\".parse::<HeaderValue>().unwrap(),");
232        }
233        let _ = writeln!(out, "        ])");
234    }
235
236    // allow_methods
237    if cors.allow_methods.is_empty() {
238        let _ = writeln!(out, "        .allow_methods(tower_http::cors::Any)");
239    } else {
240        let methods: Vec<String> = cors
241            .allow_methods
242            .iter()
243            .map(|m| format!("Method::{}", m.to_uppercase()))
244            .collect();
245        let _ = writeln!(out, "        .allow_methods([{}])", methods.join(", "));
246    }
247
248    // allow_headers
249    if cors.allow_headers.is_empty() {
250        let _ = writeln!(out, "        .allow_headers(tower_http::cors::Any)");
251    } else {
252        let headers: Vec<String> = cors
253            .allow_headers
254            .iter()
255            .map(|h| {
256                let lower = h.to_lowercase();
257                match lower.as_str() {
258                    "content-type" => "axum::http::header::CONTENT_TYPE".to_string(),
259                    "authorization" => "axum::http::header::AUTHORIZATION".to_string(),
260                    "accept" => "axum::http::header::ACCEPT".to_string(),
261                    _ => format!("HeaderName::from_static(\"{lower}\")"),
262                }
263            })
264            .collect();
265        let _ = writeln!(out, "        .allow_headers([{}])", headers.join(", "));
266    }
267
268    // max_age
269    if let Some(secs) = cors.max_age {
270        let _ = writeln!(out, "        .max_age(std::time::Duration::from_secs({secs}));");
271    } else {
272        let _ = writeln!(out, "        ;");
273    }
274
275    let _ = writeln!(out, "    let router = router.layer(cors_layer);");
276}
277
278/// Emit lines for a static-files integration test.
279///
280/// Writes fixture files to a temporary directory and serves them via
281/// `tower_http::services::ServeDir`, bypassing the spikard App entirely.
282fn render_static_files_test(
283    out: &mut String,
284    fixture: &Fixture,
285    cfgs: &[StaticFilesConfig],
286    server_call: &ServerCall<'_>,
287    req_path: &str,
288    status: u16,
289) {
290    let http = fixture.http.as_ref().unwrap();
291
292    let _ = writeln!(out, "    use tower_http::services::ServeDir;");
293    let _ = writeln!(out, "    use axum::Router;");
294    let _ = writeln!(out, "    let tmp_dir = tempfile::tempdir().expect(\"tmp dir\");");
295
296    // Build the router by nesting a ServeDir for each config entry.
297    let _ = writeln!(out, "    let mut router = Router::new();");
298    for cfg in cfgs {
299        for file in &cfg.files {
300            let file_path = file.path.replace('\\', "/");
301            let content = rust_raw_string(&file.content);
302            if file_path.contains('/') {
303                let parent: String = file_path.rsplitn(2, '/').last().unwrap_or("").to_string();
304                let _ = writeln!(
305                    out,
306                    "    std::fs::create_dir_all(tmp_dir.path().join(\"{parent}\")).unwrap();"
307                );
308            }
309            let _ = writeln!(
310                out,
311                "    std::fs::write(tmp_dir.path().join(\"{file_path}\"), {content}).unwrap();"
312            );
313        }
314        let prefix = &cfg.route_prefix;
315        let serve_dir_expr = if cfg.index_file {
316            "ServeDir::new(tmp_dir.path()).append_index_html_on_directories(true)".to_string()
317        } else {
318            "ServeDir::new(tmp_dir.path())".to_string()
319        };
320        let _ = writeln!(out, "    router = router.nest_service({prefix:?}, {serve_dir_expr});");
321    }
322
323    let _ = writeln!(out, "    let server = axum_test::TestServer::new(router);");
324
325    // Build and send the request.
326    match server_call {
327        ServerCall::Shorthand(method) => {
328            let _ = writeln!(out, "    let response = server.{method}({req_path:?})");
329        }
330        ServerCall::AxumMethod(method) => {
331            let _ = writeln!(
332                out,
333                "    let response = server.method(axum::http::Method::{method}, {req_path:?})"
334            );
335        }
336    }
337
338    // Add request headers.
339    for (name, value) in &http.request.headers {
340        let n = rust_raw_string(name);
341        let v = rust_raw_string(value);
342        let _ = writeln!(out, "        .add_header({n}, {v})");
343    }
344
345    let _ = writeln!(out, "        .await;");
346    let _ = writeln!(out, "    assert_eq!(response.status_code().as_u16(), {status}u16);");
347    let _ = writeln!(out, "}}");
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn render_cors_layer_empty_policy_uses_any() {
356        let cors = CorsConfig::default();
357        let mut out = String::new();
358        render_cors_layer(&mut out, &cors);
359        assert!(out.contains("allow_origin(tower_http::cors::Any)"));
360        assert!(out.contains("allow_methods(tower_http::cors::Any)"));
361        assert!(out.contains("allow_headers(tower_http::cors::Any)"));
362    }
363
364    /// An empty CORS policy must not import `HeaderName`/`HeaderValue`/`Method`
365    /// — emitting unused imports trips `-D unused_imports` in the consumer.
366    #[test]
367    fn render_cors_layer_empty_policy_emits_no_axum_http_imports() {
368        let cors = CorsConfig::default();
369        let mut out = String::new();
370        render_cors_layer(&mut out, &cors);
371        assert!(!out.contains("use axum::http::"));
372    }
373
374    /// `allow_origins` set → `HeaderValue` is referenced, so the import must appear.
375    #[test]
376    fn render_cors_layer_with_origin_imports_header_value() {
377        let cors = CorsConfig {
378            allow_origins: vec!["https://example.com".to_string()],
379            ..CorsConfig::default()
380        };
381        let mut out = String::new();
382        render_cors_layer(&mut out, &cors);
383        assert!(out.contains("use axum::http::HeaderValue;"));
384    }
385
386    /// `allow_methods` set → `Method` is referenced.
387    #[test]
388    fn render_cors_layer_with_method_imports_method() {
389        let cors = CorsConfig {
390            allow_methods: vec!["GET".to_string()],
391            ..CorsConfig::default()
392        };
393        let mut out = String::new();
394        render_cors_layer(&mut out, &cors);
395        assert!(out.contains("use axum::http::Method;"));
396    }
397
398    /// `allow_headers` containing only prelude-mapped names (content-type, etc.)
399    /// must NOT import `HeaderName` — those headers expand to qualified constants.
400    #[test]
401    fn render_cors_layer_with_only_prelude_headers_omits_header_name() {
402        let cors = CorsConfig {
403            allow_headers: vec!["content-type".to_string(), "Authorization".to_string()],
404            ..CorsConfig::default()
405        };
406        let mut out = String::new();
407        render_cors_layer(&mut out, &cors);
408        assert!(!out.contains("HeaderName"));
409    }
410
411    /// `allow_headers` containing a custom header → `HeaderName::from_static(...)` is
412    /// emitted, so the `HeaderName` import must appear.
413    #[test]
414    fn render_cors_layer_with_custom_header_imports_header_name() {
415        let cors = CorsConfig {
416            allow_headers: vec!["X-Custom".to_string()],
417            ..CorsConfig::default()
418        };
419        let mut out = String::new();
420        render_cors_layer(&mut out, &cors);
421        assert!(out.contains("HeaderName"));
422        assert!(out.contains("use axum::http::HeaderName;"));
423    }
424}