Skip to main content

intent_ir/
lower.rs

1//! AST → IR lowering pass.
2//!
3//! Converts a parsed intent AST into the typed IR representation.
4//! Every IR node gets a `SourceTrace` linking back to the originating spec element.
5
6use intent_parser::ast;
7
8use crate::types::*;
9
10/// Lower a parsed intent file into an IR module.
11pub fn lower_file(file: &ast::File) -> Module {
12    let module_name = &file.module.name;
13
14    let mut structs = Vec::new();
15    let mut functions = Vec::new();
16    let mut invariants = Vec::new();
17    let mut edge_guards = Vec::new();
18
19    for item in &file.items {
20        match item {
21            ast::TopLevelItem::Entity(e) => {
22                structs.push(lower_entity(module_name, e));
23            }
24            ast::TopLevelItem::Action(a) => {
25                functions.push(lower_action(module_name, a));
26            }
27            ast::TopLevelItem::Invariant(inv) => {
28                invariants.push(lower_invariant(module_name, inv));
29            }
30            ast::TopLevelItem::EdgeCases(ec) => {
31                for rule in &ec.rules {
32                    edge_guards.push(lower_edge_rule(module_name, rule));
33                }
34            }
35        }
36    }
37
38    Module {
39        name: module_name.clone(),
40        structs,
41        functions,
42        invariants,
43        edge_guards,
44    }
45}
46
47// ── Entity → Struct ─────────────────────────────────────────
48
49fn lower_entity(module: &str, entity: &ast::EntityDecl) -> Struct {
50    let fields = entity
51        .fields
52        .iter()
53        .map(|f| Field {
54            name: f.name.clone(),
55            ty: lower_type(&f.ty),
56            trace: SourceTrace {
57                module: module.to_string(),
58                item: entity.name.clone(),
59                part: format!("field:{}", f.name),
60                span: f.span,
61            },
62        })
63        .collect();
64
65    Struct {
66        name: entity.name.clone(),
67        fields,
68        trace: SourceTrace {
69            module: module.to_string(),
70            item: entity.name.clone(),
71            part: "entity".to_string(),
72            span: entity.span,
73        },
74    }
75}
76
77// ── Action → Function ───────────────────────────────────────
78
79fn lower_action(module: &str, action: &ast::ActionDecl) -> Function {
80    let params = action
81        .params
82        .iter()
83        .map(|p| Param {
84            name: p.name.clone(),
85            ty: lower_type(&p.ty),
86            trace: SourceTrace {
87                module: module.to_string(),
88                item: action.name.clone(),
89                part: format!("param:{}", p.name),
90                span: p.span,
91            },
92        })
93        .collect();
94
95    let preconditions = action
96        .requires
97        .as_ref()
98        .map(|req| {
99            req.conditions
100                .iter()
101                .map(|c| Condition {
102                    expr: lower_expr(c),
103                    trace: SourceTrace {
104                        module: module.to_string(),
105                        item: action.name.clone(),
106                        part: "requires".to_string(),
107                        span: c.span,
108                    },
109                })
110                .collect()
111        })
112        .unwrap_or_default();
113
114    let postconditions = action
115        .ensures
116        .as_ref()
117        .map(|ens| {
118            ens.items
119                .iter()
120                .map(|item| lower_ensures_item(module, &action.name, item))
121                .collect()
122        })
123        .unwrap_or_default();
124
125    let properties = action
126        .properties
127        .as_ref()
128        .map(|props| {
129            props
130                .entries
131                .iter()
132                .map(|e| Property {
133                    key: e.key.clone(),
134                    value: lower_prop_value(&e.value),
135                    trace: SourceTrace {
136                        module: module.to_string(),
137                        item: action.name.clone(),
138                        part: format!("property:{}", e.key),
139                        span: e.span,
140                    },
141                })
142                .collect()
143        })
144        .unwrap_or_default();
145
146    Function {
147        name: action.name.clone(),
148        params,
149        preconditions,
150        postconditions,
151        properties,
152        trace: SourceTrace {
153            module: module.to_string(),
154            item: action.name.clone(),
155            part: "action".to_string(),
156            span: action.span,
157        },
158    }
159}
160
161fn lower_ensures_item(module: &str, action: &str, item: &ast::EnsuresItem) -> Postcondition {
162    match item {
163        ast::EnsuresItem::Expr(e) => Postcondition::Always {
164            expr: lower_expr(e),
165            trace: SourceTrace {
166                module: module.to_string(),
167                item: action.to_string(),
168                part: "ensures".to_string(),
169                span: e.span,
170            },
171        },
172        ast::EnsuresItem::When(w) => Postcondition::When {
173            guard: lower_expr(&w.condition),
174            expr: lower_expr(&w.consequence),
175            trace: SourceTrace {
176                module: module.to_string(),
177                item: action.to_string(),
178                part: "ensures:when".to_string(),
179                span: w.span,
180            },
181        },
182    }
183}
184
185// ── Invariant ───────────────────────────────────────────────
186
187fn lower_invariant(module: &str, inv: &ast::InvariantDecl) -> Invariant {
188    Invariant {
189        name: inv.name.clone(),
190        expr: lower_expr(&inv.body),
191        trace: SourceTrace {
192            module: module.to_string(),
193            item: inv.name.clone(),
194            part: "invariant".to_string(),
195            span: inv.span,
196        },
197    }
198}
199
200// ── Edge rule → EdgeGuard ───────────────────────────────────
201
202fn lower_edge_rule(module: &str, rule: &ast::EdgeRule) -> EdgeGuard {
203    let args = rule
204        .action
205        .args
206        .iter()
207        .map(|arg| match arg {
208            ast::CallArg::Named { key, value, .. } => (key.clone(), lower_expr(value)),
209            ast::CallArg::Positional(e) => (String::new(), lower_expr(e)),
210        })
211        .collect();
212
213    EdgeGuard {
214        condition: lower_expr(&rule.condition),
215        action: rule.action.name.clone(),
216        args,
217        trace: SourceTrace {
218            module: module.to_string(),
219            item: "edge_cases".to_string(),
220            part: format!("when:{}", rule.action.name),
221            span: rule.span,
222        },
223    }
224}
225
226// ── Type lowering ───────────────────────────────────────────
227
228fn lower_type(ty: &ast::TypeExpr) -> IrType {
229    let base = lower_type_kind(&ty.ty);
230    if ty.optional {
231        IrType::Optional(Box::new(base))
232    } else {
233        base
234    }
235}
236
237fn lower_type_kind(kind: &ast::TypeKind) -> IrType {
238    match kind {
239        ast::TypeKind::Simple(name) => {
240            // Recognize known struct-like types vs primitives.
241            // During lowering we treat everything as Named; the verifier
242            // resolves struct references later.
243            IrType::Named(name.clone())
244        }
245        ast::TypeKind::Union(variants) => {
246            let names: Vec<String> = variants
247                .iter()
248                .filter_map(|v| {
249                    if let ast::TypeKind::Simple(name) = v {
250                        Some(name.clone())
251                    } else {
252                        None
253                    }
254                })
255                .collect();
256            IrType::Union(names)
257        }
258        ast::TypeKind::List(inner) => IrType::List(Box::new(lower_type(inner))),
259        ast::TypeKind::Set(inner) => IrType::Set(Box::new(lower_type(inner))),
260        ast::TypeKind::Map(k, v) => IrType::Map(Box::new(lower_type(k)), Box::new(lower_type(v))),
261        ast::TypeKind::Parameterized { name, params } => {
262            if name == "Decimal" {
263                let precision = params
264                    .iter()
265                    .find(|p| p.name == "precision")
266                    .and_then(|p| {
267                        if let ast::Literal::Int(n) = &p.value {
268                            Some(*n as u32)
269                        } else {
270                            None
271                        }
272                    })
273                    .unwrap_or(0);
274                IrType::Decimal(precision)
275            } else {
276                IrType::Named(name.clone())
277            }
278        }
279    }
280}
281
282// ── Expression lowering ─────────────────────────────────────
283
284fn lower_expr(expr: &ast::Expr) -> IrExpr {
285    match &expr.kind {
286        ast::ExprKind::Ident(name) => IrExpr::Var(name.clone()),
287        ast::ExprKind::Literal(lit) => IrExpr::Literal(lower_literal(lit)),
288        ast::ExprKind::FieldAccess { root, fields } => {
289            let mut current = lower_expr(root);
290            for field in fields {
291                current = IrExpr::FieldAccess {
292                    root: Box::new(current),
293                    field: field.clone(),
294                };
295            }
296            current
297        }
298        ast::ExprKind::Compare { left, op, right } => IrExpr::Compare {
299            left: Box::new(lower_expr(left)),
300            op: lower_cmp_op(*op),
301            right: Box::new(lower_expr(right)),
302        },
303        ast::ExprKind::Arithmetic { left, op, right } => IrExpr::Arithmetic {
304            left: Box::new(lower_expr(left)),
305            op: lower_arith_op(*op),
306            right: Box::new(lower_expr(right)),
307        },
308        ast::ExprKind::And(a, b) => IrExpr::And(Box::new(lower_expr(a)), Box::new(lower_expr(b))),
309        ast::ExprKind::Or(a, b) => IrExpr::Or(Box::new(lower_expr(a)), Box::new(lower_expr(b))),
310        ast::ExprKind::Not(inner) => IrExpr::Not(Box::new(lower_expr(inner))),
311        ast::ExprKind::Implies(a, b) => {
312            IrExpr::Implies(Box::new(lower_expr(a)), Box::new(lower_expr(b)))
313        }
314        ast::ExprKind::Old(inner) => IrExpr::Old(Box::new(lower_expr(inner))),
315        ast::ExprKind::Quantifier {
316            kind,
317            binding,
318            ty,
319            body,
320        } => match kind {
321            ast::QuantifierKind::Forall => IrExpr::Forall {
322                binding: binding.clone(),
323                ty: ty.clone(),
324                body: Box::new(lower_expr(body)),
325            },
326            ast::QuantifierKind::Exists => IrExpr::Exists {
327                binding: binding.clone(),
328                ty: ty.clone(),
329                body: Box::new(lower_expr(body)),
330            },
331        },
332        ast::ExprKind::Call { name, args } => IrExpr::Call {
333            name: name.clone(),
334            args: args
335                .iter()
336                .map(|a| match a {
337                    ast::CallArg::Named { value, .. } => lower_expr(value),
338                    ast::CallArg::Positional(e) => lower_expr(e),
339                })
340                .collect(),
341        },
342    }
343}
344
345fn lower_literal(lit: &ast::Literal) -> IrLiteral {
346    match lit {
347        ast::Literal::Null => IrLiteral::Null,
348        ast::Literal::Bool(b) => IrLiteral::Bool(*b),
349        ast::Literal::Int(n) => IrLiteral::Int(*n),
350        ast::Literal::Decimal(s) => IrLiteral::Decimal(s.clone()),
351        ast::Literal::String(s) => IrLiteral::String(s.clone()),
352    }
353}
354
355fn lower_cmp_op(op: ast::CmpOp) -> CmpOp {
356    match op {
357        ast::CmpOp::Eq => CmpOp::Eq,
358        ast::CmpOp::Ne => CmpOp::Ne,
359        ast::CmpOp::Lt => CmpOp::Lt,
360        ast::CmpOp::Gt => CmpOp::Gt,
361        ast::CmpOp::Le => CmpOp::Le,
362        ast::CmpOp::Ge => CmpOp::Ge,
363    }
364}
365
366fn lower_arith_op(op: ast::ArithOp) -> ArithOp {
367    match op {
368        ast::ArithOp::Add => ArithOp::Add,
369        ast::ArithOp::Sub => ArithOp::Sub,
370    }
371}
372
373fn lower_prop_value(val: &ast::PropValue) -> PropertyValue {
374    match val {
375        ast::PropValue::Literal(ast::Literal::Bool(b)) => PropertyValue::Bool(*b),
376        ast::PropValue::Literal(ast::Literal::Int(n)) => PropertyValue::Int(*n),
377        ast::PropValue::Literal(ast::Literal::String(s)) => PropertyValue::String(s.clone()),
378        ast::PropValue::Ident(s) => PropertyValue::Ident(s.clone()),
379        // For complex prop values, fall back to string representation
380        ast::PropValue::Literal(ast::Literal::Null) => PropertyValue::String("null".to_string()),
381        ast::PropValue::Literal(ast::Literal::Decimal(s)) => PropertyValue::String(s.clone()),
382        ast::PropValue::List(_) | ast::PropValue::Object(_) => {
383            PropertyValue::String("<complex>".to_string())
384        }
385    }
386}