use std::sync::atomic::{AtomicUsize, Ordering};
use crate::types::{CstNode, Language, ListOrdering};
static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
fn fresh_id() -> usize {
NEXT_ID.fetch_add(1, Ordering::Relaxed)
}
pub fn reset_ids() {
NEXT_ID.store(1, Ordering::Relaxed);
}
pub fn parse_to_cst(source: &str, lang: Language) -> Result<CstNode, ParseError> {
let ts_lang = get_tree_sitter_language(lang)?;
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&ts_lang)
.map_err(|e| ParseError::LanguageError(e.to_string()))?;
let tree = parser.parse(source, None).ok_or(ParseError::ParseFailed)?;
let root = tree.root_node();
Ok(ts_node_to_cst(&root, source.as_bytes()))
}
fn ts_node_to_cst(node: &tree_sitter::Node, source: &[u8]) -> CstNode {
let kind = node.kind().to_string();
let id = fresh_id();
if node.child_count() == 0 {
let value = node.utf8_text(source).unwrap_or("").to_string();
return CstNode::Leaf { id, kind, value };
}
let children: Vec<CstNode> = (0..node.child_count())
.filter_map(|i| node.child(i))
.map(|child| ts_node_to_cst(&child, source))
.collect();
let ordering = classify_ordering(&kind);
if is_list_node(&kind) || children.len() > 3 {
CstNode::List {
id,
kind,
ordering,
children,
}
} else {
CstNode::Constructed { id, kind, children }
}
}
fn classify_ordering(kind: &str) -> ListOrdering {
match kind {
"use_declaration_list"
| "import_list"
| "import_statement"
| "imports"
| "import_header" => ListOrdering::Unordered,
"class_body" | "enum_body" | "interface_body" | "declaration_list" | "companion_object"
| "enum_class_body" | "object_declaration" => ListOrdering::Unordered,
"table" | "inline_table" | "document" | "table_array_element" => ListOrdering::Unordered,
"block_mapping" | "flow_mapping" => ListOrdering::Unordered,
_ => ListOrdering::Ordered,
}
}
fn is_list_node(kind: &str) -> bool {
kind.contains("block")
|| kind.contains("body")
|| kind.contains("list")
|| kind.contains("statements")
|| kind.contains("arguments")
|| kind.contains("parameters")
|| kind.ends_with("_list")
|| kind == "program"
|| kind == "source_file"
|| kind == "module"
|| kind == "translation_unit"
|| kind == "document"
|| kind == "table"
|| kind == "array"
|| kind == "inline_table"
|| kind == "stream"
|| kind == "block_node"
|| kind == "block_mapping"
|| kind == "block_sequence"
|| kind == "flow_mapping"
|| kind == "flow_sequence"
}
fn get_tree_sitter_language(lang: Language) -> Result<tree_sitter::Language, ParseError> {
let lang_ref = match lang {
Language::Rust => tree_sitter_rust::LANGUAGE,
Language::JavaScript => tree_sitter_javascript::LANGUAGE,
Language::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT,
Language::Python => tree_sitter_python::LANGUAGE,
Language::Java => tree_sitter_java::LANGUAGE,
Language::Go => tree_sitter_go::LANGUAGE,
Language::C => tree_sitter_c::LANGUAGE,
Language::Cpp => tree_sitter_cpp::LANGUAGE,
Language::Kotlin => tree_sitter_kotlin_ng::LANGUAGE,
Language::Toml => tree_sitter_toml_ng::LANGUAGE,
Language::Yaml => tree_sitter_yaml::LANGUAGE,
};
Ok(lang_ref.into())
}
#[derive(Debug)]
pub enum ParseError {
LanguageError(String),
ParseFailed,
}
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ParseError::LanguageError(s) => write!(f, "language error: {}", s),
ParseError::ParseFailed => write!(f, "parse failed"),
}
}
}
impl std::error::Error for ParseError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_rust() {
let src = "fn main() { let x = 1; }";
let tree = parse_to_cst(src, Language::Rust).unwrap();
assert_eq!(tree.kind(), "source_file");
assert!(!tree.children().is_empty());
}
#[test]
fn test_parse_javascript() {
let src = "function foo() { return 42; }";
let tree = parse_to_cst(src, Language::JavaScript).unwrap();
assert_eq!(tree.kind(), "program");
assert!(!tree.children().is_empty());
}
#[test]
fn test_parse_kotlin() {
let src = "fun main() { val x = 1 }";
let tree = parse_to_cst(src, Language::Kotlin).unwrap();
assert_eq!(tree.kind(), "source_file");
assert!(!tree.children().is_empty());
let reconstructed = tree.to_source();
assert!(reconstructed.contains("fun"));
assert!(reconstructed.contains("val"));
}
#[test]
fn test_parse_toml() {
let src = "[package]\nname = \"merge-engine\"\nversion = \"0.1.0\"\n";
let tree = parse_to_cst(src, Language::Toml).unwrap();
assert_eq!(tree.kind(), "document");
assert!(!tree.children().is_empty());
let reconstructed = tree.to_source();
assert!(reconstructed.contains("package"));
assert!(reconstructed.contains("name"));
}
#[test]
fn test_parse_yaml() {
let src = "name: test\non:\n push:\n branches: [main]\njobs:\n build:\n runs-on: ubuntu-latest\n";
let tree = parse_to_cst(src, Language::Yaml).unwrap();
assert_eq!(tree.kind(), "stream");
let reconstructed = tree.to_source();
assert!(reconstructed.contains("ubuntu-latest"));
assert!(reconstructed.contains("branches"));
}
#[test]
fn test_leaf_reconstruction() {
let src = "let x = 1;";
let tree = parse_to_cst(src, Language::JavaScript).unwrap();
let reconstructed = tree.to_source();
assert!(reconstructed.contains("let"));
assert!(reconstructed.contains("x"));
assert!(reconstructed.contains("1"));
}
}