use anyhow::{Context, Result};
use tree_sitter::{
InputEdit, Language, Parser, Point, Query, QueryCursor, QueryPredicateArg, StreamingIterator,
Tree,
};
pub struct Highlighter {
parser: Parser,
query: Query,
textobjects: Option<Query>,
textobject_capture_names: Vec<String>,
indents: Option<Query>,
indent_capture_names: Vec<String>,
tree: Option<Tree>,
source: String,
parsed_version: Option<u64>,
capture_names: Vec<String>,
pub warnings: Vec<String>,
}
impl Highlighter {
pub(super) fn new(
language: Language,
highlights_src: &str,
textobjects_src: Option<&str>,
indents_src: Option<&str>,
) -> Result<Self> {
let mut parser = Parser::new();
parser
.set_language(&language)
.context("setting parser language (ABI mismatch?)")?;
let query = Query::new(&language, highlights_src).context("compiling highlights query")?;
let capture_names = query
.capture_names()
.iter()
.map(|s| s.to_string())
.collect();
let (textobjects, textobject_capture_names) = match textobjects_src {
Some(src) => {
let q = Query::new(&language, src).context("compiling textobjects query")?;
let names = q.capture_names().iter().map(|s| s.to_string()).collect();
(Some(q), names)
}
None => (None, Vec::new()),
};
let mut warnings = Vec::new();
let (indents, indent_capture_names) = match indents_src {
Some(src) => match Query::new(&language, src) {
Ok(q) => {
let names = q.capture_names().iter().map(|s| s.to_string()).collect();
(Some(q), names)
}
Err(e) => {
warnings.push(format!(
"indents.scm compile failed, auto-indent disabled: {e}"
));
(None, Vec::new())
}
},
None => (None, Vec::new()),
};
Ok(Self {
parser,
query,
textobjects,
textobject_capture_names,
indents,
indent_capture_names,
tree: None,
source: String::new(),
parsed_version: None,
capture_names,
warnings,
})
}
pub fn refresh(&mut self, source: &str, version: u64) {
if self.parsed_version == Some(version) {
return;
}
let old_tree = match self.tree.as_mut() {
Some(tree) if !self.source.is_empty() => {
let edit = compute_input_edit(&self.source, source);
tree.edit(&edit);
Some(&*tree)
}
_ => None,
};
self.tree = self.parser.parse(source, old_tree);
self.source = source.to_string();
self.parsed_version = Some(version);
}
pub fn captures_in_rows(&self, start_row: usize, end_row: usize) -> Vec<Capture> {
let Some(tree) = &self.tree else {
return Vec::new();
};
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&self.query, tree.root_node(), self.source.as_bytes());
let mut out = Vec::new();
while let Some(m) = matches.next() {
for cap in m.captures {
let node = cap.node;
let start = node.start_position();
let end = node.end_position();
if end.row < start_row || start.row > end_row {
continue;
}
let name = self
.capture_names
.get(cap.index as usize)
.cloned()
.unwrap_or_default();
out.push(Capture {
start_row: start.row,
start_col: byte_to_char_col(&self.source, start.row, start.column),
end_row: end.row,
end_col: byte_to_char_col(&self.source, end.row, end.column),
name,
});
}
}
out.sort_by_key(|c| (c.start_row, c.start_col));
out
}
pub fn indent_begins_at(&self, row: usize) -> bool {
let Some(tree) = self.tree.as_ref() else {
return false;
};
let Some(query) = self.indents.as_ref() else {
return false;
};
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), self.source.as_bytes());
while let Some(m) = matches.next() {
for cap in m.captures {
let name = self
.indent_capture_names
.get(cap.index as usize)
.map(String::as_str)
.unwrap_or("");
if name != "indent.begin" {
continue;
}
let node = cap.node;
let start_row = node.start_position().row;
let end_row = node.end_position().row;
if start_row != row {
continue;
}
if end_row > row {
return true;
}
if let Some(body) = node.child_by_field_name("body")
&& body.start_byte() == body.end_byte()
{
return true;
}
}
}
false
}
pub fn find_text_object(
&self,
target: &str,
cursor_row: usize,
cursor_col_chars: usize,
) -> Option<(usize, usize, usize, usize)> {
let tree = self.tree.as_ref()?;
let query = self.textobjects.as_ref()?;
let cursor_pt = (
cursor_row,
char_to_byte_col(&self.source, cursor_row, cursor_col_chars),
);
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), self.source.as_bytes());
let mut best: Option<Candidate> = None;
while let Some(m) = matches.next() {
for cap in m.captures {
let name = self
.textobject_capture_names
.get(cap.index as usize)
.map(String::as_str)
.unwrap_or("");
if name != target {
continue;
}
let node = cap.node;
consider(
&mut best,
node.start_byte()..node.end_byte(),
point(node.start_position()),
point(node.end_position()),
cursor_pt,
);
}
for pred in query.general_predicates(m.pattern_index) {
if pred.operator.as_ref() != "make-range!" {
continue;
}
let (name, start_idx, end_idx) = match pred.args.as_ref() {
[
QueryPredicateArg::String(n),
QueryPredicateArg::Capture(s),
QueryPredicateArg::Capture(e),
] => (n.as_ref(), *s, *e),
_ => continue,
};
if name != target {
continue;
}
let mut span_start: Option<tree_sitter::Node> = None;
let mut span_end: Option<tree_sitter::Node> = None;
for cap in m.captures {
if cap.index == start_idx {
span_start = match span_start {
None => Some(cap.node),
Some(prev) if cap.node.start_byte() < prev.start_byte() => {
Some(cap.node)
}
other => other,
};
}
if cap.index == end_idx {
span_end = match span_end {
None => Some(cap.node),
Some(prev) if cap.node.end_byte() > prev.end_byte() => Some(cap.node),
other => other,
};
}
}
if let (Some(s), Some(e)) = (span_start, span_end) {
consider(
&mut best,
s.start_byte()..e.end_byte(),
point(s.start_position()),
point(e.end_position()),
cursor_pt,
);
}
}
}
let c = best?;
Some((
c.start.0,
byte_to_char_col(&self.source, c.start.0, c.start.1),
c.end.0,
byte_to_char_col(&self.source, c.end.0, c.end.1),
))
}
}
struct Candidate {
bytes: std::ops::Range<usize>,
start: (usize, usize),
end: (usize, usize),
}
fn point(p: tree_sitter::Point) -> (usize, usize) {
(p.row, p.column)
}
fn compute_input_edit(old: &str, new: &str) -> InputEdit {
let old_bytes = old.as_bytes();
let new_bytes = new.as_bytes();
let common_prefix = old_bytes
.iter()
.zip(new_bytes.iter())
.take_while(|(a, b)| a == b)
.count();
let max_suffix = old_bytes
.len()
.min(new_bytes.len())
.saturating_sub(common_prefix);
let common_suffix = old_bytes
.iter()
.rev()
.zip(new_bytes.iter().rev())
.take(max_suffix)
.take_while(|(a, b)| a == b)
.count();
let start_byte = common_prefix;
let old_end_byte = old_bytes.len() - common_suffix;
let new_end_byte = new_bytes.len() - common_suffix;
InputEdit {
start_byte,
old_end_byte,
new_end_byte,
start_position: byte_to_point(old_bytes, start_byte),
old_end_position: byte_to_point(old_bytes, old_end_byte),
new_end_position: byte_to_point(new_bytes, new_end_byte),
}
}
fn byte_to_point(bytes: &[u8], offset: usize) -> Point {
let offset = offset.min(bytes.len());
let mut row = 0usize;
let mut line_start = 0usize;
for (i, &b) in bytes[..offset].iter().enumerate() {
if b == b'\n' {
row += 1;
line_start = i + 1;
}
}
Point {
row,
column: offset - line_start,
}
}
fn consider(
best: &mut Option<Candidate>,
bytes: std::ops::Range<usize>,
start: (usize, usize),
end: (usize, usize),
cursor: (usize, usize),
) {
if !(start <= cursor && cursor < end) {
return;
}
let len = bytes.end - bytes.start;
let take = match best {
None => true,
Some(c) => len < c.bytes.end - c.bytes.start,
};
if take {
*best = Some(Candidate { bytes, start, end });
}
}
fn byte_to_char_col(source: &str, row: usize, byte_col: usize) -> usize {
let line = source.lines().nth(row).unwrap_or("");
let take = byte_col.min(line.len());
line[..take].chars().count()
}
fn char_to_byte_col(source: &str, row: usize, char_col: usize) -> usize {
let line = source.lines().nth(row).unwrap_or("");
line.char_indices()
.nth(char_col)
.map(|(b, _)| b)
.unwrap_or(line.len())
}
#[derive(Debug, Clone)]
pub struct Capture {
pub start_row: usize,
pub start_col: usize,
pub end_row: usize,
pub end_col: usize,
pub name: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn byte_to_char_handles_ascii() {
let src = "let x = 1\nprintln!(\"hi\")";
assert_eq!(byte_to_char_col(src, 0, 0), 0);
assert_eq!(byte_to_char_col(src, 0, 4), 4);
assert_eq!(byte_to_char_col(src, 1, 9), 9);
}
#[test]
fn byte_to_char_handles_multibyte() {
let src = "あ x";
assert_eq!(byte_to_char_col(src, 0, 3), 1);
assert_eq!(byte_to_char_col(src, 0, 5), 3);
}
#[test]
fn input_edit_single_byte_insertion() {
let edit = compute_input_edit("abc", "abXc");
assert_eq!(edit.start_byte, 2);
assert_eq!(edit.old_end_byte, 2);
assert_eq!(edit.new_end_byte, 3);
assert_eq!(edit.start_position, Point { row: 0, column: 2 });
assert_eq!(edit.new_end_position, Point { row: 0, column: 3 });
}
#[test]
fn input_edit_no_change_is_noop_range() {
let edit = compute_input_edit("hello", "hello");
assert_eq!(edit.start_byte, 5);
assert_eq!(edit.old_end_byte, 5);
assert_eq!(edit.new_end_byte, 5);
}
#[test]
fn input_edit_multi_line_replacement() {
let edit = compute_input_edit("fn a() {\n 1\n}\n", "fn a() {\n 42\n}\n");
assert_eq!(edit.start_byte, 11);
assert_eq!(edit.old_end_byte, 15 - 3); assert_eq!(edit.new_end_byte, 16 - 3); assert_eq!(edit.start_position, Point { row: 1, column: 2 });
}
#[test]
fn input_edit_full_replacement() {
let edit = compute_input_edit("abc", "xyz");
assert_eq!(edit.start_byte, 0);
assert_eq!(edit.old_end_byte, 3);
assert_eq!(edit.new_end_byte, 3);
}
}