1use std::collections::HashMap;
2use std::ffi::{CStr, CString};
3use std::os::raw::{c_char, c_int, c_uint, c_void};
4
5#[derive(Debug, Clone)]
6pub struct Diagnostic {
7 pub severity: u8, pub line: u32,
9 pub col: u32,
10 pub end_line: u32,
11 pub end_col: u32,
12 pub message: String,
13}
14
15#[repr(C)]
16pub struct LuauAnalyzerOpaque {
17 _private: [u8; 0],
18}
19
20type DiagnosticCallback = unsafe extern "C" fn(
21 context: *mut c_void,
22 severity: c_int,
23 line: c_uint,
24 col: c_uint,
25 end_line: c_uint,
26 end_col: c_uint,
27 message: *const c_char,
28);
29
30type ReadSourceCallback =
31 unsafe extern "C" fn(context: *mut c_void, module_name: *const c_char) -> *const c_char;
32
33type ResolveModuleCallback = unsafe extern "C" fn(
34 context: *mut c_void,
35 current_module: *const c_char,
36 required_name: *const c_char,
37) -> *const c_char;
38
39unsafe extern "C" {
40 fn luau_analyzer_create() -> *mut LuauAnalyzerOpaque;
41 fn luau_analyzer_destroy(analyzer: *mut LuauAnalyzerOpaque);
42 fn luau_analyzer_add_definitions(analyzer: *mut LuauAnalyzerOpaque, source: *const c_char);
43 fn luau_analyzer_check(
44 analyzer: *mut LuauAnalyzerOpaque,
45 module_name: *const c_char,
46 read_callback: Option<ReadSourceCallback>,
47 resolve_callback: Option<ResolveModuleCallback>,
48 diag_callback: Option<DiagnosticCallback>,
49 context: *mut c_void,
50 );
51}
52
53struct CheckContext<'a> {
54 diagnostics: Vec<Diagnostic>,
55 cached_strings: HashMap<String, CString>,
56 resolver: &'a dyn Fn(&str) -> Option<String>,
57 path_resolver: &'a dyn Fn(&str, &str) -> Option<String>,
58}
59
60pub struct NativeAnalyzer {
61 ptr: *mut LuauAnalyzerOpaque,
62}
63
64impl NativeAnalyzer {
65 pub fn new() -> Self {
66 unsafe {
67 Self {
68 ptr: luau_analyzer_create(),
69 }
70 }
71 }
72
73 pub fn add_definitions(&mut self, source: &str) {
74 if let Ok(c_str) = CString::new(source) {
75 unsafe {
76 luau_analyzer_add_definitions(self.ptr, c_str.as_ptr());
77 }
78 }
79 }
80
81 pub fn check<F, P>(
82 &mut self,
83 module_name: &str,
84 resolver: F,
85 path_resolver: P,
86 ) -> Vec<Diagnostic>
87 where
88 F: Fn(&str) -> Option<String>,
89 P: Fn(&str, &str) -> Option<String>,
90 {
91 let mut context = CheckContext {
92 diagnostics: Vec::new(),
93 cached_strings: HashMap::new(),
94 resolver: &resolver,
95 path_resolver: &path_resolver,
96 };
97
98 if let Ok(mod_cstr) = CString::new(module_name) {
99 unsafe extern "C" fn read_callback(
100 ctx_ptr: *mut c_void,
101 mod_name: *const c_char,
102 ) -> *const c_char {
103 let ctx = unsafe { &mut *(ctx_ptr as *mut CheckContext) };
104 if mod_name.is_null() {
105 return std::ptr::null();
106 }
107 let name_str = unsafe { CStr::from_ptr(mod_name) }.to_string_lossy();
108 if let Some(c_str) = ctx.cached_strings.get(name_str.as_ref()) {
109 return c_str.as_ptr();
110 }
111 if let Some(src) = (ctx.resolver)(name_str.as_ref())
112 && let Ok(c_str) = CString::new(src)
113 {
114 let ptr = c_str.as_ptr();
115 ctx.cached_strings.insert(name_str.into_owned(), c_str);
116 return ptr;
117 }
118 std::ptr::null()
119 }
120
121 unsafe extern "C" fn resolve_callback(
122 ctx_ptr: *mut c_void,
123 curr_mod: *const c_char,
124 req_name: *const c_char,
125 ) -> *const c_char {
126 let ctx = unsafe { &mut *(ctx_ptr as *mut CheckContext) };
127 if curr_mod.is_null() || req_name.is_null() {
128 return std::ptr::null();
129 }
130 let curr_mod_str = unsafe { CStr::from_ptr(curr_mod) }.to_string_lossy();
131 let req_name_str = unsafe { CStr::from_ptr(req_name) }.to_string_lossy();
132
133 let cache_key = format!("RESOLVED:{}:{}", curr_mod_str, req_name_str);
134 if let Some(c_str) = ctx.cached_strings.get(&cache_key) {
135 return c_str.as_ptr();
136 }
137
138 if let Some(resolved) =
139 (ctx.path_resolver)(curr_mod_str.as_ref(), req_name_str.as_ref())
140 && let Ok(c_str) = CString::new(resolved)
141 {
142 let ptr = c_str.as_ptr();
143 ctx.cached_strings.insert(cache_key, c_str);
144 return ptr;
145 }
146 std::ptr::null()
147 }
148
149 unsafe extern "C" fn diag_callback(
150 ctx_ptr: *mut c_void,
151 severity: c_int,
152 line: c_uint,
153 col: c_uint,
154 end_line: c_uint,
155 end_col: c_uint,
156 message: *const c_char,
157 ) {
158 let ctx = unsafe { &mut *(ctx_ptr as *mut CheckContext) };
159 let msg_str = if message.is_null() {
160 String::new()
161 } else {
162 unsafe { CStr::from_ptr(message) }
163 .to_string_lossy()
164 .into_owned()
165 };
166 ctx.diagnostics.push(Diagnostic {
167 severity: severity as u8,
168 line,
169 col,
170 end_line,
171 end_col,
172 message: msg_str,
173 });
174 }
175
176 unsafe {
177 let ctx_void = &mut context as *mut CheckContext as *mut c_void;
178 luau_analyzer_check(
179 self.ptr,
180 mod_cstr.as_ptr(),
181 Some(read_callback),
182 Some(resolve_callback),
183 Some(diag_callback),
184 ctx_void,
185 );
186 }
187 }
188
189 context.diagnostics
190 }
191}
192
193impl Default for NativeAnalyzer {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199impl Drop for NativeAnalyzer {
200 fn drop(&mut self) {
201 unsafe {
202 if !self.ptr.is_null() {
203 luau_analyzer_destroy(self.ptr);
204 self.ptr = std::ptr::null_mut();
205 }
206 }
207 }
208}
209
210unsafe impl Send for NativeAnalyzer {}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn test_create_analyzer() {
218 let analyzer = NativeAnalyzer::new();
219 assert!(!analyzer.ptr.is_null());
220 }
221
222 #[test]
223 fn test_check_simple_no_errors() {
224 let mut analyzer = NativeAnalyzer::new();
225 let source = "local _x: number = 10\nlocal _y: number = _x + 5\n";
226
227 let diagnostics = analyzer.check(
228 "main",
229 |name| {
230 if name == "main" {
231 Some(source.to_string())
232 } else {
233 None
234 }
235 },
236 |_, _| None,
237 );
238
239 assert!(
240 diagnostics.is_empty(),
241 "Expected no diagnostics, got: {:?}",
242 diagnostics
243 );
244 }
245
246 #[test]
247 fn test_check_type_error() {
248 let mut analyzer = NativeAnalyzer::new();
249 let source = "local _x: number = 'hello'\n";
251
252 let diagnostics = analyzer.check(
253 "main",
254 |name| {
255 if name == "main" {
256 Some(source.to_string())
257 } else {
258 None
259 }
260 },
261 |_, _| None,
262 );
263
264 assert!(
265 !diagnostics.is_empty(),
266 "Expected at least one type error diagnostic"
267 );
268 let diag = &diagnostics[0];
269 assert!(
270 diag.severity == 0 || diag.severity == 1,
271 "Expected error or warning severity, got: {}",
272 diag.severity
273 );
274 assert!(
275 diag.message.contains("string"),
276 "Expected message to mention 'string', got: {}",
277 diag.message
278 );
279 assert!(
280 diag.message.contains("number"),
281 "Expected message to mention 'number', got: {}",
282 diag.message
283 );
284 }
285
286 #[test]
287 fn test_check_syntax_error() {
288 let mut analyzer = NativeAnalyzer::new();
289 let source = "local x = \n";
291
292 let diagnostics = analyzer.check(
293 "main",
294 |name| {
295 if name == "main" {
296 Some(source.to_string())
297 } else {
298 None
299 }
300 },
301 |_, _| None,
302 );
303
304 assert!(!diagnostics.is_empty(), "Expected syntax error diagnostic");
305 assert_eq!(diagnostics[0].severity, 0);
307 }
308
309 #[test]
310 fn test_check_with_submodule() {
311 let mut analyzer = NativeAnalyzer::new();
312 let main_source = "local dep = require('dependency')\nlocal _x: number = dep.value\n";
313 let dep_source = "local M = {}\nM.value = 42\nreturn M\n";
314
315 let diagnostics = analyzer.check(
316 "main",
317 |name| match name {
318 "main" => Some(main_source.to_string()),
319 "dependency" => Some(dep_source.to_string()),
320 _ => None,
321 },
322 |current, required| {
323 if current == "main" && required == "dependency" {
324 Some("dependency".to_string())
325 } else {
326 None
327 }
328 },
329 );
330
331 assert!(
332 diagnostics.is_empty(),
333 "Expected no diagnostics, got: {:?}",
334 diagnostics
335 );
336 }
337
338 #[test]
339 fn test_multiple_checks_same_analyzer() {
340 let mut analyzer = NativeAnalyzer::new();
341
342 let src1 = "local _x: number = 10\n";
343 let diagnostics1 = analyzer.check(
344 "mod1",
345 |name| {
346 if name == "mod1" {
347 Some(src1.to_string())
348 } else {
349 None
350 }
351 },
352 |_, _| None,
353 );
354 assert!(diagnostics1.is_empty());
355
356 let src2 = "local _y: string = 'hello'\n";
357 let diagnostics2 = analyzer.check(
358 "mod2",
359 |name| {
360 if name == "mod2" {
361 Some(src2.to_string())
362 } else {
363 None
364 }
365 },
366 |_, _| None,
367 );
368 assert!(diagnostics2.is_empty());
369 }
370
371 #[test]
372 fn test_custom_definitions() {
373 let mut analyzer = NativeAnalyzer::new();
374 analyzer.add_definitions("declare function my_global_helper(val: string): number\n");
376
377 let correct_source = "--!strict\nlocal _x: number = my_global_helper('test')\n";
379 let diagnostics = analyzer.check(
380 "main_correct",
381 |name| {
382 if name == "main_correct" {
383 Some(correct_source.to_string())
384 } else {
385 None
386 }
387 },
388 |_, _| None,
389 );
390 assert!(
391 diagnostics.is_empty(),
392 "Expected no diagnostics, got: {:?}",
393 diagnostics
394 );
395
396 let incorrect_source = "--!strict\nlocal _x: number = my_global_helper(123)\n";
398 let diagnostics2 = analyzer.check(
399 "main_incorrect",
400 |name| {
401 if name == "main_incorrect" {
402 Some(incorrect_source.to_string())
403 } else {
404 None
405 }
406 },
407 |_, _| None,
408 );
409 println!("test_custom_definitions diagnostics: {:?}", diagnostics2);
410 assert!(
411 !diagnostics2.is_empty(),
412 "Expected a type error due to parameter type mismatch"
413 );
414 let msg = &diagnostics2[0].message;
415 assert!(
416 msg.contains("number") || msg.contains("string"),
417 "Got message: {}",
418 msg
419 );
420 }
421
422 #[test]
423 fn test_precise_error_coordinates() {
424 let mut analyzer = NativeAnalyzer::new();
425 let source = "--!strict\nlocal _x: number = 10\nlocal _y: string = 20\n";
430
431 let diagnostics = analyzer.check(
432 "main_precise",
433 |name| {
434 if name == "main_precise" {
435 Some(source.to_string())
436 } else {
437 None
438 }
439 },
440 |_, _| None,
441 );
442
443 assert!(!diagnostics.is_empty());
444 let diag = &diagnostics[0];
445 assert_eq!(diag.line, 2);
448 assert!(diag.col < 100);
449 }
450
451 #[test]
452 fn test_resolver_returns_none() {
453 use std::cell::RefCell;
454 use std::rc::Rc;
455
456 let mut analyzer = NativeAnalyzer::new();
457 let source = "--!strict\nlocal _dep = require('missing_module')\n";
458
459 let resolver_called = Rc::new(RefCell::new(false));
460 let resolver_called_clone = resolver_called.clone();
461
462 let diagnostics = analyzer.check(
463 "main_resolver",
464 |name| {
465 if name == "main_resolver" {
466 Some(source.to_string())
467 } else {
468 if name == "missing_module" {
469 *resolver_called_clone.borrow_mut() = true;
470 }
471 None }
473 },
474 |current, required| {
475 if current == "main_resolver" && required == "missing_module" {
476 Some("missing_module".to_string())
477 } else {
478 None
479 }
480 },
481 );
482
483 println!("test_resolver_returns_none diagnostics: {:?}", diagnostics);
484 assert!(*resolver_called.borrow());
486 assert!(diagnostics.is_empty());
488 }
489
490 #[test]
491 fn test_multithreaded_analyzer() {
492 use std::thread;
493
494 let mut analyzer = NativeAnalyzer::new();
495 analyzer.add_definitions("declare function thread_safe_helper(): ()\n");
496
497 let handle = thread::spawn(move || {
499 let source = "thread_safe_helper()\n";
500 let diagnostics = analyzer.check(
501 "main",
502 |name| {
503 if name == "main" {
504 Some(source.to_string())
505 } else {
506 None
507 }
508 },
509 |_, _| None,
510 );
511 assert!(diagnostics.is_empty());
512 analyzer });
514
515 let _analyzer = handle.join().unwrap();
516 }
517
518 #[test]
519 fn test_default_analyzer() {
520 let analyzer = NativeAnalyzer::default();
522 assert!(!analyzer.ptr.is_null());
523 }
524
525 #[test]
526 fn test_diagnostics_clone_and_debug() {
527 let diag = Diagnostic {
528 severity: 0,
529 line: 1,
530 col: 2,
531 end_line: 3,
532 end_col: 4,
533 message: "Test message".to_string(),
534 };
535
536 let cloned = diag.clone();
537 assert_eq!(cloned.severity, diag.severity);
538 assert_eq!(cloned.line, diag.line);
539 assert_eq!(cloned.col, diag.col);
540 assert_eq!(cloned.end_line, diag.end_line);
541 assert_eq!(cloned.end_col, diag.end_col);
542 assert_eq!(cloned.message, diag.message);
543
544 let debug_str = format!("{:?}", diag);
545 assert!(debug_str.contains("Test message"));
546 }
547
548 #[test]
549 fn test_check_with_nested_relative_modules() {
550 let mut analyzer = NativeAnalyzer::new();
551
552 let main_src = "local _bar = require('foo/bar')\n";
554 let bar_src = "local _baz = require('../baz')\nlocal M = {}\nreturn M\n";
555 let baz_src = "local M = {}\nM.value = 100\nreturn M\n";
556
557 let diagnostics = analyzer.check(
558 "main",
559 |name| match name {
560 "main" => Some(main_src.to_string()),
561 "foo/bar" => Some(bar_src.to_string()),
562 "baz" => Some(baz_src.to_string()),
563 _ => None,
564 },
565 |current, required| {
566 if current == "main" && required == "foo/bar" {
567 Some("foo/bar".to_string())
568 } else if current == "foo/bar" && required == "../baz" {
569 Some("baz".to_string())
571 } else {
572 None
573 }
574 },
575 );
576
577 assert!(
578 diagnostics.is_empty(),
579 "Expected no diagnostics, got: {:?}",
580 diagnostics
581 );
582 }
583}