use std::path::Path;
use streaming_iterator::StreamingIterator;
use tree_sitter::{Query, QueryCursor};
use crate::dfg::types::DFGInfo;
use crate::error::{Result, BrrrError};
use crate::lang::LanguageRegistry;
use crate::util::format_query_error;
#[allow(dead_code)]
pub struct DfgBuilder;
#[allow(dead_code)]
impl DfgBuilder {
pub fn extract_from_file(file: &str, function: &str) -> Result<DFGInfo> {
Self::extract_from_file_with_language(file, function, None)
}
pub fn extract_from_file_with_language(
file: &str,
function: &str,
language: Option<&str>,
) -> Result<DFGInfo> {
if file.contains('\0') {
return Err(BrrrError::PathTraversal {
target: "<contains null byte>".to_string(),
base: "<DFG extraction>".to_string(),
});
}
let path = Path::new(file);
let registry = LanguageRegistry::global();
let mut depth: i32 = 0;
for component in path.components() {
match component {
std::path::Component::ParentDir => {
depth -= 1;
if depth < -10 {
return Err(BrrrError::PathTraversal {
target: file.to_string(),
base: "<DFG extraction>".to_string(),
});
}
}
std::path::Component::Normal(_) => {
depth += 1;
}
_ => {}
}
}
let lang = match language {
Some(lang_name) => registry.get_by_name(lang_name).ok_or_else(|| {
BrrrError::UnsupportedLanguage(lang_name.to_string())
})?,
None => registry.detect_language(path).ok_or_else(|| {
BrrrError::UnsupportedLanguage(
path.extension()
.and_then(|e| e.to_str())
.unwrap_or("unknown")
.to_string(),
)
})?,
};
let source = std::fs::read(path)
.map_err(|e| BrrrError::io_with_path(e, path))?;
let mut parser = lang.parser_for_path(path)?;
let tree = parser
.parse(&source, None)
.ok_or_else(|| BrrrError::Parse {
file: file.to_string(),
message: "Failed to parse file".to_string(),
})?;
let function_node = Self::find_function_node(&tree, &source, lang, function)?;
lang.build_dfg(function_node, &source)
}
pub fn extract_from_source(source: &str, language: &str, function: &str) -> Result<DFGInfo> {
let registry = LanguageRegistry::global();
let lang = registry
.get_by_name(language)
.ok_or_else(|| BrrrError::UnsupportedLanguage(language.to_string()))?;
let source_bytes = source.as_bytes();
let mut parser = lang.parser()?;
let tree = parser
.parse(source_bytes, None)
.ok_or_else(|| BrrrError::Parse {
file: "<string>".to_string(),
message: "Failed to parse source".to_string(),
})?;
let function_node = Self::find_function_node(&tree, source_bytes, lang, function)?;
lang.build_dfg(function_node, source_bytes)
}
fn find_function_node<'a>(
tree: &'a tree_sitter::Tree,
source: &'a [u8],
lang: &dyn crate::lang::Language,
function_name: &str,
) -> Result<tree_sitter::Node<'a>> {
let query_str = lang.function_query();
let ts_lang = tree.language();
let query = Query::new(&ts_lang, query_str).map_err(|e| {
BrrrError::TreeSitter(format_query_error(lang.name(), "function", query_str, &e))
})?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, tree.root_node(), source);
let function_capture_idx = query.capture_index_for_name("function");
let name_capture_idx = query.capture_index_for_name("name");
let mut seen_ranges: Vec<(usize, usize)> = Vec::new();
while let Some(match_) = matches.next() {
let func_node = if let Some(idx) = function_capture_idx {
match_
.captures
.iter()
.find(|c| c.index == idx)
.map(|c| c.node)
} else {
match_.captures.first().map(|c| c.node)
};
let name_node = if let Some(idx) = name_capture_idx {
match_
.captures
.iter()
.find(|c| c.index == idx)
.map(|c| c.node)
} else {
None
};
if let (Some(func_node), Some(name_node)) = (func_node, name_node) {
let start = func_node.start_byte();
let end = func_node.end_byte();
let overlaps = seen_ranges.iter().any(|(s, e)| start < *e && *s < end);
if overlaps {
continue;
}
seen_ranges.push((start, end));
let name =
std::str::from_utf8(&source[name_node.start_byte()..name_node.end_byte()])
.unwrap_or("");
if name == function_name {
return Ok(Self::unwrap_decorated_function(func_node));
}
}
}
Self::find_method_node(tree, source, lang, function_name)
}
fn unwrap_decorated_function(node: tree_sitter::Node) -> tree_sitter::Node {
if node.kind() == "decorated_definition" {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "function_definition" {
return child;
}
}
}
node
}
fn find_method_node<'a>(
tree: &'a tree_sitter::Tree,
source: &'a [u8],
lang: &dyn crate::lang::Language,
method_name: &str,
) -> Result<tree_sitter::Node<'a>> {
let class_query_str = lang.class_query();
let ts_lang = tree.language();
let query = Query::new(&ts_lang, class_query_str).map_err(|e| {
BrrrError::TreeSitter(format_query_error(lang.name(), "class", class_query_str, &e))
})?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, tree.root_node(), source);
let class_capture_idx = query.capture_index_for_name("class");
while let Some(match_) = matches.next() {
let class_node = if let Some(idx) = class_capture_idx {
match_
.captures
.iter()
.find(|c| c.index == idx)
.map(|c| c.node)
} else {
match_.captures.first().map(|c| c.node)
};
if let Some(class_node) = class_node {
if let Some(method_node) =
Self::find_method_in_class(class_node, source, method_name)
{
return Ok(method_node);
}
}
}
Err(BrrrError::FunctionNotFound(format!(
"Function '{}' not found in file",
method_name
)))
}
fn find_method_in_class<'a>(
class_node: tree_sitter::Node<'a>,
source: &[u8],
method_name: &str,
) -> Option<tree_sitter::Node<'a>> {
let body_node = Self::find_class_body(class_node)?;
let mut cursor = body_node.walk();
for child in body_node.children(&mut cursor) {
match child.kind() {
"function_definition" | "method_definition" | "method_declaration" => {
if let Some(name) = Self::extract_function_name(child, source) {
if name == method_name {
return Some(child);
}
}
}
"decorated_definition" => {
let mut inner_cursor = child.walk();
for inner in child.children(&mut inner_cursor) {
if inner.kind() == "function_definition" {
if let Some(name) = Self::extract_function_name(inner, source) {
if name == method_name {
return Some(inner);
}
}
}
}
}
_ => {}
}
}
None
}
fn find_class_body(class_node: tree_sitter::Node) -> Option<tree_sitter::Node> {
let mut cursor = class_node.walk();
for child in class_node.children(&mut cursor) {
match child.kind() {
"block" | "class_body" | "declaration_list" => {
return Some(child);
}
_ => {}
}
}
None
}
fn extract_function_name(node: tree_sitter::Node, source: &[u8]) -> Option<String> {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"identifier" | "property_identifier" | "field_identifier" => {
return std::str::from_utf8(&source[child.start_byte()..child.end_byte()])
.ok()
.map(|s| s.to_string());
}
"name" => {
let mut inner = child.walk();
for inner_child in child.children(&mut inner) {
if inner_child.kind() == "identifier" {
return std::str::from_utf8(
&source[inner_child.start_byte()..inner_child.end_byte()],
)
.ok()
.map(|s| s.to_string());
}
}
}
_ => {}
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_temp_file(content: &str, extension: &str) -> NamedTempFile {
let mut file = tempfile::Builder::new()
.suffix(extension)
.tempfile()
.unwrap();
file.write_all(content.as_bytes()).unwrap();
file
}
#[test]
fn test_extract_simple_function_dfg() {
let source = r#"
def example(x, y):
z = x + y
return z
"#;
let file = create_temp_file(source, ".py");
let dfg = DfgBuilder::extract_from_file(file.path().to_str().unwrap(), "example");
assert!(dfg.is_ok());
let dfg = dfg.unwrap();
assert_eq!(dfg.function_name, "example");
assert!(dfg.definitions.contains_key("x"));
assert!(dfg.definitions.contains_key("y"));
assert!(dfg.definitions.contains_key("z"));
assert!(dfg.uses.contains_key("x"));
assert!(dfg.uses.contains_key("y"));
assert!(dfg.uses.contains_key("z"));
assert!(!dfg.edges.is_empty());
}
#[test]
fn test_extract_function_with_mutation() {
let source = r#"
def accumulate(items):
total = 0
for item in items:
total += item
return total
"#;
let file = create_temp_file(source, ".py");
let dfg = DfgBuilder::extract_from_file(file.path().to_str().unwrap(), "accumulate");
assert!(dfg.is_ok());
let dfg = dfg.unwrap();
assert_eq!(dfg.function_name, "accumulate");
assert!(dfg.definitions.contains_key("total"));
assert!(dfg.uses.contains_key("total"));
assert!(dfg.definitions.contains_key("item"));
assert!(dfg.uses.contains_key("item"));
}
#[test]
fn test_extract_function_with_conditional() {
let source = r#"
def process(x):
if x > 0:
result = x * 2
else:
result = 0
return result
"#;
let file = create_temp_file(source, ".py");
let dfg = DfgBuilder::extract_from_file(file.path().to_str().unwrap(), "process");
assert!(dfg.is_ok());
let dfg = dfg.unwrap();
assert!(dfg.definitions.contains_key("result"));
assert!(dfg.uses.contains_key("x"));
}
#[test]
fn test_extract_decorated_function() {
let source = r#"
@staticmethod
def my_static(x):
y = x + 1
return y
"#;
let file = create_temp_file(source, ".py");
let dfg = DfgBuilder::extract_from_file(file.path().to_str().unwrap(), "my_static");
assert!(dfg.is_ok());
let dfg = dfg.unwrap();
assert_eq!(dfg.function_name, "my_static");
assert!(dfg.definitions.contains_key("x"));
assert!(dfg.definitions.contains_key("y"));
}
#[test]
fn test_extract_class_method() {
let source = r#"
class Calculator:
def add(self, a, b):
result = a + b
return result
"#;
let file = create_temp_file(source, ".py");
let dfg = DfgBuilder::extract_from_file(file.path().to_str().unwrap(), "add");
assert!(dfg.is_ok());
let dfg = dfg.unwrap();
assert_eq!(dfg.function_name, "add");
assert!(dfg.definitions.contains_key("self"));
assert!(dfg.definitions.contains_key("a"));
assert!(dfg.definitions.contains_key("b"));
assert!(dfg.definitions.contains_key("result"));
assert!(dfg.uses.contains_key("result"));
}
#[test]
fn test_function_not_found() {
let source = r#"
def existing_function():
pass
"#;
let file = create_temp_file(source, ".py");
let dfg = DfgBuilder::extract_from_file(file.path().to_str().unwrap(), "nonexistent");
assert!(dfg.is_err());
assert!(matches!(dfg, Err(BrrrError::FunctionNotFound(_))));
}
#[test]
fn test_extract_from_source() {
let source = r#"
def multiply(a, b):
return a * b
"#;
let dfg = DfgBuilder::extract_from_source(source, "python", "multiply");
assert!(dfg.is_ok());
let dfg = dfg.unwrap();
assert_eq!(dfg.function_name, "multiply");
assert!(dfg.definitions.contains_key("a"));
assert!(dfg.definitions.contains_key("b"));
assert!(dfg.uses.contains_key("a"));
assert!(dfg.uses.contains_key("b"));
}
#[test]
fn test_unsupported_language() {
let file = create_temp_file("content", ".xyz");
let result = DfgBuilder::extract_from_file(file.path().to_str().unwrap(), "func");
assert!(matches!(result, Err(BrrrError::UnsupportedLanguage(_))));
}
#[test]
fn test_backward_slice() {
let source = r#"
def compute(x):
a = x + 1
b = a * 2
c = b + x
return c
"#;
let dfg = DfgBuilder::extract_from_source(source, "python", "compute").unwrap();
let slice = dfg.backward_slice(5);
assert!(!slice.is_empty());
}
#[test]
fn test_forward_slice() {
let source = r#"
def compute(x):
a = x + 1
b = a * 2
return b
"#;
let dfg = DfgBuilder::extract_from_source(source, "python", "compute").unwrap();
let slice = dfg.forward_slice(2);
assert!(!slice.is_empty());
}
#[test]
fn test_multiple_assignment_targets() {
let source = r#"
def swap(pair):
a, b = pair
return b, a
"#;
let dfg = DfgBuilder::extract_from_source(source, "python", "swap").unwrap();
assert!(dfg.definitions.contains_key("a"));
assert!(dfg.definitions.contains_key("b"));
assert!(dfg.uses.contains_key("a"));
assert!(dfg.uses.contains_key("b"));
assert!(dfg.uses.contains_key("pair"));
}
#[test]
fn test_comprehension_variables() {
let source = r#"
def squared(items):
result = [x * x for x in items]
return result
"#;
let dfg = DfgBuilder::extract_from_source(source, "python", "squared").unwrap();
assert!(dfg.definitions.contains_key("x"));
assert!(dfg.uses.contains_key("x"));
assert!(dfg.uses.contains_key("items"));
}
#[test]
fn test_with_statement() {
let source = r#"
def read_file(path):
with open(path) as f:
content = f.read()
return content
"#;
let dfg = DfgBuilder::extract_from_source(source, "python", "read_file").unwrap();
assert!(dfg.definitions.contains_key("f"));
assert!(dfg.uses.contains_key("f"));
assert!(dfg.definitions.contains_key("content"));
}
#[test]
fn test_try_except_variables() {
let source = r#"
def safe_parse(text):
try:
result = int(text)
except ValueError as e:
result = 0
return result
"#;
let dfg = DfgBuilder::extract_from_source(source, "python", "safe_parse").unwrap();
assert!(dfg.definitions.contains_key("e"));
assert!(dfg.definitions.contains_key("result"));
}
#[test]
fn test_variables_method() {
let source = r#"
def example(x, y):
z = x + y
w = z * 2
return w
"#;
let dfg = DfgBuilder::extract_from_source(source, "python", "example").unwrap();
let vars = dfg.variables();
assert!(vars.contains(&"x"));
assert!(vars.contains(&"y"));
assert!(vars.contains(&"z"));
assert!(vars.contains(&"w"));
}
#[test]
fn test_overlap_detection_algorithm() {
fn overlaps(start: usize, end: usize, s: usize, e: usize) -> bool {
start < e && s < end
}
assert!(
overlaps(10, 20, 15, 25),
"Partial overlap should be detected"
);
assert!(
overlaps(15, 25, 10, 20),
"Partial overlap should be detected (reversed)"
);
assert!(overlaps(10, 30, 15, 20), "Containment should be detected");
assert!(
overlaps(15, 20, 10, 30),
"Containment should be detected (reversed)"
);
assert!(
!overlaps(10, 20, 20, 30),
"Adjacent intervals should not overlap"
);
assert!(
!overlaps(10, 20, 25, 30),
"Disjoint intervals should not overlap"
);
assert!(
!overlaps(25, 30, 10, 20),
"Disjoint intervals should not overlap (reversed)"
);
assert!(overlaps(10, 20, 10, 20), "Same interval should overlap");
assert!(
overlaps(10, 20, 19, 25),
"Should overlap when ranges share interior point"
);
}
#[test]
fn test_decorated_functions_no_duplicates() {
let source = r#"
@decorator1
def func1(x):
y = x + 1
return y
@decorator2
@decorator3
def func2(x):
if x > 0:
result = x
else:
result = 0
return result
def plain_func(x):
return x * 2
"#;
let dfg1 = DfgBuilder::extract_from_source(source, "python", "func1");
assert!(dfg1.is_ok(), "Should extract func1 DFG");
let dfg1 = dfg1.unwrap();
assert_eq!(dfg1.function_name, "func1");
assert!(dfg1.definitions.contains_key("x"));
assert!(dfg1.definitions.contains_key("y"));
let dfg2 = DfgBuilder::extract_from_source(source, "python", "func2");
assert!(dfg2.is_ok(), "Should extract func2 DFG");
let dfg2 = dfg2.unwrap();
assert_eq!(dfg2.function_name, "func2");
assert!(dfg2.definitions.contains_key("x"));
assert!(dfg2.definitions.contains_key("result"));
let dfg3 = DfgBuilder::extract_from_source(source, "python", "plain_func");
assert!(dfg3.is_ok(), "Should extract plain_func DFG");
assert_eq!(dfg3.unwrap().function_name, "plain_func");
}
#[test]
fn test_extract_from_file_rejects_path_traversal() {
let malicious_path = "../../../../../../../../../../../etc/passwd";
let result = DfgBuilder::extract_from_file(malicious_path, "main");
assert!(result.is_err());
match result.unwrap_err() {
BrrrError::PathTraversal { .. } => {}
BrrrError::Io(_) => {} BrrrError::UnsupportedLanguage(_) => {} e => panic!("Expected PathTraversal, Io, or UnsupportedLanguage error, got: {:?}", e),
}
}
#[test]
fn test_extract_from_file_allows_valid_paths() {
let source = r#"
def example(x, y):
z = x + y
return z
"#;
let file = create_temp_file(source, ".py");
let result = DfgBuilder::extract_from_file(file.path().to_str().unwrap(), "example");
assert!(result.is_ok(), "Valid path should be accepted");
}
#[test]
fn test_extract_from_source_bypasses_path_validation() {
let source = r#"
def test():
return 42
"#;
let result = DfgBuilder::extract_from_source(source, "python", "test");
assert!(result.is_ok(), "extract_from_source should work with string input");
}
}