use super::language::CodeLanguage;
use std::collections::HashMap;
use std::sync::Arc;
use tree_sitter::{Node, Parser, Point, Tree};
const MAX_PARSE_CACHE_ENTRIES: usize = 16;
pub fn byte_offset_to_position(source: &str, byte_offset: usize) -> (usize, usize) {
let byte_offset = byte_offset.min(source.len());
let prefix = &source[..byte_offset];
let mut line = 0;
let mut last_newline_pos = 0;
for (i, c) in prefix.char_indices() {
if c == '\n' {
line += 1;
last_newline_pos = i + 1; }
}
let column = prefix[last_newline_pos..].chars().count();
(line, column)
}
#[derive(Debug, Clone)]
pub enum AstError {
ParserInit(String),
ParseFailed,
LanguageMismatch {
expected: String,
got: String,
},
}
impl std::fmt::Display for AstError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AstError::ParserInit(msg) => write!(f, "Parser initialization failed: {}", msg),
AstError::ParseFailed => write!(f, "Parsing failed"),
AstError::LanguageMismatch { expected, got } => {
write!(f, "Language mismatch: expected {}, got {}", expected, got)
}
}
}
}
impl std::error::Error for AstError {}
pub struct ParsedCode {
pub tree: Tree,
pub source: String,
pub language_name: String,
pub has_errors: bool,
pub error_ranges: Vec<ErrorRange>,
}
#[derive(Debug, Clone)]
pub struct ErrorRange {
pub start_byte: usize,
pub end_byte: usize,
pub start_position: (usize, usize),
pub end_position: (usize, usize),
pub text: String,
pub kind: String,
}
impl ParsedCode {
pub fn root(&self) -> Node<'_> {
self.tree.root_node()
}
pub fn errors(&self) -> impl Iterator<Item = &ErrorRange> {
self.error_ranges.iter()
}
pub fn error_count(&self) -> usize {
self.error_ranges.len()
}
pub fn is_in_error(&self, byte_offset: usize) -> bool {
self.error_ranges
.iter()
.any(|r| byte_offset >= r.start_byte && byte_offset < r.end_byte)
}
}
#[derive(Debug, Clone)]
pub struct AstNode {
pub kind: String,
pub start_byte: usize,
pub end_byte: usize,
pub start_position: (usize, usize),
pub end_position: (usize, usize),
pub is_named: bool,
pub is_error: bool,
pub is_missing: bool,
pub children: Vec<AstNode>,
pub text: Option<String>,
}
impl AstNode {
pub fn from_ts_node(node: Node, source: &str) -> Self {
let mut children = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
children.push(Self::from_ts_node(child, source));
}
let text = if children.is_empty() {
node.utf8_text(source.as_bytes()).ok().map(String::from)
} else {
None
};
let start = node.start_position();
let end = node.end_position();
Self {
kind: node.kind().to_string(),
start_byte: node.start_byte(),
end_byte: node.end_byte(),
start_position: (start.row, start.column),
end_position: (end.row, end.column),
is_named: node.is_named(),
is_error: node.is_error(),
is_missing: node.is_missing(),
children,
text,
}
}
pub fn descendants(&self) -> impl Iterator<Item = &AstNode> {
AstNodeIterator::new(self)
}
pub fn find_by_kind<'a>(&'a self, kind: &'a str) -> impl Iterator<Item = &'a AstNode> {
self.descendants().filter(move |n| n.kind == kind)
}
pub fn find_errors(&self) -> impl Iterator<Item = &AstNode> {
self.descendants().filter(|n| n.is_error || n.is_missing)
}
}
struct AstNodeIterator<'a> {
stack: Vec<&'a AstNode>,
}
impl<'a> AstNodeIterator<'a> {
fn new(root: &'a AstNode) -> Self {
Self { stack: vec![root] }
}
}
impl<'a> Iterator for AstNodeIterator<'a> {
type Item = &'a AstNode;
fn next(&mut self) -> Option<Self::Item> {
let node = self.stack.pop()?;
for child in node.children.iter().rev() {
self.stack.push(child);
}
Some(node)
}
}
pub struct CodeParser<L: CodeLanguage> {
language: Arc<L>,
parser: Parser,
tree_cache: HashMap<u64, (String, Tree)>,
}
impl<L: CodeLanguage> CodeParser<L> {
pub fn new(language: Arc<L>) -> Result<Self, AstError> {
let mut parser = Parser::new();
parser
.set_language(&language.tree_sitter_language())
.map_err(|e| AstError::ParserInit(e.to_string()))?;
Ok(Self {
language,
parser,
tree_cache: HashMap::new(),
})
}
pub fn parse(&mut self, source: &str) -> Result<ParsedCode, AstError> {
let cache_key = crate::util::hash::safe_hash(source.as_bytes());
if let Some((cached_source, tree)) = self.tree_cache.get(&cache_key) {
if cached_source == source {
return self.parsed_code_from_tree(tree.clone(), source);
}
}
let parsed = self.parse_with_old_tree(source, None)?;
if self.tree_cache.len() >= MAX_PARSE_CACHE_ENTRIES {
self.tree_cache.clear();
}
self.tree_cache
.insert(cache_key, (source.to_string(), parsed.tree.clone()));
Ok(parsed)
}
pub fn parse_with_old_tree(
&mut self,
source: &str,
old_tree: Option<&Tree>,
) -> Result<ParsedCode, AstError> {
let tree = self
.parser
.parse(source, old_tree)
.ok_or(AstError::ParseFailed)?;
self.parsed_code_from_tree(tree, source)
}
fn parsed_code_from_tree(&self, tree: Tree, source: &str) -> Result<ParsedCode, AstError> {
let has_errors = tree.root_node().has_error();
let error_ranges = if has_errors {
self.collect_errors(&tree, source)
} else {
Vec::new()
};
Ok(ParsedCode {
tree,
source: source.to_string(),
language_name: self.language.name().to_string(),
has_errors,
error_ranges,
})
}
pub fn parse_incremental(
&mut self,
source: &str,
old_tree: &mut Tree,
edit: &EditInfo,
) -> Result<ParsedCode, AstError> {
old_tree.edit(&edit.to_input_edit());
self.parse_with_old_tree(source, Some(old_tree))
}
fn collect_errors(&self, tree: &Tree, source: &str) -> Vec<ErrorRange> {
let mut errors = Vec::new();
self.collect_errors_recursive(tree.root_node(), source, &mut errors);
errors
}
fn collect_errors_recursive(&self, node: Node, source: &str, errors: &mut Vec<ErrorRange>) {
if node.is_error() || node.is_missing() {
let start = node.start_position();
let end = node.end_position();
let text = node.utf8_text(source.as_bytes()).unwrap_or("").to_string();
errors.push(ErrorRange {
start_byte: node.start_byte(),
end_byte: node.end_byte(),
start_position: (start.row, start.column),
end_position: (end.row, end.column),
text,
kind: node.kind().to_string(),
});
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.collect_errors_recursive(child, source, errors);
}
}
pub fn language(&self) -> &L {
&self.language
}
}
#[derive(Debug, Clone)]
pub struct EditInfo {
pub start_byte: usize,
pub old_end_byte: usize,
pub new_end_byte: usize,
pub start_position: (usize, usize),
pub old_end_position: (usize, usize),
pub new_end_position: (usize, usize),
}
impl EditInfo {
pub fn to_input_edit(&self) -> tree_sitter::InputEdit {
tree_sitter::InputEdit {
start_byte: self.start_byte,
old_end_byte: self.old_end_byte,
new_end_byte: self.new_end_byte,
start_position: Point::new(self.start_position.0, self.start_position.1),
old_end_position: Point::new(self.old_end_position.0, self.old_end_position.1),
new_end_position: Point::new(self.new_end_position.0, self.new_end_position.1),
}
}
pub fn insertion(position: usize, row: usize, column: usize, inserted_text: &str) -> Self {
let new_lines: Vec<&str> = inserted_text.split('\n').collect();
let new_end_row = row + new_lines.len() - 1;
let new_end_column = if new_lines.len() == 1 {
column + inserted_text.len()
} else {
new_lines.last().map(|s| s.len()).unwrap_or(0)
};
Self {
start_byte: position,
old_end_byte: position,
new_end_byte: position + inserted_text.len(),
start_position: (row, column),
old_end_position: (row, column),
new_end_position: (new_end_row, new_end_column),
}
}
pub fn deletion(
start_byte: usize,
end_byte: usize,
start_pos: (usize, usize),
end_pos: (usize, usize),
) -> Self {
Self {
start_byte,
old_end_byte: end_byte,
new_end_byte: start_byte,
start_position: start_pos,
old_end_position: end_pos,
new_end_position: start_pos,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ast_node_descendants() {
let root = AstNode {
kind: "root".to_string(),
start_byte: 0,
end_byte: 10,
start_position: (0, 0),
end_position: (0, 10),
is_named: true,
is_error: false,
is_missing: false,
text: None,
children: vec![
AstNode {
kind: "child1".to_string(),
start_byte: 0,
end_byte: 5,
start_position: (0, 0),
end_position: (0, 5),
is_named: true,
is_error: false,
is_missing: false,
text: Some("hello".to_string()),
children: vec![],
},
AstNode {
kind: "child2".to_string(),
start_byte: 5,
end_byte: 10,
start_position: (0, 5),
end_position: (0, 10),
is_named: true,
is_error: false,
is_missing: false,
text: Some("world".to_string()),
children: vec![],
},
],
};
let kinds: Vec<&str> = root.descendants().map(|n| n.kind.as_str()).collect();
assert_eq!(kinds, vec!["root", "child1", "child2"]);
}
#[test]
fn test_edit_info_insertion() {
let edit = EditInfo::insertion(5, 0, 5, "hello");
assert_eq!(edit.start_byte, 5);
assert_eq!(edit.old_end_byte, 5);
assert_eq!(edit.new_end_byte, 10);
}
}