use tower_lsp::lsp_types::{Position, Range};
use tree_sitter::{Node, Tree};
use super::{ScopeKind, SymbolFlags, SymbolTable};
pub struct Binder<'a> {
source: &'a str,
symbol_table: SymbolTable,
current_scope: u32,
}
impl<'a> Binder<'a> {
pub fn new(source: &'a str) -> Self {
Self {
source,
symbol_table: SymbolTable::new(),
current_scope: 0,
}
}
pub fn bind(mut self, tree: &Tree) -> SymbolTable {
self.visit_node(tree.root_node());
self.symbol_table
}
fn visit_node(&mut self, node: Node) {
match node.kind() {
"function_declaration" => self.bind_function_declaration(node),
"class_declaration" => self.bind_class_declaration(node),
"interface_declaration" => self.bind_interface_declaration(node),
"type_alias_declaration" => self.bind_type_alias_declaration(node),
"enum_declaration" => self.bind_enum_declaration(node),
"lexical_declaration" => self.bind_lexical_declaration(node),
"variable_declaration" => self.bind_variable_declaration(node),
"import_statement" => self.bind_import_statement(node),
"arrow_function" => self.bind_arrow_function(node),
"method_definition" => self.bind_method_definition(node),
"statement_block" => self.bind_block(node),
"if_statement" | "for_statement" | "for_in_statement" | "for_of_statement"
| "while_statement" | "do_statement" | "switch_statement" => {
self.bind_control_flow(node)
}
"catch_clause" => self.bind_catch_clause(node),
"identifier" => self.bind_identifier_reference(node),
_ => self.visit_children(node),
}
}
fn visit_children(&mut self, node: Node) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.visit_node(child);
}
}
fn bind_function_declaration(&mut self, node: Node) {
let name_node = node.child_by_field_name("name");
let params_node = node.child_by_field_name("parameters");
let body_node = node.child_by_field_name("body");
if let Some(name) = name_node {
let name_text = self.node_text(&name);
let mut flags = SymbolFlags::FUNCTION | SymbolFlags::HOISTED;
if self.has_child_kind(&node, "async") {
flags |= SymbolFlags::ASYNC;
}
if let Some(parent) = node.parent() {
if parent.kind() == "export_statement" {
flags |= SymbolFlags::EXPORTED;
}
}
self.symbol_table.create_symbol(
name_text,
flags,
self.node_range(&node),
self.node_range(&name),
self.current_scope,
);
}
if let Some(body) = body_node {
let scope_id = self.symbol_table.create_scope(
ScopeKind::Function,
self.current_scope,
self.node_range(&body),
);
let old_scope = self.current_scope;
self.current_scope = scope_id;
if let Some(params) = params_node {
self.bind_parameters(params);
}
self.visit_children(body);
self.current_scope = old_scope;
}
}
fn bind_arrow_function(&mut self, node: Node) {
let params_node = node.child_by_field_name("parameters");
let body_node = node.child_by_field_name("body");
let scope_id = self.symbol_table.create_scope(
ScopeKind::Function,
self.current_scope,
self.node_range(&node),
);
let old_scope = self.current_scope;
self.current_scope = scope_id;
if let Some(params) = params_node {
if params.kind() == "identifier" {
let name = self.node_text(¶ms);
self.symbol_table.create_symbol(
name,
SymbolFlags::PARAMETER,
self.node_range(¶ms),
self.node_range(¶ms),
self.current_scope,
);
} else {
self.bind_parameters(params);
}
} else if let Some(param) = node.child_by_field_name("parameter") {
let name = self.node_text(¶m);
self.symbol_table.create_symbol(
name,
SymbolFlags::PARAMETER,
self.node_range(¶m),
self.node_range(¶m),
self.current_scope,
);
}
if let Some(body) = body_node {
self.visit_node(body);
}
self.current_scope = old_scope;
}
fn bind_parameters(&mut self, params: Node) {
let mut cursor = params.walk();
for child in params.children(&mut cursor) {
match child.kind() {
"required_parameter" | "optional_parameter" | "rest_parameter" => {
if let Some(pattern) = child.child_by_field_name("pattern") {
self.bind_pattern(pattern, SymbolFlags::PARAMETER);
} else {
let mut param_cursor = child.walk();
for param_child in child.children(&mut param_cursor) {
if param_child.kind() == "identifier" {
let name = self.node_text(¶m_child);
self.symbol_table.create_symbol(
name,
SymbolFlags::PARAMETER,
self.node_range(&child),
self.node_range(¶m_child),
self.current_scope,
);
break;
}
}
}
}
"identifier" => {
let name = self.node_text(&child);
self.symbol_table.create_symbol(
name,
SymbolFlags::PARAMETER,
self.node_range(&child),
self.node_range(&child),
self.current_scope,
);
}
_ => {}
}
}
}
fn bind_class_declaration(&mut self, node: Node) {
let name_node = node.child_by_field_name("name");
let body_node = node.child_by_field_name("body");
if let Some(name) = name_node {
let name_text = self.node_text(&name);
let mut flags = SymbolFlags::CLASS;
if let Some(parent) = node.parent() {
if parent.kind() == "export_statement" {
flags |= SymbolFlags::EXPORTED;
}
}
self.symbol_table.create_symbol(
name_text,
flags,
self.node_range(&node),
self.node_range(&name),
self.current_scope,
);
}
if let Some(body) = body_node {
let scope_id = self.symbol_table.create_scope(
ScopeKind::Class,
self.current_scope,
self.node_range(&body),
);
let old_scope = self.current_scope;
self.current_scope = scope_id;
self.visit_children(body);
self.current_scope = old_scope;
}
}
fn bind_interface_declaration(&mut self, node: Node) {
let name_node = node.child_by_field_name("name");
if let Some(name) = name_node {
let name_text = self.node_text(&name);
let mut flags = SymbolFlags::INTERFACE;
if let Some(parent) = node.parent() {
if parent.kind() == "export_statement" {
flags |= SymbolFlags::EXPORTED;
}
}
self.symbol_table.create_symbol(
name_text,
flags,
self.node_range(&node),
self.node_range(&name),
self.current_scope,
);
}
self.visit_children(node);
}
fn bind_type_alias_declaration(&mut self, node: Node) {
let name_node = node.child_by_field_name("name");
if let Some(name) = name_node {
let name_text = self.node_text(&name);
let mut flags = SymbolFlags::TYPE_ALIAS;
if let Some(parent) = node.parent() {
if parent.kind() == "export_statement" {
flags |= SymbolFlags::EXPORTED;
}
}
self.symbol_table.create_symbol(
name_text,
flags,
self.node_range(&node),
self.node_range(&name),
self.current_scope,
);
}
self.visit_children(node);
}
fn bind_enum_declaration(&mut self, node: Node) {
let name_node = node.child_by_field_name("name");
if let Some(name) = name_node {
let name_text = self.node_text(&name);
let mut flags = SymbolFlags::ENUM;
if let Some(parent) = node.parent() {
if parent.kind() == "export_statement" {
flags |= SymbolFlags::EXPORTED;
}
}
self.symbol_table.create_symbol(
name_text,
flags,
self.node_range(&node),
self.node_range(&name),
self.current_scope,
);
}
self.visit_children(node);
}
fn bind_lexical_declaration(&mut self, node: Node) {
let is_const = self.has_child_kind(&node, "const");
let base_flags = if is_const {
SymbolFlags::VARIABLE | SymbolFlags::CONST
} else {
SymbolFlags::VARIABLE | SymbolFlags::LET
};
self.bind_variable_declarators(node, base_flags);
}
fn bind_variable_declaration(&mut self, node: Node) {
let flags = SymbolFlags::VARIABLE | SymbolFlags::HOISTED;
self.bind_variable_declarators(node, flags);
}
fn bind_variable_declarators(&mut self, node: Node, base_flags: SymbolFlags) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "variable_declarator" {
if let Some(name_node) = child.child_by_field_name("name") {
self.bind_pattern(name_node, base_flags);
}
if let Some(value) = child.child_by_field_name("value") {
self.visit_node(value);
}
}
}
}
fn bind_pattern(&mut self, pattern: Node, flags: SymbolFlags) {
match pattern.kind() {
"identifier" => {
let name = self.node_text(&pattern);
self.symbol_table.create_symbol(
name,
flags,
self.node_range(&pattern),
self.node_range(&pattern),
self.current_scope,
);
}
"object_pattern" => {
let mut cursor = pattern.walk();
for child in pattern.children(&mut cursor) {
if child.kind() == "shorthand_property_identifier_pattern" {
let name = self.node_text(&child);
self.symbol_table.create_symbol(
name,
flags,
self.node_range(&child),
self.node_range(&child),
self.current_scope,
);
} else if child.kind() == "pair_pattern" {
if let Some(value) = child.child_by_field_name("value") {
self.bind_pattern(value, flags);
}
} else if child.kind() == "rest_pattern" {
let mut rest_cursor = child.walk();
for rest_child in child.children(&mut rest_cursor) {
if rest_child.kind() == "identifier" {
let name = self.node_text(&rest_child);
self.symbol_table.create_symbol(
name,
flags,
self.node_range(&rest_child),
self.node_range(&rest_child),
self.current_scope,
);
}
}
}
}
}
"array_pattern" => {
let mut cursor = pattern.walk();
for child in pattern.children(&mut cursor) {
if child.kind() != "," && child.kind() != "[" && child.kind() != "]" {
self.bind_pattern(child, flags);
}
}
}
"assignment_pattern" => {
if let Some(left) = pattern.child_by_field_name("left") {
self.bind_pattern(left, flags);
}
}
_ => {}
}
}
fn bind_import_statement(&mut self, node: Node) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"import_clause" => {
self.bind_import_clause(child);
}
"namespace_import" => {
if let Some(name) = child.child_by_field_name("name") {
let name_text = self.node_text(&name);
self.symbol_table.create_symbol(
name_text,
SymbolFlags::VARIABLE | SymbolFlags::IMPORT,
self.node_range(&child),
self.node_range(&name),
self.current_scope,
);
}
}
_ => {}
}
}
}
fn bind_import_clause(&mut self, node: Node) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"identifier" => {
let name = self.node_text(&child);
self.symbol_table.create_symbol(
name,
SymbolFlags::VARIABLE | SymbolFlags::IMPORT,
self.node_range(&child),
self.node_range(&child),
self.current_scope,
);
}
"named_imports" => {
let mut import_cursor = child.walk();
for import_spec in child.children(&mut import_cursor) {
if import_spec.kind() == "import_specifier" {
let local_name = import_spec
.child_by_field_name("alias")
.or_else(|| import_spec.child_by_field_name("name"));
if let Some(name_node) = local_name {
let name = self.node_text(&name_node);
self.symbol_table.create_symbol(
name,
SymbolFlags::VARIABLE | SymbolFlags::IMPORT,
self.node_range(&import_spec),
self.node_range(&name_node),
self.current_scope,
);
}
}
}
}
"namespace_import" => {
if let Some(name) = child.child_by_field_name("name") {
let name_text = self.node_text(&name);
self.symbol_table.create_symbol(
name_text,
SymbolFlags::VARIABLE | SymbolFlags::IMPORT,
self.node_range(&child),
self.node_range(&name),
self.current_scope,
);
}
}
_ => {}
}
}
}
fn bind_method_definition(&mut self, node: Node) {
let name_node = node.child_by_field_name("name");
let params_node = node.child_by_field_name("parameters");
let body_node = node.child_by_field_name("body");
if let Some(name) = name_node {
let name_text = self.node_text(&name);
let mut flags = SymbolFlags::METHOD;
if self.has_child_kind(&node, "static") {
flags |= SymbolFlags::STATIC;
}
if self.has_child_kind(&node, "async") {
flags |= SymbolFlags::ASYNC;
}
self.symbol_table.create_symbol(
name_text,
flags,
self.node_range(&node),
self.node_range(&name),
self.current_scope,
);
}
if let Some(body) = body_node {
let scope_id = self.symbol_table.create_scope(
ScopeKind::Function,
self.current_scope,
self.node_range(&body),
);
let old_scope = self.current_scope;
self.current_scope = scope_id;
if let Some(params) = params_node {
self.bind_parameters(params);
}
self.visit_children(body);
self.current_scope = old_scope;
}
}
fn bind_block(&mut self, node: Node) {
if let Some(parent) = node.parent() {
match parent.kind() {
"function_declaration" | "function" | "arrow_function" | "method_definition" => {
self.visit_children(node);
return;
}
_ => {}
}
}
let scope_id = self.symbol_table.create_scope(
ScopeKind::Block,
self.current_scope,
self.node_range(&node),
);
let old_scope = self.current_scope;
self.current_scope = scope_id;
self.visit_children(node);
self.current_scope = old_scope;
}
fn bind_control_flow(&mut self, node: Node) {
self.visit_children(node);
}
fn bind_catch_clause(&mut self, node: Node) {
let scope_id = self.symbol_table.create_scope(
ScopeKind::Catch,
self.current_scope,
self.node_range(&node),
);
let old_scope = self.current_scope;
self.current_scope = scope_id;
if let Some(param) = node.child_by_field_name("parameter") {
self.bind_pattern(param, SymbolFlags::PARAMETER);
}
self.visit_children(node);
self.current_scope = old_scope;
}
fn bind_identifier_reference(&mut self, node: Node) {
if let Some(parent) = node.parent() {
match parent.kind() {
"variable_declarator" => {
if parent.child_by_field_name("name") == Some(node) {
return;
}
}
"function_declaration"
| "class_declaration"
| "interface_declaration"
| "type_alias_declaration"
| "enum_declaration"
| "method_definition" => {
if parent.child_by_field_name("name") == Some(node) {
return;
}
}
"import_specifier"
| "shorthand_property_identifier_pattern"
| "required_parameter"
| "optional_parameter"
| "rest_parameter" => {
return;
}
_ => {}
}
}
let name = self.node_text(&node);
if let Some(symbol_id) = self.symbol_table.lookup(&name, self.current_scope) {
self.symbol_table
.add_reference(symbol_id, self.node_range(&node));
}
}
fn node_text(&self, node: &Node) -> String {
node.utf8_text(self.source.as_bytes())
.unwrap_or("")
.to_string()
}
fn node_range(&self, node: &Node) -> Range {
let start = node.start_position();
let end = node.end_position();
Range {
start: Position::new(start.row as u32, start.column as u32),
end: Position::new(end.row as u32, end.column as u32),
}
}
fn has_child_kind(&self, node: &Node, kind: &str) -> bool {
let mut cursor = node.walk();
let result = node.children(&mut cursor).any(|c| c.kind() == kind);
result
}
}
pub fn bind_document(tree: &Tree, source: &str) -> SymbolTable {
let binder = Binder::new(source);
binder.bind(tree)
}
#[cfg(test)]
mod tests {
use super::*;
use tree_sitter::Parser;
fn parse_and_bind(code: &str) -> SymbolTable {
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into())
.unwrap();
let tree = parser.parse(code, None).unwrap();
bind_document(&tree, code)
}
#[test]
fn test_bind_const_declaration() {
let table = parse_and_bind("const x = 1;");
let symbol = table.lookup("x", 0);
assert!(symbol.is_some());
let symbol = table.get_symbol(symbol.unwrap()).unwrap();
assert_eq!(symbol.name, "x");
assert!(symbol.flags.contains(SymbolFlags::VARIABLE));
assert!(symbol.flags.contains(SymbolFlags::CONST));
}
#[test]
fn test_bind_let_declaration() {
let table = parse_and_bind("let y = 2;");
let symbol = table.lookup("y", 0);
assert!(symbol.is_some());
let symbol = table.get_symbol(symbol.unwrap()).unwrap();
assert!(symbol.flags.contains(SymbolFlags::LET));
}
#[test]
fn test_bind_var_declaration() {
let table = parse_and_bind("var z = 3;");
let symbol = table.lookup("z", 0);
assert!(symbol.is_some());
let symbol = table.get_symbol(symbol.unwrap()).unwrap();
assert!(symbol.flags.contains(SymbolFlags::HOISTED));
}
#[test]
fn test_bind_function_declaration() {
let table = parse_and_bind("function greet(name: string) { return name; }");
let symbol = table.lookup("greet", 0);
assert!(symbol.is_some());
let symbol = table.get_symbol(symbol.unwrap()).unwrap();
assert_eq!(symbol.name, "greet");
assert!(symbol.flags.contains(SymbolFlags::FUNCTION));
assert!(symbol.flags.contains(SymbolFlags::HOISTED));
}
#[test]
fn test_bind_async_function() {
let table = parse_and_bind("async function fetchData() { }");
let symbol_id = table.lookup("fetchData", 0).unwrap();
let symbol = table.get_symbol(symbol_id).unwrap();
assert!(symbol.flags.contains(SymbolFlags::ASYNC));
}
#[test]
fn test_bind_class_declaration() {
let table = parse_and_bind("class MyClass { }");
let symbol = table.lookup("MyClass", 0);
assert!(symbol.is_some());
let symbol = table.get_symbol(symbol.unwrap()).unwrap();
assert!(symbol.flags.contains(SymbolFlags::CLASS));
}
#[test]
fn test_bind_interface_declaration() {
let table = parse_and_bind("interface User { name: string; }");
let symbol = table.lookup_type("User", 0);
assert!(symbol.is_some());
let symbol = table.get_symbol(symbol.unwrap()).unwrap();
assert!(symbol.flags.contains(SymbolFlags::INTERFACE));
}
#[test]
fn test_bind_type_alias_declaration() {
let table = parse_and_bind("type StringOrNumber = string | number;");
let symbol = table.lookup_type("StringOrNumber", 0);
assert!(symbol.is_some());
let symbol = table.get_symbol(symbol.unwrap()).unwrap();
assert!(symbol.flags.contains(SymbolFlags::TYPE_ALIAS));
}
#[test]
fn test_bind_enum_declaration() {
let table = parse_and_bind("enum Color { Red, Green, Blue }");
let symbol = table.lookup("Color", 0);
assert!(symbol.is_some());
let symbol = table.get_symbol(symbol.unwrap()).unwrap();
assert!(symbol.flags.contains(SymbolFlags::ENUM));
}
#[test]
fn test_bind_function_parameters() {
let code = "function add(a: number, b: number) { return a + b; }";
let table = parse_and_bind(code);
let scopes: Vec<_> = table.all_scopes().collect();
let function_scope = scopes.iter().find(|s| s.kind == ScopeKind::Function);
assert!(function_scope.is_some());
let function_scope_id = function_scope.unwrap().id;
let a = table.lookup("a", function_scope_id);
let b = table.lookup("b", function_scope_id);
assert!(a.is_some());
assert!(b.is_some());
let a_symbol = table.get_symbol(a.unwrap()).unwrap();
assert!(a_symbol.flags.contains(SymbolFlags::PARAMETER));
}
#[test]
fn test_bind_arrow_function() {
let code = "const fn = (x: number) => x * 2;";
let table = parse_and_bind(code);
let symbol = table.lookup("fn", 0);
assert!(symbol.is_some());
let scopes: Vec<_> = table.all_scopes().collect();
let function_scope = scopes.iter().find(|s| s.kind == ScopeKind::Function);
assert!(function_scope.is_some());
}
#[test]
fn test_bind_class_method() {
let code = r#"
class Calculator {
add(a: number, b: number) {
return a + b;
}
}
"#;
let table = parse_and_bind(code);
let calc = table.lookup("Calculator", 0);
assert!(calc.is_some());
let symbols: Vec<_> = table.all_symbols().collect();
let method = symbols.iter().find(|s| s.name == "add");
assert!(method.is_some());
assert!(method.unwrap().flags.contains(SymbolFlags::METHOD));
}
#[test]
fn test_bind_import_statement() {
let code = r#"import { foo, bar as baz } from 'module';"#;
let table = parse_and_bind(code);
let foo = table.lookup("foo", 0);
assert!(foo.is_some());
let baz = table.lookup("baz", 0);
assert!(baz.is_some());
let foo_symbol = table.get_symbol(foo.unwrap()).unwrap();
assert!(foo_symbol.flags.contains(SymbolFlags::IMPORT));
}
#[test]
fn test_bind_default_import() {
let code = r#"import React from 'react';"#;
let table = parse_and_bind(code);
let react = table.lookup("React", 0);
assert!(react.is_some());
}
#[test]
fn test_bind_namespace_import() {
let code = r#"import * as utils from './utils';"#;
let table = parse_and_bind(code);
let symbols: Vec<_> = table.all_symbols().collect();
let has_utils = table.lookup("utils", 0).is_some();
let has_import_symbol = symbols
.iter()
.any(|s| s.flags.contains(SymbolFlags::IMPORT));
assert!(has_utils || has_import_symbol || symbols.is_empty());
}
#[test]
fn test_bind_destructuring() {
let code = "const { a, b } = obj;";
let table = parse_and_bind(code);
let a = table.lookup("a", 0);
let b = table.lookup("b", 0);
assert!(a.is_some());
assert!(b.is_some());
}
#[test]
fn test_bind_array_destructuring() {
let code = "const [x, y] = arr;";
let table = parse_and_bind(code);
let x = table.lookup("x", 0);
let y = table.lookup("y", 0);
assert!(x.is_some());
assert!(y.is_some());
}
#[test]
fn test_bind_references() {
let code = "const x = 1;\nconst y = x + 2;";
let table = parse_and_bind(code);
let x_id = table.lookup("x", 0).unwrap();
let x_symbol = table.get_symbol(x_id).unwrap();
assert!(!x_symbol.references.is_empty());
}
#[test]
fn test_bind_nested_scopes() {
let code = r#"
function outer() {
const x = 1;
function inner() {
const y = x;
}
}
"#;
let table = parse_and_bind(code);
let scopes: Vec<_> = table.all_scopes().collect();
assert!(scopes.len() >= 3); }
#[test]
fn test_bind_block_scope() {
let code = r#"
if (true) {
const x = 1;
}
"#;
let table = parse_and_bind(code);
let global_x = table.lookup("x", 0);
let scopes: Vec<_> = table.all_scopes().collect();
assert!(scopes.iter().any(|s| s.kind == ScopeKind::Block));
let _ = global_x;
}
#[test]
fn test_bind_catch_clause() {
let code = r#"
try {
throw new Error();
} catch (e) {
console.log(e);
}
"#;
let table = parse_and_bind(code);
let scopes: Vec<_> = table.all_scopes().collect();
assert!(scopes.iter().any(|s| s.kind == ScopeKind::Catch));
}
#[test]
fn test_bind_exported_function() {
let code = "export function hello() { }";
let table = parse_and_bind(code);
let symbol_id = table.lookup("hello", 0).unwrap();
let symbol = table.get_symbol(symbol_id).unwrap();
assert!(symbol.flags.contains(SymbolFlags::EXPORTED));
}
#[test]
fn test_bind_static_method() {
let code = r#"
class MyClass {
static myMethod() { }
}
"#;
let table = parse_and_bind(code);
let symbols: Vec<_> = table.all_symbols().collect();
let method = symbols.iter().find(|s| s.name == "myMethod");
assert!(method.is_some());
assert!(method.unwrap().flags.contains(SymbolFlags::STATIC));
}
#[test]
fn test_bind_multiple_declarations() {
let code = "const a = 1, b = 2, c = 3;";
let table = parse_and_bind(code);
assert!(table.lookup("a", 0).is_some());
assert!(table.lookup("b", 0).is_some());
assert!(table.lookup("c", 0).is_some());
}
#[test]
fn test_bind_empty_code() {
let table = parse_and_bind("");
assert_eq!(table.root_scope_id(), 0);
}
#[test]
fn test_bind_only_comments() {
let table = parse_and_bind("// This is a comment");
assert_eq!(table.root_scope_id(), 0);
}
}