Skip to main content

luaur_analysis/methods/
magic_format_infer.rs

1use crate::enums::value::Value;
2use crate::functions::as_mutable_type_pack_alt_d::as_mutable_type_pack;
3use crate::functions::flatten_type_pack::flatten_type_pack_id;
4use crate::functions::follow_type::follow_type_id;
5use crate::functions::get_type_alt_j::get_type_id;
6use crate::functions::parse_format_string::parse_format_string;
7use crate::functions::should_suppress_errors_type_utils::should_suppress_errors;
8use crate::functions::unwrap_group::unwrap_group;
9use crate::records::builtin_types::BuiltinTypes;
10use crate::records::count_mismatch::CountMismatch;
11use crate::records::error_suppression::ErrorSuppression;
12use crate::records::magic_function_call_context::MagicFunctionCallContext;
13use crate::records::singleton_type::SingletonType;
14use crate::records::string_singleton::StringSingleton;
15use crate::records::type_error::TypeError;
16use alloc::vec::Vec;
17use luaur_ast::records::ast_expr_constant_string::AstExprConstantString;
18use luaur_ast::records::ast_expr_index_name::AstExprIndexName;
19use luaur_ast::records::ast_node::AstNode;
20use luaur_ast::rtti::ast_node_as;
21
22pub fn magic_format_infer(context: &MagicFunctionCallContext) -> bool {
23    let solver = unsafe { context.solver.as_ref() };
24    let arena = unsafe { &mut *solver.arena };
25
26    let iter = unsafe { crate::functions::begin_type_pack::begin(context.arguments) };
27    let end_iter = unsafe { crate::functions::end_type_pack::end(context.arguments) };
28
29    // we'll suppress any errors for `string.format` if the format string is error suppressing.
30    if iter.operator_eq(&end_iter)
31        || should_suppress_errors(solver.normalizer, unsafe {
32            follow_type_id(*iter.operator_deref())
33        }) == ErrorSuppression::from_value(Value::Suppress)
34    {
35        let result_pack = arena
36            .add_type_pack_initializer_list_type_id(
37                &[unsafe { &*solver.builtin_types }.stringType],
38            );
39        let result_mut = as_mutable_type_pack(context.result);
40        unsafe {
41            (*result_mut).ty =
42                crate::type_aliases::type_pack_variant::TypePackVariant::Bound(result_pack);
43        }
44        return true;
45    }
46
47    let mut fmt: *mut AstExprConstantString = core::ptr::null_mut();
48
49    if !context.call_site.as_ptr().is_null() {
50        let call_site = unsafe { &*context.call_site.as_ptr() };
51        if call_site.func.is_null() {
52            return false;
53        }
54
55        let func_node = unsafe { &*call_site.func };
56        if func_node.base.class_index == AstExprIndexName::ClassIndex {
57            let index_expr =
58                unsafe { ast_node_as::<AstExprIndexName>(call_site.func as *mut AstNode) };
59            if !index_expr.is_null() && call_site.self_ {
60                let unwrapped = unwrap_group(unsafe { &mut *index_expr }.expr);
61                fmt = unsafe { ast_node_as::<AstExprConstantString>(unwrapped as *mut AstNode) };
62            }
63        }
64
65        if !call_site.self_ && call_site.args.size > 0 {
66            fmt = unsafe {
67                ast_node_as::<AstExprConstantString>(*call_site.args.data as *mut AstNode)
68            };
69        }
70    }
71
72    let mut format_string: Option<&str> = None;
73
74    if !fmt.is_null() {
75        let fmt_ref = unsafe { &*fmt };
76        let data = fmt_ref.value.data as *const u8;
77        let size = fmt_ref.value.size as usize;
78        format_string = Some(unsafe {
79            core::str::from_utf8_unchecked(core::slice::from_raw_parts(data, size))
80        });
81    } else {
82        let first_arg = unsafe { *iter.operator_deref() };
83        let followed = unsafe { follow_type_id(first_arg) };
84        let singleton = unsafe { get_type_id::<SingletonType>(followed) };
85        if !singleton.is_null() {
86            if let Some(string_singleton) =
87                unsafe { (*singleton).variant.get_if::<StringSingleton>() }
88            {
89                format_string = Some(&string_singleton.value);
90            }
91        }
92    }
93
94    if format_string.is_none() {
95        return false;
96    }
97
98    let format_str = format_string.unwrap();
99    let expected = parse_format_string(
100        core::ptr::NonNull::new(unsafe { &mut *solver.builtin_types }).unwrap(),
101        format_str.as_ptr() as *const core::ffi::c_char,
102        format_str.len(),
103    );
104
105    let (params, tail) = flatten_type_pack_id(context.arguments);
106
107    let param_offset = 1;
108
109    // unify the prefix one argument at a time - needed if any of the involved types are free
110    for i in 0..expected.len() {
111        if i + param_offset >= params.len() {
112            break;
113        }
114        unsafe {
115            (*context.solver.as_ptr()).constraint_solver_unify(
116                context.constraint.as_ptr(),
117                params[i + param_offset],
118                expected[i],
119            );
120        }
121    }
122
123    // if we know the argument count or if we have too many arguments for sure, we can issue an error
124    let num_actual_params = params.len();
125    let num_expected_params = expected.len() + 1; // + 1 for the format string
126
127    if num_expected_params != num_actual_params
128        && (tail.is_none() || num_expected_params < num_actual_params)
129    {
130        let error = TypeError::type_error_location_type_error_data(
131            unsafe { &*context.call_site.as_ptr() }.base.base.location,
132            crate::type_aliases::type_error_data::TypeErrorData::CountMismatch(CountMismatch {
133                expected: num_expected_params,
134                maximum: None,
135                actual: num_actual_params,
136                context: CountMismatch::Arg,
137                is_variadic: tail.is_some(),
138                function: String::new(),
139            }),
140        );
141        unsafe {
142            (*context.solver.as_ptr()).report_error_type_error(error);
143        }
144    }
145
146    // This is invoked at solve time, so we just need to provide a type for the result of :/.format
147    let result_pack = arena
148        .add_type_pack_initializer_list_type_id(&[unsafe { &*solver.builtin_types }.stringType]);
149    let result_mut = as_mutable_type_pack(context.result);
150    unsafe {
151        (*result_mut).ty =
152            crate::type_aliases::type_pack_variant::TypePackVariant::Bound(result_pack);
153    }
154
155    true
156}