use tree_sitter::{Node, Parser};
use crate::error::{CodegraphError, Result};
use crate::graph::types::{
Binding, BindingKind, ByteSpan, FileFacts, RefRole, Reference, Scope, ScopeId, ScopeKind,
Symbol, SymbolKind, Visibility,
};
use crate::lang::Language;
use crate::symbol::Descriptor;
use super::{
ExtractCtx, Extractor, MIN_REF_LEN, attach_reference_scopes, collect_call_references,
definition_bindings, field_text, innermost_scope, make_symbol, node_span, node_text,
one_line_signature, push_binding, push_ref, push_scope,
};
const CALL_QUERY: &str = r#"
[
(call !receiver method: (identifier) @callee)
(call receiver: (_) @qualifier method: (identifier) @callee)
]
"#;
pub struct RubyExtractor;
impl Extractor for RubyExtractor {
fn lang(&self) -> Language {
Language::Ruby
}
fn extract(&self, source: &str, file: &str) -> Result<FileFacts> {
let ts_language = crate::grammar::ruby();
let mut parser = Parser::new();
parser
.set_language(&ts_language)
.map_err(|_| CodegraphError::Parse {
path: file.to_owned(),
})?;
let tree = parser
.parse(source, None)
.ok_or_else(|| CodegraphError::Parse {
path: file.to_owned(),
})?;
let root = tree.root_node();
let bytes = source.as_bytes();
let ctx = ExtractCtx {
bytes,
file,
lang: Language::Ruby,
};
let ns_strings = ruby_namespaces(file);
let namespaces: Vec<Descriptor> = ns_strings
.iter()
.cloned()
.map(Descriptor::Namespace)
.collect();
let mut defs = Vec::new();
walk(&root, &namespaces, &ctx, &mut defs);
let def_bindings = definition_bindings(&defs);
let mut symbols = defs;
symbols.push(super::module_symbol(
Language::Ruby,
&ns_strings,
file,
source.len(),
));
let mut references =
collect_call_references(&root, &ts_language, CALL_QUERY, Language::Ruby, bytes, file)?;
collect_inheritance(&root, bytes, file, &mut references);
collect_read_references(&root, bytes, file, &mut references);
collect_write_references(&root, bytes, file, &mut references);
let scopes = collect_scopes(&root, source.len());
attach_reference_scopes(&mut references, &scopes);
let mut bindings = collect_bindings(&root, bytes, &scopes);
bindings.extend(def_bindings);
Ok(FileFacts {
file: file.to_owned(),
lang: Language::Ruby.as_str().to_owned(),
symbols,
references,
scopes,
bindings,
ffi_exports: Vec::new(),
})
}
}
fn ruby_namespaces(file: &str) -> Vec<String> {
let p = file.strip_suffix(".rb").unwrap_or(file);
let p = p
.strip_prefix("lib/")
.or_else(|| p.strip_prefix("app/"))
.or_else(|| p.strip_prefix("src/"))
.unwrap_or(p);
p.split('/')
.filter(|s| !s.is_empty())
.map(str::to_owned)
.collect()
}
fn walk(node: &Node, prefix: &[Descriptor], ctx: &ExtractCtx, out: &mut Vec<Symbol>) {
for child in node.children(&mut node.walk()) {
match child.kind() {
"class" | "module" => {
let Some(name) = field_text(&child, "name", ctx.bytes) else {
continue;
};
let kind = if child.kind() == "class" {
SymbolKind::Class
} else {
SymbolKind::Module
};
let mut descriptors = prefix.to_vec();
descriptors.push(Descriptor::Type(name.clone()));
if let Some(body) = child.child_by_field_name("body") {
push_symbol(out, ctx, &child, name, kind, descriptors.clone());
walk(&body, &descriptors, ctx, out);
} else {
push_symbol(out, ctx, &child, name, kind, descriptors);
}
}
"method" | "singleton_method" => {
let Some(name) = field_text(&child, "name", ctx.bytes) else {
continue;
};
let mut descriptors = prefix.to_vec();
descriptors.push(Descriptor::Method {
name: name.clone(),
disambiguator: String::new(),
});
push_symbol(out, ctx, &child, name, SymbolKind::Method, descriptors);
}
"assignment" => {
if let Some(left) = child.child_by_field_name("left") {
if left.kind() == "constant" {
let name = node_text(&left, ctx.bytes).to_owned();
let mut descriptors = prefix.to_vec();
descriptors.push(Descriptor::Term(name.clone()));
push_symbol(out, ctx, &child, name, SymbolKind::Const, descriptors);
}
}
}
_ => {}
}
}
}
fn push_symbol(
out: &mut Vec<Symbol>,
ctx: &ExtractCtx,
node: &Node,
name: String,
kind: SymbolKind,
descriptors: Vec<Descriptor>,
) {
let signature = one_line_signature(node_text(node, ctx.bytes), &[]);
out.push(make_symbol(
ctx,
node,
name,
kind,
Visibility::Unknown,
descriptors,
signature,
));
}
fn collect_inheritance(node: &Node, bytes: &[u8], file: &str, out: &mut Vec<Reference>) {
if node.kind() == "class" {
if let Some(superclass_node) = node.child_by_field_name("superclass") {
if let Some(type_node) = superclass_node
.children(&mut superclass_node.walk())
.find(|c| c.is_named())
{
super::push_ref(
out,
super::simple_type_name(node_text(&type_node, bytes), "::"),
&type_node,
file,
RefRole::IsImplementation,
);
}
}
}
for child in node.children(&mut node.walk()) {
collect_inheritance(&child, bytes, file, out);
}
}
fn is_non_read_position(node: &Node) -> bool {
let parent = match node.parent() {
Some(p) => p,
None => return true, };
match parent.kind() {
"call" => parent.child_by_field_name("method").as_ref() == Some(node),
"method" | "singleton_method" => parent.child_by_field_name("name").as_ref() == Some(node),
"method_parameters" | "block_parameters" => true,
"optional_parameter"
| "keyword_parameter"
| "splat_parameter"
| "block_parameter"
| "hash_splat_parameter" => parent.child_by_field_name("name").as_ref() == Some(node),
"assignment" => parent.child_by_field_name("left").as_ref() == Some(node),
_ => false,
}
}
fn collect_read_references(node: &Node, bytes: &[u8], file: &str, out: &mut Vec<Reference>) {
if node.kind() == "identifier" {
let name = node_text(node, bytes);
if name.len() >= MIN_REF_LEN && !is_non_read_position(node) {
push_ref(out, name, node, file, RefRole::Read);
}
return;
}
for child in node.children(&mut node.walk()) {
collect_read_references(&child, bytes, file, out);
}
}
fn collect_write_references(node: &Node, bytes: &[u8], file: &str, out: &mut Vec<Reference>) {
if node.kind() == "assignment" {
if let Some(lhs) = node.child_by_field_name("left") {
if lhs.kind() == "identifier" {
let name = node_text(&lhs, bytes);
if name.len() >= MIN_REF_LEN {
push_ref(out, name, &lhs, file, RefRole::Write);
}
}
}
}
for child in node.children(&mut node.walk()) {
collect_write_references(&child, bytes, file, out);
}
}
fn collect_scopes(root: &Node, source_len: usize) -> Vec<Scope> {
let mut scopes = Vec::new();
push_scope(
&mut scopes,
None,
ByteSpan {
start: 0,
end: source_len,
},
ScopeKind::Module,
);
for child in root.children(&mut root.walk()) {
scope_dfs(&child, 0, &mut scopes);
}
scopes
}
fn scope_dfs(node: &Node, parent_id: ScopeId, scopes: &mut Vec<Scope>) {
match node.kind() {
"class" | "module" | "singleton_class" => {
let type_id = push_scope(scopes, Some(parent_id), node_span(node), ScopeKind::Type);
if let Some(body) = node.child_by_field_name("body") {
for child in body.children(&mut body.walk()) {
scope_dfs(&child, type_id, scopes);
}
}
}
"method" | "singleton_method" => {
let fn_id = push_scope(
scopes,
Some(parent_id),
node_span(node),
ScopeKind::Function,
);
if let Some(body) = node.child_by_field_name("body") {
if body.kind() == "body_statement" {
for child in body.children(&mut body.walk()) {
scope_dfs(&child, fn_id, scopes);
}
} else {
scope_dfs(&body, fn_id, scopes);
}
}
}
"do_block" => {
let fn_id = push_scope(
scopes,
Some(parent_id),
node_span(node),
ScopeKind::Function,
);
if let Some(body) = node.child_by_field_name("body") {
for child in body.children(&mut body.walk()) {
scope_dfs(&child, fn_id, scopes);
}
}
}
"block" => {
let fn_id = push_scope(
scopes,
Some(parent_id),
node_span(node),
ScopeKind::Function,
);
if let Some(body) = node.child_by_field_name("body") {
for child in body.children(&mut body.walk()) {
scope_dfs(&child, fn_id, scopes);
}
}
}
_ => {
for child in node.children(&mut node.walk()) {
scope_dfs(&child, parent_id, scopes);
}
}
}
}
fn collect_bindings(root: &Node, bytes: &[u8], scopes: &[Scope]) -> Vec<Binding> {
let mut out = Vec::new();
collect_bindings_dfs(root, bytes, scopes, &mut out);
out
}
fn collect_bindings_dfs(node: &Node, bytes: &[u8], scopes: &[Scope], out: &mut Vec<Binding>) {
match node.kind() {
"method" | "singleton_method" => {
if let Some(params) = node.child_by_field_name("parameters") {
collect_params(¶ms, bytes, scopes, out);
}
for child in node.children(&mut node.walk()) {
collect_bindings_dfs(&child, bytes, scopes, out);
}
}
"block" | "do_block" => {
if let Some(params) = node.child_by_field_name("parameters") {
collect_params(¶ms, bytes, scopes, out);
}
for child in node.children(&mut node.walk()) {
collect_bindings_dfs(&child, bytes, scopes, out);
}
}
"assignment" => {
if let Some(left) = node.child_by_field_name("left") {
if left.kind() == "identifier" {
let name = node_text(&left, bytes).to_owned();
let intro = left.start_byte();
if let Some(sid) = innermost_scope(intro, scopes) {
if matches!(scopes[sid].kind, ScopeKind::Function | ScopeKind::Block) {
push_binding(out, name, intro, BindingKind::Local, scopes);
}
}
}
}
for child in node.children(&mut node.walk()) {
collect_bindings_dfs(&child, bytes, scopes, out);
}
}
_ => {
for child in node.children(&mut node.walk()) {
collect_bindings_dfs(&child, bytes, scopes, out);
}
}
}
}
fn collect_params(params: &Node, bytes: &[u8], scopes: &[Scope], out: &mut Vec<Binding>) {
for child in params.named_children(&mut params.walk()) {
match child.kind() {
"identifier" => {
let name = node_text(&child, bytes).to_owned();
let intro = child.start_byte();
push_binding(out, name, intro, BindingKind::Param, scopes);
}
"optional_parameter"
| "keyword_parameter"
| "splat_parameter"
| "block_parameter"
| "hash_splat_parameter" => {
if let Some(name_node) = child.child_by_field_name("name") {
let name = node_text(&name_node, bytes).to_owned();
let intro = name_node.start_byte();
push_binding(out, name, intro, BindingKind::Param, scopes);
}
}
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extracts_nested_defs() {
let src = r#"
module Auth
class Session
MAX = 3
def validate(token)
check(token)
end
def self.create
end
end
end
TOP = 1
def helper
end
"#;
let facts = RubyExtractor.extract(src, "lib/auth/session.rb").unwrap();
let by_name = |n: &str| facts.symbols.iter().find(|s| s.name == n).cloned();
let auth = by_name("Auth").unwrap();
assert_eq!(auth.kind, SymbolKind::Module);
assert_eq!(
auth.id.to_scip_string(),
"codegraph . . . auth/session/Auth#"
);
let session = by_name("Session").unwrap();
assert_eq!(session.kind, SymbolKind::Class);
assert_eq!(
session.id.to_scip_string(),
"codegraph . . . auth/session/Auth#Session#"
);
let max = by_name("MAX").unwrap();
assert_eq!(max.kind, SymbolKind::Const);
assert_eq!(
max.id.to_scip_string(),
"codegraph . . . auth/session/Auth#Session#MAX."
);
let validate = by_name("validate").unwrap();
assert_eq!(validate.kind, SymbolKind::Method);
assert_eq!(
validate.id.to_scip_string(),
"codegraph . . . auth/session/Auth#Session#validate()."
);
let create = by_name("create").unwrap();
assert_eq!(create.kind, SymbolKind::Method);
assert_eq!(
create.id.to_scip_string(),
"codegraph . . . auth/session/Auth#Session#create()."
);
let top = by_name("TOP").unwrap();
assert_eq!(top.kind, SymbolKind::Const);
assert_eq!(top.id.to_scip_string(), "codegraph . . . auth/session/TOP.");
let helper = by_name("helper").unwrap();
assert_eq!(helper.kind, SymbolKind::Method);
assert_eq!(
helper.id.to_scip_string(),
"codegraph . . . auth/session/helper()."
);
assert_eq!(facts.lang, "ruby");
}
#[test]
fn emits_methods_regardless_of_visibility() {
let src = r#"
class Svc
def open
end
private
def secret
end
end
"#;
let facts = RubyExtractor.extract(src, "lib/svc.rb").unwrap();
let by_name = |n: &str| facts.symbols.iter().find(|s| s.name == n).cloned();
let open_sym = by_name("open").unwrap();
assert_eq!(open_sym.kind, SymbolKind::Method);
let secret_sym = by_name("secret").unwrap();
assert_eq!(secret_sym.kind, SymbolKind::Method);
}
#[test]
fn extracts_call_references() {
let src = r#"
def run
validate("t")
process(data)
end
"#;
let facts = RubyExtractor.extract(src, "lib/main.rb").unwrap();
let names: Vec<&str> = facts.references.iter().map(|r| r.name.as_str()).collect();
assert!(names.contains(&"validate"));
assert!(names.contains(&"process"));
}
#[test]
fn qualified_call_captures_receiver_as_qualifier() {
let src = "def run\n Alpha.compute\n helper()\nend\n";
let facts = RubyExtractor.extract(src, "lib/main.rb").unwrap();
let compute = facts
.references
.iter()
.find(|r| r.role == RefRole::Call && r.name == "compute")
.expect("expected a Call ref for 'compute'");
assert_eq!(
compute.qualifier.as_deref(),
Some("Alpha"),
"`Alpha.compute` must be qualified by `Alpha`"
);
let helper = facts
.references
.iter()
.find(|r| r.role == RefRole::Call && r.name == "helper")
.expect("expected a Call ref for 'helper'");
assert_eq!(
helper.qualifier, None,
"a bare `helper()` call must have no qualifier"
);
}
#[test]
fn extracts_simple_inheritance() {
let src = "class Foo < Bar\nend\n";
let facts = RubyExtractor.extract(src, "lib/foo.rb").unwrap();
let inherit: Vec<&str> = facts
.references
.iter()
.filter(|r| r.role == RefRole::IsImplementation)
.map(|r| r.name.as_str())
.collect();
assert_eq!(inherit, vec!["Bar"], "expected [Bar], got {inherit:?}");
}
#[test]
fn extracts_qualified_inheritance_simple_name() {
let src = "class Foo < A::Bar\nend\n";
let facts = RubyExtractor.extract(src, "lib/foo.rb").unwrap();
let inherit: Vec<&str> = facts
.references
.iter()
.filter(|r| r.role == RefRole::IsImplementation)
.map(|r| r.name.as_str())
.collect();
assert_eq!(inherit, vec!["Bar"], "expected [Bar], got {inherit:?}");
}
#[test]
fn module_emits_no_inheritance_refs() {
let src = "module M\nend\n";
let facts = RubyExtractor.extract(src, "lib/m.rb").unwrap();
let inherit: Vec<&str> = facts
.references
.iter()
.filter(|r| r.role == RefRole::IsImplementation)
.map(|r| r.name.as_str())
.collect();
assert!(
inherit.is_empty(),
"expected no Inherit refs, got {inherit:?}"
);
}
#[test]
fn method_params_emit_param_bindings() {
let src = "def greet(name, age)\nend\n";
let facts = RubyExtractor.extract(src, "lib/greet.rb").unwrap();
let fn_scope_id = facts
.scopes
.iter()
.position(|s| s.kind == ScopeKind::Function)
.expect("expected a Function scope");
let mut param_names: Vec<(&str, ScopeId)> = facts
.bindings
.iter()
.filter(|b| b.kind == BindingKind::Param)
.map(|b| (b.name.as_str(), b.scope))
.collect();
param_names.sort_by_key(|(n, _)| *n);
assert_eq!(
param_names,
vec![("age", fn_scope_id), ("name", fn_scope_id)],
"expected Param bindings for name and age, got {param_names:?}"
);
}
#[test]
fn optional_keyword_splat_block_params() {
let src = "def f(a, b: 1, *c, &d)\nend\n";
let facts = RubyExtractor.extract(src, "lib/f.rb").unwrap();
let params: Vec<&str> = facts
.bindings
.iter()
.filter(|b| b.kind == BindingKind::Param)
.map(|b| b.name.as_str())
.collect();
assert!(params.contains(&"a"), "expected Param 'a', got {params:?}");
assert!(params.contains(&"b"), "expected Param 'b', got {params:?}");
assert!(params.contains(&"c"), "expected Param 'c', got {params:?}");
assert!(params.contains(&"d"), "expected Param 'd', got {params:?}");
}
#[test]
fn assignment_local_in_method() {
let src = "def f\n x = 1\nend\n";
let facts = RubyExtractor.extract(src, "lib/f.rb").unwrap();
let x = facts
.bindings
.iter()
.find(|b| b.kind == BindingKind::Local && b.name == "x")
.expect("expected a Local binding for 'x'");
assert_eq!(
facts.scopes[x.scope].kind,
ScopeKind::Function,
"Local 'x' should be in a Function scope"
);
}
#[test]
fn block_params_emit_param_bindings() {
let src = "[1].each { |item| }\n";
let facts = RubyExtractor.extract(src, "lib/blk.rb").unwrap();
let _fn_scope = facts
.scopes
.iter()
.find(|s| s.kind == ScopeKind::Function)
.expect("expected a Function scope for the block");
let item = facts
.bindings
.iter()
.find(|b| b.kind == BindingKind::Param && b.name == "item")
.expect("expected a Param binding for 'item'");
assert_eq!(
facts.scopes[item.scope].kind,
ScopeKind::Function,
"block param 'item' should be in a Function scope"
);
}
#[test]
fn class_level_constant_is_definition_not_local() {
let src = "class Foo\n BAR = 1\nend\n";
let facts = RubyExtractor.extract(src, "lib/foo.rb").unwrap();
assert!(
!facts
.bindings
.iter()
.any(|b| b.kind == BindingKind::Local && b.name == "BAR"),
"class-level constant 'BAR' must NOT be a Local binding"
);
assert!(
facts
.bindings
.iter()
.any(|b| b.kind == BindingKind::Definition && b.name == "BAR"),
"expected a Definition binding for 'BAR'"
);
}
#[test]
fn ivar_assignment_is_not_local() {
let src = "def f\n @x = 1\nend\n";
let facts = RubyExtractor.extract(src, "lib/f.rb").unwrap();
assert!(
!facts.bindings.iter().any(|b| b.kind == BindingKind::Local),
"instance variable assignment must NOT produce a Local binding"
);
}
#[test]
fn nesting_class_method_produces_correct_scopes_and_local() {
let src = "class S\n def h\n x = 1\n end\nend\n";
let facts = RubyExtractor.extract(src, "lib/s.rb").unwrap();
assert_eq!(
facts.scopes[0].kind,
ScopeKind::Module,
"scopes[0] must be Module"
);
let type_scope_id = facts
.scopes
.iter()
.position(|s| s.kind == ScopeKind::Type)
.expect("expected a Type scope for the class");
let fn_scope_id = facts
.scopes
.iter()
.position(|s| s.kind == ScopeKind::Function)
.expect("expected a Function scope for the method");
assert_eq!(
facts.scopes[type_scope_id].parent,
Some(0),
"Type scope parent must be Module (0)"
);
assert_eq!(
facts.scopes[fn_scope_id].parent,
Some(type_scope_id),
"Function scope parent must be the Type scope"
);
let x = facts
.bindings
.iter()
.find(|b| b.kind == BindingKind::Local && b.name == "x")
.expect("expected a Local binding for 'x'");
assert_eq!(
facts.scopes[x.scope].kind,
ScopeKind::Function,
"Local 'x' must be in a Function scope"
);
}
#[test]
fn same_file_call_ref_has_non_zero_scope() {
let src = "def helper\n 0\nend\ndef run\n helper()\nend\n";
let facts = RubyExtractor.extract(src, "lib/main.rb").unwrap();
assert!(
facts
.bindings
.iter()
.any(|b| b.kind == BindingKind::Definition && b.name == "helper"),
"expected a Definition binding for 'helper'"
);
let helper_ref = facts
.references
.iter()
.find(|r| r.role == RefRole::Call && r.name == "helper")
.expect("expected a Call ref for 'helper'");
let scope_id = helper_ref
.scope
.expect("helper() Call ref must have a scope attached");
assert_ne!(
scope_id, 0,
"helper() Call ref scope must not be the module root"
);
}
#[test]
fn ruby_read_ref_at_use_not_declaration() {
let src = "def f\n base = 1\n base\nend\n";
let facts = RubyExtractor.extract(src, "lib/f.rb").unwrap();
let read_refs: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Read && r.name == "base")
.collect();
assert!(
!read_refs.is_empty(),
"expected at least one Read ref for 'base'; refs: {:?}",
facts
.references
.iter()
.map(|r| (&r.name, r.role))
.collect::<Vec<_>>()
);
let use_ref = read_refs
.iter()
.find(|r| r.occ.byte > 15)
.expect("expected a Read ref for 'base' at the use site (byte > 15)");
assert!(
use_ref.occ.byte > 15,
"Read ref should be at the use site, not the declaration"
);
}
#[test]
fn ruby_write_ref_for_assignment() {
let src = "def f\n cnt = 0\n cnt = 5\nend\n";
let facts = RubyExtractor.extract(src, "lib/f.rb").unwrap();
let write_refs: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Write && r.name == "cnt")
.collect();
assert!(
!write_refs.is_empty(),
"expected at least one Write ref for 'cnt'; refs: {:?}",
facts
.references
.iter()
.map(|r| (&r.name, r.role))
.collect::<Vec<_>>()
);
}
#[test]
fn ruby_explicit_call_method_not_also_read() {
let src = "def f\n obj.helper\nend\n";
let facts = RubyExtractor.extract(src, "lib/f.rb").unwrap();
let helper_reads: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Read && r.name == "helper")
.collect();
assert!(
helper_reads.is_empty(),
"call method name 'helper' must NOT be a Read ref; got: {helper_reads:?}"
);
let obj_reads: Vec<_> = facts
.references
.iter()
.filter(|r| r.role == RefRole::Read && r.name == "obj")
.collect();
assert!(
!obj_reads.is_empty(),
"receiver 'obj' should be a Read ref; refs: {:?}",
facts
.references
.iter()
.map(|r| (&r.name, r.role))
.collect::<Vec<_>>()
);
}
#[test]
fn ruby_ivar_not_a_local_read_or_write() {
let src = "def f\n @val = 1\nend\n";
let facts = RubyExtractor.extract(src, "lib/f.rb").unwrap();
let val_rw: Vec<_> = facts
.references
.iter()
.filter(|r| {
matches!(r.role, RefRole::Read | RefRole::Write)
&& (r.name == "val" || r.name == "@val")
})
.collect();
assert!(
val_rw.is_empty(),
"instance variable '@val' must NOT produce a Read/Write identifier ref; got: {val_rw:?}"
);
}
}