use std::collections::HashMap;
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_int, c_uint, c_void};
#[derive(Debug, Clone)]
pub struct Diagnostic {
pub severity: u8, pub line: u32,
pub col: u32,
pub end_line: u32,
pub end_col: u32,
pub message: String,
}
#[repr(C)]
pub struct LuauAnalyzerOpaque {
_private: [u8; 0],
}
type DiagnosticCallback = unsafe extern "C" fn(
context: *mut c_void,
severity: c_int,
line: c_uint,
col: c_uint,
end_line: c_uint,
end_col: c_uint,
message: *const c_char,
);
type ReadSourceCallback =
unsafe extern "C" fn(context: *mut c_void, module_name: *const c_char) -> *const c_char;
type ResolveModuleCallback = unsafe extern "C" fn(
context: *mut c_void,
current_module: *const c_char,
required_name: *const c_char,
) -> *const c_char;
unsafe extern "C" {
fn luau_analyzer_create() -> *mut LuauAnalyzerOpaque;
fn luau_analyzer_destroy(analyzer: *mut LuauAnalyzerOpaque);
fn luau_analyzer_add_definitions(analyzer: *mut LuauAnalyzerOpaque, source: *const c_char);
fn luau_analyzer_check(
analyzer: *mut LuauAnalyzerOpaque,
module_name: *const c_char,
read_callback: Option<ReadSourceCallback>,
resolve_callback: Option<ResolveModuleCallback>,
diag_callback: Option<DiagnosticCallback>,
context: *mut c_void,
);
}
struct CheckContext<'a> {
diagnostics: Vec<Diagnostic>,
cached_strings: HashMap<String, CString>,
resolver: &'a dyn Fn(&str) -> Option<String>,
path_resolver: &'a dyn Fn(&str, &str) -> Option<String>,
}
pub struct NativeAnalyzer {
ptr: *mut LuauAnalyzerOpaque,
}
impl NativeAnalyzer {
pub fn new() -> Self {
unsafe {
Self {
ptr: luau_analyzer_create(),
}
}
}
pub fn add_definitions(&mut self, source: &str) {
if let Ok(c_str) = CString::new(source) {
unsafe {
luau_analyzer_add_definitions(self.ptr, c_str.as_ptr());
}
}
}
pub fn check<F, P>(
&mut self,
module_name: &str,
resolver: F,
path_resolver: P,
) -> Vec<Diagnostic>
where
F: Fn(&str) -> Option<String>,
P: Fn(&str, &str) -> Option<String>,
{
let mut context = CheckContext {
diagnostics: Vec::new(),
cached_strings: HashMap::new(),
resolver: &resolver,
path_resolver: &path_resolver,
};
if let Ok(mod_cstr) = CString::new(module_name) {
unsafe extern "C" fn read_callback(
ctx_ptr: *mut c_void,
mod_name: *const c_char,
) -> *const c_char {
let ctx = unsafe { &mut *(ctx_ptr as *mut CheckContext) };
if mod_name.is_null() {
return std::ptr::null();
}
let name_str = unsafe { CStr::from_ptr(mod_name) }.to_string_lossy();
if let Some(c_str) = ctx.cached_strings.get(name_str.as_ref()) {
return c_str.as_ptr();
}
if let Some(src) = (ctx.resolver)(name_str.as_ref())
&& let Ok(c_str) = CString::new(src)
{
let ptr = c_str.as_ptr();
ctx.cached_strings.insert(name_str.into_owned(), c_str);
return ptr;
}
std::ptr::null()
}
unsafe extern "C" fn resolve_callback(
ctx_ptr: *mut c_void,
curr_mod: *const c_char,
req_name: *const c_char,
) -> *const c_char {
let ctx = unsafe { &mut *(ctx_ptr as *mut CheckContext) };
if curr_mod.is_null() || req_name.is_null() {
return std::ptr::null();
}
let curr_mod_str = unsafe { CStr::from_ptr(curr_mod) }.to_string_lossy();
let req_name_str = unsafe { CStr::from_ptr(req_name) }.to_string_lossy();
let cache_key = format!("RESOLVED:{}:{}", curr_mod_str, req_name_str);
if let Some(c_str) = ctx.cached_strings.get(&cache_key) {
return c_str.as_ptr();
}
if let Some(resolved) =
(ctx.path_resolver)(curr_mod_str.as_ref(), req_name_str.as_ref())
&& let Ok(c_str) = CString::new(resolved)
{
let ptr = c_str.as_ptr();
ctx.cached_strings.insert(cache_key, c_str);
return ptr;
}
std::ptr::null()
}
unsafe extern "C" fn diag_callback(
ctx_ptr: *mut c_void,
severity: c_int,
line: c_uint,
col: c_uint,
end_line: c_uint,
end_col: c_uint,
message: *const c_char,
) {
let ctx = unsafe { &mut *(ctx_ptr as *mut CheckContext) };
let msg_str = if message.is_null() {
String::new()
} else {
unsafe { CStr::from_ptr(message) }
.to_string_lossy()
.into_owned()
};
ctx.diagnostics.push(Diagnostic {
severity: severity as u8,
line,
col,
end_line,
end_col,
message: msg_str,
});
}
unsafe {
let ctx_void = &mut context as *mut CheckContext as *mut c_void;
luau_analyzer_check(
self.ptr,
mod_cstr.as_ptr(),
Some(read_callback),
Some(resolve_callback),
Some(diag_callback),
ctx_void,
);
}
}
context.diagnostics
}
}
impl Default for NativeAnalyzer {
fn default() -> Self {
Self::new()
}
}
impl Drop for NativeAnalyzer {
fn drop(&mut self) {
unsafe {
if !self.ptr.is_null() {
luau_analyzer_destroy(self.ptr);
self.ptr = std::ptr::null_mut();
}
}
}
}
unsafe impl Send for NativeAnalyzer {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_analyzer() {
let analyzer = NativeAnalyzer::new();
assert!(!analyzer.ptr.is_null());
}
#[test]
fn test_check_simple_no_errors() {
let mut analyzer = NativeAnalyzer::new();
let source = "local _x: number = 10\nlocal _y: number = _x + 5\n";
let diagnostics = analyzer.check(
"main",
|name| {
if name == "main" {
Some(source.to_string())
} else {
None
}
},
|_, _| None,
);
assert!(
diagnostics.is_empty(),
"Expected no diagnostics, got: {:?}",
diagnostics
);
}
#[test]
fn test_check_type_error() {
let mut analyzer = NativeAnalyzer::new();
let source = "local _x: number = 'hello'\n";
let diagnostics = analyzer.check(
"main",
|name| {
if name == "main" {
Some(source.to_string())
} else {
None
}
},
|_, _| None,
);
assert!(
!diagnostics.is_empty(),
"Expected at least one type error diagnostic"
);
let diag = &diagnostics[0];
assert!(
diag.severity == 0 || diag.severity == 1,
"Expected error or warning severity, got: {}",
diag.severity
);
assert!(
diag.message.contains("string"),
"Expected message to mention 'string', got: {}",
diag.message
);
assert!(
diag.message.contains("number"),
"Expected message to mention 'number', got: {}",
diag.message
);
}
#[test]
fn test_check_syntax_error() {
let mut analyzer = NativeAnalyzer::new();
let source = "local x = \n";
let diagnostics = analyzer.check(
"main",
|name| {
if name == "main" {
Some(source.to_string())
} else {
None
}
},
|_, _| None,
);
assert!(!diagnostics.is_empty(), "Expected syntax error diagnostic");
assert_eq!(diagnostics[0].severity, 0);
}
#[test]
fn test_check_with_submodule() {
let mut analyzer = NativeAnalyzer::new();
let main_source = "local dep = require('dependency')\nlocal _x: number = dep.value\n";
let dep_source = "local M = {}\nM.value = 42\nreturn M\n";
let diagnostics = analyzer.check(
"main",
|name| match name {
"main" => Some(main_source.to_string()),
"dependency" => Some(dep_source.to_string()),
_ => None,
},
|current, required| {
if current == "main" && required == "dependency" {
Some("dependency".to_string())
} else {
None
}
},
);
assert!(
diagnostics.is_empty(),
"Expected no diagnostics, got: {:?}",
diagnostics
);
}
#[test]
fn test_multiple_checks_same_analyzer() {
let mut analyzer = NativeAnalyzer::new();
let src1 = "local _x: number = 10\n";
let diagnostics1 = analyzer.check(
"mod1",
|name| {
if name == "mod1" {
Some(src1.to_string())
} else {
None
}
},
|_, _| None,
);
assert!(diagnostics1.is_empty());
let src2 = "local _y: string = 'hello'\n";
let diagnostics2 = analyzer.check(
"mod2",
|name| {
if name == "mod2" {
Some(src2.to_string())
} else {
None
}
},
|_, _| None,
);
assert!(diagnostics2.is_empty());
}
#[test]
fn test_custom_definitions() {
let mut analyzer = NativeAnalyzer::new();
analyzer.add_definitions("declare function my_global_helper(val: string): number\n");
let correct_source = "--!strict\nlocal _x: number = my_global_helper('test')\n";
let diagnostics = analyzer.check(
"main_correct",
|name| {
if name == "main_correct" {
Some(correct_source.to_string())
} else {
None
}
},
|_, _| None,
);
assert!(
diagnostics.is_empty(),
"Expected no diagnostics, got: {:?}",
diagnostics
);
let incorrect_source = "--!strict\nlocal _x: number = my_global_helper(123)\n";
let diagnostics2 = analyzer.check(
"main_incorrect",
|name| {
if name == "main_incorrect" {
Some(incorrect_source.to_string())
} else {
None
}
},
|_, _| None,
);
println!("test_custom_definitions diagnostics: {:?}", diagnostics2);
assert!(
!diagnostics2.is_empty(),
"Expected a type error due to parameter type mismatch"
);
let msg = &diagnostics2[0].message;
assert!(
msg.contains("number") || msg.contains("string"),
"Got message: {}",
msg
);
}
#[test]
fn test_precise_error_coordinates() {
let mut analyzer = NativeAnalyzer::new();
let source = "--!strict\nlocal _x: number = 10\nlocal _y: string = 20\n";
let diagnostics = analyzer.check(
"main_precise",
|name| {
if name == "main_precise" {
Some(source.to_string())
} else {
None
}
},
|_, _| None,
);
assert!(!diagnostics.is_empty());
let diag = &diagnostics[0];
assert_eq!(diag.line, 2);
assert!(diag.col < 100);
}
#[test]
fn test_resolver_returns_none() {
use std::cell::RefCell;
use std::rc::Rc;
let mut analyzer = NativeAnalyzer::new();
let source = "--!strict\nlocal _dep = require('missing_module')\n";
let resolver_called = Rc::new(RefCell::new(false));
let resolver_called_clone = resolver_called.clone();
let diagnostics = analyzer.check(
"main_resolver",
|name| {
if name == "main_resolver" {
Some(source.to_string())
} else {
if name == "missing_module" {
*resolver_called_clone.borrow_mut() = true;
}
None }
},
|current, required| {
if current == "main_resolver" && required == "missing_module" {
Some("missing_module".to_string())
} else {
None
}
},
);
println!("test_resolver_returns_none diagnostics: {:?}", diagnostics);
assert!(*resolver_called.borrow());
assert!(diagnostics.is_empty());
}
#[test]
fn test_multithreaded_analyzer() {
use std::thread;
let mut analyzer = NativeAnalyzer::new();
analyzer.add_definitions("declare function thread_safe_helper(): ()\n");
let handle = thread::spawn(move || {
let source = "thread_safe_helper()\n";
let diagnostics = analyzer.check(
"main",
|name| {
if name == "main" {
Some(source.to_string())
} else {
None
}
},
|_, _| None,
);
assert!(diagnostics.is_empty());
analyzer });
let _analyzer = handle.join().unwrap();
}
#[test]
fn test_default_analyzer() {
let analyzer = NativeAnalyzer::default();
assert!(!analyzer.ptr.is_null());
}
#[test]
fn test_diagnostics_clone_and_debug() {
let diag = Diagnostic {
severity: 0,
line: 1,
col: 2,
end_line: 3,
end_col: 4,
message: "Test message".to_string(),
};
let cloned = diag.clone();
assert_eq!(cloned.severity, diag.severity);
assert_eq!(cloned.line, diag.line);
assert_eq!(cloned.col, diag.col);
assert_eq!(cloned.end_line, diag.end_line);
assert_eq!(cloned.end_col, diag.end_col);
assert_eq!(cloned.message, diag.message);
let debug_str = format!("{:?}", diag);
assert!(debug_str.contains("Test message"));
}
#[test]
fn test_check_with_nested_relative_modules() {
let mut analyzer = NativeAnalyzer::new();
let main_src = "local _bar = require('foo/bar')\n";
let bar_src = "local _baz = require('../baz')\nlocal M = {}\nreturn M\n";
let baz_src = "local M = {}\nM.value = 100\nreturn M\n";
let diagnostics = analyzer.check(
"main",
|name| match name {
"main" => Some(main_src.to_string()),
"foo/bar" => Some(bar_src.to_string()),
"baz" => Some(baz_src.to_string()),
_ => None,
},
|current, required| {
if current == "main" && required == "foo/bar" {
Some("foo/bar".to_string())
} else if current == "foo/bar" && required == "../baz" {
Some("baz".to_string())
} else {
None
}
},
);
assert!(
diagnostics.is_empty(),
"Expected no diagnostics, got: {:?}",
diagnostics
);
}
}