1use crate::common::safe_slice;
6use crate::ingest::{SymbolFact, SymbolKind};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::PathBuf;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
15pub struct ReferenceFact {
16 pub file_path: PathBuf,
18 pub referenced_symbol: String,
20 pub byte_start: usize,
22 pub byte_end: usize,
24 pub start_line: usize,
26 pub start_col: usize,
28 pub end_line: usize,
30 pub end_col: usize,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
38pub struct CallFact {
39 pub file_path: PathBuf,
41 pub caller: String,
43 pub callee: String,
45 #[serde(default)]
47 pub caller_symbol_id: Option<String>,
48 #[serde(default)]
50 pub callee_symbol_id: Option<String>,
51 pub byte_start: usize,
53 pub byte_end: usize,
55 pub start_line: usize,
57 pub start_col: usize,
59 pub end_line: usize,
61 pub end_col: usize,
63}
64
65pub struct ReferenceExtractor {
67 parser: tree_sitter::Parser,
68}
69
70impl ReferenceExtractor {
71 pub fn new() -> anyhow::Result<Self> {
73 let mut parser = tree_sitter::Parser::new();
74 let language = tree_sitter_rust::language();
75 parser.set_language(&language)?;
76
77 Ok(Self { parser })
78 }
79
80 pub fn extract_references(
96 &mut self,
97 file_path: PathBuf,
98 source: &[u8],
99 symbols: &[SymbolFact],
100 ) -> Vec<ReferenceFact> {
101 let tree = match self.parser.parse(source, None) {
102 Some(t) => t,
103 None => return Vec::new(),
104 };
105
106 let root_node = tree.root_node();
107 let mut references = Vec::new();
108
109 self.walk_tree_for_references(&root_node, source, &file_path, symbols, &mut references);
111
112 references
113 }
114
115 fn walk_tree_for_references(
117 &self,
118 node: &tree_sitter::Node,
119 source: &[u8],
120 file_path: &PathBuf,
121 symbols: &[SymbolFact],
122 references: &mut Vec<ReferenceFact>,
123 ) {
124 if let Some(reference) = self.extract_reference(node, source, file_path, symbols) {
126 references.push(reference);
127
128 if node.kind() == "scoped_identifier" {
131 return;
132 }
133 }
134
135 let mut cursor = node.walk();
137 for child in node.children(&mut cursor) {
138 self.walk_tree_for_references(&child, source, file_path, symbols, references);
139 }
140 }
141
142 fn extract_reference(
144 &self,
145 node: &tree_sitter::Node,
146 source: &[u8],
147 file_path: &PathBuf,
148 symbols: &[SymbolFact],
149 ) -> Option<ReferenceFact> {
150 let kind = node.kind();
151
152 match kind {
154 "identifier" => {}
155 "scoped_identifier" => {}
156 _ => return None,
157 }
158
159 let text_bytes = safe_slice(source, node.start_byte(), node.end_byte())?;
161 let text = std::str::from_utf8(text_bytes).ok()?;
162
163 let symbol_name = if kind == "scoped_identifier" {
165 text.split("::").last().unwrap_or(text)
167 } else {
168 text
169 };
170
171 let referenced_symbol = symbols
173 .iter()
174 .find(|s| s.name.as_ref().map(|n| n == symbol_name).unwrap_or(false))?;
175
176 let ref_start = node.start_byte();
178 let ref_end = node.end_byte();
179
180 if referenced_symbol.file_path == *file_path && ref_start < referenced_symbol.byte_end {
183 return None; }
185
186 Some(ReferenceFact {
187 file_path: file_path.clone(),
188 referenced_symbol: symbol_name.to_string(),
189 byte_start: ref_start,
190 byte_end: ref_end,
191 start_line: node.start_position().row + 1,
192 start_col: node.start_position().column,
193 end_line: node.end_position().row + 1,
194 end_col: node.end_position().column,
195 })
196 }
197}
198
199impl Default for ReferenceExtractor {
200 fn default() -> Self {
201 Self::new().expect("Failed to create reference extractor")
202 }
203}
204
205impl crate::ingest::Parser {
207 pub fn extract_references(
209 &mut self,
210 file_path: PathBuf,
211 source: &[u8],
212 symbols: &[SymbolFact],
213 ) -> Vec<ReferenceFact> {
214 let mut extractor = ReferenceExtractor::new().unwrap();
215 extractor.extract_references(file_path, source, symbols)
216 }
217
218 pub fn extract_calls(
234 &mut self,
235 file_path: PathBuf,
236 source: &[u8],
237 symbols: &[SymbolFact],
238 ) -> Vec<CallFact> {
239 let mut extractor = CallExtractor::new().unwrap();
240 extractor.extract_calls(file_path, source, symbols)
241 }
242}
243
244pub struct CallExtractor {
248 parser: tree_sitter::Parser,
249}
250
251impl CallExtractor {
252 pub fn new() -> anyhow::Result<Self> {
254 let mut parser = tree_sitter::Parser::new();
255 let language = tree_sitter_rust::language();
256 parser.set_language(&language)?;
257
258 Ok(Self { parser })
259 }
260
261 pub fn extract_calls(
269 &mut self,
270 file_path: PathBuf,
271 source: &[u8],
272 symbols: &[SymbolFact],
273 ) -> Vec<CallFact> {
274 let tree = match self.parser.parse(source, None) {
275 Some(t) => t,
276 None => return Vec::new(),
277 };
278
279 let root_node = tree.root_node();
280 let mut calls = Vec::new();
281
282 let symbol_map: HashMap<String, &SymbolFact> = symbols
284 .iter()
285 .filter_map(|s| s.name.as_ref().map(|name| (name.clone(), s)))
286 .collect();
287
288 let functions: Vec<&SymbolFact> = symbols
290 .iter()
291 .filter(|s| s.kind == SymbolKind::Function)
292 .collect();
293
294 self.walk_tree_for_calls(
296 &root_node,
297 source,
298 &file_path,
299 &symbol_map,
300 &functions,
301 &mut calls,
302 );
303
304 calls
305 }
306
307 fn walk_tree_for_calls(
309 &self,
310 node: &tree_sitter::Node,
311 source: &[u8],
312 file_path: &PathBuf,
313 symbol_map: &HashMap<String, &SymbolFact>,
314 _functions: &[&SymbolFact],
315 calls: &mut Vec<CallFact>,
316 ) {
317 self.walk_tree_for_calls_with_caller(node, source, file_path, symbol_map, None, calls);
318 }
319
320 fn walk_tree_for_calls_with_caller(
322 &self,
323 node: &tree_sitter::Node,
324 source: &[u8],
325 file_path: &PathBuf,
326 symbol_map: &HashMap<String, &SymbolFact>,
327 current_caller: Option<&SymbolFact>,
328 calls: &mut Vec<CallFact>,
329 ) {
330 let kind = node.kind();
331
332 let caller: Option<&SymbolFact> = if kind == "function_item" {
334 self.extract_function_name(node, source)
336 .and_then(|name| symbol_map.get(&name).copied())
337 } else {
338 current_caller
339 };
340
341 if kind == "call_expression" {
343 if let Some(caller_fact) = caller {
344 self.extract_calls_in_node(node, source, file_path, caller_fact, symbol_map, calls);
345 }
346 }
347
348 let mut cursor = node.walk();
350 for child in node.children(&mut cursor) {
351 self.walk_tree_for_calls_with_caller(
352 &child, source, file_path, symbol_map, caller, calls,
353 );
354 }
355 }
356
357 fn extract_function_name(&self, node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
359 let mut cursor = node.walk();
360 for child in node.children(&mut cursor) {
361 if child.kind() == "identifier" || child.kind() == "type_identifier" {
362 let name_bytes = safe_slice(source, child.start_byte(), child.end_byte())?;
363 return std::str::from_utf8(name_bytes).ok().map(|s| s.to_string());
364 }
365 }
366 None
367 }
368
369 fn extract_calls_in_node(
371 &self,
372 node: &tree_sitter::Node,
373 source: &[u8],
374 file_path: &PathBuf,
375 caller: &SymbolFact,
376 symbol_map: &HashMap<String, &SymbolFact>,
377 calls: &mut Vec<CallFact>,
378 ) {
379 let kind = node.kind();
381
382 if kind == "call_expression" {
383 if let Some(callee_name) = self.extract_callee_from_call(node, source) {
385 if symbol_map.contains_key(&callee_name) {
387 let node_start = node.start_byte();
388 let node_end = node.end_byte();
389 let call_fact = CallFact {
390 file_path: file_path.clone(),
391 caller: caller.name.clone().unwrap_or_default(),
392 callee: callee_name,
393 caller_symbol_id: None,
394 callee_symbol_id: None,
395 byte_start: node_start,
396 byte_end: node_end,
397 start_line: node.start_position().row + 1,
398 start_col: node.start_position().column,
399 end_line: node.end_position().row + 1,
400 end_col: node.end_position().column,
401 };
402 calls.push(call_fact);
403 }
404 }
405 }
406 }
407
408 fn extract_callee_from_call(&self, node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
410 let mut cursor = node.walk();
412 for child in node.children(&mut cursor) {
413 let kind = child.kind();
414 if kind == "identifier" {
415 let name_bytes = safe_slice(source, child.start_byte(), child.end_byte())?;
416 return std::str::from_utf8(name_bytes).ok().map(|s| s.to_string());
417 }
418 if kind == "field_expression" || kind == "method_expression" {
420 return self.extract_method_name(&child, source);
422 }
423 }
424 None
425 }
426
427 fn extract_method_name(&self, node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
429 let mut cursor = node.walk();
430 for child in node.children(&mut cursor) {
431 if child.kind() == "field_identifier" {
433 let name_bytes = safe_slice(source, child.start_byte(), child.end_byte())?;
434 return std::str::from_utf8(name_bytes).ok().map(|s| s.to_string());
435 }
436 }
437 None
438 }
439}
440
441impl Default for CallExtractor {
442 fn default() -> Self {
443 Self::new().expect("Failed to create call extractor")
444 }
445}