pub mod classifier;
use crate::imports::FileImports;
use rma_common::Language;
use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
pub use classifier::FunctionClassifier;
#[derive(Debug, Clone, Default)]
pub struct FunctionClassification {
pub is_source: bool,
pub source_kind: Option<SourceClassification>,
pub contains_sinks: bool,
pub sink_kinds: Vec<SinkClassification>,
pub calls_sanitizers: bool,
pub sanitizes: Vec<String>,
pub confidence: f32,
}
impl PartialEq for FunctionClassification {
fn eq(&self, other: &Self) -> bool {
self.is_source == other.is_source
&& self.source_kind == other.source_kind
&& self.contains_sinks == other.contains_sinks
&& self.sink_kinds == other.sink_kinds
&& self.calls_sanitizers == other.calls_sanitizers
&& self.sanitizes == other.sanitizes
}
}
impl Eq for FunctionClassification {}
impl std::hash::Hash for FunctionClassification {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.is_source.hash(state);
self.source_kind.hash(state);
self.contains_sinks.hash(state);
self.sink_kinds.hash(state);
self.calls_sanitizers.hash(state);
self.sanitizes.hash(state);
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum SourceClassification {
HttpHandler,
HttpInput,
FileInput,
EnvironmentVariable,
DatabaseResult,
MessageInput,
CommandLineArgs,
Other(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum SinkClassification {
SqlInjection,
CommandInjection,
PathTraversal,
CrossSiteScripting,
Deserialization,
LdapInjection,
TemplateInjection,
XmlInjection,
LogInjection,
OpenRedirect,
GenericInjection,
Other(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SinkEvidenceKind {
CalleeEvidence {
qualified_name: String,
},
ImportEvidence {
import_path: String,
},
TypeEvidence {
type_name: String,
},
PatternOnly {
pattern: String,
},
None,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SinkEvidence {
pub kind: SinkEvidenceKind,
pub confidence: f32,
pub details: String,
}
impl SinkEvidence {
pub fn from_callee(qualified_name: impl Into<String>) -> Self {
let qn = qualified_name.into();
Self {
details: format!("callee: {}", qn),
kind: SinkEvidenceKind::CalleeEvidence { qualified_name: qn },
confidence: 0.95,
}
}
pub fn from_import(import_path: impl Into<String>) -> Self {
let ip = import_path.into();
Self {
details: format!("imports: {}", ip),
kind: SinkEvidenceKind::ImportEvidence { import_path: ip },
confidence: 0.8,
}
}
pub fn from_type(type_name: impl Into<String>) -> Self {
let tn = type_name.into();
Self {
details: format!("type: {}", tn),
kind: SinkEvidenceKind::TypeEvidence { type_name: tn },
confidence: 0.85,
}
}
pub fn from_pattern(pattern: impl Into<String>) -> Self {
let p = pattern.into();
Self {
details: format!("pattern: {}", p),
kind: SinkEvidenceKind::PatternOnly { pattern: p },
confidence: 0.3,
}
}
pub fn none() -> Self {
Self {
kind: SinkEvidenceKind::None,
confidence: 0.0,
details: "no evidence".to_string(),
}
}
pub fn is_strong(&self) -> bool {
matches!(
self.kind,
SinkEvidenceKind::CalleeEvidence { .. }
| SinkEvidenceKind::TypeEvidence { .. }
| SinkEvidenceKind::ImportEvidence { .. }
)
}
}
impl std::fmt::Display for SourceClassification {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SourceClassification::HttpHandler => write!(f, "HTTP Handler"),
SourceClassification::HttpInput => write!(f, "HTTP Input"),
SourceClassification::FileInput => write!(f, "File Input"),
SourceClassification::EnvironmentVariable => write!(f, "Environment Variable"),
SourceClassification::DatabaseResult => write!(f, "Database Result"),
SourceClassification::MessageInput => write!(f, "Message Input"),
SourceClassification::CommandLineArgs => write!(f, "Command Line Args"),
SourceClassification::Other(s) => write!(f, "{}", s),
}
}
}
impl std::fmt::Display for SinkClassification {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SinkClassification::SqlInjection => write!(f, "SQL Injection"),
SinkClassification::CommandInjection => write!(f, "Command Injection"),
SinkClassification::PathTraversal => write!(f, "Path Traversal"),
SinkClassification::CrossSiteScripting => write!(f, "XSS"),
SinkClassification::Deserialization => write!(f, "Deserialization"),
SinkClassification::LdapInjection => write!(f, "LDAP Injection"),
SinkClassification::TemplateInjection => write!(f, "Template Injection"),
SinkClassification::XmlInjection => write!(f, "XML Injection"),
SinkClassification::LogInjection => write!(f, "Log Injection"),
SinkClassification::OpenRedirect => write!(f, "Open Redirect"),
SinkClassification::GenericInjection => write!(f, "Injection"),
SinkClassification::Other(s) => write!(f, "{}", s),
}
}
}
pub fn validate_go_sql_sink(file_content: &str, sink_call: &str) -> SinkEvidence {
let has_sql_import = file_content.contains("\"database/sql\"")
|| file_content.contains("\"github.com/jmoiron/sqlx\"")
|| file_content.contains("\"gorm.io/gorm\"")
|| file_content.contains("\"github.com/jinzhu/gorm\"");
if !has_sql_import {
return SinkEvidence::from_pattern(sink_call);
}
let sql_methods = [
"Query",
"QueryContext",
"QueryRow",
"QueryRowContext",
"Exec",
"ExecContext",
"Prepare",
"PrepareContext",
"Raw", ];
for method in &sql_methods {
if sink_call.contains(method) {
return SinkEvidence::from_import("database/sql");
}
}
SinkEvidence::from_pattern(sink_call)
}
pub fn validate_go_xss_sink(file_content: &str, sink_call: &str) -> SinkEvidence {
if file_content.contains("\"html/template\"")
&& (sink_call.contains("Execute") || sink_call.contains("ExecuteTemplate"))
{
return SinkEvidence::from_import("html/template");
}
if file_content.contains("\"text/template\"") && sink_call.contains("Execute") {
return SinkEvidence::from_import("text/template (warning: no auto-escaping)");
}
if file_content.contains("\"encoding/json\"")
&& (sink_call.contains("Encode") || sink_call.contains("Marshal"))
{
return SinkEvidence::none();
}
if sink_call.contains("log") || sink_call.contains("trace") || sink_call.contains("debug") {
return SinkEvidence::none();
}
SinkEvidence::from_pattern(sink_call)
}
pub fn validate_sink_classification(
classification: SinkClassification,
language: Language,
file_content: &str,
sink_call: &str,
) -> (SinkClassification, SinkEvidence) {
match language {
Language::Go => {
match &classification {
SinkClassification::SqlInjection => {
let evidence = validate_go_sql_sink(file_content, sink_call);
if evidence.is_strong() {
(classification, evidence)
} else {
(SinkClassification::GenericInjection, evidence)
}
}
SinkClassification::CrossSiteScripting => {
let evidence = validate_go_xss_sink(file_content, sink_call);
if evidence.is_strong() {
(classification, evidence)
} else if matches!(evidence.kind, SinkEvidenceKind::None) {
(
SinkClassification::Other("non-html-output".to_string()),
evidence,
)
} else {
(SinkClassification::GenericInjection, evidence)
}
}
_ => {
(classification, SinkEvidence::from_pattern(sink_call))
}
}
}
_ => {
(classification, SinkEvidence::from_pattern(sink_call))
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct FunctionDef {
pub name: String,
pub file: PathBuf,
pub line: usize,
pub is_exported: bool,
pub language: Language,
pub classification: FunctionClassification,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CallSite {
pub callee_name: String,
pub caller_file: PathBuf,
pub caller_function: Option<String>,
pub line: usize,
pub resolved_target: Option<PathBuf>,
}
#[derive(Debug, Clone)]
pub struct CallEdge {
pub caller: FunctionDef,
pub callee: FunctionDef,
pub call_site: CallSite,
pub is_cross_file: bool,
}
#[derive(Debug, Clone)]
pub struct TaintFlow {
pub source: FunctionDef,
pub sink: FunctionDef,
pub path: Vec<FunctionDef>,
pub confidence: f32,
}
impl TaintFlow {
pub fn sink_type(&self) -> Option<&SinkClassification> {
self.sink.classification.sink_kinds.first()
}
pub fn source_type(&self) -> Option<&SourceClassification> {
self.source.classification.source_kind.as_ref()
}
pub fn format_path(&self) -> String {
let mut parts = Vec::new();
let source_file = self
.source
.file
.file_name()
.and_then(|f| f.to_str())
.unwrap_or("?");
parts.push(format!(
"{} ({}:{})",
self.source.name, source_file, self.source.line
));
let mut last_file = &self.source.file;
for func in &self.path {
let file = func
.file
.file_name()
.and_then(|f| f.to_str())
.unwrap_or("?");
if &func.file != last_file {
parts.push(format!("[{}] {} ({}:{})", file, func.name, file, func.line));
} else {
parts.push(format!("{} ({}:{})", func.name, file, func.line));
}
last_file = &func.file;
}
let sink_file = self
.sink
.file
.file_name()
.and_then(|f| f.to_str())
.unwrap_or("?");
if &self.sink.file != last_file {
parts.push(format!(
"[{}] {} ({}:{})",
sink_file, self.sink.name, sink_file, self.sink.line
));
} else {
parts.push(format!(
"{} ({}:{})",
self.sink.name, sink_file, self.sink.line
));
}
parts.join(" -> ")
}
}
fn calculate_flow_confidence(source: &FunctionDef, sink: &FunctionDef) -> f32 {
let mut confidence = 0.5;
if matches!(
source.classification.source_kind,
Some(SourceClassification::HttpHandler) | Some(SourceClassification::HttpInput)
) {
confidence += 0.2;
}
let has_critical_sink = source.classification.sink_kinds.iter().any(|s| {
matches!(
s,
SinkClassification::SqlInjection
| SinkClassification::CommandInjection
| SinkClassification::Deserialization
)
});
if has_critical_sink {
confidence += 0.2;
}
confidence += (source.classification.confidence + sink.classification.confidence) / 4.0;
confidence.min(1.0)
}
use crate::flow::events::{EventBinding, EventRegistry, EventSite};
#[derive(Debug, Default, Clone)]
pub struct CallGraph {
functions: HashMap<(PathBuf, String), FunctionDef>,
functions_by_name: HashMap<String, Vec<FunctionDef>>,
caller_to_callees: HashMap<(PathBuf, String), Vec<CallEdge>>,
callee_to_callers: HashMap<(PathBuf, String), Vec<CallEdge>>,
call_sites: Vec<CallSite>,
unresolved_calls: Vec<CallSite>,
event_bindings: HashMap<String, EventBinding>,
}
impl CallGraph {
pub fn new() -> Self {
Self::default()
}
pub fn functions(&self) -> impl Iterator<Item = &FunctionDef> {
self.functions.values()
}
pub fn get_function(&self, file: &Path, name: &str) -> Option<&FunctionDef> {
self.functions.get(&(file.to_path_buf(), name.to_string()))
}
pub fn get_functions_by_name(&self, name: &str) -> &[FunctionDef] {
self.functions_by_name
.get(name)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
pub fn callers_of(&self, file: &Path, name: &str) -> Vec<&CallEdge> {
self.callee_to_callers
.get(&(file.to_path_buf(), name.to_string()))
.map(|v| v.iter().collect())
.unwrap_or_default()
}
pub fn callees_of(&self, file: &Path, name: &str) -> Vec<&CallEdge> {
self.caller_to_callees
.get(&(file.to_path_buf(), name.to_string()))
.map(|v| v.iter().collect())
.unwrap_or_default()
}
pub fn is_reachable(
&self,
from_file: &Path,
from_name: &str,
to_file: &Path,
to_name: &str,
) -> bool {
let mut visited = HashSet::new();
let mut stack = vec![(from_file.to_path_buf(), from_name.to_string())];
while let Some((file, name)) = stack.pop() {
if file == to_file && name == to_name {
return true;
}
if !visited.insert((file.clone(), name.clone())) {
continue;
}
for edge in self.callees_of(&file, &name) {
stack.push((edge.callee.file.clone(), edge.callee.name.clone()));
}
}
false
}
pub fn cross_file_edges(&self) -> Vec<&CallEdge> {
self.caller_to_callees
.values()
.flatten()
.filter(|e| e.is_cross_file)
.collect()
}
pub fn all_edges(&self) -> Vec<&CallEdge> {
self.caller_to_callees.values().flatten().collect()
}
pub fn unresolved_calls(&self) -> &[CallSite] {
&self.unresolved_calls
}
pub fn function_count(&self) -> usize {
self.functions.len()
}
pub fn edge_count(&self) -> usize {
self.caller_to_callees.values().map(|v| v.len()).sum()
}
pub fn source_functions(&self) -> Vec<&FunctionDef> {
self.functions
.values()
.filter(|f| f.classification.is_source)
.collect()
}
pub fn sink_functions(&self) -> Vec<&FunctionDef> {
self.functions
.values()
.filter(|f| f.classification.contains_sinks)
.collect()
}
pub fn sanitizer_functions(&self) -> Vec<&FunctionDef> {
self.functions
.values()
.filter(|f| f.classification.calls_sanitizers)
.collect()
}
pub fn has_sanitizer_on_path(
&self,
from_file: &Path,
from_name: &str,
to_file: &Path,
to_name: &str,
) -> bool {
let mut visited = HashSet::new();
let mut stack = vec![(from_file.to_path_buf(), from_name.to_string())];
while let Some((file, name)) = stack.pop() {
if file == to_file && name == to_name {
return false; }
if !visited.insert((file.clone(), name.clone())) {
continue;
}
if let Some(func) = self.get_function(&file, &name)
&& func.classification.calls_sanitizers
{
return true;
}
for edge in self.callees_of(&file, &name) {
stack.push((edge.callee.file.clone(), edge.callee.name.clone()));
}
}
false
}
pub fn find_taint_flows(&self) -> Vec<TaintFlow> {
let sources = self.source_functions();
let sinks = self.sink_functions();
let mut flows = Vec::new();
for source in &sources {
for sink in &sinks {
if source.file == sink.file {
continue;
}
if let Some(path) =
self.find_path(&source.file, &source.name, &sink.file, &sink.name)
{
let has_sanitizer = self.has_sanitizer_on_path(
&source.file,
&source.name,
&sink.file,
&sink.name,
);
if !has_sanitizer {
flows.push(TaintFlow {
source: (*source).clone(),
sink: (*sink).clone(),
path,
confidence: calculate_flow_confidence(source, sink),
});
}
}
}
}
flows
}
pub fn find_path(
&self,
from_file: &Path,
from_name: &str,
to_file: &Path,
to_name: &str,
) -> Option<Vec<FunctionDef>> {
use std::collections::VecDeque;
let mut visited = HashSet::new();
let mut queue: VecDeque<(PathBuf, String, Vec<FunctionDef>)> = VecDeque::new();
queue.push_back((from_file.to_path_buf(), from_name.to_string(), vec![]));
visited.insert((from_file.to_path_buf(), from_name.to_string()));
let max_depth = 15; let mut depth = 0;
let mut nodes_at_depth = 1;
let mut nodes_next_depth = 0;
while let Some((file, name, path)) = queue.pop_front() {
if depth > max_depth {
break;
}
if file == to_file && name == to_name {
return Some(path);
}
for edge in self.callees_of(&file, &name) {
let key = (edge.callee.file.clone(), edge.callee.name.clone());
if !visited.contains(&key) {
visited.insert(key);
let mut new_path = path.clone();
new_path.push(edge.caller.clone());
queue.push_back((edge.callee.file.clone(), edge.callee.name.clone(), new_path));
nodes_next_depth += 1;
}
}
nodes_at_depth -= 1;
if nodes_at_depth == 0 {
depth += 1;
nodes_at_depth = nodes_next_depth;
nodes_next_depth = 0;
}
}
None
}
pub fn update_classifications(
&mut self,
classifier: &FunctionClassifier,
parsed_files: &[rma_parser::ParsedFile],
) {
let all_classifications = classifier.classify_files_parallel(parsed_files);
for ((file, name), func_def) in self.functions.iter_mut() {
if let Some(classification) = all_classifications.get(&(file.clone(), name.clone())) {
func_def.classification = classification.clone();
}
}
}
pub fn listeners_of(&self, event_name: &str) -> Vec<&EventSite> {
self.event_bindings
.get(event_name)
.map(|b| b.listen_sites.iter().collect())
.unwrap_or_default()
}
pub fn emitters_of(&self, event_name: &str) -> Vec<&EventSite> {
self.event_bindings
.get(event_name)
.map(|b| b.emit_sites.iter().collect())
.unwrap_or_default()
}
pub fn get_event_binding(&self, event_name: &str) -> Option<&EventBinding> {
self.event_bindings.get(event_name)
}
pub fn event_names(&self) -> impl Iterator<Item = &String> {
self.event_bindings.keys()
}
pub fn all_event_bindings(&self) -> impl Iterator<Item = &EventBinding> {
self.event_bindings.values()
}
pub fn has_event_emitters(&self, event_name: &str) -> bool {
self.event_bindings
.get(event_name)
.map(|b| !b.emit_sites.is_empty())
.unwrap_or(false)
}
pub fn has_event_listeners(&self, event_name: &str) -> bool {
self.event_bindings
.get(event_name)
.map(|b| !b.listen_sites.is_empty())
.unwrap_or(false)
}
pub fn add_event_binding(&mut self, event_name: String, binding: EventBinding) {
self.event_bindings.insert(event_name, binding);
}
pub fn register_event_emit(&mut self, event_name: &str, site: EventSite) {
self.event_bindings
.entry(event_name.to_string())
.or_insert_with(|| EventBinding::new(event_name.to_string()))
.add_emit_site(site);
}
pub fn register_event_listen(&mut self, event_name: &str, site: EventSite) {
self.event_bindings
.entry(event_name.to_string())
.or_insert_with(|| EventBinding::new(event_name.to_string()))
.add_listen_site(site);
}
pub fn merge_event_registry(&mut self, registry: EventRegistry) {
for binding in registry.all_bindings() {
let entry = self
.event_bindings
.entry(binding.event_name.clone())
.or_insert_with(|| EventBinding::new(binding.event_name.clone()));
for site in &binding.emit_sites {
entry.add_emit_site(site.clone());
}
for site in &binding.listen_sites {
entry.add_listen_site(site.clone());
}
}
}
}
#[derive(Debug, Default)]
pub struct CallGraphBuilder {
functions: HashMap<(PathBuf, String), FunctionDef>,
call_sites: Vec<CallSite>,
imports_by_file: HashMap<PathBuf, FileImports>,
}
impl CallGraphBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn add_file(
&mut self,
file_path: &Path,
language: Language,
functions: Vec<(String, usize, bool)>, calls: Vec<(String, usize, Option<String>)>, imports: FileImports,
) {
for (name, line, is_exported) in functions {
let def = FunctionDef {
name: name.clone(),
file: file_path.to_path_buf(),
line,
is_exported,
language,
classification: FunctionClassification::default(),
};
self.functions.insert((file_path.to_path_buf(), name), def);
}
for (callee_name, line, caller_function) in calls {
self.call_sites.push(CallSite {
callee_name,
caller_file: file_path.to_path_buf(),
caller_function,
line,
resolved_target: None,
});
}
self.imports_by_file
.insert(file_path.to_path_buf(), imports);
}
pub fn build(mut self) -> CallGraph {
let mut graph = CallGraph {
functions: self.functions.clone(),
functions_by_name: HashMap::new(),
caller_to_callees: HashMap::new(),
callee_to_callers: HashMap::new(),
call_sites: Vec::new(),
unresolved_calls: Vec::new(),
event_bindings: HashMap::new(),
};
for ((_, name), def) in &self.functions {
graph
.functions_by_name
.entry(name.clone())
.or_default()
.push(def.clone());
}
let call_sites = std::mem::take(&mut self.call_sites);
for mut call_site in call_sites {
let resolved = self.resolve_call(&call_site);
match resolved {
Some(callee_def) => {
call_site.resolved_target = Some(callee_def.file.clone());
let caller_def = if let Some(ref caller_name) = call_site.caller_function {
self.functions
.get(&(call_site.caller_file.clone(), caller_name.clone()))
.cloned()
} else {
None
};
let caller_def = caller_def.unwrap_or_else(|| FunctionDef {
name: call_site
.caller_function
.clone()
.unwrap_or_else(|| "<module>".to_string()),
file: call_site.caller_file.clone(),
line: call_site.line,
is_exported: false,
language: Language::Unknown,
classification: FunctionClassification::default(),
});
let is_cross_file = caller_def.file != callee_def.file;
let edge = CallEdge {
caller: caller_def.clone(),
callee: callee_def.clone(),
call_site: call_site.clone(),
is_cross_file,
};
graph
.caller_to_callees
.entry((caller_def.file.clone(), caller_def.name.clone()))
.or_default()
.push(edge.clone());
graph
.callee_to_callers
.entry((callee_def.file.clone(), callee_def.name.clone()))
.or_default()
.push(edge);
graph.call_sites.push(call_site);
}
None => {
graph.unresolved_calls.push(call_site);
}
}
}
graph
}
fn resolve_call(&self, call_site: &CallSite) -> Option<FunctionDef> {
if let Some(def) = self
.functions
.get(&(call_site.caller_file.clone(), call_site.callee_name.clone()))
{
return Some(def.clone());
}
if let Some(imports) = self.imports_by_file.get(&call_site.caller_file) {
for import in &imports.imports {
if import.local_name == call_site.callee_name {
if let Some(def) = self
.functions
.get(&(import.source_file.clone(), import.exported_name.clone()))
{
return Some(def.clone());
}
}
}
}
if let Some(defs) = self
.functions
.iter()
.filter(|((_, name), _)| name == &call_site.callee_name)
.map(|(_, def)| def)
.next()
{
return Some(defs.clone());
}
None
}
}
pub fn extract_function_definitions(
tree: &tree_sitter::Tree,
source: &[u8],
language: Language,
) -> Vec<(String, usize, bool)> {
let mut functions = Vec::new();
let root = tree.root_node();
extract_functions_recursive(root, source, language, &mut functions);
functions
}
fn extract_functions_recursive(
node: tree_sitter::Node,
source: &[u8],
language: Language,
functions: &mut Vec<(String, usize, bool)>,
) {
let is_function = match language {
Language::JavaScript | Language::TypeScript => matches!(
node.kind(),
"function_declaration" | "function_expression" | "arrow_function" | "method_definition"
),
Language::Python => node.kind() == "function_definition",
Language::Rust => node.kind() == "function_item",
Language::Go => {
matches!(node.kind(), "function_declaration" | "method_declaration")
}
Language::Java => node.kind() == "method_declaration",
Language::Php => matches!(node.kind(), "function_definition" | "method_declaration"),
Language::CSharp => node.kind() == "method_declaration",
Language::Kotlin => matches!(node.kind(), "function_declaration" | "anonymous_function"),
Language::Scala => matches!(node.kind(), "function_definition" | "function_declaration"),
Language::Swift => node.kind() == "function_declaration",
Language::Bash => node.kind() == "function_definition",
Language::Elixir => node.kind() == "call",
Language::Solidity => node.kind() == "function_definition",
Language::OCaml => node.kind() == "let_binding",
_ => false,
};
if is_function && let Some(name) = extract_function_name(node, source, language) {
let line = node.start_position().row + 1;
let is_exported = is_function_exported(node, source, language);
functions.push((name, line, is_exported));
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
extract_functions_recursive(child, source, language, functions);
}
}
fn extract_function_name(
node: tree_sitter::Node,
source: &[u8],
language: Language,
) -> Option<String> {
match language {
Language::JavaScript | Language::TypeScript => {
if let Some(name_node) = node.child_by_field_name("name") {
return name_node.utf8_text(source).ok().map(|s| s.to_string());
}
if node.kind() == "arrow_function"
&& let Some(parent) = node.parent()
&& parent.kind() == "variable_declarator"
&& let Some(name_node) = parent.child_by_field_name("name")
{
return name_node.utf8_text(source).ok().map(|s| s.to_string());
}
None
}
Language::Python
| Language::Rust
| Language::Go
| Language::Java
| Language::Php
| Language::CSharp
| Language::Kotlin
| Language::Scala
| Language::Swift
| Language::Bash
| Language::Solidity
| Language::OCaml => node
.child_by_field_name("name")
.and_then(|n| n.utf8_text(source).ok())
.map(|s| s.to_string()),
Language::Elixir => {
if let Some(args) = node.child_by_field_name("arguments")
&& let Some(first) = args.named_child(0)
{
return first.utf8_text(source).ok().map(|s| s.to_string());
}
None
}
_ => None,
}
}
fn is_function_exported(node: tree_sitter::Node, source: &[u8], language: Language) -> bool {
match language {
Language::JavaScript | Language::TypeScript => {
if let Some(parent) = node.parent()
&& parent.kind() == "export_statement"
{
return true;
}
false
}
Language::Python => {
if let Some(name_node) = node.child_by_field_name("name")
&& let Ok(name) = name_node.utf8_text(source)
{
return !name.starts_with('_');
}
false
}
Language::Rust => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "visibility_modifier"
&& let Ok(text) = child.utf8_text(source)
{
return text.starts_with("pub");
}
}
false
}
Language::Go => {
if let Some(name_node) = node.child_by_field_name("name")
&& let Ok(name) = name_node.utf8_text(source)
{
return name.chars().next().is_some_and(|c| c.is_uppercase());
}
false
}
Language::Java | Language::CSharp => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "modifiers"
&& let Ok(text) = child.utf8_text(source)
{
return text.contains("public");
}
}
false
}
Language::Php => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "visibility_modifier"
&& let Ok(text) = child.utf8_text(source)
{
return text == "public";
}
}
node.kind() == "function_definition"
}
Language::Kotlin => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "modifiers"
&& let Ok(text) = child.utf8_text(source)
&& (text.contains("private") || text.contains("internal"))
{
return false;
}
}
true
}
Language::Scala => true, Language::Swift => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "modifiers"
&& let Ok(text) = child.utf8_text(source)
{
return text.contains("public") || text.contains("open");
}
}
false
}
Language::Solidity => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "visibility"
&& let Ok(text) = child.utf8_text(source)
{
return text == "public" || text == "external";
}
}
false
}
Language::Bash | Language::Elixir | Language::OCaml => true,
_ => false,
}
}
pub fn extract_function_calls(
tree: &tree_sitter::Tree,
source: &[u8],
language: Language,
) -> Vec<(String, usize, Option<String>)> {
let mut calls = Vec::new();
let root = tree.root_node();
extract_calls_recursive(root, source, language, &mut calls, None);
calls
}
fn extract_calls_recursive(
node: tree_sitter::Node,
source: &[u8],
language: Language,
calls: &mut Vec<(String, usize, Option<String>)>,
current_function: Option<String>,
) {
let new_function = match language {
Language::JavaScript | Language::TypeScript => {
if matches!(
node.kind(),
"function_declaration" | "function_expression" | "method_definition"
) {
extract_function_name(node, source, language)
} else {
None
}
}
Language::Python => {
if node.kind() == "function_definition" {
extract_function_name(node, source, language)
} else {
None
}
}
Language::Rust => {
if node.kind() == "function_item" {
extract_function_name(node, source, language)
} else {
None
}
}
Language::Go => {
if matches!(node.kind(), "function_declaration" | "method_declaration") {
extract_function_name(node, source, language)
} else {
None
}
}
Language::Java | Language::CSharp => {
if node.kind() == "method_declaration" {
extract_function_name(node, source, language)
} else {
None
}
}
Language::Php => {
if matches!(node.kind(), "function_definition" | "method_declaration") {
extract_function_name(node, source, language)
} else {
None
}
}
Language::Kotlin | Language::Scala | Language::Swift => {
if matches!(node.kind(), "function_declaration" | "function_definition") {
extract_function_name(node, source, language)
} else {
None
}
}
Language::Bash => {
if node.kind() == "function_definition" {
extract_function_name(node, source, language)
} else {
None
}
}
Language::Solidity => {
if node.kind() == "function_definition" {
extract_function_name(node, source, language)
} else {
None
}
}
Language::Elixir => {
if node.kind() == "call" {
extract_function_name(node, source, language)
} else {
None
}
}
Language::OCaml => {
if node.kind() == "let_binding" {
extract_function_name(node, source, language)
} else {
None
}
}
_ => None,
};
let func_context = new_function.or(current_function);
let is_call = matches!(
node.kind(),
"call_expression" | "member_expression" | "method_invocation"
);
if is_call && let Some(callee_name) = extract_callee_name(node, source, language) {
let line = node.start_position().row + 1;
calls.push((callee_name, line, func_context.clone()));
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
extract_calls_recursive(child, source, language, calls, func_context.clone());
}
}
fn extract_callee_name(
node: tree_sitter::Node,
source: &[u8],
language: Language,
) -> Option<String> {
match language {
Language::JavaScript | Language::TypeScript => {
if let Some(func_node) = node.child_by_field_name("function") {
match func_node.kind() {
"identifier" => {
return func_node.utf8_text(source).ok().map(|s| s.to_string());
}
"member_expression" => {
if let Some(prop) = func_node.child_by_field_name("property") {
return prop.utf8_text(source).ok().map(|s| s.to_string());
}
}
_ => {}
}
}
None
}
Language::Python => {
if let Some(func_node) = node.child_by_field_name("function") {
match func_node.kind() {
"identifier" => {
return func_node.utf8_text(source).ok().map(|s| s.to_string());
}
"attribute" => {
if let Some(attr) = func_node.child_by_field_name("attribute") {
return attr.utf8_text(source).ok().map(|s| s.to_string());
}
}
_ => {}
}
}
None
}
Language::Rust => {
if let Some(func_node) = node.child_by_field_name("function") {
match func_node.kind() {
"identifier" => {
return func_node.utf8_text(source).ok().map(|s| s.to_string());
}
"scoped_identifier" | "field_expression" => {
if let Some(name) = func_node.child_by_field_name("name") {
return name.utf8_text(source).ok().map(|s| s.to_string());
}
if let Some(field) = func_node.child_by_field_name("field") {
return field.utf8_text(source).ok().map(|s| s.to_string());
}
}
_ => {}
}
}
None
}
Language::Go | Language::Java => {
if let Some(name_node) = node.child_by_field_name("name") {
return name_node.utf8_text(source).ok().map(|s| s.to_string());
}
if let Some(func_node) = node.child_by_field_name("function")
&& func_node.kind() == "identifier"
{
return func_node.utf8_text(source).ok().map(|s| s.to_string());
}
None
}
Language::Php => {
if let Some(func_node) = node.child_by_field_name("function") {
match func_node.kind() {
"name" => return func_node.utf8_text(source).ok().map(|s| s.to_string()),
"member_access_expression" => {
if let Some(name) = func_node.child_by_field_name("name") {
return name.utf8_text(source).ok().map(|s| s.to_string());
}
}
_ => {}
}
}
if let Some(name_node) = node.child_by_field_name("name") {
return name_node.utf8_text(source).ok().map(|s| s.to_string());
}
None
}
Language::CSharp => {
if let Some(func_node) = node.child_by_field_name("expression") {
match func_node.kind() {
"identifier" => return func_node.utf8_text(source).ok().map(|s| s.to_string()),
"member_access_expression" => {
if let Some(name) = func_node.child_by_field_name("name") {
return name.utf8_text(source).ok().map(|s| s.to_string());
}
}
_ => {}
}
}
None
}
Language::Kotlin | Language::Scala | Language::Swift => {
if let Some(func_node) = node.child_by_field_name("function") {
match func_node.kind() {
"simple_identifier" | "identifier" => {
return func_node.utf8_text(source).ok().map(|s| s.to_string());
}
"navigation_expression" | "field_expression" => {
if let Some(name) = func_node.child_by_field_name("name").or_else(|| {
func_node.named_child(func_node.named_child_count().saturating_sub(1))
}) {
return name.utf8_text(source).ok().map(|s| s.to_string());
}
}
_ => {}
}
}
None
}
Language::Bash => {
if let Some(name_node) = node.child_by_field_name("name") {
return name_node.utf8_text(source).ok().map(|s| s.to_string());
}
node.named_child(0)
.and_then(|n| n.utf8_text(source).ok())
.map(|s| s.to_string())
}
Language::Elixir => {
if let Some(target) = node.child_by_field_name("target") {
return target.utf8_text(source).ok().map(|s| s.to_string());
}
None
}
Language::Solidity => {
if let Some(func_node) = node.child_by_field_name("function") {
match func_node.kind() {
"identifier" => return func_node.utf8_text(source).ok().map(|s| s.to_string()),
"member_expression" => {
if let Some(prop) = func_node.child_by_field_name("property") {
return prop.utf8_text(source).ok().map(|s| s.to_string());
}
}
_ => {}
}
}
None
}
Language::OCaml => {
if let Some(func_node) = node.child_by_field_name("function") {
return func_node.utf8_text(source).ok().map(|s| s.to_string());
}
None
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::imports::ResolvedImport;
#[test]
fn test_call_graph_builder() {
let mut builder = CallGraphBuilder::new();
builder.add_file(
Path::new("/project/src/utils.js"),
Language::JavaScript,
vec![("sanitize".to_string(), 1, true)],
vec![],
FileImports::default(),
);
let mut imports = FileImports::default();
imports.imports.push(ResolvedImport {
local_name: "sanitize".to_string(),
source_file: PathBuf::from("/project/src/utils.js"),
exported_name: "sanitize".to_string(),
kind: crate::imports::ImportKind::Named,
specifier: "./utils".to_string(),
line: 1,
});
builder.add_file(
Path::new("/project/src/handler.js"),
Language::JavaScript,
vec![("handleRequest".to_string(), 5, true)],
vec![(
"sanitize".to_string(),
10,
Some("handleRequest".to_string()),
)],
imports,
);
let graph = builder.build();
assert_eq!(graph.function_count(), 2);
assert_eq!(graph.edge_count(), 1);
let edges = graph.cross_file_edges();
assert_eq!(edges.len(), 1);
assert!(edges[0].is_cross_file);
}
#[test]
fn test_reachability() {
let mut builder = CallGraphBuilder::new();
builder.add_file(
Path::new("/a.js"),
Language::JavaScript,
vec![("funcA".to_string(), 1, true)],
vec![("funcB".to_string(), 2, Some("funcA".to_string()))],
FileImports::default(),
);
builder.add_file(
Path::new("/b.js"),
Language::JavaScript,
vec![("funcB".to_string(), 1, true)],
vec![("funcC".to_string(), 2, Some("funcB".to_string()))],
FileImports::default(),
);
builder.add_file(
Path::new("/c.js"),
Language::JavaScript,
vec![("funcC".to_string(), 1, true)],
vec![],
FileImports::default(),
);
let graph = builder.build();
assert!(graph.is_reachable(Path::new("/a.js"), "funcA", Path::new("/c.js"), "funcC"));
assert!(!graph.is_reachable(Path::new("/c.js"), "funcC", Path::new("/a.js"), "funcA"));
}
}