Skip to main content

runar_compiler_rust/frontend/
typecheck.rs

1//! Pass 3: Type-Check
2//!
3//! Type-checks the Rúnar AST. Builds type environments from properties,
4//! constructor parameters, and method parameters, then verifies all
5//! expressions have consistent types.
6
7use std::collections::{HashMap, HashSet};
8
9use super::ast::*;
10
11// ---------------------------------------------------------------------------
12// Public API
13// ---------------------------------------------------------------------------
14
15/// Result of type checking.
16pub struct TypeCheckResult {
17    pub errors: Vec<String>,
18}
19
20/// Type-check a Rúnar AST. Returns any type errors found.
21pub fn typecheck(contract: &ContractNode) -> TypeCheckResult {
22    let mut errors = Vec::new();
23    let mut checker = TypeChecker::new(contract, &mut errors);
24
25    checker.check_constructor();
26    for method in &contract.methods {
27        checker.check_method(method);
28    }
29
30    TypeCheckResult { errors }
31}
32
33// ---------------------------------------------------------------------------
34// Type representation
35// ---------------------------------------------------------------------------
36
37/// Internal type representation (simplified string-based).
38type TType = String;
39
40const VOID: &str = "void";
41const BIGINT: &str = "bigint";
42const BOOLEAN: &str = "boolean";
43const BYTESTRING: &str = "ByteString";
44
45// ---------------------------------------------------------------------------
46// Built-in function signatures
47// ---------------------------------------------------------------------------
48
49struct FuncSig {
50    params: Vec<&'static str>,
51    return_type: &'static str,
52}
53
54fn builtin_functions() -> HashMap<&'static str, FuncSig> {
55    let mut m = HashMap::new();
56
57    m.insert("sha256", FuncSig { params: vec!["ByteString"], return_type: "Sha256" });
58    m.insert("ripemd160", FuncSig { params: vec!["ByteString"], return_type: "Ripemd160" });
59    m.insert("hash160", FuncSig { params: vec!["ByteString"], return_type: "Ripemd160" });
60    m.insert("hash256", FuncSig { params: vec!["ByteString"], return_type: "Sha256" });
61    m.insert("checkSig", FuncSig { params: vec!["Sig", "PubKey"], return_type: "boolean" });
62    m.insert("checkMultiSig", FuncSig { params: vec!["Sig[]", "PubKey[]"], return_type: "boolean" });
63    m.insert("assert", FuncSig { params: vec!["boolean"], return_type: "void" });
64    m.insert("len", FuncSig { params: vec!["ByteString"], return_type: "bigint" });
65    m.insert("cat", FuncSig { params: vec!["ByteString", "ByteString"], return_type: "ByteString" });
66    m.insert("substr", FuncSig { params: vec!["ByteString", "bigint", "bigint"], return_type: "ByteString" });
67    m.insert("num2bin", FuncSig { params: vec!["bigint", "bigint"], return_type: "ByteString" });
68    m.insert("bin2num", FuncSig { params: vec!["ByteString"], return_type: "bigint" });
69    m.insert("checkPreimage", FuncSig { params: vec!["SigHashPreimage"], return_type: "boolean" });
70    m.insert("verifyRabinSig", FuncSig { params: vec!["ByteString", "RabinSig", "ByteString", "RabinPubKey"], return_type: "boolean" });
71    m.insert("verifyWOTS", FuncSig { params: vec!["ByteString", "ByteString", "ByteString"], return_type: "boolean" });
72    m.insert("verifySLHDSA_SHA2_128s", FuncSig { params: vec!["ByteString", "ByteString", "ByteString"], return_type: "boolean" });
73    m.insert("verifySLHDSA_SHA2_128f", FuncSig { params: vec!["ByteString", "ByteString", "ByteString"], return_type: "boolean" });
74    m.insert("verifySLHDSA_SHA2_192s", FuncSig { params: vec!["ByteString", "ByteString", "ByteString"], return_type: "boolean" });
75    m.insert("verifySLHDSA_SHA2_192f", FuncSig { params: vec!["ByteString", "ByteString", "ByteString"], return_type: "boolean" });
76    m.insert("verifySLHDSA_SHA2_256s", FuncSig { params: vec!["ByteString", "ByteString", "ByteString"], return_type: "boolean" });
77    m.insert("verifySLHDSA_SHA2_256f", FuncSig { params: vec!["ByteString", "ByteString", "ByteString"], return_type: "boolean" });
78    m.insert("ecAdd", FuncSig { params: vec!["Point", "Point"], return_type: "Point" });
79    m.insert("ecMul", FuncSig { params: vec!["Point", "bigint"], return_type: "Point" });
80    m.insert("ecMulGen", FuncSig { params: vec!["bigint"], return_type: "Point" });
81    m.insert("ecNegate", FuncSig { params: vec!["Point"], return_type: "Point" });
82    m.insert("ecOnCurve", FuncSig { params: vec!["Point"], return_type: "boolean" });
83    m.insert("ecModReduce", FuncSig { params: vec!["bigint", "bigint"], return_type: "bigint" });
84    m.insert("ecEncodeCompressed", FuncSig { params: vec!["Point"], return_type: "ByteString" });
85    m.insert("ecMakePoint", FuncSig { params: vec!["bigint", "bigint"], return_type: "Point" });
86    m.insert("ecPointX", FuncSig { params: vec!["Point"], return_type: "bigint" });
87    m.insert("ecPointY", FuncSig { params: vec!["Point"], return_type: "bigint" });
88    m.insert("abs", FuncSig { params: vec!["bigint"], return_type: "bigint" });
89    m.insert("min", FuncSig { params: vec!["bigint", "bigint"], return_type: "bigint" });
90    m.insert("max", FuncSig { params: vec!["bigint", "bigint"], return_type: "bigint" });
91    m.insert("within", FuncSig { params: vec!["bigint", "bigint", "bigint"], return_type: "boolean" });
92    m.insert("reverseBytes", FuncSig { params: vec!["ByteString"], return_type: "ByteString" });
93    m.insert("left", FuncSig { params: vec!["ByteString", "bigint"], return_type: "ByteString" });
94    m.insert("right", FuncSig { params: vec!["ByteString", "bigint"], return_type: "ByteString" });
95    m.insert("int2str", FuncSig { params: vec!["bigint", "bigint"], return_type: "ByteString" });
96    m.insert("toByteString", FuncSig { params: vec!["ByteString"], return_type: "ByteString" });
97    m.insert("exit", FuncSig { params: vec!["boolean"], return_type: "void" });
98    m.insert("pack", FuncSig { params: vec!["bigint"], return_type: "ByteString" });
99    m.insert("unpack", FuncSig { params: vec!["ByteString"], return_type: "bigint" });
100    m.insert("safediv", FuncSig { params: vec!["bigint", "bigint"], return_type: "bigint" });
101    m.insert("safemod", FuncSig { params: vec!["bigint", "bigint"], return_type: "bigint" });
102    m.insert("clamp", FuncSig { params: vec!["bigint", "bigint", "bigint"], return_type: "bigint" });
103    m.insert("sign", FuncSig { params: vec!["bigint"], return_type: "bigint" });
104    m.insert("pow", FuncSig { params: vec!["bigint", "bigint"], return_type: "bigint" });
105    m.insert("mulDiv", FuncSig { params: vec!["bigint", "bigint", "bigint"], return_type: "bigint" });
106    m.insert("percentOf", FuncSig { params: vec!["bigint", "bigint"], return_type: "bigint" });
107    m.insert("sqrt", FuncSig { params: vec!["bigint"], return_type: "bigint" });
108    m.insert("gcd", FuncSig { params: vec!["bigint", "bigint"], return_type: "bigint" });
109    m.insert("divmod", FuncSig { params: vec!["bigint", "bigint"], return_type: "bigint" });
110    m.insert("log2", FuncSig { params: vec!["bigint"], return_type: "bigint" });
111    m.insert("bool", FuncSig { params: vec!["bigint"], return_type: "boolean" });
112
113    // Preimage extractors
114    m.insert("extractVersion", FuncSig { params: vec!["SigHashPreimage"], return_type: "bigint" });
115    m.insert("extractHashPrevouts", FuncSig { params: vec!["SigHashPreimage"], return_type: "Sha256" });
116    m.insert("extractHashSequence", FuncSig { params: vec!["SigHashPreimage"], return_type: "Sha256" });
117    m.insert("extractOutpoint", FuncSig { params: vec!["SigHashPreimage"], return_type: "ByteString" });
118    m.insert("extractInputIndex", FuncSig { params: vec!["SigHashPreimage"], return_type: "bigint" });
119    m.insert("extractScriptCode", FuncSig { params: vec!["SigHashPreimage"], return_type: "ByteString" });
120    m.insert("extractAmount", FuncSig { params: vec!["SigHashPreimage"], return_type: "bigint" });
121    m.insert("extractSequence", FuncSig { params: vec!["SigHashPreimage"], return_type: "bigint" });
122    m.insert("extractOutputHash", FuncSig { params: vec!["SigHashPreimage"], return_type: "Sha256" });
123    m.insert("extractOutputs", FuncSig { params: vec!["SigHashPreimage"], return_type: "Sha256" });
124    m.insert("extractLocktime", FuncSig { params: vec!["SigHashPreimage"], return_type: "bigint" });
125    m.insert("extractSigHashType", FuncSig { params: vec!["SigHashPreimage"], return_type: "bigint" });
126
127    m
128}
129
130// ---------------------------------------------------------------------------
131// Subtyping
132// ---------------------------------------------------------------------------
133
134/// ByteString subtypes -- types represented as byte strings on the stack.
135fn is_bytestring_subtype(t: &str) -> bool {
136    matches!(
137        t,
138        "ByteString" | "PubKey" | "Sig" | "Sha256" | "Ripemd160" | "Addr" | "SigHashPreimage" | "Point"
139    )
140}
141
142/// Bigint subtypes -- types represented as integers on the stack.
143fn is_bigint_subtype(t: &str) -> bool {
144    matches!(t, "bigint" | "RabinSig" | "RabinPubKey")
145}
146
147fn is_subtype(actual: &str, expected: &str) -> bool {
148    if actual == expected {
149        return true;
150    }
151
152    // ByteString subtypes
153    if expected == "ByteString" && is_bytestring_subtype(actual) {
154        return true;
155    }
156    if actual == "ByteString" && is_bytestring_subtype(expected) {
157        return true;
158    }
159
160    // Both in the ByteString family -> compatible (e.g. Addr and Ripemd160)
161    if is_bytestring_subtype(actual) && is_bytestring_subtype(expected) {
162        return true;
163    }
164
165    // bigint subtypes
166    if expected == "bigint" && is_bigint_subtype(actual) {
167        return true;
168    }
169    if actual == "bigint" && is_bigint_subtype(expected) {
170        return true;
171    }
172
173    // Both in the bigint family -> compatible
174    if is_bigint_subtype(actual) && is_bigint_subtype(expected) {
175        return true;
176    }
177
178    // Array subtyping
179    if expected.ends_with("[]") && actual.ends_with("[]") {
180        return is_subtype(
181            &actual[..actual.len() - 2],
182            &expected[..expected.len() - 2],
183        );
184    }
185
186    false
187}
188
189fn is_bigint_family(t: &str) -> bool {
190    is_bigint_subtype(t)
191}
192
193// ---------------------------------------------------------------------------
194// Type environment
195// ---------------------------------------------------------------------------
196
197struct TypeEnv {
198    scopes: Vec<HashMap<String, TType>>,
199}
200
201impl TypeEnv {
202    fn new() -> Self {
203        TypeEnv {
204            scopes: vec![HashMap::new()],
205        }
206    }
207
208    fn push_scope(&mut self) {
209        self.scopes.push(HashMap::new());
210    }
211
212    fn pop_scope(&mut self) {
213        self.scopes.pop();
214    }
215
216    fn define(&mut self, name: &str, t: TType) {
217        if let Some(top) = self.scopes.last_mut() {
218            top.insert(name.to_string(), t);
219        }
220    }
221
222    fn lookup(&self, name: &str) -> Option<&TType> {
223        for scope in self.scopes.iter().rev() {
224            if let Some(t) = scope.get(name) {
225                return Some(t);
226            }
227        }
228        None
229    }
230}
231
232// ---------------------------------------------------------------------------
233// Type checker
234// ---------------------------------------------------------------------------
235
236/// Types whose values can be consumed at most once.
237fn is_affine_type(t: &str) -> bool {
238    matches!(t, "Sig" | "SigHashPreimage")
239}
240
241/// Maps consuming function names to the parameter indices that consume
242/// affine values.
243fn consuming_param_indices(func_name: &str) -> Option<&'static [usize]> {
244    match func_name {
245        "checkSig" => Some(&[0]),
246        "checkMultiSig" => Some(&[0]),
247        "checkPreimage" => Some(&[0]),
248        _ => None,
249    }
250}
251
252struct TypeChecker<'a> {
253    contract: &'a ContractNode,
254    errors: &'a mut Vec<String>,
255    prop_types: HashMap<String, TType>,
256    method_sigs: HashMap<String, (Vec<TType>, TType)>,
257    builtins: HashMap<&'static str, FuncSig>,
258    consumed_values: HashSet<String>,
259}
260
261impl<'a> TypeChecker<'a> {
262    fn new(contract: &'a ContractNode, errors: &'a mut Vec<String>) -> Self {
263        let mut prop_types = HashMap::new();
264        for prop in &contract.properties {
265            prop_types.insert(prop.name.clone(), type_node_to_ttype(&prop.prop_type));
266        }
267
268        // For StatefulSmartContract, add the implicit txPreimage property
269        if contract.parent_class == "StatefulSmartContract" {
270            prop_types.insert("txPreimage".to_string(), "SigHashPreimage".to_string());
271        }
272
273        let mut method_sigs = HashMap::new();
274        for method in &contract.methods {
275            let params: Vec<TType> = method
276                .params
277                .iter()
278                .map(|p| type_node_to_ttype(&p.param_type))
279                .collect();
280            let return_type = if method.visibility == Visibility::Public {
281                VOID.to_string()
282            } else {
283                infer_method_return_type(method)
284            };
285            method_sigs.insert(method.name.clone(), (params, return_type));
286        }
287
288        TypeChecker {
289            contract,
290            errors,
291            prop_types,
292            method_sigs,
293            builtins: builtin_functions(),
294            consumed_values: HashSet::new(),
295        }
296    }
297
298    fn check_constructor(&mut self) {
299        let ctor = &self.contract.constructor;
300        let mut env = TypeEnv::new();
301
302        // Reset affine tracking for this scope
303        self.consumed_values.clear();
304
305        // Add constructor params to env
306        for param in &ctor.params {
307            env.define(&param.name, type_node_to_ttype(&param.param_type));
308        }
309
310        // Add properties to env
311        for prop in &self.contract.properties {
312            env.define(&prop.name, type_node_to_ttype(&prop.prop_type));
313        }
314
315        self.check_statements(&ctor.body, &mut env);
316    }
317
318    fn check_method(&mut self, method: &MethodNode) {
319        let mut env = TypeEnv::new();
320
321        // Reset affine tracking for this method
322        self.consumed_values.clear();
323
324        // Add method params to env
325        for param in &method.params {
326            env.define(&param.name, type_node_to_ttype(&param.param_type));
327        }
328
329        self.check_statements(&method.body, &mut env);
330    }
331
332    fn check_statements(&mut self, stmts: &[Statement], env: &mut TypeEnv) {
333        for stmt in stmts {
334            self.check_statement(stmt, env);
335        }
336    }
337
338    fn check_statement(&mut self, stmt: &Statement, env: &mut TypeEnv) {
339        match stmt {
340            Statement::VariableDecl {
341                name,
342                var_type,
343                init,
344                ..
345            } => {
346                let init_type = self.infer_expr_type(init, env);
347                if let Some(declared) = var_type {
348                    let declared_type = type_node_to_ttype(declared);
349                    if !is_subtype(&init_type, &declared_type) {
350                        self.errors.push(format!(
351                            "Type '{}' is not assignable to type '{}'",
352                            init_type, declared_type
353                        ));
354                    }
355                    env.define(name, declared_type);
356                } else {
357                    env.define(name, init_type);
358                }
359            }
360
361            Statement::Assignment { target, value, .. } => {
362                let target_type = self.infer_expr_type(target, env);
363                let value_type = self.infer_expr_type(value, env);
364                if !is_subtype(&value_type, &target_type) {
365                    self.errors.push(format!(
366                        "Type '{}' is not assignable to type '{}'",
367                        value_type, target_type
368                    ));
369                }
370            }
371
372            Statement::IfStatement {
373                condition,
374                then_branch,
375                else_branch,
376                ..
377            } => {
378                let cond_type = self.infer_expr_type(condition, env);
379                if cond_type != BOOLEAN {
380                    self.errors.push(format!(
381                        "If condition must be boolean, got '{}'",
382                        cond_type
383                    ));
384                }
385                env.push_scope();
386                self.check_statements(then_branch, env);
387                env.pop_scope();
388                if let Some(else_stmts) = else_branch {
389                    env.push_scope();
390                    self.check_statements(else_stmts, env);
391                    env.pop_scope();
392                }
393            }
394
395            Statement::ForStatement {
396                init,
397                condition,
398                body,
399                ..
400            } => {
401                env.push_scope();
402                self.check_statement(init, env);
403                let cond_type = self.infer_expr_type(condition, env);
404                if cond_type != BOOLEAN {
405                    self.errors.push(format!(
406                        "For loop condition must be boolean, got '{}'",
407                        cond_type
408                    ));
409                }
410                self.check_statements(body, env);
411                env.pop_scope();
412            }
413
414            Statement::ExpressionStatement { expression, .. } => {
415                self.infer_expr_type(expression, env);
416            }
417
418            Statement::ReturnStatement { value, .. } => {
419                if let Some(v) = value {
420                    self.infer_expr_type(v, env);
421                }
422            }
423        }
424    }
425
426    /// Infer the type of an expression.
427    fn infer_expr_type(&mut self, expr: &Expression, env: &mut TypeEnv) -> TType {
428        match expr {
429            Expression::BigIntLiteral { .. } => BIGINT.to_string(),
430
431            Expression::BoolLiteral { .. } => BOOLEAN.to_string(),
432
433            Expression::ByteStringLiteral { .. } => BYTESTRING.to_string(),
434
435            Expression::Identifier { name } => {
436                if name == "this" {
437                    return "<this>".to_string();
438                }
439                if name == "super" {
440                    return "<super>".to_string();
441                }
442                if name == "true" || name == "false" {
443                    return BOOLEAN.to_string();
444                }
445
446                if let Some(t) = env.lookup(name) {
447                    return t.clone();
448                }
449
450                // Check if it's a builtin function name
451                if self.builtins.contains_key(name.as_str()) {
452                    return "<builtin>".to_string();
453                }
454
455                "<unknown>".to_string()
456            }
457
458            Expression::PropertyAccess { property } => {
459                if let Some(t) = self.prop_types.get(property) {
460                    return t.clone();
461                }
462
463                self.errors.push(format!(
464                    "Property '{}' does not exist on the contract",
465                    property
466                ));
467                "<unknown>".to_string()
468            }
469
470            Expression::MemberExpr { object, property } => {
471                let obj_type = self.infer_expr_type(object, env);
472
473                if obj_type == "<this>" {
474                    // Check if it's a property
475                    if let Some(t) = self.prop_types.get(property) {
476                        return t.clone();
477                    }
478                    // Check if it's a method
479                    if self.method_sigs.contains_key(property) {
480                        return "<method>".to_string();
481                    }
482                    // Special: getStateScript
483                    if property == "getStateScript" {
484                        return "<method>".to_string();
485                    }
486
487                    self.errors.push(format!(
488                        "Property or method '{}' does not exist on the contract",
489                        property
490                    ));
491                    return "<unknown>".to_string();
492                }
493
494                // SigHash.ALL, SigHash.FORKID, etc.
495                if let Expression::Identifier { name } = object.as_ref() {
496                    if name == "SigHash" {
497                        return BIGINT.to_string();
498                    }
499                }
500
501                "<unknown>".to_string()
502            }
503
504            Expression::BinaryExpr { op, left, right } => {
505                self.check_binary_expr(op, left, right, env)
506            }
507
508            Expression::UnaryExpr { op, operand } => self.check_unary_expr(op, operand, env),
509
510            Expression::CallExpr { callee, args } => self.check_call_expr(callee, args, env),
511
512            Expression::TernaryExpr {
513                condition,
514                consequent,
515                alternate,
516            } => {
517                let cond_type = self.infer_expr_type(condition, env);
518                if cond_type != BOOLEAN {
519                    self.errors.push(format!(
520                        "Ternary condition must be boolean, got '{}'",
521                        cond_type
522                    ));
523                }
524                let cons_type = self.infer_expr_type(consequent, env);
525                let alt_type = self.infer_expr_type(alternate, env);
526
527                if cons_type != alt_type {
528                    if is_subtype(&alt_type, &cons_type) {
529                        return cons_type;
530                    }
531                    if is_subtype(&cons_type, &alt_type) {
532                        return alt_type;
533                    }
534                    self.errors.push(format!(
535                        "Ternary branches have incompatible types: '{}' and '{}'",
536                        cons_type, alt_type
537                    ));
538                }
539                cons_type
540            }
541
542            Expression::IndexAccess { object, index } => {
543                let obj_type = self.infer_expr_type(object, env);
544                let index_type = self.infer_expr_type(index, env);
545
546                if !is_bigint_family(&index_type) {
547                    self.errors.push(format!(
548                        "Array index must be bigint, got '{}'",
549                        index_type
550                    ));
551                }
552
553                if obj_type.ends_with("[]") {
554                    return obj_type[..obj_type.len() - 2].to_string();
555                }
556
557                "<unknown>".to_string()
558            }
559
560            Expression::IncrementExpr { operand, .. }
561            | Expression::DecrementExpr { operand, .. } => {
562                let operand_type = self.infer_expr_type(operand, env);
563                if !is_bigint_family(&operand_type) {
564                    let op_str = if matches!(expr, Expression::IncrementExpr { .. }) {
565                        "++"
566                    } else {
567                        "--"
568                    };
569                    self.errors.push(format!(
570                        "{} operator requires bigint, got '{}'",
571                        op_str, operand_type
572                    ));
573                }
574                BIGINT.to_string()
575            }
576        }
577    }
578
579    fn check_binary_expr(
580        &mut self,
581        op: &BinaryOp,
582        left: &Expression,
583        right: &Expression,
584        env: &mut TypeEnv,
585    ) -> TType {
586        let left_type = self.infer_expr_type(left, env);
587        let right_type = self.infer_expr_type(right, env);
588
589        match op {
590            // Arithmetic: bigint x bigint -> bigint
591            BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod => {
592                if !is_bigint_family(&left_type) {
593                    self.errors.push(format!(
594                        "Left operand of '{}' must be bigint, got '{}'",
595                        op.as_str(),
596                        left_type
597                    ));
598                }
599                if !is_bigint_family(&right_type) {
600                    self.errors.push(format!(
601                        "Right operand of '{}' must be bigint, got '{}'",
602                        op.as_str(),
603                        right_type
604                    ));
605                }
606                BIGINT.to_string()
607            }
608
609            // Comparison: bigint x bigint -> boolean
610            BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge => {
611                if !is_bigint_family(&left_type) {
612                    self.errors.push(format!(
613                        "Left operand of '{}' must be bigint, got '{}'",
614                        op.as_str(),
615                        left_type
616                    ));
617                }
618                if !is_bigint_family(&right_type) {
619                    self.errors.push(format!(
620                        "Right operand of '{}' must be bigint, got '{}'",
621                        op.as_str(),
622                        right_type
623                    ));
624                }
625                BOOLEAN.to_string()
626            }
627
628            // Equality: T x T -> boolean
629            BinaryOp::StrictEq | BinaryOp::StrictNe => {
630                if !is_subtype(&left_type, &right_type)
631                    && !is_subtype(&right_type, &left_type)
632                {
633                    if left_type != "<unknown>" && right_type != "<unknown>" {
634                        self.errors.push(format!(
635                            "Cannot compare '{}' and '{}' with '{}'",
636                            left_type,
637                            right_type,
638                            op.as_str()
639                        ));
640                    }
641                }
642                BOOLEAN.to_string()
643            }
644
645            // Logical: boolean x boolean -> boolean
646            BinaryOp::And | BinaryOp::Or => {
647                if left_type != BOOLEAN && left_type != "<unknown>" {
648                    self.errors.push(format!(
649                        "Left operand of '{}' must be boolean, got '{}'",
650                        op.as_str(),
651                        left_type
652                    ));
653                }
654                if right_type != BOOLEAN && right_type != "<unknown>" {
655                    self.errors.push(format!(
656                        "Right operand of '{}' must be boolean, got '{}'",
657                        op.as_str(),
658                        right_type
659                    ));
660                }
661                BOOLEAN.to_string()
662            }
663
664            // Bitwise / shift: bigint x bigint -> bigint
665            BinaryOp::BitAnd | BinaryOp::BitOr | BinaryOp::BitXor | BinaryOp::Shl | BinaryOp::Shr => {
666                if !is_bigint_family(&left_type) {
667                    self.errors.push(format!(
668                        "Left operand of '{}' must be bigint, got '{}'",
669                        op.as_str(),
670                        left_type
671                    ));
672                }
673                if !is_bigint_family(&right_type) {
674                    self.errors.push(format!(
675                        "Right operand of '{}' must be bigint, got '{}'",
676                        op.as_str(),
677                        right_type
678                    ));
679                }
680                BIGINT.to_string()
681            }
682        }
683    }
684
685    fn check_unary_expr(
686        &mut self,
687        op: &UnaryOp,
688        operand: &Expression,
689        env: &mut TypeEnv,
690    ) -> TType {
691        let operand_type = self.infer_expr_type(operand, env);
692
693        match op {
694            UnaryOp::Not => {
695                if operand_type != BOOLEAN && operand_type != "<unknown>" {
696                    self.errors.push(format!(
697                        "Operand of '!' must be boolean, got '{}'",
698                        operand_type
699                    ));
700                }
701                BOOLEAN.to_string()
702            }
703            UnaryOp::Neg => {
704                if !is_bigint_family(&operand_type) {
705                    self.errors.push(format!(
706                        "Operand of unary '-' must be bigint, got '{}'",
707                        operand_type
708                    ));
709                }
710                BIGINT.to_string()
711            }
712            UnaryOp::BitNot => {
713                if !is_bigint_family(&operand_type) {
714                    self.errors.push(format!(
715                        "Operand of '~' must be bigint, got '{}'",
716                        operand_type
717                    ));
718                }
719                BIGINT.to_string()
720            }
721        }
722    }
723
724    fn check_call_expr(
725        &mut self,
726        callee: &Expression,
727        args: &[Expression],
728        env: &mut TypeEnv,
729    ) -> TType {
730        // super() call in constructor
731        if let Expression::Identifier { name } = callee {
732            if name == "super" {
733                for arg in args {
734                    self.infer_expr_type(arg, env);
735                }
736                return VOID.to_string();
737            }
738        }
739
740        // Direct builtin call: assert(...), checkSig(...), sha256(...), etc.
741        if let Expression::Identifier { name } = callee {
742            if let Some(sig) = self.builtins.get(name.as_str()) {
743                let sig_params = sig.params.clone();
744                let sig_return_type = sig.return_type;
745                return self.check_call_args(name, &sig_params, sig_return_type, args, env);
746            }
747
748            // Check if it's a known contract method
749            if let Some((params, return_type)) = self.method_sigs.get(name).cloned() {
750                let param_strs: Vec<&str> = params.iter().map(|s| s.as_str()).collect();
751                return self.check_call_args(name, &param_strs, &return_type, args, env);
752            }
753
754            // Check if it's a local variable
755            if env.lookup(name).is_some() {
756                for arg in args {
757                    self.infer_expr_type(arg, env);
758                }
759                return "<unknown>".to_string();
760            }
761
762            self.errors.push(format!(
763                "unknown function '{}' — only Rúnar built-in functions and contract methods are allowed",
764                name
765            ));
766            for arg in args {
767                self.infer_expr_type(arg, env);
768            }
769            return "<unknown>".to_string();
770        }
771
772        // this.method(...) via PropertyAccess
773        if let Expression::PropertyAccess { property } = callee {
774            if property == "getStateScript" {
775                if !args.is_empty() {
776                    self.errors
777                        .push("getStateScript() takes no arguments".to_string());
778                }
779                return BYTESTRING.to_string();
780            }
781
782            if property == "addOutput" {
783                for arg in args {
784                    self.infer_expr_type(arg, env);
785                }
786                return VOID.to_string();
787            }
788
789            // Check contract method signatures
790            if let Some((params, return_type)) = self.method_sigs.get(property).cloned() {
791                let param_strs: Vec<&str> = params.iter().map(|s| s.as_str()).collect();
792                return self.check_call_args(property, &param_strs, &return_type, args, env);
793            }
794
795            self.errors.push(format!(
796                "unknown method 'self.{}' — only Rúnar built-in methods and contract methods are allowed",
797                property
798            ));
799            for arg in args {
800                self.infer_expr_type(arg, env);
801            }
802            return "<unknown>".to_string();
803        }
804
805        // member_expr call: obj.method(...)
806        if let Expression::MemberExpr { object, property } = callee {
807            // .clone() is a Rust idiom — allow it as a no-op
808            if property == "clone" {
809                return self.infer_expr_type(object, env);
810            }
811
812            let obj_type = self.infer_expr_type(object, env);
813
814            if obj_type == "<this>"
815                || matches!(object.as_ref(), Expression::Identifier { name } if name == "this")
816            {
817                if property == "getStateScript" {
818                    return BYTESTRING.to_string();
819                }
820
821                if let Some((params, return_type)) = self.method_sigs.get(property).cloned() {
822                    let param_strs: Vec<&str> = params.iter().map(|s| s.as_str()).collect();
823                    return self.check_call_args(
824                        property,
825                        &param_strs,
826                        &return_type,
827                        args,
828                        env,
829                    );
830                }
831            }
832
833            // Not this.method — reject (e.g. std::process::exit)
834            let obj_name = match object.as_ref() {
835                Expression::Identifier { name } => name.clone(),
836                _ => "<expr>".to_string(),
837            };
838            self.errors.push(format!(
839                "unknown function '{}.{}' — only Rúnar built-in functions and contract methods are allowed",
840                obj_name, property
841            ));
842            for arg in args {
843                self.infer_expr_type(arg, env);
844            }
845            return "<unknown>".to_string();
846        }
847
848        // Fallback — unknown callee shape
849        self.errors.push(
850            "unsupported function call expression — only Rúnar built-in functions and contract methods are allowed".to_string()
851        );
852        self.infer_expr_type(callee, env);
853        for arg in args {
854            self.infer_expr_type(arg, env);
855        }
856        "<unknown>".to_string()
857    }
858
859    fn check_call_args(
860        &mut self,
861        func_name: &str,
862        sig_params: &[&str],
863        return_type: &str,
864        args: &[Expression],
865        env: &mut TypeEnv,
866    ) -> TType {
867        // Special case: assert can take 1 or 2 args
868        if func_name == "assert" {
869            if args.is_empty() || args.len() > 2 {
870                self.errors.push(format!(
871                    "assert() expects 1 or 2 arguments, got {}",
872                    args.len()
873                ));
874            }
875            if !args.is_empty() {
876                let cond_type = self.infer_expr_type(&args[0], env);
877                if cond_type != BOOLEAN && cond_type != "<unknown>" {
878                    self.errors.push(format!(
879                        "assert() condition must be boolean, got '{}'",
880                        cond_type
881                    ));
882                }
883            }
884            if args.len() >= 2 {
885                self.infer_expr_type(&args[1], env);
886            }
887            return return_type.to_string();
888        }
889
890        // Special case: checkMultiSig
891        if func_name == "checkMultiSig" {
892            if args.len() != 2 {
893                self.errors.push(format!(
894                    "checkMultiSig() expects 2 arguments, got {}",
895                    args.len()
896                ));
897            }
898            for arg in args {
899                self.infer_expr_type(arg, env);
900            }
901            self.check_affine_consumption(func_name, args, env);
902            return return_type.to_string();
903        }
904
905        // Standard argument count check
906        if args.len() != sig_params.len() {
907            self.errors.push(format!(
908                "{}() expects {} argument(s), got {}",
909                func_name,
910                sig_params.len(),
911                args.len()
912            ));
913        }
914
915        let count = args.len().min(sig_params.len());
916        for i in 0..count {
917            let arg_type = self.infer_expr_type(&args[i], env);
918            let expected = sig_params[i];
919
920            if !is_subtype(&arg_type, expected) && arg_type != "<unknown>" {
921                self.errors.push(format!(
922                    "Argument {} of {}(): expected '{}', got '{}'",
923                    i + 1,
924                    func_name,
925                    expected,
926                    arg_type
927                ));
928            }
929        }
930
931        // Infer remaining args even if count mismatches
932        for i in count..args.len() {
933            self.infer_expr_type(&args[i], env);
934        }
935
936        // Affine type enforcement
937        self.check_affine_consumption(func_name, args, env);
938
939        return_type.to_string()
940    }
941
942    /// Check affine type constraints: Sig and SigHashPreimage values may
943    /// only be consumed once by a consuming function.
944    fn check_affine_consumption(
945        &mut self,
946        func_name: &str,
947        args: &[Expression],
948        env: &mut TypeEnv,
949    ) {
950        let indices = match consuming_param_indices(func_name) {
951            Some(indices) => indices,
952            None => return,
953        };
954
955        for &param_index in indices {
956            if param_index >= args.len() {
957                continue;
958            }
959
960            let arg = &args[param_index];
961            if let Expression::Identifier { name } = arg {
962                if let Some(arg_type) = env.lookup(name) {
963                    let arg_type = arg_type.clone();
964                    if !is_affine_type(&arg_type) {
965                        continue;
966                    }
967
968                    if self.consumed_values.contains(name) {
969                        self.errors.push(format!(
970                            "affine value '{}' has already been consumed",
971                            name
972                        ));
973                    } else {
974                        self.consumed_values.insert(name.clone());
975                    }
976                }
977            }
978        }
979    }
980}
981
982// ---------------------------------------------------------------------------
983// Helpers
984// ---------------------------------------------------------------------------
985
986// ---------------------------------------------------------------------------
987// Private method return type inference
988// ---------------------------------------------------------------------------
989
990/// Infer a private method's return type by walking all return statements
991/// and inferring the type of their expressions. Returns "void" if no
992/// return statements with values are found.
993fn infer_method_return_type(method: &MethodNode) -> TType {
994    let return_types = collect_return_types(&method.body);
995
996    if return_types.is_empty() {
997        return VOID.to_string();
998    }
999
1000    let first = &return_types[0];
1001    let all_same = return_types.iter().all(|t| t == first);
1002    if all_same {
1003        return first.clone();
1004    }
1005
1006    // Check if all are in the bigint family
1007    if return_types.iter().all(|t| is_bigint_subtype(t)) {
1008        return BIGINT.to_string();
1009    }
1010
1011    // Check if all are in the ByteString family
1012    if return_types.iter().all(|t| is_bytestring_subtype(t)) {
1013        return BYTESTRING.to_string();
1014    }
1015
1016    // Check if all are boolean
1017    if return_types.iter().all(|t| t == BOOLEAN) {
1018        return BOOLEAN.to_string();
1019    }
1020
1021    // Mixed types -- return the first as a best effort
1022    first.clone()
1023}
1024
1025/// Recursively collect inferred types from return statements.
1026fn collect_return_types(stmts: &[Statement]) -> Vec<TType> {
1027    let mut types = Vec::new();
1028    for stmt in stmts {
1029        match stmt {
1030            Statement::ReturnStatement { value, .. } => {
1031                if let Some(v) = value {
1032                    types.push(infer_expr_type_static(v));
1033                }
1034            }
1035            Statement::IfStatement {
1036                then_branch,
1037                else_branch,
1038                ..
1039            } => {
1040                types.extend(collect_return_types(then_branch));
1041                if let Some(else_stmts) = else_branch {
1042                    types.extend(collect_return_types(else_stmts));
1043                }
1044            }
1045            Statement::ForStatement { body, .. } => {
1046                types.extend(collect_return_types(body));
1047            }
1048            _ => {}
1049        }
1050    }
1051    types
1052}
1053
1054/// Lightweight static expression type inference without a type environment.
1055/// Used for inferring return types of private methods before the full
1056/// type-check pass runs.
1057fn infer_expr_type_static(expr: &Expression) -> TType {
1058    match expr {
1059        Expression::BigIntLiteral { .. } => BIGINT.to_string(),
1060        Expression::BoolLiteral { .. } => BOOLEAN.to_string(),
1061        Expression::ByteStringLiteral { .. } => BYTESTRING.to_string(),
1062        Expression::Identifier { name } => {
1063            if name == "true" || name == "false" {
1064                BOOLEAN.to_string()
1065            } else {
1066                "<unknown>".to_string()
1067            }
1068        }
1069        Expression::BinaryExpr { op, .. } => match op {
1070            BinaryOp::Add
1071            | BinaryOp::Sub
1072            | BinaryOp::Mul
1073            | BinaryOp::Div
1074            | BinaryOp::Mod
1075            | BinaryOp::BitAnd
1076            | BinaryOp::BitOr
1077            | BinaryOp::BitXor
1078            | BinaryOp::Shl
1079            | BinaryOp::Shr => BIGINT.to_string(),
1080            _ => BOOLEAN.to_string(),
1081        },
1082        Expression::UnaryExpr { op, .. } => match op {
1083            UnaryOp::Not => BOOLEAN.to_string(),
1084            _ => BIGINT.to_string(),
1085        },
1086        Expression::CallExpr { callee, .. } => {
1087            let builtins = builtin_functions();
1088            if let Expression::Identifier { name } = callee.as_ref() {
1089                if let Some(sig) = builtins.get(name.as_str()) {
1090                    return sig.return_type.to_string();
1091                }
1092            }
1093            if let Expression::PropertyAccess { property } = callee.as_ref() {
1094                if let Some(sig) = builtins.get(property.as_str()) {
1095                    return sig.return_type.to_string();
1096                }
1097            }
1098            "<unknown>".to_string()
1099        }
1100        Expression::TernaryExpr {
1101            consequent,
1102            alternate,
1103            ..
1104        } => {
1105            let cons_type = infer_expr_type_static(consequent);
1106            if cons_type != "<unknown>" {
1107                cons_type
1108            } else {
1109                infer_expr_type_static(alternate)
1110            }
1111        }
1112        Expression::IncrementExpr { .. } | Expression::DecrementExpr { .. } => {
1113            BIGINT.to_string()
1114        }
1115        _ => "<unknown>".to_string(),
1116    }
1117}
1118
1119fn type_node_to_ttype(node: &TypeNode) -> TType {
1120    match node {
1121        TypeNode::Primitive(name) => name.as_str().to_string(),
1122        TypeNode::FixedArray { element, .. } => {
1123            format!("{}[]", type_node_to_ttype(element))
1124        }
1125        TypeNode::Custom(name) => name.clone(),
1126    }
1127}
1128
1129// ---------------------------------------------------------------------------
1130// Tests
1131// ---------------------------------------------------------------------------
1132
1133#[cfg(test)]
1134mod tests {
1135    use super::*;
1136    use crate::frontend::parser::parse_source;
1137    use crate::frontend::validator;
1138
1139    /// Helper: parse and validate a TypeScript source string, then return the ContractNode.
1140    fn parse_and_validate(source: &str) -> ContractNode {
1141        let result = parse_source(source, Some("test.runar.ts"));
1142        assert!(
1143            result.errors.is_empty(),
1144            "parse errors: {:?}",
1145            result.errors
1146        );
1147        let contract = result.contract.expect("expected a contract from parse");
1148        let validation = validator::validate(&contract);
1149        assert!(
1150            validation.errors.is_empty(),
1151            "validation errors: {:?}",
1152            validation.errors
1153        );
1154        contract
1155    }
1156
1157    #[test]
1158    fn test_valid_p2pkh_passes_typecheck() {
1159        let source = r#"
1160import { SmartContract, Addr, PubKey, Sig } from 'runar-lang';
1161
1162class P2PKH extends SmartContract {
1163    readonly pubKeyHash: Addr;
1164
1165    constructor(pubKeyHash: Addr) {
1166        super(pubKeyHash);
1167        this.pubKeyHash = pubKeyHash;
1168    }
1169
1170    public unlock(sig: Sig, pubKey: PubKey) {
1171        assert(hash160(pubKey) === this.pubKeyHash);
1172        assert(checkSig(sig, pubKey));
1173    }
1174}
1175"#;
1176        let contract = parse_and_validate(source);
1177        let result = typecheck(&contract);
1178        assert!(
1179            result.errors.is_empty(),
1180            "expected no typecheck errors, got: {:?}",
1181            result.errors
1182        );
1183    }
1184
1185    #[test]
1186    fn test_unknown_function_call_produces_error() {
1187        let source = r#"
1188import { SmartContract } from 'runar-lang';
1189
1190class Bad extends SmartContract {
1191    readonly x: bigint;
1192
1193    constructor(x: bigint) {
1194        super(x);
1195        this.x = x;
1196    }
1197
1198    public check(v: bigint) {
1199        const y = Math.floor(v);
1200        assert(y === this.x);
1201    }
1202}
1203"#;
1204        let contract = parse_and_validate(source);
1205        let result = typecheck(&contract);
1206        assert!(
1207            !result.errors.is_empty(),
1208            "expected typecheck errors for unknown function Math.floor"
1209        );
1210        let has_unknown_error = result
1211            .errors
1212            .iter()
1213            .any(|e| e.to_lowercase().contains("unknown"));
1214        assert!(
1215            has_unknown_error,
1216            "expected error about unknown function, got: {:?}",
1217            result.errors
1218        );
1219    }
1220
1221    #[test]
1222    fn test_builtin_with_wrong_arg_count_produces_error() {
1223        let source = r#"
1224import { SmartContract, PubKey, Sig } from 'runar-lang';
1225
1226class Bad extends SmartContract {
1227    readonly x: bigint;
1228
1229    constructor(x: bigint) {
1230        super(x);
1231        this.x = x;
1232    }
1233
1234    public check(v: bigint) {
1235        assert(min(v));
1236    }
1237}
1238"#;
1239        let contract = parse_and_validate(source);
1240        let result = typecheck(&contract);
1241        assert!(
1242            !result.errors.is_empty(),
1243            "expected typecheck errors for wrong arg count"
1244        );
1245        let has_arg_count_error = result
1246            .errors
1247            .iter()
1248            .any(|e| e.contains("expects") && e.contains("argument"));
1249        assert!(
1250            has_arg_count_error,
1251            "expected error about wrong argument count, got: {:?}",
1252            result.errors
1253        );
1254    }
1255
1256    #[test]
1257    fn test_arithmetic_on_boolean_produces_error() {
1258        let source = r#"
1259import { SmartContract } from 'runar-lang';
1260
1261class Bad extends SmartContract {
1262    readonly x: bigint;
1263
1264    constructor(x: bigint) {
1265        super(x);
1266        this.x = x;
1267    }
1268
1269    public check(v: bigint, flag: boolean) {
1270        const sum = v + flag;
1271        assert(sum === this.x);
1272    }
1273}
1274"#;
1275        let contract = parse_and_validate(source);
1276        let result = typecheck(&contract);
1277        assert!(
1278            !result.errors.is_empty(),
1279            "expected typecheck errors for arithmetic on boolean"
1280        );
1281        let has_type_error = result
1282            .errors
1283            .iter()
1284            .any(|e| e.contains("bigint") || e.contains("boolean"));
1285        assert!(
1286            has_type_error,
1287            "expected type mismatch error, got: {:?}",
1288            result.errors
1289        );
1290    }
1291
1292    #[test]
1293    fn test_valid_stateful_contract_passes_typecheck() {
1294        let source = r#"
1295import { StatefulSmartContract } from 'runar-lang';
1296
1297class Counter extends StatefulSmartContract {
1298    count: bigint;
1299
1300    constructor(count: bigint) {
1301        super(count);
1302        this.count = count;
1303    }
1304
1305    public increment() {
1306        this.count++;
1307    }
1308}
1309"#;
1310        let contract = parse_and_validate(source);
1311        let result = typecheck(&contract);
1312        assert!(
1313            result.errors.is_empty(),
1314            "expected no typecheck errors for stateful contract, got: {:?}",
1315            result.errors
1316        );
1317    }
1318}