use editor_core::delta::TextDelta;
use editor_core::processing::{DocumentProcessor, ProcessingEdit};
use editor_core::{EditorStateManager, FoldRegion, Interval, LineIndex, StyleId, StyleLayerId};
use std::collections::BTreeMap;
use streaming_iterator::StreamingIterator;
use tree_sitter::{InputEdit, Parser, Point, Query, QueryCursor, Tree};
#[derive(Debug)]
pub enum TreeSitterError {
Wasm(String),
Io(String),
Language(String),
Query(String),
DeltaMismatch,
}
impl std::fmt::Display for TreeSitterError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Wasm(msg) => write!(f, "tree-sitter wasm error: {msg}"),
Self::Io(msg) => write!(f, "tree-sitter io error: {msg}"),
Self::Language(msg) => write!(f, "tree-sitter language error: {msg}"),
Self::Query(msg) => write!(f, "tree-sitter query error: {msg}"),
Self::DeltaMismatch => write!(f, "tree-sitter delta mismatch"),
}
}
}
impl std::error::Error for TreeSitterError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TreeSitterUpdateMode {
Initial,
Incremental,
FullReparse,
Skipped,
}
#[derive(Debug, Clone)]
pub enum TreeSitterLanguage {
Native(tree_sitter::Language),
Wasm {
language_id: String,
wasm_bytes: Vec<u8>,
},
}
impl TreeSitterLanguage {
pub fn native(language: tree_sitter::Language) -> Self {
Self::Native(language)
}
pub fn wasm(language_id: String, wasm_bytes: Vec<u8>) -> Self {
Self::Wasm {
language_id,
wasm_bytes,
}
}
}
#[derive(Debug, Clone)]
pub struct TreeSitterProcessorConfig {
pub language: TreeSitterLanguage,
pub highlights_query: String,
pub folds_query: Option<String>,
pub capture_styles: BTreeMap<String, StyleId>,
pub style_layer: StyleLayerId,
pub preserve_collapsed_folds: bool,
}
impl TreeSitterProcessorConfig {
pub fn new(language: TreeSitterLanguage, highlights_query: impl Into<String>) -> Self {
Self {
language,
highlights_query: highlights_query.into(),
folds_query: None,
capture_styles: BTreeMap::new(),
style_layer: StyleLayerId::TREE_SITTER,
preserve_collapsed_folds: true,
}
}
pub fn with_folds_query(mut self, folds_query: impl Into<String>) -> Self {
self.folds_query = Some(folds_query.into());
self
}
pub fn with_simple_capture_styles<const N: usize>(
mut self,
styles: [(&'static str, StyleId); N],
) -> Self {
for (name, style_id) in styles {
self.capture_styles.insert(name.to_string(), style_id);
}
self
}
pub fn set_preserve_collapsed_folds(&mut self, preserve: bool) {
self.preserve_collapsed_folds = preserve;
}
pub fn highlights_capture_names(&self) -> Result<Vec<String>, TreeSitterError> {
let query = match &self.language {
TreeSitterLanguage::Native(language) => Query::new(language, &self.highlights_query)
.map_err(|e| TreeSitterError::Query(e.to_string()))?,
TreeSitterLanguage::Wasm {
language_id,
wasm_bytes,
} => {
let engine = tree_sitter::wasmtime::Engine::default();
let mut store = tree_sitter::WasmStore::new(&engine)
.map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
let language = store
.load_language(language_id, wasm_bytes)
.map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
Query::new(&language, &self.highlights_query)
.map_err(|e| TreeSitterError::Query(e.to_string()))?
}
};
Ok(query
.capture_names()
.iter()
.map(|name| (*name).to_string())
.collect())
}
}
pub struct TreeSitterProcessor {
config: TreeSitterProcessorConfig,
parser: Parser,
highlight_query: Query,
highlight_capture_styles: Vec<Option<StyleId>>,
fold_query: Option<Query>,
tree: Option<Tree>,
needs_parse: bool,
text: String,
line_index: LineIndex,
last_synced_version: Option<u64>,
last_processed_version: Option<u64>,
last_update_mode: TreeSitterUpdateMode,
#[allow(dead_code)]
wasm_engine: Option<tree_sitter::wasmtime::Engine>,
}
impl TreeSitterProcessor {
pub fn new(config: TreeSitterProcessorConfig) -> Result<Self, TreeSitterError> {
let mut parser = Parser::new();
let (language, wasm_engine) = match &config.language {
TreeSitterLanguage::Native(language) => {
parser
.set_language(language)
.map_err(|e| TreeSitterError::Language(e.to_string()))?;
(language.clone(), None)
}
TreeSitterLanguage::Wasm {
language_id,
wasm_bytes,
} => {
let engine = tree_sitter::wasmtime::Engine::default();
let store = tree_sitter::WasmStore::new(&engine)
.map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
parser
.set_wasm_store(store)
.map_err(|e| TreeSitterError::Language(e.to_string()))?;
let mut store = tree_sitter::WasmStore::new(&engine)
.map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
let language = store
.load_language(language_id, wasm_bytes)
.map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
parser
.set_language(&language)
.map_err(|e| TreeSitterError::Language(e.to_string()))?;
(language, Some(engine))
}
};
let highlight_query = Query::new(&language, &config.highlights_query)
.map_err(|e| TreeSitterError::Query(e.to_string()))?;
let highlight_capture_styles = highlight_query
.capture_names()
.iter()
.map(|name| config.capture_styles.get(*name).copied())
.collect::<Vec<_>>();
let fold_query = match config.folds_query.as_deref() {
Some(q) if !q.trim().is_empty() => {
Some(Query::new(&language, q).map_err(|e| TreeSitterError::Query(e.to_string()))?)
}
_ => None,
};
Ok(Self {
config,
parser,
highlight_query,
highlight_capture_styles,
fold_query,
tree: None,
needs_parse: false,
text: String::new(),
line_index: LineIndex::new(),
last_synced_version: None,
last_processed_version: None,
last_update_mode: TreeSitterUpdateMode::FullReparse,
wasm_engine,
})
}
pub fn last_update_mode(&self) -> TreeSitterUpdateMode {
self.last_update_mode
}
pub fn expand_selection_syntax(&self, start: usize, end: usize) -> Option<(usize, usize)> {
let tree = self.tree.as_ref()?;
let root = tree.root_node();
let (sel_start, sel_end) = if start <= end {
(start, end)
} else {
(end, start)
};
let start_byte = self.line_index.char_offset_to_byte_offset(sel_start);
let end_byte = self.line_index.char_offset_to_byte_offset(sel_end);
let mut node = root.descendant_for_byte_range(start_byte, end_byte)?;
loop {
let node_start = self
.line_index
.byte_offset_to_char_offset(node.start_byte());
let node_end = self.line_index.byte_offset_to_char_offset(node.end_byte());
if node_start == sel_start && node_end == sel_end {
if let Some(parent) = node.parent() {
node = parent;
continue;
}
return None;
}
return Some((node_start, node_end));
}
}
fn sync_from_text_full(&mut self, text: &str) {
self.text.clear();
self.text.push_str(text);
self.line_index = LineIndex::from_text(&self.text);
}
fn point_for_char_offset(&self, char_offset: usize) -> Point {
let (row, col) = self.line_index.char_offset_to_line_byte_column(char_offset);
Point { row, column: col }
}
fn advance_point(mut point: Point, text: &str) -> Point {
let mut parts = text.split('\n');
let Some(first) = parts.next() else {
return point;
};
point.column = point.column.saturating_add(first.len());
for part in parts {
point.row = point.row.saturating_add(1);
point.column = part.len();
}
point
}
fn apply_text_delta_incremental(&mut self, delta: &TextDelta) -> Result<(), TreeSitterError> {
if self.line_index.char_count() != delta.before_char_count {
return Err(TreeSitterError::DeltaMismatch);
}
if self.tree.is_none() {
return Err(TreeSitterError::DeltaMismatch);
}
for edit in &delta.edits {
let start_char = edit.start;
let deleted_chars = edit.deleted_text.chars().count();
let start_byte = self.line_index.char_offset_to_byte_offset(start_char);
let old_end_byte = start_byte.saturating_add(edit.deleted_text.len());
let new_end_byte = start_byte.saturating_add(edit.inserted_text.len());
let Some(old_slice) = self.text.get(start_byte..old_end_byte) else {
return Err(TreeSitterError::DeltaMismatch);
};
if old_slice != edit.deleted_text {
return Err(TreeSitterError::DeltaMismatch);
}
let start_position = self.point_for_char_offset(start_char);
let old_end_position = Self::advance_point(start_position, &edit.deleted_text);
let new_end_position = Self::advance_point(start_position, &edit.inserted_text);
if let Some(tree) = self.tree.as_mut() {
tree.edit(&InputEdit {
start_byte,
old_end_byte,
new_end_byte,
start_position,
old_end_position,
new_end_position,
});
}
self.text
.replace_range(start_byte..old_end_byte, &edit.inserted_text);
self.line_index.delete(start_char, deleted_chars);
self.line_index.insert(start_char, &edit.inserted_text);
}
if self.line_index.char_count() != delta.after_char_count {
return Err(TreeSitterError::DeltaMismatch);
}
Ok(())
}
fn parse(&mut self) -> Option<Tree> {
self.parser.parse(&self.text, self.tree.as_ref())
}
fn collect_highlight_intervals_in_byte_range(
&self,
tree: &Tree,
byte_range: Option<(usize, usize)>,
) -> Vec<Interval> {
let mut cursor = QueryCursor::new();
if let Some((start, end)) = byte_range {
let start = start.min(self.text.len());
let end = end.min(self.text.len());
if end > start {
cursor.set_byte_range(start..end);
}
}
let root = tree.root_node();
let mut out = Vec::<Interval>::new();
let mut matches = cursor.matches(&self.highlight_query, root, self.text.as_bytes());
while let Some(m) = matches.next() {
for capture in m.captures {
let idx = capture.index as usize;
let Some(style_id) = self.highlight_capture_styles.get(idx).and_then(|x| *x) else {
continue;
};
let node = capture.node;
let start_byte = node.start_byte();
let end_byte = node.end_byte();
if end_byte <= start_byte {
continue;
}
let start = self.line_index.byte_offset_to_char_offset(start_byte);
let end = self.line_index.byte_offset_to_char_offset(end_byte);
if end <= start {
continue;
}
out.push(Interval::new(start, end, style_id));
}
}
out.sort_by_key(|i| (i.start, i.end, i.style_id));
out.dedup_by(|a, b| a.start == b.start && a.end == b.end && a.style_id == b.style_id);
out
}
fn collect_fold_regions_in_byte_range(
&self,
tree: &Tree,
byte_range: Option<(usize, usize)>,
) -> Vec<FoldRegion> {
let Some(query) = self.fold_query.as_ref() else {
return Vec::new();
};
let mut cursor = QueryCursor::new();
if let Some((start, end)) = byte_range {
let start = start.min(self.text.len());
let end = end.min(self.text.len());
if end > start {
cursor.set_byte_range(start..end);
}
}
let root = tree.root_node();
let mut regions = Vec::<FoldRegion>::new();
let mut matches = cursor.matches(query, root, self.text.as_bytes());
while let Some(m) = matches.next() {
for capture in m.captures {
let node = capture.node;
let start_line = node.start_position().row;
let end_line = node.end_position().row;
if end_line > start_line {
regions.push(FoldRegion::new(start_line, end_line));
}
}
}
regions.sort_by_key(|r| (r.start_line, r.end_line));
regions.dedup_by(|a, b| a.start_line == b.start_line && a.end_line == b.end_line);
regions
}
}
impl DocumentProcessor for TreeSitterProcessor {
type Error = TreeSitterError;
fn process(&mut self, state: &EditorStateManager) -> Result<Vec<ProcessingEdit>, Self::Error> {
let version = state.version();
if self.last_processed_version == Some(version) {
self.last_update_mode = TreeSitterUpdateMode::Skipped;
return Ok(Vec::new());
}
if self.tree.is_none() {
let full = state.editor().get_text();
return self.process_text(version, None, Some(&full));
}
if let Some(delta) = state.last_text_delta() {
match self.process_text(version, Some(delta), None) {
Ok(edits) => Ok(edits),
Err(TreeSitterError::DeltaMismatch) => {
let full = state.editor().get_text();
self.process_text(version, Some(delta), Some(&full))
}
Err(e) => Err(e),
}
} else {
let full = state.editor().get_text();
self.process_text(version, None, Some(&full))
}
}
}
impl TreeSitterProcessor {
pub fn sync_to(
&mut self,
version: u64,
delta: Option<&TextDelta>,
full_text: Option<&str>,
) -> Result<TreeSitterUpdateMode, TreeSitterError> {
if self.last_synced_version == Some(version) {
self.last_update_mode = TreeSitterUpdateMode::Skipped;
return Ok(TreeSitterUpdateMode::Skipped);
}
let update_mode = if self.tree.is_none() {
let Some(text) = full_text else {
return Err(TreeSitterError::DeltaMismatch);
};
self.sync_from_text_full(text);
self.tree = self.parse();
self.needs_parse = false;
TreeSitterUpdateMode::Initial
} else if let Some(delta) = delta {
match self.apply_text_delta_incremental(delta) {
Ok(()) => {
self.needs_parse = true;
TreeSitterUpdateMode::Incremental
}
Err(_) => {
let Some(text) = full_text else {
return Err(TreeSitterError::DeltaMismatch);
};
self.sync_from_text_full(text);
self.tree = self.parser.parse(&self.text, None);
self.needs_parse = false;
TreeSitterUpdateMode::FullReparse
}
}
} else {
let Some(text) = full_text else {
return Err(TreeSitterError::DeltaMismatch);
};
self.sync_from_text_full(text);
self.tree = self.parser.parse(&self.text, None);
self.needs_parse = false;
TreeSitterUpdateMode::FullReparse
};
self.last_synced_version = Some(version);
self.last_update_mode = update_mode;
Ok(update_mode)
}
pub fn compute_processing_edits(
&mut self,
char_range: Option<(usize, usize)>,
) -> Result<Vec<ProcessingEdit>, TreeSitterError> {
let Some(version) = self.last_synced_version else {
return Ok(Vec::new());
};
if self.last_processed_version == Some(version) {
return Ok(Vec::new());
}
if self.needs_parse {
self.tree = self.parse();
self.needs_parse = false;
}
let Some(tree) = self.tree.as_ref() else {
self.last_processed_version = Some(version);
return Ok(Vec::new());
};
let byte_range = char_range.and_then(|(start_char, end_char)| {
let start = self.line_index.char_offset_to_byte_offset(start_char);
let end = self.line_index.char_offset_to_byte_offset(end_char);
if end > start {
Some((start, end))
} else {
None
}
});
let intervals = self.collect_highlight_intervals_in_byte_range(tree, byte_range);
let fold_regions = self.collect_fold_regions_in_byte_range(tree, byte_range);
let mut edits = vec![ProcessingEdit::ReplaceStyleLayer {
layer: self.config.style_layer,
intervals,
}];
if self.fold_query.is_some() {
edits.push(ProcessingEdit::ReplaceFoldingRegions {
regions: fold_regions,
preserve_collapsed: self.config.preserve_collapsed_folds,
});
}
self.last_processed_version = Some(version);
Ok(edits)
}
pub fn process_text(
&mut self,
version: u64,
delta: Option<&TextDelta>,
full_text: Option<&str>,
) -> Result<Vec<ProcessingEdit>, TreeSitterError> {
let _ = self.sync_to(version, delta, full_text)?;
self.compute_processing_edits(None)
}
}