use std::collections::HashMap;
use std::path::Path;
use tree_sitter::StreamingIterator;
use super::types::{
capture_name_to_chunk_type, Chunk, ChunkType, ChunkTypeRefs, FunctionCalls, Language,
ParserError,
};
use super::Parser;
use crate::language::InjectionRule;
const MAX_INJECTION_RANGES: usize = 1000;
const MAX_INJECTION_DEPTH: usize = 3;
pub(crate) struct InjectionGroup {
pub language: Language,
pub ranges: Vec<tree_sitter::Range>,
pub container_lines: Vec<(u32, u32)>,
}
pub(crate) fn find_injection_ranges(
tree: &tree_sitter::Tree,
source: &str,
rules: &[InjectionRule],
) -> Vec<InjectionGroup> {
let _span = tracing::debug_span!("find_injection_ranges", rules = rules.len()).entered();
let mut entries: Vec<(&str, tree_sitter::Range, (u32, u32))> = Vec::new();
let root = tree.root_node();
for rule in rules {
walk_for_containers(root, rule, source, &mut entries);
}
if entries.is_empty() {
return vec![];
}
if entries.len() > MAX_INJECTION_RANGES {
tracing::warn!(
count = entries.len(),
limit = MAX_INJECTION_RANGES,
"Too many injection ranges, truncating to limit"
);
entries.truncate(MAX_INJECTION_RANGES);
}
entries.dedup_by(|a, b| a.1.start_byte == b.1.start_byte && a.1.end_byte == b.1.end_byte);
let mut group_index: HashMap<Language, usize> = HashMap::new();
let mut groups: Vec<InjectionGroup> = Vec::new();
for (lang_name, range, lines) in entries {
let language = match lang_name.parse::<Language>() {
Ok(lang) if lang.is_enabled() && lang.def().grammar.is_some() => lang,
Ok(lang) => {
tracing::warn!(
language = lang_name,
"Injection target language '{}' not available (disabled or no grammar)",
lang
);
continue;
}
Err(_) => {
tracing::warn!(
language = lang_name,
"Injection target language '{}' not recognized",
lang_name
);
continue;
}
};
if let Some(&idx) = group_index.get(&language) {
groups[idx].ranges.push(range);
groups[idx].container_lines.push(lines);
} else {
let idx = groups.len();
group_index.insert(language, idx);
groups.push(InjectionGroup {
language,
ranges: vec![range],
container_lines: vec![lines],
});
}
}
groups
}
fn advance_cursor(cursor: &mut tree_sitter::TreeCursor) -> bool {
if cursor.goto_next_sibling() {
return true;
}
loop {
if !cursor.goto_parent() {
return false;
}
if cursor.goto_next_sibling() {
return true;
}
}
}
fn byte_offset_to_point(source: &str, byte: usize) -> tree_sitter::Point {
let byte = byte.min(source.len());
let byte = source.floor_char_boundary(byte);
let before = &source[..byte];
let row = before.as_bytes().iter().filter(|&&b| b == b'\n').count();
let col = before.len() - before.rfind('\n').map(|p| p + 1).unwrap_or(0);
tree_sitter::Point { row, column: col }
}
fn walk_for_containers(
root: tree_sitter::Node,
rule: &InjectionRule,
source: &str,
entries: &mut Vec<(&str, tree_sitter::Range, (u32, u32))>,
) {
let mut cursor = root.walk();
loop {
let node = cursor.node();
if node.kind() == rule.container_kind {
let target = if let Some(detect) = rule.detect_language {
detect(node, source).unwrap_or(rule.target_language)
} else {
rule.target_language
};
if target != "_skip" {
let container_lines = (
node.start_position().row as u32 + 1,
node.end_position().row as u32 + 1,
);
if rule.content_kind == "_inner" {
let text = &source[node.byte_range()];
if let Some(tag_close) = text.find('>') {
let content_start = node.start_byte() + tag_close + 1;
if let Some(close_pos) = text.rfind("</") {
let content_end = node.start_byte() + close_pos;
if content_start < content_end {
let start_point = byte_offset_to_point(source, content_start);
let end_point = byte_offset_to_point(source, content_end);
let range = tree_sitter::Range {
start_byte: content_start,
end_byte: content_end,
start_point,
end_point,
};
entries.push((target, range, container_lines));
}
}
}
} else {
let mut child_cursor = node.walk();
for child in node.children(&mut child_cursor) {
if child.kind() == rule.content_kind {
let byte_range = child.byte_range();
if byte_range.start < byte_range.end {
let range = tree_sitter::Range {
start_byte: byte_range.start,
end_byte: byte_range.end,
start_point: child.start_position(),
end_point: child.end_position(),
};
let child_lines = if rule.content_scoped_lines {
(
child.start_position().row as u32 + 1,
child.end_position().row as u32 + 1,
)
} else {
container_lines
};
entries.push((target, range, child_lines));
}
}
}
}
}
if !advance_cursor(&mut cursor) {
return;
}
continue;
}
if cursor.goto_first_child() {
continue;
}
if !advance_cursor(&mut cursor) {
return;
}
}
}
fn build_injection_tree(
language: Language,
source: &str,
ranges: &[tree_sitter::Range],
) -> Option<tree_sitter::Tree> {
let grammar = language.try_grammar()?;
let mut parser = tree_sitter::Parser::new();
if let Err(e) = parser.set_language(&grammar) {
tracing::warn!(
error = ?e,
%language,
"Failed to set language for injection"
);
return None;
}
if let Err(e) = parser.set_included_ranges(ranges) {
tracing::warn!(
error = %e,
%language,
"Failed to set included ranges for injection"
);
return None;
}
let tree = parser.parse(source, None);
if tree.is_none() {
tracing::warn!(%language, "Injection parse returned None");
}
tree
}
impl Parser {
pub(crate) fn parse_injected_chunks(
&self,
source: &str,
path: &Path,
group: &InjectionGroup,
depth: usize,
) -> Result<Vec<Chunk>, ParserError> {
let inner_language = group.language;
let _span = tracing::info_span!(
"parse_injected_chunks",
language = %inner_language,
range_count = group.ranges.len(),
depth = depth,
path = %path.display()
)
.entered();
let tree = match build_injection_tree(inner_language, source, &group.ranges) {
Some(t) => t,
None => return Ok(vec![]),
};
let query = match self.get_query(inner_language) {
Ok(q) => q,
Err(e) => {
tracing::warn!(
error = %e,
language = %inner_language,
"Failed to get chunk query for injection language"
);
return Ok(vec![]);
}
};
let mut cursor = tree_sitter::QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), source.as_bytes());
let mut chunks = Vec::new();
while let Some(m) = matches.next() {
match self.extract_chunk(source, m, query, inner_language, path) {
Ok(mut chunk) => {
if chunk.content.len() > super::MAX_CHUNK_BYTES {
tracing::debug!(
id = %chunk.id,
bytes = chunk.content.len(),
"Skipping oversized injected chunk"
);
continue;
}
if let Some(post_process) = inner_language.def().post_process_chunk {
if let Some(node) = super::extract_definition_node(m, query) {
if !post_process(&mut chunk.name, &mut chunk.chunk_type, node, source) {
continue;
}
}
}
chunk.language = inner_language;
chunks.push(chunk);
}
Err(e) => {
tracing::warn!(
error = %e,
language = %inner_language,
"Failed to extract injected chunk"
);
}
}
}
if chunks.is_empty() {
tracing::debug!(
language = %inner_language,
"Injection produced no chunks, keeping outer"
);
} else {
tracing::debug!(
language = %inner_language,
count = chunks.len(),
"Injection extracted chunks"
);
}
let inner_rules = inner_language.def().injections;
if !inner_rules.is_empty() && depth < MAX_INJECTION_DEPTH {
let nested_groups = find_injection_ranges(&tree, source, inner_rules);
if !nested_groups.is_empty() {
let _nested_span = tracing::debug_span!(
"recursive_injection",
depth = depth + 1,
language = %inner_language,
groups = nested_groups.len()
)
.entered();
for nested_group in &nested_groups {
let nested_chunks =
self.parse_injected_chunks(source, path, nested_group, depth + 1)?;
if !nested_chunks.is_empty() {
chunks.retain(|c| {
!chunk_within_container(
c.line_start,
c.line_end,
&nested_group.container_lines,
)
});
chunks.extend(nested_chunks);
}
}
}
} else if !inner_rules.is_empty() {
tracing::debug!(
depth = depth,
language = %inner_language,
"Injection depth limit reached, skipping nested rules"
);
}
Ok(chunks)
}
pub(crate) fn parse_injected_relationships(
&self,
source: &str,
group: &InjectionGroup,
depth: usize,
) -> Result<(Vec<FunctionCalls>, Vec<ChunkTypeRefs>), ParserError> {
let inner_language = group.language;
let _span = tracing::info_span!(
"parse_injected_relationships",
language = %inner_language,
range_count = group.ranges.len()
)
.entered();
let tree = match build_injection_tree(inner_language, source, &group.ranges) {
Some(t) => t,
None => return Ok((vec![], vec![])),
};
let chunk_query = match self.get_query(inner_language) {
Ok(q) => q,
Err(e) => {
tracing::warn!(error = %e, "No chunk query for injection language");
return Ok((vec![], vec![]));
}
};
let call_query = match self.get_call_query(inner_language) {
Ok(q) => Some(q),
Err(e) => {
tracing::debug!(
error = %e,
language = %inner_language,
"No call query for injection language, skipping call extraction"
);
None
}
};
let mut cursor = tree_sitter::QueryCursor::new();
let mut matches = cursor.matches(chunk_query, tree.root_node(), source.as_bytes());
let capture_names = chunk_query.capture_names();
let name_idx = chunk_query.capture_index_for_name("name");
let mut call_results = Vec::new();
let mut type_results = Vec::new();
let mut call_cursor = tree_sitter::QueryCursor::new();
let mut calls = Vec::new();
let mut seen = std::collections::HashSet::new();
while let Some(m) = matches.next() {
let func_node = m.captures.iter().find(|c| {
let name = capture_names.get(c.index as usize).copied().unwrap_or("");
capture_name_to_chunk_type(name).is_some()
});
let Some(func_capture) = func_node else {
continue;
};
let node = func_capture.node;
let mut name = name_idx
.and_then(|idx| m.captures.iter().find(|c| c.index == idx))
.map(|c| source[c.node.byte_range()].to_string())
.unwrap_or_else(|| "<anonymous>".to_string());
if let Some(post_process) = inner_language.def().post_process_chunk {
let cap_name = capture_names
.get(func_capture.index as usize)
.copied()
.unwrap_or("");
let mut ct = capture_name_to_chunk_type(cap_name).unwrap_or(ChunkType::Function);
if !post_process(&mut name, &mut ct, node, source) {
continue;
}
}
let line_start = node.start_position().row as u32 + 1;
let byte_range = node.byte_range();
if let Some(call_query) = call_query {
call_cursor.set_byte_range(byte_range.clone());
calls.clear();
let mut call_matches =
call_cursor.matches(call_query, tree.root_node(), source.as_bytes());
while let Some(cm) = call_matches.next() {
for cap in cm.captures {
let callee_name = source[cap.node.byte_range()].to_string();
let call_line = cap.node.start_position().row as u32 + 1;
if !super::calls::should_skip_callee(&callee_name) {
calls.push(super::types::CallSite {
callee_name,
line_number: call_line,
});
}
}
}
seen.clear();
calls.retain(|c| seen.insert(c.callee_name.clone()));
if !calls.is_empty() {
call_results.push(FunctionCalls {
name: name.clone(),
line_start,
calls: std::mem::take(&mut calls),
});
}
}
let mut type_refs = self.extract_types(
source,
&tree,
inner_language,
byte_range.start,
byte_range.end,
);
type_refs.retain(|t| t.type_name != name);
if !type_refs.is_empty() {
type_results.push(ChunkTypeRefs {
name,
line_start,
type_refs,
});
}
}
tracing::debug!(
language = %inner_language,
calls = call_results.len(),
types = type_results.len(),
"Injection extracted relationships"
);
let inner_rules = inner_language.def().injections;
if !inner_rules.is_empty() && depth < MAX_INJECTION_DEPTH {
let nested_groups = find_injection_ranges(&tree, source, inner_rules);
for nested_group in &nested_groups {
let (nested_calls, nested_types) =
self.parse_injected_relationships(source, nested_group, depth + 1)?;
call_results.extend(nested_calls);
type_results.extend(nested_types);
}
}
Ok((call_results, type_results))
}
pub(crate) fn parse_injected_all(
&self,
source: &str,
path: &Path,
group: &InjectionGroup,
depth: usize,
) -> Result<super::ParseAllResult, ParserError> {
let inner_language = group.language;
let _span = tracing::info_span!(
"parse_injected_all",
language = %inner_language,
range_count = group.ranges.len(),
path = %path.display()
)
.entered();
let tree = match build_injection_tree(inner_language, source, &group.ranges) {
Some(t) => t,
None => return Ok((vec![], vec![], vec![])),
};
let chunk_query = match self.get_query(inner_language) {
Ok(q) => q,
Err(e) => {
tracing::warn!(
error = %e,
language = %inner_language,
"Failed to get chunk query for injection language"
);
return Ok((vec![], vec![], vec![]));
}
};
let mut cursor = tree_sitter::QueryCursor::new();
let mut matches = cursor.matches(chunk_query, tree.root_node(), source.as_bytes());
let mut chunks = Vec::new();
while let Some(m) = matches.next() {
match self.extract_chunk(source, m, chunk_query, inner_language, path) {
Ok(mut chunk) => {
if chunk.content.len() > super::MAX_CHUNK_BYTES {
tracing::debug!(
id = %chunk.id,
bytes = chunk.content.len(),
"Skipping oversized injected chunk"
);
continue;
}
if let Some(post_process) = inner_language.def().post_process_chunk {
if let Some(node) = super::extract_definition_node(m, chunk_query) {
if !post_process(&mut chunk.name, &mut chunk.chunk_type, node, source) {
continue;
}
}
}
chunk.language = inner_language;
chunks.push(chunk);
}
Err(e) => {
tracing::warn!(
error = %e,
language = %inner_language,
"Failed to extract injected chunk"
);
}
}
}
if chunks.is_empty() {
tracing::debug!(
language = %inner_language,
"Injection produced no chunks, keeping outer"
);
}
let call_query = match self.get_call_query(inner_language) {
Ok(q) => Some(q),
Err(e) => {
tracing::debug!(
error = %e,
language = %inner_language,
"No call query for injection language, skipping call extraction"
);
None
}
};
let mut cursor2 = tree_sitter::QueryCursor::new();
let mut matches2 = cursor2.matches(chunk_query, tree.root_node(), source.as_bytes());
let capture_names = chunk_query.capture_names();
let name_idx = chunk_query.capture_index_for_name("name");
let mut call_results = Vec::new();
let mut type_results = Vec::new();
let mut call_cursor = tree_sitter::QueryCursor::new();
let mut calls = Vec::new();
let mut seen = std::collections::HashSet::new();
while let Some(m) = matches2.next() {
let func_node = m.captures.iter().find(|c| {
let name = capture_names.get(c.index as usize).copied().unwrap_or("");
capture_name_to_chunk_type(name).is_some()
});
let Some(func_capture) = func_node else {
continue;
};
let node = func_capture.node;
let mut name = name_idx
.and_then(|idx| m.captures.iter().find(|c| c.index == idx))
.map(|c| source[c.node.byte_range()].to_string())
.unwrap_or_else(|| "<anonymous>".to_string());
if let Some(post_process) = inner_language.def().post_process_chunk {
let cap_name = capture_names
.get(func_capture.index as usize)
.copied()
.unwrap_or("");
let mut ct = capture_name_to_chunk_type(cap_name).unwrap_or(ChunkType::Function);
if !post_process(&mut name, &mut ct, node, source) {
continue;
}
}
let line_start = node.start_position().row as u32 + 1;
let byte_range = node.byte_range();
if let Some(cq) = call_query {
call_cursor.set_byte_range(byte_range.clone());
calls.clear();
let mut call_matches = call_cursor.matches(cq, tree.root_node(), source.as_bytes());
while let Some(cm) = call_matches.next() {
for cap in cm.captures {
let callee_name = source[cap.node.byte_range()].to_string();
let call_line = cap.node.start_position().row as u32 + 1;
if !super::calls::should_skip_callee(&callee_name) {
calls.push(super::types::CallSite {
callee_name,
line_number: call_line,
});
}
}
}
seen.clear();
calls.retain(|c| seen.insert(c.callee_name.clone()));
if !calls.is_empty() {
call_results.push(FunctionCalls {
name: name.clone(),
line_start,
calls: std::mem::take(&mut calls),
});
}
}
let mut type_refs = self.extract_types(
source,
&tree,
inner_language,
byte_range.start,
byte_range.end,
);
type_refs.retain(|t| t.type_name != name);
if !type_refs.is_empty() {
type_results.push(ChunkTypeRefs {
name,
line_start,
type_refs,
});
}
}
tracing::debug!(
language = %inner_language,
chunks = chunks.len(),
calls = call_results.len(),
types = type_results.len(),
"Injection extracted all"
);
let inner_rules = inner_language.def().injections;
if !inner_rules.is_empty() && depth < MAX_INJECTION_DEPTH {
let nested_groups = find_injection_ranges(&tree, source, inner_rules);
if !nested_groups.is_empty() {
let _nested_span = tracing::debug_span!(
"recursive_injection_all",
depth = depth + 1,
language = %inner_language,
groups = nested_groups.len()
)
.entered();
for nested_group in &nested_groups {
let (nested_chunks, nested_calls, nested_types) =
self.parse_injected_all(source, path, nested_group, depth + 1)?;
if !nested_chunks.is_empty() {
chunks.retain(|c| {
!chunk_within_container(
c.line_start,
c.line_end,
&nested_group.container_lines,
)
});
chunks.extend(nested_chunks);
}
call_results.extend(nested_calls);
type_results.extend(nested_types);
}
}
} else if !inner_rules.is_empty() {
tracing::debug!(
depth = depth,
language = %inner_language,
"Injection depth limit reached, skipping nested rules"
);
}
Ok((chunks, call_results, type_results))
}
}
pub(crate) fn chunk_within_container(
chunk_start: u32,
chunk_end: u32,
container_lines: &[(u32, u32)],
) -> bool {
container_lines
.iter()
.any(|&(start, end)| chunk_start >= start && chunk_end <= end)
}
#[cfg(test)]
mod tests {
use super::*;
mod chunk_within_container_tests {
use super::*;
#[test]
fn fully_contained() {
assert!(chunk_within_container(5, 10, &[(3, 15)]));
}
#[test]
fn exact_match() {
assert!(chunk_within_container(3, 15, &[(3, 15)]));
}
#[test]
fn start_boundary() {
assert!(chunk_within_container(3, 10, &[(3, 15)]));
}
#[test]
fn end_boundary() {
assert!(chunk_within_container(10, 15, &[(3, 15)]));
}
#[test]
fn not_contained_before() {
assert!(!chunk_within_container(1, 2, &[(3, 15)]));
}
#[test]
fn not_contained_after() {
assert!(!chunk_within_container(16, 20, &[(3, 15)]));
}
#[test]
fn partial_overlap_start() {
assert!(!chunk_within_container(1, 5, &[(3, 15)]));
}
#[test]
fn partial_overlap_end() {
assert!(!chunk_within_container(10, 20, &[(3, 15)]));
}
#[test]
fn empty_containers() {
assert!(!chunk_within_container(5, 10, &[]));
}
#[test]
fn multiple_containers() {
let containers = vec![(1, 3), (10, 20), (30, 40)];
assert!(chunk_within_container(12, 18, &containers));
assert!(!chunk_within_container(5, 8, &containers));
}
#[test]
fn single_line_chunk() {
assert!(chunk_within_container(5, 5, &[(3, 15)]));
assert!(!chunk_within_container(2, 2, &[(3, 15)]));
}
}
}