Skip to main content

intent_codegen/
rust.rs

1//! Rust skeleton code generator.
2
3use intent_parser::ast;
4
5use crate::types::map_type;
6use crate::{Language, doc_text, format_ensures_item, format_expr, to_snake_case};
7
8/// Rust reserved keywords that cannot be used as identifiers.
9const RUST_KEYWORDS: &[&str] = &[
10    "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern",
11    "false", "fn", "for", "gen", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut",
12    "pub", "ref", "return", "self", "static", "struct", "super", "trait", "true", "type", "unsafe",
13    "use", "where", "while", "yield",
14];
15
16/// Escape a Rust identifier if it collides with a reserved keyword.
17fn safe_ident(name: &str) -> String {
18    let snake = to_snake_case(name);
19    if RUST_KEYWORDS.contains(&snake.as_str()) {
20        format!("r#{snake}")
21    } else {
22        snake
23    }
24}
25
26/// Generate Rust skeleton code from a parsed intent file.
27pub fn generate(file: &ast::File) -> String {
28    let lang = Language::Rust;
29    let mut out = String::new();
30
31    // Header
32    out.push_str(&format!("//! Generated from {}.intent\n", file.module.name));
33    if let Some(doc) = &file.doc {
34        for line in &doc.lines {
35            out.push_str(&format!("//! {line}\n"));
36        }
37    }
38    out.push('\n');
39
40    // Imports (based on types used)
41    out.push_str(&generate_imports(file));
42    out.push('\n');
43
44    for item in &file.items {
45        match item {
46            ast::TopLevelItem::Entity(e) => generate_entity(&mut out, e, &lang),
47            ast::TopLevelItem::Action(a) => generate_action(&mut out, a, &lang),
48            ast::TopLevelItem::Invariant(inv) => generate_invariant(&mut out, inv),
49            ast::TopLevelItem::EdgeCases(ec) => generate_edge_cases(&mut out, ec),
50            ast::TopLevelItem::Test(_) => {}
51        }
52    }
53
54    out
55}
56
57fn generate_imports(file: &ast::File) -> String {
58    let mut imports = Vec::new();
59    let source = collect_type_names(file);
60
61    if source.contains("UUID") {
62        imports.push("use uuid::Uuid;");
63    }
64    if source.contains("Decimal") {
65        imports.push("use rust_decimal::Decimal;");
66    }
67    if source.contains("DateTime") {
68        imports.push("use chrono::{DateTime, Utc};");
69    }
70    if source.contains("Set<") {
71        imports.push("use std::collections::HashSet;");
72    }
73    if source.contains("Map<") {
74        imports.push("use std::collections::HashMap;");
75    }
76
77    if imports.is_empty() {
78        String::new()
79    } else {
80        imports.join("\n") + "\n"
81    }
82}
83
84/// Collect all type names as a single string for import detection.
85fn collect_type_names(file: &ast::File) -> String {
86    let mut names = String::new();
87    for item in &file.items {
88        match item {
89            ast::TopLevelItem::Entity(e) => {
90                for f in &e.fields {
91                    names.push_str(&format_type_for_scan(&f.ty));
92                    names.push(' ');
93                }
94            }
95            ast::TopLevelItem::Action(a) => {
96                for p in &a.params {
97                    names.push_str(&format_type_for_scan(&p.ty));
98                    names.push(' ');
99                }
100            }
101            _ => {}
102        }
103    }
104    names
105}
106
107fn format_type_for_scan(ty: &ast::TypeExpr) -> String {
108    match &ty.ty {
109        ast::TypeKind::Simple(n) => n.clone(),
110        ast::TypeKind::List(inner) => format!("List<{}>", format_type_for_scan(inner)),
111        ast::TypeKind::Set(inner) => format!("Set<{}>", format_type_for_scan(inner)),
112        ast::TypeKind::Map(k, v) => {
113            format!(
114                "Map<{}, {}>",
115                format_type_for_scan(k),
116                format_type_for_scan(v)
117            )
118        }
119        ast::TypeKind::Union(variants) => variants
120            .iter()
121            .filter_map(|v| match v {
122                ast::TypeKind::Simple(n) => Some(n.clone()),
123                _ => None,
124            })
125            .collect::<Vec<_>>()
126            .join(" "),
127        ast::TypeKind::Parameterized { name, .. } => name.clone(),
128    }
129}
130
131fn generate_entity(out: &mut String, entity: &ast::EntityDecl, lang: &Language) {
132    // Emit union enums for any union-typed fields
133    for field in &entity.fields {
134        if let ast::TypeKind::Union(variants) = &field.ty.ty {
135            let enum_name = format!("{}{}", entity.name, capitalize(&field.name));
136            generate_union_enum(out, &enum_name, variants);
137        }
138    }
139
140    // Doc comment
141    if let Some(doc) = &entity.doc {
142        for line in doc_text(doc).lines() {
143            out.push_str(&format!("/// {line}\n"));
144        }
145    }
146
147    out.push_str("#[derive(Debug, Clone)]\n");
148    out.push_str(&format!("pub struct {} {{\n", entity.name));
149
150    for field in &entity.fields {
151        let ty = if let ast::TypeKind::Union(_) = &field.ty.ty {
152            let enum_name = format!("{}{}", entity.name, capitalize(&field.name));
153            if field.ty.optional {
154                format!("Option<{enum_name}>")
155            } else {
156                enum_name
157            }
158        } else {
159            map_type(&field.ty, lang)
160        };
161        out.push_str(&format!("    pub {}: {},\n", safe_ident(&field.name), ty));
162    }
163
164    out.push_str("}\n\n");
165}
166
167fn generate_union_enum(out: &mut String, name: &str, variants: &[ast::TypeKind]) {
168    out.push_str("#[derive(Debug, Clone, PartialEq, Eq)]\n");
169    out.push_str(&format!("pub enum {name} {{\n"));
170    for v in variants {
171        if let ast::TypeKind::Simple(n) = v {
172            out.push_str(&format!("    {n},\n"));
173        }
174    }
175    out.push_str("}\n\n");
176}
177
178fn generate_action(out: &mut String, action: &ast::ActionDecl, lang: &Language) {
179    // Doc comment
180    if let Some(doc) = &action.doc {
181        for line in doc_text(doc).lines() {
182            out.push_str(&format!("/// {line}\n"));
183        }
184    }
185    out.push_str("///\n");
186
187    // Requires
188    if let Some(req) = &action.requires {
189        out.push_str("/// # Requires\n");
190        for cond in &req.conditions {
191            out.push_str(&format!("/// - `{}`\n", format_expr(cond)));
192        }
193        out.push_str("///\n");
194    }
195
196    // Ensures
197    if let Some(ens) = &action.ensures {
198        out.push_str("/// # Ensures\n");
199        for item in &ens.items {
200            out.push_str(&format!("/// - `{}`\n", format_ensures_item(item)));
201        }
202        out.push_str("///\n");
203    }
204
205    // Properties
206    if let Some(props) = &action.properties {
207        out.push_str("/// # Properties\n");
208        for entry in &props.entries {
209            out.push_str(&format!(
210                "/// - {}: {}\n",
211                entry.key,
212                crate::format_prop_value(&entry.value)
213            ));
214        }
215        out.push_str("///\n");
216    }
217
218    // Function signature
219    let fn_name = to_snake_case(&action.name);
220    let params: Vec<String> = action
221        .params
222        .iter()
223        .map(|p| {
224            let ty = map_type(&p.ty, lang);
225            format!("{}: {ty}", safe_ident(&p.name))
226        })
227        .collect();
228
229    out.push_str(&format!(
230        "pub fn {fn_name}({}) -> Result<(), Box<dyn std::error::Error>> {{\n",
231        params.join(", ")
232    ));
233    out.push_str(&format!("    todo!(\"Implement {fn_name}\")\n"));
234    out.push_str("}\n\n");
235}
236
237fn generate_invariant(out: &mut String, inv: &ast::InvariantDecl) {
238    out.push_str(&format!("// Invariant: {}\n", inv.name));
239    if let Some(doc) = &inv.doc {
240        for line in doc_text(doc).lines() {
241            out.push_str(&format!("// {line}\n"));
242        }
243    }
244    out.push_str(&format!("// {}\n\n", format_expr(&inv.body)));
245}
246
247fn generate_edge_cases(out: &mut String, ec: &ast::EdgeCasesDecl) {
248    out.push_str("// Edge cases:\n");
249    for rule in &ec.rules {
250        out.push_str(&format!(
251            "// when {} => {}()\n",
252            format_expr(&rule.condition),
253            rule.action.name,
254        ));
255    }
256    out.push('\n');
257}
258
259fn capitalize(s: &str) -> String {
260    let mut chars = s.chars();
261    match chars.next() {
262        None => String::new(),
263        Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
264    }
265}