use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex};
use tree_sitter::{Language, Parser, Tree};
use crate::error::TldrError;
use crate::types::Language as TldrLanguage;
use crate::TldrResult;
pub const MAX_PARSE_SIZE: usize = 5 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TsDialect {
None,
Ts,
Tsx,
}
impl TsDialect {
pub fn from_path_and_lang(path: Option<&Path>, lang: TldrLanguage) -> Self {
match lang {
TldrLanguage::TypeScript | TldrLanguage::JavaScript => {
match path
.and_then(|p| p.extension())
.and_then(|e| e.to_str())
.map(|e| e.to_ascii_lowercase())
{
Some(ref e) if e == "tsx" || e == "jsx" => TsDialect::Tsx,
_ => TsDialect::Ts,
}
}
_ => TsDialect::None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ParserKey {
pub lang: TldrLanguage,
pub dialect: TsDialect,
}
impl ParserKey {
pub fn new(lang: TldrLanguage, dialect: TsDialect) -> Self {
Self { lang, dialect }
}
}
pub struct ParserPool {
parsers: Mutex<HashMap<ParserKey, Parser>>,
}
impl ParserPool {
pub fn new() -> Self {
Self {
parsers: Mutex::new(HashMap::new()),
}
}
pub fn get_ts_language(lang: TldrLanguage) -> Option<Language> {
match lang {
TldrLanguage::Python => Some(tree_sitter_python::LANGUAGE.into()),
TldrLanguage::TypeScript | TldrLanguage::JavaScript => {
Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into())
}
TldrLanguage::Go => Some(tree_sitter_go::LANGUAGE.into()),
TldrLanguage::Rust => Some(tree_sitter_rust::LANGUAGE.into()),
TldrLanguage::Java => Some(tree_sitter_java::LANGUAGE.into()),
TldrLanguage::C => Some(tree_sitter_c::LANGUAGE.into()),
TldrLanguage::Cpp => Some(tree_sitter_cpp::LANGUAGE.into()),
TldrLanguage::Ruby => Some(tree_sitter_ruby::LANGUAGE.into()),
TldrLanguage::CSharp => Some(tree_sitter_c_sharp::LANGUAGE.into()),
TldrLanguage::Scala => Some(tree_sitter_scala::LANGUAGE.into()),
TldrLanguage::Php => Some(tree_sitter_php::LANGUAGE_PHP.into()),
TldrLanguage::Lua => Some(tree_sitter_lua::LANGUAGE.into()),
TldrLanguage::Luau => Some(tree_sitter_luau::LANGUAGE.into()),
TldrLanguage::Elixir => Some(tree_sitter_elixir::LANGUAGE.into()),
TldrLanguage::Ocaml => Some(tree_sitter_ocaml::LANGUAGE_OCAML.into()),
TldrLanguage::Kotlin => Some(tree_sitter_kotlin_ng::LANGUAGE.into()),
TldrLanguage::Swift => Some(tree_sitter_swift::LANGUAGE.into()),
}
}
fn select_ts_grammar(path: Option<&Path>) -> Language {
match path
.and_then(|p| p.extension())
.and_then(|e| e.to_str())
.map(|e| e.to_ascii_lowercase())
{
Some(ref e) if e == "tsx" || e == "jsx" => tree_sitter_typescript::LANGUAGE_TSX.into(),
_ => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
}
}
fn resolve_grammar(lang: TldrLanguage, path: Option<&Path>) -> Option<Language> {
match lang {
TldrLanguage::TypeScript | TldrLanguage::JavaScript => {
Some(Self::select_ts_grammar(path))
}
_ => Self::get_ts_language(lang),
}
}
pub fn parse(&self, source: &str, lang: TldrLanguage) -> TldrResult<Tree> {
self.parse_with_path(source, lang, None)
}
pub fn parse_with_path(
&self,
source: &str,
lang: TldrLanguage,
path: Option<&Path>,
) -> TldrResult<Tree> {
if source.len() > MAX_PARSE_SIZE {
return Err(TldrError::ParseError {
file: path
.map(|p| p.to_path_buf())
.unwrap_or_else(|| std::path::PathBuf::from("<source>")),
line: None,
message: format!(
"File too large: {} bytes (max {})",
source.len(),
MAX_PARSE_SIZE
),
});
}
let ts_lang = Self::resolve_grammar(lang, path)
.ok_or_else(|| TldrError::UnsupportedLanguage(lang.to_string()))?;
let dialect = TsDialect::from_path_and_lang(path, lang);
let key = ParserKey::new(lang, dialect);
let mut parsers = self.parsers.lock().unwrap();
let parser = parsers.entry(key).or_insert_with(|| {
let mut p = Parser::new();
p.set_language(&ts_lang).expect("Error loading grammar");
p
});
parser
.set_language(&ts_lang)
.map_err(|e| TldrError::ParseError {
file: path
.map(|p| p.to_path_buf())
.unwrap_or_else(|| std::path::PathBuf::from("<source>")),
line: None,
message: format!("Failed to set language: {}", e),
})?;
parser
.parse(source, None)
.ok_or_else(|| TldrError::ParseError {
file: path
.map(|p| p.to_path_buf())
.unwrap_or_else(|| std::path::PathBuf::from("<source>")),
line: None,
message: "Parsing returned None".to_string(),
})
}
pub fn parse_file(&self, path: &std::path::Path) -> TldrResult<(Tree, String, TldrLanguage)> {
let lang = TldrLanguage::from_path(path).ok_or_else(|| {
let ext = path
.extension()
.map(|e| e.to_string_lossy().to_string())
.unwrap_or_else(|| "unknown".to_string());
TldrError::UnsupportedLanguage(ext)
})?;
let bytes = std::fs::read(path).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
TldrError::PathNotFound(path.to_path_buf())
} else if e.kind() == std::io::ErrorKind::PermissionDenied {
TldrError::PermissionDenied(path.to_path_buf())
} else {
TldrError::IoError(e)
}
})?;
let source = String::from_utf8_lossy(&bytes).to_string();
let tree = self
.parse_with_path(&source, lang, Some(path))
.map_err(|e| {
if let TldrError::ParseError { line, message, .. } = e {
TldrError::ParseError {
file: path.to_path_buf(),
line,
message,
}
} else {
e
}
})?;
Ok((tree, source, lang))
}
}
impl Default for ParserPool {
fn default() -> Self {
Self::new()
}
}
lazy_static::lazy_static! {
pub static ref PARSER_POOL: Arc<ParserPool> = Arc::new(ParserPool::new());
}
pub fn parse(source: &str, lang: TldrLanguage) -> TldrResult<Tree> {
PARSER_POOL.parse(source, lang)
}
pub fn parse_with_path(source: &str, lang: TldrLanguage, path: Option<&Path>) -> TldrResult<Tree> {
PARSER_POOL.parse_with_path(source, lang, path)
}
pub fn parse_file(path: &std::path::Path) -> TldrResult<(Tree, String, TldrLanguage)> {
PARSER_POOL.parse_file(path)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_python() {
let source = "def foo(): pass";
let tree = parse(source, TldrLanguage::Python).unwrap();
assert_eq!(tree.root_node().kind(), "module");
}
#[test]
fn test_parse_typescript() {
let source = "function foo() {}";
let tree = parse(source, TldrLanguage::TypeScript).unwrap();
assert_eq!(tree.root_node().kind(), "program");
}
#[test]
fn test_parse_go() {
let source = "package main\nfunc foo() {}";
let tree = parse(source, TldrLanguage::Go).unwrap();
assert_eq!(tree.root_node().kind(), "source_file");
}
#[test]
fn test_parse_rust() {
let source = "fn foo() {}";
let tree = parse(source, TldrLanguage::Rust).unwrap();
assert_eq!(tree.root_node().kind(), "source_file");
}
#[test]
fn test_swift_now_supported() {
let result = parse("let x = 1", TldrLanguage::Swift);
assert!(
result.is_ok(),
"Swift should now parse successfully: {:?}",
result.err()
);
assert_eq!(result.unwrap().root_node().kind(), "source_file");
}
#[test]
fn test_parser_reuse() {
let pool = ParserPool::new();
for _ in 0..5 {
let _ = pool.parse("def foo(): pass", TldrLanguage::Python).unwrap();
}
let parsers = pool.parsers.lock().unwrap();
assert_eq!(parsers.len(), 1);
}
fn count_error_nodes(node: tree_sitter::Node) -> usize {
let mut count = if node.is_error() { 1 } else { 0 };
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
count += count_error_nodes(child);
}
count
}
#[test]
fn test_parse_file_tsx_uses_tsx_grammar() {
let dir = tempfile::tempdir().unwrap();
let tsx_path = dir.path().join("App.tsx");
std::fs::write(
&tsx_path,
r#"export const App = ({ name }: { name: string }) => <div className="a">{name}</div>;
"#,
)
.unwrap();
let pool = ParserPool::new();
let (tree, _src, lang) = pool.parse_file(&tsx_path).unwrap();
assert_eq!(lang, TldrLanguage::TypeScript);
let errors = count_error_nodes(tree.root_node());
assert_eq!(
errors, 0,
"expected zero ERROR nodes for .tsx via TSX grammar, got {}",
errors
);
let ts_path = dir.path().join("plain.ts");
std::fs::write(&ts_path, "export const x: number = 1;\n").unwrap();
let (tree, _src, lang) = pool.parse_file(&ts_path).unwrap();
assert_eq!(lang, TldrLanguage::TypeScript);
assert_eq!(
count_error_nodes(tree.root_node()),
0,
"plain .ts should parse cleanly"
);
}
#[test]
fn test_parse_file_jsx_uses_tsx_grammar() {
let dir = tempfile::tempdir().unwrap();
let jsx_path = dir.path().join("App.jsx");
std::fs::write(
&jsx_path,
"export const App = ({ name }) => <div className=\"a\">{name}</div>;\n",
)
.unwrap();
let pool = ParserPool::new();
let (tree, _src, lang) = pool.parse_file(&jsx_path).unwrap();
assert_eq!(lang, TldrLanguage::JavaScript);
let errors = count_error_nodes(tree.root_node());
assert_eq!(
errors, 0,
"expected zero ERROR nodes for .jsx via TSX grammar, got {}",
errors
);
}
#[test]
fn test_parse_cache_distinguishes_dialects() {
let dir = tempfile::tempdir().unwrap();
let ts_path = dir.path().join("a.ts");
let tsx_path = dir.path().join("b.tsx");
std::fs::write(&ts_path, "export const n: number = 1;\n").unwrap();
std::fs::write(&tsx_path, "export const App = () => <div>{1}</div>;\n").unwrap();
let pool = ParserPool::new();
let (t1, _, _) = pool.parse_file(&ts_path).unwrap();
assert_eq!(count_error_nodes(t1.root_node()), 0, "first .ts failed");
let (t2, _, _) = pool.parse_file(&tsx_path).unwrap();
assert_eq!(count_error_nodes(t2.root_node()), 0, ".tsx failed");
let (t3, _, _) = pool.parse_file(&ts_path).unwrap();
assert_eq!(
count_error_nodes(t3.root_node()),
0,
"second .ts failed (cache collision between TS and TSX parsers)"
);
}
#[test]
fn test_legacy_parse_without_path_uses_ts_default() {
let pool = ParserPool::new();
let tree = pool
.parse("export const x: number = 1;", TldrLanguage::TypeScript)
.unwrap();
assert_eq!(
count_error_nodes(tree.root_node()),
0,
"plain TS should parse cleanly via path-less API"
);
let jsx_src = "const App = () => <div className=\"a\">hi</div>;";
let tree = pool.parse(jsx_src, TldrLanguage::TypeScript).unwrap();
assert!(
count_error_nodes(tree.root_node()) > 0,
"path-less TS parse of JSX is expected to produce ERROR nodes; \
if it parses cleanly, the default grammar changed and callers \
must be audited"
);
}
#[test]
fn test_all_18_parsers_accept_minimal_valid_snippet() {
let snippets: &[(TldrLanguage, &str)] = &[
(TldrLanguage::Python, "def x(): pass"),
(TldrLanguage::TypeScript, "export const x: number = 1;"),
(TldrLanguage::JavaScript, "export const x = 1;"),
(TldrLanguage::Go, "package main\nfunc main() {}"),
(TldrLanguage::Rust, "pub fn x() {}"),
(
TldrLanguage::Java,
"class X { public static void main(String[] a){} }",
),
(TldrLanguage::C, "int main(){return 0;}"),
(TldrLanguage::Cpp, "int main(){return 0;}"),
(TldrLanguage::Ruby, "def x; end"),
(TldrLanguage::Kotlin, "fun x(){}"),
(TldrLanguage::Swift, "func x(){}"),
(TldrLanguage::CSharp, "class X { static void Main(){} }"),
(
TldrLanguage::Scala,
"object X { def main(args: Array[String]): Unit = {} }",
),
(TldrLanguage::Php, "<?php function x(){}"),
(TldrLanguage::Lua, "function x() end"),
(TldrLanguage::Luau, "function x() end"),
(
TldrLanguage::Elixir,
"defmodule X do\ndef y(), do: :ok\nend",
),
(TldrLanguage::Ocaml, "let x () = ()"),
];
let pool = ParserPool::new();
let mut failures: Vec<String> = Vec::new();
for (lang, src) in snippets {
match pool.parse(src, *lang) {
Ok(tree) => {
let errs = count_error_nodes(tree.root_node());
if errs != 0 {
failures.push(format!(
"{:?}: {} ERROR node(s) on valid snippet: {:?}",
lang, errs, src
));
}
}
Err(e) => {
failures.push(format!("{:?}: parse failed: {:?} on {:?}", lang, e, src));
}
}
}
assert!(
failures.is_empty(),
"Parser audit failures (VAL-008): {}",
failures.join(" | ")
);
}
}