use crate::cg::{CallGraph, EdgeType, Node, NodeType};
use crate::reachability::{NodeId, ReachabilityAnalyzer};
use std::collections::{HashMap, HashSet};
#[derive(Clone, Default, Debug, PartialEq, Eq)]
pub struct StorageAccessSummary {
pub reads: HashSet<NodeId>,
pub writes: HashSet<NodeId>,
}
pub fn analyze_storage_access(graph: &CallGraph) -> HashMap<NodeId, StorageAccessSummary> {
let analyzer = ReachabilityAnalyzer::new();
let is_function_like_node = |node: &Node| -> bool {
matches!(node.node_type, NodeType::Function | NodeType::Modifier | NodeType::Constructor)
};
let process_function_for_storage_interactions = |
func_node: &Node, state: &mut StorageAccessSummary,
graph: &CallGraph
| {
for edge in &graph.edges {
if edge.source_node_id == func_node.id {
if let Some(target_node) = graph.nodes.get(edge.target_node_id) {
if target_node.node_type == NodeType::StorageVariable {
match edge.edge_type {
EdgeType::StorageRead => {
state.reads.insert(target_node.id);
}
EdgeType::StorageWrite => {
state.writes.insert(target_node.id);
}
_ => {} }
}
}
}
}
};
analyzer.analyze_entry_points(
graph,
&is_function_like_node,
&process_function_for_storage_interactions,
StorageAccessSummary::default,
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reachability::tests::create_test_graph_for_reachability;
use std::collections::HashSet;
#[test]
fn test_analyze_storage_access_basic() {
let graph = create_test_graph_for_reachability();
let results = analyze_storage_access(&graph);
assert_eq!(results.len(), 2, "Expected 2 entry points (a_pub_func, b_pub_func)");
let a_pub_func_id = graph.iter_nodes().find(|n| n.name == "a_pub_func").unwrap().id;
let b_pub_func_id = graph.iter_nodes().find(|n| n.name == "b_pub_func").unwrap().id;
let storage_var1_id = graph.iter_nodes().find(|n| n.name == "var1").unwrap().id;
let storage_var2_id = graph.iter_nodes().find(|n| n.name == "var2").unwrap().id;
let storage_var3_id = graph.iter_nodes().find(|n| n.name == "var3").unwrap().id;
let summary_a = results.get(&a_pub_func_id).expect("Summary for a_pub_func missing");
let expected_reads_a: HashSet<NodeId> = [storage_var1_id, storage_var2_id].iter().cloned().collect();
let expected_writes_a: HashSet<NodeId> = [storage_var1_id, storage_var3_id].iter().cloned().collect();
assert_eq!(summary_a.reads, expected_reads_a, "Mismatch in reads for a_pub_func");
assert_eq!(summary_a.writes, expected_writes_a, "Mismatch in writes for a_pub_func");
let summary_b = results.get(&b_pub_func_id).expect("Summary for b_pub_func missing");
let expected_reads_b: HashSet<NodeId> = [storage_var2_id].iter().cloned().collect();
let expected_writes_b: HashSet<NodeId> = [storage_var2_id, storage_var3_id].iter().cloned().collect();
assert_eq!(summary_b.reads, expected_reads_b, "Mismatch in reads for b_pub_func");
assert_eq!(summary_b.writes, expected_writes_b, "Mismatch in writes for b_pub_func");
}
}