use std::collections::HashMap;
use tree_sitter::{Node, Parser};
use crate::error::{CodegraphError, Result};
use crate::graph::types::{
Binding, BindingKind, BindingTarget, ByteSpan, FileFacts, RefRole, Reference, Scope, ScopeId,
ScopeKind, Symbol, SymbolKind, Visibility,
};
use crate::lang::Language;
use crate::symbol::Descriptor;
use super::{
ExtractCtx, Extractor, attach_reference_scopes, definition_bindings, innermost_scope,
make_symbol, node_span, push_scope,
};
pub struct SqlExtractor;
impl Extractor for SqlExtractor {
fn lang(&self) -> Language {
Language::Sql
}
fn extract(&self, source: &str, file: &str) -> Result<FileFacts> {
let ts_language = crate::grammar::sql();
let mut parser = Parser::new();
parser
.set_language(&ts_language)
.map_err(|_| CodegraphError::Parse {
path: file.to_owned(),
})?;
let tree = parser
.parse(source, None)
.ok_or_else(|| CodegraphError::Parse {
path: file.to_owned(),
})?;
let root = tree.root_node();
let bytes = source.as_bytes();
let ctx = ExtractCtx {
bytes,
file,
lang: Language::Sql,
};
let defs = collect_symbols(&root, &ctx);
let def_bindings = definition_bindings(&defs);
let mut symbols = defs;
let cte_symbols = collect_cte_symbols(&root, &ctx);
symbols.extend(cte_symbols.iter().cloned());
symbols.push(super::module_symbol(Language::Sql, &[], file, source.len()));
let mut references = collect_references(&root, ctx.bytes, ctx.file);
let scopes = collect_scopes(&root, source.len());
attach_reference_scopes(&mut references, &scopes);
let mut bindings = collect_cte_bindings(&root, ctx.bytes, &scopes, &cte_symbols);
bindings.extend(def_bindings);
Ok(FileFacts {
file: file.to_owned(),
lang: Language::Sql.as_str().to_owned(),
symbols,
references,
scopes,
bindings,
ffi_exports: Vec::new(),
})
}
}
fn collect_symbols(root: &Node, ctx: &ExtractCtx) -> Vec<Symbol> {
let mut out = Vec::new();
collect_symbols_recursive(root, ctx, &mut out);
out
}
fn collect_symbols_recursive(node: &Node, ctx: &ExtractCtx, out: &mut Vec<Symbol>) {
match node.kind() {
"create_table" => {
extract_table(node, ctx, out);
}
"create_view" | "create_materialized_view" => {
extract_view(node, ctx, out);
}
_ => {
for child in node.children(&mut node.walk()) {
collect_symbols_recursive(&child, ctx, out);
}
}
}
}
fn object_name_and_schema<'a>(
node: &'a Node<'a>,
bytes: &[u8],
) -> Option<(String, Option<String>)> {
let obj_ref = first_object_reference(node)?;
let name_node = obj_ref.child_by_field_name("name")?;
let name = super::unquote(super::node_text(&name_node, bytes)).to_owned();
let schema = obj_ref
.child_by_field_name("schema")
.map(|n| super::unquote(super::node_text(&n, bytes)).to_owned());
Some((name, schema))
}
fn extract_table(node: &Node, ctx: &ExtractCtx, out: &mut Vec<Symbol>) {
let Some((table_name, schema)) = object_name_and_schema(node, ctx.bytes) else {
return;
};
let table_descriptors = build_descriptors(schema.as_deref(), &table_name, None);
out.push(make_symbol(
ctx,
node,
table_name.clone(),
SymbolKind::Table,
Visibility::Public,
table_descriptors,
super::one_line_signature(super::node_text(node, ctx.bytes), &['(']),
));
let Some(col_defs) = node
.children(&mut node.walk())
.find(|c| c.kind() == "column_definitions")
else {
return;
};
for col_child in col_defs.children(&mut col_defs.walk()) {
if col_child.kind() != "column_definition" {
continue;
}
let Some(col_name_node) = col_child.child_by_field_name("name") else {
continue;
};
let raw_col = super::node_text(&col_name_node, ctx.bytes);
let col_name = super::unquote(raw_col).to_owned();
let col_descriptors = build_descriptors(schema.as_deref(), &table_name, Some(&col_name));
out.push(make_symbol(
ctx,
&col_child,
col_name,
SymbolKind::Column,
Visibility::Public,
col_descriptors,
super::one_line_signature(super::node_text(&col_child, ctx.bytes), &['(']),
));
}
}
fn extract_view(node: &Node, ctx: &ExtractCtx, out: &mut Vec<Symbol>) {
let Some((view_name, schema)) = object_name_and_schema(node, ctx.bytes) else {
return;
};
let view_descriptors = build_descriptors(schema.as_deref(), &view_name, None);
out.push(make_symbol(
ctx,
node,
view_name,
SymbolKind::View,
Visibility::Public,
view_descriptors,
super::one_line_signature(super::node_text(node, ctx.bytes), &['(']),
));
}
fn first_object_reference<'a>(node: &'a Node<'a>) -> Option<Node<'a>> {
node.children(&mut node.walk())
.find(|c| c.kind() == "object_reference")
}
fn build_descriptors(schema: Option<&str>, table: &str, column: Option<&str>) -> Vec<Descriptor> {
let mut descriptors = Vec::new();
if let Some(s) = schema {
descriptors.push(Descriptor::Namespace(s.to_owned()));
}
descriptors.push(Descriptor::Type(table.to_owned()));
if let Some(col) = column {
descriptors.push(Descriptor::Term(col.to_owned()));
}
descriptors
}
fn collect_references(root: &Node, bytes: &[u8], file: &str) -> Vec<Reference> {
let mut out = Vec::new();
collect_references_recursive(root, bytes, file, &mut out);
out
}
fn collect_references_recursive(node: &Node, bytes: &[u8], file: &str, out: &mut Vec<Reference>) {
if node.kind() == "object_reference" {
let is_definition_name = node
.parent()
.map(|p| {
matches!(
p.kind(),
"create_table" | "create_view" | "create_materialized_view"
)
})
.unwrap_or(false);
if !is_definition_name {
if let Some(name_node) = node.child_by_field_name("name") {
let name = super::unquote(super::node_text(&name_node, bytes)).to_owned();
if !name.is_empty() {
let qualifier = node
.child_by_field_name("schema")
.map(|n| super::unquote(super::node_text(&n, bytes)).to_owned());
out.push(Reference {
name,
occ: super::node_occurrence(node, file),
role: RefRole::TypeRef,
source_module: None,
from_path: None,
qualifier,
scope: None,
type_ref_ctx: None,
});
}
}
}
}
for child in node.children(&mut node.walk()) {
collect_references_recursive(&child, bytes, file, out);
}
}
fn collect_scopes(root: &Node, source_len: usize) -> Vec<Scope> {
let mut scopes = Vec::new();
push_scope(
&mut scopes,
None,
ByteSpan {
start: 0,
end: source_len,
},
ScopeKind::Module,
);
for child in root.children(&mut root.walk()) {
scope_dfs(&child, 0, &mut scopes);
}
scopes
}
fn scope_dfs(node: &Node, parent: ScopeId, scopes: &mut Vec<Scope>) {
match node.kind() {
"statement" | "subquery" => {
let new_id = push_scope(scopes, Some(parent), node_span(node), ScopeKind::Other);
for child in node.children(&mut node.walk()) {
scope_dfs(&child, new_id, scopes);
}
}
_ => {
for child in node.children(&mut node.walk()) {
scope_dfs(&child, parent, scopes);
}
}
}
}
#[inline]
fn cte_identifier_node<'a>(cte_node: &Node<'a>) -> Option<Node<'a>> {
cte_node
.children(&mut cte_node.walk())
.find(|c| c.kind() == "identifier")
}
fn collect_cte_symbols(root: &Node, ctx: &ExtractCtx) -> Vec<Symbol> {
let mut out = Vec::new();
collect_cte_symbols_dfs(root, ctx, &mut out);
out
}
fn collect_cte_symbols_dfs(node: &Node, ctx: &ExtractCtx, out: &mut Vec<Symbol>) {
if node.kind() == "cte" {
if let Some(name_node) = cte_identifier_node(node) {
let name = super::unquote(super::node_text(&name_node, ctx.bytes)).to_owned();
if !name.is_empty() {
let descriptors = vec![Descriptor::Type(name.clone())];
out.push(make_symbol(
ctx,
node,
name,
SymbolKind::Other,
Visibility::Public,
descriptors,
String::new(),
));
}
}
return;
}
for child in node.children(&mut node.walk()) {
collect_cte_symbols_dfs(&child, ctx, out);
}
}
fn collect_cte_bindings(
root: &Node,
bytes: &[u8],
scopes: &[Scope],
cte_symbols: &[Symbol],
) -> Vec<Binding> {
let by_name: HashMap<&str, &Symbol> =
cte_symbols.iter().map(|s| (s.name.as_str(), s)).collect();
let mut out = Vec::new();
collect_cte_bindings_dfs(root, bytes, scopes, &by_name, &mut out);
out
}
fn collect_cte_bindings_dfs<'a>(
node: &Node,
bytes: &[u8],
scopes: &[Scope],
by_name: &HashMap<&'a str, &'a Symbol>,
out: &mut Vec<Binding>,
) {
if node.kind() == "cte" {
if let Some(name_node) = cte_identifier_node(node) {
let name = super::unquote(super::node_text(&name_node, bytes));
if let Some(sym) = by_name.get(name) {
let scope = innermost_scope(name_node.start_byte(), scopes).unwrap_or(0);
out.push(Binding {
scope,
name: name.to_owned(),
intro: name_node.start_byte(),
kind: BindingKind::Definition,
target: BindingTarget::Def(sym.id.clone()),
});
}
}
return;
}
for child in node.children(&mut node.walk()) {
collect_cte_bindings_dfs(&child, bytes, scopes, by_name, out);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::extract::extract_path;
fn scip(sym: &Symbol) -> String {
sym.id.to_scip_string()
}
fn find_by_name<'a>(symbols: &'a [Symbol], name: &str) -> Option<&'a Symbol> {
symbols.iter().find(|s| s.name == name)
}
#[test]
fn sql_stub_parses_and_emits_module_symbol() {
let src = "CREATE TABLE users (id INT, name TEXT);";
let facts = SqlExtractor.extract(src, "db/schema.sql").unwrap();
assert_eq!(facts.lang, "sql");
assert!(
!facts.symbols.is_empty(),
"expected at least the module symbol, got {:?}",
facts.symbols
);
let mod_sym = facts.symbols.iter().find(|s| s.name == "schema").unwrap();
assert!(
mod_sym.id.to_scip_string().contains("schema"),
"module symbol SCIP string should contain the file stem; got: {}",
mod_sym.id.to_scip_string()
);
assert!(
facts
.references
.iter()
.all(|r| r.role != RefRole::TypeRef || r.name != "users"),
"pure DDL should not emit a TypeRef reference for the table being created"
);
}
#[test]
fn dispatch_routes_sql_extension() {
let src = "CREATE TABLE orders (id INT);";
let facts = extract_path("db/orders.sql", src).unwrap();
assert_eq!(facts.lang, "sql");
}
#[test]
fn create_table_emits_table_and_column_symbols() {
let src = "CREATE TABLE users (id INT, email TEXT);";
let facts = SqlExtractor.extract(src, "db/schema.sql").unwrap();
let table = find_by_name(&facts.symbols, "users").expect("expected 'users' table symbol");
assert_eq!(table.kind, SymbolKind::Table);
assert!(
scip(table).ends_with("users#"),
"table SCIP should end with 'users#', got: {}",
scip(table)
);
let id_col = find_by_name(&facts.symbols, "id").expect("expected 'id' column symbol");
assert_eq!(id_col.kind, SymbolKind::Column);
assert!(
scip(id_col).ends_with("users#id."),
"id column SCIP should end with 'users#id.', got: {}",
scip(id_col)
);
let email_col =
find_by_name(&facts.symbols, "email").expect("expected 'email' column symbol");
assert_eq!(email_col.kind, SymbolKind::Column);
assert!(
scip(email_col).ends_with("users#email."),
"email column SCIP should end with 'users#email.', got: {}",
scip(email_col)
);
}
#[test]
fn schema_qualified_table_and_column() {
let src = "CREATE TABLE app.users (id INT);";
let facts = SqlExtractor.extract(src, "db/schema.sql").unwrap();
let table = find_by_name(&facts.symbols, "users").expect("expected 'users' symbol");
assert_eq!(table.kind, SymbolKind::Table);
assert!(
scip(table).ends_with("app/users#"),
"table SCIP should end with 'app/users#', got: {}",
scip(table)
);
let id_col = find_by_name(&facts.symbols, "id").expect("expected 'id' column symbol");
assert_eq!(id_col.kind, SymbolKind::Column);
assert!(
scip(id_col).ends_with("app/users#id."),
"id column SCIP should end with 'app/users#id.', got: {}",
scip(id_col)
);
}
#[test]
fn create_view_emits_view_symbol_no_columns() {
let src = "CREATE VIEW active_users AS SELECT * FROM users;";
let facts = SqlExtractor.extract(src, "db/schema.sql").unwrap();
let view =
find_by_name(&facts.symbols, "active_users").expect("expected 'active_users' symbol");
assert_eq!(view.kind, SymbolKind::View);
assert!(
scip(view).ends_with("active_users#"),
"view SCIP should end with 'active_users#', got: {}",
scip(view)
);
let col_count = facts
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Column)
.count();
assert_eq!(col_count, 0, "views should produce no column symbols");
}
#[test]
fn double_quoted_table_name_strips_quotes() {
let src = r#"CREATE TABLE "my table" (id INT);"#;
let facts = SqlExtractor.extract(src, "db/schema.sql").unwrap();
let table = find_by_name(&facts.symbols, "my table")
.expect("expected 'my table' symbol (unquoted)");
assert_eq!(table.kind, SymbolKind::Table);
assert!(
scip(table).contains("my table"),
"SCIP should contain the bare name 'my table', got: {}",
scip(table)
);
}
#[test]
fn ctas_does_not_panic_and_emits_table_symbol_only() {
let src = "CREATE TABLE summary AS SELECT id, name FROM users;";
let facts = SqlExtractor.extract(src, "db/schema.sql").unwrap();
let table =
find_by_name(&facts.symbols, "summary").expect("expected 'summary' table symbol");
assert_eq!(table.kind, SymbolKind::Table);
let col_count = facts
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Column)
.count();
assert_eq!(col_count, 0, "CTAS table should produce no column symbols");
}
#[test]
fn empty_sql_does_not_panic_and_returns_module_symbol() {
let facts = SqlExtractor.extract("", "db/empty.sql").unwrap();
assert!(
facts.symbols.iter().any(|s| s.kind == SymbolKind::Module),
"empty SQL should still produce the module symbol"
);
assert!(
!facts.symbols.iter().any(|s| matches!(
s.kind,
SymbolKind::Table | SymbolKind::View | SymbolKind::Column
)),
"empty SQL should produce no DDL symbols"
);
}
#[test]
fn malformed_sql_does_not_panic() {
let facts = SqlExtractor
.extract("THIS IS NOT VALID SQL !!!", "db/bad.sql")
.unwrap();
assert!(
facts.symbols.iter().any(|s| s.kind == SymbolKind::Module),
"malformed SQL should still return Ok with the module symbol"
);
}
#[test]
fn select_from_emits_typeref_reference() {
let src = "SELECT * FROM users;";
let facts = SqlExtractor.extract(src, "db/query.sql").unwrap();
let refs: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::TypeRef && r.name == "users")
.collect();
assert_eq!(
refs.len(),
1,
"expected exactly one TypeRef ref named 'users', got: {:?}",
facts.references
);
assert_eq!(
refs[0].qualifier, None,
"unqualified ref should have no qualifier"
);
}
#[test]
fn select_from_schema_qualified_emits_qualifier() {
let src = "SELECT * FROM app.users;";
let facts = SqlExtractor.extract(src, "db/query.sql").unwrap();
let refs: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::TypeRef && r.name == "users")
.collect();
assert_eq!(
refs.len(),
1,
"expected one TypeRef ref named 'users', got: {:?}",
facts.references
);
assert_eq!(
refs[0].qualifier,
Some("app".to_owned()),
"schema-qualified ref should carry qualifier 'app'"
);
}
#[test]
fn join_emits_typeref_reference() {
let src = "SELECT * FROM orders JOIN users ON orders.user_id = users.id;";
let facts = SqlExtractor.extract(src, "db/query.sql").unwrap();
let users_refs: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::TypeRef && r.name == "users")
.collect();
assert!(
!users_refs.is_empty(),
"expected at least one TypeRef ref for 'users' (JOIN target), got: {:?}",
facts.references
);
}
#[test]
fn foreign_key_references_emits_typeref() {
let src = "CREATE TABLE orders (id INT, user_id INT REFERENCES users(id));";
let facts = SqlExtractor.extract(src, "db/schema.sql").unwrap();
let fk_refs: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::TypeRef && r.name == "users")
.collect();
assert_eq!(
fk_refs.len(),
1,
"expected one TypeRef ref for FK REFERENCES 'users', got: {:?}",
facts.references
);
}
#[test]
fn pure_ddl_no_typeref_for_definition_name() {
let src = "CREATE TABLE users (id INT);";
let facts = SqlExtractor.extract(src, "db/schema.sql").unwrap();
let typeref_users: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::TypeRef && r.name == "users")
.collect();
assert!(
typeref_users.is_empty(),
"pure DDL CREATE TABLE should NOT emit a TypeRef ref for 'users' (it's the definition name), \
got: {:?}",
typeref_users
);
}
#[test]
fn cte_statement_opens_other_scope() {
let src = "WITH r AS (SELECT 1) SELECT * FROM r;";
let facts = SqlExtractor.extract(src, "db/query.sql").unwrap();
assert!(
facts.scopes.len() >= 2,
"expected at least two scopes (Module + statement), got: {:?}",
facts.scopes
);
let has_other = facts.scopes.iter().any(|s| s.kind == ScopeKind::Other);
assert!(has_other, "expected at least one ScopeKind::Other scope");
for scope in &facts.scopes {
if scope.kind == ScopeKind::Other {
assert!(
scope.parent.is_some(),
"Other scope should have a parent, got: {:?}",
scope
);
}
}
}
#[test]
fn cte_name_gets_definition_binding_in_statement_scope() {
let src = "WITH revenue AS (SELECT amount FROM sales) SELECT * FROM revenue;";
let facts = SqlExtractor.extract(src, "db/query.sql").unwrap();
let binding = facts
.bindings
.iter()
.find(|b| b.name == "revenue" && b.kind == BindingKind::Definition);
let binding = binding.expect("expected a Definition binding for 'revenue'");
assert_ne!(
binding.scope, 0,
"CTE binding should be in a statement scope, not the file root (scope 0)"
);
assert_eq!(
facts.scopes[binding.scope].kind,
ScopeKind::Other,
"CTE binding scope should be ScopeKind::Other, got: {:?}",
facts.scopes[binding.scope]
);
}
#[test]
fn cte_from_ref_has_scope_set() {
let src = "WITH revenue AS (SELECT amount FROM sales) SELECT * FROM revenue;";
let facts = SqlExtractor.extract(src, "db/query.sql").unwrap();
let revenue_ref = facts
.references
.iter()
.find(|r| r.role == RefRole::TypeRef && r.name == "revenue");
let revenue_ref =
revenue_ref.expect("expected a TypeRef reference named 'revenue' (FROM revenue)");
assert!(
revenue_ref.scope.is_some(),
"FROM revenue reference should have scope set, got None"
);
}
#[test]
fn cte_emits_symbol() {
let src = "WITH revenue AS (SELECT amount FROM sales) SELECT * FROM revenue;";
let facts = SqlExtractor.extract(src, "db/query.sql").unwrap();
let sym = find_by_name(&facts.symbols, "revenue")
.expect("expected a Symbol named 'revenue' for the CTE definition");
assert_eq!(
sym.kind,
SymbolKind::Other,
"CTE symbol kind should be Other"
);
assert_eq!(sym.line, 1, "CTE symbol should be on line 1");
}
#[test]
fn plain_select_scopes_and_ref_scope() {
let src = "SELECT * FROM users;";
let facts = SqlExtractor.extract(src, "db/query.sql").unwrap();
assert!(
!facts.scopes.is_empty(),
"plain SELECT should still produce scopes"
);
let users_ref = facts
.references
.iter()
.find(|r| r.role == RefRole::TypeRef && r.name == "users")
.expect("expected TypeRef ref for 'users'");
assert!(
users_ref.scope.is_some(),
"plain SELECT FROM ref should have scope set"
);
}
#[test]
fn ddl_gets_definition_binding_at_scope_0() {
let src = "CREATE TABLE orders (id INT);";
let facts = SqlExtractor.extract(src, "db/schema.sql").unwrap();
let binding = facts
.bindings
.iter()
.find(|b| b.name == "orders" && b.kind == BindingKind::Definition);
let binding = binding.expect("expected a Definition binding for 'orders'");
assert_eq!(
binding.scope, 0,
"DDL Definition binding should be at file-root scope (0)"
);
}
}