use std::collections::HashMap;
use std::sync::OnceLock;
use parking_lot::RwLock;
use crate::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Language {
Rust,
Python,
TypeScript,
Tsx,
JavaScript,
Go,
Markdown,
Sql,
}
impl Language {
pub fn from_lang_str(s: &str) -> Option<Self> {
match s.trim().to_ascii_lowercase().as_str() {
"rust" | "rs" => Some(Language::Rust),
"python" | "py" => Some(Language::Python),
"typescript" | "ts" => Some(Language::TypeScript),
"tsx" => Some(Language::Tsx),
"javascript" | "js" | "mjs" | "cjs" => Some(Language::JavaScript),
"go" => Some(Language::Go),
"markdown" | "md" => Some(Language::Markdown),
"sql" => Some(Language::Sql),
_ => None,
}
}
pub fn as_str(self) -> &'static str {
match self {
Language::Rust => "rust",
Language::Python => "python",
Language::TypeScript => "typescript",
Language::Tsx => "tsx",
Language::JavaScript => "javascript",
Language::Go => "go",
Language::Markdown => "markdown",
Language::Sql => "sql",
}
}
}
fn registry() -> &'static RwLock<HashMap<String, tree_sitter::Language>> {
static R: OnceLock<RwLock<HashMap<String, tree_sitter::Language>>> = OnceLock::new();
R.get_or_init(|| RwLock::new(HashMap::new()))
}
pub fn register_grammar(
name: impl Into<String>,
grammar: tree_sitter::Language,
) -> Option<tree_sitter::Language> {
let mut m = registry().write();
m.insert(name.into(), grammar)
}
pub fn unregister_grammar(name: &str) -> Option<tree_sitter::Language> {
let mut m = registry().write();
m.remove(name)
}
pub fn registered_grammars() -> Vec<String> {
let m = registry().read();
let mut v: Vec<String> = m.keys().cloned().collect();
v.sort();
v
}
fn registered_grammar(name: &str) -> Option<tree_sitter::Language> {
let m = registry().read();
m.get(name).cloned()
}
pub fn parse(lang: Language, source: &str) -> Result<tree_sitter::Tree> {
parse_with_cached(lang.as_str(), &grammar_for(lang), source)
}
pub fn parse_by_name(lang_name: &str, source: &str) -> Result<tree_sitter::Tree> {
if let Some(g) = registered_grammar(lang_name) {
let key = format!("dyn:{lang_name}");
return parse_with_cached(&key, &g, source);
}
if let Some(builtin) = Language::from_lang_str(lang_name) {
return parse(builtin, source);
}
Err(Error::query_execution(format!(
"no tree-sitter grammar registered for language '{lang_name}' \
(try register_grammar(name, lang))"
)))
}
fn parse_with_cached(
cache_key: &str,
ts_lang: &tree_sitter::Language,
source: &str,
) -> Result<tree_sitter::Tree> {
use std::cell::RefCell;
use std::collections::HashMap;
thread_local! {
static PARSERS: RefCell<HashMap<String, tree_sitter::Parser>> = RefCell::new(HashMap::new());
}
PARSERS.with(|cell| {
let mut map = cell.borrow_mut();
if !map.contains_key(cache_key) {
let mut p = tree_sitter::Parser::new();
p.set_language(ts_lang).map_err(|e| {
Error::query_execution(format!("tree-sitter set_language failed: {e}"))
})?;
map.insert(cache_key.to_string(), p);
}
let parser = map
.get_mut(cache_key)
.ok_or_else(|| Error::internal("Parser cache entry vanished"))?;
parser
.parse(source, None)
.ok_or_else(|| Error::query_execution("tree-sitter parse returned None"))
})
}
fn grammar_for(lang: Language) -> tree_sitter::Language {
match lang {
Language::Rust => tree_sitter_rust::LANGUAGE.into(),
Language::Python => tree_sitter_python::LANGUAGE.into(),
Language::TypeScript | Language::JavaScript => {
tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()
}
Language::Tsx => tree_sitter_typescript::LANGUAGE_TSX.into(),
Language::Go => tree_sitter_go::LANGUAGE.into(),
Language::Markdown => tree_sitter_md::LANGUAGE.into(),
Language::Sql => tree_sitter_sequel::LANGUAGE.into(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rust_parses_smoke() {
let src = "fn main() { println!(\"hi\"); }";
let tree = parse(Language::Rust, src).expect("parse");
assert!(tree.root_node().kind() == "source_file");
}
#[test]
fn python_parses_smoke() {
let src = "def main():\n print('hi')\n";
let tree = parse(Language::Python, src).expect("parse");
assert!(tree.root_node().kind() == "module");
}
#[test]
fn unknown_language_str_returns_none() {
assert!(Language::from_lang_str("cobol").is_none());
assert_eq!(Language::from_lang_str("RS"), Some(Language::Rust));
}
#[test]
fn parse_by_name_falls_back_to_builtin() {
let src = "fn main() {}";
let tree = parse_by_name("rust", src).expect("rust builtin");
assert_eq!(tree.root_node().kind(), "source_file");
}
#[test]
fn parse_by_name_uses_registry_first() {
let lang_name = "rust_alias_for_test";
let prior = register_grammar(lang_name, tree_sitter_rust::LANGUAGE.into());
assert!(prior.is_none());
let tree = parse_by_name(lang_name, "fn main() {}").expect("aliased grammar");
assert_eq!(tree.root_node().kind(), "source_file");
let removed = unregister_grammar(lang_name);
assert!(removed.is_some());
}
#[test]
fn parse_by_name_unknown_errors() {
let err = parse_by_name("definitely_unknown_grammar", "...").expect_err("must error");
let msg = err.to_string();
assert!(msg.contains("no tree-sitter grammar registered"), "got: {msg}");
}
#[test]
fn registry_overrides_builtin() {
let prior = register_grammar("rust", tree_sitter_python::LANGUAGE.into());
let tree = parse_by_name("rust", "def x():\n pass\n").expect("registered overrides");
assert_eq!(tree.root_node().kind(), "module");
if let Some(p) = prior {
register_grammar("rust", p);
} else {
unregister_grammar("rust");
}
}
#[test]
fn registered_grammars_lists_entries() {
register_grammar("test_listing_a", tree_sitter_rust::LANGUAGE.into());
register_grammar("test_listing_b", tree_sitter_python::LANGUAGE.into());
let names = registered_grammars();
assert!(names.contains(&"test_listing_a".to_string()));
assert!(names.contains(&"test_listing_b".to_string()));
unregister_grammar("test_listing_a");
unregister_grammar("test_listing_b");
}
}