use crate::error::{Result, SpliceError};
use std::path::PathBuf;
#[derive(Debug, Clone)]
struct ScopedSymbol {
name: String,
declaration_pos: usize,
}
#[derive(Debug, Clone)]
struct Scope {
start: usize,
end: usize,
symbols: Vec<ScopedSymbol>,
_parent: Option<usize>,
}
#[derive(Debug, Clone)]
pub(crate) struct ScopeMap {
scopes: Vec<Scope>,
}
impl ScopeMap {
fn new() -> Self {
Self { scopes: Vec::new() }
}
fn add_scope(&mut self, start: usize, end: usize, parent: Option<usize>) -> usize {
let idx = self.scopes.len();
self.scopes.push(Scope {
start,
end,
symbols: Vec::new(),
_parent: parent,
});
idx
}
fn add_symbol(&mut self, scope_idx: usize, name: String, declaration_pos: usize) {
if let Some(scope) = self.scopes.get_mut(scope_idx) {
scope.symbols.push(ScopedSymbol {
name,
declaration_pos,
});
}
}
pub(crate) fn is_shadowed_at(&self, name: &str, byte_offset: usize) -> bool {
for scope in &self.scopes {
if byte_offset >= scope.start && byte_offset < scope.end {
for symbol in &scope.symbols {
if symbol.name == name && byte_offset >= symbol.declaration_pos {
return true;
}
}
}
}
false
}
}
pub(crate) fn build_scope_map(source: &[u8]) -> Result<ScopeMap> {
let mut scope_map = ScopeMap::new();
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_rust::LANGUAGE.into())
.map_err(|e| SpliceError::Parse {
file: PathBuf::from("<source>"),
message: format!("Failed to set Rust language: {:?}", e),
})?;
let tree = parser
.parse(source, None)
.ok_or_else(|| SpliceError::Parse {
file: PathBuf::from("<source>"),
message: "Parse failed - no tree returned".to_string(),
})?;
let file_scope = scope_map.add_scope(0, source.len(), None);
build_scopes_recursive(tree.root_node(), source, &mut scope_map, file_scope);
Ok(scope_map)
}
fn build_scopes_recursive(
node: tree_sitter::Node,
source: &[u8],
scope_map: &mut ScopeMap,
current_scope: usize,
) {
match node.kind() {
"function_item" => {
if let Some(body) = node.child_by_field_name("body") {
let func_name = node
.child_by_field_name("name")
.and_then(|n| n.utf8_text(source).ok())
.map(|s| s.to_string());
let scope_idx =
scope_map.add_scope(body.start_byte(), body.end_byte(), Some(current_scope));
let is_nested_function = current_scope > 0;
if is_nested_function {
if let Some(name) = func_name {
scope_map.add_symbol(current_scope, name, node.start_byte());
}
}
if let Some(params) = node.child_by_field_name("parameters") {
for (i, name) in
extract_param_names(params, source, &mut std::collections::HashSet::new())
.into_iter()
.enumerate()
{
scope_map.add_symbol(scope_idx, name, body.start_byte() + i);
}
}
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
build_scopes_recursive(child, source, scope_map, scope_idx);
}
return;
}
}
"closure_expression" => {
let scope_idx =
scope_map.add_scope(node.start_byte(), node.end_byte(), Some(current_scope));
if let Some(params) = node.child_by_field_name("parameters") {
for (i, name) in
extract_param_names(params, source, &mut std::collections::HashSet::new())
.into_iter()
.enumerate()
{
scope_map.add_symbol(scope_idx, name, node.start_byte() + i);
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
build_scopes_recursive(child, source, scope_map, scope_idx);
}
return;
}
"let_declaration" => {
if let Some(name) = extract_let_binding_name(node, source) {
scope_map.add_symbol(current_scope, name, node.start_byte());
}
}
"match_arm" => {
if let Some(pattern) = node.child_by_field_name("pattern") {
let bindings = extract_pattern_bindings(pattern, source);
for binding in bindings {
scope_map.add_symbol(current_scope, binding, node.start_byte());
}
}
}
"block" => {
let scope_idx =
scope_map.add_scope(node.start_byte(), node.end_byte(), Some(current_scope));
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
build_scopes_recursive(child, source, scope_map, scope_idx);
}
return;
}
_ => {}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
build_scopes_recursive(child, source, scope_map, current_scope);
}
}
fn extract_param_names(
node: tree_sitter::Node,
source: &[u8],
_seen: &mut std::collections::HashSet<String>,
) -> Vec<String> {
let mut names = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"parameter" => {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(source) {
names.push(name.to_string());
}
} else {
let mut inner_cursor = child.walk();
for inner_child in child.children(&mut inner_cursor) {
if inner_child.kind() == "identifier" {
if let Ok(name) = inner_child.utf8_text(source) {
names.push(name.to_string());
break;
}
}
}
}
}
"," => continue,
_ => {}
}
}
names
}
fn extract_let_binding_name(node: tree_sitter::Node, source: &[u8]) -> Option<String> {
if let Some(pattern) = node.child_by_field_name("pattern") {
if pattern.kind() == "identifier" {
if let Ok(name) = pattern.utf8_text(source) {
return Some(name.to_string());
}
}
let mut cursor = pattern.walk();
for child in pattern.children(&mut cursor) {
if child.kind() == "identifier" {
if let Ok(name) = child.utf8_text(source) {
return Some(name.to_string());
}
}
}
}
None
}
fn extract_pattern_bindings(node: tree_sitter::Node, source: &[u8]) -> Vec<String> {
let mut bindings = Vec::new();
match node.kind() {
"identifier" => {
if let Ok(name) = node.utf8_text(source) {
bindings.push(name.to_string());
}
}
"tuple_pattern" | "struct_pattern" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" {
if let Ok(name) = child.utf8_text(source) {
bindings.push(name.to_string());
}
}
}
}
_ => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
bindings.extend(extract_pattern_bindings(child, source));
}
}
}
bindings
}