Skip to main content

luaur_analyze_cli/functions/
main.rs

1use crate::enums::report_format::ReportFormat;
2use crate::functions::assertion_handler::assertion_handler;
3use crate::functions::display_help::display_help;
4use crate::functions::report::report;
5use crate::functions::report_module_result::report_module_result;
6use crate::records::cli_config_resolver::CliConfigResolver;
7use crate::records::cli_file_resolver::CliFileResolver;
8use crate::records::task_scheduler::{Task, TaskScheduler};
9use alloc::boxed::Box;
10use alloc::string::String;
11use alloc::vec::Vec;
12use core::ffi::c_char;
13use luaur_analysis::enums::solver_mode::SolverMode;
14use luaur_analysis::functions::freeze::freeze;
15use luaur_analysis::functions::register_builtin_globals::register_builtin_globals;
16use luaur_analysis::functions::to_string_error_alt_k::to_string_type_error_type_error_to_string_options;
17use luaur_analysis::records::frontend::Frontend;
18use luaur_analysis::records::frontend_options::FrontendOptions;
19use luaur_analysis::records::internal_compiler_error::InternalCompilerError;
20use luaur_analysis::records::internal_error::InternalError;
21use luaur_analysis::records::type_error::TypeError;
22use luaur_analysis::records::type_error_to_string_options::TypeErrorToStringOptions;
23use luaur_analysis::type_aliases::module_name_file_resolver::ModuleName;
24use luaur_ast::enums::mode::Mode;
25use luaur_ast::records::location::Location;
26use luaur_cli_lib::functions::get_source_files::get_source_files;
27use luaur_cli_lib::functions::set_luau_flags_default::set_luau_flags_default;
28use luaur_cli_lib::functions::set_luau_flags_flags_alt_b::set_luau_flags_c_char;
29
30/// A `Box<dyn Fn()>` carried across the thread boundary. The C++ `TaskScheduler`
31/// stores `std::function<void()>` with no `Send` concept; the queued tasks capture
32/// shared frontend state and run on worker threads exactly as here.
33struct SendTask(Box<dyn Fn()>);
34unsafe impl Send for SendTask {}
35
36/// C++ `int main(int argc, char** argv)` (`CLI/src/Analyze.cpp:394-542`).
37pub fn main() {
38    std::process::exit(run());
39}
40
41fn run() -> i32 {
42    // Build an owned argv of NUL-terminated C strings so the FileUtils/Flags ports
43    // (which take `int argc, char** argv`) can be called faithfully.
44    let owned_args: Vec<std::ffi::CString> = std::env::args()
45        .map(|a| std::ffi::CString::new(a).unwrap_or_else(|_| std::ffi::CString::new("").unwrap()))
46        .collect();
47    let mut argv: Vec<*mut c_char> = owned_args
48        .iter()
49        .map(|c| c.as_ptr() as *mut c_char)
50        .collect();
51    argv.push(core::ptr::null_mut());
52    let argc = owned_args.len() as i32;
53
54    // Luau::assertHandler() = assertionHandler;
55    *luaur_common::functions::assert_handler::assert_handler() = Some(assertion_handler);
56
57    // setLuauFlagsDefault();
58    set_luau_flags_default();
59
60    // if (argc >= 2 && strcmp(argv[1], "--help") == 0) { displayHelp(argv[0]); return 0; }
61    let args: Vec<String> = std::env::args().collect();
62    if args.len() >= 2 && args[1] == "--help" {
63        display_help(&args[0]);
64        return 0;
65    }
66
67    let mut format = ReportFormat::Default;
68    let mut mode = Mode::Nonstrict;
69    let mut annotate = false;
70    let mut thread_count: i32 = 0;
71    let mut base_path = String::new();
72    let mut solver_mode = SolverMode::New;
73
74    // for (int i = 1; i < argc; ++i)
75    for arg in args.iter().skip(1) {
76        if !arg.starts_with('-') {
77            continue;
78        }
79
80        if arg == "--formatter=plain" {
81            format = ReportFormat::Luacheck;
82        } else if arg == "--formatter=gnu" {
83            format = ReportFormat::Gnu;
84        } else if arg == "--mode=strict" {
85            mode = Mode::Strict;
86        } else if arg == "--annotate" {
87            annotate = true;
88        } else if arg == "--timetrace" {
89            luaur_common::FFlag::DebugLuauTimeTracing.set(true);
90        } else if let Some(rest) = arg.strip_prefix("--fflags=") {
91            let c = std::ffi::CString::new(rest)
92                .unwrap_or_else(|_| std::ffi::CString::new("").unwrap());
93            set_luau_flags_c_char(c.as_ptr());
94        } else if let Some(rest) = arg.strip_prefix("-j") {
95            thread_count = rest.parse::<i32>().unwrap_or(0);
96        } else if let Some(rest) = arg.strip_prefix("--logbase=") {
97            base_path = String::from(rest);
98        } else if arg == "--solver=old" {
99            solver_mode = SolverMode::Old;
100        }
101    }
102
103    // The Rust build does not define LUAU_ENABLE_TIME_TRACE; mirror the C++ guard.
104    if luaur_common::FFlag::DebugLuauTimeTracing.get() {
105        eprintln!(
106            "To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled"
107        );
108        return 1;
109    }
110
111    // FrontendOptions frontendOptions; retainFullTypeGraphs = annotate; runLintChecks = true;
112    let mut frontend_options = FrontendOptions::default();
113    frontend_options.retain_full_type_graphs = annotate;
114    frontend_options.run_lint_checks = true;
115
116    // CliFileResolver fileResolver; CliConfigResolver configResolver(mode);
117    let mut file_resolver = CliFileResolver::new();
118    let mut config_resolver = CliConfigResolver::cli_config_resolver(mode);
119
120    // Frontend frontend(solverMode, &fileResolver, &configResolver, frontendOptions);
121    let mut frontend =
122        Frontend::frontend_solver_mode_file_resolver_config_resolver_frontend_options(
123            solver_mode,
124            &mut file_resolver.base,
125            &mut config_resolver.base,
126            frontend_options,
127        );
128    // Re-establish the resolver pointers and the self-referential pointers now that
129    // `frontend` lives at a stable address (mirrors the project's wiring convention).
130    frontend.file_resolver = &mut file_resolver.base;
131    frontend.config_resolver = &mut config_resolver.base;
132    unsafe {
133        frontend.wire_self_pointers();
134    }
135
136    // if (FFlag::DebugLuauLogSolverToJsonFile) { frontend.writeJsonLog = ...; }
137    if luaur_common::FFlag::DebugLuauLogSolverToJsonFile.get() {
138        let base_path = base_path.clone();
139        frontend.write_json_log = Some(alloc::rc::Rc::new(
140            move |module_name: &ModuleName, log: String| {
141                let mut path = alloc::format!("{}.log.json", module_name);
142                if let Some(pos) = module_name.rfind('/') {
143                    path = String::from(&module_name[pos + 1..]);
144                }
145                if !base_path.is_empty() {
146                    path = luaur_cli_lib::functions::join_paths_file_utils_alt_b::join_paths_string_view_string_view(&base_path, &path);
147                }
148                if std::fs::write(&path, alloc::format!("{}\n", log)).is_ok() {
149                    println!("Wrote JSON log to {}", path);
150                }
151            },
152        ));
153    }
154
155    // registerBuiltinGlobals(frontend, frontend.globals);
156    // freeze(frontend.globals.globalTypes);
157    unsafe {
158        let frontend_ptr: *mut Frontend = &mut frontend;
159        register_builtin_globals(&mut *frontend_ptr, &mut (*frontend_ptr).globals, false);
160        freeze((*frontend_ptr).globals.global_types_mut());
161    }
162
163    // std::vector<std::string> files = getSourceFiles(argc, argv);
164    let files = get_source_files(argc, argv.as_mut_ptr());
165
166    // for (const std::string& path : files) frontend.queueModuleCheck(path);
167    frontend.queue_module_check_vector_module_name(&files);
168
169    let mut checked_modules: Vec<ModuleName>;
170
171    // if (threadCount <= 0) threadCount = std::min(getThreadCount(), 8u);
172    if thread_count <= 0 {
173        thread_count = core::cmp::min(TaskScheduler::get_thread_count(), 8) as i32;
174    }
175
176    // try { TaskScheduler scheduler(threadCount); checkedModules = frontend.checkQueuedModules(...); }
177    let frontend_ptr: *mut Frontend = &mut frontend;
178    let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
179        let scheduler = TaskScheduler::task_scheduler_task_scheduler(thread_count as u32);
180        let scheduler_ptr: *const TaskScheduler = &scheduler;
181
182        // The executor pushes each task onto the scheduler queue, matching:
183        //   [&](std::vector<std::function<void()>> tasks) { for (auto& t : tasks) scheduler.push(std::move(t)); }
184        let execute_tasks: Box<dyn Fn(Vec<Box<dyn Fn()>>)> = Box::new(move |tasks| {
185            for task in tasks {
186                let send_task = SendTask(task);
187                let boxed: Task = Some(Box::new(move || {
188                    // Move the whole `SendTask` (which is `Send`) into the closure so
189                    // the auto-trait analysis sees `Send`, rather than capturing only
190                    // the inner non-`Send` `Box<dyn Fn()>` (edition-2021 disjoint capture).
191                    let send_task = send_task;
192                    (send_task.0)();
193                }));
194                crate::methods::task_scheduler_push::task_scheduler_push(
195                    unsafe { &*scheduler_ptr },
196                    boxed,
197                );
198            }
199        });
200
201        let progress: Box<dyn Fn(usize, usize) -> bool> = Box::new(|_done, _total| true);
202
203        let modules =
204            unsafe { (*frontend_ptr).check_queued_modules(None, execute_tasks, progress) };
205
206        // scheduler is dropped here (joins workers), matching the C++ block scope.
207        drop(scheduler);
208        modules
209    }));
210
211    match result {
212        Ok(modules) => checked_modules = modules,
213        Err(payload) => {
214            // catch (const InternalCompilerError& ice)
215            let ice: InternalCompilerError =
216                if let Some(e) = payload.downcast_ref::<InternalCompilerError>() {
217                    e.clone()
218                } else if let Some(e) = payload
219                    .downcast_ref::<luaur_analysis::records::time_limit_error::TimeLimitError>(
220                ) {
221                    e.base.clone()
222                } else if let Some(e) = payload
223                    .downcast_ref::<luaur_analysis::records::user_cancel_error::UserCancelError>(
224                ) {
225                    e.base.clone()
226                } else {
227                    std::panic::resume_unwind(payload);
228                };
229
230            let location = ice.location.unwrap_or_else(Location::default);
231            let module_name = ice
232                .module_name
233                .clone()
234                .unwrap_or_else(|| String::from("<unknown module>"));
235            let human_readable_name = unsafe {
236                luaur_analysis::records::file_resolver::FileResolver::get_human_readable_module_name(
237                    frontend.file_resolver,
238                    &module_name,
239                )
240            };
241
242            let error = TypeError::type_error_location_module_name_type_error_data(
243                location.clone(),
244                module_name,
245                InternalError::new(ice.message.clone()).into(),
246            );
247
248            let message = to_string_type_error_type_error_to_string_options(
249                &error,
250                TypeErrorToStringOptions {
251                    file_resolver: frontend.file_resolver,
252                },
253            );
254            report(
255                format,
256                &human_readable_name,
257                &location,
258                "InternalCompilerError",
259                &message,
260            );
261            return 1;
262        }
263    }
264
265    let mut failed = 0i32;
266
267    // for (const ModuleName& name : checkedModules) failed += !reportModuleResult(...);
268    let names = core::mem::take(&mut checked_modules);
269    for name in &names {
270        if !report_module_result(&mut frontend, name, format, annotate) {
271            failed += 1;
272        }
273    }
274
275    // if (!configResolver.configErrors.empty()) { ... }
276    if !config_resolver.config_errors.is_empty() {
277        failed += config_resolver.config_errors.len() as i32;
278
279        for (path, error) in &config_resolver.config_errors {
280            eprintln!("{}: {}", path, error);
281        }
282    }
283
284    if format == ReportFormat::Luacheck {
285        0
286    } else if failed != 0 {
287        1
288    } else {
289        0
290    }
291}