1use runmat_parser::{
2 self as parser, BinOp, Expr as AstExpr, Program as AstProgram, Stmt as AstStmt, UnOp,
3};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7pub use runmat_builtins::Type;
9
10#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct VarId(pub usize);
12
13#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
14pub struct HirExpr {
15 pub kind: HirExprKind,
16 pub ty: Type,
17}
18
19#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
20pub enum HirExprKind {
21 Number(String),
22 String(String),
23 Var(VarId),
24 Constant(String), Unary(UnOp, Box<HirExpr>),
26 Binary(Box<HirExpr>, BinOp, Box<HirExpr>),
27 Tensor(Vec<Vec<HirExpr>>),
28 Cell(Vec<Vec<HirExpr>>),
29 Index(Box<HirExpr>, Vec<HirExpr>),
30 IndexCell(Box<HirExpr>, Vec<HirExpr>),
31 Range(Box<HirExpr>, Option<Box<HirExpr>>, Box<HirExpr>),
32 Colon,
33 End,
34 Member(Box<HirExpr>, String),
35 MemberDynamic(Box<HirExpr>, Box<HirExpr>),
36 MethodCall(Box<HirExpr>, String, Vec<HirExpr>),
37 AnonFunc {
38 params: Vec<VarId>,
39 body: Box<HirExpr>,
40 },
41 FuncHandle(String),
42 FuncCall(String, Vec<HirExpr>),
43 MetaClass(String),
44}
45
46#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
47pub enum HirStmt {
48 ExprStmt(HirExpr, bool), Assign(VarId, HirExpr, bool), MultiAssign(Vec<Option<VarId>>, HirExpr, bool),
51 AssignLValue(HirLValue, HirExpr, bool),
52 If {
53 cond: HirExpr,
54 then_body: Vec<HirStmt>,
55 elseif_blocks: Vec<(HirExpr, Vec<HirStmt>)>,
56 else_body: Option<Vec<HirStmt>>,
57 },
58 While {
59 cond: HirExpr,
60 body: Vec<HirStmt>,
61 },
62 For {
63 var: VarId,
64 expr: HirExpr,
65 body: Vec<HirStmt>,
66 },
67 Switch {
68 expr: HirExpr,
69 cases: Vec<(HirExpr, Vec<HirStmt>)>,
70 otherwise: Option<Vec<HirStmt>>,
71 },
72 TryCatch {
73 try_body: Vec<HirStmt>,
74 catch_var: Option<VarId>,
75 catch_body: Vec<HirStmt>,
76 },
77 Global(Vec<(VarId, String)>),
79 Persistent(Vec<(VarId, String)>),
80 Break,
81 Continue,
82 Return,
83 Function {
84 name: String,
85 params: Vec<VarId>,
86 outputs: Vec<VarId>,
87 body: Vec<HirStmt>,
88 has_varargin: bool,
89 has_varargout: bool,
90 },
91 ClassDef {
92 name: String,
93 super_class: Option<String>,
94 members: Vec<HirClassMember>,
95 },
96 Import {
97 path: Vec<String>,
98 wildcard: bool,
99 },
100}
101
102#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
103pub enum HirClassMember {
104 Properties {
105 attributes: Vec<parser::Attr>,
106 names: Vec<String>,
107 },
108 Methods {
109 attributes: Vec<parser::Attr>,
110 body: Vec<HirStmt>,
111 },
112 Events {
113 attributes: Vec<parser::Attr>,
114 names: Vec<String>,
115 },
116 Enumeration {
117 attributes: Vec<parser::Attr>,
118 names: Vec<String>,
119 },
120 Arguments {
121 attributes: Vec<parser::Attr>,
122 names: Vec<String>,
123 },
124}
125
126#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
127pub enum HirLValue {
128 Var(VarId),
129 Member(Box<HirExpr>, String),
130 MemberDynamic(Box<HirExpr>, Box<HirExpr>),
131 Index(Box<HirExpr>, Vec<HirExpr>),
132 IndexCell(Box<HirExpr>, Vec<HirExpr>),
133}
134
135#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
136pub struct HirProgram {
137 pub body: Vec<HirStmt>,
138 #[serde(default)]
139 pub var_types: Vec<Type>,
140}
141
142#[derive(Debug, Clone)]
144pub struct LoweringResult {
145 pub hir: HirProgram,
146 pub variables: HashMap<String, usize>,
147 pub functions: HashMap<String, HirStmt>,
148 pub var_types: Vec<Type>,
149 pub var_names: HashMap<VarId, String>,
150}
151
152pub fn lower(prog: &AstProgram) -> Result<HirProgram, String> {
153 let mut ctx = Ctx::new();
154 let body = ctx.lower_stmts(&prog.body)?;
155 let var_types = ctx.var_types.clone();
156 let hir = HirProgram { body, var_types };
157 let _ = infer_function_output_types(&hir);
159 validate_classdefs(&hir)?;
161 Ok(hir)
162}
163
164pub fn infer_function_output_types(
167 prog: &HirProgram,
168) -> std::collections::HashMap<String, Vec<Type>> {
169 use std::collections::HashMap;
170
171 fn infer_expr_type(
172 expr: &HirExpr,
173 env: &HashMap<VarId, Type>,
174 func_returns: &HashMap<String, Vec<Type>>,
175 ) -> Type {
176 fn unify_tensor(a: &Type, b: &Type) -> Type {
177 match (a, b) {
178 (Type::Tensor { shape: sa }, Type::Tensor { shape: sb }) => match (sa, sb) {
179 (Some(sa), Some(sb)) => {
180 let maxr = sa.len().max(sb.len());
181 let mut out: Vec<Option<usize>> = Vec::with_capacity(maxr);
182 for i in 0..maxr {
183 let da = sa.get(i).cloned().unwrap_or(None);
184 let db = sb.get(i).cloned().unwrap_or(None);
185 let d = match (da, db) {
186 (Some(a), Some(b)) => {
187 if a == b {
188 Some(a)
189 } else if a == 1 {
190 Some(b)
191 } else if b == 1 {
192 Some(a)
193 } else {
194 None
195 }
196 }
197 (Some(a), None) => Some(a),
198 (None, Some(b)) => Some(b),
199 (None, None) => None,
200 };
201 out.push(d);
202 }
203 Type::Tensor { shape: Some(out) }
204 }
205 _ => Type::tensor(),
206 },
207 (Type::Tensor { .. }, _) | (_, Type::Tensor { .. }) => Type::tensor(),
208 _ => Type::tensor(),
209 }
210 }
211 fn index_tensor_shape(
212 base: &Type,
213 idxs: &[HirExpr],
214 env: &HashMap<VarId, Type>,
215 func_returns: &HashMap<String, Vec<Type>>,
216 ) -> Type {
217 let idx_types: Vec<Type> = idxs
219 .iter()
220 .map(|e| infer_expr_type(e, env, func_returns))
221 .collect();
222 match base {
223 Type::Tensor { shape: Some(dims) } => {
224 let rank = dims.len();
225 let mut out: Vec<Option<usize>> = Vec::new();
226 for i in 0..rank {
227 if i < idx_types.len() {
228 match idx_types[i] {
229 Type::Int | Type::Num | Type::Bool => { }
230 _ => {
231 out.push(None);
232 }
233 }
234 } else {
235 out.push(dims[i]);
236 }
237 }
238 if out.is_empty() {
239 Type::Num
240 } else {
241 Type::Tensor { shape: Some(out) }
242 }
243 }
244 Type::Tensor { shape: None } => {
245 let scalar_count = idx_types
247 .iter()
248 .filter(|t| matches!(t, Type::Int | Type::Num | Type::Bool))
249 .count();
250 if scalar_count == idx_types.len() {
251 Type::Num
252 } else {
253 Type::tensor()
254 }
255 }
256 _ => Type::Unknown,
257 }
258 }
259 use HirExprKind as K;
260 match &expr.kind {
261 K::Number(_) => Type::Num,
262 K::String(_) => Type::String,
263 K::Constant(_) => Type::Num,
264 K::Var(id) => env.get(id).cloned().unwrap_or(Type::Unknown),
265 K::Unary(_, e) => infer_expr_type(e, env, func_returns),
266 K::Binary(a, op, b) => {
267 let ta = infer_expr_type(a, env, func_returns);
268 let tb = infer_expr_type(b, env, func_returns);
269 match op {
270 parser::BinOp::Add
271 | parser::BinOp::Sub
272 | parser::BinOp::Mul
273 | parser::BinOp::Div
274 | parser::BinOp::Pow
275 | parser::BinOp::LeftDiv
276 | parser::BinOp::ElemMul
277 | parser::BinOp::ElemDiv
278 | parser::BinOp::ElemPow
279 | parser::BinOp::ElemLeftDiv => {
280 if matches!(ta, Type::Tensor { .. }) || matches!(tb, Type::Tensor { .. }) {
281 unify_tensor(&ta, &tb)
282 } else {
283 Type::Num
284 }
285 }
286 parser::BinOp::Equal
287 | parser::BinOp::NotEqual
288 | parser::BinOp::Less
289 | parser::BinOp::LessEqual
290 | parser::BinOp::Greater
291 | parser::BinOp::GreaterEqual => Type::Bool,
292 parser::BinOp::AndAnd
293 | parser::BinOp::OrOr
294 | parser::BinOp::BitAnd
295 | parser::BinOp::BitOr => Type::Bool,
296 parser::BinOp::Colon => Type::tensor(),
297 }
298 }
299 K::Tensor(rows) => {
300 let r = rows.len();
301 let c = rows.iter().map(|row| row.len()).max().unwrap_or(0);
302 if r > 0 && rows.iter().all(|row| row.len() == c) {
303 Type::tensor_with_shape(vec![r, c])
304 } else {
305 Type::tensor()
306 }
307 }
308 K::Cell(rows) => {
309 let mut elem_ty: Option<Type> = None;
310 let mut len: usize = 0;
311 for row in rows {
312 for e in row {
313 let t = infer_expr_type(e, env, func_returns);
314 elem_ty = Some(match elem_ty {
315 Some(curr) => curr.unify(&t),
316 None => t,
317 });
318 len += 1;
319 }
320 }
321 Type::Cell {
322 element_type: elem_ty.map(Box::new),
323 length: Some(len),
324 }
325 }
326 K::Index(base, idxs) => {
327 let bt = infer_expr_type(base, env, func_returns);
328 index_tensor_shape(&bt, idxs, env, func_returns)
329 }
330 K::IndexCell(base, idxs) => {
331 let bt = infer_expr_type(base, env, func_returns);
332 if let Type::Cell {
333 element_type: Some(t),
334 ..
335 } = bt
336 {
337 let scalar = idxs.len() == 1
338 && matches!(
339 infer_expr_type(&idxs[0], env, func_returns),
340 Type::Int | Type::Num | Type::Bool | Type::Tensor { .. }
341 );
342 if scalar {
343 *t
344 } else {
345 Type::Unknown
346 }
347 } else {
348 Type::Unknown
349 }
350 }
351 K::Range(_, _, _) => Type::tensor(),
352 K::FuncCall(name, _args) => {
353 if let Some(v) = func_returns.get(name) {
354 v.first().cloned().unwrap_or(Type::Unknown)
355 } else {
356 let builtins = runmat_builtins::builtin_functions();
357 if let Some(b) = builtins.iter().find(|b| b.name == *name) {
358 b.return_type.clone()
359 } else {
360 Type::Unknown
361 }
362 }
363 }
364 K::MethodCall(_, _, _) => Type::Unknown,
365 K::Member(base, _) => {
366 let _bt = infer_expr_type(base, env, func_returns);
368 Type::Unknown
369 }
370 K::MemberDynamic(_, _) => Type::Unknown,
371 K::AnonFunc { .. } => Type::Function {
372 params: vec![Type::Unknown],
373 returns: Box::new(Type::Unknown),
374 },
375 K::FuncHandle(_) => Type::Function {
376 params: vec![Type::Unknown],
377 returns: Box::new(Type::Unknown),
378 },
379 K::MetaClass(_) => Type::String,
380 K::End => Type::Unknown,
381 K::Colon => Type::tensor(),
382 }
383 }
384
385 fn join_env(a: &HashMap<VarId, Type>, b: &HashMap<VarId, Type>) -> HashMap<VarId, Type> {
386 let mut out = a.clone();
387 for (k, v) in b {
388 out.entry(*k)
389 .and_modify(|t| *t = t.unify(v))
390 .or_insert_with(|| v.clone());
391 }
392 out
393 }
394
395 #[derive(Clone)]
396 struct Analysis {
397 exits: Vec<HashMap<VarId, Type>>,
398 fallthrough: Option<HashMap<VarId, Type>>,
399 }
400
401 #[allow(clippy::only_used_in_recursion)]
402 #[allow(clippy::type_complexity, clippy::only_used_in_recursion)]
403 fn analyze_stmts(
404 #[allow(clippy::only_used_in_recursion)] _outputs: &[VarId],
405 stmts: &[HirStmt],
406 mut env: HashMap<VarId, Type>,
407 returns: &HashMap<String, Vec<Type>>,
408 func_defs: &HashMap<String, (Vec<VarId>, Vec<VarId>, Vec<HirStmt>)>,
409 ) -> Analysis {
410 let mut exits = Vec::new();
411 let mut i = 0usize;
412 while i < stmts.len() {
413 match &stmts[i] {
414 HirStmt::Assign(var, expr, _) => {
415 let t = infer_expr_type(expr, &env, returns);
416 env.insert(*var, t);
417 }
418 HirStmt::MultiAssign(vars, expr, _) => {
419 if let HirExprKind::FuncCall(ref name, _) = expr.kind {
420 let mut per_out: Vec<Type> = returns.get(name).cloned().unwrap_or_default();
422 let needs_fallback = per_out.is_empty()
424 || per_out.iter().any(|t| matches!(t, Type::Unknown));
425 if needs_fallback {
426 if let Some((params, outs, body)) = func_defs.get(name).cloned() {
427 let mut penv: HashMap<VarId, Type> = HashMap::new();
429 for p in params {
432 penv.insert(p, Type::Num);
433 }
434 let mut out_types: Vec<Type> = vec![Type::Unknown; outs.len()];
436 for s in &body {
437 if let HirStmt::Assign(var, rhs, _) = s {
438 if let Some(pos) = outs.iter().position(|o| o == var) {
439 let t = infer_expr_type(rhs, &penv, returns);
440 out_types[pos] = out_types[pos].unify(&t);
441 }
442 }
443 }
444 if per_out.is_empty() {
445 per_out = out_types;
446 } else {
447 for (i, t) in out_types.into_iter().enumerate() {
448 if matches!(per_out.get(i), Some(Type::Unknown)) {
449 if let Some(slot) = per_out.get_mut(i) {
450 *slot = t;
451 }
452 }
453 }
454 }
455 }
456 }
457 for (i, v) in vars.iter().enumerate() {
458 if let Some(id) = v {
459 env.insert(*id, per_out.get(i).cloned().unwrap_or(Type::Unknown));
460 }
461 }
462 } else {
463 let t = infer_expr_type(expr, &env, returns);
464 for v in vars.iter().flatten() {
465 env.insert(*v, t.clone());
466 }
467 }
468 }
469 HirStmt::ExprStmt(_, _) | HirStmt::Break | HirStmt::Continue => {}
470 HirStmt::Return => {
471 exits.push(env.clone());
472 return Analysis {
473 exits,
474 fallthrough: None,
475 };
476 }
477 HirStmt::If {
478 cond,
479 then_body,
480 elseif_blocks,
481 else_body,
482 } => {
483 fn trim_quotes(s: &str) -> String {
485 let t = s.trim();
486 t.trim_matches('\'').to_string()
487 }
488 fn extract_field_literal(e: &HirExpr) -> Option<String> {
489 match &e.kind {
490 HirExprKind::String(s) => Some(trim_quotes(s)),
491 _ => None,
492 }
493 }
494 fn extract_field_list(e: &HirExpr) -> Vec<String> {
495 match &e.kind {
496 HirExprKind::String(s) => vec![trim_quotes(s)],
497 HirExprKind::Cell(rows) => {
498 let mut out = Vec::new();
499 for row in rows {
500 for it in row {
501 if let Some(v) = extract_field_literal(it) {
502 out.push(v);
503 }
504 }
505 }
506 out
507 }
508 _ => Vec::new(),
509 }
510 }
511 fn collect_assertions(e: &HirExpr, out: &mut Vec<(VarId, String)>) {
512 use HirExprKind as K;
513 match &e.kind {
514 K::Unary(parser::UnOp::Not, _inner) => {
515 }
517 K::Binary(left, parser::BinOp::AndAnd, right)
518 | K::Binary(left, parser::BinOp::BitAnd, right) => {
519 collect_assertions(left, out);
520 collect_assertions(right, out);
521 }
522 K::FuncCall(name, args) => {
523 let lname = name.as_str();
524 if lname.eq_ignore_ascii_case("isfield") && args.len() >= 2 {
525 if let HirExprKind::Var(vid) = args[0].kind {
526 if let Some(f) = extract_field_literal(&args[1]) {
527 out.push((vid, f));
528 }
529 }
530 }
531 if lname.eq_ignore_ascii_case("ismember") && args.len() >= 2 {
533 let mut fields: Vec<String> = Vec::new();
534 let mut target: Option<VarId> = None;
535 if let HirExprKind::FuncCall(ref n0, ref a0) = args[0].kind {
537 if n0.eq_ignore_ascii_case("fieldnames") && a0.len() == 1 {
538 if let HirExprKind::Var(vid) = a0[0].kind {
539 target = Some(vid);
540 }
541 }
542 }
543 if let HirExprKind::FuncCall(ref n1, ref a1) = args[1].kind {
544 if n1.eq_ignore_ascii_case("fieldnames") && a1.len() == 1 {
545 if let HirExprKind::Var(vid) = a1[0].kind {
546 target = Some(vid);
547 }
548 }
549 }
550 if fields.is_empty() {
551 fields.extend(extract_field_list(&args[0]));
552 }
553 if fields.is_empty() {
554 fields.extend(extract_field_list(&args[1]));
555 }
556 if let Some(vid) = target {
557 for f in fields {
558 out.push((vid, f));
559 }
560 }
561 }
562 if (lname.eq_ignore_ascii_case("any")
564 || lname.eq_ignore_ascii_case("all"))
565 && !args.is_empty()
566 {
567 collect_assertions(&args[0], out);
568 }
569 if (lname.eq_ignore_ascii_case("strcmp")
570 || lname.eq_ignore_ascii_case("strcmpi"))
571 && args.len() >= 2
572 {
573 let mut target: Option<VarId> = None;
574 if let HirExprKind::FuncCall(ref n0, ref a0) = args[0].kind {
575 if n0.eq_ignore_ascii_case("fieldnames") && a0.len() == 1 {
576 if let HirExprKind::Var(vid) = a0[0].kind {
577 target = Some(vid);
578 }
579 }
580 }
581 if let HirExprKind::FuncCall(ref n1, ref a1) = args[1].kind {
582 if n1.eq_ignore_ascii_case("fieldnames") && a1.len() == 1 {
583 if let HirExprKind::Var(vid) = a1[0].kind {
584 target = Some(vid);
585 }
586 }
587 }
588 let mut fields = Vec::new();
589 fields.extend(extract_field_list(&args[0]));
590 fields.extend(extract_field_list(&args[1]));
591 if let Some(vid) = target {
592 for f in fields {
593 out.push((vid, f));
594 }
595 }
596 }
597 }
598 _ => {}
599 }
600 }
601 let mut assertions: Vec<(VarId, String)> = Vec::new();
602 collect_assertions(cond, &mut assertions);
603 let mut then_env = env.clone();
604 if !assertions.is_empty() {
605 for (vid, field) in assertions {
606 let mut known = match then_env.get(&vid) {
607 Some(Type::Struct { known_fields }) => known_fields.clone(),
608 _ => Some(Vec::new()),
609 };
610 if let Some(list) = &mut known {
611 if !list.iter().any(|f| f == &field) {
612 list.push(field);
613 list.sort();
614 list.dedup();
615 }
616 }
617 then_env.insert(
618 vid,
619 Type::Struct {
620 known_fields: known,
621 },
622 );
623 }
624 }
625 let then_a = analyze_stmts(_outputs, then_body, then_env, returns, func_defs);
626 let mut out_env = then_a.fallthrough.clone().unwrap_or_else(|| env.clone());
627 let mut all_exits = then_a.exits.clone();
628 for (c, b) in elseif_blocks {
629 let mut elseif_env = env.clone();
630 let mut els_assertions: Vec<(VarId, String)> = Vec::new();
631 collect_assertions(c, &mut els_assertions);
632 if !els_assertions.is_empty() {
633 for (vid, field) in els_assertions {
634 let mut known = match elseif_env.get(&vid) {
635 Some(Type::Struct { known_fields }) => known_fields.clone(),
636 _ => Some(Vec::new()),
637 };
638 if let Some(list) = &mut known {
639 if !list.iter().any(|f| f == &field) {
640 list.push(field);
641 list.sort();
642 list.dedup();
643 }
644 }
645 elseif_env.insert(
646 vid,
647 Type::Struct {
648 known_fields: known,
649 },
650 );
651 }
652 }
653 let a = analyze_stmts(_outputs, b, elseif_env, returns, func_defs);
654 if let Some(f) = a.fallthrough {
655 out_env = join_env(&out_env, &f);
656 }
657 all_exits.extend(a.exits);
658 }
659 if let Some(else_body) = else_body {
660 let a = analyze_stmts(_outputs, else_body, env.clone(), returns, func_defs);
661 if let Some(f) = a.fallthrough {
662 out_env = join_env(&out_env, &f);
663 }
664 all_exits.extend(a.exits);
665 } else {
666 out_env = join_env(&out_env, &env);
667 }
668 env = out_env;
669 exits.extend(all_exits);
670 }
671 HirStmt::While { cond: _, body } => {
672 let a = analyze_stmts(_outputs, body, env.clone(), returns, func_defs);
673 if let Some(f) = a.fallthrough {
674 env = join_env(&env, &f);
675 }
676 exits.extend(a.exits);
677 }
678 HirStmt::For { var, expr, body } => {
679 let t = infer_expr_type(expr, &env, returns);
680 env.insert(*var, t);
681 let a = analyze_stmts(_outputs, body, env.clone(), returns, func_defs);
682 if let Some(f) = a.fallthrough {
683 env = join_env(&env, &f);
684 }
685 exits.extend(a.exits);
686 }
687 HirStmt::Switch {
688 expr: _,
689 cases,
690 otherwise,
691 } => {
692 let mut out_env: Option<HashMap<VarId, Type>> = None;
693 for (_v, b) in cases {
694 let a = analyze_stmts(_outputs, b, env.clone(), returns, func_defs);
695 if let Some(f) = a.fallthrough {
696 out_env = Some(match out_env {
697 Some(curr) => join_env(&curr, &f),
698 None => f,
699 });
700 }
701 exits.extend(a.exits);
702 }
703 if let Some(otherwise) = otherwise {
704 let a = analyze_stmts(_outputs, otherwise, env.clone(), returns, func_defs);
705 if let Some(f) = a.fallthrough {
706 out_env = Some(match out_env {
707 Some(curr) => join_env(&curr, &f),
708 None => f,
709 });
710 }
711 exits.extend(a.exits);
712 } else {
713 out_env = Some(match out_env {
714 Some(curr) => join_env(&curr, &env),
715 None => env.clone(),
716 });
717 }
718 if let Some(f) = out_env {
719 env = f;
720 }
721 }
722 HirStmt::TryCatch {
723 try_body,
724 catch_var: _,
725 catch_body,
726 } => {
727 let a_try = analyze_stmts(_outputs, try_body, env.clone(), returns, func_defs);
728 let a_catch =
729 analyze_stmts(_outputs, catch_body, env.clone(), returns, func_defs);
730 let mut out_env = a_try.fallthrough.clone().unwrap_or_else(|| env.clone());
731 if let Some(f) = a_catch.fallthrough {
732 out_env = join_env(&out_env, &f);
733 }
734 env = out_env;
735 exits.extend(a_try.exits);
736 exits.extend(a_catch.exits);
737 }
738 HirStmt::Global(_) | HirStmt::Persistent(_) => {}
739 HirStmt::Function { .. } => {}
740 HirStmt::ClassDef { .. } => {}
741 HirStmt::AssignLValue(lv, expr, _) => {
742 if let HirLValue::Member(base, field) = lv {
744 if let HirExprKind::Var(vid) = base.kind {
746 let mut known = match env.get(&vid) {
747 Some(Type::Struct { known_fields }) => known_fields.clone(),
748 _ => Some(Vec::new()),
749 };
750 if let Some(list) = &mut known {
751 if !list.iter().any(|f| f == field) {
752 list.push(field.clone());
753 list.sort();
754 list.dedup();
755 }
756 }
757 env.insert(
758 vid,
759 Type::Struct {
760 known_fields: known,
761 },
762 );
763 }
764 }
765 let _ = infer_expr_type(expr, &env, returns);
766 }
767 HirStmt::Import { .. } => {}
768 }
769 i += 1;
770 }
771 Analysis {
772 exits,
773 fallthrough: Some(env),
774 }
775 }
776
777 fn collect_function_names(stmts: &[HirStmt], acc: &mut Vec<String>) {
779 for s in stmts {
780 match s {
781 HirStmt::Function { name, .. } => acc.push(name.clone()),
782 HirStmt::ClassDef { members, .. } => {
783 for m in members {
784 if let HirClassMember::Methods { body, .. } = m {
785 collect_function_names(body, acc);
786 }
787 }
788 }
789 _ => {}
790 }
791 }
792 }
793
794 let mut function_names: Vec<String> = Vec::new();
795 collect_function_names(&prog.body, &mut function_names);
796 let mut returns: HashMap<String, Vec<Type>> = function_names
797 .iter()
798 .map(|n| (n.clone(), Vec::new()))
799 .collect();
800
801 let mut globals: std::collections::HashSet<VarId> = std::collections::HashSet::new();
803 let mut persistents: std::collections::HashSet<VarId> = std::collections::HashSet::new();
804 for stmt in &prog.body {
805 if let HirStmt::Global(vs) = stmt {
806 for (v, _n) in vs {
807 globals.insert(*v);
808 }
809 }
810 }
811 for stmt in &prog.body {
812 if let HirStmt::Persistent(vs) = stmt {
813 for (v, _n) in vs {
814 persistents.insert(*v);
815 }
816 }
817 }
818
819 #[allow(clippy::type_complexity)]
821 let mut func_defs: HashMap<String, (Vec<VarId>, Vec<VarId>, Vec<HirStmt>)> = HashMap::new();
822 for stmt in &prog.body {
823 if let HirStmt::Function {
824 name,
825 params,
826 outputs,
827 body,
828 ..
829 } = stmt
830 {
831 func_defs.insert(
832 name.clone(),
833 (params.clone(), outputs.clone(), body.clone()),
834 );
835 } else if let HirStmt::ClassDef { members, .. } = stmt {
836 for m in members {
837 if let HirClassMember::Methods { body, .. } = m {
838 for s in body {
839 if let HirStmt::Function {
840 name,
841 params,
842 outputs,
843 body,
844 ..
845 } = s
846 {
847 func_defs.insert(
848 name.clone(),
849 (params.clone(), outputs.clone(), body.clone()),
850 );
851 }
852 }
853 }
854 }
855 }
856 }
857
858 for stmt in &prog.body {
860 if let HirStmt::Function {
861 name,
862 outputs,
863 body,
864 ..
865 } = stmt
866 {
867 let mut per_output: Vec<Type> = vec![Type::Unknown; outputs.len()];
868 let analysis = analyze_stmts(outputs, body, HashMap::new(), &returns, &func_defs);
869 let mut accumulate = |env: &HashMap<VarId, Type>| {
870 for (i, out_id) in outputs.iter().enumerate() {
871 if let Some(t) = env.get(out_id) {
872 per_output[i] = per_output[i].unify(t);
873 }
874 }
875 };
876 if let Some(f) = &analysis.fallthrough {
877 accumulate(f);
878 }
879 for e in &analysis.exits {
880 accumulate(e);
881 }
882 returns.insert(name.clone(), per_output);
883 }
884 }
885
886 let mut changed = true;
887 let mut iter = 0usize;
888 let max_iters = 3usize;
889 while changed && iter < max_iters {
890 changed = false;
891 iter += 1;
892 for stmt in &prog.body {
893 match stmt {
894 HirStmt::Function {
895 name,
896 outputs,
897 body,
898 ..
899 } => {
900 let analysis =
901 analyze_stmts(outputs, body, HashMap::new(), &returns, &func_defs);
902 let mut per_output: Vec<Type> = vec![Type::Unknown; outputs.len()];
903 let mut accumulate = |env: &HashMap<VarId, Type>| {
904 for (i, out_id) in outputs.iter().enumerate() {
905 if let Some(t) = env.get(out_id) {
906 per_output[i] = per_output[i].unify(t);
907 }
908 }
909 };
910 for e in &analysis.exits {
911 accumulate(e);
912 }
913 if let Some(f) = &analysis.fallthrough {
914 accumulate(f);
915 }
916 if returns.get(name) != Some(&per_output) {
917 returns.insert(name.clone(), per_output);
918 changed = true;
919 }
920 }
921 HirStmt::ClassDef { members, .. } => {
922 for m in members {
923 if let HirClassMember::Methods { body, .. } = m {
924 for s in body {
925 if let HirStmt::Function {
926 name,
927 outputs,
928 body,
929 ..
930 } = s
931 {
932 let analysis = analyze_stmts(
933 outputs,
934 body,
935 HashMap::new(),
936 &returns,
937 &func_defs,
938 );
939 let mut per_output: Vec<Type> =
940 vec![Type::Unknown; outputs.len()];
941 let mut accumulate = |env: &HashMap<VarId, Type>| {
942 for (i, out_id) in outputs.iter().enumerate() {
943 if let Some(t) = env.get(out_id) {
944 per_output[i] = per_output[i].unify(t);
945 }
946 }
947 };
948 for e in &analysis.exits {
949 accumulate(e);
950 }
951 if let Some(f) = &analysis.fallthrough {
952 accumulate(f);
953 }
954 if returns.get(name) != Some(&per_output) {
955 returns.insert(name.clone(), per_output);
956 changed = true;
957 }
958 }
959 }
960 }
961 }
962 }
963 _ => {}
964 }
965 }
966 }
967
968 returns
969}
970
971#[allow(clippy::type_complexity)]
974pub fn infer_function_variable_types(
975 prog: &HirProgram,
976) -> std::collections::HashMap<String, std::collections::HashMap<VarId, Type>> {
977 use std::collections::HashMap;
978
979 let returns_map = infer_function_output_types(prog);
981
982 #[allow(clippy::type_complexity)]
984 let mut func_defs: HashMap<String, (Vec<VarId>, Vec<VarId>, Vec<HirStmt>)> = HashMap::new();
985 for stmt in &prog.body {
986 if let HirStmt::Function {
987 name,
988 params,
989 outputs,
990 body,
991 ..
992 } = stmt
993 {
994 func_defs.insert(
995 name.clone(),
996 (params.clone(), outputs.clone(), body.clone()),
997 );
998 } else if let HirStmt::ClassDef { members, .. } = stmt {
999 for m in members {
1000 if let HirClassMember::Methods { body, .. } = m {
1001 for s in body {
1002 if let HirStmt::Function {
1003 name,
1004 params,
1005 outputs,
1006 body,
1007 ..
1008 } = s
1009 {
1010 func_defs.insert(
1011 name.clone(),
1012 (params.clone(), outputs.clone(), body.clone()),
1013 );
1014 }
1015 }
1016 }
1017 }
1018 }
1019 }
1020
1021 fn infer_expr_type(
1022 expr: &HirExpr,
1023 env: &HashMap<VarId, Type>,
1024 returns: &HashMap<String, Vec<Type>>,
1025 ) -> Type {
1026 use HirExprKind as K;
1027 match &expr.kind {
1028 K::Number(_) => Type::Num,
1029 K::String(_) => Type::String,
1030 K::Constant(_) => Type::Num,
1031 K::Var(id) => env.get(id).cloned().unwrap_or(Type::Unknown),
1032 K::Unary(_, e) => infer_expr_type(e, env, returns),
1033 K::Binary(a, op, b) => {
1034 let ta = infer_expr_type(a, env, returns);
1035 let tb = infer_expr_type(b, env, returns);
1036 match op {
1037 parser::BinOp::Add
1038 | parser::BinOp::Sub
1039 | parser::BinOp::Mul
1040 | parser::BinOp::Div
1041 | parser::BinOp::Pow
1042 | parser::BinOp::LeftDiv => {
1043 if matches!(ta, Type::Tensor { .. }) || matches!(tb, Type::Tensor { .. }) {
1044 Type::tensor()
1045 } else {
1046 Type::Num
1047 }
1048 }
1049 parser::BinOp::ElemMul
1050 | parser::BinOp::ElemDiv
1051 | parser::BinOp::ElemPow
1052 | parser::BinOp::ElemLeftDiv => {
1053 if matches!(ta, Type::Tensor { .. }) || matches!(tb, Type::Tensor { .. }) {
1054 Type::tensor()
1055 } else {
1056 Type::Num
1057 }
1058 }
1059 parser::BinOp::Equal
1060 | parser::BinOp::NotEqual
1061 | parser::BinOp::Less
1062 | parser::BinOp::LessEqual
1063 | parser::BinOp::Greater
1064 | parser::BinOp::GreaterEqual => Type::Bool,
1065 parser::BinOp::AndAnd
1066 | parser::BinOp::OrOr
1067 | parser::BinOp::BitAnd
1068 | parser::BinOp::BitOr => Type::Bool,
1069 parser::BinOp::Colon => Type::tensor(),
1070 }
1071 }
1072 K::Tensor(rows) => {
1073 let r = rows.len();
1074 let c = rows.iter().map(|row| row.len()).max().unwrap_or(0);
1075 if r > 0 && rows.iter().all(|row| row.len() == c) {
1076 Type::tensor_with_shape(vec![r, c])
1077 } else {
1078 Type::tensor()
1079 }
1080 }
1081 K::Cell(rows) => {
1082 let mut elem_ty: Option<Type> = None;
1083 let mut len: usize = 0;
1084 for row in rows {
1085 for e in row {
1086 let t = infer_expr_type(e, env, returns);
1087 elem_ty = Some(match elem_ty {
1088 Some(curr) => curr.unify(&t),
1089 None => t,
1090 });
1091 len += 1;
1092 }
1093 }
1094 Type::Cell {
1095 element_type: elem_ty.map(Box::new),
1096 length: Some(len),
1097 }
1098 }
1099 K::Index(base, idxs) => {
1100 let bt = infer_expr_type(base, env, returns);
1101 let scalar_indices = idxs.iter().all(|i| {
1102 matches!(
1103 infer_expr_type(i, env, returns),
1104 Type::Int | Type::Num | Type::Bool
1105 )
1106 });
1107 if scalar_indices {
1108 Type::Num
1109 } else {
1110 bt
1111 }
1112 }
1113 K::IndexCell(base, idxs) => {
1114 let bt = infer_expr_type(base, env, returns);
1115 if let Type::Cell {
1116 element_type: Some(t),
1117 ..
1118 } = bt
1119 {
1120 let scalar = idxs.len() == 1
1121 && matches!(
1122 infer_expr_type(&idxs[0], env, returns),
1123 Type::Int | Type::Num | Type::Bool | Type::Tensor { .. }
1124 );
1125 if scalar {
1126 *t
1127 } else {
1128 Type::Unknown
1129 }
1130 } else {
1131 Type::Unknown
1132 }
1133 }
1134 K::Range(_, _, _) => Type::tensor(),
1135 K::FuncCall(name, _args) => returns
1136 .get(name)
1137 .and_then(|v| v.first())
1138 .cloned()
1139 .unwrap_or_else(|| {
1140 if let Some(b) = runmat_builtins::builtin_functions()
1141 .into_iter()
1142 .find(|b| b.name == *name)
1143 {
1144 b.return_type.clone()
1145 } else {
1146 Type::Unknown
1147 }
1148 }),
1149 K::MethodCall(_, _, _) => Type::Unknown,
1150 K::Member(_, _) => Type::Unknown,
1151 K::MemberDynamic(_, _) => Type::Unknown,
1152 K::AnonFunc { .. } => Type::Function {
1153 params: vec![Type::Unknown],
1154 returns: Box::new(Type::Unknown),
1155 },
1156 K::FuncHandle(_) => Type::Function {
1157 params: vec![Type::Unknown],
1158 returns: Box::new(Type::Unknown),
1159 },
1160 K::MetaClass(_) => Type::String,
1161 K::End => Type::Unknown,
1162 K::Colon => Type::tensor(),
1163 }
1164 }
1165
1166 fn join_env(a: &HashMap<VarId, Type>, b: &HashMap<VarId, Type>) -> HashMap<VarId, Type> {
1167 let mut out = a.clone();
1168 for (k, v) in b {
1169 out.entry(*k)
1170 .and_modify(|t| *t = t.unify(v))
1171 .or_insert_with(|| v.clone());
1172 }
1173 out
1174 }
1175
1176 #[derive(Clone)]
1177 struct Analysis {
1178 exits: Vec<HashMap<VarId, Type>>,
1179 fallthrough: Option<HashMap<VarId, Type>>,
1180 }
1181
1182 #[allow(clippy::type_complexity, clippy::only_used_in_recursion)]
1183 fn analyze_stmts(
1184 #[allow(clippy::only_used_in_recursion)] _outputs: &[VarId],
1185 stmts: &[HirStmt],
1186 mut env: HashMap<VarId, Type>,
1187 returns: &HashMap<String, Vec<Type>>,
1188 func_defs: &HashMap<String, (Vec<VarId>, Vec<VarId>, Vec<HirStmt>)>,
1189 ) -> Analysis {
1190 let mut exits = Vec::new();
1191 let mut i = 0usize;
1192 while i < stmts.len() {
1193 match &stmts[i] {
1194 HirStmt::Assign(var, expr, _) => {
1195 let t = infer_expr_type(expr, &env, returns);
1196 env.insert(*var, t);
1197 }
1198 HirStmt::MultiAssign(vars, expr, _) => {
1199 if let HirExprKind::FuncCall(ref name, _) = expr.kind {
1200 let mut per_out: Vec<Type> = returns.get(name).cloned().unwrap_or_default();
1202 let needs_fallback = per_out.is_empty()
1204 || per_out.iter().any(|t| matches!(t, Type::Unknown));
1205 if needs_fallback {
1206 if let Some((params, outs, body)) = func_defs.get(name).cloned() {
1207 let mut penv: HashMap<VarId, Type> = HashMap::new();
1209 for p in params {
1212 penv.insert(p, Type::Num);
1213 }
1214 let mut out_types: Vec<Type> = vec![Type::Unknown; outs.len()];
1216 for s in &body {
1217 if let HirStmt::Assign(var, rhs, _) = s {
1218 if let Some(pos) = outs.iter().position(|o| o == var) {
1219 let t = infer_expr_type(rhs, &penv, returns);
1220 out_types[pos] = out_types[pos].unify(&t);
1221 }
1222 }
1223 }
1224 if per_out.is_empty() {
1225 per_out = out_types;
1226 } else {
1227 for (i, t) in out_types.into_iter().enumerate() {
1228 if matches!(per_out.get(i), Some(Type::Unknown)) {
1229 if let Some(slot) = per_out.get_mut(i) {
1230 *slot = t;
1231 }
1232 }
1233 }
1234 }
1235 }
1236 }
1237 for (i, v) in vars.iter().enumerate() {
1238 if let Some(id) = v {
1239 env.insert(*id, per_out.get(i).cloned().unwrap_or(Type::Unknown));
1240 }
1241 }
1242 } else {
1243 let t = infer_expr_type(expr, &env, returns);
1244 for v in vars.iter().flatten() {
1245 env.insert(*v, t.clone());
1246 }
1247 }
1248 }
1249 HirStmt::ExprStmt(_, _) | HirStmt::Break | HirStmt::Continue => {}
1250 HirStmt::Return => {
1251 exits.push(env.clone());
1252 return Analysis {
1253 exits,
1254 fallthrough: None,
1255 };
1256 }
1257 HirStmt::If {
1258 cond,
1259 then_body,
1260 elseif_blocks,
1261 else_body,
1262 } => {
1263 fn trim_quotes(s: &str) -> String {
1265 let t = s.trim();
1266 t.trim_matches('\'').to_string()
1267 }
1268 fn extract_field_literal(e: &HirExpr) -> Option<String> {
1269 match &e.kind {
1270 HirExprKind::String(s) => Some(trim_quotes(s)),
1271 _ => None,
1272 }
1273 }
1274 fn extract_field_list(e: &HirExpr) -> Vec<String> {
1275 match &e.kind {
1276 HirExprKind::String(s) => vec![trim_quotes(s)],
1277 HirExprKind::Cell(rows) => {
1278 let mut out = Vec::new();
1279 for row in rows {
1280 for it in row {
1281 if let Some(v) = extract_field_literal(it) {
1282 out.push(v);
1283 }
1284 }
1285 }
1286 out
1287 }
1288 _ => Vec::new(),
1289 }
1290 }
1291 fn collect_assertions(e: &HirExpr, out: &mut Vec<(VarId, String)>) {
1292 use HirExprKind as K;
1293 match &e.kind {
1294 K::Unary(parser::UnOp::Not, _inner) => {}
1295 K::Binary(left, parser::BinOp::AndAnd, right)
1296 | K::Binary(left, parser::BinOp::BitAnd, right) => {
1297 collect_assertions(left, out);
1298 collect_assertions(right, out);
1299 }
1300 K::FuncCall(name, args) => {
1301 let lname = name.as_str();
1302 if lname.eq_ignore_ascii_case("isfield") && args.len() >= 2 {
1303 if let HirExprKind::Var(vid) = args[0].kind {
1304 if let Some(f) = extract_field_literal(&args[1]) {
1305 out.push((vid, f));
1306 }
1307 }
1308 }
1309 if lname.eq_ignore_ascii_case("ismember") && args.len() >= 2 {
1311 let mut fields: Vec<String> = Vec::new();
1312 let mut target: Option<VarId> = None;
1313 if let HirExprKind::FuncCall(ref n0, ref a0) = args[0].kind {
1315 if n0.eq_ignore_ascii_case("fieldnames") && a0.len() == 1 {
1316 if let HirExprKind::Var(vid) = a0[0].kind {
1317 target = Some(vid);
1318 }
1319 }
1320 }
1321 if let HirExprKind::FuncCall(ref n1, ref a1) = args[1].kind {
1322 if n1.eq_ignore_ascii_case("fieldnames") && a1.len() == 1 {
1323 if let HirExprKind::Var(vid) = a1[0].kind {
1324 target = Some(vid);
1325 }
1326 }
1327 }
1328 if fields.is_empty() {
1329 fields.extend(extract_field_list(&args[0]));
1330 }
1331 if fields.is_empty() {
1332 fields.extend(extract_field_list(&args[1]));
1333 }
1334 if let Some(vid) = target {
1335 for f in fields {
1336 out.push((vid, f));
1337 }
1338 }
1339 }
1340 if (lname.eq_ignore_ascii_case("any")
1342 || lname.eq_ignore_ascii_case("all"))
1343 && !args.is_empty()
1344 {
1345 collect_assertions(&args[0], out);
1346 }
1347 if (lname.eq_ignore_ascii_case("strcmp")
1348 || lname.eq_ignore_ascii_case("strcmpi"))
1349 && args.len() >= 2
1350 {
1351 let mut target: Option<VarId> = None;
1352 if let HirExprKind::FuncCall(ref n0, ref a0) = args[0].kind {
1353 if n0.eq_ignore_ascii_case("fieldnames") && a0.len() == 1 {
1354 if let HirExprKind::Var(vid) = a0[0].kind {
1355 target = Some(vid);
1356 }
1357 }
1358 }
1359 if let HirExprKind::FuncCall(ref n1, ref a1) = args[1].kind {
1360 if n1.eq_ignore_ascii_case("fieldnames") && a1.len() == 1 {
1361 if let HirExprKind::Var(vid) = a1[0].kind {
1362 target = Some(vid);
1363 }
1364 }
1365 }
1366 let mut fields = Vec::new();
1367 fields.extend(extract_field_list(&args[0]));
1368 fields.extend(extract_field_list(&args[1]));
1369 if let Some(vid) = target {
1370 for f in fields {
1371 out.push((vid, f));
1372 }
1373 }
1374 }
1375 }
1376 _ => {}
1377 }
1378 }
1379 let mut assertions: Vec<(VarId, String)> = Vec::new();
1380 collect_assertions(cond, &mut assertions);
1381 let mut then_env = env.clone();
1382 if !assertions.is_empty() {
1383 for (vid, field) in assertions {
1384 let mut known = match then_env.get(&vid) {
1385 Some(Type::Struct { known_fields }) => known_fields.clone(),
1386 _ => Some(Vec::new()),
1387 };
1388 if let Some(list) = &mut known {
1389 if !list.iter().any(|f| f == &field) {
1390 list.push(field);
1391 list.sort();
1392 list.dedup();
1393 }
1394 }
1395 then_env.insert(
1396 vid,
1397 Type::Struct {
1398 known_fields: known,
1399 },
1400 );
1401 }
1402 }
1403 let then_a = analyze_stmts(_outputs, then_body, then_env, returns, func_defs);
1404 let mut out_env = then_a.fallthrough.clone().unwrap_or_else(|| env.clone());
1405 let mut all_exits = then_a.exits.clone();
1406 for (c, b) in elseif_blocks {
1407 let mut elseif_env = env.clone();
1408 let mut els_assertions: Vec<(VarId, String)> = Vec::new();
1409 collect_assertions(c, &mut els_assertions);
1410 if !els_assertions.is_empty() {
1411 for (vid, field) in els_assertions {
1412 let mut known = match elseif_env.get(&vid) {
1413 Some(Type::Struct { known_fields }) => known_fields.clone(),
1414 _ => Some(Vec::new()),
1415 };
1416 if let Some(list) = &mut known {
1417 if !list.iter().any(|f| f == &field) {
1418 list.push(field);
1419 list.sort();
1420 list.dedup();
1421 }
1422 }
1423 elseif_env.insert(
1424 vid,
1425 Type::Struct {
1426 known_fields: known,
1427 },
1428 );
1429 }
1430 }
1431 let a = analyze_stmts(_outputs, b, elseif_env, returns, func_defs);
1432 if let Some(f) = a.fallthrough {
1433 out_env = join_env(&out_env, &f);
1434 }
1435 all_exits.extend(a.exits);
1436 }
1437 if let Some(else_body) = else_body {
1438 let a = analyze_stmts(_outputs, else_body, env.clone(), returns, func_defs);
1439 if let Some(f) = a.fallthrough {
1440 out_env = join_env(&out_env, &f);
1441 }
1442 all_exits.extend(a.exits);
1443 } else {
1444 out_env = join_env(&out_env, &env);
1445 }
1446 env = out_env;
1447 exits.extend(all_exits);
1448 }
1449 HirStmt::While { body, .. } => {
1450 let a = analyze_stmts(_outputs, body, env.clone(), returns, func_defs);
1451 if let Some(f) = a.fallthrough {
1452 env = join_env(&env, &f);
1453 }
1454 exits.extend(a.exits);
1455 }
1456 HirStmt::For { var, expr, body } => {
1457 let t = infer_expr_type(expr, &env, returns);
1458 env.insert(*var, t);
1459 let a = analyze_stmts(_outputs, body, env.clone(), returns, func_defs);
1460 if let Some(f) = a.fallthrough {
1461 env = join_env(&env, &f);
1462 }
1463 exits.extend(a.exits);
1464 }
1465 HirStmt::Switch {
1466 cases, otherwise, ..
1467 } => {
1468 let mut out_env: Option<HashMap<VarId, Type>> = None;
1469 for (_v, b) in cases {
1470 let a = analyze_stmts(_outputs, b, env.clone(), returns, func_defs);
1471 if let Some(f) = a.fallthrough {
1472 out_env = Some(match out_env {
1473 Some(curr) => join_env(&curr, &f),
1474 None => f,
1475 });
1476 }
1477 exits.extend(a.exits);
1478 }
1479 if let Some(otherwise) = otherwise {
1480 let a = analyze_stmts(_outputs, otherwise, env.clone(), returns, func_defs);
1481 if let Some(f) = a.fallthrough {
1482 out_env = Some(match out_env {
1483 Some(curr) => join_env(&curr, &f),
1484 None => f,
1485 });
1486 }
1487 exits.extend(a.exits);
1488 } else {
1489 out_env = Some(match out_env {
1490 Some(curr) => join_env(&curr, &env),
1491 None => env.clone(),
1492 });
1493 }
1494 if let Some(f) = out_env {
1495 env = f;
1496 }
1497 }
1498 HirStmt::TryCatch {
1499 try_body,
1500 catch_body,
1501 ..
1502 } => {
1503 let a_try = analyze_stmts(_outputs, try_body, env.clone(), returns, func_defs);
1504 let a_catch =
1505 analyze_stmts(_outputs, catch_body, env.clone(), returns, func_defs);
1506 let mut out_env = a_try.fallthrough.clone().unwrap_or_else(|| env.clone());
1507 if let Some(f) = a_catch.fallthrough {
1508 out_env = join_env(&out_env, &f);
1509 }
1510 env = out_env;
1511 exits.extend(a_try.exits);
1512 exits.extend(a_catch.exits);
1513 }
1514 HirStmt::Global(_) | HirStmt::Persistent(_) => {}
1515 HirStmt::Function { .. } => {}
1516 HirStmt::ClassDef { .. } => {}
1517 HirStmt::AssignLValue(_, expr, _) => {
1518 let _ = infer_expr_type(expr, &env, returns);
1519 }
1520 HirStmt::Import { .. } => {}
1521 }
1522 i += 1;
1523 }
1524 Analysis {
1525 exits,
1526 fallthrough: Some(env),
1527 }
1528 }
1529
1530 let mut out: HashMap<String, HashMap<VarId, Type>> = HashMap::new();
1531 for stmt in &prog.body {
1532 match stmt {
1533 HirStmt::Function { name, body, .. } => {
1534 let empty: &[VarId] = &[];
1535 let a = analyze_stmts(empty, body, HashMap::new(), &returns_map, &func_defs);
1536 let mut env = HashMap::new();
1537 for e in &a.exits {
1538 env = join_env(&env, e);
1539 }
1540 if let Some(f) = &a.fallthrough {
1541 env = join_env(&env, f);
1542 }
1543 out.insert(name.clone(), env);
1544 }
1545 HirStmt::ClassDef { members, .. } => {
1546 for m in members {
1547 if let HirClassMember::Methods { body, .. } = m {
1548 for s in body {
1549 if let HirStmt::Function { name, body, .. } = s {
1550 let empty: &[VarId] = &[];
1551 let a = analyze_stmts(
1552 empty,
1553 body,
1554 HashMap::new(),
1555 &returns_map,
1556 &func_defs,
1557 );
1558 let mut env = HashMap::new();
1559 for e in &a.exits {
1560 env = join_env(&env, e);
1561 }
1562 if let Some(f) = &a.fallthrough {
1563 env = join_env(&env, f);
1564 }
1565 out.insert(name.clone(), env);
1566 }
1567 }
1568 }
1569 }
1570 }
1571 _ => {}
1572 }
1573 }
1574 out
1575}
1576
1577pub fn collect_imports(prog: &HirProgram) -> Vec<(Vec<String>, bool)> {
1579 let mut imports = Vec::new();
1580 for stmt in &prog.body {
1581 if let HirStmt::Import { path, wildcard } = stmt {
1582 imports.push((path.clone(), *wildcard));
1583 }
1584 }
1585 imports
1586}
1587
1588#[derive(Debug, Clone, PartialEq, Eq)]
1590pub struct NormalizedImport {
1591 pub path: String,
1593 pub wildcard: bool,
1595 pub unqualified: Option<String>,
1597}
1598
1599pub fn normalize_imports(prog: &HirProgram) -> Vec<NormalizedImport> {
1601 let mut out = Vec::new();
1602 for stmt in &prog.body {
1603 if let HirStmt::Import { path, wildcard } = stmt {
1604 let path_str = path.join(".");
1606 let last = if *wildcard {
1607 None
1608 } else {
1609 path.last().cloned()
1610 };
1611 out.push(NormalizedImport {
1612 path: path_str,
1613 wildcard: *wildcard,
1614 unqualified: last,
1615 });
1616 }
1617 }
1618 out
1619}
1620
1621pub fn validate_imports(prog: &HirProgram) -> Result<(), String> {
1625 use std::collections::{HashMap, HashSet};
1626 let norms = normalize_imports(prog);
1627 let mut seen_exact: HashSet<(String, bool)> = HashSet::new();
1628 for n in &norms {
1629 if !seen_exact.insert((n.path.clone(), n.wildcard)) {
1630 return Err(format!(
1631 "duplicate import '{}{}'",
1632 n.path,
1633 if n.wildcard { ".*" } else { "" }
1634 ));
1635 }
1636 }
1637 let mut by_name: HashMap<String, Vec<String>> = HashMap::new();
1639 for n in &norms {
1640 if !n.wildcard {
1641 if let Some(uq) = &n.unqualified {
1642 by_name.entry(uq.clone()).or_default().push(n.path.clone());
1643 }
1644 }
1645 }
1646 for (uq, sources) in by_name {
1647 if sources.len() > 1 {
1648 return Err(format!(
1649 "ambiguous import for '{}': {}",
1650 uq,
1651 sources.join(", ")
1652 ));
1653 }
1654 }
1655 Ok(())
1656}
1657
1658pub fn validate_classdefs(prog: &HirProgram) -> Result<(), String> {
1662 use std::collections::HashSet;
1663 fn norm_attr_value(v: &str) -> String {
1664 let t = v.trim();
1665 let t = t.trim_matches('\'');
1666 t.to_ascii_lowercase()
1667 }
1668 fn validate_access_value(ctx: &str, v: &str) -> Result<(), String> {
1669 match v {
1670 "public" | "private" => Ok(()),
1671 other => Err(format!(
1672 "invalid access value '{other}' in {ctx} (allowed: public, private)",
1673 )),
1674 }
1675 }
1676 for stmt in &prog.body {
1677 if let HirStmt::ClassDef {
1678 name,
1679 super_class,
1680 members,
1681 } = stmt
1682 {
1683 if let Some(sup) = super_class {
1684 if sup == name {
1685 return Err(format!("Class '{name}' cannot inherit from itself"));
1686 }
1687 }
1688 let mut prop_names: HashSet<String> = HashSet::new();
1689 let mut method_names: HashSet<String> = HashSet::new();
1690 for m in members {
1691 match m {
1692 HirClassMember::Properties {
1693 names: props,
1694 attributes,
1695 } => {
1696 let mut has_static = false;
1698 let mut has_constant = false;
1699 let mut _has_transient = false;
1700 let mut _has_hidden = false;
1701 let mut has_dependent = false;
1702 let mut access_default: Option<String> = None;
1703 let mut get_access: Option<String> = None;
1704 let mut set_access: Option<String> = None;
1705 for a in attributes {
1706 if a.name.eq_ignore_ascii_case("Static") {
1707 has_static = true;
1708 continue;
1709 }
1710 if a.name.eq_ignore_ascii_case("Constant") {
1711 has_constant = true;
1712 continue;
1713 }
1714 if a.name.eq_ignore_ascii_case("Transient") {
1715 _has_transient = true;
1716 continue;
1717 }
1718 if a.name.eq_ignore_ascii_case("Hidden") {
1719 _has_hidden = true;
1720 continue;
1721 }
1722 if a.name.eq_ignore_ascii_case("Dependent") {
1723 has_dependent = true;
1724 continue;
1725 }
1726 if a.name.eq_ignore_ascii_case("Access") {
1727 let v = a.value.as_ref().ok_or_else(|| {
1728 format!(
1729 "Access requires value in class '{name}' properties block",
1730 )
1731 })?;
1732 let v = norm_attr_value(v);
1733 validate_access_value(&format!("class '{name}' properties"), &v)?;
1734 access_default = Some(v);
1735 continue;
1736 }
1737 if a.name.eq_ignore_ascii_case("GetAccess") {
1738 let v = a.value.as_ref().ok_or_else(|| {
1739 format!(
1740 "GetAccess requires value in class '{name}' properties block",
1741 )
1742 })?;
1743 let v = norm_attr_value(v);
1744 validate_access_value(&format!("class '{name}' properties"), &v)?;
1745 get_access = Some(v);
1746 continue;
1747 }
1748 if a.name.eq_ignore_ascii_case("SetAccess") {
1749 let v = a.value.as_ref().ok_or_else(|| {
1750 format!(
1751 "SetAccess requires value in class '{name}' properties block",
1752 )
1753 })?;
1754 let v = norm_attr_value(v);
1755 validate_access_value(&format!("class '{name}' properties"), &v)?;
1756 set_access = Some(v);
1757 continue;
1758 }
1759 }
1760 if has_static && has_dependent {
1761 return Err(format!("class '{name}' properties: attributes 'Static' and 'Dependent' cannot be combined"));
1762 }
1763 if has_constant && has_dependent {
1764 return Err(format!("class '{name}' properties: attributes 'Constant' and 'Dependent' cannot be combined"));
1765 }
1766 let _ = (access_default, get_access, set_access);
1768 for p in props {
1770 if !prop_names.insert(p.clone()) {
1771 return Err(format!("Duplicate property '{p}' in class {name}"));
1772 }
1773 if method_names.contains(p) {
1774 return Err(format!(
1775 "Name '{p}' used for both property and method in class {name}"
1776 ));
1777 }
1778 }
1779 }
1780 HirClassMember::Methods { body, attributes } => {
1781 let mut _has_static = false;
1783 let mut has_abstract = false;
1784 let mut has_sealed = false;
1785 let mut _has_hidden = false;
1786 for a in attributes {
1787 if a.name.eq_ignore_ascii_case("Static") {
1788 _has_static = true;
1789 continue;
1790 }
1791 if a.name.eq_ignore_ascii_case("Abstract") {
1792 has_abstract = true;
1793 continue;
1794 }
1795 if a.name.eq_ignore_ascii_case("Sealed") {
1796 has_sealed = true;
1797 continue;
1798 }
1799 if a.name.eq_ignore_ascii_case("Hidden") {
1800 _has_hidden = true;
1801 continue;
1802 }
1803 if a.name.eq_ignore_ascii_case("Access") {
1804 let v =
1805 a.value.as_ref().ok_or_else(|| {
1806 format!(
1807 "Access requires value in class '{name}' methods block",
1808 )
1809 })?;
1810 let v = norm_attr_value(v);
1811 validate_access_value(&format!("class '{name}' methods"), &v)?;
1812 }
1813 }
1814 if has_abstract && has_sealed {
1815 return Err(format!("class '{name}' methods: attributes 'Abstract' and 'Sealed' cannot be combined"));
1816 }
1817 for s in body {
1819 if let HirStmt::Function { name: fname, .. } = s {
1820 if !method_names.insert(fname.clone()) {
1821 return Err(format!(
1822 "Duplicate method '{fname}' in class {name}"
1823 ));
1824 }
1825 if prop_names.contains(fname) {
1826 return Err(format!("Name '{fname}' used for both property and method in class {name}"));
1827 }
1828 }
1829 }
1830 }
1831 HirClassMember::Events { attributes, names } => {
1832 for ev in names {
1834 if method_names.contains(ev) || prop_names.contains(ev) {
1835 return Err(format!("Name '{ev}' used for event conflicts with existing member in class {name}"));
1836 }
1837 }
1838 let mut seen = std::collections::HashSet::new();
1839 for ev in names {
1840 if !seen.insert(ev) {
1841 return Err(format!("Duplicate event '{ev}' in class {name}"));
1842 }
1843 }
1844 let _ = attributes; }
1846 HirClassMember::Enumeration { attributes, names } => {
1847 for en in names {
1849 if method_names.contains(en) || prop_names.contains(en) {
1850 return Err(format!("Name '{en}' used for enumeration conflicts with existing member in class {name}"));
1851 }
1852 }
1853 let mut seen = std::collections::HashSet::new();
1854 for en in names {
1855 if !seen.insert(en) {
1856 return Err(format!(
1857 "Duplicate enumeration '{en}' in class {name}"
1858 ));
1859 }
1860 }
1861 let _ = attributes;
1862 }
1863 HirClassMember::Arguments { attributes, names } => {
1864 for ar in names {
1866 if method_names.contains(ar) || prop_names.contains(ar) {
1867 return Err(format!("Name '{ar}' used for arguments conflicts with existing member in class {name}"));
1868 }
1869 }
1870 let _ = attributes;
1871 }
1872 }
1873 }
1874 }
1875 }
1876 Ok(())
1877}
1878
1879pub fn lower_with_context(
1881 prog: &AstProgram,
1882 existing_vars: &HashMap<String, usize>,
1883) -> Result<(HirProgram, HashMap<String, usize>), String> {
1884 let empty_functions = HashMap::new();
1885 let result = lower_with_full_context(prog, existing_vars, &empty_functions)?;
1886 Ok((result.hir, result.variables))
1887}
1888
1889pub fn lower_with_full_context(
1891 prog: &AstProgram,
1892 existing_vars: &HashMap<String, usize>,
1893 existing_functions: &HashMap<String, HirStmt>,
1894) -> Result<LoweringResult, String> {
1895 let mut ctx = Ctx::new();
1896
1897 for (name, var_id) in existing_vars {
1899 ctx.scopes[0].bindings.insert(name.clone(), VarId(*var_id));
1900 while ctx.var_types.len() <= *var_id {
1902 ctx.var_types.push(Type::Unknown);
1903 }
1904 while ctx.var_names.len() <= *var_id {
1905 ctx.var_names.push(None);
1906 }
1907 ctx.var_names[*var_id] = Some(name.clone());
1908 if *var_id >= ctx.next_var {
1910 ctx.next_var = var_id + 1;
1911 }
1912 }
1913
1914 for (name, func_stmt) in existing_functions {
1916 ctx.functions.insert(name.clone(), func_stmt.clone());
1917 }
1918
1919 let body = ctx.lower_stmts(&prog.body)?;
1920
1921 let mut all_vars = HashMap::new();
1923 for (name, var_id) in &ctx.scopes[0].bindings {
1924 all_vars.insert(name.clone(), var_id.0);
1925 }
1926
1927 Ok(LoweringResult {
1928 hir: HirProgram {
1929 body,
1930 var_types: ctx.var_types.clone(),
1931 },
1932 variables: all_vars,
1933 functions: ctx.functions,
1934 var_types: ctx.var_types,
1935 var_names: ctx
1936 .var_names
1937 .into_iter()
1938 .enumerate()
1939 .filter_map(|(idx, name)| name.map(|n| (VarId(idx), n)))
1940 .collect(),
1941 })
1942}
1943
1944pub mod remapping {
1947 use super::*;
1948 use std::collections::HashMap;
1949
1950 pub fn remap_function_body(body: &[HirStmt], var_map: &HashMap<VarId, VarId>) -> Vec<HirStmt> {
1952 body.iter().map(|stmt| remap_stmt(stmt, var_map)).collect()
1953 }
1954
1955 pub fn remap_stmt(stmt: &HirStmt, var_map: &HashMap<VarId, VarId>) -> HirStmt {
1957 match stmt {
1958 HirStmt::ExprStmt(expr, suppressed) => {
1959 HirStmt::ExprStmt(remap_expr(expr, var_map), *suppressed)
1960 }
1961 HirStmt::Assign(var_id, expr, suppressed) => {
1962 let new_var_id = var_map.get(var_id).copied().unwrap_or(*var_id);
1963 HirStmt::Assign(new_var_id, remap_expr(expr, var_map), *suppressed)
1964 }
1965 HirStmt::MultiAssign(var_ids, expr, suppressed) => {
1966 let mapped: Vec<Option<VarId>> = var_ids
1967 .iter()
1968 .map(|v| v.and_then(|vv| var_map.get(&vv).copied().or(Some(vv))))
1969 .collect();
1970 HirStmt::MultiAssign(mapped, remap_expr(expr, var_map), *suppressed)
1971 }
1972 HirStmt::AssignLValue(lv, expr, suppressed) => {
1973 let remapped_lv = match lv {
1974 super::HirLValue::Var(v) => {
1975 super::HirLValue::Var(var_map.get(v).copied().unwrap_or(*v))
1976 }
1977 super::HirLValue::Member(b, n) => {
1978 super::HirLValue::Member(Box::new(remap_expr(b, var_map)), n.clone())
1979 }
1980 super::HirLValue::MemberDynamic(b, n) => super::HirLValue::MemberDynamic(
1981 Box::new(remap_expr(b, var_map)),
1982 Box::new(remap_expr(n, var_map)),
1983 ),
1984 super::HirLValue::Index(b, idxs) => super::HirLValue::Index(
1985 Box::new(remap_expr(b, var_map)),
1986 idxs.iter().map(|e| remap_expr(e, var_map)).collect(),
1987 ),
1988 super::HirLValue::IndexCell(b, idxs) => super::HirLValue::IndexCell(
1989 Box::new(remap_expr(b, var_map)),
1990 idxs.iter().map(|e| remap_expr(e, var_map)).collect(),
1991 ),
1992 };
1993 HirStmt::AssignLValue(remapped_lv, remap_expr(expr, var_map), *suppressed)
1994 }
1995 HirStmt::If {
1996 cond,
1997 then_body,
1998 elseif_blocks,
1999 else_body,
2000 } => HirStmt::If {
2001 cond: remap_expr(cond, var_map),
2002 then_body: remap_function_body(then_body, var_map),
2003 elseif_blocks: elseif_blocks
2004 .iter()
2005 .map(|(c, b)| (remap_expr(c, var_map), remap_function_body(b, var_map)))
2006 .collect(),
2007 else_body: else_body.as_ref().map(|b| remap_function_body(b, var_map)),
2008 },
2009 HirStmt::While { cond, body } => HirStmt::While {
2010 cond: remap_expr(cond, var_map),
2011 body: remap_function_body(body, var_map),
2012 },
2013 HirStmt::For { var, expr, body } => {
2014 let new_var = var_map.get(var).copied().unwrap_or(*var);
2015 HirStmt::For {
2016 var: new_var,
2017 expr: remap_expr(expr, var_map),
2018 body: remap_function_body(body, var_map),
2019 }
2020 }
2021 HirStmt::Switch {
2022 expr,
2023 cases,
2024 otherwise,
2025 } => HirStmt::Switch {
2026 expr: remap_expr(expr, var_map),
2027 cases: cases
2028 .iter()
2029 .map(|(c, b)| (remap_expr(c, var_map), remap_function_body(b, var_map)))
2030 .collect(),
2031 otherwise: otherwise.as_ref().map(|b| remap_function_body(b, var_map)),
2032 },
2033 HirStmt::TryCatch {
2034 try_body,
2035 catch_var,
2036 catch_body,
2037 } => HirStmt::TryCatch {
2038 try_body: remap_function_body(try_body, var_map),
2039 catch_var: catch_var
2040 .as_ref()
2041 .map(|v| var_map.get(v).copied().unwrap_or(*v)),
2042 catch_body: remap_function_body(catch_body, var_map),
2043 },
2044 HirStmt::Global(vars) => HirStmt::Global(
2045 vars.iter()
2046 .map(|(v, name)| (var_map.get(v).copied().unwrap_or(*v), name.clone()))
2047 .collect(),
2048 ),
2049 HirStmt::Persistent(vars) => HirStmt::Persistent(
2050 vars.iter()
2051 .map(|(v, name)| (var_map.get(v).copied().unwrap_or(*v), name.clone()))
2052 .collect(),
2053 ),
2054 HirStmt::Break | HirStmt::Continue | HirStmt::Return => stmt.clone(),
2055 HirStmt::Function { .. } => stmt.clone(), HirStmt::ClassDef {
2057 name,
2058 super_class,
2059 members,
2060 } => HirStmt::ClassDef {
2061 name: name.clone(),
2062 super_class: super_class.clone(),
2063 members: members
2064 .iter()
2065 .map(|m| match m {
2066 HirClassMember::Properties { attributes, names } => {
2067 HirClassMember::Properties {
2068 attributes: attributes.clone(),
2069 names: names.clone(),
2070 }
2071 }
2072 HirClassMember::Events { attributes, names } => HirClassMember::Events {
2073 attributes: attributes.clone(),
2074 names: names.clone(),
2075 },
2076 HirClassMember::Enumeration { attributes, names } => {
2077 HirClassMember::Enumeration {
2078 attributes: attributes.clone(),
2079 names: names.clone(),
2080 }
2081 }
2082 HirClassMember::Arguments { attributes, names } => {
2083 HirClassMember::Arguments {
2084 attributes: attributes.clone(),
2085 names: names.clone(),
2086 }
2087 }
2088 HirClassMember::Methods { attributes, body } => HirClassMember::Methods {
2089 attributes: attributes.clone(),
2090 body: remap_function_body(body, var_map),
2091 },
2092 })
2093 .collect(),
2094 },
2095 HirStmt::Import { path, wildcard } => HirStmt::Import {
2096 path: path.clone(),
2097 wildcard: *wildcard,
2098 },
2099 }
2100 }
2101
2102 pub fn remap_expr(expr: &HirExpr, var_map: &HashMap<VarId, VarId>) -> HirExpr {
2104 let new_kind = match &expr.kind {
2105 HirExprKind::Var(var_id) => {
2106 let new_var_id = var_map.get(var_id).copied().unwrap_or(*var_id);
2107 HirExprKind::Var(new_var_id)
2108 }
2109 HirExprKind::Unary(op, e) => HirExprKind::Unary(*op, Box::new(remap_expr(e, var_map))),
2110 HirExprKind::Binary(left, op, right) => HirExprKind::Binary(
2111 Box::new(remap_expr(left, var_map)),
2112 *op,
2113 Box::new(remap_expr(right, var_map)),
2114 ),
2115 HirExprKind::Tensor(rows) => HirExprKind::Tensor(
2116 rows.iter()
2117 .map(|row| row.iter().map(|e| remap_expr(e, var_map)).collect())
2118 .collect(),
2119 ),
2120 HirExprKind::Cell(rows) => HirExprKind::Cell(
2121 rows.iter()
2122 .map(|row| row.iter().map(|e| remap_expr(e, var_map)).collect())
2123 .collect(),
2124 ),
2125 HirExprKind::Index(base, indices) => HirExprKind::Index(
2126 Box::new(remap_expr(base, var_map)),
2127 indices.iter().map(|i| remap_expr(i, var_map)).collect(),
2128 ),
2129 HirExprKind::IndexCell(base, indices) => HirExprKind::IndexCell(
2130 Box::new(remap_expr(base, var_map)),
2131 indices.iter().map(|i| remap_expr(i, var_map)).collect(),
2132 ),
2133 HirExprKind::Range(start, step, end) => HirExprKind::Range(
2134 Box::new(remap_expr(start, var_map)),
2135 step.as_ref().map(|s| Box::new(remap_expr(s, var_map))),
2136 Box::new(remap_expr(end, var_map)),
2137 ),
2138 HirExprKind::Member(base, name) => {
2139 HirExprKind::Member(Box::new(remap_expr(base, var_map)), name.clone())
2140 }
2141 HirExprKind::MemberDynamic(base, name) => HirExprKind::MemberDynamic(
2142 Box::new(remap_expr(base, var_map)),
2143 Box::new(remap_expr(name, var_map)),
2144 ),
2145 HirExprKind::MethodCall(base, name, args) => HirExprKind::MethodCall(
2146 Box::new(remap_expr(base, var_map)),
2147 name.clone(),
2148 args.iter().map(|a| remap_expr(a, var_map)).collect(),
2149 ),
2150 HirExprKind::AnonFunc { params, body } => HirExprKind::AnonFunc {
2151 params: params.clone(),
2152 body: Box::new(remap_expr(body, var_map)),
2153 },
2154 HirExprKind::FuncHandle(name) => HirExprKind::FuncHandle(name.clone()),
2155 HirExprKind::FuncCall(name, args) => HirExprKind::FuncCall(
2156 name.clone(),
2157 args.iter().map(|a| remap_expr(a, var_map)).collect(),
2158 ),
2159 HirExprKind::Number(_)
2160 | HirExprKind::String(_)
2161 | HirExprKind::Constant(_)
2162 | HirExprKind::Colon
2163 | HirExprKind::End
2164 | HirExprKind::MetaClass(_) => expr.kind.clone(),
2165 };
2166 HirExpr {
2167 kind: new_kind,
2168 ty: expr.ty.clone(),
2169 }
2170 }
2171
2172 pub fn collect_function_variables(body: &[HirStmt]) -> std::collections::HashSet<VarId> {
2174 let mut vars = std::collections::HashSet::new();
2175
2176 for stmt in body {
2177 collect_stmt_variables(stmt, &mut vars);
2178 }
2179
2180 vars
2181 }
2182
2183 fn collect_stmt_variables(stmt: &HirStmt, vars: &mut std::collections::HashSet<VarId>) {
2184 match stmt {
2185 HirStmt::ExprStmt(expr, _) => collect_expr_variables(expr, vars),
2186 HirStmt::Assign(var_id, expr, _) => {
2187 vars.insert(*var_id);
2188 collect_expr_variables(expr, vars);
2189 }
2190 HirStmt::MultiAssign(var_ids, expr, _) => {
2191 for v in var_ids.iter().flatten() {
2192 vars.insert(*v);
2193 }
2194 collect_expr_variables(expr, vars);
2195 }
2196 HirStmt::If {
2197 cond,
2198 then_body,
2199 elseif_blocks,
2200 else_body,
2201 } => {
2202 collect_expr_variables(cond, vars);
2203 for stmt in then_body {
2204 collect_stmt_variables(stmt, vars);
2205 }
2206 for (cond_expr, body) in elseif_blocks {
2207 collect_expr_variables(cond_expr, vars);
2208 for stmt in body {
2209 collect_stmt_variables(stmt, vars);
2210 }
2211 }
2212 if let Some(body) = else_body {
2213 for stmt in body {
2214 collect_stmt_variables(stmt, vars);
2215 }
2216 }
2217 }
2218 HirStmt::While { cond, body } => {
2219 collect_expr_variables(cond, vars);
2220 for stmt in body {
2221 collect_stmt_variables(stmt, vars);
2222 }
2223 }
2224 HirStmt::For { var, expr, body } => {
2225 vars.insert(*var);
2226 collect_expr_variables(expr, vars);
2227 for stmt in body {
2228 collect_stmt_variables(stmt, vars);
2229 }
2230 }
2231 HirStmt::Switch {
2232 expr,
2233 cases,
2234 otherwise,
2235 } => {
2236 collect_expr_variables(expr, vars);
2237 for (v, b) in cases {
2238 collect_expr_variables(v, vars);
2239 for s in b {
2240 collect_stmt_variables(s, vars);
2241 }
2242 }
2243 if let Some(b) = otherwise {
2244 for s in b {
2245 collect_stmt_variables(s, vars);
2246 }
2247 }
2248 }
2249 HirStmt::TryCatch {
2250 try_body,
2251 catch_var,
2252 catch_body,
2253 } => {
2254 if let Some(v) = catch_var {
2255 vars.insert(*v);
2256 }
2257 for s in try_body {
2258 collect_stmt_variables(s, vars);
2259 }
2260 for s in catch_body {
2261 collect_stmt_variables(s, vars);
2262 }
2263 }
2264 HirStmt::Global(vs) | HirStmt::Persistent(vs) => {
2265 for (v, _name) in vs {
2266 vars.insert(*v);
2267 }
2268 }
2269 HirStmt::AssignLValue(lv, expr, _) => {
2270 match lv {
2271 HirLValue::Var(v) => {
2272 vars.insert(*v);
2273 }
2274 HirLValue::Member(base, _) => collect_expr_variables(base, vars),
2275 HirLValue::MemberDynamic(base, name) => {
2276 collect_expr_variables(base, vars);
2277 collect_expr_variables(name, vars);
2278 }
2279 HirLValue::Index(base, idxs) | HirLValue::IndexCell(base, idxs) => {
2280 collect_expr_variables(base, vars);
2281 for i in idxs {
2282 collect_expr_variables(i, vars);
2283 }
2284 }
2285 }
2286 collect_expr_variables(expr, vars);
2287 }
2288 HirStmt::Break | HirStmt::Continue | HirStmt::Return => {}
2289 HirStmt::Function { .. } => {} HirStmt::ClassDef { .. } => {}
2291 HirStmt::Import { .. } => {}
2292 }
2293 }
2294
2295 fn collect_expr_variables(expr: &HirExpr, vars: &mut std::collections::HashSet<VarId>) {
2296 match &expr.kind {
2297 HirExprKind::Var(var_id) => {
2298 vars.insert(*var_id);
2299 }
2300 HirExprKind::Unary(_, e) => collect_expr_variables(e, vars),
2301 HirExprKind::Binary(left, _, right) => {
2302 collect_expr_variables(left, vars);
2303 collect_expr_variables(right, vars);
2304 }
2305 HirExprKind::Tensor(rows) => {
2306 for row in rows {
2307 for e in row {
2308 collect_expr_variables(e, vars);
2309 }
2310 }
2311 }
2312 HirExprKind::Cell(rows) => {
2313 for row in rows {
2314 for e in row {
2315 collect_expr_variables(e, vars);
2316 }
2317 }
2318 }
2319 HirExprKind::Index(base, indices) => {
2320 collect_expr_variables(base, vars);
2321 for idx in indices {
2322 collect_expr_variables(idx, vars);
2323 }
2324 }
2325 HirExprKind::IndexCell(base, indices) => {
2326 collect_expr_variables(base, vars);
2327 for idx in indices {
2328 collect_expr_variables(idx, vars);
2329 }
2330 }
2331 HirExprKind::Range(start, step, end) => {
2332 collect_expr_variables(start, vars);
2333 if let Some(step_expr) = step {
2334 collect_expr_variables(step_expr, vars);
2335 }
2336 collect_expr_variables(end, vars);
2337 }
2338 HirExprKind::Member(base, _) => collect_expr_variables(base, vars),
2339 HirExprKind::MemberDynamic(base, name) => {
2340 collect_expr_variables(base, vars);
2341 collect_expr_variables(name, vars);
2342 }
2343 HirExprKind::MethodCall(base, _, args) => {
2344 collect_expr_variables(base, vars);
2345 for a in args {
2346 collect_expr_variables(a, vars);
2347 }
2348 }
2349 HirExprKind::AnonFunc { body, .. } => collect_expr_variables(body, vars),
2350 HirExprKind::FuncHandle(_) => {}
2351 HirExprKind::FuncCall(_, args) => {
2352 for arg in args {
2353 collect_expr_variables(arg, vars);
2354 }
2355 }
2356 HirExprKind::Number(_)
2357 | HirExprKind::String(_)
2358 | HirExprKind::Constant(_)
2359 | HirExprKind::Colon
2360 | HirExprKind::End
2361 | HirExprKind::MetaClass(_) => {}
2362 }
2363 }
2364
2365 pub fn create_function_var_map(params: &[VarId], outputs: &[VarId]) -> HashMap<VarId, VarId> {
2368 let mut var_map = HashMap::new();
2369 let mut local_var_index = 0;
2370
2371 for param_id in params {
2373 var_map.insert(*param_id, VarId(local_var_index));
2374 local_var_index += 1;
2375 }
2376
2377 for output_id in outputs {
2379 if !var_map.contains_key(output_id) {
2380 var_map.insert(*output_id, VarId(local_var_index));
2381 local_var_index += 1;
2382 }
2383 }
2384
2385 var_map
2386 }
2387
2388 pub fn create_complete_function_var_map(
2390 params: &[VarId],
2391 outputs: &[VarId],
2392 body: &[HirStmt],
2393 ) -> HashMap<VarId, VarId> {
2394 let mut var_map = HashMap::new();
2395 let mut local_var_index = 0;
2396
2397 let all_vars = collect_function_variables(body);
2399
2400 for param_id in params {
2402 var_map.insert(*param_id, VarId(local_var_index));
2403 local_var_index += 1;
2404 }
2405
2406 for output_id in outputs {
2408 if !var_map.contains_key(output_id) {
2409 var_map.insert(*output_id, VarId(local_var_index));
2410 local_var_index += 1;
2411 }
2412 }
2413
2414 for var_id in &all_vars {
2416 if !var_map.contains_key(var_id) {
2417 var_map.insert(*var_id, VarId(local_var_index));
2418 local_var_index += 1;
2419 }
2420 }
2421
2422 var_map
2423 }
2424}
2425
2426struct Scope {
2427 parent: Option<usize>,
2428 bindings: HashMap<String, VarId>,
2429}
2430
2431struct Ctx {
2432 scopes: Vec<Scope>,
2433 var_types: Vec<Type>,
2434 next_var: usize,
2435 functions: HashMap<String, HirStmt>, var_names: Vec<Option<String>>,
2437}
2438
2439impl Ctx {
2440 fn new() -> Self {
2441 Self {
2442 scopes: vec![Scope {
2443 parent: None,
2444 bindings: HashMap::new(),
2445 }],
2446 var_types: Vec::new(),
2447 next_var: 0,
2448 functions: HashMap::new(),
2449 var_names: Vec::new(),
2450 }
2451 }
2452
2453 fn push_scope(&mut self) -> usize {
2454 let parent = Some(self.scopes.len() - 1);
2455 self.scopes.push(Scope {
2456 parent,
2457 bindings: HashMap::new(),
2458 });
2459 self.scopes.len() - 1
2460 }
2461
2462 fn pop_scope(&mut self) {
2463 self.scopes.pop();
2464 }
2465
2466 fn define(&mut self, name: String) -> VarId {
2467 let id = VarId(self.next_var);
2468 self.next_var += 1;
2469 let current = self.scopes.len() - 1;
2470 self.scopes[current].bindings.insert(name.clone(), id);
2471 self.var_types.push(Type::Unknown);
2472 self.var_names.push(Some(name));
2473 id
2474 }
2475
2476 fn lookup(&self, name: &str) -> Option<VarId> {
2477 let mut scope_idx = Some(self.scopes.len() - 1);
2478 while let Some(idx) = scope_idx {
2479 if let Some(id) = self.scopes[idx].bindings.get(name) {
2480 return Some(*id);
2481 }
2482 scope_idx = self.scopes[idx].parent;
2483 }
2484 None
2485 }
2486
2487 fn is_constant(&self, name: &str) -> bool {
2488 runmat_builtins::constants().iter().any(|c| c.name == name)
2490 }
2491
2492 fn is_builtin_function(&self, name: &str) -> bool {
2493 runmat_builtins::builtin_functions()
2495 .iter()
2496 .any(|b| b.name == name)
2497 }
2498
2499 fn is_user_defined_function(&self, name: &str) -> bool {
2500 self.functions.contains_key(name)
2501 }
2502
2503 fn is_function(&self, name: &str) -> bool {
2504 self.is_user_defined_function(name) || self.is_builtin_function(name)
2505 }
2506
2507 fn lower_stmts(&mut self, stmts: &[AstStmt]) -> Result<Vec<HirStmt>, String> {
2508 stmts.iter().map(|s| self.lower_stmt(s)).collect()
2509 }
2510
2511 fn lower_stmt(&mut self, stmt: &AstStmt) -> Result<HirStmt, String> {
2512 match stmt {
2513 AstStmt::ExprStmt(e, semicolon_terminated) => Ok(HirStmt::ExprStmt(
2514 self.lower_expr(e)?,
2515 *semicolon_terminated,
2516 )),
2517 AstStmt::Assign(name, expr, semicolon_terminated) => {
2518 let id = match self.lookup(name) {
2519 Some(id) => id,
2520 None => self.define(name.clone()),
2521 };
2522 let value = self.lower_expr(expr)?;
2523 if id.0 < self.var_types.len() {
2524 self.var_types[id.0] = value.ty.clone();
2525 }
2526 Ok(HirStmt::Assign(id, value, *semicolon_terminated))
2527 }
2528 AstStmt::MultiAssign(names, expr, semicolon_terminated) => {
2529 let ids: Vec<Option<VarId>> = names
2530 .iter()
2531 .map(|n| {
2532 if n == "~" {
2533 None
2534 } else {
2535 Some(match self.lookup(n) {
2536 Some(id) => id,
2537 None => self.define(n.clone()),
2538 })
2539 }
2540 })
2541 .collect();
2542 let value = self.lower_expr(expr)?;
2543 Ok(HirStmt::MultiAssign(ids, value, *semicolon_terminated))
2544 }
2545 AstStmt::If {
2546 cond,
2547 then_body,
2548 elseif_blocks,
2549 else_body,
2550 } => {
2551 let cond = self.lower_expr(cond)?;
2552 let then_body = self.lower_stmts(then_body)?;
2553 let mut elseif_vec = Vec::new();
2554 for (c, b) in elseif_blocks {
2555 elseif_vec.push((self.lower_expr(c)?, self.lower_stmts(b)?));
2556 }
2557 let else_body = match else_body {
2558 Some(b) => Some(self.lower_stmts(b)?),
2559 None => None,
2560 };
2561 Ok(HirStmt::If {
2562 cond,
2563 then_body,
2564 elseif_blocks: elseif_vec,
2565 else_body,
2566 })
2567 }
2568 AstStmt::While { cond, body } => Ok(HirStmt::While {
2569 cond: self.lower_expr(cond)?,
2570 body: self.lower_stmts(body)?,
2571 }),
2572 AstStmt::For { var, expr, body } => {
2573 let id = match self.lookup(var) {
2574 Some(id) => id,
2575 None => self.define(var.clone()),
2576 };
2577 let expr = self.lower_expr(expr)?;
2578 let body = self.lower_stmts(body)?;
2579 Ok(HirStmt::For {
2580 var: id,
2581 expr,
2582 body,
2583 })
2584 }
2585 AstStmt::Switch {
2586 expr,
2587 cases,
2588 otherwise,
2589 } => {
2590 let control = self.lower_expr(expr)?;
2591 let mut cases_hir: Vec<(HirExpr, Vec<HirStmt>)> = Vec::new();
2592 for (v, b) in cases {
2593 let ve = self.lower_expr(v)?;
2594 let vb = self.lower_stmts(b)?;
2595 cases_hir.push((ve, vb));
2596 }
2597 let otherwise_hir = otherwise
2598 .as_ref()
2599 .map(|b| self.lower_stmts(b))
2600 .transpose()?;
2601 Ok(HirStmt::Switch {
2602 expr: control,
2603 cases: cases_hir,
2604 otherwise: otherwise_hir,
2605 })
2606 }
2607 AstStmt::TryCatch {
2608 try_body,
2609 catch_var,
2610 catch_body,
2611 } => {
2612 let try_hir = self.lower_stmts(try_body)?;
2613 let catch_var_id = catch_var.as_ref().map(|name| match self.lookup(name) {
2614 Some(id) => id,
2615 None => self.define(name.clone()),
2616 });
2617 let catch_hir = self.lower_stmts(catch_body)?;
2618 Ok(HirStmt::TryCatch {
2619 try_body: try_hir,
2620 catch_var: catch_var_id,
2621 catch_body: catch_hir,
2622 })
2623 }
2624 AstStmt::Global(names) => {
2625 let pairs: Vec<(VarId, String)> = names
2626 .iter()
2627 .map(|n| {
2628 let id = match self.lookup(n) {
2629 Some(id) => id,
2630 None => self.define(n.clone()),
2631 };
2632 (id, n.clone())
2633 })
2634 .collect();
2635 Ok(HirStmt::Global(pairs))
2636 }
2637 AstStmt::Persistent(names) => {
2638 let pairs: Vec<(VarId, String)> = names
2639 .iter()
2640 .map(|n| {
2641 let id = match self.lookup(n) {
2642 Some(id) => id,
2643 None => self.define(n.clone()),
2644 };
2645 (id, n.clone())
2646 })
2647 .collect();
2648 Ok(HirStmt::Persistent(pairs))
2649 }
2650 AstStmt::Break => Ok(HirStmt::Break),
2651 AstStmt::Continue => Ok(HirStmt::Continue),
2652 AstStmt::Return => Ok(HirStmt::Return),
2653 AstStmt::Function {
2654 name,
2655 params,
2656 outputs,
2657 body,
2658 } => {
2659 self.push_scope();
2660 let param_ids: Vec<VarId> = params.iter().map(|p| self.define(p.clone())).collect();
2661 let output_ids: Vec<VarId> =
2662 outputs.iter().map(|o| self.define(o.clone())).collect();
2663 let body_hir = self.lower_stmts(body)?;
2664 self.pop_scope();
2665
2666 let has_varargin = params
2667 .last()
2668 .map(|s| s.as_str() == "varargin")
2669 .unwrap_or(false);
2670 let has_varargout = outputs
2671 .last()
2672 .map(|s| s.as_str() == "varargout")
2673 .unwrap_or(false);
2674
2675 let func_stmt = HirStmt::Function {
2676 name: name.clone(),
2677 params: param_ids,
2678 outputs: output_ids,
2679 body: body_hir,
2680 has_varargin,
2681 has_varargout,
2682 };
2683
2684 self.functions.insert(name.clone(), func_stmt.clone());
2686
2687 Ok(func_stmt)
2688 }
2689 AstStmt::ClassDef {
2690 name,
2691 super_class,
2692 members,
2693 } => {
2694 let members_hir = members
2696 .iter()
2697 .map(|m| match m {
2698 parser::ClassMember::Properties { attributes, names } => {
2699 HirClassMember::Properties {
2700 attributes: attributes.clone(),
2701 names: names.clone(),
2702 }
2703 }
2704 parser::ClassMember::Events { attributes, names } => {
2705 HirClassMember::Events {
2706 attributes: attributes.clone(),
2707 names: names.clone(),
2708 }
2709 }
2710 parser::ClassMember::Enumeration { attributes, names } => {
2711 HirClassMember::Enumeration {
2712 attributes: attributes.clone(),
2713 names: names.clone(),
2714 }
2715 }
2716 parser::ClassMember::Arguments { attributes, names } => {
2717 HirClassMember::Arguments {
2718 attributes: attributes.clone(),
2719 names: names.clone(),
2720 }
2721 }
2722 parser::ClassMember::Methods { attributes, body } => {
2723 match self.lower_stmts(body) {
2724 Ok(s) => HirClassMember::Methods {
2725 attributes: attributes.clone(),
2726 body: s,
2727 },
2728 Err(_) => HirClassMember::Methods {
2729 attributes: attributes.clone(),
2730 body: Vec::new(),
2731 },
2732 }
2733 }
2734 })
2735 .collect();
2736 Ok(HirStmt::ClassDef {
2737 name: name.clone(),
2738 super_class: super_class.clone(),
2739 members: members_hir,
2740 })
2741 }
2742 AstStmt::AssignLValue(lv, rhs, suppressed) => {
2743 let hir_lv = self.lower_lvalue(lv)?;
2745 let value = self.lower_expr(rhs)?;
2746 if let HirLValue::Var(var_id) = hir_lv {
2748 if var_id.0 < self.var_types.len() {
2749 self.var_types[var_id.0] = value.ty.clone();
2750 }
2751 return Ok(HirStmt::Assign(var_id, value, *suppressed));
2752 }
2753 Ok(HirStmt::AssignLValue(hir_lv, value, *suppressed))
2754 }
2755 AstStmt::Import { .. } => {
2756 if let AstStmt::Import { path, wildcard } = stmt {
2758 Ok(HirStmt::Import {
2759 path: path.clone(),
2760 wildcard: *wildcard,
2761 })
2762 } else {
2763 unreachable!()
2764 }
2765 }
2766 }
2767 }
2768
2769 fn lower_expr(&mut self, expr: &AstExpr) -> Result<HirExpr, String> {
2770 use parser::Expr::*;
2771 let (kind, ty) = match expr {
2772 Number(n) => (HirExprKind::Number(n.clone()), Type::Num),
2773 String(s) => (HirExprKind::String(s.clone()), Type::String),
2774 Ident(name) => {
2775 if let Some(id) = self.lookup(name) {
2777 let ty = if id.0 < self.var_types.len() {
2778 self.var_types[id.0].clone()
2779 } else {
2780 Type::Unknown
2781 };
2782 (HirExprKind::Var(id), ty)
2783 } else if self.is_constant(name) {
2784 (HirExprKind::Constant(name.clone()), Type::Num)
2785 } else if self.is_function(name) {
2786 let return_type = self.infer_function_return_type(name, &[]);
2788 (HirExprKind::FuncCall(name.clone(), vec![]), return_type)
2789 } else {
2790 return Err(format!(
2791 "{}: Undefined variable: {name}",
2792 "MATLAB:UndefinedVariable"
2793 ));
2794 }
2795 }
2796 Unary(op, e) => {
2797 let inner = self.lower_expr(e)?;
2798 let ty = inner.ty.clone();
2799 (HirExprKind::Unary(*op, Box::new(inner)), ty)
2800 }
2801 Binary(a, op, b) => {
2802 let left = self.lower_expr(a)?;
2803 let left_ty = left.ty.clone();
2804 let right = self.lower_expr(b)?;
2805 let right_ty = right.ty.clone();
2806 let ty = match op {
2807 BinOp::Add
2808 | BinOp::Sub
2809 | BinOp::Mul
2810 | BinOp::Div
2811 | BinOp::Pow
2812 | BinOp::LeftDiv => {
2813 if matches!(left_ty, Type::Tensor { .. })
2814 || matches!(right_ty, Type::Tensor { .. })
2815 {
2816 Type::tensor()
2817 } else {
2818 Type::Num
2819 }
2820 }
2821 BinOp::ElemMul | BinOp::ElemDiv | BinOp::ElemPow | BinOp::ElemLeftDiv => {
2823 if matches!(left_ty, Type::Tensor { .. })
2824 || matches!(right_ty, Type::Tensor { .. })
2825 {
2826 Type::tensor()
2827 } else {
2828 Type::Num
2829 }
2830 }
2831 BinOp::Equal
2833 | BinOp::NotEqual
2834 | BinOp::Less
2835 | BinOp::LessEqual
2836 | BinOp::Greater
2837 | BinOp::GreaterEqual => Type::Bool,
2838 BinOp::AndAnd | BinOp::OrOr | BinOp::BitAnd | BinOp::BitOr => Type::Bool,
2840 BinOp::Colon => Type::tensor(),
2841 };
2842 (
2843 HirExprKind::Binary(Box::new(left), *op, Box::new(right)),
2844 ty,
2845 )
2846 }
2847 AnonFunc { params, body } => {
2848 let saved_len = self.scopes.len();
2850 self.push_scope();
2851 let mut param_ids: Vec<VarId> = Vec::with_capacity(params.len());
2852 for p in params {
2853 param_ids.push(self.define(p.clone()));
2854 }
2855 let lowered_body = self.lower_expr(body)?;
2856 while self.scopes.len() > saved_len {
2858 self.pop_scope();
2859 }
2860 (
2861 HirExprKind::AnonFunc {
2862 params: param_ids,
2863 body: Box::new(lowered_body),
2864 },
2865 Type::Unknown,
2866 )
2867 }
2868 FuncHandle(name) => (HirExprKind::FuncHandle(name.clone()), Type::Unknown),
2869 FuncCall(name, args) => {
2870 let arg_exprs: Result<Vec<_>, _> =
2871 args.iter().map(|a| self.lower_expr(a)).collect();
2872 let arg_exprs = arg_exprs?;
2873
2874 if let Some(var_id) = self.lookup(name) {
2877 let var_ty = if var_id.0 < self.var_types.len() {
2879 self.var_types[var_id.0].clone()
2880 } else {
2881 Type::Unknown
2882 };
2883 let var_expr = HirExpr {
2884 kind: HirExprKind::Var(var_id),
2885 ty: var_ty,
2886 };
2887 let index_result_type = Type::Num; (
2890 HirExprKind::Index(Box::new(var_expr), arg_exprs),
2891 index_result_type,
2892 )
2893 } else {
2894 let return_type = self.infer_function_return_type(name, &arg_exprs);
2896 (HirExprKind::FuncCall(name.clone(), arg_exprs), return_type)
2897 }
2898 }
2899 Tensor(rows) => {
2900 let mut hir_rows = Vec::new();
2901 for row in rows {
2902 let mut hir_row = Vec::new();
2903 for expr in row {
2904 hir_row.push(self.lower_expr(expr)?);
2905 }
2906 hir_rows.push(hir_row);
2907 }
2908 (HirExprKind::Tensor(hir_rows), Type::tensor())
2909 }
2910 Cell(rows) => {
2911 let mut hir_rows = Vec::new();
2912 for row in rows {
2913 let mut hir_row = Vec::new();
2914 for expr in row {
2915 hir_row.push(self.lower_expr(expr)?);
2916 }
2917 hir_rows.push(hir_row);
2918 }
2919 (HirExprKind::Cell(hir_rows), Type::Unknown)
2920 }
2921 Index(expr, indices) => {
2922 let base = self.lower_expr(expr)?;
2923 let idx_exprs: Result<Vec<_>, _> =
2924 indices.iter().map(|i| self.lower_expr(i)).collect();
2925 let idx_exprs = idx_exprs?;
2926 let ty = base.ty.clone(); (HirExprKind::Index(Box::new(base), idx_exprs), ty)
2928 }
2929 IndexCell(expr, indices) => {
2930 let base = self.lower_expr(expr)?;
2931 let idx_exprs: Result<Vec<_>, _> =
2932 indices.iter().map(|i| self.lower_expr(i)).collect();
2933 let idx_exprs = idx_exprs?;
2934 (
2935 HirExprKind::IndexCell(Box::new(base), idx_exprs),
2936 Type::Unknown,
2937 )
2938 }
2939 Range(start, step, end) => {
2940 let start_hir = self.lower_expr(start)?;
2941 let end_hir = self.lower_expr(end)?;
2942 let step_hir = step.as_ref().map(|s| self.lower_expr(s)).transpose()?;
2943 (
2944 HirExprKind::Range(
2945 Box::new(start_hir),
2946 step_hir.map(Box::new),
2947 Box::new(end_hir),
2948 ),
2949 Type::tensor(),
2950 )
2951 }
2952 Colon => (HirExprKind::Colon, Type::tensor()),
2953 EndKeyword => (HirExprKind::End, Type::Unknown),
2954 Member(base, name) => {
2955 let b = self.lower_expr(base)?;
2956 (
2957 HirExprKind::Member(Box::new(b), name.clone()),
2958 Type::Unknown,
2959 )
2960 }
2961 MemberDynamic(base, name_expr) => {
2962 let b = self.lower_expr(base)?;
2963 let n = self.lower_expr(name_expr)?;
2964 (
2965 HirExprKind::MemberDynamic(Box::new(b), Box::new(n)),
2966 Type::Unknown,
2967 )
2968 }
2969 MethodCall(base, name, args) => {
2970 let b = self.lower_expr(base)?;
2971 let lowered_args: Result<Vec<_>, _> =
2972 args.iter().map(|a| self.lower_expr(a)).collect();
2973 (
2974 HirExprKind::MethodCall(Box::new(b), name.clone(), lowered_args?),
2975 Type::Unknown,
2976 )
2977 }
2978 MetaClass(name) => (HirExprKind::MetaClass(name.clone()), Type::String),
2979 };
2980 Ok(HirExpr { kind, ty })
2981 }
2982
2983 fn lower_lvalue(&mut self, lv: &parser::LValue) -> Result<HirLValue, String> {
2984 use parser::LValue as ALV;
2985 Ok(match lv {
2986 ALV::Var(name) => {
2987 let id = match self.lookup(name) {
2988 Some(id) => id,
2989 None => self.define(name.clone()),
2990 };
2991 HirLValue::Var(id)
2992 }
2993 ALV::Member(base, name) => {
2994 if let parser::Expr::Ident(var_name) = &**base {
2996 let id = match self.lookup(var_name) {
2997 Some(id) => id,
2998 None => self.define(var_name.clone()),
2999 };
3000 let ty = if id.0 < self.var_types.len() {
3001 self.var_types[id.0].clone()
3002 } else {
3003 Type::Unknown
3004 };
3005 let b = HirExpr {
3006 kind: HirExprKind::Var(id),
3007 ty,
3008 };
3009 HirLValue::Member(Box::new(b), name.clone())
3010 } else {
3011 let b = self.lower_expr(base)?;
3012 HirLValue::Member(Box::new(b), name.clone())
3013 }
3014 }
3015 ALV::MemberDynamic(base, name_expr) => {
3016 let b = self.lower_expr(base)?;
3017 let n = self.lower_expr(name_expr)?;
3018 HirLValue::MemberDynamic(Box::new(b), Box::new(n))
3019 }
3020 ALV::Index(base, idxs) => {
3021 let b = self.lower_expr(base)?;
3022 let lowered: Result<Vec<_>, _> = idxs.iter().map(|e| self.lower_expr(e)).collect();
3023 HirLValue::Index(Box::new(b), lowered?)
3024 }
3025 ALV::IndexCell(base, idxs) => {
3026 let b = self.lower_expr(base)?;
3027 let lowered: Result<Vec<_>, _> = idxs.iter().map(|e| self.lower_expr(e)).collect();
3028 HirLValue::IndexCell(Box::new(b), lowered?)
3029 }
3030 })
3031 }
3032
3033 fn infer_function_return_type(&self, func_name: &str, args: &[HirExpr]) -> Type {
3035 if let Some(HirStmt::Function { outputs, body, .. }) = self.functions.get(func_name) {
3037 return self.infer_user_function_return_type(outputs, body, args);
3039 }
3040
3041 let builtin_functions = runmat_builtins::builtin_functions();
3043 for builtin in builtin_functions {
3044 if builtin.name == func_name {
3045 return builtin.return_type.clone();
3046 }
3047 }
3048
3049 Type::Unknown
3051 }
3052
3053 fn infer_user_function_return_type(
3055 &self,
3056 outputs: &[VarId],
3057 body: &[HirStmt],
3058 _args: &[HirExpr],
3059 ) -> Type {
3060 if outputs.is_empty() {
3061 return Type::Void;
3062 }
3063 let result_types = self.infer_outputs_types(outputs, body);
3064 result_types.first().cloned().unwrap_or(Type::Unknown)
3066 }
3067
3068 fn infer_outputs_types(&self, outputs: &[VarId], body: &[HirStmt]) -> Vec<Type> {
3069 use std::collections::HashMap;
3070
3071 #[derive(Clone)]
3072 struct Analysis {
3073 exits: Vec<HashMap<VarId, Type>>, fallthrough: Option<HashMap<VarId, Type>>, }
3076
3077 fn join_type(a: &Type, b: &Type) -> Type {
3078 if a == b {
3079 return a.clone();
3080 }
3081 if matches!(a, Type::Unknown) {
3082 return b.clone();
3083 }
3084 if matches!(b, Type::Unknown) {
3085 return a.clone();
3086 }
3087 Type::Unknown
3088 }
3089
3090 fn join_env(a: &HashMap<VarId, Type>, b: &HashMap<VarId, Type>) -> HashMap<VarId, Type> {
3091 let mut out = a.clone();
3092 for (k, v) in b {
3093 out.entry(*k)
3094 .and_modify(|t| *t = join_type(t, v))
3095 .or_insert_with(|| v.clone());
3096 }
3097 out
3098 }
3099
3100 #[allow(clippy::type_complexity)]
3101 #[allow(clippy::only_used_in_recursion)]
3102 fn analyze_stmts(
3103 #[allow(clippy::only_used_in_recursion)] _outputs: &[VarId],
3104 stmts: &[HirStmt],
3105 mut env: HashMap<VarId, Type>,
3106 ) -> Analysis {
3107 let mut exits = Vec::new();
3108 let mut i = 0usize;
3109 while i < stmts.len() {
3110 match &stmts[i] {
3111 HirStmt::Assign(var, expr, _) => {
3112 env.insert(*var, expr.ty.clone());
3113 }
3114 HirStmt::MultiAssign(vars, expr, _) => {
3115 for v in vars.iter().flatten() {
3116 env.insert(*v, expr.ty.clone());
3117 }
3118 }
3119 HirStmt::ExprStmt(_, _) | HirStmt::Break | HirStmt::Continue => {}
3120 HirStmt::Return => {
3121 exits.push(env.clone());
3122 return Analysis {
3123 exits,
3124 fallthrough: None,
3125 };
3126 }
3127 HirStmt::If {
3128 cond: _,
3129 then_body,
3130 elseif_blocks,
3131 else_body,
3132 } => {
3133 let then_a = analyze_stmts(_outputs, then_body, env.clone());
3134 let mut out_env = then_a.fallthrough.unwrap_or_else(|| env.clone());
3135 let mut all_exits = then_a.exits;
3136 for (c, b) in elseif_blocks {
3137 let _ = c; let a = analyze_stmts(_outputs, b, env.clone());
3139 if let Some(f) = a.fallthrough {
3140 out_env = join_env(&out_env, &f);
3141 }
3142 all_exits.extend(a.exits);
3143 }
3144 if let Some(else_body) = else_body {
3145 let a = analyze_stmts(_outputs, else_body, env.clone());
3146 if let Some(f) = a.fallthrough {
3147 out_env = join_env(&out_env, &f);
3148 }
3149 all_exits.extend(a.exits);
3150 } else {
3151 out_env = join_env(&out_env, &env);
3153 }
3154 env = out_env;
3155 exits.extend(all_exits);
3156 }
3157 HirStmt::While { cond: _, body } => {
3158 let a = analyze_stmts(_outputs, body, env.clone());
3160 if let Some(f) = a.fallthrough {
3161 env = join_env(&env, &f);
3162 }
3163 exits.extend(a.exits);
3164 }
3165 HirStmt::For { var, expr, body } => {
3166 env.insert(*var, expr.ty.clone());
3168 let a = analyze_stmts(_outputs, body, env.clone());
3169 if let Some(f) = a.fallthrough {
3170 env = join_env(&env, &f);
3171 }
3172 exits.extend(a.exits);
3173 }
3174 HirStmt::Switch {
3175 expr: _,
3176 cases,
3177 otherwise,
3178 } => {
3179 let mut out_env: Option<HashMap<VarId, Type>> = None;
3180 for (_v, b) in cases {
3181 let a = analyze_stmts(_outputs, b, env.clone());
3182 if let Some(f) = a.fallthrough {
3183 out_env = Some(match out_env {
3184 Some(curr) => join_env(&curr, &f),
3185 None => f,
3186 });
3187 }
3188 exits.extend(a.exits);
3189 }
3190 if let Some(otherwise) = otherwise {
3191 let a = analyze_stmts(_outputs, otherwise, env.clone());
3192 if let Some(f) = a.fallthrough {
3193 out_env = Some(match out_env {
3194 Some(curr) => join_env(&curr, &f),
3195 None => f,
3196 });
3197 }
3198 exits.extend(a.exits);
3199 } else {
3200 out_env = Some(match out_env {
3201 Some(curr) => join_env(&curr, &env),
3202 None => env.clone(),
3203 });
3204 }
3205 if let Some(f) = out_env {
3206 env = f;
3207 }
3208 }
3209 HirStmt::TryCatch {
3210 try_body,
3211 catch_var: _,
3212 catch_body,
3213 } => {
3214 let a_try = analyze_stmts(_outputs, try_body, env.clone());
3215 let a_catch = analyze_stmts(_outputs, catch_body, env.clone());
3216 let mut out_env = a_try.fallthrough.unwrap_or_else(|| env.clone());
3217 if let Some(f) = a_catch.fallthrough {
3218 out_env = join_env(&out_env, &f);
3219 }
3220 env = out_env;
3221 exits.extend(a_try.exits);
3222 exits.extend(a_catch.exits);
3223 }
3224 HirStmt::Global(_) | HirStmt::Persistent(_) => {}
3225 HirStmt::Function { .. } => {}
3226 HirStmt::ClassDef { .. } => {}
3227 HirStmt::AssignLValue(_, expr, _) => {
3228 let _ = &expr.ty;
3231 }
3232 HirStmt::Import { .. } => {}
3233 }
3234 i += 1;
3235 }
3236 Analysis {
3237 exits,
3238 fallthrough: Some(env),
3239 }
3240 }
3241
3242 let initial_env: HashMap<VarId, Type> = HashMap::new();
3243 let analysis = analyze_stmts(outputs, body, initial_env);
3244 let mut per_output: Vec<Type> = vec![Type::Unknown; outputs.len()];
3245 let mut accumulate = |env: &std::collections::HashMap<VarId, Type>| {
3246 for (i, out) in outputs.iter().enumerate() {
3247 if let Some(t) = env.get(out) {
3248 per_output[i] = join_type(&per_output[i], t);
3249 }
3250 }
3251 };
3252 for e in &analysis.exits {
3253 accumulate(e);
3254 }
3255 if let Some(f) = &analysis.fallthrough {
3256 accumulate(f);
3257 }
3258 per_output
3259 }
3260}