Skip to main content

luau_analyzer_sys/
lib.rs

1use std::collections::HashMap;
2use std::ffi::{CStr, CString};
3use std::os::raw::{c_char, c_int, c_uint, c_void};
4
5#[derive(Debug, Clone)]
6pub struct Diagnostic {
7    pub severity: u8, // 0 for error, 1 for warning
8    pub line: u32,
9    pub col: u32,
10    pub end_line: u32,
11    pub end_col: u32,
12    pub message: String,
13}
14
15#[repr(C)]
16pub struct LuauAnalyzerOpaque {
17    _private: [u8; 0],
18}
19
20type DiagnosticCallback = unsafe extern "C" fn(
21    context: *mut c_void,
22    severity: c_int,
23    line: c_uint,
24    col: c_uint,
25    end_line: c_uint,
26    end_col: c_uint,
27    message: *const c_char,
28);
29
30type ReadSourceCallback =
31    unsafe extern "C" fn(context: *mut c_void, module_name: *const c_char) -> *const c_char;
32
33type ResolveModuleCallback = unsafe extern "C" fn(
34    context: *mut c_void,
35    current_module: *const c_char,
36    required_name: *const c_char,
37) -> *const c_char;
38
39unsafe extern "C" {
40    fn luau_analyzer_create() -> *mut LuauAnalyzerOpaque;
41    fn luau_analyzer_destroy(analyzer: *mut LuauAnalyzerOpaque);
42    fn luau_analyzer_add_definitions(analyzer: *mut LuauAnalyzerOpaque, source: *const c_char);
43    fn luau_analyzer_check(
44        analyzer: *mut LuauAnalyzerOpaque,
45        module_name: *const c_char,
46        read_callback: Option<ReadSourceCallback>,
47        resolve_callback: Option<ResolveModuleCallback>,
48        diag_callback: Option<DiagnosticCallback>,
49        context: *mut c_void,
50    );
51}
52
53struct CheckContext<'a> {
54    diagnostics: Vec<Diagnostic>,
55    cached_strings: HashMap<String, CString>,
56    resolver: &'a dyn Fn(&str) -> Option<String>,
57    path_resolver: &'a dyn Fn(&str, &str) -> Option<String>,
58}
59
60pub struct NativeAnalyzer {
61    ptr: *mut LuauAnalyzerOpaque,
62}
63
64impl NativeAnalyzer {
65    pub fn new() -> Self {
66        unsafe {
67            Self {
68                ptr: luau_analyzer_create(),
69            }
70        }
71    }
72
73    pub fn add_definitions(&mut self, source: &str) {
74        if let Ok(c_str) = CString::new(source) {
75            unsafe {
76                luau_analyzer_add_definitions(self.ptr, c_str.as_ptr());
77            }
78        }
79    }
80
81    pub fn check<F, P>(
82        &mut self,
83        module_name: &str,
84        resolver: F,
85        path_resolver: P,
86    ) -> Vec<Diagnostic>
87    where
88        F: Fn(&str) -> Option<String>,
89        P: Fn(&str, &str) -> Option<String>,
90    {
91        let mut context = CheckContext {
92            diagnostics: Vec::new(),
93            cached_strings: HashMap::new(),
94            resolver: &resolver,
95            path_resolver: &path_resolver,
96        };
97
98        if let Ok(mod_cstr) = CString::new(module_name) {
99            unsafe extern "C" fn read_callback(
100                ctx_ptr: *mut c_void,
101                mod_name: *const c_char,
102            ) -> *const c_char {
103                let ctx = unsafe { &mut *(ctx_ptr as *mut CheckContext) };
104                if mod_name.is_null() {
105                    return std::ptr::null();
106                }
107                let name_str = unsafe { CStr::from_ptr(mod_name) }.to_string_lossy();
108                if let Some(c_str) = ctx.cached_strings.get(name_str.as_ref()) {
109                    return c_str.as_ptr();
110                }
111                if let Some(src) = (ctx.resolver)(name_str.as_ref())
112                    && let Ok(c_str) = CString::new(src)
113                {
114                    let ptr = c_str.as_ptr();
115                    ctx.cached_strings.insert(name_str.into_owned(), c_str);
116                    return ptr;
117                }
118                std::ptr::null()
119            }
120
121            unsafe extern "C" fn resolve_callback(
122                ctx_ptr: *mut c_void,
123                curr_mod: *const c_char,
124                req_name: *const c_char,
125            ) -> *const c_char {
126                let ctx = unsafe { &mut *(ctx_ptr as *mut CheckContext) };
127                if curr_mod.is_null() || req_name.is_null() {
128                    return std::ptr::null();
129                }
130                let curr_mod_str = unsafe { CStr::from_ptr(curr_mod) }.to_string_lossy();
131                let req_name_str = unsafe { CStr::from_ptr(req_name) }.to_string_lossy();
132
133                let cache_key = format!("RESOLVED:{}:{}", curr_mod_str, req_name_str);
134                if let Some(c_str) = ctx.cached_strings.get(&cache_key) {
135                    return c_str.as_ptr();
136                }
137
138                if let Some(resolved) =
139                    (ctx.path_resolver)(curr_mod_str.as_ref(), req_name_str.as_ref())
140                    && let Ok(c_str) = CString::new(resolved)
141                {
142                    let ptr = c_str.as_ptr();
143                    ctx.cached_strings.insert(cache_key, c_str);
144                    return ptr;
145                }
146                std::ptr::null()
147            }
148
149            unsafe extern "C" fn diag_callback(
150                ctx_ptr: *mut c_void,
151                severity: c_int,
152                line: c_uint,
153                col: c_uint,
154                end_line: c_uint,
155                end_col: c_uint,
156                message: *const c_char,
157            ) {
158                let ctx = unsafe { &mut *(ctx_ptr as *mut CheckContext) };
159                let msg_str = if message.is_null() {
160                    String::new()
161                } else {
162                    unsafe { CStr::from_ptr(message) }
163                        .to_string_lossy()
164                        .into_owned()
165                };
166                ctx.diagnostics.push(Diagnostic {
167                    severity: severity as u8,
168                    line,
169                    col,
170                    end_line,
171                    end_col,
172                    message: msg_str,
173                });
174            }
175
176            unsafe {
177                let ctx_void = &mut context as *mut CheckContext as *mut c_void;
178                luau_analyzer_check(
179                    self.ptr,
180                    mod_cstr.as_ptr(),
181                    Some(read_callback),
182                    Some(resolve_callback),
183                    Some(diag_callback),
184                    ctx_void,
185                );
186            }
187        }
188
189        context.diagnostics
190    }
191}
192
193impl Default for NativeAnalyzer {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199impl Drop for NativeAnalyzer {
200    fn drop(&mut self) {
201        unsafe {
202            if !self.ptr.is_null() {
203                luau_analyzer_destroy(self.ptr);
204                self.ptr = std::ptr::null_mut();
205            }
206        }
207    }
208}
209
210unsafe impl Send for NativeAnalyzer {}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn test_create_analyzer() {
218        let analyzer = NativeAnalyzer::new();
219        assert!(!analyzer.ptr.is_null());
220    }
221
222    #[test]
223    fn test_check_simple_no_errors() {
224        let mut analyzer = NativeAnalyzer::new();
225        let source = "local _x: number = 10\nlocal _y: number = _x + 5\n";
226
227        let diagnostics = analyzer.check(
228            "main",
229            |name| {
230                if name == "main" {
231                    Some(source.to_string())
232                } else {
233                    None
234                }
235            },
236            |_, _| None,
237        );
238
239        assert!(
240            diagnostics.is_empty(),
241            "Expected no diagnostics, got: {:?}",
242            diagnostics
243        );
244    }
245
246    #[test]
247    fn test_check_type_error() {
248        let mut analyzer = NativeAnalyzer::new();
249        // Intentional type error: assigning string to a number variable
250        let source = "local _x: number = 'hello'\n";
251
252        let diagnostics = analyzer.check(
253            "main",
254            |name| {
255                if name == "main" {
256                    Some(source.to_string())
257                } else {
258                    None
259                }
260            },
261            |_, _| None,
262        );
263
264        assert!(
265            !diagnostics.is_empty(),
266            "Expected at least one type error diagnostic"
267        );
268        let diag = &diagnostics[0];
269        assert!(
270            diag.severity == 0 || diag.severity == 1,
271            "Expected error or warning severity, got: {}",
272            diag.severity
273        );
274        assert!(
275            diag.message.contains("string"),
276            "Expected message to mention 'string', got: {}",
277            diag.message
278        );
279        assert!(
280            diag.message.contains("number"),
281            "Expected message to mention 'number', got: {}",
282            diag.message
283        );
284    }
285
286    #[test]
287    fn test_check_syntax_error() {
288        let mut analyzer = NativeAnalyzer::new();
289        // Syntax error: missing end of statement/operator
290        let source = "local x = \n";
291
292        let diagnostics = analyzer.check(
293            "main",
294            |name| {
295                if name == "main" {
296                    Some(source.to_string())
297                } else {
298                    None
299                }
300            },
301            |_, _| None,
302        );
303
304        assert!(!diagnostics.is_empty(), "Expected syntax error diagnostic");
305        // Syntax error is usually severity 0 (error)
306        assert_eq!(diagnostics[0].severity, 0);
307    }
308
309    #[test]
310    fn test_check_with_submodule() {
311        let mut analyzer = NativeAnalyzer::new();
312        let main_source = "local dep = require('dependency')\nlocal _x: number = dep.value\n";
313        let dep_source = "local M = {}\nM.value = 42\nreturn M\n";
314
315        let diagnostics = analyzer.check(
316            "main",
317            |name| match name {
318                "main" => Some(main_source.to_string()),
319                "dependency" => Some(dep_source.to_string()),
320                _ => None,
321            },
322            |current, required| {
323                if current == "main" && required == "dependency" {
324                    Some("dependency".to_string())
325                } else {
326                    None
327                }
328            },
329        );
330
331        assert!(
332            diagnostics.is_empty(),
333            "Expected no diagnostics, got: {:?}",
334            diagnostics
335        );
336    }
337
338    #[test]
339    fn test_multiple_checks_same_analyzer() {
340        let mut analyzer = NativeAnalyzer::new();
341
342        let src1 = "local _x: number = 10\n";
343        let diagnostics1 = analyzer.check(
344            "mod1",
345            |name| {
346                if name == "mod1" {
347                    Some(src1.to_string())
348                } else {
349                    None
350                }
351            },
352            |_, _| None,
353        );
354        assert!(diagnostics1.is_empty());
355
356        let src2 = "local _y: string = 'hello'\n";
357        let diagnostics2 = analyzer.check(
358            "mod2",
359            |name| {
360                if name == "mod2" {
361                    Some(src2.to_string())
362                } else {
363                    None
364                }
365            },
366            |_, _| None,
367        );
368        assert!(diagnostics2.is_empty());
369    }
370
371    #[test]
372    fn test_custom_definitions() {
373        let mut analyzer = NativeAnalyzer::new();
374        // Register a custom global function `my_global_helper`
375        analyzer.add_definitions("declare function my_global_helper(val: string): number\n");
376
377        // Code that uses the custom global function correctly
378        let correct_source = "--!strict\nlocal _x: number = my_global_helper('test')\n";
379        let diagnostics = analyzer.check(
380            "main_correct",
381            |name| {
382                if name == "main_correct" {
383                    Some(correct_source.to_string())
384                } else {
385                    None
386                }
387            },
388            |_, _| None,
389        );
390        assert!(
391            diagnostics.is_empty(),
392            "Expected no diagnostics, got: {:?}",
393            diagnostics
394        );
395
396        // Code that uses it incorrectly (type mismatch: passing number instead of string)
397        let incorrect_source = "--!strict\nlocal _x: number = my_global_helper(123)\n";
398        let diagnostics2 = analyzer.check(
399            "main_incorrect",
400            |name| {
401                if name == "main_incorrect" {
402                    Some(incorrect_source.to_string())
403                } else {
404                    None
405                }
406            },
407            |_, _| None,
408        );
409        println!("test_custom_definitions diagnostics: {:?}", diagnostics2);
410        assert!(
411            !diagnostics2.is_empty(),
412            "Expected a type error due to parameter type mismatch"
413        );
414        let msg = &diagnostics2[0].message;
415        assert!(
416            msg.contains("number") || msg.contains("string"),
417            "Got message: {}",
418            msg
419        );
420    }
421
422    #[test]
423    fn test_precise_error_coordinates() {
424        let mut analyzer = NativeAnalyzer::new();
425        // Error on line 3 (0-indexed line 2), column 19:
426        // Line 1: --!strict
427        // Line 2: local _x: number = 10
428        // Line 3: local _y: string = 20
429        let source = "--!strict\nlocal _x: number = 10\nlocal _y: string = 20\n";
430
431        let diagnostics = analyzer.check(
432            "main_precise",
433            |name| {
434                if name == "main_precise" {
435                    Some(source.to_string())
436                } else {
437                    None
438                }
439            },
440            |_, _| None,
441        );
442
443        assert!(!diagnostics.is_empty());
444        let diag = &diagnostics[0];
445        // In Luau, line numbers in error locations are 0-based.
446        // Line 3 is index 2.
447        assert_eq!(diag.line, 2);
448        assert!(diag.col < 100);
449    }
450
451    #[test]
452    fn test_resolver_returns_none() {
453        use std::cell::RefCell;
454        use std::rc::Rc;
455
456        let mut analyzer = NativeAnalyzer::new();
457        let source = "--!strict\nlocal _dep = require('missing_module')\n";
458
459        let resolver_called = Rc::new(RefCell::new(false));
460        let resolver_called_clone = resolver_called.clone();
461
462        let diagnostics = analyzer.check(
463            "main_resolver",
464            |name| {
465                if name == "main_resolver" {
466                    Some(source.to_string())
467                } else {
468                    if name == "missing_module" {
469                        *resolver_called_clone.borrow_mut() = true;
470                    }
471                    None // fails to load
472                }
473            },
474            |current, required| {
475                if current == "main_resolver" && required == "missing_module" {
476                    Some("missing_module".to_string())
477                } else {
478                    None
479                }
480            },
481        );
482
483        println!("test_resolver_returns_none diagnostics: {:?}", diagnostics);
484        // Verify that the resolver was indeed called with the missing module name
485        assert!(*resolver_called.borrow());
486        // And diagnostics for the entry file is empty because errors on the required module are filtered out
487        assert!(diagnostics.is_empty());
488    }
489
490    #[test]
491    fn test_multithreaded_analyzer() {
492        use std::thread;
493
494        let mut analyzer = NativeAnalyzer::new();
495        analyzer.add_definitions("declare function thread_safe_helper(): ()\n");
496
497        // NativeAnalyzer is Send, so we can move it to another thread
498        let handle = thread::spawn(move || {
499            let source = "thread_safe_helper()\n";
500            let diagnostics = analyzer.check(
501                "main",
502                |name| {
503                    if name == "main" {
504                        Some(source.to_string())
505                    } else {
506                        None
507                    }
508                },
509                |_, _| None,
510            );
511            assert!(diagnostics.is_empty());
512            analyzer // Return it back
513        });
514
515        let _analyzer = handle.join().unwrap();
516    }
517
518    #[test]
519    fn test_default_analyzer() {
520        // Test Default trait implementation
521        let analyzer = NativeAnalyzer::default();
522        assert!(!analyzer.ptr.is_null());
523    }
524
525    #[test]
526    fn test_diagnostics_clone_and_debug() {
527        let diag = Diagnostic {
528            severity: 0,
529            line: 1,
530            col: 2,
531            end_line: 3,
532            end_col: 4,
533            message: "Test message".to_string(),
534        };
535
536        let cloned = diag.clone();
537        assert_eq!(cloned.severity, diag.severity);
538        assert_eq!(cloned.line, diag.line);
539        assert_eq!(cloned.col, diag.col);
540        assert_eq!(cloned.end_line, diag.end_line);
541        assert_eq!(cloned.end_col, diag.end_col);
542        assert_eq!(cloned.message, diag.message);
543
544        let debug_str = format!("{:?}", diag);
545        assert!(debug_str.contains("Test message"));
546    }
547
548    #[test]
549    fn test_check_with_nested_relative_modules() {
550        let mut analyzer = NativeAnalyzer::new();
551
552        // main requires foo/bar, which in turn requires ../baz
553        let main_src = "local _bar = require('foo/bar')\n";
554        let bar_src = "local _baz = require('../baz')\nlocal M = {}\nreturn M\n";
555        let baz_src = "local M = {}\nM.value = 100\nreturn M\n";
556
557        let diagnostics = analyzer.check(
558            "main",
559            |name| match name {
560                "main" => Some(main_src.to_string()),
561                "foo/bar" => Some(bar_src.to_string()),
562                "baz" => Some(baz_src.to_string()),
563                _ => None,
564            },
565            |current, required| {
566                if current == "main" && required == "foo/bar" {
567                    Some("foo/bar".to_string())
568                } else if current == "foo/bar" && required == "../baz" {
569                    // Resolve relative path "../baz" from "foo/bar" to "baz"
570                    Some("baz".to_string())
571                } else {
572                    None
573                }
574            },
575        );
576
577        assert!(
578            diagnostics.is_empty(),
579            "Expected no diagnostics, got: {:?}",
580            diagnostics
581        );
582    }
583}