#![cfg_attr(feature = "strict_docs", allow(missing_docs))]
use crate::compress::CompressedParseTable;
pub struct LanguageValidator<'a> {
language: &'a TSLanguage,
tables: &'a CompressedParseTable,
}
#[repr(C)]
pub struct TSLanguage {
pub version: u32,
pub symbol_count: u32,
pub alias_count: u32,
pub token_count: u32,
pub external_token_count: u32,
pub state_count: u32,
pub large_state_count: u32,
pub production_id_count: u32,
pub field_count: u32,
pub max_alias_sequence_length: u16,
pub parse_table: *const u16,
pub small_parse_table: *const u16,
pub small_parse_table_map: *const u32,
pub parse_actions: *const TSParseActionEntry,
pub symbol_names: *const *const i8,
pub field_names: *const *const i8,
pub field_map_slices: *const TSFieldMapSlice,
pub field_map_entries: *const TSFieldMapEntry,
pub symbol_metadata: *const TSSymbolMetadata,
pub public_symbol_map: *const TSSymbol,
pub alias_map: *const u16,
pub alias_sequences: *const TSSymbol,
pub lex_modes: *const TSLexMode,
pub lex_fn: Option<unsafe extern "C" fn(*mut TSLexer, TSStateId) -> bool>,
pub keyword_lex_fn: Option<unsafe extern "C" fn(*mut TSLexer, TSStateId) -> bool>,
pub keyword_capture_token: TSSymbol,
pub external_scanner_data: TSExternalScannerData,
pub primary_state_ids: *const TSStateId,
}
#[repr(C)]
pub struct TSParseActionEntry {
pub action: u32,
}
#[repr(C)]
pub struct TSFieldMapSlice {
pub start: u16,
pub length: u16,
}
#[repr(C)]
pub struct TSFieldMapEntry {
pub field_id: u16,
pub child_index: u8,
pub inherited: bool,
}
#[repr(C)]
pub struct TSSymbolMetadata {
pub visible: bool,
pub named: bool,
}
#[repr(C)]
pub struct TSLexMode {
pub lex_mode_id: u8,
}
#[repr(C)]
pub struct TSExternalScannerData {
pub states: *const bool,
pub symbol_map: *const TSSymbol,
pub create: Option<unsafe extern "C" fn() -> *mut std::ffi::c_void>,
pub destroy: Option<unsafe extern "C" fn(*mut std::ffi::c_void)>,
pub scan:
Option<unsafe extern "C" fn(*mut std::ffi::c_void, *mut TSLexer, *const bool) -> bool>,
pub serialize: Option<unsafe extern "C" fn(*mut std::ffi::c_void, *mut u8) -> u32>,
pub deserialize: Option<unsafe extern "C" fn(*mut std::ffi::c_void, *const u8, u32)>,
}
#[repr(C)]
pub struct TSLexer {
pub lookahead: i32,
pub result_symbol: TSSymbol,
pub advance: Option<unsafe extern "C" fn(*mut TSLexer, bool)>,
pub mark_end: Option<unsafe extern "C" fn(*mut TSLexer)>,
pub get_column: Option<unsafe extern "C" fn(*mut TSLexer) -> u32>,
pub is_at_included_range_start: Option<unsafe extern "C" fn(*mut TSLexer) -> bool>,
pub eof: Option<unsafe extern "C" fn(*mut TSLexer) -> bool>,
}
pub type TSSymbol = u16;
pub type TSStateId = u16;
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationError {
InvalidVersion { expected: u32, actual: u32 },
SymbolCountMismatch { language: u32, tables: u32 },
StateCountMismatch { language: u32, tables: u32 },
NullPointer(&'static str),
FieldNamesNotSorted,
InvalidSymbolMetadata { symbol: TSSymbol, reason: String },
TableDimensionMismatch { expected: usize, actual: usize },
InvalidProductionId { id: u32, max: u32 },
InvalidFieldMapping { field_id: u16, max: u16 },
}
impl<'a> LanguageValidator<'a> {
pub fn new(language: &'a TSLanguage, tables: &'a CompressedParseTable) -> Self {
Self { language, tables }
}
#[must_use = "validation result must be checked"]
pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
let mut errors = Vec::new();
if self.language.version != 15 {
errors.push(ValidationError::InvalidVersion {
expected: 15,
actual: self.language.version,
});
}
self.validate_counts(&mut errors);
self.validate_pointers(&mut errors);
self.validate_symbol_metadata(&mut errors);
self.validate_field_names(&mut errors);
self.validate_table_dimensions(&mut errors);
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
fn validate_counts(&self, errors: &mut Vec<ValidationError>) {
let table_symbol_count = self.tables.symbol_count();
if self.language.symbol_count != table_symbol_count as u32 {
errors.push(ValidationError::SymbolCountMismatch {
language: self.language.symbol_count,
tables: table_symbol_count as u32,
});
}
let table_state_count = self.tables.state_count();
if self.language.state_count != table_state_count as u32 {
errors.push(ValidationError::StateCountMismatch {
language: self.language.state_count,
tables: table_state_count as u32,
});
}
}
fn validate_pointers(&self, errors: &mut Vec<ValidationError>) {
if self.language.parse_table.is_null() && self.language.small_parse_table.is_null() {
errors.push(ValidationError::NullPointer(
"parse_table or small_parse_table",
));
}
if self.language.symbol_names.is_null() {
errors.push(ValidationError::NullPointer("symbol_names"));
}
if self.language.symbol_metadata.is_null() {
errors.push(ValidationError::NullPointer("symbol_metadata"));
}
if self.language.field_count > 0 && self.language.field_names.is_null() {
errors.push(ValidationError::NullPointer("field_names"));
}
}
fn validate_symbol_metadata(&self, errors: &mut Vec<ValidationError>) {
if self.language.symbol_metadata.is_null() {
return;
}
unsafe {
let metadata_slice = std::slice::from_raw_parts(
self.language.symbol_metadata,
self.language.symbol_count as usize,
);
if metadata_slice[0].visible || metadata_slice[0].named {
errors.push(ValidationError::InvalidSymbolMetadata {
symbol: 0,
reason: "EOF symbol must be invisible and unnamed".to_string(),
});
}
}
}
fn validate_field_names(&self, errors: &mut Vec<ValidationError>) {
if self.language.field_count == 0 || self.language.field_names.is_null() {
return;
}
unsafe {
let field_names = std::slice::from_raw_parts(
self.language.field_names,
self.language.field_count as usize + 1, );
for i in 2..field_names.len() {
let prev = std::ffi::CStr::from_ptr(field_names[i - 1]);
let curr = std::ffi::CStr::from_ptr(field_names[i]);
if prev >= curr {
errors.push(ValidationError::FieldNamesNotSorted);
break;
}
}
}
}
#[allow(clippy::ptr_arg)]
fn validate_table_dimensions(&self, _errors: &mut Vec<ValidationError>) {
if !self.language.small_parse_table.is_null() {
let _expected_entries =
self.language.state_count as usize * self.language.symbol_count as usize;
} else if !self.language.parse_table.is_null() {
}
}
}
#[cfg(test)]
pub fn create_test_language() -> TSLanguage {
TSLanguage {
version: 15,
symbol_count: 10,
alias_count: 0,
token_count: 5,
external_token_count: 0,
state_count: 20,
large_state_count: 0,
production_id_count: 0,
field_count: 0,
max_alias_sequence_length: 0,
parse_table: std::ptr::null(),
small_parse_table: std::ptr::null(),
small_parse_table_map: std::ptr::null(),
parse_actions: std::ptr::null(),
symbol_names: std::ptr::null(),
field_names: std::ptr::null(),
field_map_slices: std::ptr::null(),
field_map_entries: std::ptr::null(),
symbol_metadata: std::ptr::null(),
public_symbol_map: std::ptr::null(),
alias_map: std::ptr::null(),
alias_sequences: std::ptr::null(),
lex_modes: std::ptr::null(),
lex_fn: None,
keyword_lex_fn: None,
keyword_capture_token: 0,
external_scanner_data: TSExternalScannerData {
states: std::ptr::null(),
symbol_map: std::ptr::null(),
create: None,
destroy: None,
scan: None,
serialize: None,
deserialize: None,
},
primary_state_ids: std::ptr::null(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version_validation() {
let mut language = create_test_language();
language.version = 14;
let tables = CompressedParseTable::new_for_testing(10, 20);
let validator = LanguageValidator::new(&language, &tables);
let result = validator.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(
errors
.iter()
.any(|e| matches!(e, ValidationError::InvalidVersion { .. }))
);
}
#[test]
fn test_null_pointer_validation() {
let language = create_test_language();
let tables = CompressedParseTable::new_for_testing(10, 20);
let validator = LanguageValidator::new(&language, &tables);
let result = validator.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(
errors
.iter()
.any(|e| matches!(e, ValidationError::NullPointer(_)))
);
}
}