Skip to main content

runar_compiler_rust/frontend/
validator.rs

1//! Pass 2: Validate
2//!
3//! Validates the Rúnar AST against the language subset constraints.
4//! This pass does NOT modify the AST; it only reports errors and warnings.
5
6use std::collections::{HashMap, HashSet};
7
8use super::ast::*;
9
10// ---------------------------------------------------------------------------
11// Public API
12// ---------------------------------------------------------------------------
13
14/// Result of validation.
15pub struct ValidationResult {
16    pub errors: Vec<String>,
17    pub warnings: Vec<String>,
18}
19
20/// Validate a parsed Rúnar AST against the language subset constraints.
21pub fn validate(contract: &ContractNode) -> ValidationResult {
22    let mut errors = Vec::new();
23    let warnings = Vec::new();
24
25    validate_properties(contract, &mut errors);
26    validate_constructor(contract, &mut errors);
27    validate_methods(contract, &mut errors);
28    check_no_recursion(contract, &mut errors);
29
30    ValidationResult { errors, warnings }
31}
32
33// ---------------------------------------------------------------------------
34// Valid primitive types for properties
35// ---------------------------------------------------------------------------
36
37fn is_valid_property_primitive(name: &PrimitiveTypeName) -> bool {
38    match name {
39        PrimitiveTypeName::Bigint
40        | PrimitiveTypeName::Boolean
41        | PrimitiveTypeName::ByteString
42        | PrimitiveTypeName::PubKey
43        | PrimitiveTypeName::Sig
44        | PrimitiveTypeName::Sha256
45        | PrimitiveTypeName::Ripemd160
46        | PrimitiveTypeName::Addr
47        | PrimitiveTypeName::SigHashPreimage
48        | PrimitiveTypeName::RabinSig
49        | PrimitiveTypeName::RabinPubKey
50        | PrimitiveTypeName::Point => true,
51        PrimitiveTypeName::Void => false,
52    }
53}
54
55// ---------------------------------------------------------------------------
56// Property validation
57// ---------------------------------------------------------------------------
58
59fn validate_properties(contract: &ContractNode, errors: &mut Vec<String>) {
60    for prop in &contract.properties {
61        validate_property_type(&prop.prop_type, errors);
62    }
63}
64
65fn validate_property_type(type_node: &TypeNode, errors: &mut Vec<String>) {
66    match type_node {
67        TypeNode::Primitive(name) => {
68            if !is_valid_property_primitive(name) {
69                errors.push(format!("Property type '{}' is not valid", name.as_str()));
70            }
71        }
72        TypeNode::FixedArray { element, length } => {
73            if *length == 0 {
74                errors.push("FixedArray length must be a positive integer".to_string());
75            }
76            validate_property_type(element, errors);
77        }
78        TypeNode::Custom(name) => {
79            errors.push(format!(
80                "Unsupported type '{}' in property declaration. Use one of: bigint, boolean, ByteString, PubKey, Sig, Sha256, Ripemd160, Addr, SigHashPreimage, RabinSig, RabinPubKey, or FixedArray<T, N>",
81                name
82            ));
83        }
84    }
85}
86
87// ---------------------------------------------------------------------------
88// Constructor validation
89// ---------------------------------------------------------------------------
90
91fn validate_constructor(contract: &ContractNode, errors: &mut Vec<String>) {
92    let ctor = &contract.constructor;
93    let prop_names: HashSet<String> = contract.properties.iter().map(|p| p.name.clone()).collect();
94
95    // Check that constructor has a super() call as first statement
96    if ctor.body.is_empty() {
97        errors.push("Constructor must call super() as its first statement".to_string());
98        return;
99    }
100
101    if !is_super_call(&ctor.body[0]) {
102        errors.push("Constructor must call super() as its first statement".to_string());
103    }
104
105    // Check that all properties are assigned in constructor
106    let mut assigned_props = HashSet::new();
107    for stmt in &ctor.body {
108        if let Statement::Assignment { target, .. } = stmt {
109            if let Expression::PropertyAccess { property } = target {
110                assigned_props.insert(property.clone());
111            }
112        }
113    }
114
115    // Properties with initializers don't need constructor assignments
116    let props_with_init: HashSet<String> = contract
117        .properties
118        .iter()
119        .filter(|p| p.initializer.is_some())
120        .map(|p| p.name.clone())
121        .collect();
122
123    for prop_name in &prop_names {
124        if !assigned_props.contains(prop_name) && !props_with_init.contains(prop_name) {
125            errors.push(format!(
126                "Property '{}' must be assigned in the constructor",
127                prop_name
128            ));
129        }
130    }
131
132    // Validate constructor params have type annotations
133    for param in &ctor.params {
134        if let TypeNode::Custom(ref name) = param.param_type {
135            if name == "unknown" {
136                errors.push(format!(
137                    "Constructor parameter '{}' must have a type annotation",
138                    param.name
139                ));
140            }
141        }
142    }
143
144    // Validate statements in constructor body
145    for stmt in &ctor.body {
146        validate_statement(stmt, errors);
147    }
148}
149
150fn is_super_call(stmt: &Statement) -> bool {
151    if let Statement::ExpressionStatement { expression, .. } = stmt {
152        if let Expression::CallExpr { callee, .. } = expression {
153            if let Expression::Identifier { name } = callee.as_ref() {
154                return name == "super";
155            }
156        }
157    }
158    false
159}
160
161// ---------------------------------------------------------------------------
162// Method validation
163// ---------------------------------------------------------------------------
164
165fn validate_methods(contract: &ContractNode, errors: &mut Vec<String>) {
166    for method in &contract.methods {
167        validate_method(method, contract, errors);
168    }
169}
170
171fn validate_method(method: &MethodNode, contract: &ContractNode, errors: &mut Vec<String>) {
172    // All params must have type annotations
173    for param in &method.params {
174        if let TypeNode::Custom(ref name) = param.param_type {
175            if name == "unknown" {
176                errors.push(format!(
177                    "Parameter '{}' in method '{}' must have a type annotation",
178                    param.name, method.name
179                ));
180            }
181        }
182    }
183
184    // Public methods must end with an assert() call (unless StatefulSmartContract,
185    // where the compiler auto-injects the final assert)
186    if method.visibility == Visibility::Public && contract.parent_class == "SmartContract" {
187        if !ends_with_assert(&method.body) {
188            errors.push(format!(
189                "Public method '{}' must end with an assert() call",
190                method.name
191            ));
192        }
193    }
194
195    // Validate all statements in method body
196    for stmt in &method.body {
197        validate_statement(stmt, errors);
198    }
199}
200
201fn ends_with_assert(body: &[Statement]) -> bool {
202    if body.is_empty() {
203        return false;
204    }
205
206    let last = &body[body.len() - 1];
207
208    // Direct assert() call as expression statement
209    if let Statement::ExpressionStatement { expression, .. } = last {
210        if is_assert_call(expression) {
211            return true;
212        }
213    }
214
215    // If/else where both branches end with assert
216    if let Statement::IfStatement {
217        then_branch,
218        else_branch,
219        ..
220    } = last
221    {
222        let then_ends = ends_with_assert(then_branch);
223        let else_ends = else_branch
224            .as_ref()
225            .map_or(false, |e| ends_with_assert(e));
226        return then_ends && else_ends;
227    }
228
229    false
230}
231
232fn is_assert_call(expr: &Expression) -> bool {
233    if let Expression::CallExpr { callee, .. } = expr {
234        if let Expression::Identifier { name } = callee.as_ref() {
235            return name == "assert";
236        }
237    }
238    false
239}
240
241// ---------------------------------------------------------------------------
242// Statement validation
243// ---------------------------------------------------------------------------
244
245fn validate_statement(stmt: &Statement, errors: &mut Vec<String>) {
246    match stmt {
247        Statement::VariableDecl { init, .. } => {
248            validate_expression(init, errors);
249        }
250        Statement::Assignment { target, value, .. } => {
251            validate_expression(target, errors);
252            validate_expression(value, errors);
253        }
254        Statement::IfStatement {
255            condition,
256            then_branch,
257            else_branch,
258            ..
259        } => {
260            validate_expression(condition, errors);
261            for s in then_branch {
262                validate_statement(s, errors);
263            }
264            if let Some(else_stmts) = else_branch {
265                for s in else_stmts {
266                    validate_statement(s, errors);
267                }
268            }
269        }
270        Statement::ForStatement {
271            condition,
272            init,
273            body,
274            ..
275        } => {
276            validate_expression(condition, errors);
277
278            // Check that the loop bound is a compile-time constant
279            if let Expression::BinaryExpr { right, .. } = condition {
280                if !is_compile_time_constant(right) {
281                    errors.push(
282                        "For loop bound must be a compile-time constant (literal or const variable)"
283                            .to_string(),
284                    );
285                }
286            }
287
288            // Validate init
289            if let Statement::VariableDecl { init: init_expr, .. } = init.as_ref() {
290                validate_expression(init_expr, errors);
291            }
292
293            // Validate body
294            for s in body {
295                validate_statement(s, errors);
296            }
297        }
298        Statement::ExpressionStatement { expression, .. } => {
299            validate_expression(expression, errors);
300        }
301        Statement::ReturnStatement { value, .. } => {
302            if let Some(v) = value {
303                validate_expression(v, errors);
304            }
305        }
306    }
307}
308
309fn is_compile_time_constant(expr: &Expression) -> bool {
310    match expr {
311        Expression::BigIntLiteral { .. } => true,
312        Expression::BoolLiteral { .. } => true,
313        Expression::Identifier { .. } => true, // Could be a const
314        Expression::UnaryExpr { op, operand } if *op == UnaryOp::Neg => {
315            is_compile_time_constant(operand)
316        }
317        _ => false,
318    }
319}
320
321// ---------------------------------------------------------------------------
322// Expression validation
323// ---------------------------------------------------------------------------
324
325fn validate_expression(expr: &Expression, errors: &mut Vec<String>) {
326    match expr {
327        Expression::BinaryExpr { left, right, .. } => {
328            validate_expression(left, errors);
329            validate_expression(right, errors);
330        }
331        Expression::UnaryExpr { operand, .. } => {
332            validate_expression(operand, errors);
333        }
334        Expression::CallExpr { callee, args, .. } => {
335            validate_expression(callee, errors);
336            for arg in args {
337                validate_expression(arg, errors);
338            }
339        }
340        Expression::MemberExpr { object, .. } => {
341            validate_expression(object, errors);
342        }
343        Expression::TernaryExpr {
344            condition,
345            consequent,
346            alternate,
347        } => {
348            validate_expression(condition, errors);
349            validate_expression(consequent, errors);
350            validate_expression(alternate, errors);
351        }
352        Expression::IndexAccess { object, index } => {
353            validate_expression(object, errors);
354            validate_expression(index, errors);
355        }
356        Expression::IncrementExpr { operand, .. } | Expression::DecrementExpr { operand, .. } => {
357            validate_expression(operand, errors);
358        }
359        // Leaf nodes -- nothing to validate
360        Expression::Identifier { .. }
361        | Expression::BigIntLiteral { .. }
362        | Expression::BoolLiteral { .. }
363        | Expression::ByteStringLiteral { .. }
364        | Expression::PropertyAccess { .. } => {}
365    }
366}
367
368// ---------------------------------------------------------------------------
369// Recursion detection
370// ---------------------------------------------------------------------------
371
372fn check_no_recursion(contract: &ContractNode, errors: &mut Vec<String>) {
373    // Build call graph: method name -> set of methods it calls
374    let mut call_graph: HashMap<String, HashSet<String>> = HashMap::new();
375    let mut method_names: HashSet<String> = HashSet::new();
376
377    for method in &contract.methods {
378        method_names.insert(method.name.clone());
379        let mut calls = HashSet::new();
380        collect_method_calls(&method.body, &mut calls);
381        call_graph.insert(method.name.clone(), calls);
382    }
383
384    // Also add constructor
385    {
386        let mut calls = HashSet::new();
387        collect_method_calls(&contract.constructor.body, &mut calls);
388        call_graph.insert("constructor".to_string(), calls);
389    }
390
391    // Check for cycles using DFS
392    for method in &contract.methods {
393        let mut visited = HashSet::new();
394        let mut stack = HashSet::new();
395
396        if has_cycle(
397            &method.name,
398            &call_graph,
399            &method_names,
400            &mut visited,
401            &mut stack,
402        ) {
403            errors.push(format!(
404                "Recursion detected: method '{}' calls itself directly or indirectly. Recursion is not allowed in Rúnar contracts.",
405                method.name
406            ));
407        }
408    }
409}
410
411fn collect_method_calls(stmts: &[Statement], calls: &mut HashSet<String>) {
412    for stmt in stmts {
413        collect_method_calls_in_statement(stmt, calls);
414    }
415}
416
417fn collect_method_calls_in_statement(stmt: &Statement, calls: &mut HashSet<String>) {
418    match stmt {
419        Statement::ExpressionStatement { expression, .. } => {
420            collect_method_calls_in_expr(expression, calls);
421        }
422        Statement::VariableDecl { init, .. } => {
423            collect_method_calls_in_expr(init, calls);
424        }
425        Statement::Assignment { target, value, .. } => {
426            collect_method_calls_in_expr(target, calls);
427            collect_method_calls_in_expr(value, calls);
428        }
429        Statement::IfStatement {
430            condition,
431            then_branch,
432            else_branch,
433            ..
434        } => {
435            collect_method_calls_in_expr(condition, calls);
436            collect_method_calls(then_branch, calls);
437            if let Some(else_stmts) = else_branch {
438                collect_method_calls(else_stmts, calls);
439            }
440        }
441        Statement::ForStatement {
442            condition, body, ..
443        } => {
444            collect_method_calls_in_expr(condition, calls);
445            collect_method_calls(body, calls);
446        }
447        Statement::ReturnStatement { value, .. } => {
448            if let Some(v) = value {
449                collect_method_calls_in_expr(v, calls);
450            }
451        }
452    }
453}
454
455fn collect_method_calls_in_expr(expr: &Expression, calls: &mut HashSet<String>) {
456    match expr {
457        Expression::CallExpr { callee, args, .. } => {
458            // Check if callee is `this.methodName` (PropertyAccess variant)
459            if let Expression::PropertyAccess { property } = callee.as_ref() {
460                calls.insert(property.clone());
461            }
462            // Also check `this.method` via MemberExpr
463            if let Expression::MemberExpr { object, property } = callee.as_ref() {
464                if let Expression::Identifier { name } = object.as_ref() {
465                    if name == "this" {
466                        calls.insert(property.clone());
467                    }
468                }
469            }
470            collect_method_calls_in_expr(callee, calls);
471            for arg in args {
472                collect_method_calls_in_expr(arg, calls);
473            }
474        }
475        Expression::BinaryExpr { left, right, .. } => {
476            collect_method_calls_in_expr(left, calls);
477            collect_method_calls_in_expr(right, calls);
478        }
479        Expression::UnaryExpr { operand, .. } => {
480            collect_method_calls_in_expr(operand, calls);
481        }
482        Expression::MemberExpr { object, .. } => {
483            collect_method_calls_in_expr(object, calls);
484        }
485        Expression::TernaryExpr {
486            condition,
487            consequent,
488            alternate,
489        } => {
490            collect_method_calls_in_expr(condition, calls);
491            collect_method_calls_in_expr(consequent, calls);
492            collect_method_calls_in_expr(alternate, calls);
493        }
494        Expression::IndexAccess { object, index } => {
495            collect_method_calls_in_expr(object, calls);
496            collect_method_calls_in_expr(index, calls);
497        }
498        Expression::IncrementExpr { operand, .. } | Expression::DecrementExpr { operand, .. } => {
499            collect_method_calls_in_expr(operand, calls);
500        }
501        // Leaf nodes
502        _ => {}
503    }
504}
505
506fn has_cycle(
507    method_name: &str,
508    call_graph: &HashMap<String, HashSet<String>>,
509    method_names: &HashSet<String>,
510    visited: &mut HashSet<String>,
511    stack: &mut HashSet<String>,
512) -> bool {
513    if stack.contains(method_name) {
514        return true;
515    }
516    if visited.contains(method_name) {
517        return false;
518    }
519
520    visited.insert(method_name.to_string());
521    stack.insert(method_name.to_string());
522
523    if let Some(calls) = call_graph.get(method_name) {
524        for callee in calls {
525            if method_names.contains(callee) {
526                if has_cycle(callee, call_graph, method_names, visited, stack) {
527                    return true;
528                }
529            }
530        }
531    }
532
533    stack.remove(method_name);
534    false
535}
536
537// ---------------------------------------------------------------------------
538// Tests
539// ---------------------------------------------------------------------------
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544    use crate::frontend::parser::parse_source;
545
546    /// Helper: parse a TypeScript source string and return the ContractNode.
547    fn parse_contract(source: &str) -> ContractNode {
548        let result = parse_source(source, Some("test.runar.ts"));
549        assert!(
550            result.errors.is_empty(),
551            "parse errors: {:?}",
552            result.errors
553        );
554        result.contract.expect("expected a contract from parse")
555    }
556
557    #[test]
558    fn test_valid_p2pkh_passes_validation() {
559        let source = r#"
560import { SmartContract, Addr, PubKey, Sig } from 'runar-lang';
561
562class P2PKH extends SmartContract {
563    readonly pubKeyHash: Addr;
564
565    constructor(pubKeyHash: Addr) {
566        super(pubKeyHash);
567        this.pubKeyHash = pubKeyHash;
568    }
569
570    public unlock(sig: Sig, pubKey: PubKey) {
571        assert(hash160(pubKey) === this.pubKeyHash);
572        assert(checkSig(sig, pubKey));
573    }
574}
575"#;
576        let contract = parse_contract(source);
577        let result = validate(&contract);
578        assert!(
579            result.errors.is_empty(),
580            "expected no validation errors, got: {:?}",
581            result.errors
582        );
583    }
584
585    #[test]
586    fn test_missing_super_in_constructor_produces_error() {
587        let source = r#"
588import { SmartContract } from 'runar-lang';
589
590class Bad extends SmartContract {
591    readonly x: bigint;
592
593    constructor(x: bigint) {
594        this.x = x;
595    }
596
597    public check(v: bigint) {
598        assert(v === this.x);
599    }
600}
601"#;
602        let contract = parse_contract(source);
603        let result = validate(&contract);
604        assert!(
605            !result.errors.is_empty(),
606            "expected validation errors for missing super()"
607        );
608        let has_super_error = result
609            .errors
610            .iter()
611            .any(|e| e.to_lowercase().contains("super"));
612        assert!(
613            has_super_error,
614            "expected error about super(), got: {:?}",
615            result.errors
616        );
617    }
618
619    #[test]
620    fn test_public_method_not_ending_with_assert_produces_error() {
621        let source = r#"
622import { SmartContract } from 'runar-lang';
623
624class NoAssert extends SmartContract {
625    readonly x: bigint;
626
627    constructor(x: bigint) {
628        super(x);
629        this.x = x;
630    }
631
632    public check(v: bigint) {
633        const sum = v + this.x;
634    }
635}
636"#;
637        let contract = parse_contract(source);
638        let result = validate(&contract);
639        assert!(
640            !result.errors.is_empty(),
641            "expected validation errors for missing assert at end of public method"
642        );
643        let has_assert_error = result
644            .errors
645            .iter()
646            .any(|e| e.to_lowercase().contains("assert"));
647        assert!(
648            has_assert_error,
649            "expected error about missing assert(), got: {:?}",
650            result.errors
651        );
652    }
653
654    #[test]
655    fn test_direct_recursion_produces_error() {
656        let source = r#"
657import { SmartContract } from 'runar-lang';
658
659class Recursive extends SmartContract {
660    readonly x: bigint;
661
662    constructor(x: bigint) {
663        super(x);
664        this.x = x;
665    }
666
667    public check(v: bigint) {
668        this.check(v);
669        assert(v === this.x);
670    }
671}
672"#;
673        let contract = parse_contract(source);
674        let result = validate(&contract);
675        assert!(
676            !result.errors.is_empty(),
677            "expected validation errors for recursion"
678        );
679        let has_recursion_error = result
680            .errors
681            .iter()
682            .any(|e| e.to_lowercase().contains("recursion") || e.to_lowercase().contains("recursive"));
683        assert!(
684            has_recursion_error,
685            "expected error about recursion, got: {:?}",
686            result.errors
687        );
688    }
689
690    #[test]
691    fn test_stateful_contract_passes_validation() {
692        // StatefulSmartContract public methods don't need to end with assert
693        // because the compiler auto-injects the final assert.
694        let source = r#"
695import { StatefulSmartContract } from 'runar-lang';
696
697class Counter extends StatefulSmartContract {
698    count: bigint;
699
700    constructor(count: bigint) {
701        super(count);
702        this.count = count;
703    }
704
705    public increment() {
706        this.count++;
707    }
708}
709"#;
710        let contract = parse_contract(source);
711        let result = validate(&contract);
712        assert!(
713            result.errors.is_empty(),
714            "expected no validation errors for stateful contract, got: {:?}",
715            result.errors
716        );
717    }
718}