1use crate::config::E2eConfig;
4use crate::escape::{go_string_literal, sanitize_filename};
5use crate::field_access::FieldResolver;
6use crate::fixture::{Assertion, CallbackAction, Fixture, FixtureGroup};
7use alef_codegen::naming::go_param_name;
8use alef_core::backend::GeneratedFile;
9use alef_core::config::AlefConfig;
10use alef_core::hash::{self, CommentStyle};
11use anyhow::Result;
12use heck::ToUpperCamelCase;
13use std::fmt::Write as FmtWrite;
14use std::path::PathBuf;
15
16use super::E2eCodegen;
17
18pub struct GoCodegen;
20
21impl E2eCodegen for GoCodegen {
22 fn generate(
23 &self,
24 groups: &[FixtureGroup],
25 e2e_config: &E2eConfig,
26 alef_config: &AlefConfig,
27 ) -> Result<Vec<GeneratedFile>> {
28 let lang = self.language_name();
29 let output_base = PathBuf::from(e2e_config.effective_output()).join(lang);
30
31 let mut files = Vec::new();
32
33 let call = &e2e_config.call;
35 let overrides = call.overrides.get(lang);
36 let module_path = overrides
37 .and_then(|o| o.module.as_ref())
38 .cloned()
39 .unwrap_or_else(|| call.module.clone());
40 let import_alias = overrides
41 .and_then(|o| o.alias.as_ref())
42 .cloned()
43 .unwrap_or_else(|| "pkg".to_string());
44
45 let go_pkg = e2e_config.resolve_package("go");
47 let go_module_path = go_pkg
48 .as_ref()
49 .and_then(|p| p.module.as_ref())
50 .cloned()
51 .unwrap_or_else(|| module_path.clone());
52 let replace_path = go_pkg.as_ref().and_then(|p| p.path.as_ref()).cloned();
53 let go_version = go_pkg
54 .as_ref()
55 .and_then(|p| p.version.as_ref())
56 .cloned()
57 .unwrap_or_else(|| {
58 alef_config
59 .resolved_version()
60 .map(|v| format!("v{v}"))
61 .unwrap_or_else(|| "v0.0.0".to_string())
62 });
63 let field_resolver = FieldResolver::new(
64 &e2e_config.fields,
65 &e2e_config.fields_optional,
66 &e2e_config.result_fields,
67 &e2e_config.fields_array,
68 );
69
70 let effective_replace = match e2e_config.dep_mode {
73 crate::config::DependencyMode::Registry => None,
74 crate::config::DependencyMode::Local => replace_path.as_deref().map(String::from),
75 };
76 files.push(GeneratedFile {
77 path: output_base.join("go.mod"),
78 content: render_go_mod(&go_module_path, effective_replace.as_deref(), &go_version),
79 generated_header: false,
80 });
81
82 for group in groups {
84 let active: Vec<&Fixture> = group
85 .fixtures
86 .iter()
87 .filter(|f| f.skip.as_ref().is_none_or(|s| !s.should_skip(lang)))
88 .collect();
89
90 if active.is_empty() {
91 continue;
92 }
93
94 let filename = format!("{}_test.go", sanitize_filename(&group.category));
95 let content = render_test_file(
96 &group.category,
97 &active,
98 &module_path,
99 &import_alias,
100 &field_resolver,
101 e2e_config,
102 );
103 files.push(GeneratedFile {
104 path: output_base.join(filename),
105 content,
106 generated_header: true,
107 });
108 }
109
110 Ok(files)
111 }
112
113 fn language_name(&self) -> &'static str {
114 "go"
115 }
116}
117
118fn render_go_mod(go_module_path: &str, replace_path: Option<&str>, version: &str) -> String {
119 let mut out = String::new();
120 let _ = writeln!(out, "module e2e_go");
121 let _ = writeln!(out);
122 let _ = writeln!(out, "go 1.26");
123 let _ = writeln!(out);
124 let _ = writeln!(out, "require {go_module_path} {version}");
125
126 if let Some(path) = replace_path {
127 let _ = writeln!(out);
128 let _ = writeln!(out, "replace {go_module_path} => {path}");
129 }
130
131 out
132}
133
134fn render_test_file(
135 category: &str,
136 fixtures: &[&Fixture],
137 go_module_path: &str,
138 import_alias: &str,
139 field_resolver: &FieldResolver,
140 e2e_config: &crate::config::E2eConfig,
141) -> String {
142 let mut out = String::new();
143
144 out.push_str(&hash::header(CommentStyle::DoubleSlash));
146 let _ = writeln!(out);
147
148 let needs_os = fixtures.iter().any(|f| {
151 let call_args = &e2e_config.resolve_call(f.call.as_deref()).args;
152 call_args.iter().any(|a| a.arg_type == "mock_url")
153 });
154
155 let needs_json = fixtures.iter().any(|f| {
157 let call_args = &e2e_config.resolve_call(f.call.as_deref()).args;
158 call_args.iter().any(|a| a.arg_type == "handle") && {
159 call_args.iter().filter(|a| a.arg_type == "handle").any(|a| {
160 let v = f.input.get(&a.field).unwrap_or(&serde_json::Value::Null);
161 !(v.is_null() || v.is_object() && v.as_object().is_some_and(|o| o.is_empty()))
162 })
163 }
164 });
165
166 let needs_fmt = fixtures.iter().any(|f| {
168 f.visitor.as_ref().is_some_and(|v| {
169 v.callbacks.values().any(|action| {
170 if let CallbackAction::CustomTemplate { template } = action {
171 template.contains('{')
172 } else {
173 false
174 }
175 })
176 })
177 });
178
179 let needs_strings = fixtures.iter().any(|f| {
182 f.assertions.iter().any(|a| {
183 let type_needs_strings = if a.assertion_type == "equals" {
184 a.value.as_ref().is_some_and(|v| v.is_string())
186 } else {
187 matches!(
188 a.assertion_type.as_str(),
189 "contains" | "contains_all" | "not_contains" | "starts_with" | "ends_with"
190 )
191 };
192 let field_valid = a
193 .field
194 .as_ref()
195 .map(|f| f.is_empty() || field_resolver.is_valid_for_result(f))
196 .unwrap_or(true);
197 type_needs_strings && field_valid
198 })
199 });
200
201 let needs_assert = fixtures.iter().any(|f| {
204 f.assertions.iter().any(|a| {
205 let field_valid = a
206 .field
207 .as_ref()
208 .map(|f| f.is_empty() || field_resolver.is_valid_for_result(f))
209 .unwrap_or(true);
210 let type_needs_assert = matches!(
211 a.assertion_type.as_str(),
212 "count_min"
213 | "count_max"
214 | "is_true"
215 | "is_false"
216 | "method_result"
217 | "min_length"
218 | "max_length"
219 | "matches_regex"
220 );
221 type_needs_assert && field_valid
222 })
223 });
224
225 let _ = writeln!(out, "// E2e tests for category: {category}");
226 let _ = writeln!(out, "package e2e_test");
227 let _ = writeln!(out);
228 let _ = writeln!(out, "import (");
229 if needs_json {
230 let _ = writeln!(out, "\t\"encoding/json\"");
231 }
232 if needs_fmt {
233 let _ = writeln!(out, "\t\"fmt\"");
234 }
235 if needs_os {
236 let _ = writeln!(out, "\t\"os\"");
237 }
238 if needs_strings {
239 let _ = writeln!(out, "\t\"strings\"");
240 }
241 let _ = writeln!(out, "\t\"testing\"");
242 if needs_assert {
243 let _ = writeln!(out);
244 let _ = writeln!(out, "\t\"github.com/stretchr/testify/assert\"");
245 }
246 let _ = writeln!(out);
247 let _ = writeln!(out, "\t{import_alias} \"{go_module_path}\"");
248 let _ = writeln!(out, ")");
249 let _ = writeln!(out);
250
251 for fixture in fixtures.iter() {
253 if let Some(visitor_spec) = &fixture.visitor {
254 let struct_name = visitor_struct_name(&fixture.id);
255 emit_go_visitor_struct(&mut out, &struct_name, visitor_spec, import_alias);
256 let _ = writeln!(out);
257 }
258 }
259
260 for (i, fixture) in fixtures.iter().enumerate() {
261 render_test_function(&mut out, fixture, import_alias, field_resolver, e2e_config);
262 if i + 1 < fixtures.len() {
263 let _ = writeln!(out);
264 }
265 }
266
267 while out.ends_with("\n\n") {
269 out.pop();
270 }
271 if !out.ends_with('\n') {
272 out.push('\n');
273 }
274 out
275}
276
277fn render_test_function(
278 out: &mut String,
279 fixture: &Fixture,
280 import_alias: &str,
281 field_resolver: &FieldResolver,
282 e2e_config: &crate::config::E2eConfig,
283) {
284 let fn_name = fixture.id.to_upper_camel_case();
285 let description = &fixture.description;
286
287 let call_config = e2e_config.resolve_call(fixture.call.as_deref());
289 let lang = "go";
290 let overrides = call_config.overrides.get(lang);
291 let function_name = overrides
292 .and_then(|o| o.function.as_ref())
293 .cloned()
294 .unwrap_or_else(|| call_config.function.clone());
295 let result_var = &call_config.result_var;
296 let args = &call_config.args;
297
298 let expects_error = fixture.assertions.iter().any(|a| a.assertion_type == "error");
299
300 let (mut setup_lines, args_str) = build_args_and_setup(&fixture.input, args, import_alias, e2e_config, &fixture.id);
301
302 let mut visitor_arg = String::new();
304 if fixture.visitor.is_some() {
305 let struct_name = visitor_struct_name(&fixture.id);
306 setup_lines.push(format!("visitor := &{struct_name}{{}}"));
307 visitor_arg = "visitor".to_string();
308 }
309
310 let final_args = if visitor_arg.is_empty() {
311 args_str
312 } else {
313 format!("{args_str}, {visitor_arg}")
314 };
315
316 let _ = writeln!(out, "func Test_{fn_name}(t *testing.T) {{");
317 let _ = writeln!(out, "\t// {description}");
318
319 for line in &setup_lines {
320 let _ = writeln!(out, "\t{line}");
321 }
322
323 if expects_error {
324 let _ = writeln!(out, "\t_, err := {import_alias}.{function_name}({final_args})");
325 let _ = writeln!(out, "\tif err == nil {{");
326 let _ = writeln!(out, "\t\tt.Errorf(\"expected an error, but call succeeded\")");
327 let _ = writeln!(out, "\t}}");
328 let _ = writeln!(out, "}}");
329 return;
330 }
331
332 let has_usable_assertion = fixture.assertions.iter().any(|a| {
336 if a.assertion_type == "not_error" || a.assertion_type == "error" {
337 return false;
338 }
339 if a.assertion_type == "method_result" {
341 return true;
342 }
343 match &a.field {
344 Some(f) if !f.is_empty() => field_resolver.is_valid_for_result(f),
345 _ => true,
346 }
347 });
348
349 let result_binding = if has_usable_assertion {
350 result_var.to_string()
351 } else {
352 "_".to_string()
353 };
354
355 let _ = writeln!(
357 out,
358 "\t{result_binding}, err := {import_alias}.{function_name}({final_args})"
359 );
360 let _ = writeln!(out, "\tif err != nil {{");
361 let _ = writeln!(out, "\t\tt.Fatalf(\"call failed: %v\", err)");
362 let _ = writeln!(out, "\t}}");
363
364 let mut optional_locals: std::collections::HashMap<String, String> = std::collections::HashMap::new();
369 for assertion in &fixture.assertions {
370 if let Some(f) = &assertion.field {
371 if !f.is_empty() {
372 let resolved = field_resolver.resolve(f);
373 if field_resolver.is_optional(resolved) && !optional_locals.contains_key(f.as_str()) {
374 let is_string_field = assertion.value.as_ref().is_some_and(|v| v.is_string());
377 if !is_string_field {
378 continue;
381 }
382 let field_expr = field_resolver.accessor(f, "go", result_var);
383 let local_var = go_param_name(&resolved.replace(['.', '[', ']'], "_"));
384 if field_resolver.has_map_access(f) {
385 let _ = writeln!(out, "\t{local_var} := {field_expr}");
388 } else {
389 let _ = writeln!(out, "\tvar {local_var} string");
390 let _ = writeln!(out, "\tif {field_expr} != nil {{");
391 let _ = writeln!(out, "\t\t{local_var} = *{field_expr}");
392 let _ = writeln!(out, "\t}}");
393 }
394 optional_locals.insert(f.clone(), local_var);
395 }
396 }
397 }
398 }
399
400 for assertion in &fixture.assertions {
402 if let Some(f) = &assertion.field {
403 if !f.is_empty() && !optional_locals.contains_key(f.as_str()) {
404 let parts: Vec<&str> = f.split('.').collect();
407 let mut guard_expr: Option<String> = None;
408 for i in 1..parts.len() {
409 let prefix = parts[..i].join(".");
410 let resolved_prefix = field_resolver.resolve(&prefix);
411 if field_resolver.is_optional(resolved_prefix) {
412 let accessor = field_resolver.accessor(&prefix, "go", result_var);
413 guard_expr = Some(accessor);
414 break;
415 }
416 }
417 if let Some(guard) = guard_expr {
418 if field_resolver.is_valid_for_result(f) {
421 let _ = writeln!(out, "\tif {guard} != nil {{");
422 let mut nil_buf = String::new();
425 render_assertion(
426 &mut nil_buf,
427 assertion,
428 result_var,
429 import_alias,
430 field_resolver,
431 &optional_locals,
432 );
433 for line in nil_buf.lines() {
434 let _ = writeln!(out, "\t{line}");
435 }
436 let _ = writeln!(out, "\t}}");
437 } else {
438 render_assertion(
439 out,
440 assertion,
441 result_var,
442 import_alias,
443 field_resolver,
444 &optional_locals,
445 );
446 }
447 continue;
448 }
449 }
450 }
451 render_assertion(
452 out,
453 assertion,
454 result_var,
455 import_alias,
456 field_resolver,
457 &optional_locals,
458 );
459 }
460
461 let _ = writeln!(out, "}}");
462}
463
464fn build_args_and_setup(
468 input: &serde_json::Value,
469 args: &[crate::config::ArgMapping],
470 import_alias: &str,
471 e2e_config: &crate::config::E2eConfig,
472 fixture_id: &str,
473) -> (Vec<String>, String) {
474 use heck::ToUpperCamelCase;
475
476 if args.is_empty() {
477 return (Vec::new(), json_to_go(input));
478 }
479
480 let overrides = e2e_config.call.overrides.get("go");
481 let options_type = overrides.and_then(|o| o.options_type.as_deref());
482
483 let mut setup_lines: Vec<String> = Vec::new();
484 let mut parts: Vec<String> = Vec::new();
485
486 for arg in args {
487 if arg.arg_type == "mock_url" {
488 setup_lines.push(format!(
489 "{} := os.Getenv(\"MOCK_SERVER_URL\") + \"/fixtures/{fixture_id}\"",
490 arg.name,
491 ));
492 parts.push(arg.name.clone());
493 continue;
494 }
495
496 if arg.arg_type == "handle" {
497 let constructor_name = format!("Create{}", arg.name.to_upper_camel_case());
499 let field = arg.field.strip_prefix("input.").unwrap_or(&arg.field);
500 let config_value = input.get(field).unwrap_or(&serde_json::Value::Null);
501 if config_value.is_null()
502 || config_value.is_object() && config_value.as_object().is_some_and(|o| o.is_empty())
503 {
504 setup_lines.push(format!(
505 "{name}, createErr := {import_alias}.{constructor_name}(nil)\n\tif createErr != nil {{\n\t\tt.Fatalf(\"create handle failed: %v\", createErr)\n\t}}",
506 name = arg.name,
507 ));
508 } else {
509 let json_str = serde_json::to_string(config_value).unwrap_or_default();
510 let go_literal = go_string_literal(&json_str);
511 let name = &arg.name;
512 setup_lines.push(format!(
513 "var {name}Config {import_alias}.CrawlConfig\n\tif err := json.Unmarshal([]byte({go_literal}), &{name}Config); err != nil {{\n\t\tt.Fatalf(\"config parse failed: %v\", err)\n\t}}"
514 ));
515 setup_lines.push(format!(
516 "{name}, createErr := {import_alias}.{constructor_name}(&{name}Config)\n\tif createErr != nil {{\n\t\tt.Fatalf(\"create handle failed: %v\", createErr)\n\t}}"
517 ));
518 }
519 parts.push(arg.name.clone());
520 continue;
521 }
522
523 let field = arg.field.strip_prefix("input.").unwrap_or(&arg.field);
524 let val = input.get(field);
525 match val {
526 None | Some(serde_json::Value::Null) if arg.optional => {
527 continue;
529 }
530 None | Some(serde_json::Value::Null) => {
531 let default_val = match arg.arg_type.as_str() {
533 "string" => "\"\"".to_string(),
534 "int" | "integer" => "0".to_string(),
535 "float" | "number" => "0.0".to_string(),
536 "bool" | "boolean" => "false".to_string(),
537 _ => "nil".to_string(),
538 };
539 parts.push(default_val);
540 }
541 Some(v) => {
542 if let (Some(opts_type), "json_object") = (options_type, arg.arg_type.as_str()) {
544 if let Some(obj) = v.as_object() {
545 let with_calls: Vec<String> = obj
546 .iter()
547 .map(|(k, vv)| {
548 let func_name = format!("With{}{}", opts_type, k.to_upper_camel_case());
549 let go_val = json_to_go(vv);
550 format!("htmd.{func_name}({go_val})")
551 })
552 .collect();
553 let new_fn = format!("New{opts_type}");
554 parts.push(format!("htmd.{new_fn}({})", with_calls.join(", ")));
555 continue;
556 }
557 }
558 parts.push(json_to_go(v));
559 }
560 }
561 }
562
563 (setup_lines, parts.join(", "))
564}
565
566fn render_assertion(
567 out: &mut String,
568 assertion: &Assertion,
569 result_var: &str,
570 import_alias: &str,
571 field_resolver: &FieldResolver,
572 optional_locals: &std::collections::HashMap<String, String>,
573) {
574 if let Some(f) = &assertion.field {
576 if !f.is_empty() && !field_resolver.is_valid_for_result(f) {
577 let _ = writeln!(out, "\t// skipped: field '{f}' not available on result type");
578 return;
579 }
580 }
581
582 let field_expr = match &assertion.field {
583 Some(f) if !f.is_empty() => {
584 if let Some(local_var) = optional_locals.get(f.as_str()) {
586 local_var.clone()
587 } else {
588 field_resolver.accessor(f, "go", result_var)
589 }
590 }
591 _ => result_var.to_string(),
592 };
593
594 let is_optional = assertion
598 .field
599 .as_ref()
600 .map(|f| {
601 let resolved = field_resolver.resolve(f);
602 let check_path = resolved
603 .strip_suffix(".length")
604 .or_else(|| resolved.strip_suffix(".count"))
605 .or_else(|| resolved.strip_suffix(".size"))
606 .unwrap_or(resolved);
607 field_resolver.is_optional(check_path) && !optional_locals.contains_key(f.as_str())
608 })
609 .unwrap_or(false);
610
611 let field_expr = if is_optional && field_expr.starts_with("len(") && field_expr.ends_with(')') {
614 let inner = &field_expr[4..field_expr.len() - 1];
615 format!("len(*{inner})")
616 } else {
617 field_expr
618 };
619 let nil_guard_expr = if is_optional && field_expr.starts_with("len(*") {
621 Some(field_expr[5..field_expr.len() - 1].to_string())
622 } else {
623 None
624 };
625
626 let deref_field_expr = if is_optional && !field_expr.starts_with("len(") {
629 format!("*{field_expr}")
630 } else {
631 field_expr.clone()
632 };
633
634 let array_guard: Option<String> = if let Some(idx) = field_expr.find("[0]") {
639 let array_expr = &field_expr[..idx];
640 Some(array_expr.to_string())
641 } else {
642 None
643 };
644
645 let mut assertion_buf = String::new();
648 let out_ref = &mut assertion_buf;
649
650 match assertion.assertion_type.as_str() {
651 "equals" => {
652 if let Some(expected) = &assertion.value {
653 let go_val = json_to_go(expected);
654 if expected.is_string() {
656 let trimmed_field = if is_optional && !field_expr.starts_with("len(") {
658 format!("strings.TrimSpace(*{field_expr})")
659 } else {
660 format!("strings.TrimSpace({field_expr})")
661 };
662 if is_optional && !field_expr.starts_with("len(") {
663 let _ = writeln!(out_ref, "\tif {field_expr} != nil && {trimmed_field} != {go_val} {{");
664 } else {
665 let _ = writeln!(out_ref, "\tif {trimmed_field} != {go_val} {{");
666 }
667 } else if is_optional && !field_expr.starts_with("len(") {
668 let _ = writeln!(out_ref, "\tif {field_expr} != nil && {deref_field_expr} != {go_val} {{");
669 } else {
670 let _ = writeln!(out_ref, "\tif {field_expr} != {go_val} {{");
671 }
672 let _ = writeln!(out_ref, "\t\tt.Errorf(\"equals mismatch: got %v\", {field_expr})");
673 let _ = writeln!(out_ref, "\t}}");
674 }
675 }
676 "contains" => {
677 if let Some(expected) = &assertion.value {
678 let go_val = json_to_go(expected);
679 let field_for_contains = if is_optional
680 && !optional_locals.contains_key(assertion.field.as_ref().unwrap_or(&String::new()))
681 {
682 format!("string(*{field_expr})")
683 } else {
684 format!("string({field_expr})")
685 };
686 let _ = writeln!(out_ref, "\tif !strings.Contains({field_for_contains}, {go_val}) {{");
687 let _ = writeln!(
688 out_ref,
689 "\t\tt.Errorf(\"expected to contain %s, got %v\", {go_val}, {field_expr})"
690 );
691 let _ = writeln!(out_ref, "\t}}");
692 }
693 }
694 "contains_all" => {
695 if let Some(values) = &assertion.values {
696 for val in values {
697 let go_val = json_to_go(val);
698 let field_for_contains = if is_optional
699 && !optional_locals.contains_key(assertion.field.as_ref().unwrap_or(&String::new()))
700 {
701 format!("string(*{field_expr})")
702 } else {
703 format!("string({field_expr})")
704 };
705 let _ = writeln!(out_ref, "\tif !strings.Contains({field_for_contains}, {go_val}) {{");
706 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected to contain %s\", {go_val})");
707 let _ = writeln!(out_ref, "\t}}");
708 }
709 }
710 }
711 "not_contains" => {
712 if let Some(expected) = &assertion.value {
713 let go_val = json_to_go(expected);
714 let field_for_contains = if is_optional
715 && !optional_locals.contains_key(assertion.field.as_ref().unwrap_or(&String::new()))
716 {
717 format!("string(*{field_expr})")
718 } else {
719 format!("string({field_expr})")
720 };
721 let _ = writeln!(out_ref, "\tif strings.Contains({field_for_contains}, {go_val}) {{");
722 let _ = writeln!(
723 out_ref,
724 "\t\tt.Errorf(\"expected NOT to contain %s, got %v\", {go_val}, {field_expr})"
725 );
726 let _ = writeln!(out_ref, "\t}}");
727 }
728 }
729 "not_empty" => {
730 if is_optional {
731 let _ = writeln!(out_ref, "\tif {field_expr} == nil || len(*{field_expr}) == 0 {{");
732 } else {
733 let _ = writeln!(out_ref, "\tif len({field_expr}) == 0 {{");
734 }
735 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected non-empty value\")");
736 let _ = writeln!(out_ref, "\t}}");
737 }
738 "is_empty" => {
739 if is_optional {
740 let _ = writeln!(out_ref, "\tif {field_expr} != nil && len(*{field_expr}) != 0 {{");
741 } else {
742 let _ = writeln!(out_ref, "\tif len({field_expr}) != 0 {{");
743 }
744 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected empty value, got %v\", {field_expr})");
745 let _ = writeln!(out_ref, "\t}}");
746 }
747 "contains_any" => {
748 if let Some(values) = &assertion.values {
749 let field_for_contains = if is_optional
750 && !optional_locals.contains_key(assertion.field.as_ref().unwrap_or(&String::new()))
751 {
752 format!("*{field_expr}")
753 } else {
754 field_expr.clone()
755 };
756 let _ = writeln!(out_ref, "\t{{");
757 let _ = writeln!(out_ref, "\t\tfound := false");
758 for val in values {
759 let go_val = json_to_go(val);
760 let _ = writeln!(
761 out_ref,
762 "\t\tif strings.Contains({field_for_contains}, {go_val}) {{ found = true }}"
763 );
764 }
765 let _ = writeln!(out_ref, "\t\tif !found {{");
766 let _ = writeln!(
767 out_ref,
768 "\t\t\tt.Errorf(\"expected to contain at least one of the specified values\")"
769 );
770 let _ = writeln!(out_ref, "\t\t}}");
771 let _ = writeln!(out_ref, "\t}}");
772 }
773 }
774 "greater_than" => {
775 if let Some(val) = &assertion.value {
776 let go_val = json_to_go(val);
777 if let Some(n) = val.as_u64() {
780 let next = n + 1;
781 let _ = writeln!(out_ref, "\tif {field_expr} < {next} {{");
782 } else {
783 let _ = writeln!(out_ref, "\tif {field_expr} <= {go_val} {{");
784 }
785 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected > {go_val}, got %v\", {field_expr})");
786 let _ = writeln!(out_ref, "\t}}");
787 }
788 }
789 "less_than" => {
790 if let Some(val) = &assertion.value {
791 let go_val = json_to_go(val);
792 let _ = writeln!(out_ref, "\tif {field_expr} >= {go_val} {{");
793 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected < {go_val}, got %v\", {field_expr})");
794 let _ = writeln!(out_ref, "\t}}");
795 }
796 }
797 "greater_than_or_equal" => {
798 if let Some(val) = &assertion.value {
799 let go_val = json_to_go(val);
800 if let Some(ref guard) = nil_guard_expr {
801 let _ = writeln!(out_ref, "\tif {guard} != nil {{");
802 let _ = writeln!(out_ref, "\t\tif {field_expr} < {go_val} {{");
803 let _ = writeln!(
804 out_ref,
805 "\t\t\tt.Errorf(\"expected >= {go_val}, got %v\", {field_expr})"
806 );
807 let _ = writeln!(out_ref, "\t\t}}");
808 let _ = writeln!(out_ref, "\t}}");
809 } else {
810 let _ = writeln!(out_ref, "\tif {field_expr} < {go_val} {{");
811 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected >= {go_val}, got %v\", {field_expr})");
812 let _ = writeln!(out_ref, "\t}}");
813 }
814 }
815 }
816 "less_than_or_equal" => {
817 if let Some(val) = &assertion.value {
818 let go_val = json_to_go(val);
819 let _ = writeln!(out_ref, "\tif {field_expr} > {go_val} {{");
820 let _ = writeln!(out_ref, "\t\tt.Errorf(\"expected <= {go_val}, got %v\", {field_expr})");
821 let _ = writeln!(out_ref, "\t}}");
822 }
823 }
824 "starts_with" => {
825 if let Some(expected) = &assertion.value {
826 let go_val = json_to_go(expected);
827 let field_for_prefix = if is_optional
828 && !optional_locals.contains_key(assertion.field.as_ref().unwrap_or(&String::new()))
829 {
830 format!("string(*{field_expr})")
831 } else {
832 format!("string({field_expr})")
833 };
834 let _ = writeln!(out_ref, "\tif !strings.HasPrefix({field_for_prefix}, {go_val}) {{");
835 let _ = writeln!(
836 out_ref,
837 "\t\tt.Errorf(\"expected to start with %s, got %v\", {go_val}, {field_expr})"
838 );
839 let _ = writeln!(out_ref, "\t}}");
840 }
841 }
842 "count_min" => {
843 if let Some(val) = &assertion.value {
844 if let Some(n) = val.as_u64() {
845 if is_optional {
846 let _ = writeln!(out_ref, "\tif {field_expr} != nil {{");
847 let _ = writeln!(
848 out_ref,
849 "\t\tassert.GreaterOrEqual(t, len(*{field_expr}), {n}, \"expected at least {n} elements\")"
850 );
851 let _ = writeln!(out_ref, "\t}}");
852 } else {
853 let _ = writeln!(
854 out_ref,
855 "\tassert.GreaterOrEqual(t, len({field_expr}), {n}, \"expected at least {n} elements\")"
856 );
857 }
858 }
859 }
860 }
861 "count_equals" => {
862 if let Some(val) = &assertion.value {
863 if let Some(n) = val.as_u64() {
864 if is_optional {
865 let _ = writeln!(out_ref, "\tif {field_expr} != nil {{");
866 let _ = writeln!(
867 out_ref,
868 "\t\tassert.Equal(t, len(*{field_expr}), {n}, \"expected exactly {n} elements\")"
869 );
870 let _ = writeln!(out_ref, "\t}}");
871 } else {
872 let _ = writeln!(
873 out_ref,
874 "\tassert.Equal(t, len({field_expr}), {n}, \"expected exactly {n} elements\")"
875 );
876 }
877 }
878 }
879 }
880 "is_true" => {
881 if is_optional {
882 let _ = writeln!(out_ref, "\tif {field_expr} != nil {{");
883 let _ = writeln!(out_ref, "\t\tassert.True(t, *{field_expr}, \"expected true\")");
884 let _ = writeln!(out_ref, "\t}}");
885 } else {
886 let _ = writeln!(out_ref, "\tassert.True(t, {field_expr}, \"expected true\")");
887 }
888 }
889 "is_false" => {
890 if is_optional {
891 let _ = writeln!(out_ref, "\tif {field_expr} != nil {{");
892 let _ = writeln!(out_ref, "\t\tassert.False(t, *{field_expr}, \"expected false\")");
893 let _ = writeln!(out_ref, "\t}}");
894 } else {
895 let _ = writeln!(out_ref, "\tassert.False(t, {field_expr}, \"expected false\")");
896 }
897 }
898 "method_result" => {
899 if let Some(method_name) = &assertion.method {
900 let info = build_go_method_call(result_var, method_name, assertion.args.as_ref(), import_alias);
901 let check = assertion.check.as_deref().unwrap_or("is_true");
902 let deref_expr = if info.is_pointer {
905 format!("*{}", info.call_expr)
906 } else {
907 info.call_expr.clone()
908 };
909 match check {
910 "equals" => {
911 if let Some(val) = &assertion.value {
912 if val.is_boolean() {
913 if val.as_bool() == Some(true) {
914 let _ = writeln!(out_ref, "\tassert.True(t, {deref_expr}, \"expected true\")");
915 } else {
916 let _ = writeln!(out_ref, "\tassert.False(t, {deref_expr}, \"expected false\")");
917 }
918 } else {
919 let go_val = if let Some(cast) = info.value_cast {
923 if val.is_number() {
924 format!("{cast}({})", json_to_go(val))
925 } else {
926 json_to_go(val)
927 }
928 } else {
929 json_to_go(val)
930 };
931 let _ = writeln!(
932 out_ref,
933 "\tassert.Equal(t, {go_val}, {deref_expr}, \"method_result equals assertion failed\")"
934 );
935 }
936 }
937 }
938 "is_true" => {
939 let _ = writeln!(out_ref, "\tassert.True(t, {deref_expr}, \"expected true\")");
940 }
941 "is_false" => {
942 let _ = writeln!(out_ref, "\tassert.False(t, {deref_expr}, \"expected false\")");
943 }
944 "greater_than_or_equal" => {
945 if let Some(val) = &assertion.value {
946 let n = val.as_u64().unwrap_or(0);
947 let cast = info.value_cast.unwrap_or("uint");
949 let _ = writeln!(
950 out_ref,
951 "\tassert.GreaterOrEqual(t, {deref_expr}, {cast}({n}), \"expected >= {n}\")"
952 );
953 }
954 }
955 "count_min" => {
956 if let Some(val) = &assertion.value {
957 let n = val.as_u64().unwrap_or(0);
958 let _ = writeln!(
959 out_ref,
960 "\tassert.GreaterOrEqual(t, len({deref_expr}), {n}, \"expected at least {n} elements\")"
961 );
962 }
963 }
964 "contains" => {
965 if let Some(val) = &assertion.value {
966 let go_val = json_to_go(val);
967 let _ = writeln!(
968 out_ref,
969 "\tassert.Contains(t, {deref_expr}, {go_val}, \"expected result to contain value\")"
970 );
971 }
972 }
973 "is_error" => {
974 let _ = writeln!(out_ref, "\t{{");
975 let _ = writeln!(out_ref, "\t\t_, methodErr := {}", info.call_expr);
976 let _ = writeln!(out_ref, "\t\tassert.Error(t, methodErr)");
977 let _ = writeln!(out_ref, "\t}}");
978 }
979 other_check => {
980 panic!("Go e2e generator: unsupported method_result check type: {other_check}");
981 }
982 }
983 } else {
984 panic!("Go e2e generator: method_result assertion missing 'method' field");
985 }
986 }
987 "min_length" => {
988 if let Some(val) = &assertion.value {
989 if let Some(n) = val.as_u64() {
990 if is_optional {
991 let _ = writeln!(out_ref, "\tif {field_expr} != nil {{");
992 let _ = writeln!(
993 out_ref,
994 "\t\tassert.GreaterOrEqual(t, len(*{field_expr}), {n}, \"expected length >= {n}\")"
995 );
996 let _ = writeln!(out_ref, "\t}}");
997 } else {
998 let _ = writeln!(
999 out_ref,
1000 "\tassert.GreaterOrEqual(t, len({field_expr}), {n}, \"expected length >= {n}\")"
1001 );
1002 }
1003 }
1004 }
1005 }
1006 "max_length" => {
1007 if let Some(val) = &assertion.value {
1008 if let Some(n) = val.as_u64() {
1009 if is_optional {
1010 let _ = writeln!(out_ref, "\tif {field_expr} != nil {{");
1011 let _ = writeln!(
1012 out_ref,
1013 "\t\tassert.LessOrEqual(t, len(*{field_expr}), {n}, \"expected length <= {n}\")"
1014 );
1015 let _ = writeln!(out_ref, "\t}}");
1016 } else {
1017 let _ = writeln!(
1018 out_ref,
1019 "\tassert.LessOrEqual(t, len({field_expr}), {n}, \"expected length <= {n}\")"
1020 );
1021 }
1022 }
1023 }
1024 }
1025 "ends_with" => {
1026 if let Some(expected) = &assertion.value {
1027 let go_val = json_to_go(expected);
1028 let field_for_suffix = if is_optional
1029 && !optional_locals.contains_key(assertion.field.as_ref().unwrap_or(&String::new()))
1030 {
1031 format!("string(*{field_expr})")
1032 } else {
1033 format!("string({field_expr})")
1034 };
1035 let _ = writeln!(out_ref, "\tif !strings.HasSuffix({field_for_suffix}, {go_val}) {{");
1036 let _ = writeln!(
1037 out_ref,
1038 "\t\tt.Errorf(\"expected to end with %s, got %v\", {go_val}, {field_expr})"
1039 );
1040 let _ = writeln!(out_ref, "\t}}");
1041 }
1042 }
1043 "matches_regex" => {
1044 if let Some(expected) = &assertion.value {
1045 let go_val = json_to_go(expected);
1046 let field_for_regex = if is_optional
1047 && !optional_locals.contains_key(assertion.field.as_ref().unwrap_or(&String::new()))
1048 {
1049 format!("*{field_expr}")
1050 } else {
1051 field_expr.clone()
1052 };
1053 let _ = writeln!(
1054 out_ref,
1055 "\tassert.Regexp(t, {go_val}, {field_for_regex}, \"expected value to match regex\")"
1056 );
1057 }
1058 }
1059 "not_error" => {
1060 }
1062 "error" => {
1063 }
1065 other => {
1066 panic!("Go e2e generator: unsupported assertion type: {other}");
1067 }
1068 }
1069
1070 if let Some(ref arr) = array_guard {
1073 if !assertion_buf.is_empty() {
1074 let _ = writeln!(out, "\tif len({arr}) > 0 {{");
1075 for line in assertion_buf.lines() {
1077 let _ = writeln!(out, "\t{line}");
1078 }
1079 let _ = writeln!(out, "\t}}");
1080 }
1081 } else {
1082 out.push_str(&assertion_buf);
1083 }
1084}
1085
1086struct GoMethodCallInfo {
1088 call_expr: String,
1090 is_pointer: bool,
1092 value_cast: Option<&'static str>,
1095}
1096
1097fn build_go_method_call(
1112 result_var: &str,
1113 method_name: &str,
1114 args: Option<&serde_json::Value>,
1115 import_alias: &str,
1116) -> GoMethodCallInfo {
1117 match method_name {
1118 "root_node_type" => GoMethodCallInfo {
1119 call_expr: format!("{import_alias}.RootNodeInfo({result_var}).Kind"),
1120 is_pointer: false,
1121 value_cast: None,
1122 },
1123 "named_children_count" => GoMethodCallInfo {
1124 call_expr: format!("{import_alias}.RootNodeInfo({result_var}).NamedChildCount"),
1125 is_pointer: false,
1126 value_cast: Some("uint"),
1127 },
1128 "has_error_nodes" => GoMethodCallInfo {
1129 call_expr: format!("{import_alias}.TreeHasErrorNodes({result_var})"),
1130 is_pointer: true,
1131 value_cast: None,
1132 },
1133 "error_count" | "tree_error_count" => GoMethodCallInfo {
1134 call_expr: format!("{import_alias}.TreeErrorCount({result_var})"),
1135 is_pointer: true,
1136 value_cast: Some("uint"),
1137 },
1138 "tree_to_sexp" => GoMethodCallInfo {
1139 call_expr: format!("{import_alias}.TreeToSexp({result_var})"),
1140 is_pointer: true,
1141 value_cast: None,
1142 },
1143 "contains_node_type" => {
1144 let node_type = args
1145 .and_then(|a| a.get("node_type"))
1146 .and_then(|v| v.as_str())
1147 .unwrap_or("");
1148 GoMethodCallInfo {
1149 call_expr: format!("{import_alias}.TreeContainsNodeType({result_var}, \"{node_type}\")"),
1150 is_pointer: true,
1151 value_cast: None,
1152 }
1153 }
1154 "find_nodes_by_type" => {
1155 let node_type = args
1156 .and_then(|a| a.get("node_type"))
1157 .and_then(|v| v.as_str())
1158 .unwrap_or("");
1159 GoMethodCallInfo {
1160 call_expr: format!("{import_alias}.FindNodesByType({result_var}, \"{node_type}\")"),
1161 is_pointer: true,
1162 value_cast: None,
1163 }
1164 }
1165 "run_query" => {
1166 let query_source = args
1167 .and_then(|a| a.get("query_source"))
1168 .and_then(|v| v.as_str())
1169 .unwrap_or("");
1170 let language = args
1171 .and_then(|a| a.get("language"))
1172 .and_then(|v| v.as_str())
1173 .unwrap_or("");
1174 let query_lit = go_string_literal(query_source);
1175 let lang_lit = go_string_literal(language);
1176 GoMethodCallInfo {
1178 call_expr: format!("{import_alias}.RunQuery({result_var}, {lang_lit}, {query_lit}, []byte(source))"),
1179 is_pointer: false,
1180 value_cast: None,
1181 }
1182 }
1183 other => {
1184 let method_pascal = other.to_upper_camel_case();
1185 GoMethodCallInfo {
1186 call_expr: format!("{result_var}.{method_pascal}()"),
1187 is_pointer: false,
1188 value_cast: None,
1189 }
1190 }
1191 }
1192}
1193
1194fn json_to_go(value: &serde_json::Value) -> String {
1196 match value {
1197 serde_json::Value::String(s) => go_string_literal(s),
1198 serde_json::Value::Bool(b) => b.to_string(),
1199 serde_json::Value::Number(n) => n.to_string(),
1200 serde_json::Value::Null => "nil".to_string(),
1201 other => go_string_literal(&other.to_string()),
1203 }
1204}
1205
1206fn visitor_struct_name(fixture_id: &str) -> String {
1215 use heck::ToUpperCamelCase;
1216 format!("testVisitor{}", fixture_id.to_upper_camel_case())
1218}
1219
1220fn emit_go_visitor_struct(
1222 out: &mut String,
1223 struct_name: &str,
1224 visitor_spec: &crate::fixture::VisitorSpec,
1225 import_alias: &str,
1226) {
1227 let _ = writeln!(out, "type {struct_name} struct{{}}");
1228 for (method_name, action) in &visitor_spec.callbacks {
1229 emit_go_visitor_method(out, struct_name, method_name, action, import_alias);
1230 }
1231}
1232
1233fn emit_go_visitor_method(
1235 out: &mut String,
1236 struct_name: &str,
1237 method_name: &str,
1238 action: &CallbackAction,
1239 import_alias: &str,
1240) {
1241 let camel_method = method_to_camel(method_name);
1242 let params = match method_name {
1243 "visit_link" => format!("_ {import_alias}.NodeContext, href, text, title string"),
1244 "visit_image" => format!("_ {import_alias}.NodeContext, src, alt, title string"),
1245 "visit_heading" => format!("_ {import_alias}.NodeContext, level int, text, id string"),
1246 "visit_code_block" => format!("_ {import_alias}.NodeContext, lang, code string"),
1247 "visit_code_inline"
1248 | "visit_strong"
1249 | "visit_emphasis"
1250 | "visit_strikethrough"
1251 | "visit_underline"
1252 | "visit_subscript"
1253 | "visit_superscript"
1254 | "visit_mark"
1255 | "visit_button"
1256 | "visit_summary"
1257 | "visit_figcaption"
1258 | "visit_definition_term"
1259 | "visit_definition_description" => format!("_ {import_alias}.NodeContext, text string"),
1260 "visit_text" => format!("_ {import_alias}.NodeContext, text string"),
1261 "visit_list_item" => {
1262 format!("_ {import_alias}.NodeContext, ordered bool, marker, text string")
1263 }
1264 "visit_blockquote" => format!("_ {import_alias}.NodeContext, content string, depth int"),
1265 "visit_table_row" => format!("_ {import_alias}.NodeContext, cells []string, isHeader bool"),
1266 "visit_custom_element" => format!("_ {import_alias}.NodeContext, tagName, html string"),
1267 "visit_form" => format!("_ {import_alias}.NodeContext, actionUrl, method string"),
1268 "visit_input" => format!("_ {import_alias}.NodeContext, inputType, name, value string"),
1269 "visit_audio" | "visit_video" | "visit_iframe" => {
1270 format!("_ {import_alias}.NodeContext, src string")
1271 }
1272 "visit_details" => format!("_ {import_alias}.NodeContext, isOpen bool"),
1273 "visit_element_end" | "visit_table_end" | "visit_definition_list_end" | "visit_figure_end" => {
1274 format!("_ {import_alias}.NodeContext, output string")
1275 }
1276 "visit_list_start" => format!("_ {import_alias}.NodeContext, ordered bool"),
1277 "visit_list_end" => format!("_ {import_alias}.NodeContext, ordered bool, output string"),
1278 _ => format!("_ {import_alias}.NodeContext"),
1279 };
1280
1281 let _ = writeln!(
1282 out,
1283 "func (v *{struct_name}) {camel_method}({params}) {import_alias}.VisitResult {{"
1284 );
1285 match action {
1286 CallbackAction::Skip => {
1287 let _ = writeln!(out, "\treturn {import_alias}.VisitResultSkip");
1288 }
1289 CallbackAction::Continue => {
1290 let _ = writeln!(out, "\treturn {import_alias}.VisitResultContinue");
1291 }
1292 CallbackAction::PreserveHtml => {
1293 let _ = writeln!(out, "\treturn {import_alias}.VisitResultPreserveHtml");
1294 }
1295 CallbackAction::Custom { output } => {
1296 let escaped = go_string_literal(output);
1297 let _ = writeln!(out, "\treturn {import_alias}.VisitResultCustom({escaped})");
1298 }
1299 CallbackAction::CustomTemplate { template } => {
1300 let (fmt_str, fmt_args) = template_to_sprintf(template);
1303 let escaped_fmt = go_string_literal(&fmt_str);
1304 if fmt_args.is_empty() {
1305 let _ = writeln!(out, "\treturn {import_alias}.VisitResultCustom({escaped_fmt})");
1306 } else {
1307 let args_str = fmt_args.join(", ");
1308 let _ = writeln!(
1309 out,
1310 "\treturn {import_alias}.VisitResultCustom(fmt.Sprintf({escaped_fmt}, {args_str}))"
1311 );
1312 }
1313 }
1314 }
1315 let _ = writeln!(out, "}}");
1316}
1317
1318fn template_to_sprintf(template: &str) -> (String, Vec<String>) {
1322 let mut fmt_str = String::new();
1323 let mut args: Vec<String> = Vec::new();
1324 let mut chars = template.chars().peekable();
1325 while let Some(c) = chars.next() {
1326 if c == '{' {
1327 let mut name = String::new();
1329 for inner in chars.by_ref() {
1330 if inner == '}' {
1331 break;
1332 }
1333 name.push(inner);
1334 }
1335 fmt_str.push_str("%s");
1336 args.push(name);
1337 } else {
1338 fmt_str.push(c);
1339 }
1340 }
1341 (fmt_str, args)
1342}
1343
1344fn method_to_camel(snake: &str) -> String {
1346 use heck::ToUpperCamelCase;
1347 snake.to_upper_camel_case()
1348}