use crate::cg::{
CallGraph,
CallGraphGeneratorContext,
CallGraphGeneratorInput,
CallGraphGeneratorPipeline,
CallGraphGeneratorStep,
EdgeType,
Node,
NodeType,
Visibility,
EVENT_LISTENER_NODE_NAME, EVM_NODE_NAME,
};
use crate::cg_dot;
use crate::parser::parse_solidity;
use anyhow::Result; use crate::parser::get_solidity_language;
use std::collections::HashMap;
use crate::steps::CallsHandling;
use crate::steps::ContractHandling;
fn find_node<'a>(graph: &'a CallGraph, name: &str, contract: Option<&str>) -> Option<&'a Node> {
graph
.iter_nodes()
.find(|n| n.name == name && n.contract_name.as_deref() == contract)
}
fn assert_visibility(node: &Node, expected: Visibility) {
assert_eq!(
node.visibility, expected,
"Node '{}' should have {:?} visibility, but has {:?}",
node.name, expected, node.visibility
);
}
#[test]
fn test_simple_contract_call() -> Result<()> {
let source = r#"
pragma solidity ^0.8.0;
contract Simple {
function foo() public pure {}
function bar() private pure {
foo();
}
constructor() {
foo();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(graph.nodes.len(), 3, "Should find 3 nodes");
assert_eq!(graph.edges.len(), 2, "Should find 2 edges");
let foo_node = find_node(&graph, "foo", Some("Simple")).expect("foo node not found");
let bar_node = find_node(&graph, "bar", Some("Simple")).expect("bar node not found");
let constructor_node =
find_node(&graph, "Simple", Some("Simple")).expect("constructor node not found");
assert_eq!(foo_node.node_type, NodeType::Function);
assert_eq!(bar_node.node_type, NodeType::Function);
assert_eq!(constructor_node.node_type, NodeType::Constructor);
assert_visibility(&foo_node, Visibility::Public);
assert_visibility(&bar_node, Visibility::Private);
assert_visibility(&constructor_node, Visibility::Public);
assert_eq!(graph.nodes[0].id, foo_node.id);
assert_eq!(graph.nodes[1].id, bar_node.id);
assert_eq!(graph.nodes[2].id, constructor_node.id);
assert_eq!(graph.edges[0].source_node_id, bar_node.id);
assert_eq!(graph.edges[0].target_node_id, foo_node.id);
assert_eq!(graph.edges[0].sequence_number, 1, "bar -> foo sequence"); assert_eq!(graph.edges[1].source_node_id, constructor_node.id);
assert_eq!(graph.edges[1].target_node_id, foo_node.id);
assert_eq!(
graph.edges[1].sequence_number,
1, "constructor -> foo sequence"
);
assert_eq!(graph.iter_nodes().count(), 3);
assert_eq!(graph.iter_edges().count(), 2);
Ok(())
}
#[test]
fn test_delete_keyword() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
contract DeleteTest {
uint256 public myVar; // State variable
// Function that deletes the state variable
function deleteMyVar() public {
delete myVar; // Delete operation
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
3,
"Should find 3 nodes (state var, delete func, default ctor)"
);
let var_node = find_node(&graph, "myVar", Some("DeleteTest"))
.expect("DeleteTest.myVar node missing");
let delete_func_node = find_node(&graph, "deleteMyVar", Some("DeleteTest"))
.expect("DeleteTest.deleteMyVar node missing");
assert_eq!(var_node.node_type, NodeType::StorageVariable);
assert_eq!(delete_func_node.node_type, NodeType::Function);
assert_eq!(graph.edges.len(), 1, "Should find 1 edge (1 write)");
let write_edge = graph
.edges
.iter()
.find(|e| {
e.source_node_id == delete_func_node.id
&& e.target_node_id == var_node.id
&& e.edge_type == EdgeType::StorageWrite
})
.expect("StorageWrite edge from deleteMyVar to myVar missing");
assert_eq!(write_edge.sequence_number, 1, "StorageWrite (delete) sequence should be 1");
assert!(write_edge.call_site_span.0 > 0, "Delete call site span start should be > 0");
Ok(())
}
#[test]
fn test_interface_call_no_implementation() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
// Interface definition ONLY
interface IAction {
function performAction() external returns (bool);
}
// Contract that uses the interface type
contract ActionCaller {
IAction public actionContract; // State variable of interface type
constructor(address _actionAddress) {
actionContract = IAction(_actionAddress); // Assume setup elsewhere
}
function triggerAction() public returns (bool) {
// Call the interface method directly on the state variable
// No concrete implementation is known within this source unit.
return actionContract.performAction();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
5, "Should find 5 nodes (interface, iface func, contract func, contract ctor, state var)"
);
let caller_trigger_node = find_node(&graph, "triggerAction", Some("ActionCaller"))
.expect("ActionCaller.triggerAction node missing");
let iface_perform_node = find_node(&graph, "performAction", Some("IAction"))
.expect("IAction.performAction node missing");
let _iface_node =
find_node(&graph, "IAction", Some("IAction")).expect("IAction interface node missing"); let _caller_ctor_node = find_node(&graph, "ActionCaller", Some("ActionCaller")) .expect("ActionCaller constructor node missing");
assert_eq!(
graph.edges.len(),
3, "Should find 3 edges (1 call, 1 read, 1 write)"
);
let call_edge = graph
.edges
.iter()
.find(|e| e.source_node_id == caller_trigger_node.id && e.edge_type == EdgeType::Call) .expect("Call edge from triggerAction not found");
assert_eq!(
call_edge.source_node_id, caller_trigger_node.id,
"Edge source should be ActionCaller.triggerAction"
);
assert_eq!(
call_edge.target_node_id, iface_perform_node.id,
"Edge target should be IAction.performAction (the interface method node)"
);
assert_eq!(
call_edge.edge_type,
EdgeType::Call,
"Edge type should be Call"
);
assert_eq!(
call_edge.sequence_number, 2,
"Edge sequence number should be 2 (read + call)"
);
Ok(())
}
#[test]
fn test_contract_inheritance() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
contract Base {
uint public baseValue;
function baseFunction() public pure returns (uint) {
return 1;
}
function overriddenFunction() public virtual pure returns (uint) {
return 10;
}
}
contract Derived is Base {
uint public derivedValue;
function derivedFunction() public pure returns (uint) {
// Call a function from the base contract
return baseFunction();
}
// Override a function from the base contract
function overriddenFunction() public pure override returns (uint) {
return 20;
}
function callOverridden() public pure returns (uint) {
// Call the overridden version within Derived
return overriddenFunction();
}
function callBaseOverridden() public pure returns (uint) {
// Explicitly call the base version
return Base.overriddenFunction();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert!(
ctx.contract_inherits.contains_key("Derived"),
"Context should contain inheritance info for Derived"
);
let derived_inherits = ctx
.contract_inherits
.get("Derived")
.expect("Derived inheritance info missing");
assert!(
derived_inherits.contains(&"Base".to_string()),
"Derived should inherit from Base"
);
assert_eq!(
derived_inherits.len(),
1,
"Derived should only inherit from Base directly"
);
let _base_ctor_node =
find_node(&graph, "Base", Some("Base")).expect("Base constructor node missing");
let base_func_node =
find_node(&graph, "baseFunction", Some("Base")).expect("Base.baseFunction node missing");
let base_override_node = find_node(&graph, "overriddenFunction", Some("Base"))
.expect("Base.overriddenFunction node missing");
let _derived_ctor_node =
find_node(&graph, "Derived", Some("Derived")).expect("Derived constructor node missing");
let derived_func_node = find_node(&graph, "derivedFunction", Some("Derived"))
.expect("Derived.derivedFunction node missing");
let derived_override_node = find_node(&graph, "overriddenFunction", Some("Derived"))
.expect("Derived.overriddenFunction node missing");
let derived_call_override_node = find_node(&graph, "callOverridden", Some("Derived"))
.expect("Derived.callOverridden node missing");
let derived_call_base_override_node = find_node(&graph, "callBaseOverridden", Some("Derived"))
.expect("Derived.callBaseOverridden node missing");
assert_eq!(
graph.nodes.len(),
10,
"Should find 10 nodes total (+2 state vars)"
);
assert_eq!(graph.edges.len(), 3, "Should find 3 call edges");
let edge_derived_to_base = graph
.edges
.iter()
.find(|e| e.source_node_id == derived_func_node.id && e.target_node_id == base_func_node.id)
.expect("Edge derivedFunction -> baseFunction missing");
assert_eq!(edge_derived_to_base.edge_type, EdgeType::Call);
assert_eq!(edge_derived_to_base.sequence_number, 1);
let edge_call_to_derived_override = graph
.edges
.iter()
.find(|e| {
e.source_node_id == derived_call_override_node.id
&& e.target_node_id == derived_override_node.id
})
.expect("Edge callOverridden -> Derived.overriddenFunction missing");
assert_eq!(edge_call_to_derived_override.edge_type, EdgeType::Call);
assert_eq!(edge_call_to_derived_override.sequence_number, 1);
let edge_call_to_base_override = graph
.edges
.iter()
.find(|e| {
e.source_node_id == derived_call_base_override_node.id
&& e.target_node_id == base_override_node.id
})
.expect("Edge callBaseOverridden -> Base.overriddenFunction missing");
assert_eq!(edge_call_to_base_override.edge_type, EdgeType::Call);
assert_eq!(edge_call_to_base_override.sequence_number, 1);
Ok(())
}
#[test]
fn test_modifier_call() -> Result<()> {
let source = r#"
pragma solidity ^0.8.0;
contract Modifiers {
modifier onlyAdmin() {
checkAdmin();
_;
}
function checkAdmin() internal pure {}
function restricted() public onlyAdmin {}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(graph.nodes.len(), 4, "Should find 4 nodes");
assert_eq!(graph.edges.len(), 1, "Should find 1 edge");
let mod_node = find_node(&graph, "onlyAdmin", Some("Modifiers")).expect("modifier node");
let check_node = find_node(&graph, "checkAdmin", Some("Modifiers")).expect("checkAdmin node");
let restricted_node =
find_node(&graph, "restricted", Some("Modifiers")).expect("restricted node");
assert_eq!(mod_node.node_type, NodeType::Modifier);
assert_eq!(check_node.node_type, NodeType::Function);
assert_eq!(restricted_node.node_type, NodeType::Function);
assert_eq!(graph.nodes[0].id, mod_node.id);
assert_eq!(graph.nodes[1].id, check_node.id);
assert_eq!(graph.nodes[2].id, restricted_node.id);
assert_eq!(graph.edges[0].source_node_id, mod_node.id);
assert_eq!(graph.edges[0].target_node_id, check_node.id);
assert_eq!(
graph.edges[0].sequence_number, 1,
"onlyAdmin -> checkAdmin sequence"
);
Ok(())
}
#[test]
fn test_free_function_call() -> Result<()> {
let source = r#"
pragma solidity ^0.8.0;
function helper() pure returns (uint) {
return 1;
}
contract Caller {
function callHelper() public pure returns (uint) {
return helper();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(graph.nodes.len(), 3, "Should find 3 nodes");
assert_eq!(graph.edges.len(), 1, "Should find 1 edge");
let helper_node = find_node(&graph, "helper", None).expect("helper node");
let caller_node = find_node(&graph, "callHelper", Some("Caller")).expect("callHelper node");
assert_eq!(helper_node.node_type, NodeType::Function);
assert_eq!(helper_node.contract_name, None);
assert_eq!(caller_node.node_type, NodeType::Function);
assert_eq!(caller_node.contract_name, Some("Caller".to_string()));
assert_eq!(graph.nodes[0].id, helper_node.id);
assert_eq!(graph.nodes[1].id, caller_node.id);
assert_eq!(graph.edges[0].source_node_id, caller_node.id);
assert_eq!(graph.edges[0].target_node_id, helper_node.id);
assert_eq!(
graph.edges[0].sequence_number, 1,
"callHelper -> helper sequence"
);
Ok(())
}
#[test]
fn test_no_calls() -> Result<()> {
let source = r#"
pragma solidity ^0.8.0;
contract NoCalls {
function a() public pure {}
function b() public pure {}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree, solidity_lang: solidity_lang, };
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(graph.nodes.len(), 3, "Should find 3 nodes");
assert_eq!(graph.edges.len(), 0, "Should find 0 edges");
Ok(())
}
#[test]
fn test_call_order_within_function() -> Result<()> {
let source = r#"
pragma solidity ^0.8.0;
contract CallOrder {
function callee1() public pure {}
function callee2() public pure {}
function caller() public pure {
callee2();
callee1();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang: solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(graph.nodes.len(), 4, "Should find 4 nodes");
assert_eq!(graph.edges.len(), 2, "Should find 2 edges");
let c1_node = find_node(&graph, "callee1", Some("CallOrder")).unwrap();
let c2_node = find_node(&graph, "callee2", Some("CallOrder")).unwrap();
let caller_node = find_node(&graph, "caller", Some("CallOrder")).unwrap();
assert_eq!(graph.nodes[0].id, c1_node.id);
assert_eq!(graph.nodes[1].id, c2_node.id);
assert_eq!(graph.nodes[2].id, caller_node.id);
assert_eq!(
graph.edges[0].source_node_id, caller_node.id,
"Edge 0 source"
);
assert_eq!(
graph.edges[0].target_node_id, c2_node.id,
"Edge 0 target should be callee2 (first call)"
);
assert_eq!(graph.edges[0].sequence_number, 1, "First call sequence");
assert_eq!(
graph.edges[1].source_node_id, caller_node.id,
"Edge 1 source"
);
assert_eq!(
graph.edges[1].target_node_id, c1_node.id,
"Edge 1 target should be callee1 (second call)"
);
assert_eq!(graph.edges[1].sequence_number, 2, "Second call sequence");
Ok(())
}
#[test]
fn test_empty_source() -> Result<()> {
let source = "pragma solidity ^0.8.0;";
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(), tree: ast.tree, solidity_lang: solidity_lang, };
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(graph.nodes.len(), 0, "Should find 0 nodes");
assert_eq!(graph.edges.len(), 0, "Should find 0 edges");
Ok(())
}
#[test]
fn test_unresolved_call() -> Result<()> {
let source = r#"
pragma solidity ^0.8.0;
contract Unresolved {
function callNonExistent() public {
nonExistent();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
2,
"Should find 2 nodes (caller + default ctor)"
);
assert_eq!(
graph.edges.len(),
0,
"Should find 0 edges (call is unresolved)"
);
Ok(())
}
#[test]
fn test_inter_contract_call() -> Result<()> {
let source = r#"
pragma solidity ^0.8.0;
contract Counter {
uint count;
function increment() public {
count += 1;
}
}
contract CounterCaller {
Counter public myCounter;
constructor(address counterAddress) {
myCounter = Counter(counterAddress);
}
function callIncrement() public {
myCounter.increment();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(graph.nodes.len(), 6, "Should find 6 nodes (2 vars, 2 funcs, 2 ctors)");
let counter_inc_node =
find_node(&graph, "increment", Some("Counter")).expect("Counter.increment node not found");
let caller_ctor_node = find_node(&graph, "CounterCaller", Some("CounterCaller"))
.expect("CounterCaller.constructor node not found");
let caller_call_node = find_node(&graph, "callIncrement", Some("CounterCaller"))
.expect("CounterCaller.callIncrement node not found");
assert_eq!(counter_inc_node.node_type, NodeType::Function);
assert_eq!(caller_ctor_node.node_type, NodeType::Constructor);
assert_eq!(caller_call_node.node_type, NodeType::Function);
assert_eq!(
graph.edges.len(),
6,
"Should find 6 edges (constructor->constructor, constructor->myCounter write, callIncrement->myCounter read, callIncrement->increment call, increment->count read, increment->count write)"
);
let call_inc_edge = graph
.edges
.iter()
.find(|e| e.source_node_id == caller_call_node.id && e.edge_type == EdgeType::Call) .expect("Call edge from callIncrement to Counter.increment not found");
assert_eq!(
call_inc_edge.source_node_id, caller_call_node.id,
"Edge source should be callIncrement"
);
assert_eq!(
call_inc_edge.target_node_id, counter_inc_node.id,
"Edge target should be Counter.increment"
);
assert_eq!(
call_inc_edge.sequence_number, 2,
"callIncrement -> increment sequence should be 2 (read + call)"
);
Ok(())
}
#[test]
fn test_dot_escape_string_via_module() {
assert_eq!(cg_dot::escape_dot_string(""), "");
assert_eq!(cg_dot::escape_dot_string("simple"), "simple");
assert_eq!(
cg_dot::escape_dot_string("with \"quotes\""),
"with \\\"quotes\\\""
);
assert_eq!(cg_dot::escape_dot_string("new\nline"), "new\\nline");
assert_eq!(cg_dot::escape_dot_string("back\\slash"), "back\\\\slash");
assert_eq!(cg_dot::escape_dot_string("<html>"), "\\<html\\>");
assert_eq!(cg_dot::escape_dot_string("{record}"), "\\{record\\}");
}
#[test]
fn test_return_boolean_literal() -> Result<()> {
let source = r#"
pragma solidity ^0.8.0;
contract BoolReturn {
function returnsTrue() internal pure returns (bool) {
return true; // Return a boolean literal
}
function callsReturnTrue() public pure returns (bool) {
return returnsTrue();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input.clone(), &mut ctx, &mut graph, &config)?;
graph.add_explicit_return_edges(&input, &ctx)?;
assert_eq!(graph.nodes.len(), 3, "Should find 3 nodes");
assert_eq!(
graph.edges.len(),
2,
"Should find 2 edges (1 call, 1 return)"
);
let returns_true_node =
find_node(&graph, "returnsTrue", Some("BoolReturn")).expect("returnsTrue node");
let calls_return_true_node =
find_node(&graph, "callsReturnTrue", Some("BoolReturn")).expect("callsReturnTrue node");
let return_edge = graph
.iter_edges()
.find(|e| {
e.edge_type == EdgeType::Return
&& e.source_node_id == returns_true_node.id
&& e.target_node_id == calls_return_true_node.id
})
.expect("Return edge from returnsTrue to callsReturnTrue not found");
assert_eq!(
return_edge.returned_value,
Some("true".to_string()),
"Return edge should capture 'true'"
);
Ok(())
}
#[test]
fn test_pipeline_execution() -> Result<()> {
let source = r#"
pragma solidity ^0.8.0;
contract SimplePipeline {
function foo() public pure {}
function bar() private pure {
foo(); // Call within the contract
}
constructor() {
// No call in constructor for simplicity in this test
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(graph.nodes.len(), 3, "Pipeline: Should find 3 nodes");
assert_eq!(graph.edges.len(), 1, "Pipeline: Should find 1 edge");
let foo_node =
find_node(&graph, "foo", Some("SimplePipeline")).expect("Pipeline: foo node not found");
let bar_node =
find_node(&graph, "bar", Some("SimplePipeline")).expect("Pipeline: bar node not found");
let constructor_node = find_node(&graph, "SimplePipeline", Some("SimplePipeline"))
.expect("Pipeline: constructor node not found");
assert_eq!(foo_node.node_type, NodeType::Function);
assert_eq!(bar_node.node_type, NodeType::Function);
assert_eq!(constructor_node.node_type, NodeType::Constructor);
assert_visibility(&foo_node, Visibility::Public);
assert_visibility(&bar_node, Visibility::Private);
assert_visibility(&constructor_node, Visibility::Public);
assert_eq!(graph.nodes[0].id, foo_node.id);
assert_eq!(graph.nodes[1].id, bar_node.id);
assert_eq!(graph.nodes[2].id, constructor_node.id);
assert_eq!(graph.edges[0].source_node_id, bar_node.id);
assert_eq!(graph.edges[0].target_node_id, foo_node.id);
assert_eq!(graph.edges[0].edge_type, EdgeType::Call);
assert_eq!(
graph.edges[0].sequence_number, 1,
"Pipeline: bar -> foo sequence"
);
Ok(())
}
#[test]
fn test_pipeline_step_enable_disable() -> Result<()> {
let source = r#"
pragma solidity ^0.8.0;
contract EnableDisableTest {
function target() public pure {}
function caller() public pure {
target();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let config = HashMap::new();
let mut ctx_disabled = CallGraphGeneratorContext::default();
let mut graph_disabled = CallGraph::new();
let mut pipeline_disabled = CallGraphGeneratorPipeline::new();
pipeline_disabled.add_step(Box::new(ContractHandling::default()));
pipeline_disabled.add_step(Box::new(CallsHandling::default()));
pipeline_disabled.disable_step("Calls-Handling");
pipeline_disabled.run(
input.clone(),
&mut ctx_disabled,
&mut graph_disabled,
&config,
)?;
assert_eq!(
graph_disabled.nodes.len(),
3,
"Disabled: Should find 3 nodes (ContractHandling ran)"
);
assert_eq!(
graph_disabled.edges.len(),
0,
"Disabled: Should find 0 edges (CallsHandling disabled)"
);
let mut ctx_enabled = CallGraphGeneratorContext::default();
let mut graph_enabled = CallGraph::new();
let mut pipeline_enabled = CallGraphGeneratorPipeline::new();
pipeline_enabled.add_step(Box::new(ContractHandling::default()));
pipeline_enabled.add_step(Box::new(CallsHandling::default()));
pipeline_enabled.run(input, &mut ctx_enabled, &mut graph_enabled, &config)?;
assert_eq!(graph_enabled.nodes.len(), 3, "Enabled: Should find 3 nodes");
assert_eq!(
graph_enabled.edges.len(),
1,
"Enabled: Should find 1 edge (CallsHandling ran)"
);
let target_node = find_node(&graph_enabled, "target", Some("EnableDisableTest"))
.expect("Enabled: target node");
let caller_node = find_node(&graph_enabled, "caller", Some("EnableDisableTest"))
.expect("Enabled: caller node");
assert_eq!(graph_enabled.edges[0].source_node_id, caller_node.id);
assert_eq!(graph_enabled.edges[0].target_node_id, target_node.id);
Ok(())
}
#[test]
fn test_using_for_directive_extraction() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
library SafeMath {
function add(uint256 a, uint256 b) internal pure returns (uint256) {
return a + b;
}
}
contract MyContract {
using SafeMath for uint256; // Contract-level using directive
uint256 public value;
function increment(uint256 _amount) public {
value = value.add(_amount); // Call resolution not tested here
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let contract_handler = ContractHandling::default();
contract_handler.generate(input, &mut ctx, &mut graph)?;
assert_eq!(
ctx.using_for_directives.len(),
1,
"Should find exactly one 'using for' directive entry"
);
let expected_key = (Some("MyContract".to_string()), "uint256".to_string());
let expected_value = vec!["SafeMath".to_string()];
assert!(
ctx.using_for_directives.contains_key(&expected_key),
"Context should contain key for (Some(MyContract), uint256)"
);
assert_eq!(
ctx.using_for_directives.get(&expected_key),
Some(&expected_value),
"The value for the key should be vec![\"SafeMath\"]"
);
assert!(
find_node(&graph, "add", Some("SafeMath")).is_some(),
"Library function node should exist"
);
assert!(
find_node(&graph, "increment", Some("MyContract")).is_some(),
"Contract function node should exist"
);
assert!(
find_node(&graph, "MyContract", Some("MyContract")).is_some(),
"Default constructor node should exist"
);
Ok(())
}
#[test]
fn test_library_definition_and_usage() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
library MathUtils {
function isEven(uint256 a) internal pure returns (bool) {
return a % 2 == 0;
}
}
contract ExampleContract {
using MathUtils for uint256; // This line is parsed but not yet used for call resolution by CallsHandling
function checkNumberIsEven(uint256 _num) public pure returns (bool) {
// Call resolution for _num.isEven() requires changes in CallsHandling
return _num.isEven();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
4,
"Should find 4 nodes (library, lib func, contract func, default ctor)"
);
let lib_node = find_node(&graph, "MathUtils", Some("MathUtils"))
.expect("Library node MathUtils not found");
assert_eq!(lib_node.node_type, NodeType::Library);
assert_eq!(
lib_node.contract_name,
Some("MathUtils".to_string()),
"Library node should store its own name as scope"
);
assert_visibility(&lib_node, Visibility::Default);
let lib_func_node = find_node(&graph, "isEven", Some("MathUtils"))
.expect("Library function node MathUtils.isEven not found");
assert_eq!(lib_func_node.node_type, NodeType::Function);
assert_eq!(
lib_func_node.contract_name,
Some("MathUtils".to_string()),
"Library function scope"
);
assert_visibility(&lib_func_node, Visibility::Internal);
let contract_func_node = find_node(&graph, "checkNumberIsEven", Some("ExampleContract"))
.expect("Contract function node ExampleContract.checkNumberIsEven not found");
assert_eq!(contract_func_node.node_type, NodeType::Function);
assert_eq!(
contract_func_node.contract_name,
Some("ExampleContract".to_string()),
"Contract function scope"
);
assert_visibility(&contract_func_node, Visibility::Public);
let constructor_node = find_node(&graph, "ExampleContract", Some("ExampleContract"))
.expect("Default constructor node ExampleContract not found");
assert_eq!(constructor_node.node_type, NodeType::Constructor);
assert_visibility(&constructor_node, Visibility::Public);
assert_eq!(
graph.edges.len(),
0,
"Should find 0 edges currently (Parameter type resolution for 'using for' needed)"
);
Ok(())
}
#[test]
fn test_using_for_call_resolution() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
library MathUtils {
function isEven(uint256 a) internal pure returns (bool) {
return a % 2 == 0;
}
}
contract ExampleContract {
using MathUtils for uint256; // Directive to be used by CallsHandling
uint256 number; // State variable to call method on
function checkNumberIsEven() public view returns (bool) {
// This call should be resolved via 'using for'
return number.isEven();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
5, "Should find 5 nodes (library, lib func, contract func, default ctor, state var)"
);
let lib_func_node = find_node(&graph, "isEven", Some("MathUtils"))
.expect("Library function node MathUtils.isEven not found");
assert_eq!(lib_func_node.node_type, NodeType::Function);
assert_visibility(&lib_func_node, Visibility::Internal);
let contract_func_node = find_node(&graph, "checkNumberIsEven", Some("ExampleContract"))
.expect("Contract function node ExampleContract.checkNumberIsEven not found");
assert_eq!(contract_func_node.node_type, NodeType::Function);
assert_visibility(&contract_func_node, Visibility::Public);
assert_eq!(
graph.edges.len(),
2,
"Should find exactly 2 edges (1 read, 1 call)" );
let call_edge = graph
.edges
.iter()
.find(|e| {
e.source_node_id == contract_func_node.id
&& e.target_node_id == lib_func_node.id
&& e.edge_type == EdgeType::Call
})
.expect("Call edge checkNumberIsEven -> isEven not found");
assert_eq!(
call_edge.source_node_id, contract_func_node.id,
"Call edge source should be checkNumberIsEven"
);
assert_eq!(
call_edge.target_node_id, lib_func_node.id,
"Edge target should be MathUtils.isEven"
);
assert_eq!(
call_edge.edge_type,
EdgeType::Call,
"Edge type should be Call"
);
assert_eq!(
call_edge.sequence_number, 2, "Edge sequence number should be 2"
);
let read_edge = graph
.edges
.iter()
.find(|e| {
e.source_node_id == contract_func_node.id && e.edge_type == EdgeType::StorageRead
})
.expect("StorageRead edge from checkNumberIsEven not found");
let number_var_node = find_node(&graph, "number", Some("ExampleContract"))
.expect("State variable 'number' node not found");
assert_eq!(read_edge.target_node_id, number_var_node.id);
Ok(())
}
#[test]
fn test_interface_definition() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
interface IERC20 {
function totalSupply() external view returns (uint256);
function balanceOf(address account) external view returns (uint256);
function transfer(address recipient, uint256 amount) external returns (bool);
}
contract TokenImplementation is IERC20 {
uint256 private _totalSupply;
mapping(address => uint256) private _balances;
function totalSupply() external view override returns (uint256) {
return _totalSupply;
}
function balanceOf(address account) external view override returns (uint256) {
return _balances[account];
}
function transfer(address recipient, uint256 amount) external override returns (bool) {
// Implementation details omitted
return true;
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert!(
ctx.all_interfaces.contains_key("IERC20"),
"Interface IERC20 should be captured"
);
let interface_functions = ctx
.interface_functions
.get("IERC20")
.expect("IERC20 functions missing");
assert!(
interface_functions.contains(&"totalSupply".to_string()),
"totalSupply function missing"
);
assert!(
interface_functions.contains(&"balanceOf".to_string()),
"balanceOf function missing"
);
assert!(
interface_functions.contains(&"transfer".to_string()),
"transfer function missing"
);
let interface_node = find_node(&graph, "IERC20", Some("IERC20")).expect("IERC20 node missing");
assert_eq!(
interface_node.node_type,
NodeType::Interface,
"Node type should be Interface"
);
let total_supply_node =
find_node(&graph, "totalSupply", Some("IERC20")).expect("totalSupply node missing");
let balance_of_node =
find_node(&graph, "balanceOf", Some("IERC20")).expect("balanceOf node missing");
let transfer_node =
find_node(&graph, "transfer", Some("IERC20")).expect("transfer node missing");
assert_eq!(total_supply_node.node_type, NodeType::Function);
assert_eq!(balance_of_node.node_type, NodeType::Function);
assert_eq!(transfer_node.node_type, NodeType::Function);
let impl_total_supply = find_node(&graph, "totalSupply", Some("TokenImplementation"))
.expect("Implementation totalSupply missing");
let impl_balance_of = find_node(&graph, "balanceOf", Some("TokenImplementation"))
.expect("Implementation balanceOf missing");
let impl_transfer = find_node(&graph, "transfer", Some("TokenImplementation"))
.expect("Implementation transfer missing");
assert_eq!(impl_total_supply.node_type, NodeType::Function);
assert_eq!(impl_balance_of.node_type, NodeType::Function);
assert_eq!(impl_transfer.node_type, NodeType::Function);
Ok(())
}
#[test]
fn test_interface_inheritance() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
// Base interface
interface IERC20 {
function totalSupply() external view returns (uint256);
function balanceOf(address account) external view returns (uint256);
}
// Extended interface that inherits from base
interface IERC20Extended is IERC20 {
function name() external view returns (string memory);
function symbol() external view returns (string memory);
}
// Contract implementing the extended interface
contract CompleteToken is IERC20Extended {
string private _name;
string private _symbol;
uint256 private _totalSupply;
mapping(address => uint256) private _balances;
constructor(string memory name_, string memory symbol_) {
_name = name_;
_symbol = symbol_;
}
function name() external view override returns (string memory) {
return _name;
}
function symbol() external view override returns (string memory) {
return _symbol;
}
function totalSupply() external view override returns (uint256) {
return _totalSupply;
}
function balanceOf(address account) external view override returns (uint256) {
return _balances[account];
}
}
// Contract implementing only the base interface
contract BasicToken is IERC20 {
uint256 private _totalSupply;
mapping(address => uint256) private _balances;
function totalSupply() external view override returns (uint256) {
return _totalSupply;
}
function balanceOf(address account) external view override returns (uint256) {
return _balances[account];
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert!(
ctx.all_interfaces.contains_key("IERC20"),
"Interface IERC20 should be captured"
);
assert!(
ctx.all_interfaces.contains_key("IERC20Extended"),
"Interface IERC20Extended should be captured"
);
println!("ctx.contract_implements: {:?}", ctx.contract_implements);
println!("ctx.interface_inherits: {:?}", ctx.interface_inherits);
let ierc20_extended_implements = ctx
.interface_inherits
.get("IERC20Extended")
.expect("IERC20Extended inheritance relationship missing");
assert!(
ierc20_extended_implements.contains(&"IERC20".to_string()),
"IERC20Extended should implement IERC20"
);
let complete_token_implements = ctx
.contract_implements
.get("CompleteToken")
.expect("CompleteToken implementation relationship missing");
assert!(
complete_token_implements.contains(&"IERC20Extended".to_string()),
"CompleteToken should implement IERC20Extended"
);
let basic_token_implements = ctx
.contract_implements
.get("BasicToken")
.expect("BasicToken implementation relationship missing");
assert!(
basic_token_implements.contains(&"IERC20".to_string()),
"BasicToken should implement IERC20"
);
let ierc20_functions = ctx
.interface_functions
.get("IERC20")
.expect("IERC20 functions missing");
assert!(
ierc20_functions.contains(&"totalSupply".to_string())
&& ierc20_functions.contains(&"balanceOf".to_string()),
"IERC20 functions incomplete"
);
let ierc20_extended_functions = ctx
.interface_functions
.get("IERC20Extended")
.expect("IERC20Extended functions missing");
assert!(
ierc20_extended_functions.contains(&"name".to_string())
&& ierc20_extended_functions.contains(&"symbol".to_string()),
"IERC20Extended functions incomplete"
);
assert!(
find_node(&graph, "IERC20", Some("IERC20")).is_some(),
"IERC20 node missing"
);
assert!(
find_node(&graph, "IERC20Extended", Some("IERC20Extended")).is_some(),
"IERC20Extended node missing"
);
assert!(
find_node(&graph, "CompleteToken", Some("CompleteToken")).is_some(),
"CompleteToken constructor node missing"
);
assert!(
find_node(&graph, "BasicToken", Some("BasicToken")).is_some(),
"BasicToken constructor node missing"
);
assert!(
find_node(&graph, "totalSupply", Some("IERC20")).is_some(),
"IERC20.totalSupply node missing"
);
assert!(
find_node(&graph, "balanceOf", Some("IERC20")).is_some(),
"IERC20.balanceOf node missing"
);
assert!(
find_node(&graph, "name", Some("IERC20Extended")).is_some(),
"IERC20Extended.name node missing"
);
assert!(
find_node(&graph, "symbol", Some("IERC20Extended")).is_some(),
"IERC20Extended.symbol node missing"
);
assert!(
find_node(&graph, "name", Some("CompleteToken")).is_some(),
"CompleteToken.name node missing"
);
assert!(
find_node(&graph, "symbol", Some("CompleteToken")).is_some(),
"CompleteToken.symbol node missing"
);
assert!(
find_node(&graph, "totalSupply", Some("CompleteToken")).is_some(),
"CompleteToken.totalSupply node missing"
);
assert!(
find_node(&graph, "balanceOf", Some("CompleteToken")).is_some(),
"CompleteToken.balanceOf node missing"
);
assert!(
find_node(&graph, "totalSupply", Some("BasicToken")).is_some(),
"BasicToken.totalSupply node missing"
);
assert!(
find_node(&graph, "balanceOf", Some("BasicToken")).is_some(),
"BasicToken.balanceOf node missing"
);
Ok(())
}
#[test]
fn test_interface_invocation_single_implementation() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
interface ICounter {
function increment() external;
}
contract Counter is ICounter {
uint public count;
function increment() external override {
count += 1;
}
}
contract CounterUser {
ICounter public _counter; // State variable of interface type
constructor(address counterAddress) {
_counter = ICounter(counterAddress); // Assume setup elsewhere
}
function useCounter() public {
_counter.increment(); // Call via interface
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(graph.nodes.len(), 8, "Should find 8 nodes (+2 state vars)");
assert!(
ctx.all_interfaces.contains_key("ICounter"),
"Context should contain ICounter interface"
);
assert!(
ctx.interface_functions
.get("ICounter")
.map_or(false, |funcs| funcs.contains(&"increment".to_string())),
"Context should contain ICounter.increment function"
);
assert!(
ctx.contract_implements
.get("Counter")
.map_or(false, |ifaces| ifaces.contains(&"ICounter".to_string())),
"Context should show Counter implements ICounter"
);
assert!(
ctx.state_var_types
.contains_key(&("CounterUser".to_string(), "_counter".to_string())),
"Context should contain type info for CounterUser._counter"
);
assert_eq!(
ctx.state_var_types
.get(&("CounterUser".to_string(), "_counter".to_string())),
Some(&"ICounter".to_string()),
"CounterUser._counter type should be ICounter"
);
let user_use_node = find_node(&graph, "useCounter", Some("CounterUser"))
.expect("CounterUser.useCounter node missing");
let impl_inc_node =
find_node(&graph, "increment", Some("Counter")).expect("Counter.increment node missing");
let _iface_inc_node =
find_node(&graph, "increment", Some("ICounter")).expect("ICounter.increment node missing");
let interface_call_edge = graph
.edges
.iter()
.find(|e| e.source_node_id == user_use_node.id && e.edge_type == EdgeType::Call) .expect("Interface call edge not found");
assert_eq!(
interface_call_edge.source_node_id, user_use_node.id,
"Edge source should be CounterUser.useCounter"
);
assert_eq!(
interface_call_edge.target_node_id, impl_inc_node.id,
"Edge target should be Counter.increment (the implementation)"
);
assert_eq!(
interface_call_edge.edge_type,
EdgeType::Call,
"Edge type should be Call"
);
assert!(
interface_call_edge.sequence_number > 0,
"Edge sequence number should be positive"
);
assert_eq!(
graph.edges.len(),
6, "Should find 6 edges (2 calls, 2 reads, 2 writes)"
);
Ok(())
}
#[test]
fn test_chained_call_resolution() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
library SafeMath {
function add(uint256 a, uint256 b) internal pure returns (uint256) { return a + b; }
function sub(uint256 a, uint256 b) internal pure returns (uint256) { return a - b; }
}
contract ChainedCalls {
using SafeMath for uint256;
uint256 public value;
function complexUpdate(uint256 _add, uint256 _sub) public {
// Chained call: value.add(_add).sub(_sub)
value = value.add(_add).sub(_sub);
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
6, "Should find 6 nodes (library, 2 lib funcs, contract func, default ctor, state var)"
);
let lib_add_node =
find_node(&graph, "add", Some("SafeMath")).expect("SafeMath.add node missing");
let lib_sub_node =
find_node(&graph, "sub", Some("SafeMath")).expect("SafeMath.sub node missing");
let contract_update_node = find_node(&graph, "complexUpdate", Some("ChainedCalls"))
.expect("ChainedCalls.complexUpdate node missing");
assert_eq!(
graph.edges.len(),
4,
"Should find 4 edges for the chained call (1 write + 1 read + 2 calls)" );
let edge_to_add = graph
.edges
.iter()
.find(|e| {
e.source_node_id == contract_update_node.id && e.target_node_id == lib_add_node.id
})
.expect("Edge complexUpdate -> add missing");
assert_eq!(edge_to_add.edge_type, EdgeType::Call);
assert_eq!(
edge_to_add.sequence_number, 2, "Edge complexUpdate -> add sequence should be 2"
);
let edge_to_sub = graph
.edges
.iter()
.find(|e| {
e.source_node_id == contract_update_node.id && e.target_node_id == lib_sub_node.id && e.edge_type == EdgeType::Call })
.expect("Edge complexUpdate -> sub missing");
assert_eq!(edge_to_sub.edge_type, EdgeType::Call);
assert_eq!(
edge_to_sub.sequence_number, 3, "Edge complexUpdate -> sub sequence should be 3"
);
Ok(())
}
#[test]
fn test_explicit_return_edge_generation() -> Result<()> {
let source = r#"
pragma solidity ^0.8.0;
contract ReturnTest {
function returnsValue() internal pure returns (uint) {
return 42; // Explicit return with value
}
function returnsNothingExplicitly() internal pure {
return; // Explicit empty return
}
function noReturnStatement() internal pure {
// No explicit return statement
uint x = 1;
}
function caller() public pure {
uint val = returnsValue();
returnsNothingExplicitly();
noReturnStatement();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree.clone(),
solidity_lang: solidity_lang.clone(),
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config: HashMap<String, String> = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
let caller_node_id = find_node(&graph, "caller", Some("ReturnTest"))
.expect("caller node")
.id;
let returns_value_node_id = find_node(&graph, "returnsValue", Some("ReturnTest"))
.expect("returnsValue node")
.id;
let returns_nothing_node_id = find_node(&graph, "returnsNothingExplicitly", Some("ReturnTest"))
.expect("returnsNothingExplicitly node")
.id;
let no_return_node_id = find_node(&graph, "noReturnStatement", Some("ReturnTest"))
.expect("noReturnStatement node")
.id;
let call_to_value_seq = graph
.iter_edges()
.find(|e| e.source_node_id == caller_node_id && e.target_node_id == returns_value_node_id)
.expect("Call edge to returnsValue")
.sequence_number;
let call_to_nothing_seq = graph
.iter_edges()
.find(|e| e.source_node_id == caller_node_id && e.target_node_id == returns_nothing_node_id)
.expect("Call edge to returnsNothingExplicitly")
.sequence_number;
let input_for_returns = CallGraphGeneratorInput {
source: source.to_string(), tree: ast.tree.clone(), solidity_lang, };
graph.add_explicit_return_edges(&input_for_returns, &ctx)?;
assert_eq!(graph.nodes.len(), 5, "Should find 5 nodes");
assert_eq!(
graph.edges.len(),
5,
"Should find 3 call edges and 2 return edges"
);
let return_from_value_edge = graph
.iter_edges()
.find(|e| {
e.edge_type == EdgeType::Return
&& e.source_node_id == returns_value_node_id
&& e.target_node_id == caller_node_id
})
.expect("Return edge from returnsValue not found");
assert_eq!(
return_from_value_edge.sequence_number, call_to_value_seq,
"Sequence number mismatch for returnsValue return"
);
assert_eq!(
return_from_value_edge.returned_value,
Some("42".to_string()),
"Returned value mismatch for returnsValue"
);
assert!(
return_from_value_edge.return_site_span.is_some(),
"Return site span missing for returnsValue"
);
let return_from_nothing_edge = graph
.iter_edges()
.find(|e| {
e.edge_type == EdgeType::Return
&& e.source_node_id == returns_nothing_node_id
&& e.target_node_id == caller_node_id
})
.expect("Return edge from returnsNothingExplicitly not found");
assert_eq!(
return_from_nothing_edge.sequence_number, call_to_nothing_seq,
"Sequence number mismatch for returnsNothingExplicitly return"
);
assert_eq!(
return_from_nothing_edge.returned_value, None,
"Returned value should be None for empty return"
);
assert!(
return_from_nothing_edge.return_site_span.is_some(),
"Return site span missing for returnsNothingExplicitly"
);
let no_return_edge = graph
.iter_edges()
.find(|e| e.edge_type == EdgeType::Return && e.source_node_id == no_return_node_id);
assert!(
no_return_edge.is_none(),
"Should be no return edge from noReturnStatement"
);
Ok(())
}
#[test]
fn test_direct_library_call() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
library Lib {
function doSomething() internal pure returns (uint256) {
return 1;
}
}
contract Caller {
function callLib() public pure returns (uint256) {
// Direct library call
return Lib.doSomething();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
4,
"Should find 4 nodes (library, lib func, contract func, default ctor)"
);
let lib_func_node =
find_node(&graph, "doSomething", Some("Lib")).expect("Lib.doSomething node missing");
let contract_call_node =
find_node(&graph, "callLib", Some("Caller")).expect("Caller.callLib node missing");
assert_eq!(
graph.edges.len(),
1,
"Should find exactly 1 edge (callLib -> doSomething)"
);
let edge = &graph.edges[0];
assert_eq!(
edge.source_node_id, contract_call_node.id,
"Edge source should be Caller.callLib"
);
assert_eq!(
edge.target_node_id, lib_func_node.id,
"Edge target should be Lib.doSomething"
);
assert_eq!(edge.edge_type, EdgeType::Call, "Edge type should be Call");
assert_eq!(edge.sequence_number, 1, "Edge sequence number should be 1");
Ok(())
}
#[test]
fn test_chained_library_call_resolution() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
library SafeMath {
function mul(uint256 a, uint256 b) internal pure returns (uint256) { return a * b; }
function sub(uint256 a, uint256 b) internal pure returns (uint256) { return a - b; }
}
contract ChainedLibCalls {
using SafeMath for uint256;
uint256 public balance;
// uint256 public amountIn; // Not strictly needed for this test's focus
function complexUpdate() public {
// Direct chained call: result_of_mul.sub(...)
// Mimics balance.mul(1000).sub(amountIn.mul(3)) structure's outer chain
balance = balance.mul(1000).sub(3);
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
6, "Should find 6 nodes (library, 2 lib funcs, contract func, default ctor, state var)"
);
let lib_mul_node =
find_node(&graph, "mul", Some("SafeMath")).expect("SafeMath.mul node missing");
let lib_sub_node =
find_node(&graph, "sub", Some("SafeMath")).expect("SafeMath.sub node missing");
let contract_update_node = find_node(&graph, "complexUpdate", Some("ChainedLibCalls"))
.expect("ChainedLibCalls.complexUpdate node missing");
assert_eq!(
graph.edges.len(),
4,
"Should find 4 edges for the chained library call (1 write + 1 read + 2 calls)" );
let edge_to_mul = graph
.edges
.iter()
.find(|e| {
e.source_node_id == contract_update_node.id && e.target_node_id == lib_mul_node.id
})
.expect("Edge complexUpdate -> mul missing");
assert_eq!(edge_to_mul.edge_type, EdgeType::Call);
assert_eq!(
edge_to_mul.sequence_number, 2, "Edge complexUpdate -> mul sequence should be 2"
);
let edge_to_sub = graph
.edges
.iter()
.find(|e| {
e.source_node_id == contract_update_node.id && e.target_node_id == lib_sub_node.id && e.edge_type == EdgeType::Call })
.expect("Edge complexUpdate -> sub missing");
assert_eq!(edge_to_sub.edge_type, EdgeType::Call);
assert_eq!(
edge_to_sub.sequence_number, 3, "Edge complexUpdate -> sub sequence should be 3"
);
Ok(())
}
#[test]
fn test_interface_call_resolution_factory_pattern() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
// Interface for the action contract
interface IAction {
function performAction() external returns (bool);
}
// Implementation of the action contract
contract ActionImpl is IAction {
function performAction() external override returns (bool) {
// Implementation logic
return true;
}
}
// Interface for the factory
interface IActionFactory {
function createAction() external returns (IAction);
}
// Factory contract that creates ActionImpl instances
contract ActionFactory is IActionFactory {
function createAction() external override returns (IAction) {
// Creates a new instance of ActionImpl
return new ActionImpl();
}
}
// Contract that uses the factory to get an action contract and call it
contract ActionCaller {
address public factoryAddress; // Store factory address
constructor(address _factoryAddress) {
factoryAddress = _factoryAddress;
// Factory interaction moved to triggerAction
}
function triggerAction() public returns (bool) {
// Chained call: Cast address -> call factory -> call action
// This mimics the IUniswapV2Factory(factory).feeTo() pattern
// 1. IActionFactory(factoryAddress) -> Cast
// 2. .createAction() -> Calls ActionFactory.createAction (returns IAction)
// 3. .performAction() -> Calls ActionImpl.performAction on the returned IAction
return IActionFactory(factoryAddress).createAction().performAction();
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
11, "Should find 11 nodes (interfaces, funcs, contracts, ctors, state var)"
);
assert!(
ctx.all_interfaces.contains_key("IAction"),
"Context should contain IAction interface"
);
assert!(
ctx.all_interfaces.contains_key("IActionFactory"),
"Context should contain IActionFactory interface"
);
assert!(
ctx.interface_functions
.get("IAction")
.map_or(false, |funcs| funcs.contains(&"performAction".to_string())),
"Context should contain IAction.performAction function"
);
assert!(
ctx.interface_functions
.get("IActionFactory")
.map_or(false, |funcs| funcs.contains(&"createAction".to_string())),
"Context should contain IActionFactory.createAction function"
);
assert!(
ctx.contract_implements
.get("ActionImpl")
.map_or(false, |ifaces| ifaces.contains(&"IAction".to_string())),
"Context should show ActionImpl implements IAction"
);
assert!(
ctx.contract_implements
.get("ActionFactory")
.map_or(false, |ifaces| ifaces
.contains(&"IActionFactory".to_string())),
"Context should show ActionFactory implements IActionFactory"
);
assert!(
ctx.state_var_types
.contains_key(&("ActionCaller".to_string(), "factoryAddress".to_string())), "Context should contain type info for ActionCaller.factoryAddress"
);
assert_eq!(
ctx.state_var_types
.get(&("ActionCaller".to_string(), "factoryAddress".to_string())), Some(&"address".to_string()), "ActionCaller.factoryAddress type should be address"
);
let caller_trigger_node = find_node(&graph, "triggerAction", Some("ActionCaller"))
.expect("ActionCaller.triggerAction node missing");
let impl_perform_node = find_node(&graph, "performAction", Some("ActionImpl"))
.expect("ActionImpl.performAction node missing");
let _caller_ctor_node = find_node(&graph, "ActionCaller", Some("ActionCaller")) .expect("ActionCaller constructor node missing");
let factory_create_node = find_node(&graph, "createAction", Some("ActionFactory"))
.expect("ActionFactory.createAction node missing");
let impl_ctor_node = find_node(&graph, "ActionImpl", Some("ActionImpl"))
.expect("ActionImpl constructor node missing");
println!("Edges: {:?}", graph.edges);
assert_eq!(
graph.edges.len(),
5,
"Should find 5 edges (trigger->factory.create, factory.create->impl.ctor, trigger->impl.perform, 1 read, 1 write)"
);
let edge_trigger_to_factory = graph
.edges
.iter()
.find(|e| {
e.source_node_id == caller_trigger_node.id && e.target_node_id == factory_create_node.id
})
.expect("Edge ActionCaller.triggerAction -> ActionFactory.createAction missing");
assert_eq!(edge_trigger_to_factory.edge_type, EdgeType::Call);
let edge_factory_to_impl_ctor = graph
.edges
.iter()
.find(|e| {
e.source_node_id == factory_create_node.id && e.target_node_id == impl_ctor_node.id
})
.expect("Edge ActionFactory.createAction -> ActionImpl.ctor missing");
assert_eq!(edge_factory_to_impl_ctor.edge_type, EdgeType::Call);
let edge_trigger_to_impl = graph
.edges
.iter()
.find(|e| {
e.source_node_id == caller_trigger_node.id && e.target_node_id == impl_perform_node.id
})
.expect("Edge ActionCaller.triggerAction -> ActionImpl.performAction missing");
assert_eq!(
edge_trigger_to_impl.source_node_id, caller_trigger_node.id,
"Edge source should be ActionCaller.triggerAction"
);
assert_eq!(
edge_trigger_to_impl.target_node_id, impl_perform_node.id,
"Edge target should be ActionImpl.performAction (the implementation)"
);
assert_eq!(
edge_trigger_to_impl.edge_type,
EdgeType::Call,
"Edge type should be Call"
);
assert_eq!(
edge_trigger_to_factory.sequence_number, 3, "triggerAction -> createAction sequence should be 3"
);
assert_eq!(
edge_trigger_to_impl.sequence_number, 4, "triggerAction -> performAction sequence should be 4"
);
Ok(())
}
#[test]
fn test_argument_capturing() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
contract CalleeContract {
function externalTarget(uint amount, string memory message) public pure returns (bool) {
return amount > 0 && bytes(message).length > 0;
}
}
contract CallerContract {
CalleeContract public calleeInstance;
constructor(address _callee) {
calleeInstance = CalleeContract(_callee);
}
function internalTarget(uint value) internal pure returns (uint) {
return value * 2;
}
// Scenario 1: Intra-contract call with arguments
function callInternal(uint data) public pure returns (uint) {
return internalTarget(data + 1); // Argument: data + 1
}
// Scenario 2: Contract-to-contract call with arguments
function callExternal(uint num, string memory text) public returns (bool) {
return calleeInstance.externalTarget(num, text); // Arguments: num, text
}
// Scenario 3: Public function (simulating user call) with arguments
function entryPoint(uint startValue, address recipient) public pure {
// Arguments: startValue, recipient
uint _ = startValue; // Use args to avoid warnings
address _ = recipient;
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
let caller_internal_node = find_node(&graph, "callInternal", Some("CallerContract"))
.expect("CallerContract.callInternal node missing");
let target_internal_node = find_node(&graph, "internalTarget", Some("CallerContract"))
.expect("CallerContract.internalTarget node missing");
let caller_external_node = find_node(&graph, "callExternal", Some("CallerContract"))
.expect("CallerContract.callExternal node missing");
let target_external_node = find_node(&graph, "externalTarget", Some("CalleeContract"))
.expect("CalleeContract.externalTarget node missing");
let entry_point_node = find_node(&graph, "entryPoint", Some("CallerContract"))
.expect("CallerContract.entryPoint node missing");
let _caller_ctor_node = find_node(&graph, "CallerContract", Some("CallerContract")) .expect("CallerContract constructor node missing");
let _callee_ctor_node = find_node(&graph, "CalleeContract", Some("CalleeContract")) .expect("CalleeContract constructor node missing");
let intra_contract_edge = graph
.edges
.iter()
.find(|e| {
e.source_node_id == caller_internal_node.id
&& e.target_node_id == target_internal_node.id
})
.expect("Edge callInternal -> internalTarget missing");
assert_eq!(
intra_contract_edge.argument_names,
Some(vec!["data + 1".to_string()]),
"Intra-contract call arguments mismatch"
);
assert_eq!(intra_contract_edge.sequence_number, 1, "Intra-contract call sequence should be 1");
let inter_contract_edge = graph
.edges
.iter()
.find(|e| {
e.source_node_id == caller_external_node.id
&& e.target_node_id == target_external_node.id
})
.expect("Edge callExternal -> externalTarget missing");
assert_eq!(
inter_contract_edge.argument_names,
Some(vec!["num".to_string(), "text".to_string()]),
"Inter-contract call arguments mismatch"
);
assert_eq!(inter_contract_edge.sequence_number, 2, "Inter-contract call sequence should be 2");
assert_eq!(entry_point_node.name, "entryPoint");
assert_eq!(
entry_point_node.contract_name,
Some("CallerContract".to_string())
);
assert_eq!(graph.edges.len(), 4, "Expected 4 edges (2 calls + 1 write + 1 read)");
Ok(())
}
#[test]
fn test_simple_emit_statement() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
contract EventEmitter {
event ValueChanged(uint256 indexed oldValue, uint256 newValue);
uint256 private _value;
function updateValue(uint256 newValue) public {
uint256 oldValue = _value;
_value = newValue;
emit ValueChanged(oldValue, newValue); // Simple emit
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
5, "Should find 5 nodes (updateValue, default ctor, EVM, EventListener, state var)"
);
let update_value_node = find_node(&graph, "updateValue", Some("EventEmitter"))
.expect("EventEmitter.updateValue node missing");
let evm_node = find_node(&graph, EVM_NODE_NAME, None).expect("EVM node missing");
let listener_node =
find_node(&graph, EVENT_LISTENER_NODE_NAME, None).expect("EventListener node missing");
assert_eq!(evm_node.node_type, NodeType::Evm);
assert_eq!(listener_node.node_type, NodeType::EventListener);
assert_eq!(
graph.edges.len(),
4,
"Should find 4 edges (1 read, 1 write, 2 emit)"
);
let edge_func_to_evm = graph
.edges
.iter()
.find(|e| e.source_node_id == update_value_node.id && e.target_node_id == evm_node.id)
.expect("Edge updateValue -> EVM missing");
assert_eq!(edge_func_to_evm.edge_type, EdgeType::Call);
assert_eq!(
edge_func_to_evm.sequence_number, 3,
"Emit sequence number should be 3"
);
assert_eq!(
edge_func_to_evm.event_name,
Some("ValueChanged".to_string()),
"Event name mismatch on func->EVM edge"
);
assert_eq!(
edge_func_to_evm.argument_names,
Some(vec!["oldValue".to_string(), "newValue".to_string()]),
"Arguments mismatch on func->EVM edge"
);
let edge_evm_to_listener = graph
.edges
.iter()
.find(|e| e.source_node_id == evm_node.id && e.target_node_id == listener_node.id)
.expect("Edge EVM -> EventListener missing");
assert_eq!(edge_evm_to_listener.edge_type, EdgeType::Call);
assert_eq!(
edge_evm_to_listener.sequence_number, 4, "Sequence number for EVM->Listener should be 4"
);
assert_eq!(
edge_evm_to_listener.event_name,
Some("ValueChanged".to_string()),
"Event name mismatch on EVM->Listener edge"
);
assert_eq!(
edge_evm_to_listener.argument_names,
Some(vec!["oldValue".to_string(), "newValue".to_string()]),
"Arguments mismatch on EVM->Listener edge"
);
Ok(())
}
#[test]
fn test_interface_call_resolution_factory_pattern_no_return() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
// Interface for the action contract
interface IAction {
function performAction() external returns (bool);
}
// Implementation of the action contract
contract ActionImpl is IAction {
function performAction() external override returns (bool) {
// Implementation logic
return true;
}
}
// Interface for the factory
interface IActionFactory {
function createAction() external returns (IAction);
}
// Factory contract that creates ActionImpl instances
contract ActionFactory is IActionFactory {
function createAction() external override returns (IAction) {
// Creates a new instance of ActionImpl
return new ActionImpl();
}
}
// Contract that uses the factory to get an action contract and call it
contract ActionCaller {
address public factoryAddress; // Store factory address
constructor(address _factoryAddress) {
factoryAddress = _factoryAddress;
}
function triggerAction() public { // Changed: No return value
// Chained call: Cast address -> call factory -> call action
// Result is not returned, just executed.
IActionFactory(factoryAddress).createAction().performAction(); // Changed: No return statement
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
11, "NoReturn: Should find 11 nodes (interfaces, funcs, contracts, ctors, state var)"
);
let caller_trigger_node = find_node(&graph, "triggerAction", Some("ActionCaller"))
.expect("NoReturn: ActionCaller.triggerAction node missing");
let impl_perform_node = find_node(&graph, "performAction", Some("ActionImpl"))
.expect("NoReturn: ActionImpl.performAction node missing");
let _caller_ctor_node = find_node(&graph, "ActionCaller", Some("ActionCaller")) .expect("NoReturn: ActionCaller constructor node missing");
let factory_create_node = find_node(&graph, "createAction", Some("ActionFactory"))
.expect("NoReturn: ActionFactory.createAction node missing");
let impl_ctor_node = find_node(&graph, "ActionImpl", Some("ActionImpl"))
.expect("NoReturn: ActionImpl constructor node missing");
println!("NoReturn Edges: {:?}", graph.edges);
assert_eq!(
graph.edges.len(),
5,
"NoReturn: Should find 5 edges (trigger->factory.create, factory.create->impl.ctor, trigger->impl.perform, 1 read, 1 write)"
);
let edge_trigger_to_factory = graph
.edges
.iter()
.find(|e| {
e.source_node_id == caller_trigger_node.id && e.target_node_id == factory_create_node.id
})
.expect("NoReturn: Edge ActionCaller.triggerAction -> ActionFactory.createAction missing");
assert_eq!(edge_trigger_to_factory.edge_type, EdgeType::Call);
let edge_factory_to_impl_ctor = graph
.edges
.iter()
.find(|e| {
e.source_node_id == factory_create_node.id && e.target_node_id == impl_ctor_node.id
})
.expect("NoReturn: Edge ActionFactory.createAction -> ActionImpl.ctor missing");
assert_eq!(edge_factory_to_impl_ctor.edge_type, EdgeType::Call);
let edge_trigger_to_impl = graph
.edges
.iter()
.find(|e| {
e.source_node_id == caller_trigger_node.id && e.target_node_id == impl_perform_node.id
})
.expect("NoReturn: Edge ActionCaller.triggerAction -> ActionImpl.performAction missing");
assert_eq!(
edge_trigger_to_impl.source_node_id, caller_trigger_node.id,
"NoReturn: Edge source should be ActionCaller.triggerAction"
);
assert_eq!(
edge_trigger_to_impl.target_node_id, impl_perform_node.id,
"NoReturn: Edge target should be ActionImpl.performAction (the implementation)"
);
assert_eq!(
edge_trigger_to_impl.edge_type,
EdgeType::Call,
"NoReturn: Edge type should be Call"
);
assert_eq!(
edge_trigger_to_factory.sequence_number, 3, "NoReturn: triggerAction -> createAction sequence should be 3"
);
assert_eq!(
edge_trigger_to_impl.sequence_number, 4, "NoReturn: triggerAction -> performAction sequence should be 4"
);
Ok(())
}
#[test]
fn test_storage_read_write() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
contract StorageAccess {
uint256 public myVariable; // State variable
// Function that reads the state variable
function readVariable() public view returns (uint256) {
return myVariable; // Read operation
}
// Function that writes to the state variable
function writeVariable(uint256 newValue) public {
myVariable = newValue; // Write operation
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
4,
"Should find 4 nodes (state var, read func, write func, default ctor)"
);
let var_node = find_node(&graph, "myVariable", Some("StorageAccess"))
.expect("StorageAccess.myVariable node missing");
let read_func_node = find_node(&graph, "readVariable", Some("StorageAccess"))
.expect("StorageAccess.readVariable node missing");
let write_func_node = find_node(&graph, "writeVariable", Some("StorageAccess"))
.expect("StorageAccess.writeVariable node missing");
assert_eq!(var_node.node_type, NodeType::StorageVariable);
assert_eq!(read_func_node.node_type, NodeType::Function);
assert_eq!(write_func_node.node_type, NodeType::Function);
assert_eq!(graph.edges.len(), 2, "Should find 2 edges (1 read, 1 write)");
let read_edge = graph
.edges
.iter()
.find(|e| {
e.source_node_id == read_func_node.id
&& e.target_node_id == var_node.id
&& e.edge_type == EdgeType::StorageRead
})
.expect("StorageRead edge from readVariable to myVariable missing");
assert_eq!(read_edge.sequence_number, 1, "StorageRead sequence should be 1");
let write_edge = graph
.edges
.iter()
.find(|e| {
e.source_node_id == write_func_node.id
&& e.target_node_id == var_node.id
&& e.edge_type == EdgeType::StorageWrite
})
.expect("StorageWrite edge from writeVariable to myVariable missing");
assert_eq!(write_edge.sequence_number, 1, "StorageWrite sequence should be 1");
Ok(())
}
#[test]
fn test_library_call_on_return_value() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
library SafeMath {
function sub(uint256 a, uint256 b) internal pure returns (uint256) {
require(b <= a, "SafeMath: subtraction overflow");
return a - b;
}
}
interface IERC20 {
function balanceOf(address account) external view returns (uint256);
}
contract MyContract {
using SafeMath for uint256;
IERC20 public token; // Assume set elsewhere or in constructor
constructor(address _token) {
token = IERC20(_token);
}
function doMath() public view returns (uint256) {
// 1. Call balanceOf on the interface instance
uint256 balance = token.balanceOf(address(this));
// 2. Call sub (from SafeMath via using for) on the returned uint256
uint256 result = balance.sub(1);
return result;
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input.clone(), &mut ctx, &mut graph, &config)?;
let do_math_node_opt = find_node(&graph, "doMath", Some("MyContract"));
if let Some(do_math_node) = do_math_node_opt {
eprintln!("[DEBUG Mermaid Focus] Edges involving doMath (Node ID {}):", do_math_node.id);
for (idx, edge) in graph.edges.iter().enumerate() {
let _source_node_name = graph.nodes.get(edge.source_node_id).map_or("?".to_string(), |n| format!("{}.{}", n.contract_name.as_deref().unwrap_or("Global"), n.name));
let target_node_name = graph.nodes.get(edge.target_node_id).map_or("?".to_string(), |n| format!("{}.{}", n.contract_name.as_deref().unwrap_or("Global"), n.name));
if edge.source_node_id == do_math_node.id || edge.target_node_id == do_math_node.id {
eprintln!(
"[DEBUG Mermaid Focus] Edge Index {}: {} -> {} (Target: '{}'), Type: {:?}, Seq: {}, Args: {:?}, RetVal: {:?}",
idx, edge.source_node_id, edge.target_node_id, target_node_name, edge.edge_type, edge.sequence_number, edge.argument_names, edge.returned_value
);
}
}
} else {
eprintln!("[DEBUG Mermaid Focus] Could not find doMath node!");
}
graph.add_explicit_return_edges(&input, &ctx)?;
assert_eq!(
graph.nodes.len(),
8, "Should find 8 nodes (lib, lib func, iface, iface func, contract, contract func, ctor, state var, Require)"
);
let do_math_node = find_node(&graph, "doMath", Some("MyContract"))
.expect("MyContract.doMath node missing");
let balance_of_node = find_node(&graph, "balanceOf", Some("IERC20"))
.expect("IERC20.balanceOf node missing");
let sub_node =
find_node(&graph, "sub", Some("SafeMath")).expect("SafeMath.sub node missing");
let _token_var_node = find_node(&graph, "token", Some("MyContract")) .expect("MyContract.token state variable node missing");
let _ctor_node = find_node(&graph, "MyContract", Some("MyContract")) .expect("MyContract constructor node missing");
assert_eq!(
graph.edges.len(),
6, "Should find 6 edges (doMath->balanceOf, doMath->sub, doMath->token(R), ctor->token(W), sub->Require, sub->doMath(Ret))"
);
let edge_to_balance_of = graph
.edges
.iter()
.find(|e| {
e.source_node_id == do_math_node.id
&& e.target_node_id == balance_of_node.id
&& e.edge_type == EdgeType::Call
})
.expect("Edge doMath -> IERC20.balanceOf missing");
let edge_to_sub = graph
.edges
.iter()
.find(|e| {
e.source_node_id == do_math_node.id
&& e.target_node_id == sub_node.id
&& e.edge_type == EdgeType::Call
})
.expect("Edge doMath -> SafeMath.sub missing");
assert_eq!(
edge_to_balance_of.sequence_number, 2,
"balanceOf call should be sequence 2"
);
assert_eq!(edge_to_sub.sequence_number, 3, "sub call should be sequence 3");
Ok(())
}
#[test]
fn test_inherited_storage_access() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
contract BaseStorage {
uint256 public baseVar; // Inherited variable
}
contract DerivedStorage is BaseStorage {
// Reads the inherited variable
function readBase() public view returns (uint256) {
return baseVar; // Read inherited baseVar
}
// Writes to the inherited variable
function writeBase(uint256 newValue) public {
baseVar = newValue; // Write inherited baseVar
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
5,
"Should find 5 nodes (base var, base ctor, derived read, derived write, derived ctor)"
);
let base_var_node = find_node(&graph, "baseVar", Some("BaseStorage"))
.expect("BaseStorage.baseVar node missing");
let derived_read_node = find_node(&graph, "readBase", Some("DerivedStorage"))
.expect("DerivedStorage.readBase node missing");
let derived_write_node = find_node(&graph, "writeBase", Some("DerivedStorage"))
.expect("DerivedStorage.writeBase node missing");
assert_eq!(base_var_node.node_type, NodeType::StorageVariable);
assert_eq!(derived_read_node.node_type, NodeType::Function);
assert_eq!(derived_write_node.node_type, NodeType::Function);
assert_eq!(graph.edges.len(), 2, "Should find 2 edges (1 read, 1 write)");
let read_edge = graph
.edges
.iter()
.find(|e| {
e.source_node_id == derived_read_node.id
&& e.target_node_id == base_var_node.id
&& e.edge_type == EdgeType::StorageRead
})
.expect("StorageRead edge from readBase to baseVar missing");
assert_eq!(read_edge.sequence_number, 1, "StorageRead sequence should be 1");
let write_edge = graph
.edges
.iter()
.find(|e| {
e.source_node_id == derived_write_node.id
&& e.target_node_id == base_var_node.id
&& e.edge_type == EdgeType::StorageWrite
})
.expect("StorageWrite edge from writeBase to baseVar missing");
assert_eq!(write_edge.sequence_number, 1, "StorageWrite sequence should be 1");
Ok(())
}
#[test]
fn test_require_statement() -> Result<()> {
let source = r#"
pragma solidity ^0.8.20;
contract RequireTest {
uint256 public threshold = 10;
function checkValue(uint256 value) public view {
require(value > threshold, "Value must be greater than threshold"); // Require call
// ... rest of the function
}
}
"#;
let ast = parse_solidity(source)?;
let solidity_lang = get_solidity_language();
let input = CallGraphGeneratorInput {
source: source.to_string(),
tree: ast.tree,
solidity_lang,
};
let mut ctx = CallGraphGeneratorContext::default();
let mut graph = CallGraph::new();
let config = HashMap::new();
let mut pipeline = CallGraphGeneratorPipeline::new();
pipeline.add_step(Box::new(ContractHandling::default()));
pipeline.add_step(Box::new(CallsHandling::default()));
pipeline.run(input, &mut ctx, &mut graph, &config)?;
assert_eq!(
graph.nodes.len(),
4,
"Should find 4 nodes (checkValue, default ctor, state var, Require)"
);
let check_value_node = find_node(&graph, "checkValue", Some("RequireTest"))
.expect("RequireTest.checkValue node missing");
let _threshold_node = find_node(&graph, "threshold", Some("RequireTest"))
.expect("RequireTest.threshold node missing");
let require_node = graph.nodes.iter()
.find(|n| n.node_type == NodeType::RequireCondition)
.expect("Require node missing");
assert_eq!(require_node.node_type, NodeType::RequireCondition);
assert_eq!(
graph.edges.len(),
2,
"Should find 2 edges (1 read, 1 require call)"
);
let require_edge = graph
.edges
.iter()
.find(|e| {
e.source_node_id == check_value_node.id && e.target_node_id == require_node.id
})
.expect("Edge checkValue -> Require missing");
assert_eq!(require_edge.edge_type, EdgeType::Require); assert_eq!(
require_edge.sequence_number, 2, "Require call sequence number should be 2"
);
assert_eq!(
require_edge.argument_names,
Some(vec![
"value > threshold".to_string(),
"\"Value must be greater than threshold\"".to_string()
]),
"Require call arguments mismatch"
);
assert!(require_edge.call_site_span.0 > 0, "Require call site span start should be > 0");
Ok(())
}