#![cfg_attr(feature = "strict_docs", allow(missing_docs))]
use crate::abi::TREE_SITTER_LANGUAGE_VERSION;
use adze_glr_core::ParseTable;
use adze_ir::Grammar;
use proc_macro2::TokenStream;
use quote::quote;
#[cfg(not(debug_assertions))]
macro_rules! debug_trace {
($($arg:tt)*) => {};
}
#[cfg(debug_assertions)]
macro_rules! debug_trace {
($($arg:tt)*) => {
if std::env::var("RUST_LOG")
.ok()
.unwrap_or_default()
.contains("debug")
{
eprintln!($($arg)*);
}
};
}
pub struct LanguageGenerator<'a> {
grammar: &'a Grammar,
parse_table: &'a ParseTable,
}
impl<'a> LanguageGenerator<'a> {
pub fn new(grammar: &'a Grammar, parse_table: &'a ParseTable) -> Self {
Self {
grammar,
parse_table,
}
}
pub fn generate(&self) -> TokenStream {
let language_name = &self.grammar.name;
let language_fn_ident = quote::format_ident!("tree_sitter_{}", language_name);
let symbol_names = self.generate_symbol_names();
let field_names = self.generate_field_names();
let symbol_metadata = self.generate_symbol_metadata();
let parse_actions = self.generate_parse_actions();
let lex_modes = self.generate_lex_modes();
let (compressed_table, small_table_map) = self.generate_compressed_tables();
let symbol_name_indices: Vec<usize> = (0..symbol_names.len()).collect();
let field_name_indices: Vec<usize> = (0..field_names.len()).collect();
let symbol_count = self.count_symbols();
let token_count = self.parse_table.token_count as u32;
let field_count = self.grammar.fields.len() as u32;
let state_count = self.parse_table.state_count as u32;
let external_token_count = self.parse_table.external_token_count as u32;
let large_state_count = self.determine_large_state_count() as u32;
let production_id_count = self.count_production_ids() as u32;
quote! {
use adze::tree_sitter as ts;
use crate::abi::{TSLanguage, TSSymbol, TSStateId, TSLexState, TSParseAction, ExternalScanner};
const TREE_SITTER_LANGUAGE_VERSION: u32 = 15;
const EXTERNAL_TOKEN_COUNT: u32 = #external_token_count;
static SYMBOL_NAMES: &[&str] = &[#(#symbol_names),*];
static SYMBOL_NAMES_PTRS: &[*const u8] = &[
#(SYMBOL_NAMES[#symbol_name_indices].as_ptr()),*
];
static FIELD_NAMES: &[&str] = &[#(#field_names),*];
static FIELD_NAMES_PTRS: &[*const u8] = &[
#(FIELD_NAMES[#field_name_indices].as_ptr()),*
];
static SYMBOL_METADATA: &[u8] = &[#(#symbol_metadata),*];
static PARSE_ACTIONS: &[TSParseAction] = &[#(#parse_actions),*];
static LEX_MODES: &[TSLexState] = &[#(#lex_modes),*];
static PARSE_TABLE: &[u16] = &[#(#compressed_table),*];
static SMALL_PARSE_TABLE_MAP: &[u32] = &[#(#small_table_map),*];
static FIELD_MAP_SLICES: &[u16] = &[];
static FIELD_MAP_ENTRIES: &[u16] = &[];
static PUBLIC_SYMBOL_MAP: &[TSSymbol] = &[
#(TSSymbol(#symbol_name_indices as u16)),*
];
static PRIMARY_STATE_IDS: &[TSStateId] = &[
#(TSStateId(#symbol_name_indices as u16)),*
];
static EXTERNAL_SCANNER: ExternalScanner = ExternalScanner::default();
static LANGUAGE: TSLanguage = TSLanguage {
version: #TREE_SITTER_LANGUAGE_VERSION,
symbol_count: #symbol_count,
alias_count: 0, token_count: #token_count,
external_token_count: EXTERNAL_TOKEN_COUNT,
state_count: #state_count,
large_state_count: #large_state_count,
production_id_count: #production_id_count,
field_count: #field_count,
max_alias_sequence_length: 0,
parse_table: PARSE_TABLE.as_ptr(),
small_parse_table: PARSE_TABLE.as_ptr().wrapping_add(#large_state_count as usize * #symbol_count as usize),
small_parse_table_map: SMALL_PARSE_TABLE_MAP.as_ptr(),
parse_actions: PARSE_ACTIONS.as_ptr(),
symbol_names: SYMBOL_NAMES_PTRS.as_ptr(),
field_names: FIELD_NAMES_PTRS.as_ptr(),
field_map_slices: FIELD_MAP_SLICES.as_ptr(),
field_map_entries: FIELD_MAP_ENTRIES.as_ptr(),
symbol_metadata: SYMBOL_METADATA.as_ptr(),
public_symbol_map: PUBLIC_SYMBOL_MAP.as_ptr(),
alias_map: std::ptr::null(),
alias_sequences: std::ptr::null(),
lex_modes: LEX_MODES.as_ptr(),
lex_fn: None, keyword_lex_fn: None,
keyword_capture_token: TSSymbol(0),
external_scanner: EXTERNAL_SCANNER,
primary_state_ids: PRIMARY_STATE_IDS.as_ptr(),
};
pub fn language() -> ts::Language {
unsafe {
ts::Language::from_raw(&LANGUAGE as *const TSLanguage as *const _)
}
}
#[unsafe(no_mangle)]
pub extern "C" fn #language_fn_ident() -> ts::Language {
unsafe { language() }
}
}
}
fn generate_symbol_names(&self) -> Vec<String> {
let mut names = Vec::with_capacity(self.parse_table.symbol_count);
for (i, symbol_id) in self.parse_table.index_to_symbol.iter().enumerate() {
if i == 0 {
names.push("end".to_string());
continue;
}
let name = if let Some(token) = self.grammar.tokens.get(symbol_id) {
token.name.clone()
} else if let Some(external) = self
.grammar
.externals
.iter()
.find(|e| e.symbol_id == *symbol_id)
{
external.name.clone()
} else {
self.grammar
.rule_names
.get(symbol_id)
.cloned()
.unwrap_or_else(|| format!("rule_{}", symbol_id.0))
};
debug_trace!(
"DEBUG: Symbol index {} -> ID {} (name {})",
i,
symbol_id.0,
name
);
names.push(name);
}
names
}
fn generate_field_names(&self) -> Vec<String> {
let mut names = vec![];
for (_id, name) in &self.grammar.fields {
names.push(name.clone());
}
names
}
fn generate_symbol_metadata(&self) -> Vec<u8> {
let symbol_count = self.count_symbols();
let mut metadata = vec![0u8; symbol_count];
for item in metadata.iter_mut().take(symbol_count) {
*item = 0b11;
}
metadata
}
fn generate_parse_actions(&self) -> Vec<TokenStream> {
vec![quote! {
TSParseAction {
action_type: 0,
extra: 0,
child_count: 0,
dynamic_precedence: 0,
symbol: TSSymbol(0),
}
}]
}
fn generate_lex_modes(&self) -> Vec<TokenStream> {
let state_count = self.parse_table.state_count;
let mut modes = vec![];
for i in 0..state_count {
modes.push(quote! {
TSLexState {
lex_state: #i as u16,
external_lex_state: 0,
}
});
}
modes
}
fn generate_compressed_tables(&self) -> (Vec<u16>, Vec<u32>) {
let large_state_count = self.determine_large_state_count();
let mut compressed_table = Vec::new();
let mut small_table_map = Vec::new();
for state in 0..large_state_count {
for symbol in 0..self.parse_table.symbol_count {
let action = self.get_action(state, symbol);
compressed_table.push(self.encode_action(action));
}
}
let mut small_table_data = Vec::new();
for state in large_state_count..self.parse_table.state_count {
small_table_map.push(small_table_data.len() as u32);
let mut non_error_actions = Vec::new();
for symbol in 0..self.parse_table.symbol_count {
let action = self.get_action(state, symbol);
if !self.is_error_action(action) {
non_error_actions.push((symbol, action));
}
}
small_table_data.push(non_error_actions.len() as u16);
for (symbol, action) in non_error_actions {
small_table_data.push(symbol as u16);
small_table_data.push(self.encode_action(action));
}
}
if small_table_map.is_empty() {
small_table_map.push(0);
}
compressed_table.extend(small_table_data);
(compressed_table, small_table_map)
}
fn determine_large_state_count(&self) -> usize {
0
}
fn get_action(&self, state: usize, symbol: usize) -> u16 {
if state < self.parse_table.action_table.len()
&& symbol < self.parse_table.action_table[state].len()
{
let action_cell = &self.parse_table.action_table[state][symbol];
if action_cell.is_empty() {
0xFFFE } else {
let action = &action_cell[0];
match action {
adze_glr_core::Action::Shift(s) => s.0,
adze_glr_core::Action::Reduce(r) => 0x8000 | (r.0 + 1),
adze_glr_core::Action::Accept => 0xFFFF,
adze_glr_core::Action::Error => 0xFFFE,
adze_glr_core::Action::Recover => 0xFFFD, adze_glr_core::Action::Fork(_) => 0xFFFE, _ => 0xFFFE, }
}
} else {
0xFFFE }
}
fn encode_action(&self, action: u16) -> u16 {
action
}
fn is_error_action(&self, action: u16) -> bool {
action == 0xFFFE
}
fn count_symbols(&self) -> usize {
1 + self.grammar.tokens.len() +
self.grammar.rules.len()
}
fn count_production_ids(&self) -> usize {
let mut max_production_id = 0;
for (_, rules) in &self.grammar.rules {
for rule in rules {
max_production_id = max_production_id.max(rule.production_id.0);
}
}
(max_production_id + 1) as usize
}
pub fn generate_symbol_metadata_public(&self) -> Vec<u8> {
self.generate_symbol_metadata()
}
pub fn count_production_ids_public(&self) -> usize {
self.count_production_ids()
}
}
#[cfg(test)]
mod tests {
use super::*;
use adze_ir::*;
#[test]
fn test_language_generation() {
let mut grammar = Grammar::new("test".to_string());
let num_token = Token {
name: "number".to_string(),
pattern: TokenPattern::Regex(r"\d+".to_string()),
fragile: false,
};
grammar.tokens.insert(SymbolId(1), num_token);
let parse_table = crate::empty_table!(states: 10, terms: 4, nonterms: 0);
let generator = LanguageGenerator::new(&grammar, &parse_table);
let output = generator.generate();
let output_str = output.to_string();
assert!(output_str.contains("TSLanguage"));
assert!(output_str.contains("tree_sitter_test"));
}
}