1use crate::config::E2eConfig;
4use crate::escape::{go_string_literal, sanitize_filename};
5use crate::field_access::FieldResolver;
6use crate::fixture::{Assertion, CallbackAction, Fixture, FixtureGroup};
7use alef_core::backend::GeneratedFile;
8use alef_core::config::AlefConfig;
9use anyhow::Result;
10use heck::ToUpperCamelCase;
11use std::fmt::Write as FmtWrite;
12use std::path::PathBuf;
13
14use super::E2eCodegen;
15
16pub struct GoCodegen;
18
19impl E2eCodegen for GoCodegen {
20 fn generate(
21 &self,
22 groups: &[FixtureGroup],
23 e2e_config: &E2eConfig,
24 alef_config: &AlefConfig,
25 ) -> Result<Vec<GeneratedFile>> {
26 let lang = self.language_name();
27 let output_base = PathBuf::from(e2e_config.effective_output()).join(lang);
28
29 let mut files = Vec::new();
30
31 let call = &e2e_config.call;
33 let overrides = call.overrides.get(lang);
34 let module_path = overrides
35 .and_then(|o| o.module.as_ref())
36 .cloned()
37 .unwrap_or_else(|| call.module.clone());
38 let import_alias = overrides
39 .and_then(|o| o.alias.as_ref())
40 .cloned()
41 .unwrap_or_else(|| "pkg".to_string());
42
43 let go_pkg = e2e_config.resolve_package("go");
45 let go_module_path = go_pkg
46 .as_ref()
47 .and_then(|p| p.module.as_ref())
48 .cloned()
49 .unwrap_or_else(|| module_path.clone());
50 let replace_path = go_pkg.as_ref().and_then(|p| p.path.as_ref()).cloned();
51 let go_version = go_pkg
52 .as_ref()
53 .and_then(|p| p.version.as_ref())
54 .cloned()
55 .unwrap_or_else(|| {
56 alef_config
57 .resolved_version()
58 .map(|v| format!("v{v}"))
59 .unwrap_or_else(|| "v0.0.0".to_string())
60 });
61 let field_resolver = FieldResolver::new(
62 &e2e_config.fields,
63 &e2e_config.fields_optional,
64 &e2e_config.result_fields,
65 &e2e_config.fields_array,
66 );
67
68 let effective_replace = match e2e_config.dep_mode {
71 crate::config::DependencyMode::Registry => None,
72 crate::config::DependencyMode::Local => replace_path.as_deref().map(String::from),
73 };
74 files.push(GeneratedFile {
75 path: output_base.join("go.mod"),
76 content: render_go_mod(&go_module_path, effective_replace.as_deref(), &go_version),
77 generated_header: false,
78 });
79
80 for group in groups {
82 let active: Vec<&Fixture> = group
83 .fixtures
84 .iter()
85 .filter(|f| f.skip.as_ref().is_none_or(|s| !s.should_skip(lang)))
86 .collect();
87
88 if active.is_empty() {
89 continue;
90 }
91
92 let filename = format!("{}_test.go", sanitize_filename(&group.category));
93 let content = render_test_file(
94 &group.category,
95 &active,
96 &module_path,
97 &import_alias,
98 &field_resolver,
99 e2e_config,
100 );
101 files.push(GeneratedFile {
102 path: output_base.join(filename),
103 content,
104 generated_header: true,
105 });
106 }
107
108 Ok(files)
109 }
110
111 fn language_name(&self) -> &'static str {
112 "go"
113 }
114}
115
116fn render_go_mod(go_module_path: &str, replace_path: Option<&str>, version: &str) -> String {
117 let mut out = String::new();
118 let _ = writeln!(out, "module e2e_go");
119 let _ = writeln!(out);
120 let _ = writeln!(out, "go 1.26");
121 let _ = writeln!(out);
122 let _ = writeln!(out, "require {go_module_path} {version}");
123
124 if let Some(path) = replace_path {
125 let _ = writeln!(out);
126 let _ = writeln!(out, "replace {go_module_path} => {path}");
127 }
128
129 out
130}
131
132fn render_test_file(
133 category: &str,
134 fixtures: &[&Fixture],
135 go_module_path: &str,
136 import_alias: &str,
137 field_resolver: &FieldResolver,
138 e2e_config: &crate::config::E2eConfig,
139) -> String {
140 let mut out = String::new();
141
142 let _ = writeln!(out, "// Code generated by alef. DO NOT EDIT.");
144 let _ = writeln!(out);
145
146 let needs_os = fixtures.iter().any(|f| {
149 let call_args = &e2e_config.resolve_call(f.call.as_deref()).args;
150 call_args.iter().any(|a| a.arg_type == "mock_url")
151 });
152
153 let needs_json = fixtures.iter().any(|f| {
155 let call_args = &e2e_config.resolve_call(f.call.as_deref()).args;
156 call_args.iter().any(|a| a.arg_type == "handle") && {
157 call_args.iter().filter(|a| a.arg_type == "handle").any(|a| {
158 let v = f.input.get(&a.field).unwrap_or(&serde_json::Value::Null);
159 !(v.is_null() || v.is_object() && v.as_object().is_some_and(|o| o.is_empty()))
160 })
161 }
162 });
163
164 let needs_fmt = fixtures.iter().any(|f| {
166 f.visitor.as_ref().is_some_and(|v| {
167 v.callbacks.values().any(|action| {
168 if let CallbackAction::CustomTemplate { template } = action {
169 template.contains('{')
170 } else {
171 false
172 }
173 })
174 })
175 });
176
177 let needs_strings = fixtures.iter().any(|f| {
180 f.assertions.iter().any(|a| {
181 let type_needs_strings = if a.assertion_type == "equals" {
182 a.value.as_ref().is_some_and(|v| v.is_string())
184 } else {
185 matches!(
186 a.assertion_type.as_str(),
187 "contains" | "contains_all" | "not_contains" | "starts_with"
188 )
189 };
190 let field_valid = a
191 .field
192 .as_ref()
193 .map(|f| f.is_empty() || field_resolver.is_valid_for_result(f))
194 .unwrap_or(true);
195 type_needs_strings && field_valid
196 })
197 });
198
199 let needs_assert = fixtures.iter().any(|f| {
202 f.assertions.iter().any(|a| {
203 let field_valid = a
204 .field
205 .as_ref()
206 .map(|f| f.is_empty() || field_resolver.is_valid_for_result(f))
207 .unwrap_or(true);
208 let type_needs_assert = matches!(
209 a.assertion_type.as_str(),
210 "count_min" | "count_max" | "is_true" | "is_false" | "method_result"
211 );
212 type_needs_assert && field_valid
213 })
214 });
215
216 let _ = writeln!(out, "// E2e tests for category: {category}");
217 let _ = writeln!(out, "package e2e_test");
218 let _ = writeln!(out);
219 let _ = writeln!(out, "import (");
220 if needs_json {
221 let _ = writeln!(out, "\t\"encoding/json\"");
222 }
223 if needs_fmt {
224 let _ = writeln!(out, "\t\"fmt\"");
225 }
226 if needs_os {
227 let _ = writeln!(out, "\t\"os\"");
228 }
229 if needs_strings {
230 let _ = writeln!(out, "\t\"strings\"");
231 }
232 let _ = writeln!(out, "\t\"testing\"");
233 if needs_assert {
234 let _ = writeln!(out);
235 let _ = writeln!(out, "\t\"github.com/stretchr/testify/assert\"");
236 }
237 let _ = writeln!(out);
238 let _ = writeln!(out, "\t{import_alias} \"{go_module_path}\"");
239 let _ = writeln!(out, ")");
240 let _ = writeln!(out);
241
242 for fixture in fixtures.iter() {
244 if let Some(visitor_spec) = &fixture.visitor {
245 let struct_name = visitor_struct_name(&fixture.id);
246 emit_go_visitor_struct(&mut out, &struct_name, visitor_spec, import_alias);
247 let _ = writeln!(out);
248 }
249 }
250
251 for (i, fixture) in fixtures.iter().enumerate() {
252 render_test_function(&mut out, fixture, import_alias, field_resolver, e2e_config);
253 if i + 1 < fixtures.len() {
254 let _ = writeln!(out);
255 }
256 }
257
258 while out.ends_with("\n\n") {
260 out.pop();
261 }
262 if !out.ends_with('\n') {
263 out.push('\n');
264 }
265 out
266}
267
268fn render_test_function(
269 out: &mut String,
270 fixture: &Fixture,
271 import_alias: &str,
272 field_resolver: &FieldResolver,
273 e2e_config: &crate::config::E2eConfig,
274) {
275 let fn_name = fixture.id.to_upper_camel_case();
276 let description = &fixture.description;
277
278 let call_config = e2e_config.resolve_call(fixture.call.as_deref());
280 let lang = "go";
281 let overrides = call_config.overrides.get(lang);
282 let function_name = overrides
283 .and_then(|o| o.function.as_ref())
284 .cloned()
285 .unwrap_or_else(|| call_config.function.clone());
286 let result_var = &call_config.result_var;
287 let args = &call_config.args;
288
289 let expects_error = fixture.assertions.iter().any(|a| a.assertion_type == "error");
290
291 let (mut setup_lines, args_str) = build_args_and_setup(&fixture.input, args, import_alias, e2e_config, &fixture.id);
292
293 let mut visitor_arg = String::new();
295 if fixture.visitor.is_some() {
296 let struct_name = visitor_struct_name(&fixture.id);
297 setup_lines.push(format!("visitor := &{struct_name}{{}}"));
298 visitor_arg = "visitor".to_string();
299 }
300
301 let final_args = if visitor_arg.is_empty() {
302 args_str
303 } else {
304 format!("{args_str}, {visitor_arg}")
305 };
306
307 let _ = writeln!(out, "func Test_{fn_name}(t *testing.T) {{");
308 let _ = writeln!(out, "\t// {description}");
309
310 for line in &setup_lines {
311 let _ = writeln!(out, "\t{line}");
312 }
313
314 if expects_error {
315 let _ = writeln!(out, "\t_, err := {import_alias}.{function_name}({final_args})");
316 let _ = writeln!(out, "\tif err == nil {{");
317 let _ = writeln!(out, "\t\tt.Errorf(\"expected an error, but call succeeded\")");
318 let _ = writeln!(out, "\t}}");
319 let _ = writeln!(out, "}}");
320 return;
321 }
322
323 let has_usable_assertion = fixture.assertions.iter().any(|a| {
327 if a.assertion_type == "not_error" || a.assertion_type == "error" {
328 return false;
329 }
330 if a.assertion_type == "method_result" {
332 return true;
333 }
334 match &a.field {
335 Some(f) if !f.is_empty() => field_resolver.is_valid_for_result(f),
336 _ => true,
337 }
338 });
339
340 let result_binding = if has_usable_assertion {
341 result_var.to_string()
342 } else {
343 "_".to_string()
344 };
345
346 let _ = writeln!(
348 out,
349 "\t{result_binding}, err := {import_alias}.{function_name}({final_args})"
350 );
351 let _ = writeln!(out, "\tif err != nil {{");
352 let _ = writeln!(out, "\t\tt.Fatalf(\"call failed: %v\", err)");
353 let _ = writeln!(out, "\t}}");
354
355 let mut optional_locals: std::collections::HashMap<String, String> = std::collections::HashMap::new();
360 for assertion in &fixture.assertions {
361 if let Some(f) = &assertion.field {
362 if !f.is_empty() {
363 let resolved = field_resolver.resolve(f);
364 if field_resolver.is_optional(resolved) && !optional_locals.contains_key(f.as_str()) {
365 let is_string_field = assertion.value.as_ref().is_some_and(|v| v.is_string());
368 if !is_string_field {
369 continue;
372 }
373 let field_expr = field_resolver.accessor(f, "go", result_var);
374 let local_var = go_local_name(&resolved.replace(['.', '[', ']'], "_"));
375 if field_resolver.has_map_access(f) {
376 let _ = writeln!(out, "\t{local_var} := {field_expr}");
379 } else {
380 let _ = writeln!(out, "\tvar {local_var} string");
381 let _ = writeln!(out, "\tif {field_expr} != nil {{");
382 let _ = writeln!(out, "\t\t{local_var} = *{field_expr}");
383 let _ = writeln!(out, "\t}}");
384 }
385 optional_locals.insert(f.clone(), local_var);
386 }
387 }
388 }
389 }
390
391 for assertion in &fixture.assertions {
393 if let Some(f) = &assertion.field {
394 if !f.is_empty() && !optional_locals.contains_key(f.as_str()) {
395 let parts: Vec<&str> = f.split('.').collect();
398 let mut guard_expr: Option<String> = None;
399 for i in 1..parts.len() {
400 let prefix = parts[..i].join(".");
401 let resolved_prefix = field_resolver.resolve(&prefix);
402 if field_resolver.is_optional(resolved_prefix) {
403 let accessor = field_resolver.accessor(&prefix, "go", result_var);
404 guard_expr = Some(accessor);
405 break;
406 }
407 }
408 if let Some(guard) = guard_expr {
409 if field_resolver.is_valid_for_result(f) {
412 let _ = writeln!(out, "\tif {guard} != nil {{");
413 let mut nil_buf = String::new();
416 render_assertion(
417 &mut nil_buf,
418 assertion,
419 result_var,
420 import_alias,
421 field_resolver,
422 &optional_locals,
423 );
424 for line in nil_buf.lines() {
425 let _ = writeln!(out, "\t{line}");
426 }
427 let _ = writeln!(out, "\t}}");
428 } else {
429 render_assertion(
430 out,
431 assertion,
432 result_var,
433 import_alias,
434 field_resolver,
435 &optional_locals,
436 );
437 }
438 continue;
439 }
440 }
441 }
442 render_assertion(
443 out,
444 assertion,
445 result_var,
446 import_alias,
447 field_resolver,
448 &optional_locals,
449 );
450 }
451
452 let _ = writeln!(out, "}}");
453}
454
455fn build_args_and_setup(
459 input: &serde_json::Value,
460 args: &[crate::config::ArgMapping],
461 import_alias: &str,
462 e2e_config: &crate::config::E2eConfig,
463 fixture_id: &str,
464) -> (Vec<String>, String) {
465 use heck::ToUpperCamelCase;
466
467 if args.is_empty() {
468 return (Vec::new(), json_to_go(input));
469 }
470
471 let overrides = e2e_config.call.overrides.get("go");
472 let options_type = overrides.and_then(|o| o.options_type.as_deref());
473
474 let mut setup_lines: Vec<String> = Vec::new();
475 let mut parts: Vec<String> = Vec::new();
476
477 for arg in args {
478 if arg.arg_type == "mock_url" {
479 setup_lines.push(format!(
480 "{} := os.Getenv(\"MOCK_SERVER_URL\") + \"/fixtures/{fixture_id}\"",
481 arg.name,
482 ));
483 parts.push(arg.name.clone());
484 continue;
485 }
486
487 if arg.arg_type == "handle" {
488 let constructor_name = format!("Create{}", arg.name.to_upper_camel_case());
490 let field = arg.field.strip_prefix("input.").unwrap_or(&arg.field);
491 let config_value = input.get(field).unwrap_or(&serde_json::Value::Null);
492 if config_value.is_null()
493 || config_value.is_object() && config_value.as_object().is_some_and(|o| o.is_empty())
494 {
495 setup_lines.push(format!(
496 "{name}, createErr := {import_alias}.{constructor_name}(nil)\n\tif createErr != nil {{\n\t\tt.Fatalf(\"create handle failed: %v\", createErr)\n\t}}",
497 name = arg.name,
498 ));
499 } else {
500 let json_str = serde_json::to_string(config_value).unwrap_or_default();
501 let go_literal = go_string_literal(&json_str);
502 let name = &arg.name;
503 setup_lines.push(format!(
504 "var {name}Config {import_alias}.CrawlConfig\n\tif err := json.Unmarshal([]byte({go_literal}), &{name}Config); err != nil {{\n\t\tt.Fatalf(\"config parse failed: %v\", err)\n\t}}"
505 ));
506 setup_lines.push(format!(
507 "{name}, createErr := {import_alias}.{constructor_name}(&{name}Config)\n\tif createErr != nil {{\n\t\tt.Fatalf(\"create handle failed: %v\", createErr)\n\t}}"
508 ));
509 }
510 parts.push(arg.name.clone());
511 continue;
512 }
513
514 let field = arg.field.strip_prefix("input.").unwrap_or(&arg.field);
515 let val = input.get(field);
516 match val {
517 None | Some(serde_json::Value::Null) if arg.optional => {
518 continue;
520 }
521 None | Some(serde_json::Value::Null) => {
522 let default_val = match arg.arg_type.as_str() {
524 "string" => "\"\"".to_string(),
525 "int" | "integer" => "0".to_string(),
526 "float" | "number" => "0.0".to_string(),
527 "bool" | "boolean" => "false".to_string(),
528 _ => "nil".to_string(),
529 };
530 parts.push(default_val);
531 }
532 Some(v) => {
533 if let (Some(opts_type), "json_object") = (options_type, arg.arg_type.as_str()) {
535 if let Some(obj) = v.as_object() {
536 let with_calls: Vec<String> = obj
537 .iter()
538 .map(|(k, vv)| {
539 let func_name = format!("With{}{}", opts_type, k.to_upper_camel_case());
540 let go_val = json_to_go(vv);
541 format!("htmd.{func_name}({go_val})")
542 })
543 .collect();
544 let new_fn = format!("New{opts_type}");
545 parts.push(format!("htmd.{new_fn}({})", with_calls.join(", ")));
546 continue;
547 }
548 }
549 parts.push(json_to_go(v));
550 }
551 }
552 }
553
554 (setup_lines, parts.join(", "))
555}
556
557fn render_assertion(
558 out: &mut String,
559 assertion: &Assertion,
560 result_var: &str,
561 import_alias: &str,
562 field_resolver: &FieldResolver,
563 optional_locals: &std::collections::HashMap<String, String>,
564) {
565 if let Some(f) = &assertion.field {
567 if !f.is_empty() && !field_resolver.is_valid_for_result(f) {
568 let _ = writeln!(out, "\t// skipped: field '{f}' not available on result type");
569 return;
570 }
571 }
572
573 let field_expr = match &assertion.field {
574 Some(f) if !f.is_empty() => {
575 if let Some(local_var) = optional_locals.get(f.as_str()) {
577 local_var.clone()
578 } else {
579 field_resolver.accessor(f, "go", result_var)
580 }
581 }
582 _ => result_var.to_string(),
583 };
584
585 let is_optional = assertion
589 .field
590 .as_ref()
591 .map(|f| {
592 let resolved = field_resolver.resolve(f);
593 let check_path = resolved
594 .strip_suffix(".length")
595 .or_else(|| resolved.strip_suffix(".count"))
596 .or_else(|| resolved.strip_suffix(".size"))
597 .unwrap_or(resolved);
598 field_resolver.is_optional(check_path) && !optional_locals.contains_key(f.as_str())
599 })
600 .unwrap_or(false);
601
602 let field_expr = if is_optional && field_expr.starts_with("len(") && field_expr.ends_with(')') {
605 let inner = &field_expr[4..field_expr.len() - 1];
606 format!("len(*{inner})")
607 } else {
608 field_expr
609 };
610 let nil_guard_expr = if is_optional && field_expr.starts_with("len(*") {
612 Some(field_expr[5..field_expr.len() - 1].to_string())
613 } else {
614 None
615 };
616
617 let deref_field_expr = if is_optional && !field_expr.starts_with("len(") {
620 format!("*{field_expr}")
621 } else {
622 field_expr.clone()
623 };
624
625 let array_guard: Option<String> = if let Some(idx) = field_expr.find("[0]") {
630 let array_expr = &field_expr[..idx];
631 Some(array_expr.to_string())
632 } else {
633 None
634 };
635
636 let mut assertion_buf = String::new();
639 let out_ref = &mut assertion_buf;
640
641 match assertion.assertion_type.as_str() {
642 "equals" => {
643 if let Some(expected) = &assertion.value {
644 let go_val = json_to_go(expected);
645 if expected.is_string() {
647 let trimmed_field = if is_optional && !field_expr.starts_with("len(") {
649 format!("strings.TrimSpace(*{field_expr})")
650 } else {
651 format!("strings.TrimSpace({field_expr})")
652 };
653 if is_optional && !field_expr.starts_with("len(") {
654 let _ = writeln!(out_ref, "\tif {field_expr} != nil && {trimmed_field} != {go_val} {{");
655 } else {
656 let _ = writeln!(out_ref, "\tif {trimmed_field} != {go_val} {{");
657 }
658 } else if is_optional && !field_expr.starts_with("len(") {
659 let _ = writeln!(out_ref, "\tif {field_expr} != nil && {deref_field_expr} != {go_val} {{");
660 } else {
661 let _ = writeln!(out_ref, "\tif {field_expr} != {go_val} {{");
662 }
663 let _ = writeln!(out_ref, "\t\tt.Errorf(\"equals mismatch: got %v\", {field_expr})");
664 let _ = writeln!(out_ref, "\t}}");
665 }
666 }
667 "contains" => {
668 if let Some(expected) = &assertion.value {
669 let go_val = json_to_go(expected);
670 let field_for_contains = if is_optional
671 && !optional_locals.contains_key(assertion.field.as_ref().unwrap_or(&String::new()))
672 {
673 format!("string(*{field_expr})")
674 } else {
675 format!("string({field_expr})")
676 };
677 let _ = writeln!(out_ref, "\tif !strings.Contains({field_for_contains}, {go_val}) {{");
678 let _ = writeln!(
679 out_ref,
680 "\t\tt.Errorf(\"expected to contain %s, got %v\", {go_val}, {field_expr})"
681 );
682 let _ = writeln!(out_ref, "\t}}");
683 }
684 }
685 "contains_all" => {
686 if let Some(values) = &assertion.values {
687 for val in values {
688 let go_val = json_to_go(val);
689 let field_for_contains = if is_optional
690 && !optional_locals.contains_key(assertion.field.as_ref().unwrap_or(&String::new()))
691 {
692 format!("string(*{field_expr})")
693 } else {
694 format!("string({field_expr})")
695 };
696 let _ = writeln!(out_ref, "\tif !strings.Contains({field_for_contains}, {go_val}) {{");
697 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected to contain %s\", {go_val})");
698 let _ = writeln!(out_ref, "\t}}");
699 }
700 }
701 }
702 "not_contains" => {
703 if let Some(expected) = &assertion.value {
704 let go_val = json_to_go(expected);
705 let field_for_contains = if is_optional
706 && !optional_locals.contains_key(assertion.field.as_ref().unwrap_or(&String::new()))
707 {
708 format!("string(*{field_expr})")
709 } else {
710 format!("string({field_expr})")
711 };
712 let _ = writeln!(out_ref, "\tif strings.Contains({field_for_contains}, {go_val}) {{");
713 let _ = writeln!(
714 out_ref,
715 "\t\tt.Errorf(\"expected NOT to contain %s, got %v\", {go_val}, {field_expr})"
716 );
717 let _ = writeln!(out_ref, "\t}}");
718 }
719 }
720 "not_empty" => {
721 if is_optional {
722 let _ = writeln!(out_ref, "\tif {field_expr} == nil || len(*{field_expr}) == 0 {{");
723 } else {
724 let _ = writeln!(out_ref, "\tif len({field_expr}) == 0 {{");
725 }
726 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected non-empty value\")");
727 let _ = writeln!(out_ref, "\t}}");
728 }
729 "is_empty" => {
730 if is_optional {
731 let _ = writeln!(out_ref, "\tif {field_expr} != nil && len(*{field_expr}) != 0 {{");
732 } else {
733 let _ = writeln!(out_ref, "\tif len({field_expr}) != 0 {{");
734 }
735 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected empty value, got %v\", {field_expr})");
736 let _ = writeln!(out_ref, "\t}}");
737 }
738 "contains_any" => {
739 if let Some(values) = &assertion.values {
740 let field_for_contains = if is_optional
741 && !optional_locals.contains_key(assertion.field.as_ref().unwrap_or(&String::new()))
742 {
743 format!("*{field_expr}")
744 } else {
745 field_expr.clone()
746 };
747 let _ = writeln!(out_ref, "\t{{");
748 let _ = writeln!(out_ref, "\t\tfound := false");
749 for val in values {
750 let go_val = json_to_go(val);
751 let _ = writeln!(
752 out_ref,
753 "\t\tif strings.Contains({field_for_contains}, {go_val}) {{ found = true }}"
754 );
755 }
756 let _ = writeln!(out_ref, "\t\tif !found {{");
757 let _ = writeln!(
758 out_ref,
759 "\t\t\tt.Errorf(\"expected to contain at least one of the specified values\")"
760 );
761 let _ = writeln!(out_ref, "\t\t}}");
762 let _ = writeln!(out_ref, "\t}}");
763 }
764 }
765 "greater_than" => {
766 if let Some(val) = &assertion.value {
767 let go_val = json_to_go(val);
768 if let Some(n) = val.as_u64() {
771 let next = n + 1;
772 let _ = writeln!(out_ref, "\tif {field_expr} < {next} {{");
773 } else {
774 let _ = writeln!(out_ref, "\tif {field_expr} <= {go_val} {{");
775 }
776 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected > {go_val}, got %v\", {field_expr})");
777 let _ = writeln!(out_ref, "\t}}");
778 }
779 }
780 "less_than" => {
781 if let Some(val) = &assertion.value {
782 let go_val = json_to_go(val);
783 let _ = writeln!(out_ref, "\tif {field_expr} >= {go_val} {{");
784 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected < {go_val}, got %v\", {field_expr})");
785 let _ = writeln!(out_ref, "\t}}");
786 }
787 }
788 "greater_than_or_equal" => {
789 if let Some(val) = &assertion.value {
790 let go_val = json_to_go(val);
791 if let Some(ref guard) = nil_guard_expr {
792 let _ = writeln!(out_ref, "\tif {guard} != nil {{");
793 let _ = writeln!(out_ref, "\t\tif {field_expr} < {go_val} {{");
794 let _ = writeln!(
795 out_ref,
796 "\t\t\tt.Errorf(\"expected >= {go_val}, got %v\", {field_expr})"
797 );
798 let _ = writeln!(out_ref, "\t\t}}");
799 let _ = writeln!(out_ref, "\t}}");
800 } else {
801 let _ = writeln!(out_ref, "\tif {field_expr} < {go_val} {{");
802 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected >= {go_val}, got %v\", {field_expr})");
803 let _ = writeln!(out_ref, "\t}}");
804 }
805 }
806 }
807 "less_than_or_equal" => {
808 if let Some(val) = &assertion.value {
809 let go_val = json_to_go(val);
810 let _ = writeln!(out_ref, "\tif {field_expr} > {go_val} {{");
811 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected <= {go_val}, got %v\", {field_expr})");
812 let _ = writeln!(out_ref, "\t}}");
813 }
814 }
815 "starts_with" => {
816 if let Some(expected) = &assertion.value {
817 let go_val = json_to_go(expected);
818 let field_for_prefix = if is_optional
819 && !optional_locals.contains_key(assertion.field.as_ref().unwrap_or(&String::new()))
820 {
821 format!("string(*{field_expr})")
822 } else {
823 format!("string({field_expr})")
824 };
825 let _ = writeln!(out_ref, "\tif !strings.HasPrefix({field_for_prefix}, {go_val}) {{");
826 let _ = writeln!(
827 out_ref,
828 "\t\tt.Errorf(\"expected to start with %s, got %v\", {go_val}, {field_expr})"
829 );
830 let _ = writeln!(out_ref, "\t}}");
831 }
832 }
833 "count_min" => {
834 if let Some(val) = &assertion.value {
835 if let Some(n) = val.as_u64() {
836 if is_optional {
837 let _ = writeln!(out_ref, "\tif {field_expr} != nil {{");
838 let _ = writeln!(
839 out_ref,
840 "\t\tassert.GreaterOrEqual(t, len(*{field_expr}), {n}, \"expected at least {n} elements\")"
841 );
842 let _ = writeln!(out_ref, "\t}}");
843 } else {
844 let _ = writeln!(
845 out_ref,
846 "\tassert.GreaterOrEqual(t, len({field_expr}), {n}, \"expected at least {n} elements\")"
847 );
848 }
849 }
850 }
851 }
852 "count_equals" => {
853 if let Some(val) = &assertion.value {
854 if let Some(n) = val.as_u64() {
855 if is_optional {
856 let _ = writeln!(out_ref, "\tif {field_expr} != nil {{");
857 let _ = writeln!(
858 out_ref,
859 "\t\tassert.Equal(t, len(*{field_expr}), {n}, \"expected exactly {n} elements\")"
860 );
861 let _ = writeln!(out_ref, "\t}}");
862 } else {
863 let _ = writeln!(
864 out_ref,
865 "\tassert.Equal(t, len({field_expr}), {n}, \"expected exactly {n} elements\")"
866 );
867 }
868 }
869 }
870 }
871 "is_true" => {
872 if is_optional {
873 let _ = writeln!(out_ref, "\tif {field_expr} != nil {{");
874 let _ = writeln!(out_ref, "\t\tassert.True(t, *{field_expr}, \"expected true\")");
875 let _ = writeln!(out_ref, "\t}}");
876 } else {
877 let _ = writeln!(out_ref, "\tassert.True(t, {field_expr}, \"expected true\")");
878 }
879 }
880 "is_false" => {
881 if is_optional {
882 let _ = writeln!(out_ref, "\tif {field_expr} != nil {{");
883 let _ = writeln!(out_ref, "\t\tassert.False(t, *{field_expr}, \"expected false\")");
884 let _ = writeln!(out_ref, "\t}}");
885 } else {
886 let _ = writeln!(out_ref, "\tassert.False(t, {field_expr}, \"expected false\")");
887 }
888 }
889 "method_result" => {
890 if let Some(method_name) = &assertion.method {
891 let info = build_go_method_call(result_var, method_name, assertion.args.as_ref(), import_alias);
892 let check = assertion.check.as_deref().unwrap_or("is_true");
893 let deref_expr = if info.is_pointer {
896 format!("*{}", info.call_expr)
897 } else {
898 info.call_expr.clone()
899 };
900 match check {
901 "equals" => {
902 if let Some(val) = &assertion.value {
903 if val.is_boolean() {
904 if val.as_bool() == Some(true) {
905 let _ = writeln!(out_ref, "\tassert.True(t, {deref_expr}, \"expected true\")");
906 } else {
907 let _ = writeln!(out_ref, "\tassert.False(t, {deref_expr}, \"expected false\")");
908 }
909 } else {
910 let go_val = if let Some(cast) = info.value_cast {
914 if val.is_number() {
915 format!("{cast}({})", json_to_go(val))
916 } else {
917 json_to_go(val)
918 }
919 } else {
920 json_to_go(val)
921 };
922 let _ = writeln!(
923 out_ref,
924 "\tassert.Equal(t, {go_val}, {deref_expr}, \"method_result equals assertion failed\")"
925 );
926 }
927 }
928 }
929 "is_true" => {
930 let _ = writeln!(out_ref, "\tassert.True(t, {deref_expr}, \"expected true\")");
931 }
932 "is_false" => {
933 let _ = writeln!(out_ref, "\tassert.False(t, {deref_expr}, \"expected false\")");
934 }
935 "greater_than_or_equal" => {
936 if let Some(val) = &assertion.value {
937 let n = val.as_u64().unwrap_or(0);
938 let cast = info.value_cast.unwrap_or("uint");
940 let _ = writeln!(
941 out_ref,
942 "\tassert.GreaterOrEqual(t, {deref_expr}, {cast}({n}), \"expected >= {n}\")"
943 );
944 }
945 }
946 "count_min" => {
947 if let Some(val) = &assertion.value {
948 let n = val.as_u64().unwrap_or(0);
949 let _ = writeln!(
950 out_ref,
951 "\tassert.GreaterOrEqual(t, len({deref_expr}), {n}, \"expected at least {n} elements\")"
952 );
953 }
954 }
955 "contains" => {
956 if let Some(val) = &assertion.value {
957 let go_val = json_to_go(val);
958 let _ = writeln!(
959 out_ref,
960 "\tassert.Contains(t, {deref_expr}, {go_val}, \"expected result to contain value\")"
961 );
962 }
963 }
964 "is_error" => {
965 let _ = writeln!(out_ref, "\t{{");
966 let _ = writeln!(out_ref, "\t\t_, methodErr := {}", info.call_expr);
967 let _ = writeln!(out_ref, "\t\tassert.Error(t, methodErr)");
968 let _ = writeln!(out_ref, "\t}}");
969 }
970 other_check => {
971 panic!("Go e2e generator: unsupported method_result check type: {other_check}");
972 }
973 }
974 } else {
975 panic!("Go e2e generator: method_result assertion missing 'method' field");
976 }
977 }
978 "not_error" => {
979 }
981 "error" => {
982 }
984 other => {
985 panic!("Go e2e generator: unsupported assertion type: {other}");
986 }
987 }
988
989 if let Some(ref arr) = array_guard {
992 if !assertion_buf.is_empty() {
993 let _ = writeln!(out, "\tif len({arr}) > 0 {{");
994 for line in assertion_buf.lines() {
996 let _ = writeln!(out, "\t{line}");
997 }
998 let _ = writeln!(out, "\t}}");
999 }
1000 } else {
1001 out.push_str(&assertion_buf);
1002 }
1003}
1004
1005struct GoMethodCallInfo {
1007 call_expr: String,
1009 is_pointer: bool,
1011 value_cast: Option<&'static str>,
1014}
1015
1016fn build_go_method_call(
1031 result_var: &str,
1032 method_name: &str,
1033 args: Option<&serde_json::Value>,
1034 import_alias: &str,
1035) -> GoMethodCallInfo {
1036 match method_name {
1037 "root_node_type" => GoMethodCallInfo {
1038 call_expr: format!("{import_alias}.RootNodeInfo({result_var}).Kind"),
1039 is_pointer: false,
1040 value_cast: None,
1041 },
1042 "named_children_count" => GoMethodCallInfo {
1043 call_expr: format!("{import_alias}.RootNodeInfo({result_var}).NamedChildCount"),
1044 is_pointer: false,
1045 value_cast: Some("uint"),
1046 },
1047 "has_error_nodes" => GoMethodCallInfo {
1048 call_expr: format!("{import_alias}.TreeHasErrorNodes({result_var})"),
1049 is_pointer: true,
1050 value_cast: None,
1051 },
1052 "error_count" | "tree_error_count" => GoMethodCallInfo {
1053 call_expr: format!("{import_alias}.TreeErrorCount({result_var})"),
1054 is_pointer: true,
1055 value_cast: Some("uint"),
1056 },
1057 "tree_to_sexp" => GoMethodCallInfo {
1058 call_expr: format!("{import_alias}.TreeToSexp({result_var})"),
1059 is_pointer: true,
1060 value_cast: None,
1061 },
1062 "contains_node_type" => {
1063 let node_type = args
1064 .and_then(|a| a.get("node_type"))
1065 .and_then(|v| v.as_str())
1066 .unwrap_or("");
1067 GoMethodCallInfo {
1068 call_expr: format!("{import_alias}.TreeContainsNodeType({result_var}, \"{node_type}\")"),
1069 is_pointer: true,
1070 value_cast: None,
1071 }
1072 }
1073 "find_nodes_by_type" => {
1074 let node_type = args
1075 .and_then(|a| a.get("node_type"))
1076 .and_then(|v| v.as_str())
1077 .unwrap_or("");
1078 GoMethodCallInfo {
1079 call_expr: format!("{import_alias}.FindNodesByType({result_var}, \"{node_type}\")"),
1080 is_pointer: true,
1081 value_cast: None,
1082 }
1083 }
1084 "run_query" => {
1085 let query_source = args
1086 .and_then(|a| a.get("query_source"))
1087 .and_then(|v| v.as_str())
1088 .unwrap_or("");
1089 let language = args
1090 .and_then(|a| a.get("language"))
1091 .and_then(|v| v.as_str())
1092 .unwrap_or("");
1093 let query_lit = go_string_literal(query_source);
1094 let lang_lit = go_string_literal(language);
1095 GoMethodCallInfo {
1097 call_expr: format!("{import_alias}.RunQuery({result_var}, {lang_lit}, {query_lit}, []byte(source))"),
1098 is_pointer: false,
1099 value_cast: None,
1100 }
1101 }
1102 other => {
1103 let method_pascal = other.to_upper_camel_case();
1104 GoMethodCallInfo {
1105 call_expr: format!("{result_var}.{method_pascal}()"),
1106 is_pointer: false,
1107 value_cast: None,
1108 }
1109 }
1110 }
1111}
1112
1113const GO_INITIALISMS: &[&str] = &[
1116 "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IDS", "IP", "JSON",
1117 "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "UID", "UUID",
1118 "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS",
1119];
1120
1121fn go_local_name(snake: &str) -> String {
1125 let words: Vec<&str> = snake.split('_').filter(|w| !w.is_empty()).collect();
1126 if words.is_empty() {
1127 return String::new();
1128 }
1129 let mut result = String::new();
1130 for (i, word) in words.iter().enumerate() {
1131 let upper = word.to_uppercase();
1132 if GO_INITIALISMS.contains(&upper.as_str()) {
1133 if i == 0 {
1134 result.push_str(&upper.to_lowercase());
1136 } else {
1137 result.push_str(&upper);
1138 }
1139 } else if i == 0 {
1140 result.push_str(&word.to_lowercase());
1142 } else {
1143 let mut chars = word.chars();
1145 if let Some(c) = chars.next() {
1146 result.extend(c.to_uppercase());
1147 result.push_str(&chars.as_str().to_lowercase());
1148 }
1149 }
1150 }
1151 result
1152}
1153
1154fn json_to_go(value: &serde_json::Value) -> String {
1156 match value {
1157 serde_json::Value::String(s) => go_string_literal(s),
1158 serde_json::Value::Bool(b) => b.to_string(),
1159 serde_json::Value::Number(n) => n.to_string(),
1160 serde_json::Value::Null => "nil".to_string(),
1161 other => go_string_literal(&other.to_string()),
1163 }
1164}
1165
1166fn visitor_struct_name(fixture_id: &str) -> String {
1175 use heck::ToUpperCamelCase;
1176 format!("testVisitor{}", fixture_id.to_upper_camel_case())
1178}
1179
1180fn emit_go_visitor_struct(
1182 out: &mut String,
1183 struct_name: &str,
1184 visitor_spec: &crate::fixture::VisitorSpec,
1185 import_alias: &str,
1186) {
1187 let _ = writeln!(out, "type {struct_name} struct{{}}");
1188 for (method_name, action) in &visitor_spec.callbacks {
1189 emit_go_visitor_method(out, struct_name, method_name, action, import_alias);
1190 }
1191}
1192
1193fn emit_go_visitor_method(
1195 out: &mut String,
1196 struct_name: &str,
1197 method_name: &str,
1198 action: &CallbackAction,
1199 import_alias: &str,
1200) {
1201 let camel_method = method_to_camel(method_name);
1202 let params = match method_name {
1203 "visit_link" => format!("_ {import_alias}.NodeContext, href, text, title string"),
1204 "visit_image" => format!("_ {import_alias}.NodeContext, src, alt, title string"),
1205 "visit_heading" => format!("_ {import_alias}.NodeContext, level int, text, id string"),
1206 "visit_code_block" => format!("_ {import_alias}.NodeContext, lang, code string"),
1207 "visit_code_inline"
1208 | "visit_strong"
1209 | "visit_emphasis"
1210 | "visit_strikethrough"
1211 | "visit_underline"
1212 | "visit_subscript"
1213 | "visit_superscript"
1214 | "visit_mark"
1215 | "visit_button"
1216 | "visit_summary"
1217 | "visit_figcaption"
1218 | "visit_definition_term"
1219 | "visit_definition_description" => format!("_ {import_alias}.NodeContext, text string"),
1220 "visit_text" => format!("_ {import_alias}.NodeContext, text string"),
1221 "visit_list_item" => {
1222 format!("_ {import_alias}.NodeContext, ordered bool, marker, text string")
1223 }
1224 "visit_blockquote" => format!("_ {import_alias}.NodeContext, content string, depth int"),
1225 "visit_table_row" => format!("_ {import_alias}.NodeContext, cells []string, isHeader bool"),
1226 "visit_custom_element" => format!("_ {import_alias}.NodeContext, tagName, html string"),
1227 "visit_form" => format!("_ {import_alias}.NodeContext, actionUrl, method string"),
1228 "visit_input" => format!("_ {import_alias}.NodeContext, inputType, name, value string"),
1229 "visit_audio" | "visit_video" | "visit_iframe" => {
1230 format!("_ {import_alias}.NodeContext, src string")
1231 }
1232 "visit_details" => format!("_ {import_alias}.NodeContext, isOpen bool"),
1233 _ => format!("_ {import_alias}.NodeContext"),
1234 };
1235
1236 let _ = writeln!(
1237 out,
1238 "func (v *{struct_name}) {camel_method}({params}) {import_alias}.VisitResult {{"
1239 );
1240 match action {
1241 CallbackAction::Skip => {
1242 let _ = writeln!(out, "\treturn {import_alias}.VisitResultSkip");
1243 }
1244 CallbackAction::Continue => {
1245 let _ = writeln!(out, "\treturn {import_alias}.VisitResultContinue");
1246 }
1247 CallbackAction::PreserveHtml => {
1248 let _ = writeln!(out, "\treturn {import_alias}.VisitResultPreserveHtml");
1249 }
1250 CallbackAction::Custom { output } => {
1251 let escaped = go_string_literal(output);
1252 let _ = writeln!(out, "\treturn {import_alias}.VisitResultCustom({escaped})");
1253 }
1254 CallbackAction::CustomTemplate { template } => {
1255 let (fmt_str, fmt_args) = template_to_sprintf(template);
1258 let escaped_fmt = go_string_literal(&fmt_str);
1259 if fmt_args.is_empty() {
1260 let _ = writeln!(out, "\treturn {import_alias}.VisitResultCustom({escaped_fmt})");
1261 } else {
1262 let args_str = fmt_args.join(", ");
1263 let _ = writeln!(
1264 out,
1265 "\treturn {import_alias}.VisitResultCustom(fmt.Sprintf({escaped_fmt}, {args_str}))"
1266 );
1267 }
1268 }
1269 }
1270 let _ = writeln!(out, "}}");
1271}
1272
1273fn template_to_sprintf(template: &str) -> (String, Vec<String>) {
1277 let mut fmt_str = String::new();
1278 let mut args: Vec<String> = Vec::new();
1279 let mut chars = template.chars().peekable();
1280 while let Some(c) = chars.next() {
1281 if c == '{' {
1282 let mut name = String::new();
1284 for inner in chars.by_ref() {
1285 if inner == '}' {
1286 break;
1287 }
1288 name.push(inner);
1289 }
1290 fmt_str.push_str("%s");
1291 args.push(name);
1292 } else {
1293 fmt_str.push(c);
1294 }
1295 }
1296 (fmt_str, args)
1297}
1298
1299fn method_to_camel(snake: &str) -> String {
1301 use heck::ToUpperCamelCase;
1302 snake.to_upper_camel_case()
1303}