Skip to main content

intent_codegen/
go.rs

1//! Go skeleton code generator.
2
3use intent_parser::ast;
4
5use crate::types::map_type;
6use crate::{Language, doc_text, format_ensures_item, format_expr, to_snake_case};
7
8/// Go reserved keywords that cannot be used as identifiers.
9const GO_KEYWORDS: &[&str] = &[
10    "break",
11    "case",
12    "chan",
13    "const",
14    "continue",
15    "default",
16    "defer",
17    "else",
18    "fallthrough",
19    "for",
20    "func",
21    "go",
22    "goto",
23    "if",
24    "import",
25    "interface",
26    "map",
27    "package",
28    "range",
29    "return",
30    "select",
31    "struct",
32    "switch",
33    "type",
34    "var",
35];
36
37/// Escape a Go identifier if it collides with a reserved keyword.
38fn safe_ident(name: &str) -> String {
39    let snake = to_snake_case(name);
40    if GO_KEYWORDS.contains(&snake.as_str()) {
41        format!("{snake}_")
42    } else {
43        snake
44    }
45}
46
47/// Capitalize the first character of a string (for exported Go names).
48fn capitalize(s: &str) -> String {
49    let mut chars = s.chars();
50    match chars.next() {
51        None => String::new(),
52        Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
53    }
54}
55
56/// Convert a snake_case or lowercase name to PascalCase (exported Go identifier).
57fn to_pascal_case(s: &str) -> String {
58    s.split('_').map(capitalize).collect::<String>()
59}
60
61/// Convert a field name to a JSON struct tag value (camelCase).
62fn to_json_tag(s: &str) -> String {
63    crate::to_camel_case(s)
64}
65
66/// Generate Go skeleton code from a parsed intent file.
67pub fn generate(file: &ast::File) -> String {
68    let lang = Language::Go;
69    let mut out = String::new();
70
71    // Header
72    out.push_str(&format!(
73        "// Code generated from {}.intent. DO NOT EDIT.\n",
74        file.module.name
75    ));
76    if let Some(doc) = &file.doc {
77        out.push('\n');
78        for line in &doc.lines {
79            out.push_str(&format!("// {line}\n"));
80        }
81    }
82    out.push('\n');
83
84    // Package declaration (lowercase module name)
85    let pkg_name = file.module.name.to_lowercase();
86    out.push_str(&format!("package {pkg_name}\n\n"));
87
88    // Imports
89    let imports = generate_imports(file);
90    if !imports.is_empty() {
91        out.push_str(&imports);
92        out.push('\n');
93    }
94
95    for item in &file.items {
96        match item {
97            ast::TopLevelItem::Entity(e) => generate_entity(&mut out, e, &lang),
98            ast::TopLevelItem::Action(a) => generate_action(&mut out, a, &lang),
99            ast::TopLevelItem::Invariant(inv) => generate_invariant(&mut out, inv),
100            ast::TopLevelItem::EdgeCases(ec) => generate_edge_cases(&mut out, ec),
101            ast::TopLevelItem::Test(_) => {}
102        }
103    }
104
105    out
106}
107
108fn generate_imports(file: &ast::File) -> String {
109    let source = collect_type_names(file);
110    let has_union = file.items.iter().any(|item| {
111        if let ast::TopLevelItem::Entity(e) = item {
112            e.fields
113                .iter()
114                .any(|f| matches!(f.ty.ty, ast::TypeKind::Union(_)))
115        } else {
116            false
117        }
118    });
119    let has_action = file
120        .items
121        .iter()
122        .any(|item| matches!(item, ast::TopLevelItem::Action(_)));
123
124    let mut imports = Vec::new();
125
126    if has_action || has_union {
127        // errors and fmt are stdlib, always available
128        if has_action {
129            imports.push("\"errors\"");
130        }
131        if has_union {
132            imports.push("\"fmt\"");
133        }
134    }
135    if source.contains("DateTime") {
136        imports.push("\"time\"");
137    }
138    // External packages
139    if source.contains("Decimal") {
140        imports.push("\"github.com/shopspring/decimal\"");
141    }
142    if source.contains("UUID") {
143        imports.push("\"github.com/google/uuid\"");
144    }
145
146    if imports.is_empty() {
147        return String::new();
148    }
149
150    if imports.len() == 1 {
151        return format!("import {}\n", imports[0]);
152    }
153
154    let mut out = String::from("import (\n");
155    for imp in &imports {
156        out.push_str(&format!("\t{imp}\n"));
157    }
158    out.push_str(")\n");
159    out
160}
161
162/// Collect all type names as a single string for import detection.
163fn collect_type_names(file: &ast::File) -> String {
164    let mut names = String::new();
165    for item in &file.items {
166        match item {
167            ast::TopLevelItem::Entity(e) => {
168                for f in &e.fields {
169                    collect_type_name(&f.ty, &mut names);
170                }
171            }
172            ast::TopLevelItem::Action(a) => {
173                for p in &a.params {
174                    collect_type_name(&p.ty, &mut names);
175                }
176            }
177            _ => {}
178        }
179    }
180    names
181}
182
183fn collect_type_name(ty: &ast::TypeExpr, out: &mut String) {
184    match &ty.ty {
185        ast::TypeKind::Simple(n) => {
186            out.push_str(n);
187            out.push(' ');
188        }
189        ast::TypeKind::Parameterized { name, .. } => {
190            out.push_str(name);
191            out.push(' ');
192        }
193        ast::TypeKind::List(inner) | ast::TypeKind::Set(inner) => collect_type_name(inner, out),
194        ast::TypeKind::Map(k, v) => {
195            collect_type_name(k, out);
196            collect_type_name(v, out);
197        }
198        ast::TypeKind::Union(_) => {}
199    }
200}
201
202fn generate_entity(out: &mut String, entity: &ast::EntityDecl, lang: &Language) {
203    // Emit union type aliases with const blocks for union-typed fields
204    for field in &entity.fields {
205        if let ast::TypeKind::Union(variants) = &field.ty.ty {
206            let type_name = format!("{}{}", entity.name, capitalize(&field.name));
207            generate_union_type(out, &type_name, variants);
208        }
209    }
210
211    // Doc comment
212    if let Some(doc) = &entity.doc {
213        for line in doc_text(doc).lines() {
214            out.push_str(&format!("// {line}\n"));
215        }
216    }
217
218    out.push_str(&format!("type {} struct {{\n", entity.name));
219
220    for field in &entity.fields {
221        let field_name = to_pascal_case(&field.name);
222        let json_tag = to_json_tag(&field.name);
223        let ty = if let ast::TypeKind::Union(_) = &field.ty.ty {
224            let type_name = format!("{}{}", entity.name, capitalize(&field.name));
225            if field.ty.optional {
226                format!("*{type_name}")
227            } else {
228                type_name
229            }
230        } else {
231            map_type(&field.ty, lang)
232        };
233        out.push_str(&format!("\t{field_name} {ty} `json:\"{json_tag}\"`\n"));
234    }
235
236    out.push_str("}\n\n");
237}
238
239fn generate_union_type(out: &mut String, name: &str, variants: &[ast::TypeKind]) {
240    let names: Vec<&str> = variants
241        .iter()
242        .filter_map(|v| match v {
243            ast::TypeKind::Simple(n) => Some(n.as_str()),
244            _ => None,
245        })
246        .collect();
247
248    out.push_str(&format!(
249        "// {name} represents the allowed values for this field.\n"
250    ));
251    out.push_str(&format!("type {name} string\n\n"));
252
253    out.push_str("const (\n");
254    for n in &names {
255        let const_name = format!("{name}{n}");
256        out.push_str(&format!("\t{const_name} {name} = \"{n}\"\n"));
257    }
258    out.push_str(")\n\n");
259
260    // Validate method
261    out.push_str(&format!(
262        "// Valid returns true if v is a known {name} value.\n"
263    ));
264    out.push_str(&format!("func (v {name}) Valid() bool {{\n"));
265    out.push_str("\tswitch v {\n");
266    out.push_str(&format!(
267        "\tcase {}:\n",
268        names
269            .iter()
270            .map(|n| format!("{name}{n}"))
271            .collect::<Vec<_>>()
272            .join(", ")
273    ));
274    out.push_str("\t\treturn true\n");
275    out.push_str("\tdefault:\n");
276    out.push_str("\t\treturn false\n");
277    out.push_str("\t}\n");
278    out.push_str("}\n\n");
279
280    // String method
281    out.push_str(&format!("func (v {name}) String() string {{\n"));
282    out.push_str("\treturn string(v)\n");
283    out.push_str("}\n\n");
284
285    // UnmarshalText for safe JSON deserialization
286    out.push_str(&format!(
287        "// UnmarshalText implements encoding.TextUnmarshaler for {name}.\n"
288    ));
289    out.push_str(&format!(
290        "func (v *{name}) UnmarshalText(data []byte) error {{\n"
291    ));
292    out.push_str(&format!("\ts := {name}(data)\n"));
293    out.push_str("\tif !s.Valid() {\n");
294    out.push_str(&format!(
295        "\t\treturn fmt.Errorf(\"invalid {name}: %q\", string(data))\n"
296    ));
297    out.push_str("\t}\n");
298    out.push_str("\t*v = s\n");
299    out.push_str("\treturn nil\n");
300    out.push_str("}\n\n");
301}
302
303fn generate_action(out: &mut String, action: &ast::ActionDecl, lang: &Language) {
304    let fn_name = to_pascal_case(&to_snake_case(&action.name));
305
306    // Doc comment
307    if let Some(doc) = &action.doc {
308        out.push_str(&format!("// {fn_name} — {}\n", doc_text(doc)));
309    }
310
311    // Requires
312    if let Some(req) = &action.requires {
313        out.push_str("//\n// Requires:\n");
314        for cond in &req.conditions {
315            out.push_str(&format!("//   - {}\n", format_expr(cond)));
316        }
317    }
318
319    // Ensures
320    if let Some(ens) = &action.ensures {
321        out.push_str("//\n// Ensures:\n");
322        for item in &ens.items {
323            out.push_str(&format!("//   - {}\n", format_ensures_item(item)));
324        }
325    }
326
327    // Properties
328    if let Some(props) = &action.properties {
329        out.push_str("//\n// Properties:\n");
330        for entry in &props.entries {
331            out.push_str(&format!(
332                "//   - {}: {}\n",
333                entry.key,
334                crate::format_prop_value(&entry.value)
335            ));
336        }
337    }
338
339    // Function signature
340    let params: Vec<String> = action
341        .params
342        .iter()
343        .map(|p| {
344            let ty = map_type(&p.ty, lang);
345            format!("{} {ty}", safe_ident(&p.name))
346        })
347        .collect();
348
349    out.push_str(&format!("func {fn_name}({}) error {{\n", params.join(", ")));
350    out.push_str(&format!(
351        "\treturn errors.New(\"TODO: implement {fn_name}\")\n"
352    ));
353    out.push_str("}\n\n");
354}
355
356fn generate_invariant(out: &mut String, inv: &ast::InvariantDecl) {
357    out.push_str(&format!("// Invariant: {}\n", inv.name));
358    if let Some(doc) = &inv.doc {
359        for line in doc_text(doc).lines() {
360            out.push_str(&format!("// {line}\n"));
361        }
362    }
363    out.push_str(&format!("// {}\n\n", format_expr(&inv.body)));
364}
365
366fn generate_edge_cases(out: &mut String, ec: &ast::EdgeCasesDecl) {
367    out.push_str("// Edge cases:\n");
368    for rule in &ec.rules {
369        out.push_str(&format!(
370            "// when {} => {}()\n",
371            format_expr(&rule.condition),
372            rule.action.name,
373        ));
374    }
375    out.push('\n');
376}