use crate::LanguageName;
use crate::query_compiler::LanguageInfo;
use crate::{query_compiler, range_union};
#[derive(Debug, Clone)]
pub enum FileParseError {
FailedToAttachLanguage {
language_name: LanguageName,
message: String,
},
InvalidFileRange {
range: tree_sitter::Range,
message: String,
},
}
#[rustfmt::skip]
impl std::fmt::Display for FileParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::FailedToAttachLanguage { language_name, message }
=> write!(f, "language {language_name} incompatible with parser: {message}"),
Self::InvalidFileRange { range, message }
=> write!(f, "tree_sitter rejected range restriction {range:?}: {message}"),
}
}
}
pub fn parse(
source_code: &[u8],
language_name: LanguageName,
language: &tree_sitter::Language,
) -> Result<tree_sitter::Tree, FileParseError> {
parse_ranged(source_code, language_name, language, None)
}
pub fn parse_ranged(
source_code: &[u8],
language_name: LanguageName,
language: &tree_sitter::Language,
range: Option<tree_sitter::Range>,
) -> Result<tree_sitter::Tree, FileParseError> {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(language)
.map_err(|e| FileParseError::FailedToAttachLanguage {
language_name,
message: e.to_string(),
})?;
if let Some(range) = range {
parser
.set_included_ranges(&[range])
.map_err(|e| FileParseError::InvalidFileRange {
range,
message: e.to_string(),
})?;
}
Ok(parser
.parse(source_code, None)
.expect("parse() should have returned a tree if parser.set_language() was called"))
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub ranges: range_union::RangeUnion,
pub recurse_names: Vec<String>,
pub import_origins: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct InjectionRange {
pub range: tree_sitter::Range,
pub context: range_union::RangeUnion,
pub language_hint: Option<String>,
}
pub fn end_point_to_end_line(p: tree_sitter::Point) -> usize {
if p.column == 0 {
p.row
} else {
p.row.saturating_add(1)
}
}
fn extract_name<'a>(bytes: &'a [u8], language_info: &LanguageInfo) -> &'a str {
let full_name = std::str::from_utf8(bytes).unwrap();
match language_info.name_transform.as_ref() {
None => full_name,
Some(f) => f(full_name),
}
}
pub fn find_names(
source_code: &[u8],
tree: &tree_sitter::Tree,
language_info: &LanguageInfo,
pattern: ®ex::Regex,
) -> Vec<String> {
use tree_sitter::StreamingIterator;
let mut cursor = tree_sitter::QueryCursor::new();
let mut names: std::vec::Vec<String> = std::vec::Vec::new();
let mut matches = cursor.matches(
&language_info.definition_query.query,
tree.root_node(),
source_code,
);
while let Some(query_match) = matches.next() {
names.extend(query_match.captures.iter().filter_map(|capture| {
if capture.index != language_info.definition_query.index_name {
return None;
}
let name = extract_name(&source_code[capture.node.byte_range()], language_info);
if pattern.is_match(name) {
Some(name.to_owned())
} else {
None
}
}));
}
names.dedup(); names.sort();
names.dedup();
names
}
pub fn find_definition(
source_code: &[u8],
tree: &tree_sitter::Tree,
language_info: &LanguageInfo,
pattern: ®ex::Regex,
recurse: bool,
) -> SearchResult {
use tree_sitter::StreamingIterator;
let mut ranges: range_union::RangeUnion = Default::default();
let mut cursor = tree_sitter::QueryCursor::new();
let mut recurse_cursor = tree_sitter::QueryCursor::new();
let mut recurse_names: std::vec::Vec<String> = std::vec::Vec::new();
let mut context_cursor = tree_sitter::QueryCursor::new();
context_cursor.set_max_start_depth(Some(0));
let mut matches = cursor.matches(
&language_info.definition_query.query,
tree.root_node(),
source_code,
);
while let Some(query_match) = matches.next() {
if ! query_match.captures.iter().any(|capture| {
capture.index == language_info.definition_query.index_name
&& pattern.is_match(extract_name(&source_code[capture.node.byte_range()], language_info))
}) {
continue;
}
for capture in query_match
.captures
.iter()
.filter(|capture| capture.index == language_info.definition_query.index_def)
{
let mut node = capture.node;
ranges
.push(node.range().start_point.row..end_point_to_end_line(node.range().end_point));
if recurse {
if let Some(recurse_query) = &language_info.recurse_query {
let mut recurse_matches =
recurse_cursor.matches(&recurse_query.query, node, source_code);
while let Some(recurse_match) = recurse_matches.next() {
for recurse_capture in
recurse_match.captures.iter().filter(|recurse_capture| {
recurse_capture.index == recurse_query.index_name
})
{
let recurse_name = std::str::from_utf8(
&source_code[recurse_capture.node.byte_range()],
)
.unwrap();
recurse_names.push(String::from(recurse_name));
}
}
}
}
while let Some(same_line_ancestor) = node.parent() {
if same_line_ancestor.range().start_point.row == node.range().start_point.row {
node = same_line_ancestor
} else {
break;
}
}
let mut last_ambiguously_attached_sibling_range: Option<std::ops::Range<usize>> = None;
while let Some(sibling) = node.prev_sibling() {
if match std::num::NonZero::new(sibling.kind_id()) {
None => false,
Some(kind_id) => language_info.sibling_node_types.contains(&kind_id),
} {
let new_sibling_range = sibling.range().start_point.row
..end_point_to_end_line(sibling.range().end_point);
if let Some(r) = last_ambiguously_attached_sibling_range {
ranges.push(r);
}
last_ambiguously_attached_sibling_range = Some(new_sibling_range);
node = sibling;
} else {
if let Some(r) = last_ambiguously_attached_sibling_range {
let sibling_end_line = end_point_to_end_line(sibling.range().end_point);
if sibling_end_line < r.end {
ranges.push(sibling_end_line.max(r.start)..r.end);
}
last_ambiguously_attached_sibling_range = None;
}
break;
}
}
if let Some(r) = last_ambiguously_attached_sibling_range {
ranges.push(r);
}
if let Some(parent_query) = &language_info.parent_query {
ranges.extend(AncestorRangeIterator {
node: capture.node,
cursor: &mut context_cursor,
query: parent_query,
source_code,
});
}
}
}
let mut import_origins: Vec<String> = vec![];
if let Some(import_query) = &language_info.import_query {
cursor
.matches(&import_query.query, tree.root_node(), source_code)
.filter(|query_match| {
query_match.captures.iter().any(|capture| {
capture.index == import_query.index_name
&& pattern.is_match(
std::str::from_utf8(&source_code[capture.node.byte_range()]).unwrap(),
)
})
})
.for_each(|query_match| {
import_origins.extend(
query_match
.captures
.iter()
.filter(|capture| capture.index == import_query.index_origin)
.map(|capture| {
std::str::from_utf8(&source_code[capture.node.byte_range()])
.unwrap()
.to_owned()
}),
)
});
}
recurse_names.sort();
recurse_names.dedup();
SearchResult {
ranges,
recurse_names,
import_origins,
}
}
pub fn find_injections(
source_code: &[u8],
tree: &tree_sitter::Tree,
language_info: &LanguageInfo,
pattern: ®ex::Regex,
) -> Vec<InjectionRange> {
use tree_sitter::StreamingIterator;
let mut cursor = tree_sitter::QueryCursor::new();
let mut injections: Vec<InjectionRange> = vec![];
let mut context_cursor = tree_sitter::QueryCursor::new();
context_cursor.set_max_start_depth(Some(0));
if let Some(injection_query) = &language_info.injection_query {
cursor
.matches(&injection_query.query, tree.root_node(), source_code)
.for_each(|query_match| {
let pattern_index = query_match.pattern_index;
let language_hint = match injection_query
.language_hints_by_pattern_index
.get(pattern_index)
{
None => None,
Some(query_compiler::InjectionLanguageHint::Absent) => None,
Some(query_compiler::InjectionLanguageHint::Fixed(s)) => Some(s.as_ref()),
Some(query_compiler::InjectionLanguageHint::Capture(capture_index)) => query_match
.captures
.get(*capture_index)
.and_then(|c| std::str::from_utf8(&source_code[c.node.byte_range()]).ok()),
};
injections.extend(
query_match
.captures
.iter()
.filter(|capture| {
if capture.index != injection_query.index_range {
return false;
}
let Ok(substring) =
std::str::from_utf8(&source_code[capture.node.byte_range()])
else {
return false;
};
pattern.is_match(substring)
})
.map(|capture| InjectionRange {
range: capture.node.range(),
language_hint: language_hint.map(|s| s.to_owned()),
context: match &language_info.parent_query {
Some(query) => AncestorRangeIterator {
node: capture.node,
cursor: &mut context_cursor,
query,
source_code,
}
.into(),
None => Default::default(),
},
}),
)
});
}
injections
}
struct AncestorRangeIterator<'it> {
node: tree_sitter::Node<'it>,
cursor: &'it mut tree_sitter::QueryCursor,
query: &'it query_compiler::ParentQuery,
source_code: &'it [u8],
}
impl Iterator for AncestorRangeIterator<'_> {
type Item = std::ops::Range<usize>;
fn next(&mut self) -> Option<Self::Item> {
use tree_sitter::StreamingIterator;
while let Some(parent) = self.node.parent() {
let mut parent_matches =
self.cursor
.matches(&self.query.query, parent, self.source_code);
let context_start = parent.range().start_point.row;
let mut context_end = parent.range().end_point;
let mut matched = false;
while let Some(parent_match) = parent_matches.next() {
for capture in parent_match
.captures
.iter()
.filter(|c| Some(c.index) == self.query.index_exclude)
{
if let Some(prev) = capture.node.prev_sibling() {
context_end = context_end.min(prev.range().end_point);
}
}
matched = true;
}
self.node = parent;
if matched {
return Some(context_start..end_point_to_end_line(context_end));
}
}
None
}
}