1use std::collections::{HashMap, HashSet};
2use std::path::PathBuf;
3
4use crate::extract::ExtractionResult;
5use crate::graph::{EdgeKind, FlowDirection, Graph, NodeRole, TerminalKind};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct Classification {
9 pub terminal_kind: TerminalKind,
10 pub direction: FlowDirection,
11 pub operation: String,
12}
13
14#[derive(Debug, Clone)]
15pub struct ClassifyContext {
16 pub source_node: String,
17 pub file: PathBuf,
18 pub arguments: Vec<String>,
19}
20
21pub trait Classifier: Send + Sync {
22 fn classify(&self, call_target: &str, context: &ClassifyContext) -> Option<Classification>;
23}
24
25pub struct CompositeClassifier {
26 classifiers: Vec<Box<dyn Classifier>>,
27}
28
29impl CompositeClassifier {
30 pub fn new(classifiers: Vec<Box<dyn Classifier>>) -> Self {
31 Self { classifiers }
32 }
33
34 pub fn classify(&self, call_target: &str, context: &ClassifyContext) -> Option<Classification> {
35 self.classifiers
36 .iter()
37 .find_map(|classifier| classifier.classify(call_target, context))
38 }
39}
40
41pub fn classify_graph(graph: &Graph, classifier: &CompositeClassifier) -> Graph {
42 let node_file_map: HashMap<&str, &PathBuf> = graph
43 .nodes
44 .iter()
45 .map(|node| (node.id.as_str(), &node.file))
46 .collect();
47 let node_ids: HashSet<&str> = graph.nodes.iter().map(|node| node.id.as_str()).collect();
48 let mut terminal_nodes: HashMap<String, TerminalKind> = HashMap::new();
49
50 let edges = graph
51 .edges
52 .iter()
53 .map(|edge| {
54 if edge.kind != EdgeKind::Calls {
55 return edge.clone();
56 }
57
58 let source_file = node_file_map
59 .get(edge.source.as_str())
60 .cloned()
61 .cloned()
62 .unwrap_or_default();
63 let context = ClassifyContext {
64 source_node: edge.source.clone(),
65 file: source_file,
66 arguments: Vec::new(),
67 };
68
69 let Some(classification) = classifier.classify(&edge.target, &context) else {
70 return edge.clone();
71 };
72
73 let terminal_node_id = if node_ids.contains(edge.target.as_str()) {
74 edge.target.clone()
75 } else {
76 edge.source.clone()
77 };
78 terminal_nodes.insert(terminal_node_id, classification.terminal_kind);
79
80 let mut enriched = edge.clone();
81 enriched.direction = Some(classification.direction);
82 enriched.operation = Some(classification.operation);
83 enriched
84 })
85 .collect();
86
87 let nodes = graph
88 .nodes
89 .iter()
90 .map(|node| {
91 if let Some(kind) = terminal_nodes.get(&node.id) {
92 let mut enriched = node.clone();
93 enriched.role = Some(NodeRole::Terminal { kind: *kind });
94 enriched
95 } else {
96 node.clone()
97 }
98 })
99 .collect();
100
101 Graph {
102 version: graph.version.clone(),
103 nodes,
104 edges,
105 }
106}
107
108pub fn classify_extraction_result(
109 mut result: ExtractionResult,
110 classifier: &CompositeClassifier,
111) -> ExtractionResult {
112 let node_ids: HashSet<&str> = result.nodes.iter().map(|node| node.id.as_str()).collect();
113 let node_file_map: HashMap<&str, &PathBuf> = result
114 .nodes
115 .iter()
116 .map(|node| (node.id.as_str(), &node.file))
117 .collect();
118 let mut terminal_nodes: HashMap<String, TerminalKind> = HashMap::new();
119
120 result.edges = result
121 .edges
122 .into_iter()
123 .map(|mut edge| {
124 if edge.kind != EdgeKind::Calls {
125 return edge;
126 }
127
128 let source_file = node_file_map
129 .get(edge.source.as_str())
130 .cloned()
131 .cloned()
132 .unwrap_or_default();
133 let context = ClassifyContext {
134 source_node: edge.source.clone(),
135 file: source_file,
136 arguments: Vec::new(),
137 };
138
139 if let Some(classification) = classifier.classify(&edge.target, &context) {
140 let terminal_node_id = if node_ids.contains(edge.target.as_str()) {
141 edge.target.clone()
142 } else {
143 edge.source.clone()
144 };
145 terminal_nodes.insert(terminal_node_id, classification.terminal_kind);
146 edge.direction = Some(classification.direction);
147 edge.operation = Some(classification.operation);
148 }
149
150 edge
151 })
152 .collect();
153
154 for node in &mut result.nodes {
155 if let Some(kind) = terminal_nodes.get(&node.id)
156 && node.role.is_none()
157 {
158 node.role = Some(NodeRole::Terminal { kind: *kind });
159 }
160 }
161
162 result
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::graph::*;
169 use std::collections::HashMap;
170
171 struct AlwaysMatch {
172 classification: Classification,
173 }
174
175 impl Classifier for AlwaysMatch {
176 fn classify(
177 &self,
178 _call_target: &str,
179 _context: &ClassifyContext,
180 ) -> Option<Classification> {
181 Some(self.classification.clone())
182 }
183 }
184
185 struct NeverMatch;
186
187 impl Classifier for NeverMatch {
188 fn classify(
189 &self,
190 _call_target: &str,
191 _context: &ClassifyContext,
192 ) -> Option<Classification> {
193 None
194 }
195 }
196
197 fn test_context() -> ClassifyContext {
198 ClassifyContext {
199 source_node: "test::caller".to_string(),
200 file: PathBuf::from("test.rs"),
201 arguments: vec![],
202 }
203 }
204
205 #[test]
206 fn composite_returns_first_match() {
207 let classifier = CompositeClassifier::new(vec![Box::new(AlwaysMatch {
208 classification: Classification {
209 terminal_kind: TerminalKind::Network,
210 direction: FlowDirection::Read,
211 operation: "HTTP_GET".to_string(),
212 },
213 })]);
214 let result = classifier.classify("something", &test_context());
215 assert!(result.is_some());
216 assert_eq!(result.unwrap().terminal_kind, TerminalKind::Network);
217 }
218
219 #[test]
220 fn composite_returns_none_when_no_match() {
221 let classifier = CompositeClassifier::new(vec![Box::new(NeverMatch)]);
222 assert!(classifier.classify("something", &test_context()).is_none());
223 }
224
225 #[test]
226 fn classifies_external_call_on_source_node() {
227 let graph = Graph {
228 version: "0.1.0".to_string(),
229 nodes: vec![Node {
230 id: "src::caller".to_string(),
231 kind: NodeKind::Function,
232 name: "caller".to_string(),
233 file: PathBuf::from("src/main.rs"),
234 span: Span {
235 start: [0, 0],
236 end: [1, 0],
237 },
238 visibility: Visibility::Public,
239 metadata: HashMap::new(),
240 role: None,
241 signature: None,
242 doc_comment: None,
243 module: None,
244 snippet: None,
245 }],
246 edges: vec![Edge {
247 source: "src::caller".to_string(),
248 target: "reqwest::get".to_string(),
249 kind: EdgeKind::Calls,
250 confidence: 0.9,
251 direction: None,
252 operation: None,
253 condition: None,
254 async_boundary: None,
255 provenance: Vec::new(),
256 }],
257 };
258 let classifier = CompositeClassifier::new(vec![Box::new(AlwaysMatch {
259 classification: Classification {
260 terminal_kind: TerminalKind::Network,
261 direction: FlowDirection::Read,
262 operation: "HTTP".to_string(),
263 },
264 })]);
265
266 let enriched = classify_graph(&graph, &classifier);
267 assert_eq!(
268 enriched.nodes[0].role,
269 Some(NodeRole::Terminal {
270 kind: TerminalKind::Network,
271 })
272 );
273 assert_eq!(enriched.edges[0].direction, Some(FlowDirection::Read));
274 }
275}