Skip to main content

grapha_core/
classify.rs

1use std::collections::{HashMap, HashSet};
2use std::path::PathBuf;
3
4use crate::extract::ExtractionResult;
5use crate::graph::{EdgeKind, FlowDirection, Graph, NodeRole, TerminalKind};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct Classification {
9    pub terminal_kind: TerminalKind,
10    pub direction: FlowDirection,
11    pub operation: String,
12}
13
14#[derive(Debug, Clone)]
15pub struct ClassifyContext {
16    pub source_node: String,
17    pub file: PathBuf,
18    pub arguments: Vec<String>,
19}
20
21pub trait Classifier: Send + Sync {
22    fn classify(&self, call_target: &str, context: &ClassifyContext) -> Option<Classification>;
23}
24
25pub struct CompositeClassifier {
26    classifiers: Vec<Box<dyn Classifier>>,
27}
28
29impl CompositeClassifier {
30    pub fn new(classifiers: Vec<Box<dyn Classifier>>) -> Self {
31        Self { classifiers }
32    }
33
34    pub fn classify(&self, call_target: &str, context: &ClassifyContext) -> Option<Classification> {
35        self.classifiers
36            .iter()
37            .find_map(|classifier| classifier.classify(call_target, context))
38    }
39}
40
41pub fn classify_graph(graph: &Graph, classifier: &CompositeClassifier) -> Graph {
42    let node_file_map: HashMap<&str, &PathBuf> = graph
43        .nodes
44        .iter()
45        .map(|node| (node.id.as_str(), &node.file))
46        .collect();
47    let node_ids: HashSet<&str> = graph.nodes.iter().map(|node| node.id.as_str()).collect();
48    let mut terminal_nodes: HashMap<String, TerminalKind> = HashMap::new();
49
50    let edges = graph
51        .edges
52        .iter()
53        .map(|edge| {
54            if edge.kind != EdgeKind::Calls {
55                return edge.clone();
56            }
57
58            let target_name = edge.target.rsplit("::").next().unwrap_or(&edge.target);
59            let source_file = node_file_map
60                .get(edge.source.as_str())
61                .cloned()
62                .cloned()
63                .unwrap_or_default();
64            let context = ClassifyContext {
65                source_node: edge.source.clone(),
66                file: source_file,
67                arguments: Vec::new(),
68            };
69
70            let Some(classification) = classifier.classify(target_name, &context) else {
71                return edge.clone();
72            };
73
74            let terminal_node_id = if node_ids.contains(edge.target.as_str()) {
75                edge.target.clone()
76            } else {
77                edge.source.clone()
78            };
79            terminal_nodes.insert(terminal_node_id, classification.terminal_kind);
80
81            let mut enriched = edge.clone();
82            enriched.direction = Some(classification.direction);
83            enriched.operation = Some(classification.operation);
84            enriched
85        })
86        .collect();
87
88    let nodes = graph
89        .nodes
90        .iter()
91        .map(|node| {
92            if let Some(kind) = terminal_nodes.get(&node.id) {
93                let mut enriched = node.clone();
94                enriched.role = Some(NodeRole::Terminal { kind: *kind });
95                enriched
96            } else {
97                node.clone()
98            }
99        })
100        .collect();
101
102    Graph {
103        version: graph.version.clone(),
104        nodes,
105        edges,
106    }
107}
108
109pub fn classify_extraction_result(
110    mut result: ExtractionResult,
111    classifier: &CompositeClassifier,
112) -> ExtractionResult {
113    let node_ids: HashSet<&str> = result.nodes.iter().map(|node| node.id.as_str()).collect();
114    let node_file_map: HashMap<&str, &PathBuf> = result
115        .nodes
116        .iter()
117        .map(|node| (node.id.as_str(), &node.file))
118        .collect();
119    let mut terminal_nodes: HashMap<String, TerminalKind> = HashMap::new();
120
121    result.edges = result
122        .edges
123        .into_iter()
124        .map(|mut edge| {
125            if edge.kind != EdgeKind::Calls {
126                return edge;
127            }
128
129            let target_name = edge.target.rsplit("::").next().unwrap_or(&edge.target);
130            let source_file = node_file_map
131                .get(edge.source.as_str())
132                .cloned()
133                .cloned()
134                .unwrap_or_default();
135            let context = ClassifyContext {
136                source_node: edge.source.clone(),
137                file: source_file,
138                arguments: Vec::new(),
139            };
140
141            if let Some(classification) = classifier.classify(target_name, &context) {
142                let terminal_node_id = if node_ids.contains(edge.target.as_str()) {
143                    edge.target.clone()
144                } else {
145                    edge.source.clone()
146                };
147                terminal_nodes.insert(terminal_node_id, classification.terminal_kind);
148                edge.direction = Some(classification.direction);
149                edge.operation = Some(classification.operation);
150            }
151
152            edge
153        })
154        .collect();
155
156    for node in &mut result.nodes {
157        if let Some(kind) = terminal_nodes.get(&node.id)
158            && node.role.is_none()
159        {
160            node.role = Some(NodeRole::Terminal { kind: *kind });
161        }
162    }
163
164    result
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::graph::*;
171    use std::collections::HashMap;
172
173    struct AlwaysMatch {
174        classification: Classification,
175    }
176
177    impl Classifier for AlwaysMatch {
178        fn classify(
179            &self,
180            _call_target: &str,
181            _context: &ClassifyContext,
182        ) -> Option<Classification> {
183            Some(self.classification.clone())
184        }
185    }
186
187    struct NeverMatch;
188
189    impl Classifier for NeverMatch {
190        fn classify(
191            &self,
192            _call_target: &str,
193            _context: &ClassifyContext,
194        ) -> Option<Classification> {
195            None
196        }
197    }
198
199    fn test_context() -> ClassifyContext {
200        ClassifyContext {
201            source_node: "test::caller".to_string(),
202            file: PathBuf::from("test.rs"),
203            arguments: vec![],
204        }
205    }
206
207    #[test]
208    fn composite_returns_first_match() {
209        let classifier = CompositeClassifier::new(vec![Box::new(AlwaysMatch {
210            classification: Classification {
211                terminal_kind: TerminalKind::Network,
212                direction: FlowDirection::Read,
213                operation: "HTTP_GET".to_string(),
214            },
215        })]);
216        let result = classifier.classify("something", &test_context());
217        assert!(result.is_some());
218        assert_eq!(result.unwrap().terminal_kind, TerminalKind::Network);
219    }
220
221    #[test]
222    fn composite_returns_none_when_no_match() {
223        let classifier = CompositeClassifier::new(vec![Box::new(NeverMatch)]);
224        assert!(classifier.classify("something", &test_context()).is_none());
225    }
226
227    #[test]
228    fn classifies_external_call_on_source_node() {
229        let graph = Graph {
230            version: "0.1.0".to_string(),
231            nodes: vec![Node {
232                id: "src::caller".to_string(),
233                kind: NodeKind::Function,
234                name: "caller".to_string(),
235                file: PathBuf::from("src/main.rs"),
236                span: Span {
237                    start: [0, 0],
238                    end: [1, 0],
239                },
240                visibility: Visibility::Public,
241                metadata: HashMap::new(),
242                role: None,
243                signature: None,
244                doc_comment: None,
245                module: None,
246                snippet: None,
247            }],
248            edges: vec![Edge {
249                source: "src::caller".to_string(),
250                target: "reqwest::get".to_string(),
251                kind: EdgeKind::Calls,
252                confidence: 0.9,
253                direction: None,
254                operation: None,
255                condition: None,
256                async_boundary: None,
257                provenance: Vec::new(),
258            }],
259        };
260        let classifier = CompositeClassifier::new(vec![Box::new(AlwaysMatch {
261            classification: Classification {
262                terminal_kind: TerminalKind::Network,
263                direction: FlowDirection::Read,
264                operation: "HTTP".to_string(),
265            },
266        })]);
267
268        let enriched = classify_graph(&graph, &classifier);
269        assert_eq!(
270            enriched.nodes[0].role,
271            Some(NodeRole::Terminal {
272                kind: TerminalKind::Network,
273            })
274        );
275        assert_eq!(enriched.edges[0].direction, Some(FlowDirection::Read));
276    }
277}