use crate::codegen::resolve_field;
use crate::config::E2eConfig;
use crate::escape::{escape_rust, rust_raw_string, sanitize_filename, sanitize_ident};
use crate::field_access::FieldResolver;
use crate::fixture::{Assertion, CallbackAction, CorsConfig, Fixture, FixtureGroup, StaticFilesConfig};
use alef_core::backend::GeneratedFile;
use alef_core::config::AlefConfig;
use alef_core::hash::{self, CommentStyle};
use alef_core::template_versions as tv;
use anyhow::Result;
use std::fmt::Write as FmtWrite;
use std::path::PathBuf;
pub struct RustE2eCodegen;
impl super::E2eCodegen for RustE2eCodegen {
fn generate(
&self,
groups: &[FixtureGroup],
e2e_config: &E2eConfig,
alef_config: &AlefConfig,
) -> Result<Vec<GeneratedFile>> {
let mut files = Vec::new();
let output_base = PathBuf::from(e2e_config.effective_output()).join("rust");
let crate_name = resolve_crate_name(e2e_config, alef_config);
let crate_path = resolve_crate_path(e2e_config, &crate_name);
let dep_name = crate_name.replace('-', "_");
let all_call_configs = std::iter::once(&e2e_config.call).chain(e2e_config.calls.values());
let needs_serde_json = all_call_configs
.flat_map(|c| c.args.iter())
.any(|a| a.arg_type == "json_object" || a.arg_type == "handle");
let needs_mock_server = groups
.iter()
.flat_map(|g| g.fixtures.iter())
.any(|f| !is_skipped(f, "rust") && f.mock_response.is_some());
let needs_http_tests = groups
.iter()
.flat_map(|g| g.fixtures.iter())
.any(|f| !is_skipped(f, "rust") && f.http.is_some());
let needs_tower_http = groups
.iter()
.flat_map(|g| g.fixtures.iter())
.filter(|f| !is_skipped(f, "rust"))
.filter_map(|f| f.http.as_ref())
.filter_map(|h| h.handler.middleware.as_ref())
.any(|m| m.cors.is_some() || m.static_files.is_some());
let any_async_call = std::iter::once(&e2e_config.call)
.chain(e2e_config.calls.values())
.any(|c| c.r#async);
let needs_tokio = needs_mock_server || needs_http_tests || any_async_call;
let crate_version = resolve_crate_version(e2e_config);
files.push(GeneratedFile {
path: output_base.join("Cargo.toml"),
content: render_cargo_toml(
&crate_name,
&dep_name,
&crate_path,
needs_serde_json,
needs_mock_server,
needs_http_tests,
needs_tokio,
needs_tower_http,
e2e_config.dep_mode,
crate_version.as_deref(),
&alef_config.crate_config.features,
),
generated_header: true,
});
if needs_mock_server {
files.push(GeneratedFile {
path: output_base.join("tests").join("mock_server.rs"),
content: render_mock_server_module(),
generated_header: true,
});
}
if needs_mock_server || needs_http_tests {
files.push(GeneratedFile {
path: output_base.join("src").join("main.rs"),
content: render_mock_server_binary(),
generated_header: true,
});
}
for group in groups {
let fixtures: Vec<&Fixture> = group.fixtures.iter().filter(|f| !is_skipped(f, "rust")).collect();
if fixtures.is_empty() {
continue;
}
let filename = format!("{}_test.rs", sanitize_filename(&group.category));
let content = render_test_file(&group.category, &fixtures, e2e_config, &dep_name, needs_mock_server);
files.push(GeneratedFile {
path: output_base.join("tests").join(filename),
content,
generated_header: true,
});
}
Ok(files)
}
fn language_name(&self) -> &'static str {
"rust"
}
}
fn resolve_crate_name(_e2e_config: &E2eConfig, alef_config: &AlefConfig) -> String {
alef_config.crate_config.name.clone()
}
fn resolve_crate_path(e2e_config: &E2eConfig, crate_name: &str) -> String {
e2e_config
.resolve_package("rust")
.and_then(|p| p.path.clone())
.unwrap_or_else(|| format!("../../crates/{crate_name}"))
}
fn resolve_crate_version(e2e_config: &E2eConfig) -> Option<String> {
e2e_config.resolve_package("rust").and_then(|p| p.version.clone())
}
fn resolve_function_name_for_call(call_config: &crate::config::CallConfig) -> String {
call_config
.overrides
.get("rust")
.and_then(|o| o.function.clone())
.unwrap_or_else(|| call_config.function.clone())
}
fn resolve_module(e2e_config: &E2eConfig, dep_name: &str) -> String {
resolve_module_for_call(&e2e_config.call, dep_name)
}
fn resolve_module_for_call(call_config: &crate::config::CallConfig, dep_name: &str) -> String {
let overrides = call_config.overrides.get("rust");
overrides
.and_then(|o| o.crate_name.clone())
.or_else(|| overrides.and_then(|o| o.module.clone()))
.unwrap_or_else(|| dep_name.to_string())
}
fn is_skipped(fixture: &Fixture, language: &str) -> bool {
fixture.skip.as_ref().is_some_and(|s| s.should_skip(language))
}
#[allow(clippy::too_many_arguments)]
pub fn render_cargo_toml(
crate_name: &str,
dep_name: &str,
crate_path: &str,
needs_serde_json: bool,
needs_mock_server: bool,
needs_http_tests: bool,
needs_tokio: bool,
needs_tower_http: bool,
dep_mode: crate::config::DependencyMode,
version: Option<&str>,
features: &[String],
) -> String {
let e2e_name = format!("{dep_name}-e2e-rust");
let effective_features: Vec<&str> = features.iter().map(|s| s.as_str()).collect();
let features_str = if effective_features.is_empty() {
String::new()
} else {
format!(", default-features = false, features = {:?}", effective_features)
};
let dep_spec = match dep_mode {
crate::config::DependencyMode::Registry => {
let ver = version.unwrap_or("0.1.0");
if crate_name != dep_name {
format!("{dep_name} = {{ package = \"{crate_name}\", version = \"{ver}\"{features_str} }}")
} else if effective_features.is_empty() {
format!("{dep_name} = \"{ver}\"")
} else {
format!("{dep_name} = {{ version = \"{ver}\"{features_str} }}")
}
}
crate::config::DependencyMode::Local => {
if crate_name != dep_name {
format!("{dep_name} = {{ package = \"{crate_name}\", path = \"{crate_path}\"{features_str} }}")
} else if effective_features.is_empty() {
format!("{dep_name} = {{ path = \"{crate_path}\" }}")
} else {
format!("{dep_name} = {{ path = \"{crate_path}\"{features_str} }}")
}
}
};
let effective_needs_serde_json = needs_serde_json || needs_mock_server || needs_http_tests;
let serde_line = if effective_needs_serde_json {
"\nserde_json = \"1\""
} else {
""
};
let needs_axum = needs_mock_server || needs_http_tests;
let mock_lines = if needs_axum {
let mut lines = format!(
"\naxum = \"{axum}\"\nserde = {{ version = \"1\", features = [\"derive\"] }}\nwalkdir = \"{walkdir}\"",
axum = tv::cargo::AXUM,
walkdir = tv::cargo::WALKDIR,
);
if needs_mock_server {
lines.push_str(&format!(
"\ntokio-stream = \"{tokio_stream}\"",
tokio_stream = tv::cargo::TOKIO_STREAM
));
}
if needs_http_tests {
lines.push_str("\naxum-test = \"20\"\nbytes = \"1\"");
}
if needs_tower_http {
lines.push_str(&format!(
"\ntower-http = {{ version = \"{tower_http}\", features = [\"cors\", \"fs\"] }}\ntempfile = \"{tempfile}\"",
tower_http = tv::cargo::TOWER_HTTP,
tempfile = tv::cargo::TEMPFILE,
));
}
lines
} else {
String::new()
};
let mut machete_ignored: Vec<&str> = Vec::new();
if effective_needs_serde_json {
machete_ignored.push("\"serde_json\"");
}
if needs_axum {
machete_ignored.push("\"axum\"");
machete_ignored.push("\"serde\"");
machete_ignored.push("\"walkdir\"");
}
if needs_mock_server {
machete_ignored.push("\"tokio-stream\"");
}
if needs_http_tests {
machete_ignored.push("\"axum-test\"");
machete_ignored.push("\"bytes\"");
}
if needs_tower_http {
machete_ignored.push("\"tower-http\"");
machete_ignored.push("\"tempfile\"");
}
let machete_section = if machete_ignored.is_empty() {
String::new()
} else {
format!(
"\n[package.metadata.cargo-machete]\nignored = [{}]\n",
machete_ignored.join(", ")
)
};
let tokio_line = if needs_tokio {
"\ntokio = { version = \"1\", features = [\"full\"] }"
} else {
""
};
let bin_section = if needs_mock_server || needs_http_tests {
"\n[[bin]]\nname = \"mock-server\"\npath = \"src/main.rs\"\n"
} else {
""
};
let header = hash::header(CommentStyle::Hash);
format!(
r#"{header}
[workspace]
[package]
name = "{e2e_name}"
version = "0.1.0"
edition = "2021"
license = "MIT"
publish = false
{bin_section}
[dependencies]
{dep_spec}{serde_line}{mock_lines}{tokio_line}
{machete_section}"#
)
}
fn render_test_file(
category: &str,
fixtures: &[&Fixture],
e2e_config: &E2eConfig,
dep_name: &str,
needs_mock_server: bool,
) -> String {
let mut out = String::new();
out.push_str(&hash::header(CommentStyle::DoubleSlash));
let _ = writeln!(out, "//! E2e tests for category: {category}");
let _ = writeln!(out);
let module = resolve_module(e2e_config, dep_name);
let field_resolver = FieldResolver::new(
&e2e_config.fields,
&e2e_config.fields_optional,
&e2e_config.result_fields,
&e2e_config.fields_array,
);
let file_has_http = fixtures.iter().any(|f| f.http.is_some());
let file_has_call_based = fixtures.iter().any(|f| f.mock_response.is_some());
if file_has_call_based {
let mut imported: std::collections::BTreeSet<(String, String)> = std::collections::BTreeSet::new();
for fixture in fixtures.iter().filter(|f| f.mock_response.is_some()) {
let call_config = e2e_config.resolve_call(fixture.call.as_deref());
let fn_name = resolve_function_name_for_call(call_config);
let mod_name = resolve_module_for_call(call_config, dep_name);
imported.insert((mod_name, fn_name));
}
let mut by_module: std::collections::BTreeMap<String, Vec<String>> = std::collections::BTreeMap::new();
for (mod_name, fn_name) in &imported {
by_module.entry(mod_name.clone()).or_default().push(fn_name.clone());
}
for (mod_name, fns) in &by_module {
if fns.len() == 1 {
let _ = writeln!(out, "use {mod_name}::{};", fns[0]);
} else {
let joined = fns.join(", ");
let _ = writeln!(out, "use {mod_name}::{{{joined}}};");
}
}
}
if file_has_http {
let _ = writeln!(out, "use {module}::{{App, RequestContext}};");
}
let has_handle_args = e2e_config.call.args.iter().any(|a| a.arg_type == "handle");
if has_handle_args {
let _ = writeln!(out, "use {module}::CrawlConfig;");
}
for arg in &e2e_config.call.args {
if arg.arg_type == "handle" {
use heck::ToSnakeCase;
let constructor_name = format!("create_{}", arg.name.to_snake_case());
let _ = writeln!(out, "use {module}::{constructor_name};");
}
}
let file_needs_mock = needs_mock_server && fixtures.iter().any(|f| f.mock_response.is_some());
if file_needs_mock {
let _ = writeln!(out, "mod mock_server;");
let _ = writeln!(out, "use mock_server::{{MockRoute, MockServer}};");
}
let file_needs_visitor = fixtures.iter().any(|f| f.visitor.is_some());
if file_needs_visitor {
let visitor_trait = resolve_visitor_trait(&module);
let _ = writeln!(out, "use {module}::{{{visitor_trait}, NodeContext, VisitResult}};");
}
let _ = writeln!(out);
for fixture in fixtures {
render_test_function(&mut out, fixture, e2e_config, dep_name, &field_resolver);
let _ = writeln!(out);
}
if !out.ends_with('\n') {
out.push('\n');
}
out
}
fn render_test_function(
out: &mut String,
fixture: &Fixture,
e2e_config: &E2eConfig,
dep_name: &str,
field_resolver: &FieldResolver,
) {
if fixture.http.is_some() {
render_http_test_function(out, fixture, dep_name);
return;
}
if fixture.http.is_none() && fixture.mock_response.is_none() {
let fn_name = sanitize_ident(&fixture.id);
let description = &fixture.description;
let _ = writeln!(out, "#[tokio::test]");
let _ = writeln!(out, "async fn test_{fn_name}() {{");
let _ = writeln!(out, " // {description}");
let _ = writeln!(
out,
" // TODO: implement when a callable API is available for this fixture type."
);
let _ = writeln!(out, "}}");
return;
}
let fn_name = sanitize_ident(&fixture.id);
let description = &fixture.description;
let call_config = e2e_config.resolve_call(fixture.call.as_deref());
let function_name = resolve_function_name_for_call(call_config);
let module = resolve_module_for_call(call_config, dep_name);
let result_var = &call_config.result_var;
let has_mock = fixture.mock_response.is_some();
let is_async = call_config.r#async || has_mock;
if is_async {
let _ = writeln!(out, "#[tokio::test]");
let _ = writeln!(out, "async fn test_{fn_name}() {{");
} else {
let _ = writeln!(out, "#[test]");
let _ = writeln!(out, "fn test_{fn_name}() {{");
}
let _ = writeln!(out, " // {description}");
if has_mock {
render_mock_server_setup(out, fixture, e2e_config);
}
let has_error_assertion = fixture.assertions.iter().any(|a| a.assertion_type == "error");
let rust_overrides = call_config.overrides.get("rust");
let wrap_options_in_some = rust_overrides.is_some_and(|o| o.wrap_options_in_some);
let extra_args: Vec<String> = rust_overrides.map(|o| o.extra_args.clone()).unwrap_or_default();
let mut arg_exprs: Vec<String> = Vec::new();
for arg in &call_config.args {
let value = resolve_field(&fixture.input, &arg.field);
let var_name = &arg.name;
let (bindings, expr) = render_rust_arg(
var_name,
value,
&arg.arg_type,
arg.optional,
&module,
&fixture.id,
if has_mock {
Some("mock_server.url.as_str()")
} else {
None
},
arg.owned,
arg.element_type.as_deref(),
);
for binding in &bindings {
let _ = writeln!(out, " {binding}");
}
let final_expr = if wrap_options_in_some && arg.arg_type == "json_object" {
if let Some(rest) = expr.strip_prefix('&') {
format!("Some({rest}.clone())")
} else {
format!("Some({expr})")
}
} else {
expr
};
arg_exprs.push(final_expr);
}
if let Some(visitor_spec) = &fixture.visitor {
let _ = writeln!(out, " struct _TestVisitor;");
let _ = writeln!(out, " impl {} for _TestVisitor {{", resolve_visitor_trait(&module));
for (method_name, action) in &visitor_spec.callbacks {
emit_rust_visitor_method(out, method_name, action);
}
let _ = writeln!(out, " }}");
let _ = writeln!(
out,
" let visitor = std::rc::Rc::new(std::cell::RefCell::new(_TestVisitor));"
);
arg_exprs.push("Some(visitor)".to_string());
} else {
arg_exprs.extend(extra_args);
}
let args_str = arg_exprs.join(", ");
let await_suffix = if is_async { ".await" } else { "" };
let result_is_tree = call_config.result_var == "tree";
let result_is_simple = rust_overrides.is_some_and(|o| o.result_is_simple);
let result_is_vec = rust_overrides.is_some_and(|o| o.result_is_vec);
let result_is_option = rust_overrides.is_some_and(|o| o.result_is_option);
if has_error_assertion {
let _ = writeln!(out, " let {result_var} = {function_name}({args_str}){await_suffix};");
for assertion in &fixture.assertions {
render_assertion(
out,
assertion,
result_var,
&module,
dep_name,
true,
&[],
field_resolver,
result_is_tree,
result_is_simple,
false,
false,
);
}
let _ = writeln!(out, "}}");
return;
}
let has_not_error = fixture.assertions.iter().any(|a| a.assertion_type == "not_error");
let has_usable_assertion = fixture.assertions.iter().any(|a| {
if a.assertion_type == "not_error" || a.assertion_type == "error" {
return false;
}
if a.assertion_type == "method_result" {
let supported_checks = [
"equals",
"is_true",
"is_false",
"greater_than_or_equal",
"count_min",
"is_error",
"contains",
"not_empty",
"is_empty",
];
let check = a.check.as_deref().unwrap_or("is_true");
if a.method.is_none() || !supported_checks.contains(&check) {
return false;
}
}
match &a.field {
Some(f) if !f.is_empty() => field_resolver.is_valid_for_result(f),
_ => true,
}
});
let result_binding = if has_usable_assertion {
result_var.to_string()
} else {
"_".to_string()
};
let has_field_access = fixture
.assertions
.iter()
.any(|a| a.field.as_ref().is_some_and(|f| !f.is_empty()));
let only_emptiness_checks = !has_field_access
&& fixture.assertions.iter().all(|a| {
matches!(
a.assertion_type.as_str(),
"is_empty" | "is_false" | "not_empty" | "is_true" | "not_error"
)
});
let returns_result = rust_overrides
.and_then(|o| o.returns_result)
.unwrap_or(call_config.returns_result);
let unwrap_suffix = if returns_result {
".expect(\"should succeed\")"
} else {
""
};
if only_emptiness_checks || !returns_result {
let _ = writeln!(
out,
" let {result_binding} = {function_name}({args_str}){await_suffix};"
);
} else if has_not_error || !fixture.assertions.is_empty() {
let _ = writeln!(
out,
" let {result_binding} = {function_name}({args_str}){await_suffix}{unwrap_suffix};"
);
} else {
let _ = writeln!(
out,
" let {result_binding} = {function_name}({args_str}){await_suffix};"
);
}
let string_assertion_types = [
"equals",
"contains",
"contains_all",
"contains_any",
"not_contains",
"starts_with",
"ends_with",
"min_length",
"max_length",
"matches_regex",
];
let mut unwrapped_fields: Vec<(String, String)> = Vec::new(); if !result_is_vec {
for assertion in &fixture.assertions {
if let Some(f) = &assertion.field {
if !f.is_empty()
&& string_assertion_types.contains(&assertion.assertion_type.as_str())
&& !unwrapped_fields.iter().any(|(ff, _)| ff == f)
{
let is_string_assertion = assertion.value.as_ref().is_none_or(|v| v.is_string());
if !is_string_assertion {
continue;
}
if let Some((binding, local_var)) = field_resolver.rust_unwrap_binding(f, result_var) {
let _ = writeln!(out, " {binding}");
unwrapped_fields.push((f.clone(), local_var));
}
}
}
}
}
for assertion in &fixture.assertions {
if assertion.assertion_type == "not_error" {
continue;
}
render_assertion(
out,
assertion,
result_var,
&module,
dep_name,
false,
&unwrapped_fields,
field_resolver,
result_is_tree,
result_is_simple,
result_is_vec,
result_is_option,
);
}
let _ = writeln!(out, "}}");
}
#[allow(clippy::too_many_arguments)]
fn render_rust_arg(
name: &str,
value: &serde_json::Value,
arg_type: &str,
optional: bool,
module: &str,
fixture_id: &str,
mock_base_url: Option<&str>,
owned: bool,
element_type: Option<&str>,
) -> (Vec<String>, String) {
if arg_type == "mock_url" {
let lines = vec![format!(
"let {name} = format!(\"{{}}/fixtures/{{}}\", std::env::var(\"MOCK_SERVER_URL\").expect(\"MOCK_SERVER_URL not set\"), \"{fixture_id}\");"
)];
return (lines, format!("&{name}"));
}
if arg_type == "base_url" {
if let Some(url_expr) = mock_base_url {
return (vec![], url_expr.to_string());
}
}
if arg_type == "handle" {
use heck::ToSnakeCase;
let constructor_name = format!("create_{}", name.to_snake_case());
let mut lines = Vec::new();
if value.is_null() || value.is_object() && value.as_object().unwrap().is_empty() {
lines.push(format!(
"let {name} = {constructor_name}(None).expect(\"handle creation should succeed\");"
));
} else {
let json_literal = serde_json::to_string(value).unwrap_or_default();
let escaped = json_literal.replace('\\', "\\\\").replace('"', "\\\"");
lines.push(format!(
"let {name}_config: CrawlConfig = serde_json::from_str(\"{escaped}\").expect(\"config should parse\");"
));
lines.push(format!(
"let {name} = {constructor_name}(Some({name}_config)).expect(\"handle creation should succeed\");"
));
}
return (lines, format!("&{name}"));
}
if arg_type == "json_object" {
return render_json_object_arg(name, value, optional, owned, element_type, module);
}
if value.is_null() && !optional {
let default_val = match arg_type {
"string" => "String::new()".to_string(),
"int" | "integer" => "0".to_string(),
"float" | "number" => "0.0_f64".to_string(),
"bool" | "boolean" => "false".to_string(),
_ => "Default::default()".to_string(),
};
let expr = if arg_type == "string" {
format!("&{name}")
} else {
name.to_string()
};
return (vec![format!("let {name} = {default_val};")], expr);
}
let literal = json_to_rust_literal(value, arg_type);
let pass_by_ref = arg_type == "bytes";
let optional_expr = |n: &str| {
if arg_type == "string" {
format!("{n}.as_deref()")
} else if arg_type == "bytes" {
format!("{n}.as_deref().map(|v| v.as_slice())")
} else {
n.to_string()
}
};
let expr = |n: &str| {
if arg_type == "bytes" {
format!("{n}.as_bytes()")
} else if pass_by_ref {
format!("&{n}")
} else {
n.to_string()
}
};
if optional && value.is_null() {
let none_decl = match arg_type {
"string" => format!("let {name}: Option<String> = None;"),
"bytes" => format!("let {name}: Option<Vec<u8>> = None;"),
_ => format!("let {name} = None;"),
};
(vec![none_decl], optional_expr(name))
} else if optional {
(vec![format!("let {name} = Some({literal});")], optional_expr(name))
} else {
(vec![format!("let {name} = {literal};")], expr(name))
}
}
fn render_json_object_arg(
name: &str,
value: &serde_json::Value,
optional: bool,
owned: bool,
element_type: Option<&str>,
_module: &str,
) -> (Vec<String>, String) {
let pass_by_ref = !owned;
if value.is_null() && optional {
let expr = if pass_by_ref {
format!("&{name}")
} else {
name.to_string()
};
return (vec![format!("let {name} = Default::default();")], expr);
}
let normalized = super::normalize_json_keys_to_snake_case(value);
let json_literal = json_value_to_macro_literal(&normalized);
let mut lines = Vec::new();
lines.push(format!("let {name}_json = serde_json::json!({json_literal});"));
let deser_expr = if let Some(elem) = element_type {
format!("serde_json::from_value::<Vec<{elem}>>({name}_json).unwrap()")
} else {
format!("serde_json::from_value({name}_json).unwrap()")
};
lines.push(format!("let {name} = {deser_expr};"));
let expr = if pass_by_ref {
format!("&{name}")
} else {
name.to_string()
};
(lines, expr)
}
fn json_value_to_macro_literal(value: &serde_json::Value) -> String {
match value {
serde_json::Value::Null => "null".to_string(),
serde_json::Value::Bool(b) => format!("{b}"),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::String(s) => {
let escaped = s.replace('\\', "\\\\").replace('"', "\\\"");
format!("\"{escaped}\"")
}
serde_json::Value::Array(arr) => {
let items: Vec<String> = arr.iter().map(json_value_to_macro_literal).collect();
format!("[{}]", items.join(", "))
}
serde_json::Value::Object(obj) => {
let entries: Vec<String> = obj
.iter()
.map(|(k, v)| {
let escaped_key = k.replace('\\', "\\\\").replace('"', "\\\"");
format!("\"{escaped_key}\": {}", json_value_to_macro_literal(v))
})
.collect();
format!("{{{}}}", entries.join(", "))
}
}
}
fn json_to_rust_literal(value: &serde_json::Value, arg_type: &str) -> String {
match value {
serde_json::Value::Null => "None".to_string(),
serde_json::Value::Bool(b) => format!("{b}"),
serde_json::Value::Number(n) => {
if arg_type.contains("float") || arg_type.contains("f64") || arg_type.contains("f32") {
if let Some(f) = n.as_f64() {
return format!("{f}_f64");
}
}
n.to_string()
}
serde_json::Value::String(s) => rust_raw_string(s),
serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
let json_str = serde_json::to_string(value).unwrap_or_default();
let literal = rust_raw_string(&json_str);
format!("serde_json::from_str({literal}).unwrap()")
}
}
}
enum ServerCall<'a> {
Shorthand(&'a str),
AxumMethod(&'a str),
}
enum RouteRegistration<'a> {
Shorthand(&'a str),
Explicit(&'a str),
}
fn render_http_test_function(out: &mut String, fixture: &Fixture, dep_name: &str) {
let http = match &fixture.http {
Some(h) => h,
None => return,
};
let fn_name = sanitize_ident(&fixture.id);
let description = &fixture.description;
let route = &http.handler.route;
let route_reg = match http.handler.method.to_lowercase().as_str() {
"get" => RouteRegistration::Shorthand("get"),
"post" => RouteRegistration::Shorthand("post"),
"put" => RouteRegistration::Shorthand("put"),
"patch" => RouteRegistration::Shorthand("patch"),
"delete" => RouteRegistration::Shorthand("delete"),
"head" => RouteRegistration::Explicit("Head"),
"options" => RouteRegistration::Explicit("Options"),
"trace" => RouteRegistration::Explicit("Trace"),
_ => RouteRegistration::Shorthand("get"),
};
let server_call = match http.request.method.to_uppercase().as_str() {
"GET" => ServerCall::Shorthand("get"),
"POST" => ServerCall::Shorthand("post"),
"PUT" => ServerCall::Shorthand("put"),
"PATCH" => ServerCall::Shorthand("patch"),
"DELETE" => ServerCall::Shorthand("delete"),
"HEAD" => ServerCall::AxumMethod("HEAD"),
"OPTIONS" => ServerCall::AxumMethod("OPTIONS"),
"TRACE" => ServerCall::AxumMethod("TRACE"),
_ => ServerCall::Shorthand("get"),
};
let req_path = &http.request.path;
let status = http.expected_response.status_code;
let body_str = match &http.expected_response.body {
Some(b) => serde_json::to_string(b).unwrap_or_else(|_| "{}".to_string()),
None => String::new(),
};
let body_literal = rust_raw_string(&body_str);
let req_body_str = match &http.request.body {
Some(b) => serde_json::to_string(b).unwrap_or_else(|_| "{}".to_string()),
None => String::new(),
};
let has_req_body = !req_body_str.is_empty();
let middleware = http.handler.middleware.as_ref();
let cors_cfg: Option<&CorsConfig> = middleware.and_then(|m| m.cors.as_ref());
let static_files_cfgs: Option<&Vec<StaticFilesConfig>> = middleware.and_then(|m| m.static_files.as_ref());
let has_static_files = static_files_cfgs.is_some_and(|v| !v.is_empty());
let _ = writeln!(out, "#[tokio::test]");
let _ = writeln!(out, "async fn test_{fn_name}() {{");
let _ = writeln!(out, " // {description}");
if has_static_files {
render_static_files_test(out, fixture, static_files_cfgs.unwrap(), &server_call, req_path, status);
return;
}
let _ = writeln!(out, " let expected_body = {body_literal}.to_string();");
let _ = writeln!(out, " let mut app = {dep_name}::App::new();");
match &route_reg {
RouteRegistration::Shorthand(method) => {
let _ = writeln!(
out,
" app.route({dep_name}::{method}({route:?}), move |_ctx: {dep_name}::RequestContext| {{"
);
}
RouteRegistration::Explicit(variant) => {
let _ = writeln!(
out,
" app.route({dep_name}::RouteBuilder::new({dep_name}::Method::{variant}, {route:?}), move |_ctx: {dep_name}::RequestContext| {{"
);
}
}
let _ = writeln!(out, " let body = expected_body.clone();");
let _ = writeln!(out, " async move {{");
let _ = writeln!(out, " Ok(axum::http::Response::builder()");
let _ = writeln!(out, " .status({status}u16)");
let _ = writeln!(out, " .header(\"content-type\", \"application/json\")");
let _ = writeln!(out, " .body(axum::body::Body::from(body))");
let _ = writeln!(out, " .unwrap())");
let _ = writeln!(out, " }}");
let _ = writeln!(out, " }}).unwrap();");
let _ = writeln!(out, " let router = app.into_router().unwrap();");
if let Some(cors) = cors_cfg {
render_cors_layer(out, cors);
}
let _ = writeln!(out, " let server = axum_test::TestServer::new(router);");
match &server_call {
ServerCall::Shorthand(method) => {
let _ = writeln!(out, " let response = server.{method}({req_path:?})");
}
ServerCall::AxumMethod(method) => {
let _ = writeln!(
out,
" let response = server.method(axum::http::Method::{method}, {req_path:?})"
);
}
}
for (name, value) in &http.request.headers {
let n = rust_raw_string(name);
let v = rust_raw_string(value);
let _ = writeln!(out, " .add_header({n}, {v})");
}
if has_req_body {
let req_body_literal = rust_raw_string(&req_body_str);
let _ = writeln!(
out,
" .bytes(bytes::Bytes::copy_from_slice({req_body_literal}.as_bytes()))"
);
}
let _ = writeln!(out, " .await;");
if cors_cfg.is_some() && (200..300).contains(&status) {
let _ = writeln!(
out,
" assert!(response.status_code().is_success(), \"expected CORS success status, got {{}}\", response.status_code());"
);
} else {
let _ = writeln!(out, " assert_eq!(response.status_code().as_u16(), {status}u16);");
}
let _ = writeln!(out, "}}");
}
fn render_cors_layer(out: &mut String, cors: &CorsConfig) {
let _ = writeln!(
out,
" // Apply CorsLayer from tower-http based on fixture CORS config."
);
let _ = writeln!(out, " use tower_http::cors::CorsLayer;");
let _ = writeln!(out, " use axum::http::{{HeaderName, HeaderValue, Method}};");
let _ = writeln!(out, " let cors_layer = CorsLayer::new()");
if cors.allow_origins.is_empty() {
let _ = writeln!(out, " .allow_origin(tower_http::cors::Any)");
} else {
let _ = writeln!(out, " .allow_origin([");
for origin in &cors.allow_origins {
let _ = writeln!(out, " \"{origin}\".parse::<HeaderValue>().unwrap(),");
}
let _ = writeln!(out, " ])");
}
if cors.allow_methods.is_empty() {
let _ = writeln!(out, " .allow_methods(tower_http::cors::Any)");
} else {
let methods: Vec<String> = cors
.allow_methods
.iter()
.map(|m| format!("Method::{}", m.to_uppercase()))
.collect();
let _ = writeln!(out, " .allow_methods([{}])", methods.join(", "));
}
if cors.allow_headers.is_empty() {
let _ = writeln!(out, " .allow_headers(tower_http::cors::Any)");
} else {
let headers: Vec<String> = cors
.allow_headers
.iter()
.map(|h| {
let lower = h.to_lowercase();
match lower.as_str() {
"content-type" => "axum::http::header::CONTENT_TYPE".to_string(),
"authorization" => "axum::http::header::AUTHORIZATION".to_string(),
"accept" => "axum::http::header::ACCEPT".to_string(),
_ => format!("HeaderName::from_static(\"{lower}\")"),
}
})
.collect();
let _ = writeln!(out, " .allow_headers([{}])", headers.join(", "));
}
if let Some(secs) = cors.max_age {
let _ = writeln!(out, " .max_age(std::time::Duration::from_secs({secs}));");
} else {
let _ = writeln!(out, " ;");
}
let _ = writeln!(out, " let router = router.layer(cors_layer);");
}
fn render_static_files_test(
out: &mut String,
fixture: &Fixture,
cfgs: &[StaticFilesConfig],
server_call: &ServerCall<'_>,
req_path: &str,
status: u16,
) {
let http = fixture.http.as_ref().unwrap();
let _ = writeln!(out, " use tower_http::services::ServeDir;");
let _ = writeln!(out, " use axum::Router;");
let _ = writeln!(out, " let tmp_dir = tempfile::tempdir().expect(\"tmp dir\");");
let _ = writeln!(out, " let mut router = Router::new();");
for cfg in cfgs {
for file in &cfg.files {
let file_path = file.path.replace('\\', "/");
let content = rust_raw_string(&file.content);
if file_path.contains('/') {
let parent: String = file_path.rsplitn(2, '/').last().unwrap_or("").to_string();
let _ = writeln!(
out,
" std::fs::create_dir_all(tmp_dir.path().join(\"{parent}\")).unwrap();"
);
}
let _ = writeln!(
out,
" std::fs::write(tmp_dir.path().join(\"{file_path}\"), {content}).unwrap();"
);
}
let prefix = &cfg.route_prefix;
let serve_dir_expr = if cfg.index_file {
"ServeDir::new(tmp_dir.path()).append_index_html_on_directories(true)".to_string()
} else {
"ServeDir::new(tmp_dir.path())".to_string()
};
let _ = writeln!(out, " router = router.nest_service({prefix:?}, {serve_dir_expr});");
}
let _ = writeln!(out, " let server = axum_test::TestServer::new(router);");
match server_call {
ServerCall::Shorthand(method) => {
let _ = writeln!(out, " let response = server.{method}({req_path:?})");
}
ServerCall::AxumMethod(method) => {
let _ = writeln!(
out,
" let response = server.method(axum::http::Method::{method}, {req_path:?})"
);
}
}
for (name, value) in &http.request.headers {
let n = rust_raw_string(name);
let v = rust_raw_string(value);
let _ = writeln!(out, " .add_header({n}, {v})");
}
let _ = writeln!(out, " .await;");
let _ = writeln!(out, " assert_eq!(response.status_code().as_u16(), {status}u16);");
let _ = writeln!(out, "}}");
}
fn render_mock_server_setup(out: &mut String, fixture: &Fixture, e2e_config: &E2eConfig) {
let mock = match fixture.mock_response.as_ref() {
Some(m) => m,
None => return,
};
let call_config = e2e_config.resolve_call(fixture.call.as_deref());
let path = call_config.path.as_deref().unwrap_or("/");
let method = call_config.method.as_deref().unwrap_or("POST");
let status = mock.status;
let mut header_entries: Vec<(&String, &String)> = mock.headers.iter().collect();
header_entries.sort_by(|a, b| a.0.cmp(b.0));
let render_headers = |out: &mut String| {
let _ = writeln!(out, " headers: vec![");
for (name, value) in &header_entries {
let n = rust_raw_string(name);
let v = rust_raw_string(value);
let _ = writeln!(out, " ({n}.to_string(), {v}.to_string()),");
}
let _ = writeln!(out, " ],");
};
if let Some(chunks) = &mock.stream_chunks {
let _ = writeln!(out, " let mock_route = MockRoute {{");
let _ = writeln!(out, " path: \"{path}\",");
let _ = writeln!(out, " method: \"{method}\",");
let _ = writeln!(out, " status: {status},");
let _ = writeln!(out, " body: String::new(),");
let _ = writeln!(out, " stream_chunks: vec![");
for chunk in chunks {
let chunk_str = match chunk {
serde_json::Value::String(s) => rust_raw_string(s),
other => {
let s = serde_json::to_string(other).unwrap_or_default();
rust_raw_string(&s)
}
};
let _ = writeln!(out, " {chunk_str}.to_string(),");
}
let _ = writeln!(out, " ],");
render_headers(out);
let _ = writeln!(out, " }};");
} else {
let body_str = match &mock.body {
Some(b) => {
let s = serde_json::to_string(b).unwrap_or_default();
rust_raw_string(&s)
}
None => rust_raw_string("{}"),
};
let _ = writeln!(out, " let mock_route = MockRoute {{");
let _ = writeln!(out, " path: \"{path}\",");
let _ = writeln!(out, " method: \"{method}\",");
let _ = writeln!(out, " status: {status},");
let _ = writeln!(out, " body: {body_str}.to_string(),");
let _ = writeln!(out, " stream_chunks: vec![],");
render_headers(out);
let _ = writeln!(out, " }};");
}
let _ = writeln!(out, " let mock_server = MockServer::start(vec![mock_route]).await;");
}
pub fn render_mock_server_module() -> String {
hash::header(CommentStyle::DoubleSlash)
+ r#"//
// Minimal axum-based mock HTTP server for e2e tests.
use std::net::SocketAddr;
use std::sync::Arc;
use axum::Router;
use axum::body::Body;
use axum::extract::State;
use axum::http::{Request, StatusCode};
use axum::response::{IntoResponse, Response};
use tokio::net::TcpListener;
/// A single mock route: match by path + method, return a configured response.
#[derive(Clone, Debug)]
pub struct MockRoute {
/// URL path to match, e.g. `"/v1/chat/completions"`.
pub path: &'static str,
/// HTTP method to match, e.g. `"POST"` or `"GET"`.
pub method: &'static str,
/// HTTP status code to return.
pub status: u16,
/// Response body JSON string (used when `stream_chunks` is empty).
pub body: String,
/// Ordered SSE data payloads for streaming responses.
/// Each entry becomes `data: <chunk>\n\n` in the response.
/// A final `data: [DONE]\n\n` is always appended.
pub stream_chunks: Vec<String>,
/// Response headers to apply (name, value) pairs.
/// Multiple entries with the same name produce multiple header lines.
pub headers: Vec<(String, String)>,
}
struct ServerState {
routes: Vec<MockRoute>,
}
pub struct MockServer {
/// Base URL of the mock server, e.g. `"http://127.0.0.1:54321"`.
pub url: String,
handle: tokio::task::JoinHandle<()>,
}
impl MockServer {
/// Start a mock server with the given routes. Binds to a random port on
/// localhost and returns immediately once the server is listening.
pub async fn start(routes: Vec<MockRoute>) -> Self {
let state = Arc::new(ServerState { routes });
let app = Router::new().fallback(handle_request).with_state(state);
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("Failed to bind mock server port");
let addr: SocketAddr = listener.local_addr().expect("Failed to get local addr");
let url = format!("http://{addr}");
let handle = tokio::spawn(async move {
axum::serve(listener, app).await.expect("Mock server failed");
});
MockServer { url, handle }
}
/// Stop the mock server.
pub fn shutdown(self) {
self.handle.abort();
}
}
impl Drop for MockServer {
fn drop(&mut self) {
self.handle.abort();
}
}
async fn handle_request(State(state): State<Arc<ServerState>>, req: Request<Body>) -> Response {
let path = req.uri().path().to_owned();
let method = req.method().as_str().to_uppercase();
for route in &state.routes {
if route.path == path && route.method.to_uppercase() == method {
let status =
StatusCode::from_u16(route.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
if !route.stream_chunks.is_empty() {
// Build SSE body: data: <chunk>\n\n ... data: [DONE]\n\n
let mut sse = String::new();
for chunk in &route.stream_chunks {
sse.push_str("data: ");
sse.push_str(chunk);
sse.push_str("\n\n");
}
sse.push_str("data: [DONE]\n\n");
let mut builder = Response::builder()
.status(status)
.header("content-type", "text/event-stream")
.header("cache-control", "no-cache");
for (name, value) in &route.headers {
builder = builder.header(name, value);
}
return builder.body(Body::from(sse)).unwrap().into_response();
}
let mut builder =
Response::builder().status(status).header("content-type", "application/json");
for (name, value) in &route.headers {
builder = builder.header(name, value);
}
return builder.body(Body::from(route.body.clone())).unwrap().into_response();
}
}
// No matching route → 404.
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from(format!("No mock route for {method} {path}")))
.unwrap()
.into_response()
}
"#
}
pub fn render_mock_server_binary() -> String {
hash::header(CommentStyle::DoubleSlash)
+ r#"//
// Standalone mock HTTP server binary for cross-language e2e tests.
// Reads fixture JSON files and serves mock responses on /fixtures/{fixture_id}.
//
// Usage: mock-server [fixtures-dir]
// fixtures-dir defaults to "../../fixtures"
//
// Prints `MOCK_SERVER_URL=http://127.0.0.1:<port>` to stdout once listening,
// then blocks until stdin is closed (parent process exit triggers cleanup).
use std::collections::HashMap;
use std::io::{self, BufRead};
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
use axum::Router;
use axum::body::Body;
use axum::extract::State;
use axum::http::{Request, StatusCode};
use axum::response::{IntoResponse, Response};
use serde::Deserialize;
use tokio::net::TcpListener;
// ---------------------------------------------------------------------------
// Fixture types (mirrors alef-e2e's fixture.rs for runtime deserialization)
// Supports both schemas:
// liter-llm: mock_response: { status, body, stream_chunks }
// spikard: http.expected_response: { status_code, body, headers }
// ---------------------------------------------------------------------------
#[derive(Debug, Deserialize)]
struct MockResponse {
status: u16,
#[serde(default)]
body: Option<serde_json::Value>,
#[serde(default)]
stream_chunks: Option<Vec<serde_json::Value>>,
#[serde(default)]
headers: HashMap<String, String>,
}
#[derive(Debug, Deserialize)]
struct HttpExpectedResponse {
status_code: u16,
#[serde(default)]
body: Option<serde_json::Value>,
#[serde(default)]
headers: HashMap<String, String>,
}
#[derive(Debug, Deserialize)]
struct HttpFixture {
expected_response: HttpExpectedResponse,
}
#[derive(Debug, Deserialize)]
struct Fixture {
id: String,
#[serde(default)]
mock_response: Option<MockResponse>,
#[serde(default)]
http: Option<HttpFixture>,
}
impl Fixture {
/// Bridge both schemas into a unified MockResponse.
fn as_mock_response(&self) -> Option<MockResponse> {
if let Some(mock) = &self.mock_response {
return Some(MockResponse {
status: mock.status,
body: mock.body.clone(),
stream_chunks: mock.stream_chunks.clone(),
headers: mock.headers.clone(),
});
}
if let Some(http) = &self.http {
return Some(MockResponse {
status: http.expected_response.status_code,
body: http.expected_response.body.clone(),
stream_chunks: None,
headers: http.expected_response.headers.clone(),
});
}
None
}
}
// ---------------------------------------------------------------------------
// Route table
// ---------------------------------------------------------------------------
#[derive(Clone, Debug)]
struct MockRoute {
status: u16,
body: String,
stream_chunks: Vec<String>,
headers: Vec<(String, String)>,
}
type RouteTable = Arc<HashMap<String, MockRoute>>;
// ---------------------------------------------------------------------------
// Axum handler
// ---------------------------------------------------------------------------
async fn handle_request(State(routes): State<RouteTable>, req: Request<Body>) -> Response {
let path = req.uri().path().to_owned();
// Try exact match first
if let Some(route) = routes.get(&path) {
return serve_route(route);
}
// Try prefix match: find a route that is a prefix of the request path
// This allows /fixtures/basic_chat/v1/chat/completions to match /fixtures/basic_chat
for (route_path, route) in routes.iter() {
if path.starts_with(route_path) && (path.len() == route_path.len() || path.as_bytes()[route_path.len()] == b'/') {
return serve_route(route);
}
}
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from(format!("No mock route for {path}")))
.unwrap()
.into_response()
}
fn serve_route(route: &MockRoute) -> Response {
let status = StatusCode::from_u16(route.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
if !route.stream_chunks.is_empty() {
let mut sse = String::new();
for chunk in &route.stream_chunks {
sse.push_str("data: ");
sse.push_str(chunk);
sse.push_str("\n\n");
}
sse.push_str("data: [DONE]\n\n");
let mut builder = Response::builder()
.status(status)
.header("content-type", "text/event-stream")
.header("cache-control", "no-cache");
for (name, value) in &route.headers {
builder = builder.header(name, value);
}
return builder.body(Body::from(sse)).unwrap().into_response();
}
// Only set the default content-type if the fixture does not override it.
// Use application/json when the body looks like JSON (starts with { or [),
// otherwise fall back to text/plain to avoid clients failing JSON-decode.
let has_content_type = route.headers.iter().any(|(k, _)| k.to_lowercase() == "content-type");
let mut builder = Response::builder().status(status);
if !has_content_type {
let trimmed = route.body.trim_start();
let default_ct = if trimmed.starts_with('{') || trimmed.starts_with('[') {
"application/json"
} else {
"text/plain"
};
builder = builder.header("content-type", default_ct);
}
for (name, value) in &route.headers {
// Skip content-encoding headers — the mock server returns uncompressed bodies.
// Sending a content-encoding without actually encoding the body would cause
// clients to fail decompression.
if name.to_lowercase() == "content-encoding" {
continue;
}
// The <<absent>> sentinel means this header must NOT be present in the
// real server response — do not emit it from the mock server either.
if value == "<<absent>>" {
continue;
}
// Replace the <<uuid>> sentinel with a real UUID v4 so clients can
// assert the header value matches the UUID pattern.
if value == "<<uuid>>" {
let uuid = format!(
"{:08x}-{:04x}-4{:03x}-{:04x}-{:012x}",
rand_u32(),
rand_u16(),
rand_u16() & 0x0fff,
(rand_u16() & 0x3fff) | 0x8000,
rand_u48(),
);
builder = builder.header(name, uuid);
continue;
}
builder = builder.header(name, value);
}
builder.body(Body::from(route.body.clone())).unwrap().into_response()
}
/// Generate a pseudo-random u32 using the current time nanoseconds.
fn rand_u32() -> u32 {
use std::time::{SystemTime, UNIX_EPOCH};
let ns = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0);
ns ^ (ns.wrapping_shl(13)) ^ (ns.wrapping_shr(17))
}
fn rand_u16() -> u16 {
(rand_u32() & 0xffff) as u16
}
fn rand_u48() -> u64 {
((rand_u32() as u64) << 16) | (rand_u16() as u64)
}
// ---------------------------------------------------------------------------
// Fixture loading
// ---------------------------------------------------------------------------
fn load_routes(fixtures_dir: &Path) -> HashMap<String, MockRoute> {
let mut routes = HashMap::new();
load_routes_recursive(fixtures_dir, &mut routes);
routes
}
fn load_routes_recursive(dir: &Path, routes: &mut HashMap<String, MockRoute>) {
let entries = match std::fs::read_dir(dir) {
Ok(e) => e,
Err(err) => {
eprintln!("warning: cannot read directory {}: {err}", dir.display());
return;
}
};
let mut paths: Vec<_> = entries.filter_map(|e| e.ok()).map(|e| e.path()).collect();
paths.sort();
for path in paths {
if path.is_dir() {
load_routes_recursive(&path, routes);
} else if path.extension().is_some_and(|ext| ext == "json") {
let filename = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
if filename == "schema.json" || filename.starts_with('_') {
continue;
}
let content = match std::fs::read_to_string(&path) {
Ok(c) => c,
Err(err) => {
eprintln!("warning: cannot read {}: {err}", path.display());
continue;
}
};
let fixtures: Vec<Fixture> = if content.trim_start().starts_with('[') {
match serde_json::from_str(&content) {
Ok(v) => v,
Err(err) => {
eprintln!("warning: cannot parse {}: {err}", path.display());
continue;
}
}
} else {
match serde_json::from_str::<Fixture>(&content) {
Ok(f) => vec![f],
Err(err) => {
eprintln!("warning: cannot parse {}: {err}", path.display());
continue;
}
}
};
for fixture in fixtures {
if let Some(mock) = fixture.as_mock_response() {
let route_path = format!("/fixtures/{}", fixture.id);
let body = mock
.body
.as_ref()
.map(|b| match b {
// Plain strings (e.g. text/plain bodies) are stored as JSON strings in
// fixtures. Return the raw value so clients receive the string itself,
// not its JSON-encoded form with extra surrounding quotes.
serde_json::Value::String(s) => s.clone(),
other => serde_json::to_string(other).unwrap_or_default(),
})
.unwrap_or_default();
let stream_chunks = mock
.stream_chunks
.unwrap_or_default()
.into_iter()
.map(|c| match c {
serde_json::Value::String(s) => s,
other => serde_json::to_string(&other).unwrap_or_default(),
})
.collect();
let mut headers: Vec<(String, String)> =
mock.headers.into_iter().collect();
headers.sort_by(|a, b| a.0.cmp(&b.0));
routes.insert(route_path, MockRoute { status: mock.status, body, stream_chunks, headers });
}
}
}
}
}
// ---------------------------------------------------------------------------
// Entry point
// ---------------------------------------------------------------------------
#[tokio::main]
async fn main() {
let fixtures_dir_arg = std::env::args().nth(1).unwrap_or_else(|| "../../fixtures".to_string());
let fixtures_dir = Path::new(&fixtures_dir_arg);
let routes = load_routes(fixtures_dir);
eprintln!("mock-server: loaded {} routes from {}", routes.len(), fixtures_dir.display());
let route_table: RouteTable = Arc::new(routes);
let app = Router::new().fallback(handle_request).with_state(route_table);
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("mock-server: failed to bind port");
let addr: SocketAddr = listener.local_addr().expect("mock-server: failed to get local addr");
// Print the URL so the parent process can read it.
println!("MOCK_SERVER_URL=http://{addr}");
// Flush stdout explicitly so the parent does not block waiting.
use std::io::Write;
std::io::stdout().flush().expect("mock-server: failed to flush stdout");
// Spawn the server in the background.
tokio::spawn(async move {
axum::serve(listener, app).await.expect("mock-server: server error");
});
// Block until stdin is closed — the parent process controls lifetime.
let stdin = io::stdin();
let mut lines = stdin.lock().lines();
while lines.next().is_some() {}
}
"#
}
#[allow(clippy::too_many_arguments)]
fn render_assertion(
out: &mut String,
assertion: &Assertion,
result_var: &str,
module: &str,
dep_name: &str,
is_error_context: bool,
unwrapped_fields: &[(String, String)], field_resolver: &FieldResolver,
result_is_tree: bool,
result_is_simple: bool,
result_is_vec: bool,
result_is_option: bool,
) {
let has_field = assertion.field.as_ref().is_some_and(|f| !f.is_empty());
if result_is_vec && has_field && !is_error_context {
let _ = writeln!(out, " for r in &{result_var} {{");
render_assertion(
out,
assertion,
"r",
module,
dep_name,
is_error_context,
unwrapped_fields,
field_resolver,
result_is_tree,
result_is_simple,
false, result_is_option,
);
let _ = writeln!(out, " }}");
return;
}
if result_is_option && !is_error_context {
let assertion_type = assertion.assertion_type.as_str();
if !has_field && (assertion_type == "is_empty" || assertion_type == "not_empty") {
let check = if assertion_type == "is_empty" {
"is_none"
} else {
"is_some"
};
let _ = writeln!(
out,
" assert!({result_var}.{check}(), \"expected Option to be {check}\");"
);
return;
}
let _ = writeln!(
out,
" let r = {result_var}.as_ref().expect(\"Option<T> should be Some\");"
);
render_assertion(
out,
assertion,
"r",
module,
dep_name,
is_error_context,
unwrapped_fields,
field_resolver,
result_is_tree,
result_is_simple,
result_is_vec,
false, );
return;
}
let _ = dep_name;
if let Some(f) = &assertion.field {
match f.as_str() {
"chunks_have_content" => {
match assertion.assertion_type.as_str() {
"is_true" => {
let _ = writeln!(
out,
" assert!({result_var}.chunks.as_ref().is_some_and(|chunks| !chunks.is_empty() && chunks.iter().all(|c| !c.content.is_empty())), \"expected all chunks to have content\");"
);
}
"is_false" => {
let _ = writeln!(
out,
" assert!({result_var}.chunks.as_ref().is_none() || {result_var}.chunks.as_ref().unwrap().iter().any(|c| c.content.is_empty()), \"expected some chunks to be empty\");"
);
}
_ => {
let _ = writeln!(
out,
" // unsupported assertion type on synthetic field chunks_have_content"
);
}
}
return;
}
"chunks_have_embeddings" => {
match assertion.assertion_type.as_str() {
"is_true" => {
let _ = writeln!(
out,
" assert!({result_var}.chunks.as_ref().is_some_and(|c| c.iter().all(|ch| ch.embedding.as_ref().is_some_and(|e| !e.is_empty()))), \"expected all chunks to have embeddings\");"
);
}
"is_false" => {
let _ = writeln!(
out,
" assert!({result_var}.chunks.as_ref().is_none_or(|c| c.iter().any(|ch| ch.embedding.as_ref().is_none_or(|e| e.is_empty()))), \"expected some chunks to lack embeddings\");"
);
}
_ => {
let _ = writeln!(
out,
" // unsupported assertion type on synthetic field chunks_have_embeddings"
);
}
}
return;
}
"embeddings" => {
let embed_list = result_var.to_string();
match assertion.assertion_type.as_str() {
"count_equals" => {
if let Some(val) = &assertion.value {
if let Some(n) = val.as_u64() {
let _ = writeln!(
out,
" assert_eq!({embed_list}.len(), {n}, \"expected exactly {n} elements, got {{}}\", {embed_list}.len());"
);
}
}
}
"count_min" => {
if let Some(val) = &assertion.value {
if let Some(n) = val.as_u64() {
if n <= 1 {
let _ =
writeln!(out, " assert!(!{embed_list}.is_empty(), \"expected >= {n}\");");
} else {
let _ = writeln!(
out,
" assert!({embed_list}.len() >= {n}, \"expected at least {n} elements, got {{}}\", {embed_list}.len());"
);
}
}
}
}
"not_empty" => {
let _ = writeln!(
out,
" assert!(!{embed_list}.is_empty(), \"expected non-empty embeddings\");"
);
}
"is_empty" => {
let _ = writeln!(
out,
" assert!({embed_list}.is_empty(), \"expected empty embeddings\");"
);
}
_ => {
let _ = writeln!(
out,
" // skipped: unsupported assertion type on synthetic field 'embeddings'"
);
}
}
return;
}
"embedding_dimensions" => {
let embed_list = result_var;
let expr = format!("{embed_list}.first().map_or(0, |e| e.len())");
match assertion.assertion_type.as_str() {
"equals" => {
if let Some(val) = &assertion.value {
let lit = numeric_literal(val);
let _ = writeln!(
out,
" assert_eq!({expr}, {lit} as usize, \"equals assertion failed\");"
);
}
}
"greater_than" => {
if let Some(val) = &assertion.value {
let lit = numeric_literal(val);
let _ = writeln!(out, " assert!({expr} > {lit} as usize, \"expected > {lit}\");");
}
}
_ => {
let _ = writeln!(
out,
" // skipped: unsupported assertion type on synthetic field 'embedding_dimensions'"
);
}
}
return;
}
"embeddings_valid" | "embeddings_finite" | "embeddings_non_zero" | "embeddings_normalized" => {
let embed_list = result_var;
let pred = match f.as_str() {
"embeddings_valid" => {
format!("{embed_list}.iter().all(|e| !e.is_empty())")
}
"embeddings_finite" => {
format!("{embed_list}.iter().all(|e| e.iter().all(|v| v.is_finite()))")
}
"embeddings_non_zero" => {
format!("{embed_list}.iter().all(|e| e.iter().any(|v| *v != 0.0_f32))")
}
"embeddings_normalized" => {
format!(
"{embed_list}.iter().all(|e| {{ let n: f64 = e.iter().map(|v| f64::from(*v) * f64::from(*v)).sum(); (n - 1.0_f64).abs() < 1e-3 }})"
)
}
_ => unreachable!(),
};
match assertion.assertion_type.as_str() {
"is_true" => {
let _ = writeln!(out, " assert!({pred}, \"expected true\");");
}
"is_false" => {
let _ = writeln!(out, " assert!(!({pred}), \"expected false\");");
}
_ => {
let _ = writeln!(
out,
" // skipped: unsupported assertion type on synthetic field '{f}'"
);
}
}
return;
}
"keywords" => {
let accessor = format!("{result_var}.extracted_keywords");
match assertion.assertion_type.as_str() {
"not_empty" => {
let _ = writeln!(
out,
" assert!({accessor}.as_ref().is_some_and(|v| !v.is_empty()), \"expected keywords to be present and non-empty\");"
);
}
"is_empty" => {
let _ = writeln!(
out,
" assert!({accessor}.as_ref().is_none_or(|v| v.is_empty()), \"expected keywords to be empty or absent\");"
);
}
"count_min" => {
if let Some(val) = &assertion.value {
if let Some(n) = val.as_u64() {
if n <= 1 {
let _ = writeln!(
out,
" assert!({accessor}.as_ref().is_some_and(|v| !v.is_empty()), \"expected >= {n}\");"
);
} else {
let _ = writeln!(
out,
" assert!({accessor}.as_ref().is_some_and(|v| v.len() >= {n}), \"expected at least {n} keywords\");"
);
}
}
}
}
"count_equals" => {
if let Some(val) = &assertion.value {
if let Some(n) = val.as_u64() {
let _ = writeln!(
out,
" assert!({accessor}.as_ref().is_some_and(|v| v.len() == {n}), \"expected exactly {n} keywords\");"
);
}
}
}
_ => {
let _ = writeln!(
out,
" // skipped: unsupported assertion type on synthetic field 'keywords'"
);
}
}
return;
}
"keywords_count" => {
let expr = format!("{result_var}.extracted_keywords.as_ref().map_or(0, |v| v.len())");
match assertion.assertion_type.as_str() {
"equals" => {
if let Some(val) = &assertion.value {
let lit = numeric_literal(val);
let _ = writeln!(
out,
" assert_eq!({expr}, {lit} as usize, \"equals assertion failed\");"
);
}
}
"less_than_or_equal" => {
if let Some(val) = &assertion.value {
let lit = numeric_literal(val);
let _ = writeln!(out, " assert!({expr} <= {lit} as usize, \"expected <= {lit}\");");
}
}
"greater_than_or_equal" => {
if let Some(val) = &assertion.value {
let lit = numeric_literal(val);
let _ = writeln!(out, " assert!({expr} >= {lit} as usize, \"expected >= {lit}\");");
}
}
"greater_than" => {
if let Some(val) = &assertion.value {
let lit = numeric_literal(val);
let _ = writeln!(out, " assert!({expr} > {lit} as usize, \"expected > {lit}\");");
}
}
"less_than" => {
if let Some(val) = &assertion.value {
let lit = numeric_literal(val);
let _ = writeln!(out, " assert!({expr} < {lit} as usize, \"expected < {lit}\");");
}
}
_ => {
let _ = writeln!(
out,
" // skipped: unsupported assertion type on synthetic field 'keywords_count'"
);
}
}
return;
}
_ => {}
}
}
if let Some(f) = &assertion.field {
if !f.is_empty() && !field_resolver.is_valid_for_result(f) {
let _ = writeln!(out, " // skipped: field '{f}' not available on result type");
return;
}
}
let field_access = match &assertion.field {
Some(f) if !f.is_empty() => {
if let Some((_, local_var)) = unwrapped_fields.iter().find(|(ff, _)| ff == f) {
local_var.clone()
} else if result_is_simple {
result_var.to_string()
} else if f == result_var {
result_var.to_string()
} else if result_is_tree {
tree_field_access_expr(f, result_var, module)
} else {
field_resolver.accessor(f, "rust", result_var)
}
}
_ => result_var.to_string(),
};
let is_unwrapped = assertion
.field
.as_ref()
.is_some_and(|f| unwrapped_fields.iter().any(|(ff, _)| ff == f));
match assertion.assertion_type.as_str() {
"error" => {
let _ = writeln!(out, " assert!({result_var}.is_err(), \"expected call to fail\");");
if let Some(serde_json::Value::String(msg)) = &assertion.value {
let escaped = escape_rust(msg);
let _ = writeln!(
out,
" assert!({result_var}.as_ref().unwrap_err().to_string().contains(\"{escaped}\"), \"error message mismatch\");"
);
}
}
"not_error" => {
}
"equals" => {
if let Some(val) = &assertion.value {
let expected = value_to_rust_string(val);
if is_error_context {
return;
}
if val.is_string() {
let is_opt_str_not_unwrapped = assertion.field.as_ref().is_some_and(|f| {
let resolved = field_resolver.resolve(f);
let is_opt = field_resolver.is_optional(resolved);
let is_arr = field_resolver.is_array(resolved);
is_opt && !is_arr && !is_unwrapped
});
let field_expr = if is_opt_str_not_unwrapped {
format!("{field_access}.as_deref().unwrap_or(\"\").trim()")
} else {
format!("{field_access}.trim()")
};
let _ = writeln!(
out,
" assert_eq!({field_expr}, {expected}, \"equals assertion failed\");"
);
} else if val.is_boolean() {
if val.as_bool() == Some(true) {
let _ = writeln!(out, " assert!({field_access}, \"equals assertion failed\");");
} else {
let _ = writeln!(out, " assert!(!{field_access}, \"equals assertion failed\");");
}
} else {
let is_opt = assertion.field.as_ref().is_some_and(|f| {
let resolved = field_resolver.resolve(f);
field_resolver.is_optional(resolved)
});
if is_opt
&& !unwrapped_fields
.iter()
.any(|(ff, _)| assertion.field.as_ref() == Some(ff))
{
let _ = writeln!(
out,
" assert_eq!({field_access}, Some({expected}), \"equals assertion failed\");"
);
} else {
let _ = writeln!(
out,
" assert_eq!({field_access}, {expected}, \"equals assertion failed\");"
);
}
}
}
}
"contains" => {
if let Some(val) = &assertion.value {
let expected = value_to_rust_string(val);
let line = format!(
" assert!(format!(\"{{:?}}\", {field_access}).contains({expected}), \"expected to contain: {{}}\", {expected});"
);
let _ = writeln!(out, "{line}");
}
}
"contains_all" => {
if let Some(values) = &assertion.values {
for val in values {
let expected = value_to_rust_string(val);
let line = format!(
" assert!(format!(\"{{:?}}\", {field_access}).contains({expected}), \"expected to contain: {{}}\", {expected});"
);
let _ = writeln!(out, "{line}");
}
}
}
"not_contains" => {
if let Some(val) = &assertion.value {
let expected = value_to_rust_string(val);
let line = format!(
" assert!(!format!(\"{{:?}}\", {field_access}).contains({expected}), \"expected NOT to contain: {{}}\", {expected});"
);
let _ = writeln!(out, "{line}");
}
}
"not_empty" => {
if let Some(f) = &assertion.field {
let resolved = field_resolver.resolve(f);
let is_opt = !is_unwrapped && field_resolver.is_optional(resolved);
let is_arr = field_resolver.is_array(resolved);
if is_opt && is_arr {
let accessor = field_resolver.accessor(f, "rust", result_var);
let _ = writeln!(
out,
" assert!({accessor}.as_ref().is_some_and(|v| !v.is_empty()), \"expected {f} to be present and non-empty\");"
);
} else if is_opt {
let accessor = field_resolver.accessor(f, "rust", result_var);
let _ = writeln!(
out,
" assert!({accessor}.is_some(), \"expected {f} to be present\");"
);
} else {
let _ = writeln!(
out,
" assert!(!{field_access}.is_empty(), \"expected non-empty value\");"
);
}
} else if result_is_option {
let _ = writeln!(
out,
" assert!({field_access}.is_some(), \"expected non-empty value\");"
);
} else {
let _ = writeln!(
out,
" assert!(!{field_access}.is_empty(), \"expected non-empty value\");"
);
}
}
"is_empty" => {
if let Some(f) = &assertion.field {
let resolved = field_resolver.resolve(f);
let is_opt = !is_unwrapped && field_resolver.is_optional(resolved);
let is_arr = field_resolver.is_array(resolved);
if is_opt && is_arr {
let accessor = field_resolver.accessor(f, "rust", result_var);
let _ = writeln!(
out,
" assert!({accessor}.as_ref().is_none_or(|v| v.is_empty()), \"expected {f} to be empty or absent\");"
);
} else if is_opt {
let accessor = field_resolver.accessor(f, "rust", result_var);
let _ = writeln!(out, " assert!({accessor}.is_none(), \"expected {f} to be absent\");");
} else {
let _ = writeln!(out, " assert!({field_access}.is_empty(), \"expected empty value\");");
}
} else {
let _ = writeln!(out, " assert!({field_access}.is_none(), \"expected empty value\");");
}
}
"contains_any" => {
if let Some(values) = &assertion.values {
let checks: Vec<String> = values
.iter()
.map(|v| {
let expected = value_to_rust_string(v);
format!("{field_access}.contains({expected})")
})
.collect();
let joined = checks.join(" || ");
let _ = writeln!(
out,
" assert!({joined}, \"expected to contain at least one of the specified values\");"
);
}
}
"greater_than" => {
if let Some(val) = &assertion.value {
if val.as_f64().is_some_and(|n| n < 0.0) {
let _ = writeln!(
out,
" // skipped: greater_than with negative value is always true for unsigned types"
);
} else if val.as_u64() == Some(0) {
let base = field_access.strip_suffix(".len()").unwrap_or(&field_access);
let _ = writeln!(out, " assert!(!{base}.is_empty(), \"expected > 0\");");
} else {
let lit = numeric_literal(val);
let _ = writeln!(out, " assert!({field_access} > {lit}, \"expected > {lit}\");");
}
}
}
"less_than" => {
if let Some(val) = &assertion.value {
let lit = numeric_literal(val);
let _ = writeln!(out, " assert!({field_access} < {lit}, \"expected < {lit}\");");
}
}
"greater_than_or_equal" => {
if let Some(val) = &assertion.value {
let lit = numeric_literal(val);
let is_opt_numeric = assertion.field.as_ref().is_some_and(|f| {
let resolved = field_resolver.resolve(f);
let is_opt = !is_unwrapped && field_resolver.is_optional(resolved);
let is_arr = field_resolver.is_array(resolved);
is_opt && !is_arr
});
if val.as_u64() == Some(1) && field_access.ends_with(".len()") {
let base = field_access.strip_suffix(".len()").unwrap_or(&field_access);
let _ = writeln!(out, " assert!(!{base}.is_empty(), \"expected >= 1\");");
} else if is_opt_numeric {
let _ = writeln!(
out,
" assert!({field_access}.unwrap_or(0) >= {lit}, \"expected >= {lit}\");"
);
} else {
let _ = writeln!(out, " assert!({field_access} >= {lit}, \"expected >= {lit}\");");
}
}
}
"less_than_or_equal" => {
if let Some(val) = &assertion.value {
let lit = numeric_literal(val);
let _ = writeln!(out, " assert!({field_access} <= {lit}, \"expected <= {lit}\");");
}
}
"starts_with" => {
if let Some(val) = &assertion.value {
let expected = value_to_rust_string(val);
let _ = writeln!(
out,
" assert!({field_access}.starts_with({expected}), \"expected to start with: {{}}\", {expected});"
);
}
}
"ends_with" => {
if let Some(val) = &assertion.value {
let expected = value_to_rust_string(val);
let _ = writeln!(
out,
" assert!({field_access}.ends_with({expected}), \"expected to end with: {{}}\", {expected});"
);
}
}
"min_length" => {
if let Some(val) = &assertion.value {
if let Some(n) = val.as_u64() {
let _ = writeln!(
out,
" assert!({field_access}.len() >= {n}, \"expected length >= {n}, got {{}}\", {field_access}.len());"
);
}
}
}
"max_length" => {
if let Some(val) = &assertion.value {
if let Some(n) = val.as_u64() {
let _ = writeln!(
out,
" assert!({field_access}.len() <= {n}, \"expected length <= {n}, got {{}}\", {field_access}.len());"
);
}
}
}
"count_min" => {
if let Some(val) = &assertion.value {
if let Some(n) = val.as_u64() {
let opt_arr_field = assertion.field.as_ref().is_some_and(|f| {
let resolved = field_resolver.resolve(f);
let is_opt = !is_unwrapped && field_resolver.is_optional(resolved);
let is_arr = field_resolver.is_array(resolved);
is_opt && is_arr
});
let base = field_access.strip_suffix(".len()").unwrap_or(&field_access);
if opt_arr_field {
if n <= 1 {
let _ = writeln!(
out,
" assert!({base}.as_ref().is_some_and(|v| !v.is_empty()), \"expected >= {n}\");"
);
} else {
let _ = writeln!(
out,
" assert!({base}.as_ref().is_some_and(|v| v.len() >= {n}), \"expected at least {n} elements\");"
);
}
} else if n <= 1 {
let _ = writeln!(out, " assert!(!{base}.is_empty(), \"expected >= {n}\");");
} else {
let _ = writeln!(
out,
" assert!({field_access}.len() >= {n}, \"expected at least {n} elements, got {{}}\", {field_access}.len());"
);
}
}
}
}
"count_equals" => {
if let Some(val) = &assertion.value {
if let Some(n) = val.as_u64() {
let opt_arr_field = assertion.field.as_ref().is_some_and(|f| {
let resolved = field_resolver.resolve(f);
let is_opt = !is_unwrapped && field_resolver.is_optional(resolved);
let is_arr = field_resolver.is_array(resolved);
is_opt && is_arr
});
let base = field_access.strip_suffix(".len()").unwrap_or(&field_access);
if opt_arr_field {
let _ = writeln!(
out,
" assert!({base}.as_ref().is_some_and(|v| v.len() == {n}), \"expected exactly {n} elements\");"
);
} else {
let _ = writeln!(
out,
" assert_eq!({field_access}.len(), {n}, \"expected exactly {n} elements, got {{}}\", {field_access}.len());"
);
}
}
}
}
"is_true" => {
let _ = writeln!(out, " assert!({field_access}, \"expected true\");");
}
"is_false" => {
let _ = writeln!(out, " assert!(!{field_access}, \"expected false\");");
}
"method_result" => {
if let Some(method_name) = &assertion.method {
let call_expr = if result_is_tree {
build_tree_call_expr(field_access.as_str(), method_name, assertion.args.as_ref(), module)
} else if let Some(args) = &assertion.args {
let arg_lit = json_to_rust_literal(args, "");
format!("{field_access}.{method_name}({arg_lit})")
} else {
format!("{field_access}.{method_name}()")
};
let returns_numeric = result_is_tree && is_tree_numeric_method(method_name);
let check = assertion.check.as_deref().unwrap_or("is_true");
match check {
"equals" => {
if let Some(val) = &assertion.value {
if val.is_boolean() {
if val.as_bool() == Some(true) {
let _ = writeln!(
out,
" assert!({call_expr}, \"method_result equals assertion failed\");"
);
} else {
let _ = writeln!(
out,
" assert!(!{call_expr}, \"method_result equals assertion failed\");"
);
}
} else {
let expected = value_to_rust_string(val);
let _ = writeln!(
out,
" assert_eq!({call_expr}, {expected}, \"method_result equals assertion failed\");"
);
}
}
}
"is_true" => {
let _ = writeln!(
out,
" assert!({call_expr}, \"method_result is_true assertion failed\");"
);
}
"is_false" => {
let _ = writeln!(
out,
" assert!(!{call_expr}, \"method_result is_false assertion failed\");"
);
}
"greater_than_or_equal" => {
if let Some(val) = &assertion.value {
let lit = numeric_literal(val);
if returns_numeric {
let _ = writeln!(out, " assert!({call_expr} >= {lit}, \"expected >= {lit}\");");
} else if val.as_u64() == Some(1) {
let _ = writeln!(out, " assert!(!{call_expr}.is_empty(), \"expected >= 1\");");
} else {
let _ = writeln!(out, " assert!({call_expr} >= {lit}, \"expected >= {lit}\");");
}
}
}
"count_min" => {
if let Some(val) = &assertion.value {
let n = val.as_u64().unwrap_or(0);
if n <= 1 {
let _ = writeln!(out, " assert!(!{call_expr}.is_empty(), \"expected >= {n}\");");
} else {
let _ = writeln!(
out,
" assert!({call_expr}.len() >= {n}, \"expected at least {n} elements, got {{}}\", {call_expr}.len());"
);
}
}
}
"is_error" => {
let raw_call = call_expr.strip_suffix(".unwrap()").unwrap_or(&call_expr);
let _ = writeln!(
out,
" assert!({raw_call}.is_err(), \"expected method to return error\");"
);
}
"contains" => {
if let Some(val) = &assertion.value {
let expected = value_to_rust_string(val);
let _ = writeln!(
out,
" assert!({call_expr}.contains({expected}), \"expected result to contain {{}}\", {expected});"
);
}
}
"not_empty" => {
let _ = writeln!(
out,
" assert!(!{call_expr}.is_empty(), \"expected non-empty result\");"
);
}
"is_empty" => {
let _ = writeln!(out, " assert!({call_expr}.is_empty(), \"expected empty result\");");
}
other_check => {
panic!("Rust e2e generator: unsupported method_result check type: {other_check}");
}
}
} else {
panic!("Rust e2e generator: method_result assertion missing 'method' field");
}
}
other => {
panic!("Rust e2e generator: unsupported assertion type: {other}");
}
}
}
fn tree_field_access_expr(field: &str, result_var: &str, module: &str) -> String {
match field {
"root_child_count" => format!("{result_var}.root_node().child_count()"),
"root_node_type" => format!("{result_var}.root_node().kind()"),
"named_children_count" => format!("{result_var}.root_node().named_child_count()"),
"has_error_nodes" => format!("{module}::tree_has_error_nodes(&{result_var})"),
"error_count" | "tree_error_count" => format!("{module}::tree_error_count(&{result_var})"),
"tree_to_sexp" => format!("{module}::tree_to_sexp(&{result_var})"),
other => format!("{result_var}.{other}"),
}
}
fn build_tree_call_expr(
field_access: &str,
method_name: &str,
args: Option<&serde_json::Value>,
module: &str,
) -> String {
match method_name {
"root_child_count" => format!("{field_access}.root_node().child_count()"),
"root_node_type" => format!("{field_access}.root_node().kind()"),
"named_children_count" => format!("{field_access}.root_node().named_child_count()"),
"has_error_nodes" => format!("{module}::tree_has_error_nodes(&{field_access})"),
"error_count" | "tree_error_count" => format!("{module}::tree_error_count(&{field_access})"),
"tree_to_sexp" => format!("{module}::tree_to_sexp(&{field_access})"),
"contains_node_type" => {
let node_type = args
.and_then(|a| a.get("node_type"))
.and_then(|v| v.as_str())
.unwrap_or("");
format!("{module}::tree_contains_node_type(&{field_access}, \"{node_type}\")")
}
"find_nodes_by_type" => {
let node_type = args
.and_then(|a| a.get("node_type"))
.and_then(|v| v.as_str())
.unwrap_or("");
format!("{module}::find_nodes_by_type(&{field_access}, \"{node_type}\")")
}
"run_query" => {
let query_source = args
.and_then(|a| a.get("query_source"))
.and_then(|v| v.as_str())
.unwrap_or("");
let language = args
.and_then(|a| a.get("language"))
.and_then(|v| v.as_str())
.unwrap_or("");
format!(
"{module}::run_query(&{field_access}, \"{language}\", r#\"{query_source}\"#, source.as_bytes()).unwrap()"
)
}
_ => {
if let Some(args) = args {
let arg_lit = json_to_rust_literal(args, "");
format!("{field_access}.{method_name}({arg_lit})")
} else {
format!("{field_access}.{method_name}()")
}
}
}
}
fn is_tree_numeric_method(method_name: &str) -> bool {
matches!(
method_name,
"root_child_count" | "named_children_count" | "error_count" | "tree_error_count"
)
}
fn numeric_literal(value: &serde_json::Value) -> String {
if let Some(n) = value.as_f64() {
if n.fract() == 0.0 {
return format!("{}", n as i64);
}
return format!("{n}_f64");
}
value.to_string()
}
fn value_to_rust_string(value: &serde_json::Value) -> String {
match value {
serde_json::Value::String(s) => rust_raw_string(s),
serde_json::Value::Bool(b) => format!("{b}"),
serde_json::Value::Number(n) => n.to_string(),
other => {
let s = other.to_string();
format!("\"{s}\"")
}
}
}
fn resolve_visitor_trait(module: &str) -> String {
if module.contains("html_to_markdown") {
"HtmlVisitor".to_string()
} else {
"Visitor".to_string()
}
}
fn emit_rust_visitor_method(out: &mut String, method_name: &str, action: &CallbackAction) {
let params = match method_name {
"visit_link" => "_: &NodeContext, _: &str, _: &str, _: &str",
"visit_image" => "_: &NodeContext, _: &str, _: &str, _: &str",
"visit_heading" => "_: &NodeContext, _: u8, _: &str, _: Option<&str>",
"visit_code_block" => "_: &NodeContext, _: Option<&str>, _: &str",
"visit_code_inline"
| "visit_strong"
| "visit_emphasis"
| "visit_strikethrough"
| "visit_underline"
| "visit_subscript"
| "visit_superscript"
| "visit_mark"
| "visit_button"
| "visit_summary"
| "visit_figcaption"
| "visit_definition_term"
| "visit_definition_description" => "_: &NodeContext, _: &str",
"visit_text" => "_: &NodeContext, _: &str",
"visit_list_item" => "_: &NodeContext, _: bool, _: &str, _: &str",
"visit_blockquote" => "_: &NodeContext, _: &str, _: u32",
"visit_table_row" => "_: &NodeContext, _: &[String], _: bool",
"visit_custom_element" => "_: &NodeContext, _: &str, _: &str",
"visit_form" => "_: &NodeContext, _: &str, _: &str",
"visit_input" => "_: &NodeContext, _: &str, _: &str, _: &str",
"visit_audio" | "visit_video" | "visit_iframe" => "_: &NodeContext, _: &str",
"visit_details" => "_: &NodeContext, _: bool",
"visit_element_end" | "visit_table_end" | "visit_definition_list_end" | "visit_figure_end" => {
"_: &NodeContext, _: &str"
}
"visit_list_start" => "_: &NodeContext, _: bool",
"visit_list_end" => "_: &NodeContext, _: bool, _: &str",
_ => "_: &NodeContext",
};
let _ = writeln!(out, " fn {method_name}(&mut self, {params}) -> VisitResult {{");
match action {
CallbackAction::Skip => {
let _ = writeln!(out, " VisitResult::Skip");
}
CallbackAction::Continue => {
let _ = writeln!(out, " VisitResult::Continue");
}
CallbackAction::PreserveHtml => {
let _ = writeln!(out, " VisitResult::PreserveHtml");
}
CallbackAction::Custom { output } => {
let escaped = escape_rust(output);
let _ = writeln!(out, " VisitResult::Custom(\"{escaped}\".to_string())");
}
CallbackAction::CustomTemplate { template } => {
let escaped = escape_rust(template);
let _ = writeln!(out, " VisitResult::Custom(format!(\"{escaped}\"))");
}
}
let _ = writeln!(out, " }}");
}