Skip to main content

grapha_core/
classify.rs

1use std::collections::{HashMap, HashSet};
2use std::path::PathBuf;
3
4use crate::extract::ExtractionResult;
5use crate::graph::{EdgeKind, FlowDirection, Graph, NodeRole, TerminalKind};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct Classification {
9    pub terminal_kind: TerminalKind,
10    pub direction: FlowDirection,
11    pub operation: String,
12}
13
14#[derive(Debug, Clone)]
15pub struct ClassifyContext {
16    pub source_node: String,
17    pub file: PathBuf,
18    pub arguments: Vec<String>,
19}
20
21pub trait Classifier: Send + Sync {
22    fn classify(&self, call_target: &str, context: &ClassifyContext) -> Option<Classification>;
23}
24
25pub struct CompositeClassifier {
26    classifiers: Vec<Box<dyn Classifier>>,
27}
28
29impl CompositeClassifier {
30    pub fn new(classifiers: Vec<Box<dyn Classifier>>) -> Self {
31        Self { classifiers }
32    }
33
34    pub fn classify(&self, call_target: &str, context: &ClassifyContext) -> Option<Classification> {
35        self.classifiers
36            .iter()
37            .find_map(|classifier| classifier.classify(call_target, context))
38    }
39}
40
41pub fn classify_graph(graph: &Graph, classifier: &CompositeClassifier) -> Graph {
42    let node_file_map: HashMap<&str, &PathBuf> = graph
43        .nodes
44        .iter()
45        .map(|node| (node.id.as_str(), &node.file))
46        .collect();
47    let node_ids: HashSet<&str> = graph.nodes.iter().map(|node| node.id.as_str()).collect();
48    let mut terminal_nodes: HashMap<String, TerminalKind> = HashMap::new();
49
50    let edges = graph
51        .edges
52        .iter()
53        .map(|edge| {
54            if edge.kind != EdgeKind::Calls {
55                return edge.clone();
56            }
57
58            let source_file = node_file_map
59                .get(edge.source.as_str())
60                .cloned()
61                .cloned()
62                .unwrap_or_default();
63            let context = ClassifyContext {
64                source_node: edge.source.clone(),
65                file: source_file,
66                arguments: Vec::new(),
67            };
68
69            let Some(classification) = classifier.classify(&edge.target, &context) else {
70                return edge.clone();
71            };
72
73            let terminal_node_id = if node_ids.contains(edge.target.as_str()) {
74                edge.target.clone()
75            } else {
76                edge.source.clone()
77            };
78            terminal_nodes.insert(terminal_node_id, classification.terminal_kind);
79
80            let mut enriched = edge.clone();
81            enriched.direction = Some(classification.direction);
82            enriched.operation = Some(classification.operation);
83            enriched
84        })
85        .collect();
86
87    let nodes = graph
88        .nodes
89        .iter()
90        .map(|node| {
91            if let Some(kind) = terminal_nodes.get(&node.id) {
92                let mut enriched = node.clone();
93                enriched.role = Some(NodeRole::Terminal { kind: *kind });
94                enriched
95            } else {
96                node.clone()
97            }
98        })
99        .collect();
100
101    Graph {
102        version: graph.version.clone(),
103        nodes,
104        edges,
105    }
106}
107
108pub fn classify_extraction_result(
109    mut result: ExtractionResult,
110    classifier: &CompositeClassifier,
111) -> ExtractionResult {
112    let node_ids: HashSet<&str> = result.nodes.iter().map(|node| node.id.as_str()).collect();
113    let node_file_map: HashMap<&str, &PathBuf> = result
114        .nodes
115        .iter()
116        .map(|node| (node.id.as_str(), &node.file))
117        .collect();
118    let mut terminal_nodes: HashMap<String, TerminalKind> = HashMap::new();
119
120    result.edges = result
121        .edges
122        .into_iter()
123        .map(|mut edge| {
124            if edge.kind != EdgeKind::Calls {
125                return edge;
126            }
127
128            let source_file = node_file_map
129                .get(edge.source.as_str())
130                .cloned()
131                .cloned()
132                .unwrap_or_default();
133            let context = ClassifyContext {
134                source_node: edge.source.clone(),
135                file: source_file,
136                arguments: Vec::new(),
137            };
138
139            if let Some(classification) = classifier.classify(&edge.target, &context) {
140                let terminal_node_id = if node_ids.contains(edge.target.as_str()) {
141                    edge.target.clone()
142                } else {
143                    edge.source.clone()
144                };
145                terminal_nodes.insert(terminal_node_id, classification.terminal_kind);
146                edge.direction = Some(classification.direction);
147                edge.operation = Some(classification.operation);
148            }
149
150            edge
151        })
152        .collect();
153
154    for node in &mut result.nodes {
155        if let Some(kind) = terminal_nodes.get(&node.id)
156            && node.role.is_none()
157        {
158            node.role = Some(NodeRole::Terminal { kind: *kind });
159        }
160    }
161
162    result
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::graph::*;
169    use std::collections::HashMap;
170
171    struct AlwaysMatch {
172        classification: Classification,
173    }
174
175    impl Classifier for AlwaysMatch {
176        fn classify(
177            &self,
178            _call_target: &str,
179            _context: &ClassifyContext,
180        ) -> Option<Classification> {
181            Some(self.classification.clone())
182        }
183    }
184
185    struct NeverMatch;
186
187    impl Classifier for NeverMatch {
188        fn classify(
189            &self,
190            _call_target: &str,
191            _context: &ClassifyContext,
192        ) -> Option<Classification> {
193            None
194        }
195    }
196
197    fn test_context() -> ClassifyContext {
198        ClassifyContext {
199            source_node: "test::caller".to_string(),
200            file: PathBuf::from("test.rs"),
201            arguments: vec![],
202        }
203    }
204
205    #[test]
206    fn composite_returns_first_match() {
207        let classifier = CompositeClassifier::new(vec![Box::new(AlwaysMatch {
208            classification: Classification {
209                terminal_kind: TerminalKind::Network,
210                direction: FlowDirection::Read,
211                operation: "HTTP_GET".to_string(),
212            },
213        })]);
214        let result = classifier.classify("something", &test_context());
215        assert!(result.is_some());
216        assert_eq!(result.unwrap().terminal_kind, TerminalKind::Network);
217    }
218
219    #[test]
220    fn composite_returns_none_when_no_match() {
221        let classifier = CompositeClassifier::new(vec![Box::new(NeverMatch)]);
222        assert!(classifier.classify("something", &test_context()).is_none());
223    }
224
225    #[test]
226    fn classifies_external_call_on_source_node() {
227        let graph = Graph {
228            version: "0.1.0".to_string(),
229            nodes: vec![Node {
230                id: "src::caller".to_string(),
231                kind: NodeKind::Function,
232                name: "caller".to_string(),
233                file: PathBuf::from("src/main.rs"),
234                span: Span {
235                    start: [0, 0],
236                    end: [1, 0],
237                },
238                visibility: Visibility::Public,
239                metadata: HashMap::new(),
240                role: None,
241                signature: None,
242                doc_comment: None,
243                module: None,
244                snippet: None,
245            }],
246            edges: vec![Edge {
247                source: "src::caller".to_string(),
248                target: "reqwest::get".to_string(),
249                kind: EdgeKind::Calls,
250                confidence: 0.9,
251                direction: None,
252                operation: None,
253                condition: None,
254                async_boundary: None,
255                provenance: Vec::new(),
256            }],
257        };
258        let classifier = CompositeClassifier::new(vec![Box::new(AlwaysMatch {
259            classification: Classification {
260                terminal_kind: TerminalKind::Network,
261                direction: FlowDirection::Read,
262                operation: "HTTP".to_string(),
263            },
264        })]);
265
266        let enriched = classify_graph(&graph, &classifier);
267        assert_eq!(
268            enriched.nodes[0].role,
269            Some(NodeRole::Terminal {
270                kind: TerminalKind::Network,
271            })
272        );
273        assert_eq!(enriched.edges[0].direction, Some(FlowDirection::Read));
274    }
275}