1use std::collections::{HashMap, HashSet};
2use std::path::PathBuf;
3
4use crate::extract::ExtractionResult;
5use crate::graph::{EdgeKind, FlowDirection, Graph, NodeRole, TerminalKind};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct Classification {
9 pub terminal_kind: TerminalKind,
10 pub direction: FlowDirection,
11 pub operation: String,
12}
13
14#[derive(Debug, Clone)]
15pub struct ClassifyContext {
16 pub source_node: String,
17 pub file: PathBuf,
18 pub arguments: Vec<String>,
19}
20
21pub trait Classifier: Send + Sync {
22 fn classify(&self, call_target: &str, context: &ClassifyContext) -> Option<Classification>;
23}
24
25pub struct CompositeClassifier {
26 classifiers: Vec<Box<dyn Classifier>>,
27}
28
29impl CompositeClassifier {
30 pub fn new(classifiers: Vec<Box<dyn Classifier>>) -> Self {
31 Self { classifiers }
32 }
33
34 pub fn classify(&self, call_target: &str, context: &ClassifyContext) -> Option<Classification> {
35 self.classifiers
36 .iter()
37 .find_map(|classifier| classifier.classify(call_target, context))
38 }
39}
40
41pub fn classify_graph(graph: &Graph, classifier: &CompositeClassifier) -> Graph {
42 let node_file_map: HashMap<&str, &PathBuf> = graph
43 .nodes
44 .iter()
45 .map(|node| (node.id.as_str(), &node.file))
46 .collect();
47 let node_ids: HashSet<&str> = graph.nodes.iter().map(|node| node.id.as_str()).collect();
48 let mut terminal_nodes: HashMap<String, TerminalKind> = HashMap::new();
49
50 let edges = graph
51 .edges
52 .iter()
53 .map(|edge| {
54 if edge.kind != EdgeKind::Calls {
55 return edge.clone();
56 }
57
58 let target_name = edge.target.rsplit("::").next().unwrap_or(&edge.target);
59 let source_file = node_file_map
60 .get(edge.source.as_str())
61 .cloned()
62 .cloned()
63 .unwrap_or_default();
64 let context = ClassifyContext {
65 source_node: edge.source.clone(),
66 file: source_file,
67 arguments: Vec::new(),
68 };
69
70 let Some(classification) = classifier.classify(target_name, &context) else {
71 return edge.clone();
72 };
73
74 let terminal_node_id = if node_ids.contains(edge.target.as_str()) {
75 edge.target.clone()
76 } else {
77 edge.source.clone()
78 };
79 terminal_nodes.insert(terminal_node_id, classification.terminal_kind);
80
81 let mut enriched = edge.clone();
82 enriched.direction = Some(classification.direction);
83 enriched.operation = Some(classification.operation);
84 enriched
85 })
86 .collect();
87
88 let nodes = graph
89 .nodes
90 .iter()
91 .map(|node| {
92 if let Some(kind) = terminal_nodes.get(&node.id) {
93 let mut enriched = node.clone();
94 enriched.role = Some(NodeRole::Terminal { kind: *kind });
95 enriched
96 } else {
97 node.clone()
98 }
99 })
100 .collect();
101
102 Graph {
103 version: graph.version.clone(),
104 nodes,
105 edges,
106 }
107}
108
109pub fn classify_extraction_result(
110 mut result: ExtractionResult,
111 classifier: &CompositeClassifier,
112) -> ExtractionResult {
113 let node_ids: HashSet<&str> = result.nodes.iter().map(|node| node.id.as_str()).collect();
114 let node_file_map: HashMap<&str, &PathBuf> = result
115 .nodes
116 .iter()
117 .map(|node| (node.id.as_str(), &node.file))
118 .collect();
119 let mut terminal_nodes: HashMap<String, TerminalKind> = HashMap::new();
120
121 result.edges = result
122 .edges
123 .into_iter()
124 .map(|mut edge| {
125 if edge.kind != EdgeKind::Calls {
126 return edge;
127 }
128
129 let target_name = edge.target.rsplit("::").next().unwrap_or(&edge.target);
130 let source_file = node_file_map
131 .get(edge.source.as_str())
132 .cloned()
133 .cloned()
134 .unwrap_or_default();
135 let context = ClassifyContext {
136 source_node: edge.source.clone(),
137 file: source_file,
138 arguments: Vec::new(),
139 };
140
141 if let Some(classification) = classifier.classify(target_name, &context) {
142 let terminal_node_id = if node_ids.contains(edge.target.as_str()) {
143 edge.target.clone()
144 } else {
145 edge.source.clone()
146 };
147 terminal_nodes.insert(terminal_node_id, classification.terminal_kind);
148 edge.direction = Some(classification.direction);
149 edge.operation = Some(classification.operation);
150 }
151
152 edge
153 })
154 .collect();
155
156 for node in &mut result.nodes {
157 if let Some(kind) = terminal_nodes.get(&node.id)
158 && node.role.is_none()
159 {
160 node.role = Some(NodeRole::Terminal { kind: *kind });
161 }
162 }
163
164 result
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::graph::*;
171 use std::collections::HashMap;
172
173 struct AlwaysMatch {
174 classification: Classification,
175 }
176
177 impl Classifier for AlwaysMatch {
178 fn classify(
179 &self,
180 _call_target: &str,
181 _context: &ClassifyContext,
182 ) -> Option<Classification> {
183 Some(self.classification.clone())
184 }
185 }
186
187 struct NeverMatch;
188
189 impl Classifier for NeverMatch {
190 fn classify(
191 &self,
192 _call_target: &str,
193 _context: &ClassifyContext,
194 ) -> Option<Classification> {
195 None
196 }
197 }
198
199 fn test_context() -> ClassifyContext {
200 ClassifyContext {
201 source_node: "test::caller".to_string(),
202 file: PathBuf::from("test.rs"),
203 arguments: vec![],
204 }
205 }
206
207 #[test]
208 fn composite_returns_first_match() {
209 let classifier = CompositeClassifier::new(vec![Box::new(AlwaysMatch {
210 classification: Classification {
211 terminal_kind: TerminalKind::Network,
212 direction: FlowDirection::Read,
213 operation: "HTTP_GET".to_string(),
214 },
215 })]);
216 let result = classifier.classify("something", &test_context());
217 assert!(result.is_some());
218 assert_eq!(result.unwrap().terminal_kind, TerminalKind::Network);
219 }
220
221 #[test]
222 fn composite_returns_none_when_no_match() {
223 let classifier = CompositeClassifier::new(vec![Box::new(NeverMatch)]);
224 assert!(classifier.classify("something", &test_context()).is_none());
225 }
226
227 #[test]
228 fn classifies_external_call_on_source_node() {
229 let graph = Graph {
230 version: "0.1.0".to_string(),
231 nodes: vec![Node {
232 id: "src::caller".to_string(),
233 kind: NodeKind::Function,
234 name: "caller".to_string(),
235 file: PathBuf::from("src/main.rs"),
236 span: Span {
237 start: [0, 0],
238 end: [1, 0],
239 },
240 visibility: Visibility::Public,
241 metadata: HashMap::new(),
242 role: None,
243 signature: None,
244 doc_comment: None,
245 module: None,
246 snippet: None,
247 }],
248 edges: vec![Edge {
249 source: "src::caller".to_string(),
250 target: "reqwest::get".to_string(),
251 kind: EdgeKind::Calls,
252 confidence: 0.9,
253 direction: None,
254 operation: None,
255 condition: None,
256 async_boundary: None,
257 provenance: Vec::new(),
258 }],
259 };
260 let classifier = CompositeClassifier::new(vec![Box::new(AlwaysMatch {
261 classification: Classification {
262 terminal_kind: TerminalKind::Network,
263 direction: FlowDirection::Read,
264 operation: "HTTP".to_string(),
265 },
266 })]);
267
268 let enriched = classify_graph(&graph, &classifier);
269 assert_eq!(
270 enriched.nodes[0].role,
271 Some(NodeRole::Terminal {
272 kind: TerminalKind::Network,
273 })
274 );
275 assert_eq!(enriched.edges[0].direction, Some(FlowDirection::Read));
276 }
277}