1use std::collections::{HashMap, HashSet};
7
8use super::ast::*;
9
10pub struct ValidationResult {
16 pub errors: Vec<String>,
17 pub warnings: Vec<String>,
18}
19
20pub fn validate(contract: &ContractNode) -> ValidationResult {
22 let mut errors = Vec::new();
23 let warnings = Vec::new();
24
25 validate_properties(contract, &mut errors);
26 validate_constructor(contract, &mut errors);
27 validate_methods(contract, &mut errors);
28 check_no_recursion(contract, &mut errors);
29
30 ValidationResult { errors, warnings }
31}
32
33fn is_valid_property_primitive(name: &PrimitiveTypeName) -> bool {
38 match name {
39 PrimitiveTypeName::Bigint
40 | PrimitiveTypeName::Boolean
41 | PrimitiveTypeName::ByteString
42 | PrimitiveTypeName::PubKey
43 | PrimitiveTypeName::Sig
44 | PrimitiveTypeName::Sha256
45 | PrimitiveTypeName::Ripemd160
46 | PrimitiveTypeName::Addr
47 | PrimitiveTypeName::SigHashPreimage
48 | PrimitiveTypeName::RabinSig
49 | PrimitiveTypeName::RabinPubKey
50 | PrimitiveTypeName::Point => true,
51 PrimitiveTypeName::Void => false,
52 }
53}
54
55fn validate_properties(contract: &ContractNode, errors: &mut Vec<String>) {
60 for prop in &contract.properties {
61 validate_property_type(&prop.prop_type, errors);
62 }
63}
64
65fn validate_property_type(type_node: &TypeNode, errors: &mut Vec<String>) {
66 match type_node {
67 TypeNode::Primitive(name) => {
68 if !is_valid_property_primitive(name) {
69 errors.push(format!("Property type '{}' is not valid", name.as_str()));
70 }
71 }
72 TypeNode::FixedArray { element, length } => {
73 if *length == 0 {
74 errors.push("FixedArray length must be a positive integer".to_string());
75 }
76 validate_property_type(element, errors);
77 }
78 TypeNode::Custom(name) => {
79 errors.push(format!(
80 "Unsupported type '{}' in property declaration. Use one of: bigint, boolean, ByteString, PubKey, Sig, Sha256, Ripemd160, Addr, SigHashPreimage, RabinSig, RabinPubKey, or FixedArray<T, N>",
81 name
82 ));
83 }
84 }
85}
86
87fn validate_constructor(contract: &ContractNode, errors: &mut Vec<String>) {
92 let ctor = &contract.constructor;
93 let prop_names: HashSet<String> = contract.properties.iter().map(|p| p.name.clone()).collect();
94
95 if ctor.body.is_empty() {
97 errors.push("Constructor must call super() as its first statement".to_string());
98 return;
99 }
100
101 if !is_super_call(&ctor.body[0]) {
102 errors.push("Constructor must call super() as its first statement".to_string());
103 }
104
105 let mut assigned_props = HashSet::new();
107 for stmt in &ctor.body {
108 if let Statement::Assignment { target, .. } = stmt {
109 if let Expression::PropertyAccess { property } = target {
110 assigned_props.insert(property.clone());
111 }
112 }
113 }
114
115 let props_with_init: HashSet<String> = contract
117 .properties
118 .iter()
119 .filter(|p| p.initializer.is_some())
120 .map(|p| p.name.clone())
121 .collect();
122
123 for prop_name in &prop_names {
124 if !assigned_props.contains(prop_name) && !props_with_init.contains(prop_name) {
125 errors.push(format!(
126 "Property '{}' must be assigned in the constructor",
127 prop_name
128 ));
129 }
130 }
131
132 for param in &ctor.params {
134 if let TypeNode::Custom(ref name) = param.param_type {
135 if name == "unknown" {
136 errors.push(format!(
137 "Constructor parameter '{}' must have a type annotation",
138 param.name
139 ));
140 }
141 }
142 }
143
144 for stmt in &ctor.body {
146 validate_statement(stmt, errors);
147 }
148}
149
150fn is_super_call(stmt: &Statement) -> bool {
151 if let Statement::ExpressionStatement { expression, .. } = stmt {
152 if let Expression::CallExpr { callee, .. } = expression {
153 if let Expression::Identifier { name } = callee.as_ref() {
154 return name == "super";
155 }
156 }
157 }
158 false
159}
160
161fn validate_methods(contract: &ContractNode, errors: &mut Vec<String>) {
166 for method in &contract.methods {
167 validate_method(method, contract, errors);
168 }
169}
170
171fn validate_method(method: &MethodNode, contract: &ContractNode, errors: &mut Vec<String>) {
172 for param in &method.params {
174 if let TypeNode::Custom(ref name) = param.param_type {
175 if name == "unknown" {
176 errors.push(format!(
177 "Parameter '{}' in method '{}' must have a type annotation",
178 param.name, method.name
179 ));
180 }
181 }
182 }
183
184 if method.visibility == Visibility::Public && contract.parent_class == "SmartContract" {
187 if !ends_with_assert(&method.body) {
188 errors.push(format!(
189 "Public method '{}' must end with an assert() call",
190 method.name
191 ));
192 }
193 }
194
195 for stmt in &method.body {
197 validate_statement(stmt, errors);
198 }
199}
200
201fn ends_with_assert(body: &[Statement]) -> bool {
202 if body.is_empty() {
203 return false;
204 }
205
206 let last = &body[body.len() - 1];
207
208 if let Statement::ExpressionStatement { expression, .. } = last {
210 if is_assert_call(expression) {
211 return true;
212 }
213 }
214
215 if let Statement::IfStatement {
217 then_branch,
218 else_branch,
219 ..
220 } = last
221 {
222 let then_ends = ends_with_assert(then_branch);
223 let else_ends = else_branch
224 .as_ref()
225 .map_or(false, |e| ends_with_assert(e));
226 return then_ends && else_ends;
227 }
228
229 false
230}
231
232fn is_assert_call(expr: &Expression) -> bool {
233 if let Expression::CallExpr { callee, .. } = expr {
234 if let Expression::Identifier { name } = callee.as_ref() {
235 return name == "assert";
236 }
237 }
238 false
239}
240
241fn validate_statement(stmt: &Statement, errors: &mut Vec<String>) {
246 match stmt {
247 Statement::VariableDecl { init, .. } => {
248 validate_expression(init, errors);
249 }
250 Statement::Assignment { target, value, .. } => {
251 validate_expression(target, errors);
252 validate_expression(value, errors);
253 }
254 Statement::IfStatement {
255 condition,
256 then_branch,
257 else_branch,
258 ..
259 } => {
260 validate_expression(condition, errors);
261 for s in then_branch {
262 validate_statement(s, errors);
263 }
264 if let Some(else_stmts) = else_branch {
265 for s in else_stmts {
266 validate_statement(s, errors);
267 }
268 }
269 }
270 Statement::ForStatement {
271 condition,
272 init,
273 body,
274 ..
275 } => {
276 validate_expression(condition, errors);
277
278 if let Expression::BinaryExpr { right, .. } = condition {
280 if !is_compile_time_constant(right) {
281 errors.push(
282 "For loop bound must be a compile-time constant (literal or const variable)"
283 .to_string(),
284 );
285 }
286 }
287
288 if let Statement::VariableDecl { init: init_expr, .. } = init.as_ref() {
290 validate_expression(init_expr, errors);
291 }
292
293 for s in body {
295 validate_statement(s, errors);
296 }
297 }
298 Statement::ExpressionStatement { expression, .. } => {
299 validate_expression(expression, errors);
300 }
301 Statement::ReturnStatement { value, .. } => {
302 if let Some(v) = value {
303 validate_expression(v, errors);
304 }
305 }
306 }
307}
308
309fn is_compile_time_constant(expr: &Expression) -> bool {
310 match expr {
311 Expression::BigIntLiteral { .. } => true,
312 Expression::BoolLiteral { .. } => true,
313 Expression::Identifier { .. } => true, Expression::UnaryExpr { op, operand } if *op == UnaryOp::Neg => {
315 is_compile_time_constant(operand)
316 }
317 _ => false,
318 }
319}
320
321fn validate_expression(expr: &Expression, errors: &mut Vec<String>) {
326 match expr {
327 Expression::BinaryExpr { left, right, .. } => {
328 validate_expression(left, errors);
329 validate_expression(right, errors);
330 }
331 Expression::UnaryExpr { operand, .. } => {
332 validate_expression(operand, errors);
333 }
334 Expression::CallExpr { callee, args, .. } => {
335 validate_expression(callee, errors);
336 for arg in args {
337 validate_expression(arg, errors);
338 }
339 }
340 Expression::MemberExpr { object, .. } => {
341 validate_expression(object, errors);
342 }
343 Expression::TernaryExpr {
344 condition,
345 consequent,
346 alternate,
347 } => {
348 validate_expression(condition, errors);
349 validate_expression(consequent, errors);
350 validate_expression(alternate, errors);
351 }
352 Expression::IndexAccess { object, index } => {
353 validate_expression(object, errors);
354 validate_expression(index, errors);
355 }
356 Expression::IncrementExpr { operand, .. } | Expression::DecrementExpr { operand, .. } => {
357 validate_expression(operand, errors);
358 }
359 Expression::Identifier { .. }
361 | Expression::BigIntLiteral { .. }
362 | Expression::BoolLiteral { .. }
363 | Expression::ByteStringLiteral { .. }
364 | Expression::PropertyAccess { .. } => {}
365 }
366}
367
368fn check_no_recursion(contract: &ContractNode, errors: &mut Vec<String>) {
373 let mut call_graph: HashMap<String, HashSet<String>> = HashMap::new();
375 let mut method_names: HashSet<String> = HashSet::new();
376
377 for method in &contract.methods {
378 method_names.insert(method.name.clone());
379 let mut calls = HashSet::new();
380 collect_method_calls(&method.body, &mut calls);
381 call_graph.insert(method.name.clone(), calls);
382 }
383
384 {
386 let mut calls = HashSet::new();
387 collect_method_calls(&contract.constructor.body, &mut calls);
388 call_graph.insert("constructor".to_string(), calls);
389 }
390
391 for method in &contract.methods {
393 let mut visited = HashSet::new();
394 let mut stack = HashSet::new();
395
396 if has_cycle(
397 &method.name,
398 &call_graph,
399 &method_names,
400 &mut visited,
401 &mut stack,
402 ) {
403 errors.push(format!(
404 "Recursion detected: method '{}' calls itself directly or indirectly. Recursion is not allowed in Rúnar contracts.",
405 method.name
406 ));
407 }
408 }
409}
410
411fn collect_method_calls(stmts: &[Statement], calls: &mut HashSet<String>) {
412 for stmt in stmts {
413 collect_method_calls_in_statement(stmt, calls);
414 }
415}
416
417fn collect_method_calls_in_statement(stmt: &Statement, calls: &mut HashSet<String>) {
418 match stmt {
419 Statement::ExpressionStatement { expression, .. } => {
420 collect_method_calls_in_expr(expression, calls);
421 }
422 Statement::VariableDecl { init, .. } => {
423 collect_method_calls_in_expr(init, calls);
424 }
425 Statement::Assignment { target, value, .. } => {
426 collect_method_calls_in_expr(target, calls);
427 collect_method_calls_in_expr(value, calls);
428 }
429 Statement::IfStatement {
430 condition,
431 then_branch,
432 else_branch,
433 ..
434 } => {
435 collect_method_calls_in_expr(condition, calls);
436 collect_method_calls(then_branch, calls);
437 if let Some(else_stmts) = else_branch {
438 collect_method_calls(else_stmts, calls);
439 }
440 }
441 Statement::ForStatement {
442 condition, body, ..
443 } => {
444 collect_method_calls_in_expr(condition, calls);
445 collect_method_calls(body, calls);
446 }
447 Statement::ReturnStatement { value, .. } => {
448 if let Some(v) = value {
449 collect_method_calls_in_expr(v, calls);
450 }
451 }
452 }
453}
454
455fn collect_method_calls_in_expr(expr: &Expression, calls: &mut HashSet<String>) {
456 match expr {
457 Expression::CallExpr { callee, args, .. } => {
458 if let Expression::PropertyAccess { property } = callee.as_ref() {
460 calls.insert(property.clone());
461 }
462 if let Expression::MemberExpr { object, property } = callee.as_ref() {
464 if let Expression::Identifier { name } = object.as_ref() {
465 if name == "this" {
466 calls.insert(property.clone());
467 }
468 }
469 }
470 collect_method_calls_in_expr(callee, calls);
471 for arg in args {
472 collect_method_calls_in_expr(arg, calls);
473 }
474 }
475 Expression::BinaryExpr { left, right, .. } => {
476 collect_method_calls_in_expr(left, calls);
477 collect_method_calls_in_expr(right, calls);
478 }
479 Expression::UnaryExpr { operand, .. } => {
480 collect_method_calls_in_expr(operand, calls);
481 }
482 Expression::MemberExpr { object, .. } => {
483 collect_method_calls_in_expr(object, calls);
484 }
485 Expression::TernaryExpr {
486 condition,
487 consequent,
488 alternate,
489 } => {
490 collect_method_calls_in_expr(condition, calls);
491 collect_method_calls_in_expr(consequent, calls);
492 collect_method_calls_in_expr(alternate, calls);
493 }
494 Expression::IndexAccess { object, index } => {
495 collect_method_calls_in_expr(object, calls);
496 collect_method_calls_in_expr(index, calls);
497 }
498 Expression::IncrementExpr { operand, .. } | Expression::DecrementExpr { operand, .. } => {
499 collect_method_calls_in_expr(operand, calls);
500 }
501 _ => {}
503 }
504}
505
506fn has_cycle(
507 method_name: &str,
508 call_graph: &HashMap<String, HashSet<String>>,
509 method_names: &HashSet<String>,
510 visited: &mut HashSet<String>,
511 stack: &mut HashSet<String>,
512) -> bool {
513 if stack.contains(method_name) {
514 return true;
515 }
516 if visited.contains(method_name) {
517 return false;
518 }
519
520 visited.insert(method_name.to_string());
521 stack.insert(method_name.to_string());
522
523 if let Some(calls) = call_graph.get(method_name) {
524 for callee in calls {
525 if method_names.contains(callee) {
526 if has_cycle(callee, call_graph, method_names, visited, stack) {
527 return true;
528 }
529 }
530 }
531 }
532
533 stack.remove(method_name);
534 false
535}
536
537#[cfg(test)]
542mod tests {
543 use super::*;
544 use crate::frontend::parser::parse_source;
545
546 fn parse_contract(source: &str) -> ContractNode {
548 let result = parse_source(source, Some("test.runar.ts"));
549 assert!(
550 result.errors.is_empty(),
551 "parse errors: {:?}",
552 result.errors
553 );
554 result.contract.expect("expected a contract from parse")
555 }
556
557 #[test]
558 fn test_valid_p2pkh_passes_validation() {
559 let source = r#"
560import { SmartContract, Addr, PubKey, Sig } from 'runar-lang';
561
562class P2PKH extends SmartContract {
563 readonly pubKeyHash: Addr;
564
565 constructor(pubKeyHash: Addr) {
566 super(pubKeyHash);
567 this.pubKeyHash = pubKeyHash;
568 }
569
570 public unlock(sig: Sig, pubKey: PubKey) {
571 assert(hash160(pubKey) === this.pubKeyHash);
572 assert(checkSig(sig, pubKey));
573 }
574}
575"#;
576 let contract = parse_contract(source);
577 let result = validate(&contract);
578 assert!(
579 result.errors.is_empty(),
580 "expected no validation errors, got: {:?}",
581 result.errors
582 );
583 }
584
585 #[test]
586 fn test_missing_super_in_constructor_produces_error() {
587 let source = r#"
588import { SmartContract } from 'runar-lang';
589
590class Bad extends SmartContract {
591 readonly x: bigint;
592
593 constructor(x: bigint) {
594 this.x = x;
595 }
596
597 public check(v: bigint) {
598 assert(v === this.x);
599 }
600}
601"#;
602 let contract = parse_contract(source);
603 let result = validate(&contract);
604 assert!(
605 !result.errors.is_empty(),
606 "expected validation errors for missing super()"
607 );
608 let has_super_error = result
609 .errors
610 .iter()
611 .any(|e| e.to_lowercase().contains("super"));
612 assert!(
613 has_super_error,
614 "expected error about super(), got: {:?}",
615 result.errors
616 );
617 }
618
619 #[test]
620 fn test_public_method_not_ending_with_assert_produces_error() {
621 let source = r#"
622import { SmartContract } from 'runar-lang';
623
624class NoAssert extends SmartContract {
625 readonly x: bigint;
626
627 constructor(x: bigint) {
628 super(x);
629 this.x = x;
630 }
631
632 public check(v: bigint) {
633 const sum = v + this.x;
634 }
635}
636"#;
637 let contract = parse_contract(source);
638 let result = validate(&contract);
639 assert!(
640 !result.errors.is_empty(),
641 "expected validation errors for missing assert at end of public method"
642 );
643 let has_assert_error = result
644 .errors
645 .iter()
646 .any(|e| e.to_lowercase().contains("assert"));
647 assert!(
648 has_assert_error,
649 "expected error about missing assert(), got: {:?}",
650 result.errors
651 );
652 }
653
654 #[test]
655 fn test_direct_recursion_produces_error() {
656 let source = r#"
657import { SmartContract } from 'runar-lang';
658
659class Recursive extends SmartContract {
660 readonly x: bigint;
661
662 constructor(x: bigint) {
663 super(x);
664 this.x = x;
665 }
666
667 public check(v: bigint) {
668 this.check(v);
669 assert(v === this.x);
670 }
671}
672"#;
673 let contract = parse_contract(source);
674 let result = validate(&contract);
675 assert!(
676 !result.errors.is_empty(),
677 "expected validation errors for recursion"
678 );
679 let has_recursion_error = result
680 .errors
681 .iter()
682 .any(|e| e.to_lowercase().contains("recursion") || e.to_lowercase().contains("recursive"));
683 assert!(
684 has_recursion_error,
685 "expected error about recursion, got: {:?}",
686 result.errors
687 );
688 }
689
690 #[test]
691 fn test_stateful_contract_passes_validation() {
692 let source = r#"
695import { StatefulSmartContract } from 'runar-lang';
696
697class Counter extends StatefulSmartContract {
698 count: bigint;
699
700 constructor(count: bigint) {
701 super(count);
702 this.count = count;
703 }
704
705 public increment() {
706 this.count++;
707 }
708}
709"#;
710 let contract = parse_contract(source);
711 let result = validate(&contract);
712 assert!(
713 result.errors.is_empty(),
714 "expected no validation errors for stateful contract, got: {:?}",
715 result.errors
716 );
717 }
718}