mod ffi;
use std::{
cmp::Ordering,
error::Error as StdError,
fmt,
marker::PhantomData,
ptr::{self, NonNull},
slice,
sync::Arc,
time::Duration,
};
const DEFAULT_CHECK_MODULE_NAME: &str = "main";
const DEFAULT_DEFINITIONS_MODULE_NAME: &str = "@definitions";
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Severity {
Error,
Warning,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Diagnostic {
pub line: u32,
pub col: u32,
pub end_line: u32,
pub end_col: u32,
pub severity: Severity,
pub message: String,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct CheckResult {
pub diagnostics: Vec<Diagnostic>,
pub timed_out: bool,
pub cancelled: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EntrypointParam {
pub name: String,
pub annotation: String,
pub optional: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct EntrypointSchema {
pub params: Vec<EntrypointParam>,
}
impl CheckResult {
pub fn is_ok(&self) -> bool {
!self
.diagnostics
.iter()
.any(|diagnostic| diagnostic.severity == Severity::Error)
}
pub fn errors(&self) -> Vec<&Diagnostic> {
self.diagnostics_with_severity(Severity::Error)
}
pub fn warnings(&self) -> Vec<&Diagnostic> {
self.diagnostics_with_severity(Severity::Warning)
}
fn diagnostics_with_severity(&self, severity: Severity) -> Vec<&Diagnostic> {
self.diagnostics
.iter()
.filter(|diagnostic| diagnostic.severity == severity)
.collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CheckerPolicy {
pub strict_mode: bool,
pub solver: &'static str,
pub exposes_batch_queue: bool,
}
pub const fn checker_policy() -> CheckerPolicy {
CheckerPolicy {
strict_mode: true,
solver: "new",
exposes_batch_queue: false,
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Error {
CreateCheckerFailed,
CreateCancellationTokenFailed,
Definitions(String),
EntrypointSchema(String),
InputTooLarge {
kind: &'static str,
len: usize,
},
}
impl fmt::Display for Error {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::CreateCheckerFailed => formatter.write_str("failed to create Luau checker"),
Self::CreateCancellationTokenFailed => {
formatter.write_str("failed to create Luau cancellation token")
}
Self::Definitions(message) => {
write!(formatter, "failed to load Luau definitions: {message}")
}
Self::EntrypointSchema(message) => {
write!(
formatter,
"failed to extract Luau entrypoint schema: {message}"
)
}
Self::InputTooLarge { kind, len } => {
write!(
formatter,
"{kind} input is too large for checker FFI boundary ({len} bytes)"
)
}
}
}
}
impl StdError for Error {}
#[derive(Debug, Clone)]
pub struct CheckerOptions {
pub default_timeout: Option<Duration>,
pub default_module_name: String,
pub default_definitions_module_name: String,
}
impl Default for CheckerOptions {
fn default() -> Self {
Self {
default_timeout: None,
default_module_name: DEFAULT_CHECK_MODULE_NAME.to_owned(),
default_definitions_module_name: DEFAULT_DEFINITIONS_MODULE_NAME.to_owned(),
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct CheckOptions<'a> {
pub timeout: Option<Duration>,
pub module_name: Option<&'a str>,
pub cancellation_token: Option<&'a CancellationToken>,
}
#[derive(Clone, Debug)]
pub struct CancellationToken {
inner: Arc<CancellationTokenInner>,
}
#[derive(Debug)]
struct CancellationTokenInner {
raw: NonNull<ffi::LuauCancellationToken>,
}
unsafe impl Send for CancellationTokenInner {}
unsafe impl Sync for CancellationTokenInner {}
impl Drop for CancellationTokenInner {
fn drop(&mut self) {
unsafe { ffi::luau_cancellation_token_free(self.raw.as_ptr()) };
}
}
impl CancellationToken {
pub fn new() -> Result<Self, Error> {
let raw = NonNull::new(unsafe { ffi::luau_cancellation_token_new() })
.ok_or(Error::CreateCancellationTokenFailed)?;
Ok(Self {
inner: Arc::new(CancellationTokenInner { raw }),
})
}
pub fn cancel(&self) {
unsafe { ffi::luau_cancellation_token_cancel(self.inner.raw.as_ptr()) };
}
pub fn reset(&self) {
unsafe { ffi::luau_cancellation_token_reset(self.inner.raw.as_ptr()) };
}
fn raw(&self) -> *mut ffi::LuauCancellationToken {
self.inner.raw.as_ptr()
}
}
pub struct Checker {
inner: NonNull<ffi::LuauChecker>,
options: CheckerOptions,
}
unsafe impl Send for Checker {}
impl Checker {
pub fn new() -> Result<Self, Error> {
Self::with_options(CheckerOptions::default())
}
pub fn with_options(options: CheckerOptions) -> Result<Self, Error> {
let inner =
NonNull::new(unsafe { ffi::luau_checker_new() }).ok_or(Error::CreateCheckerFailed)?;
Ok(Self { inner, options })
}
pub fn options(&self) -> &CheckerOptions {
&self.options
}
pub fn add_definitions(&mut self, defs: &str) -> Result<(), Error> {
let module_name = self.options.default_definitions_module_name.clone();
self.add_definitions_with_name(defs, &module_name)
}
pub fn add_definitions_with_name(
&mut self,
defs: &str,
module_name: &str,
) -> Result<(), Error> {
let defs = FfiStr::new(defs, "definitions")?;
let module_name = FfiStr::new(module_name, "definition module name")?;
let raw = RawStringGuard::new(unsafe {
ffi::luau_checker_add_definitions(
self.inner.as_ptr(),
defs.ptr(),
defs.len(),
module_name.ptr(),
module_name.len(),
)
});
match raw.message() {
Some(message) => Err(Error::Definitions(message)),
None => Ok(()),
}
}
pub fn check(&mut self, source: &str) -> Result<CheckResult, Error> {
self.check_with_options(source, CheckOptions::default())
}
pub fn check_with_options(
&mut self,
source: &str,
options: CheckOptions<'_>,
) -> Result<CheckResult, Error> {
let source = FfiStr::new(source, "source")?;
let module_name = options
.module_name
.unwrap_or(self.options.default_module_name.as_str());
let module_name = FfiStr::new(module_name, "module name")?;
let timeout = options.timeout.or(self.options.default_timeout);
let raw_options = ffi::LuauCheckOptions {
module_name: module_name.ptr(),
module_name_len: module_name.len(),
has_timeout: u32::from(timeout.is_some()),
timeout_seconds: timeout.map_or(0.0, |duration| duration.as_secs_f64()),
cancellation_token: options
.cancellation_token
.map_or(ptr::null_mut(), CancellationToken::raw),
};
let raw = unsafe {
ffi::luau_checker_check(
self.inner.as_ptr(),
source.ptr(),
source.len(),
&raw_options,
)
};
let raw = RawCheckResultGuard::new(raw);
let mut diagnostics = collect_diagnostics(raw.as_ref());
diagnostics.sort_by(diagnostic_sort_key);
Ok(CheckResult {
diagnostics,
timed_out: raw.as_ref().timed_out != 0,
cancelled: raw.as_ref().cancelled != 0,
})
}
}
pub fn extract_entrypoint_schema(source: &str) -> Result<EntrypointSchema, Error> {
let source = FfiStr::new(source, "source")?;
let raw = unsafe { ffi::luau_extract_entrypoint_schema(source.ptr(), source.len()) };
let raw = RawEntrypointSchemaGuard::new(raw);
if raw.as_ref().error_len != 0 {
return Err(Error::EntrypointSchema(string_from_raw(
raw.as_ref().error,
raw.as_ref().error_len,
)));
}
Ok(EntrypointSchema {
params: collect_entrypoint_params(raw.as_ref()),
})
}
impl Drop for Checker {
fn drop(&mut self) {
unsafe { ffi::luau_checker_free(self.inner.as_ptr()) };
}
}
#[derive(Clone, Copy)]
struct FfiStr<'a> {
ptr: *const u8,
len: u32,
_marker: PhantomData<&'a str>,
}
impl<'a> FfiStr<'a> {
fn new(value: &'a str, kind: &'static str) -> Result<Self, Error> {
let len = u32::try_from(value.len()).map_err(|_| Error::InputTooLarge {
kind,
len: value.len(),
})?;
Ok(Self {
ptr: if len == 0 {
ptr::null()
} else {
value.as_ptr()
},
len,
_marker: PhantomData,
})
}
fn ptr(self) -> *const u8 {
self.ptr
}
fn len(self) -> u32 {
self.len
}
}
struct RawCheckResultGuard {
raw: ffi::LuauCheckResult,
}
impl RawCheckResultGuard {
fn new(raw: ffi::LuauCheckResult) -> Self {
Self { raw }
}
fn as_ref(&self) -> &ffi::LuauCheckResult {
&self.raw
}
}
impl Drop for RawCheckResultGuard {
fn drop(&mut self) {
unsafe { ffi::luau_check_result_free(self.raw) };
}
}
struct RawStringGuard {
raw: ffi::LuauString,
}
impl RawStringGuard {
fn new(raw: ffi::LuauString) -> Self {
Self { raw }
}
fn message(&self) -> Option<String> {
if self.raw.len == 0 {
None
} else {
Some(string_from_raw(self.raw.data, self.raw.len))
}
}
}
impl Drop for RawStringGuard {
fn drop(&mut self) {
unsafe { ffi::luau_string_free(self.raw) };
}
}
struct RawEntrypointSchemaGuard {
raw: ffi::LuauEntrypointSchemaResult,
}
impl RawEntrypointSchemaGuard {
fn new(raw: ffi::LuauEntrypointSchemaResult) -> Self {
Self { raw }
}
fn as_ref(&self) -> &ffi::LuauEntrypointSchemaResult {
&self.raw
}
}
impl Drop for RawEntrypointSchemaGuard {
fn drop(&mut self) {
unsafe { ffi::luau_entrypoint_schema_result_free(self.raw) };
}
}
fn string_from_raw(ptr: *const u8, len: u32) -> String {
if ptr.is_null() || len == 0 {
return String::new();
}
let bytes = unsafe { slice::from_raw_parts(ptr, len as usize) };
String::from_utf8_lossy(bytes).into_owned()
}
impl Severity {
fn from_ffi(code: u32) -> Self {
match code {
0 => Self::Error,
_ => Self::Warning,
}
}
}
fn collect_diagnostics(raw: &ffi::LuauCheckResult) -> Vec<Diagnostic> {
unsafe { raw_slice(raw.diagnostics, raw.diagnostic_count) }
.iter()
.map(|diagnostic| Diagnostic {
line: diagnostic.line,
col: diagnostic.col,
end_line: diagnostic.end_line,
end_col: diagnostic.end_col,
severity: Severity::from_ffi(diagnostic.severity),
message: string_from_raw(diagnostic.message, diagnostic.message_len),
})
.collect()
}
fn collect_entrypoint_params(raw: &ffi::LuauEntrypointSchemaResult) -> Vec<EntrypointParam> {
unsafe { raw_slice(raw.params, raw.param_count) }
.iter()
.map(|param| EntrypointParam {
name: string_from_raw(param.name, param.name_len),
annotation: string_from_raw(param.annotation, param.annotation_len),
optional: param.optional != 0,
})
.collect()
}
unsafe fn raw_slice<'a, T>(ptr: *const T, len: u32) -> &'a [T] {
if len == 0 {
&[]
} else {
debug_assert!(!ptr.is_null(), "non-empty shim slice must not be null");
unsafe { slice::from_raw_parts(ptr, len as usize) }
}
}
fn diagnostic_sort_key(left: &Diagnostic, right: &Diagnostic) -> Ordering {
left.line
.cmp(&right.line)
.then(left.col.cmp(&right.col))
.then(left.severity.cmp(&right.severity))
.then(left.message.cmp(&right.message))
}
#[cfg(test)]
mod tests {
use super::{
CheckResult, CheckerOptions, Diagnostic, Severity, checker_policy,
extract_entrypoint_schema,
};
#[test]
fn check_result_ok_with_warnings() {
let result = CheckResult {
diagnostics: vec![Diagnostic {
line: 0,
col: 0,
end_line: 0,
end_col: 1,
severity: Severity::Warning,
message: "unused local".to_owned(),
}],
timed_out: false,
cancelled: false,
};
assert!(result.is_ok());
assert_eq!(1, result.warnings().len());
assert_eq!(0, result.errors().len());
}
#[test]
fn check_result_not_ok_with_error() {
let result = CheckResult {
diagnostics: vec![Diagnostic {
line: 1,
col: 1,
end_line: 1,
end_col: 5,
severity: Severity::Error,
message: "type mismatch".to_owned(),
}],
timed_out: false,
cancelled: false,
};
assert!(!result.is_ok());
assert_eq!(0, result.warnings().len());
assert_eq!(1, result.errors().len());
}
#[test]
fn policy_is_strict_new_solver_and_queue_free() {
let policy = checker_policy();
assert!(policy.strict_mode);
assert_eq!("new", policy.solver);
assert!(!policy.exposes_batch_queue);
}
#[test]
fn checker_options_defaults_are_stable() {
let options = CheckerOptions::default();
assert_eq!("main", options.default_module_name);
assert_eq!("@definitions", options.default_definitions_module_name);
assert!(options.default_timeout.is_none());
}
#[test]
fn extract_entrypoint_schema_reads_params() {
let schema = extract_entrypoint_schema(
r#"
return function(target: Node, count: number?, payload: JsonValue)
return nil
end
"#,
)
.expect("schema");
assert_eq!(3, schema.params.len());
assert_eq!("target", schema.params[0].name);
assert_eq!("Node", schema.params[0].annotation);
assert!(!schema.params[0].optional);
assert_eq!("count", schema.params[1].name);
assert_eq!("number?", schema.params[1].annotation);
assert!(schema.params[1].optional);
assert_eq!("payload", schema.params[2].name);
assert_eq!("JsonValue", schema.params[2].annotation);
assert!(!schema.params[2].optional);
}
#[test]
fn extract_entrypoint_schema_rejects_indirect_return() {
let error = extract_entrypoint_schema(
r#"
local main = function(target: Node)
return nil
end
return main
"#,
)
.expect_err("schema should fail");
assert!(
error
.to_string()
.contains("script must use a direct `return function(...) ... end` entrypoint"),
"{error}"
);
}
}