use logging_timer::time;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::HashSet;
use std::hash::{Hash, Hasher};
use std::ops::{Deref, DerefMut};
use std::{cell::RefCell, ops::Index, path::PathBuf};
use tree_sitter::Node as TSNode;
use tree_sitter::Point;
use tree_sitter::Tree as TSTree;
use unicode_segmentation as us;
#[cfg(test)]
use mockall::{automock, predicate::str};
#[cfg_attr(test, automock)]
trait TSNodeTrait {
fn kind(&self) -> &str;
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "kebab-case", default)]
pub struct TreeSitterProcessor {
pub split_graphemes: bool,
pub exclude_kinds: Option<HashSet<String>>,
pub include_kinds: Option<HashSet<String>>,
pub strip_whitespace: bool,
}
impl Default for TreeSitterProcessor {
fn default() -> Self {
Self {
split_graphemes: true,
exclude_kinds: None,
include_kinds: None,
strip_whitespace: true,
}
}
}
#[derive(Debug)]
struct TSNodeWrapper<'a>(TSNode<'a>);
impl<'a> TSNodeTrait for TSNodeWrapper<'a> {
fn kind(&self) -> &str {
self.0.kind()
}
}
impl TreeSitterProcessor {
#[time("info", "ast::{}")]
pub fn process<'a>(&self, tree: &'a TSTree, text: &'a str) -> Vec<Entry<'a>> {
let ast_vector = from_ts_tree(tree, text);
let iter = ast_vector
.leaves
.iter()
.filter(|leaf| self.should_include_node(&TSNodeWrapper(leaf.reference)));
if self.split_graphemes {
iter.flat_map(|leaf| leaf.split_on_graphemes(self.strip_whitespace))
.collect()
} else {
iter.map(|&x| self.process_leaf(x)).collect()
}
}
fn process_leaf<'a>(&self, leaf: VectorLeaf<'a>) -> Entry<'a> {
let new_text = if self.strip_whitespace {
Cow::from(leaf.text.trim())
} else {
Cow::from(leaf.text)
};
Entry {
reference: leaf.reference,
text: new_text,
start_position: leaf.reference.start_position(),
end_position: leaf.reference.start_position(),
kind_id: leaf.reference.kind_id(),
}
}
fn should_include_node(&self, node: &dyn TSNodeTrait) -> bool {
let should_exclude = self
.exclude_kinds
.as_ref()
.is_some_and(|x| x.contains(node.kind()))
|| self
.include_kinds
.as_ref()
.is_some_and(|x| !x.contains(node.kind()));
!should_exclude
}
}
#[time("info", "ast::{}")]
fn from_ts_tree<'a>(tree: &'a TSTree, text: &'a str) -> Vector<'a> {
let leaves = RefCell::new(Vec::new());
build(&leaves, tree.root_node(), text);
Vector {
leaves: leaves.into_inner(),
source_text: text,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct VectorLeaf<'a> {
pub reference: TSNode<'a>,
pub text: &'a str,
}
#[derive(Serialize, Deserialize)]
#[serde(remote = "Point")]
struct PointWrapper {
pub row: usize,
pub column: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct Entry<'node> {
#[serde(skip_serializing)]
pub reference: TSNode<'node>,
pub text: Cow<'node, str>,
#[serde(with = "PointWrapper")]
pub start_position: Point,
#[serde(with = "PointWrapper")]
pub end_position: Point,
pub kind_id: u16,
}
impl<'a> VectorLeaf<'a> {
fn split_on_graphemes(self, strip_whitespace: bool) -> Vec<Entry<'a>> {
let mut entries: Vec<Entry<'a>> = Vec::new();
let lines = self.text.lines();
for (line_offset, line) in lines.enumerate() {
let indices: Vec<(usize, &str)> =
us::UnicodeSegmentation::grapheme_indices(line, true).collect();
entries.reserve(entries.len() + indices.len());
for (idx, grapheme) in indices {
debug_assert!(!grapheme.is_empty());
if strip_whitespace && grapheme.chars().all(char::is_whitespace) {
continue;
}
let start_column = if line_offset == 0 {
self.reference.start_position().column + idx
} else {
idx
};
let row = self.reference.start_position().row + line_offset;
let new_start_pos = Point {
row,
column: start_column,
};
let new_end_pos = Point {
row,
column: new_start_pos.column + grapheme.len(),
};
debug_assert!(new_start_pos.row <= new_end_pos.row);
let entry = Entry {
reference: self.reference,
text: Cow::from(&line[idx..idx + grapheme.len()]),
start_position: new_start_pos,
end_position: new_end_pos,
kind_id: self.reference.kind_id(),
};
#[cfg(debug_assertions)]
if let Some(last_entry) = entries.last() {
debug_assert!(
last_entry.end_position().row < entry.start_position().row
|| (last_entry.end_position.row == entry.start_position().row
&& last_entry.end_position.column <= entry.start_position().column)
);
}
entries.push(entry);
}
}
entries
}
}
impl<'a> From<VectorLeaf<'a>> for Entry<'a> {
fn from(leaf: VectorLeaf<'a>) -> Self {
Self {
reference: leaf.reference,
text: Cow::from(leaf.text),
start_position: leaf.reference.start_position(),
end_position: leaf.reference.start_position(),
kind_id: leaf.reference.kind_id(),
}
}
}
impl<'a> Entry<'a> {
#[must_use]
pub fn start_position(&self) -> Point {
self.start_position
}
#[must_use]
pub fn end_position(&self) -> Point {
self.end_position
}
}
impl<'a> From<&'a Vector<'a>> for Vec<Entry<'a>> {
fn from(ast_vector: &'a Vector<'a>) -> Self {
ast_vector
.leaves
.iter()
.flat_map(|entry| entry.split_on_graphemes(true))
.collect()
}
}
#[derive(Debug)]
pub struct Vector<'a> {
pub leaves: Vec<VectorLeaf<'a>>,
pub source_text: &'a str,
}
impl<'a> Eq for Entry<'a> {}
#[derive(Debug)]
pub struct VectorData {
pub text: String,
pub tree: TSTree,
pub path: PathBuf,
}
impl<'a> Vector<'a> {
#[time("info", "ast::{}")]
pub fn from_ts_tree(tree: &'a TSTree, text: &'a str) -> Self {
let leaves = RefCell::new(Vec::new());
build(&leaves, tree.root_node(), text);
Vector {
leaves: leaves.into_inner(),
source_text: text,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.leaves.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.leaves.is_empty()
}
}
impl<'a> Index<usize> for Vector<'a> {
type Output = VectorLeaf<'a>;
fn index(&self, index: usize) -> &Self::Output {
&self.leaves[index]
}
}
impl<'a> Hash for VectorLeaf<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.reference.kind_id().hash(state);
self.text.hash(state);
}
}
impl<'a> PartialEq for Entry<'a> {
fn eq(&self, other: &Entry) -> bool {
self.kind_id == other.kind_id && self.text == other.text
}
}
impl<'a> PartialEq for Vector<'a> {
fn eq(&self, other: &Vector) -> bool {
if self.leaves.len() != other.leaves.len() {
return false;
}
for i in 0..self.leaves.len() {
let leaf = self.leaves[i];
let other_leaf = other.leaves[i];
if leaf != other_leaf {
return false;
}
}
true
}
}
fn build<'a>(vector: &RefCell<Vec<VectorLeaf<'a>>>, node: tree_sitter::Node<'a>, text: &'a str) {
if node.child_count() == 0 {
if !node.byte_range().is_empty() {
let node_text: &'a str = &text[node.byte_range()];
if node_text
.replace("\r\n", "")
.replace(['\n', '\r'], "")
.is_empty()
{
return;
}
vector.borrow_mut().push(VectorLeaf {
reference: node,
text: node_text,
});
}
return;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
build(vector, child, text);
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum EditType<T> {
Addition(T),
Deletion(T),
}
impl<T> AsRef<T> for EditType<T> {
fn as_ref(&self) -> &T {
match self {
Self::Addition(x) | Self::Deletion(x) => x,
}
}
}
impl<T> Deref for EditType<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
Self::Addition(x) | Self::Deletion(x) => x,
}
}
}
impl<T> DerefMut for EditType<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
Self::Addition(x) | Self::Deletion(x) => x,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::GrammarConfig;
use tree_sitter::Parser;
#[cfg(feature = "static-grammar-libs")]
use crate::parse::generate_language;
#[test]
fn test_should_filter_node() {
let exclude_kinds: HashSet<String> = HashSet::from(["comment".to_string()]);
let mut mock_node = MockTSNodeTrait::new();
mock_node.expect_kind().return_const("comment".to_owned());
let processor = TreeSitterProcessor {
split_graphemes: false,
exclude_kinds: Some(exclude_kinds.clone()),
include_kinds: None,
..Default::default()
};
assert!(!processor.should_include_node(&mock_node));
let processor = TreeSitterProcessor {
split_graphemes: false,
exclude_kinds: Some(exclude_kinds.clone()),
include_kinds: Some(exclude_kinds),
..Default::default()
};
assert!(!processor.should_include_node(&mock_node));
let include_kinds: HashSet<String> = HashSet::from([
"some_other_type".to_string(),
"yet another type".to_string(),
]);
let processor = TreeSitterProcessor {
split_graphemes: false,
exclude_kinds: None,
include_kinds: Some(include_kinds),
..Default::default()
};
assert!(!processor.should_include_node(&mock_node));
let include_kinds: HashSet<String> = HashSet::from(["comment".to_string()]);
let processor = TreeSitterProcessor {
split_graphemes: false,
exclude_kinds: None,
include_kinds: Some(include_kinds),
..Default::default()
};
assert!(processor.should_include_node(&mock_node));
let processor = TreeSitterProcessor {
split_graphemes: false,
exclude_kinds: None,
include_kinds: None,
..Default::default()
};
assert!(processor.should_include_node(&mock_node));
}
#[cfg(feature = "static-grammar-libs")]
#[test]
fn test_strip_whitespace() {
let md_parser = generate_language("python", &GrammarConfig::default()).unwrap();
let mut parser = Parser::new();
parser.set_language(&md_parser).unwrap();
let text_a = "'''# A heading\nThis has no diff.'''";
let text_b = "'''# A heading\nThis\nhas\r\nno diff.'''";
let tree_a = parser.parse(text_a, None).unwrap();
let tree_b = parser.parse(text_b, None).unwrap();
{
let processor = TreeSitterProcessor {
strip_whitespace: true,
..Default::default()
};
let entries_a = processor.process(&tree_a, text_a);
let entries_b = processor.process(&tree_b, text_b);
assert_eq!(entries_a, entries_b);
}
{
let processor = TreeSitterProcessor {
strip_whitespace: false,
..Default::default()
};
let entries_a = processor.process(&tree_a, text_a);
let entries_b = processor.process(&tree_b, text_b);
assert_ne!(entries_a, entries_b);
}
}
}