1use crate::attribute::Attribute;
2
3use super::{
5 ast::{Expr, ExprValue, Node, NodeValue},
6 token::{Token, TokenType},
7 ty::{
8 StmtTermination, StmtType, StmtTypeCell, Type, TypeCell,
9 ARRAY_TYPE_UNKNOWN_SIZE
10 }
11};
12use pulsar_utils::{
13 disjoint_set::{DisjointSets, NodeTrait},
14 environment::Environment,
15 error::{Error, ErrorBuilder, ErrorCode, ErrorManager, Level, Style},
16 id::Gen,
17 loc::{Region, RegionProvider},
18 CheapClone
19};
20use std::{
21 cell::RefCell, collections::VecDeque, fmt::Debug, iter::zip, rc::Rc
22};
23
24const RETURN_ID: &str = " return";
28
29#[derive(Clone, PartialEq, Eq, Hash)]
30pub struct TypeNode {
31 cell: TypeCell
32}
33
34impl TypeNode {
35 pub fn from_currently_stable_cell(cell: TypeCell) -> Self {
36 Self { cell }
37 }
38
39 pub fn get(&self) -> Type {
40 self.cell.clone_out()
41 }
42}
43
44impl CheapClone for TypeNode {}
45impl Debug for TypeNode {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 write!(f, "{:?}", self.cell.as_ref())
48 }
49}
50impl NodeTrait for TypeNode {}
51
52#[derive(Debug)]
53enum TypeConstraint {
54 Equality {
55 lhs: TypeCell,
56 rhs: TypeCell,
57 lhs_ctx: Region,
58 rhs_ctx: Region
59 }
60}
61
62pub struct StaticAnalyzer {
63 env: Environment<String, TypeCell>,
64 constraints: VecDeque<TypeConstraint>,
65 error_manager: Rc<RefCell<ErrorManager>>
66}
67
68impl StaticAnalyzer {
69 pub fn new(error_manager: Rc<RefCell<ErrorManager>>) -> StaticAnalyzer {
70 StaticAnalyzer {
71 env: Environment::new(),
72 constraints: VecDeque::new(),
73 error_manager
74 }
75 }
76
77 pub fn bind_top_level(&mut self, name: String, ty: Type) {
80 self.env.bind_base(name, TypeCell::new(ty));
81 }
82
83 pub fn infer(&mut self, mut program: Vec<Node>) -> Option<Vec<Node>> {
87 self.constraints.clear();
88
89 self.env.push();
90 for node in &mut program {
91 self.visit_node(node, true)?;
92 }
93 for node in &mut program {
94 self.visit_node(node, false)?;
95 }
96 self.env.pop();
97
98 let substitution = self.unify_constraints()?;
99 for (ty, sub_ty) in &substitution {
100 match *sub_ty.cell.as_ref() {
101 Type::Var(_) => {
102 self.report_ambiguous_type(
103 sub_ty.cell.clone(),
104 "Type variable not resolved (bug?)".into()
105 );
106 return None;
107 }
108 Type::Array(_, ARRAY_TYPE_UNKNOWN_SIZE) => {
109 self.report_ambiguous_type(
110 sub_ty.cell.clone(),
111 "Array size not resolved".into()
112 );
113 return None;
114 }
115 _ => {}
116 }
117 *ty.cell.as_mut() = sub_ty.get();
118 }
119
120 Some(program)
121 }
122
123 fn report(&mut self, error: Error) {
124 self.error_manager.borrow_mut().record(error);
125 }
126
127 fn warn_dead_code(
128 &mut self, func_name: &Token, dead_node: &Node, term_node: &Node
129 ) {
130 self.report(
131 ErrorBuilder::new()
132 .of_style(Style::Primary)
133 .at_level(Level::Warning)
134 .with_code(ErrorCode::StaticAnalysisIssue)
135 .at_region(dead_node)
136 .message("Statement is never reached".into())
137 .build()
138 );
139 self.report(
140 ErrorBuilder::new()
141 .of_style(Style::Secondary)
142 .at_level(Level::Warning)
143 .at_region(term_node)
144 .continues()
145 .explain(format!(
146 "Returned from function `{}` here",
147 func_name.value
148 ))
149 .build()
150 );
151 }
152
153 fn report_missing_return(&mut self, func_name: &Token) {
154 self.report(
155 ErrorBuilder::new()
156 .of_style(Style::Primary)
157 .at_level(Level::Error)
158 .with_code(ErrorCode::InvalidTopLevelConstruct)
159 .at_region(func_name)
160 .message(format!(
161 "Function `{}` does not return from all paths",
162 func_name.value
163 ))
164 .fix("Consider adding a `return` statement at the end of the function".into())
165 .build()
166 );
167 }
168
169 fn report_unbound_name(&mut self, name: &Token) {
170 self.report(
171 ErrorBuilder::new()
172 .of_style(Style::Primary)
173 .at_level(Level::Error)
174 .with_code(ErrorCode::UnboundName)
175 .at_region(name)
176 .message(format!(
177 "Unbound function or variable `{}`",
178 name.value
179 ))
180 .build()
181 );
182 }
183
184 fn report_ambiguous_type(
185 &mut self, ty: TypeCell, explain: String
186 ) {
187 self.report(
188 ErrorBuilder::new()
189 .of_style(Style::Primary)
190 .at_level(Level::Error)
191 .with_code(ErrorCode::StaticAnalysisIssue)
192 .without_loc()
194 .message(format!("Ambiguous type `{}`", ty))
195 .explain(explain)
196 .build()
197 );
198 }
199
200 fn report_failed_purity_derivation(
201 &mut self, pure_token: &Token, name: &Token, impure_node: &Node
202 ) {
203 self.report(
204 ErrorBuilder::new()
205 .of_style(Style::Primary)
206 .at_level(Level::Error)
207 .with_code(ErrorCode::StaticAnalysisIssue)
208 .at_region(impure_node)
209 .message(format!(
210 "Impure statement in `pure` function `{}`",
211 name.value
212 ))
213 .build()
214 );
215 self.report(
216 ErrorBuilder::new()
217 .of_style(Style::Secondary)
218 .at_level(Level::Error)
219 .with_code(ErrorCode::StaticAnalysisIssue)
220 .at_region(pure_token)
221 .continues()
222 .explain("Function declared pure here".into())
223 .fix("Consider marking called functions with `pure`".into())
224 .build()
225 );
226 }
227
228 fn report_called_non_function(&mut self, name: &Token) {
229 self.report(
230 ErrorBuilder::new()
231 .of_style(Style::Primary)
232 .at_level(Level::Error)
233 .with_code(ErrorCode::StaticAnalysisIssue)
234 .at_region(name)
235 .message(format!(
236 "Cannot call non-function value `{}`",
237 name.value
238 ))
239 .build()
240 );
241 }
242
243 fn report_unification_failure(
257 &mut self, lhs: TypeCell, rhs: TypeCell, lhs_ctx: Region,
258 rhs_ctx: Region, fix: Option<String>
259 ) {
260 let mut builder = ErrorBuilder::new()
261 .of_style(Style::Primary)
262 .at_level(Level::Error)
263 .with_code(ErrorCode::StaticAnalysisIssue)
264 .at_region(&lhs_ctx)
265 .message(format!("Failed to unify types `{}` and `{}`", lhs, rhs))
266 .explain(format!("Type inferred here to be `{}`", lhs));
267 if let Some(fix) = fix {
268 builder = builder.fix(fix);
269 }
270 self.report(builder.build());
271 if lhs_ctx != rhs_ctx {
272 self.report(
273 ErrorBuilder::new()
274 .of_style(Style::Secondary)
275 .at_level(Level::Error)
276 .with_code(ErrorCode::StaticAnalysisIssue)
277 .at_region(&rhs_ctx)
278 .continues()
279 .explain(format!("Type inferred here to be `{}`", rhs))
280 .build()
281 );
282 }
283 }
284}
285
286impl StaticAnalyzer {
287 fn new_type_var(&self) -> Type {
288 Type::Var(Gen::next("TypeInferer::get_type_var"))
289 }
290
291 fn add_constraint(
292 &mut self, lhs: TypeCell, rhs: TypeCell, lhs_ctx: Region,
293 rhs_ctx: Region
294 ) {
295 self.constraints.push_back(TypeConstraint::Equality {
296 lhs,
297 rhs,
298 lhs_ctx,
299 rhs_ctx
300 });
301 }
302
303 fn visit_expr(&mut self, expr: &Expr) -> Option<(TypeCell, bool)> {
307 let mut expr_is_pure = true;
308 match &expr.value {
309 ExprValue::ConstantInt(_) => {
310 *expr.ty.as_mut() = Type::Int64;
311 }
312 ExprValue::BoundName(name) => {
313 if let Some(name_ty) = self.env.find(name.value.clone()) {
314 *expr.ty.as_mut() = self.new_type_var();
315 self.add_constraint(
316 expr.ty.clone(),
317 name_ty.clone(),
318 expr.region(),
319 name.region()
320 );
321 } else {
322 self.report_unbound_name(name);
323 return None;
324 }
325 }
326 ExprValue::Call(name, args) => {
327 if let Some(name_ty) = self.env.find(name.value.clone()) {
328 match name_ty.clone_out() {
330 Type::Function {
331 is_pure,
332 args: param_tys,
333 ret
334 } => {
335 expr_is_pure &= is_pure;
336
337 *expr.ty.as_mut() = self.new_type_var();
338 self.add_constraint(
339 expr.ty.clone(),
340 TypeCell::new((*ret).clone()),
341 expr.region(),
342 name.region()
343 );
344
345 for (arg, param_ty) in zip(args, param_tys) {
346 let (arg_ty, arg_is_pure) =
347 self.visit_expr(arg)?;
348 expr_is_pure &= arg_is_pure;
349 self.add_constraint(
350 arg_ty,
351 TypeCell::new(param_ty),
352 arg.region(),
353 name.region() )
356 }
357 }
358 _ => {
359 self.report_called_non_function(name);
360 return None;
361 }
362 }
363 } else {
364 self.report_unbound_name(name);
365 return None;
366 }
367 }
368 ExprValue::MemberAccess(_, _) => todo!(),
369 ExprValue::PostfixBop(array, op1, index, op2)
370 if op1.ty == TokenType::LeftBracket
371 && op2.ty == TokenType::RightBracket =>
372 {
373 *expr.ty.as_mut() = self.new_type_var();
374 let (array_ty, array_is_pure) = self.visit_expr(array)?;
375 let (index_ty, index_is_pure) = self.visit_expr(index)?;
376 expr_is_pure &= array_is_pure && index_is_pure;
377 self.add_constraint(
378 index_ty,
379 Type::int64_singleton(),
380 index.region(),
381 array.region()
382 );
383 self.add_constraint(
384 TypeCell::new(Type::Array(
385 expr.ty.clone(),
386 ARRAY_TYPE_UNKNOWN_SIZE
387 )),
388 array_ty,
389 expr.region(),
390 expr.region()
391 );
392 }
405 ExprValue::PostfixBop(_, _, _, _) => todo!(),
406 ExprValue::ArrayLiteral(elements, should_continue) => {
407 let element_ty_var = self.new_type_var();
408 let element_ty_var_cell = TypeCell::new(element_ty_var);
409 *expr.ty.as_mut() = Type::Array(
410 element_ty_var_cell.clone(),
411 if *should_continue {
412 ARRAY_TYPE_UNKNOWN_SIZE
413 } else {
414 elements
415 .len()
416 .try_into()
417 .unwrap_or_else(|_| panic!("how?"))
418 }
419 );
420 for element in elements {
421 let (element_type, element_is_pure) =
422 self.visit_expr(element)?;
423 expr_is_pure &= element_is_pure;
424 self.add_constraint(
425 element_ty_var_cell.clone(),
426 element_type,
427 expr.region(),
428 element.region()
429 );
430 }
431 }
432 ExprValue::PrefixOp(_, _) => todo!(),
433 ExprValue::InfixBop(lhs, bop, rhs) => {
434 match bop.ty {
437 TokenType::Plus | TokenType::Minus | TokenType::Times => {
438 let (lhs_ty, lhs_is_pure) = self.visit_expr(lhs)?;
439 let (rhs_ty, rhs_is_pure) = self.visit_expr(rhs)?;
440 expr_is_pure &= lhs_is_pure && rhs_is_pure;
441
442 self.add_constraint(
443 expr.ty.clone(),
444 lhs_ty,
445 expr.region(),
446 lhs.region()
447 );
448 self.add_constraint(
449 expr.ty.clone(),
450 rhs_ty,
451 expr.region(),
452 rhs.region()
453 );
454
455 *expr.ty.as_mut() = Type::Int64;
456 }
457 _ => ()
458 }
459 }
460 ExprValue::HardwareMap(map_token, _, f, arr) => {
461 *expr.ty.as_mut() = self.new_type_var();
462 let (arr_ty, arr_is_pure) = self.visit_expr(arr)?;
463 expr_is_pure &= arr_is_pure;
464 if let Some(f_type) = self.env.find(f.value.clone()) {
465 self.add_constraint(
467 f_type.clone(),
468 TypeCell::new(Type::Function {
469 is_pure: true,
470 args: vec![Type::Int64],
471 ret: Box::new(Type::Int64)
472 }),
473 f.region(),
474 map_token.region()
475 );
476 self.add_constraint(
478 arr_ty.clone(),
479 TypeCell::new(Type::Array(
480 Type::int64_singleton(),
481 ARRAY_TYPE_UNKNOWN_SIZE
482 )),
483 arr.region(),
484 map_token.region()
485 );
486 self.add_constraint(
488 expr.ty.clone(),
489 arr_ty,
490 map_token.region(),
491 arr.region()
492 );
493 } else {
494 self.report_unbound_name(f);
495 return None;
496 }
497 }
498 };
499
500 Some((expr.ty.clone(), expr_is_pure))
501 }
502
503 fn visit_node(
504 &mut self, node: &Node, top_level_pass: bool
505 ) -> Option<StmtTypeCell> {
506 match (&node.value, top_level_pass) {
507 (
508 NodeValue::Function {
509 name,
510 params,
511 ret,
512 pure_token,
513 body: _
514 },
515 true
516 ) => {
517 self.env.bind(
519 name.value.clone(),
520 TypeCell::new(Type::Function {
521 is_pure: pure_token.is_some(),
522 args: params
523 .iter()
524 .map(|p| p.1.clone())
525 .collect::<Vec<_>>(),
526 ret: Box::new(ret.clone())
527 })
528 );
529 *node.ty.as_mut() = StmtType::from(
531 StmtTermination::Nonterminal,
532 pure_token.is_some()
533 );
534 }
535 (
536 NodeValue::Function {
537 name,
538 params,
539 ret,
540 pure_token,
541 body
542 },
543 false
544 ) => {
545 self.env.push();
546 self.env.bind(RETURN_ID.into(), TypeCell::new(ret.clone()));
547 for (name, ty) in params {
548 self.env
549 .bind(name.value.clone(), TypeCell::new(ty.clone()));
550 }
551
552 let func_ty = node.ty.clone();
553 let mut warned_dead_code = false;
554 let mut term_node = None;
555 for node in body {
556 let node_ty = self.visit_node(node, false)?;
557 let mut just_found_term = false;
558 if node_ty.as_ref().termination == StmtTermination::Terminal
559 && term_node.is_none()
560 {
561 term_node = Some(node);
562 func_ty.as_mut().termination =
563 StmtTermination::Terminal;
564 just_found_term = true;
565 }
566 if func_ty.as_ref().termination == StmtTermination::Terminal
567 && !warned_dead_code
568 && !just_found_term
569 && !node.attributes.has(Attribute::Generated)
570 {
571 self.warn_dead_code(name, node, term_node.unwrap());
572 warned_dead_code = true;
573 }
574 if !node_ty.as_ref().is_pure && pure_token.is_some() {
575 self.report_failed_purity_derivation(
576 &pure_token.clone().unwrap(),
577 name,
578 node
579 );
580 return None;
581 }
582 }
583 if func_ty.as_ref().termination == StmtTermination::Nonterminal
584 {
585 self.report_missing_return(name);
586 return None;
587 }
588 self.env.pop();
589 }
590 (
591 NodeValue::LetBinding {
592 name,
593 hint: hint_opt,
594 value
595 },
596 false
597 ) => {
598 let (value_ty, expr_is_pure) = self.visit_expr(value)?;
599 if let Some(hint) = hint_opt {
600 self.add_constraint(
601 hint.clone(),
602 value_ty.clone(),
603 name.region(),
604 value.region()
605 );
606 }
607 self.env.bind(name.value.clone(), value_ty);
608 *node.ty.as_mut() =
610 StmtType::from(StmtTermination::Nonterminal, expr_is_pure);
611 }
612 (
613 NodeValue::Return {
614 keyword_token: token,
615 value: value_opt
616 },
617 false
618 ) => {
619 let ((value_ty, value_is_pure), value_start) =
620 if let Some(value) = value_opt {
621 (self.visit_expr(value)?, value.region())
622 } else {
623 ((Type::unit_singleton(), true), token.region())
624 };
625 self.add_constraint(
626 value_ty.clone(),
627 self.env.find(RETURN_ID.into()).unwrap().clone(),
628 value_start,
629 token.region()
630 );
631 *node.ty.as_mut() =
632 StmtType::from(StmtTermination::Terminal, value_is_pure);
633 }
634 _ => {}
635 }
636 Some(node.ty.clone())
637 }
638
639 fn unify(
641 &mut self, dsu: &mut DisjointSets<TypeNode>, lhs: TypeCell,
642 rhs: TypeCell, lhs_ctx: Region, rhs_ctx: Region
643 ) -> Result<(), String> {
644 let lhs_tn = TypeNode::from_currently_stable_cell(lhs.clone());
645 let rhs_tn = TypeNode::from_currently_stable_cell(rhs.clone());
646 dsu.add(lhs_tn.clone());
647 dsu.add(rhs_tn.clone());
648 let lhs_r = dsu
649 .find(lhs_tn)
650 .ok_or_else(|| "dsu find failed".to_string())?;
651 let rhs_r = dsu
652 .find(rhs_tn)
653 .ok_or_else(|| "dsu find failed".to_string())?;
654 if lhs_r != rhs_r {
655 match (lhs_r.get(), rhs_r.get()) {
656 (Type::Var(_), Type::Var(_)) => {
657 dsu.union(lhs_r, rhs_r, true)
658 .ok_or_else(|| "dsu union failed".to_string())?;
659 }
660 (Type::Var(_), _) => {
661 dsu.union(lhs_r, rhs_r, false)
662 .ok_or_else(|| "dsu union failed".to_string())?;
663 }
664 (_, Type::Var(_)) => {
665 dsu.union(rhs_r, lhs_r, false)
666 .ok_or_else(|| "dsu union failed".to_string())?;
667 }
668 (
669 Type::Array(lhs_element_ty, lhs_size),
670 Type::Array(rhs_element_ty, rhs_size)
671 ) => match (lhs_size, rhs_size) {
672 (ARRAY_TYPE_UNKNOWN_SIZE, ARRAY_TYPE_UNKNOWN_SIZE) => {
673 dsu.union(lhs_r, rhs_r, true)
674 .ok_or_else(|| "dsu union failed".to_string())?;
675 self.unify(
676 dsu,
677 lhs_element_ty,
678 rhs_element_ty,
679 lhs_ctx,
680 rhs_ctx
681 )?;
682 }
683 (ARRAY_TYPE_UNKNOWN_SIZE, _) => {
684 dsu.union(lhs_r, rhs_r, false)
685 .ok_or_else(|| "dsu union failed".to_string())?;
686 self.unify(
687 dsu,
688 lhs_element_ty,
689 rhs_element_ty,
690 lhs_ctx,
691 rhs_ctx
692 )?;
693 }
694 (_, ARRAY_TYPE_UNKNOWN_SIZE) => {
695 dsu.union(rhs_r, lhs_r, false)
696 .ok_or_else(|| "dsu union failed".to_string())?;
697 self.unify(
698 dsu,
699 rhs_element_ty,
700 lhs_element_ty,
701 lhs_ctx,
702 rhs_ctx
703 )?;
704 }
705 _ => {
706 if lhs_size != rhs_size {
707 self.report(
708 ErrorBuilder::new()
709 .of_style(Style::Primary)
710 .at_level(Level::Error)
711 .with_code(ErrorCode::StaticAnalysisIssue)
712 .at_region(&lhs_ctx)
713 .message(format!(
714 "Array sizes don't match: {} != {}",
715 lhs_size, rhs_size
716 ))
717 .build()
718 );
719 self.report(
720 ErrorBuilder::new()
721 .of_style(Style::Secondary)
722 .at_level(Level::Error)
723 .with_code(ErrorCode::StaticAnalysisIssue)
724 .at_region(&rhs_ctx)
725 .message("...".into())
726 .explain(format!("Inferred to have size {} here based on environment", rhs_size))
727 .build()
728 );
729
730 return Err("array type error".into());
731 }
732 }
733 },
734 (
735 Type::Function {
736 is_pure: lhs_is_pure,
737 args: lhs_args,
738 ret: lhs_ret
739 },
740 Type::Function {
741 is_pure: rhs_is_pure,
742 args: rhs_args,
743 ret: rhs_ret
744 }
745 ) => {
746 if !lhs_is_pure && rhs_is_pure {
747 self.report_unification_failure(
748 lhs,
749 rhs,
750 lhs_ctx.clone(),
751 lhs_ctx,
752 Some("Try marking the function as `pure`".into())
753 );
754 return Err("unification failure".into());
755 }
756 for (lhs_arg, rhs_arg) in zip(lhs_args, rhs_args) {
757 self.unify(
758 dsu,
759 TypeCell::new(lhs_arg),
760 TypeCell::new(rhs_arg),
761 lhs_ctx.clone(),
762 rhs_ctx.clone()
763 )?;
764 }
765 self.unify(
766 dsu,
767 TypeCell::new(*lhs_ret),
768 TypeCell::new(*rhs_ret),
769 lhs_ctx,
770 rhs_ctx
771 )?;
772 }
773 _ => {
774 self.report_unification_failure(
775 lhs, rhs, lhs_ctx, rhs_ctx, None
776 );
777 return Err("unification failure".into());
778 }
779 }
780 }
781 Ok(())
782 }
783
784 fn unify_constraints(&mut self) -> Option<DisjointSets<TypeNode>> {
785 let mut dsu = DisjointSets::new();
786 while !self.constraints.is_empty() {
787 let constraint = self.constraints.pop_front()?;
788 match constraint {
789 TypeConstraint::Equality {
790 lhs,
791 rhs,
792 lhs_ctx,
793 rhs_ctx
794 } => {
795 let _ = self
796 .unify(&mut dsu, lhs, rhs, lhs_ctx, rhs_ctx)
797 .map_err(|_| {
798 if !self.error_manager.borrow().has_errors() {
799 panic!(
800 "TypeInferer failed without error message"
801 );
802 }
803 });
804 }
805 }
806 }
807 dsu.collapse();
808 Some(dsu)
809 }
810}