luaur_analysis/methods/
magic_find_infer.rs1use crate::functions::as_mutable_type_pack::as_mutable_type_pack;
2use crate::functions::flatten_type_pack::flatten_type_pack_id;
3use crate::functions::parse_pattern_string::parse_pattern_string;
4use crate::records::builtin_types::BuiltinTypes;
5use crate::records::magic_find::MagicFind;
6use crate::records::magic_function_call_context::MagicFunctionCallContext;
7use crate::records::type_arena::TypeArena;
8use crate::records::union_type::UnionType;
9use crate::type_aliases::type_id::TypeId;
10use crate::type_aliases::type_pack_id::TypePackId;
11use alloc::vec;
12use luaur_ast::records::ast_expr_constant_bool::AstExprConstantBool;
13use luaur_ast::records::ast_expr_constant_string::AstExprConstantString;
14use luaur_ast::records::ast_node::AstNode;
15use luaur_ast::rtti::ast_node_as;
16
17impl MagicFind {
18 pub fn infer(&self, context: &MagicFunctionCallContext) -> bool {
19 let (params, _tail) = flatten_type_pack_id(context.arguments);
20
21 if params.len() < 2 || params.len() > 4 {
22 return false;
23 }
24
25 let solver = unsafe { context.solver.as_ref() };
26 let arena = unsafe { &mut *solver.arena };
27 let builtin_types = unsafe { &*solver.builtin_types };
28 let call_site = unsafe { context.call_site.as_ref() };
29
30 let pattern_index = if call_site.self_ { 0 } else { 1 };
31 let pattern = if call_site.args.size > pattern_index {
32 let expr = unsafe { *call_site.args.data.add(pattern_index) };
33 unsafe { ast_node_as::<AstExprConstantString>(expr as *mut AstNode) }
34 } else {
35 core::ptr::null_mut()
36 };
37
38 if pattern.is_null() {
39 return false;
40 }
41
42 let mut plain = false;
43 let plain_index = if call_site.self_ { 2 } else { 3 };
44 if call_site.args.size > plain_index {
45 let expr = unsafe { *call_site.args.data.add(plain_index) };
46 let bool_expr = unsafe { ast_node_as::<AstExprConstantBool>(expr as *mut AstNode) };
47 if !bool_expr.is_null() {
48 plain = unsafe { &*bool_expr }.value;
49 }
50 }
51
52 let mut return_types: Vec<TypeId> = Vec::new();
53 if !plain {
54 return_types = unsafe {
55 parse_pattern_string(
56 core::ptr::NonNull::new_unchecked(solver.builtin_types),
57 unsafe { &*pattern }.value.data,
58 unsafe { &*pattern }.value.size,
59 )
60 };
61
62 if return_types.is_empty() {
63 return false;
64 }
65 }
66
67 unsafe {
68 (*context.solver.as_ptr()).constraint_solver_unify(
69 context.constraint.as_ptr(),
70 params[0],
71 builtin_types.stringType,
72 );
73 }
74
75 let optional_number = arena.add_type(UnionType {
76 options: vec![builtin_types.nilType, builtin_types.numberType],
77 });
78 let optional_boolean = arena.add_type(UnionType {
79 options: vec![builtin_types.nilType, builtin_types.booleanType],
80 });
81
82 let init_index = if call_site.self_ { 1 } else { 2 };
83 if params.len() >= 3 && call_site.args.size > init_index {
84 unsafe {
85 (*context.solver.as_ptr()).constraint_solver_unify(
86 context.constraint.as_ptr(),
87 params[2],
88 optional_number,
89 );
90 }
91 }
92
93 if params.len() == 4 && call_site.args.size > plain_index {
94 unsafe {
95 (*context.solver.as_ptr()).constraint_solver_unify(
96 context.constraint.as_ptr(),
97 params[3],
98 optional_boolean,
99 );
100 }
101 }
102
103 return_types.insert(0, optional_number);
104 return_types.insert(1, optional_number);
105
106 let return_list =
107 arena.add_type_pack_vector_type_id_optional_type_pack_id(return_types, None);
108 let result_mut = as_mutable_type_pack(context.result);
109 unsafe {
110 (*result_mut).ty =
111 crate::type_aliases::type_pack_variant::TypePackVariant::Bound(return_list);
112 }
113
114 true
115 }
116}