use crate::model::buffer::Buffer;
use crate::primitives::highlighter::{HighlightSpan, Language};
use crate::primitives::word_navigation::{find_word_end, find_word_start, is_word_char};
use fresh_languages::tree_sitter::{Parser, Query, QueryCursor, StreamingIterator};
use ratatui::style::Color;
use std::ops::Range;
pub const DEFAULT_HIGHLIGHT_COLOR: Color = Color::Rgb(60, 60, 80);
pub struct ReferenceHighlighter {
pub highlight_color: Color,
pub min_word_length: usize,
pub enabled: bool,
parser: Option<Parser>,
identifier_query: Option<Query>,
locals_query: Option<Query>,
locals_captures: LocalsCaptures,
}
#[derive(Default)]
struct LocalsCaptures {
scope: Option<u32>,
definition: Option<u32>,
reference: Option<u32>,
}
const IDENTIFIER_QUERY: &str = "(identifier) @id";
fn get_locals_query(language: &Language) -> Option<&'static str> {
match language {
Language::Rust => Some(RUST_LOCALS_QUERY),
Language::Python => Some(PYTHON_LOCALS_QUERY),
Language::JavaScript | Language::TypeScript => Some(JS_LOCALS_QUERY),
Language::Go => Some(GO_LOCALS_QUERY),
Language::C | Language::Cpp => Some(C_LOCALS_QUERY),
_ => None, }
}
const RUST_LOCALS_QUERY: &str = r#"
; Scopes
(function_item body: (_) @local.scope)
(closure_expression body: (_) @local.scope)
; Definitions - parameters
((parameter pattern: (identifier) @local.definition))
; Definitions - let bindings
(let_declaration pattern: (identifier) @local.definition)
; References
(identifier) @local.reference
"#;
const PYTHON_LOCALS_QUERY: &str = r#"
; Scopes
(function_definition) @local.scope
(class_definition) @local.scope
(lambda) @local.scope
(for_statement) @local.scope
(while_statement) @local.scope
(with_statement) @local.scope
; Definitions
(parameters (identifier) @local.definition)
(assignment left: (identifier) @local.definition)
(for_statement left: (identifier) @local.definition)
(with_clause (as_pattern (as_pattern_target (identifier) @local.definition)))
; References
(identifier) @local.reference
"#;
const JS_LOCALS_QUERY: &str = r#"
; Scopes
(function_declaration) @local.scope
(function_expression) @local.scope
(arrow_function) @local.scope
(method_definition) @local.scope
(for_statement) @local.scope
(for_in_statement) @local.scope
(block) @local.scope
; Definitions
(formal_parameters (identifier) @local.definition)
(variable_declarator name: (identifier) @local.definition)
(for_in_statement left: (identifier) @local.definition)
; References
(identifier) @local.reference
"#;
const GO_LOCALS_QUERY: &str = r#"
; Scopes
(function_declaration) @local.scope
(method_declaration) @local.scope
(func_literal) @local.scope
(block) @local.scope
(if_statement) @local.scope
(for_statement) @local.scope
; Definitions
(parameter_declaration (identifier) @local.definition)
(short_var_declaration left: (expression_list (identifier) @local.definition))
(var_spec name: (identifier) @local.definition)
(range_clause left: (expression_list (identifier) @local.definition))
; References
(identifier) @local.reference
"#;
const C_LOCALS_QUERY: &str = r#"
; Scopes
(function_definition) @local.scope
(compound_statement) @local.scope
(for_statement) @local.scope
(while_statement) @local.scope
(if_statement) @local.scope
; Definitions
(parameter_declaration declarator: (identifier) @local.definition)
(declaration declarator: (identifier) @local.definition)
(init_declarator declarator: (identifier) @local.definition)
; References
(identifier) @local.reference
"#;
impl ReferenceHighlighter {
pub fn new() -> Self {
Self {
highlight_color: DEFAULT_HIGHLIGHT_COLOR,
min_word_length: 2,
enabled: true,
parser: None,
identifier_query: None,
locals_query: None,
locals_captures: LocalsCaptures::default(),
}
}
pub fn with_color(mut self, color: Color) -> Self {
self.highlight_color = color;
self
}
pub fn with_min_length(mut self, length: usize) -> Self {
self.min_word_length = length;
self
}
pub fn set_language(&mut self, language: &Language) {
let ts_language = match language {
Language::Rust => fresh_languages::tree_sitter_rust::LANGUAGE.into(),
Language::Python => fresh_languages::tree_sitter_python::LANGUAGE.into(),
Language::JavaScript => fresh_languages::tree_sitter_javascript::LANGUAGE.into(),
Language::TypeScript => {
fresh_languages::tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()
}
Language::Go => fresh_languages::tree_sitter_go::LANGUAGE.into(),
Language::C => fresh_languages::tree_sitter_c::LANGUAGE.into(),
Language::Cpp => fresh_languages::tree_sitter_cpp::LANGUAGE.into(),
Language::Java => fresh_languages::tree_sitter_java::LANGUAGE.into(),
Language::Php => fresh_languages::tree_sitter_php::LANGUAGE_PHP.into(),
Language::Ruby => fresh_languages::tree_sitter_ruby::LANGUAGE.into(),
Language::Bash => fresh_languages::tree_sitter_bash::LANGUAGE.into(),
Language::Lua => fresh_languages::tree_sitter_lua::LANGUAGE.into(),
Language::Pascal => fresh_languages::tree_sitter_pascal::LANGUAGE.into(),
Language::Json => fresh_languages::tree_sitter_json::LANGUAGE.into(),
Language::HTML => fresh_languages::tree_sitter_html::LANGUAGE.into(),
Language::CSS => fresh_languages::tree_sitter_css::LANGUAGE.into(),
Language::CSharp => fresh_languages::tree_sitter_c_sharp::LANGUAGE.into(),
Language::Odin => fresh_languages::tree_sitter_odin::LANGUAGE.into(),
};
let mut parser = Parser::new();
if parser.set_language(&ts_language).is_err() {
tracing::warn!("Failed to set language for semantic highlighting parser");
self.parser = None;
self.identifier_query = None;
self.locals_query = None;
self.locals_captures = LocalsCaptures::default();
return;
}
if let Some(locals_source) = get_locals_query(language) {
match Query::new(&ts_language, locals_source) {
Ok(query) => {
let mut captures = LocalsCaptures::default();
for (i, name) in query.capture_names().iter().enumerate() {
match *name {
"local.scope" => captures.scope = Some(i as u32),
"local.definition" => captures.definition = Some(i as u32),
"local.reference" => captures.reference = Some(i as u32),
_ => {}
}
}
self.locals_query = Some(query);
self.locals_captures = captures;
tracing::debug!(
"Locals query enabled for {:?} (scope-aware highlighting)",
language
);
}
Err(e) => {
tracing::debug!(
"Locals query failed for {:?}, falling back to identifier matching: {}",
language,
e
);
self.locals_query = None;
self.locals_captures = LocalsCaptures::default();
}
}
} else {
self.locals_query = None;
self.locals_captures = LocalsCaptures::default();
}
match Query::new(&ts_language, IDENTIFIER_QUERY) {
Ok(query) => {
self.parser = Some(parser);
self.identifier_query = Some(query);
tracing::debug!(
"Tree-sitter semantic highlighting enabled for {:?}",
language
);
}
Err(e) => {
tracing::debug!(
"Identifier query not supported for {:?}, using text matching: {}",
language,
e
);
self.parser = None;
self.identifier_query = None;
}
}
}
pub fn has_locals(&self) -> bool {
self.locals_query.is_some()
&& self.locals_captures.definition.is_some()
&& self.locals_captures.reference.is_some()
}
pub fn has_tree_sitter(&self) -> bool {
self.parser.is_some() && self.identifier_query.is_some()
}
pub fn highlight_occurrences(
&mut self,
buffer: &Buffer,
cursor_position: usize,
viewport_start: usize,
viewport_end: usize,
context_bytes: usize,
) -> Vec<HighlightSpan> {
if !self.enabled {
return Vec::new();
}
if self.has_locals() {
return self.highlight_with_locals(
buffer,
cursor_position,
viewport_start,
viewport_end,
context_bytes,
);
}
if self.has_tree_sitter() {
return self.highlight_with_tree_sitter(
buffer,
cursor_position,
viewport_start,
viewport_end,
context_bytes,
);
}
self.highlight_with_text_matching(
buffer,
cursor_position,
viewport_start,
viewport_end,
context_bytes,
)
}
fn highlight_with_locals(
&mut self,
buffer: &Buffer,
cursor_position: usize,
viewport_start: usize,
viewport_end: usize,
context_bytes: usize,
) -> Vec<HighlightSpan> {
let parser = match &mut self.parser {
Some(p) => p,
None => return Vec::new(),
};
let query = match &self.locals_query {
Some(q) => q,
None => return Vec::new(),
};
let def_idx = match self.locals_captures.definition {
Some(i) => i,
None => return Vec::new(),
};
let ref_idx = match self.locals_captures.reference {
Some(i) => i,
None => return Vec::new(),
};
let scope_idx = self.locals_captures.scope;
let parse_start = viewport_start.saturating_sub(context_bytes);
let parse_end = (viewport_end + context_bytes).min(buffer.len());
let source = buffer.slice_bytes(parse_start..parse_end);
let tree = match parser.parse(&source, None) {
Some(t) => t,
None => {
return self.highlight_with_tree_sitter(
buffer,
cursor_position,
viewport_start,
viewport_end,
context_bytes,
);
}
};
let mut query_cursor = QueryCursor::new();
let mut matches = query_cursor.matches(query, tree.root_node(), source.as_slice());
let mut scopes: Vec<Range<usize>> = Vec::new();
let mut definitions: Vec<(Range<usize>, String, usize)> = Vec::new(); let mut references: Vec<(Range<usize>, String)> = Vec::new();
while let Some(m) = matches.next() {
for capture in m.captures {
let node = capture.node;
let start = parse_start + node.start_byte();
let end = parse_start + node.end_byte();
let range = start..end;
let text_bytes = &source[node.start_byte()..node.end_byte()];
let text = match std::str::from_utf8(text_bytes) {
Ok(s) => s.to_string(),
Err(_) => continue,
};
if Some(capture.index) == scope_idx {
scopes.push(range);
} else if capture.index == def_idx {
let scope_id = scopes
.iter()
.enumerate()
.filter(|(_, s)| s.start <= start && end <= s.end)
.map(|(i, _)| i)
.next_back()
.unwrap_or(usize::MAX);
definitions.push((range, text, scope_id));
} else if capture.index == ref_idx {
references.push((range, text));
}
}
}
let cursor_item = definitions
.iter()
.find(|(range, _, _)| cursor_position >= range.start && cursor_position <= range.end)
.map(|(range, name, scope_id)| (range.clone(), name.clone(), Some(*scope_id)))
.or_else(|| {
references
.iter()
.find(|(range, _)| {
cursor_position >= range.start && cursor_position <= range.end
})
.map(|(range, name)| (range.clone(), name.clone(), None))
});
let (cursor_range, target_name, cursor_scope_id) = match cursor_item {
Some(item) => item,
None => return Vec::new(),
};
if target_name.len() < self.min_word_length {
return Vec::new();
}
let definition_scope = if let Some(scope_id) = cursor_scope_id {
Some(scope_id)
} else {
let containing_scopes: Vec<usize> = scopes
.iter()
.enumerate()
.filter(|(_, s)| s.start <= cursor_range.start && cursor_range.end <= s.end)
.map(|(i, _)| i)
.collect();
containing_scopes.iter().rev().find_map(|&scope_id| {
definitions
.iter()
.find(|(_, name, def_scope)| name == &target_name && *def_scope == scope_id)
.map(|(_, _, s)| *s)
})
};
let mut highlights = Vec::new();
if let Some(scope_id) = definition_scope {
for (range, name, def_scope) in &definitions {
if name == &target_name
&& *def_scope == scope_id
&& range.start < viewport_end
&& range.end > viewport_start
{
highlights.push(HighlightSpan {
range: range.clone(),
color: self.highlight_color,
});
}
}
let scope_range = scopes.get(scope_id).cloned();
for (range, name) in &references {
if name != &target_name {
continue;
}
if range.start >= viewport_end || range.end <= viewport_start {
continue;
}
let ref_in_scope = match &scope_range {
Some(sr) => range.start >= sr.start && range.end <= sr.end,
None => true, };
if ref_in_scope {
let is_shadowed = definitions.iter().any(|(def_range, def_name, def_scope)| {
def_name == name
&& *def_scope != scope_id
&& def_range.start < range.start
&& scopes
.get(*def_scope)
.is_some_and(|s| range.start >= s.start && range.end <= s.end)
});
if !is_shadowed {
highlights.push(HighlightSpan {
range: range.clone(),
color: self.highlight_color,
});
}
}
}
} else {
for (range, name) in &references {
if name == &target_name && range.start < viewport_end && range.end > viewport_start
{
highlights.push(HighlightSpan {
range: range.clone(),
color: self.highlight_color,
});
}
}
}
highlights
}
fn highlight_with_tree_sitter(
&mut self,
buffer: &Buffer,
cursor_position: usize,
viewport_start: usize,
viewport_end: usize,
context_bytes: usize,
) -> Vec<HighlightSpan> {
let parser = match &mut self.parser {
Some(p) => p,
None => return Vec::new(),
};
let query = match &self.identifier_query {
Some(q) => q,
None => return Vec::new(),
};
let parse_start = viewport_start.saturating_sub(context_bytes);
let parse_end = (viewport_end + context_bytes).min(buffer.len());
let source = buffer.slice_bytes(parse_start..parse_end);
let tree = match parser.parse(&source, None) {
Some(t) => t,
None => {
tracing::debug!("Tree-sitter parsing failed, falling back to text matching");
return self.highlight_with_text_matching(
buffer,
cursor_position,
viewport_start,
viewport_end,
context_bytes,
);
}
};
let mut query_cursor = QueryCursor::new();
let mut matches = query_cursor.matches(query, tree.root_node(), source.as_slice());
let mut identifiers: Vec<(Range<usize>, String)> = Vec::new();
let mut cursor_identifier: Option<String> = None;
while let Some(m) = matches.next() {
for capture in m.captures {
let node = capture.node;
let start = parse_start + node.start_byte();
let end = parse_start + node.end_byte();
let text_bytes = &source[node.start_byte()..node.end_byte()];
let text = match std::str::from_utf8(text_bytes) {
Ok(s) => s.to_string(),
Err(_) => continue,
};
if text.len() < self.min_word_length {
continue;
}
if cursor_position >= start && cursor_position <= end {
cursor_identifier = Some(text.clone());
}
if start < viewport_end && end > viewport_start {
identifiers.push((start..end, text));
}
}
}
let target_identifier = match cursor_identifier {
Some(id) => id,
None => return Vec::new(),
};
identifiers
.into_iter()
.filter(|(_, text)| text == &target_identifier)
.map(|(range, _)| HighlightSpan {
range,
color: self.highlight_color,
})
.collect()
}
#[allow(unused_variables)] fn highlight_with_text_matching(
&self,
buffer: &Buffer,
cursor_position: usize,
viewport_start: usize,
viewport_end: usize,
context_bytes: usize,
) -> Vec<HighlightSpan> {
let word_range = match self.get_word_at_position(buffer, cursor_position) {
Some(range) => range,
None => return Vec::new(),
};
let word_bytes = buffer.slice_bytes(word_range.clone());
let word = match std::str::from_utf8(&word_bytes) {
Ok(s) => s.to_string(),
Err(_) => return Vec::new(),
};
if word.len() < self.min_word_length {
return Vec::new();
}
let occurrences =
self.find_occurrences_in_range(buffer, &word, viewport_start, viewport_end);
occurrences
.into_iter()
.map(|range| HighlightSpan {
range,
color: self.highlight_color,
})
.collect()
}
fn get_word_at_position(&self, buffer: &Buffer, position: usize) -> Option<Range<usize>> {
let buf_len = buffer.len();
if position > buf_len {
return None;
}
let is_on_word = if position < buf_len {
let byte_at_pos = buffer.slice_bytes(position..position + 1);
byte_at_pos
.first()
.map(|&b| is_word_char(b))
.unwrap_or(false)
} else if position > 0 {
let byte_before = buffer.slice_bytes(position - 1..position);
byte_before
.first()
.map(|&b| is_word_char(b))
.unwrap_or(false)
} else {
false
};
if !is_on_word && position > 0 {
let byte_before = buffer.slice_bytes(position.saturating_sub(1)..position);
let is_after_word = byte_before
.first()
.map(|&b| is_word_char(b))
.unwrap_or(false);
if is_after_word && position >= buf_len {
let start = find_word_start(buffer, position.saturating_sub(1));
let end = position;
if start < end {
return Some(start..end);
}
}
return None;
}
if !is_on_word {
return None;
}
let start = find_word_start(buffer, position);
let end = find_word_end(buffer, position);
if start < end {
Some(start..end)
} else {
None
}
}
const MAX_SEARCH_RANGE: usize = 1024 * 1024;
fn find_occurrences_in_range(
&self,
buffer: &Buffer,
word: &str,
start: usize,
end: usize,
) -> Vec<Range<usize>> {
if end.saturating_sub(start) > Self::MAX_SEARCH_RANGE {
return Vec::new();
}
let mut occurrences = Vec::new();
let search_start = start.saturating_sub(word.len());
let search_end = (end + word.len()).min(buffer.len());
let bytes = buffer.slice_bytes(search_start..search_end);
let text = match std::str::from_utf8(&bytes) {
Ok(s) => s,
Err(_) => return occurrences,
};
for (rel_pos, _) in text.match_indices(word) {
let abs_start = search_start + rel_pos;
let abs_end = abs_start + word.len();
let is_word_start = abs_start == 0 || {
let prev_byte = buffer.slice_bytes(abs_start - 1..abs_start);
prev_byte.first().map(|&b| !is_word_char(b)).unwrap_or(true)
};
let is_word_end = abs_end >= buffer.len() || {
let next_byte = buffer.slice_bytes(abs_end..abs_end + 1);
next_byte.first().map(|&b| !is_word_char(b)).unwrap_or(true)
};
if is_word_start && is_word_end {
if abs_start < end && abs_end > start {
occurrences.push(abs_start..abs_end);
}
}
}
occurrences
}
}
impl Default for ReferenceHighlighter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_word_at_position() {
let buffer = Buffer::from_str_test("hello world test");
let highlighter = ReferenceHighlighter::new();
let range = highlighter.get_word_at_position(&buffer, 2).unwrap();
assert_eq!(range, 0..5);
let range = highlighter.get_word_at_position(&buffer, 6).unwrap();
assert_eq!(range, 6..11);
let range = highlighter.get_word_at_position(&buffer, 5);
assert!(range.is_none());
}
#[test]
fn test_find_occurrences() {
let buffer = Buffer::from_str_test("foo bar foo baz foo");
let highlighter = ReferenceHighlighter::new();
let occurrences = highlighter.find_occurrences_in_range(&buffer, "foo", 0, buffer.len());
assert_eq!(occurrences.len(), 3);
assert_eq!(occurrences[0], 0..3);
assert_eq!(occurrences[1], 8..11);
assert_eq!(occurrences[2], 16..19);
}
#[test]
fn test_whole_word_only() {
let buffer = Buffer::from_str_test("foobar foo foobaz");
let highlighter = ReferenceHighlighter::new();
let occurrences = highlighter.find_occurrences_in_range(&buffer, "foo", 0, buffer.len());
assert_eq!(occurrences.len(), 1);
assert_eq!(occurrences[0], 7..10);
}
#[test]
fn test_highlight_occurrences() {
let buffer = Buffer::from_str_test("let foo = 1;\nlet bar = foo;\nlet baz = foo;");
let mut highlighter = ReferenceHighlighter::new();
let spans = highlighter.highlight_occurrences(&buffer, 4, 0, buffer.len(), 100_000);
assert_eq!(spans.len(), 3);
}
#[test]
fn test_min_word_length() {
let buffer = Buffer::from_str_test("a b c a b c");
let mut highlighter = ReferenceHighlighter::new().with_min_length(2);
let spans = highlighter.highlight_occurrences(&buffer, 0, 0, buffer.len(), 100_000);
assert_eq!(spans.len(), 0);
}
#[test]
fn test_disabled() {
let buffer = Buffer::from_str_test("hello hello hello");
let mut highlighter = ReferenceHighlighter::new();
highlighter.enabled = false;
let spans = highlighter.highlight_occurrences(&buffer, 0, 0, buffer.len(), 100_000);
assert_eq!(spans.len(), 0);
}
#[test]
fn test_cursor_at_end_of_buffer() {
let buffer = Buffer::from_str_test("foo bar foo");
let mut highlighter = ReferenceHighlighter::new();
let spans =
highlighter.highlight_occurrences(&buffer, buffer.len(), 0, buffer.len(), 100_000);
assert_eq!(spans.len(), 2);
}
#[test]
fn test_cursor_on_word() {
let buffer = Buffer::from_str_test("foo bar foo");
let mut highlighter = ReferenceHighlighter::new();
let spans = highlighter.highlight_occurrences(&buffer, 0, 0, buffer.len(), 100_000);
assert_eq!(spans.len(), 2);
}
#[test]
fn test_viewport_limiting() {
let buffer = Buffer::from_str_test("foo bar foo baz foo");
let mut highlighter = ReferenceHighlighter::new();
let spans = highlighter.highlight_occurrences(&buffer, 8, 4, 12, 100_000);
assert_eq!(spans.len(), 1);
assert_eq!(spans[0].range, 8..11);
}
#[test]
fn test_tree_sitter_mode() {
use crate::primitives::highlighter::Language;
let buffer = Buffer::from_str_test("fn main() {\n let foo = 1;\n let bar = foo;\n}");
let mut highlighter = ReferenceHighlighter::new();
highlighter.set_language(&Language::Rust);
let spans = highlighter.highlight_occurrences(&buffer, 20, 0, buffer.len(), 100_000);
assert!(spans.len() >= 2);
}
#[test]
fn test_tree_sitter_identifier_only() {
use crate::primitives::highlighter::Language;
let buffer = Buffer::from_str_test("let foo = 1;\nlet bar = foo;");
let mut highlighter = ReferenceHighlighter::new();
highlighter.set_language(&Language::Rust);
let spans = highlighter.highlight_occurrences(&buffer, 4, 0, buffer.len(), 100_000);
assert!(spans.len() >= 2);
}
#[test]
fn test_locals_mode_enabled() {
use crate::primitives::highlighter::Language;
let mut highlighter = ReferenceHighlighter::new();
highlighter.set_language(&Language::Rust);
assert!(highlighter.has_locals());
}
#[test]
fn test_scope_aware_highlighting() {
use crate::primitives::highlighter::Language;
let code = r#"
fn first() {
let foo = 1;
println!("{}", foo);
}
fn second() {
let foo = 2;
println!("{}", foo);
}
"#;
let buffer = Buffer::from_str_test(code);
let mut highlighter = ReferenceHighlighter::new();
highlighter.set_language(&Language::Rust);
let first_foo_pos = code.find("let foo = 1").unwrap() + 4;
let spans =
highlighter.highlight_occurrences(&buffer, first_foo_pos, 0, buffer.len(), 100_000);
assert!(
spans.len() >= 2,
"Expected at least 2 spans, got {}",
spans.len()
);
}
#[test]
fn test_shadowing_in_nested_scope() {
use crate::primitives::highlighter::Language;
let code = r#"
fn main() {
let foo = 1;
{
let foo = 2;
println!("{}", foo);
}
println!("{}", foo);
}
"#;
let buffer = Buffer::from_str_test(code);
let mut highlighter = ReferenceHighlighter::new();
highlighter.set_language(&Language::Rust);
let outer_foo_pos = code.find("let foo = 1").unwrap() + 4;
let spans =
highlighter.highlight_occurrences(&buffer, outer_foo_pos, 0, buffer.len(), 100_000);
assert!(
spans.len() >= 2,
"Expected at least 2 spans, got {}",
spans.len()
);
}
#[test]
fn test_parameter_highlighting() {
use crate::primitives::highlighter::Language;
let code = r#"
fn greet(name: &str) {
println!("Hello, {}", name);
println!("Goodbye, {}", name);
}
"#;
let buffer = Buffer::from_str_test(code);
let mut highlighter = ReferenceHighlighter::new();
highlighter.set_language(&Language::Rust);
let name_pos = code.find("name: &str").unwrap();
let spans = highlighter.highlight_occurrences(&buffer, name_pos, 0, buffer.len(), 100_000);
assert!(
spans.len() >= 3,
"Expected at least 3 spans, got {}",
spans.len()
);
}
}