1use crate::error::AstError;
4use crate::error::AstResult;
5use crate::types::AstNodeKind;
6use crate::types::ParsedAst;
7use crate::types::SourceLocation;
8use crate::types::Visibility;
9use dashmap::DashMap;
10use std::path::Path;
12use std::path::PathBuf;
13use tree_sitter::Node;
14use tree_sitter::TreeCursor;
15
16#[derive(Debug, Clone)]
18pub struct Symbol {
19 pub name: String,
20 pub kind: SymbolKind,
21 pub location: SourceLocation,
22 pub visibility: Visibility,
23 pub signature: String,
24 pub documentation: Option<String>,
25 pub references: Vec<SourceLocation>,
26 pub definitions: Vec<SourceLocation>,
27 pub call_sites: Vec<SourceLocation>,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub enum SymbolKind {
33 Function,
34 Method,
35 Class,
36 Struct,
37 Enum,
38 Interface,
39 Trait,
40 Module,
41 Variable,
42 Constant,
43 Type,
44 Property,
45 Field,
46 Parameter,
47}
48
49impl SymbolKind {
50 pub const fn from_ast_kind(kind: AstNodeKind) -> Self {
52 match kind {
53 AstNodeKind::Function => Self::Function,
54 AstNodeKind::Class => Self::Class,
55 AstNodeKind::Struct => Self::Struct,
56 AstNodeKind::Enum => Self::Enum,
57 AstNodeKind::Interface => Self::Interface,
58 AstNodeKind::Trait => Self::Trait,
59 AstNodeKind::Module => Self::Module,
60 AstNodeKind::Variable => Self::Variable,
61 AstNodeKind::Constant => Self::Constant,
62 AstNodeKind::Type => Self::Type,
63 _ => Self::Variable,
64 }
65 }
66}
67
68#[derive(Debug)]
70pub struct SemanticIndex {
71 symbols: DashMap<String, Symbol>,
73 file_symbols: DashMap<PathBuf, Vec<String>>,
75 call_graph: DashMap<String, Vec<String>>,
77 inheritance_graph: DashMap<String, Vec<String>>,
79 import_graph: DashMap<PathBuf, Vec<PathBuf>>,
81}
82
83impl SemanticIndex {
84 pub fn new() -> Self {
86 Self {
87 symbols: DashMap::new(),
88 file_symbols: DashMap::new(),
89 call_graph: DashMap::new(),
90 inheritance_graph: DashMap::new(),
91 import_graph: DashMap::new(),
92 }
93 }
94
95 pub fn index_ast(&mut self, path: &Path, ast: &ParsedAst) -> AstResult<()> {
97 let source = ast.source.as_bytes();
98 let mut cursor = ast.tree.root_node().walk();
99 let mut symbols = Vec::new();
100
101 self.extract_symbols(&mut cursor, source, path, &mut symbols)?;
103
104 let symbol_ids: Vec<String> = symbols.iter().map(|s| self.get_symbol_id(s)).collect();
106 self.file_symbols.insert(path.to_path_buf(), symbol_ids);
107
108 for symbol in symbols {
110 let id = self.get_symbol_id(&symbol);
111 self.symbols.insert(id, symbol);
112 }
113
114 self.build_relationships(path, ast)?;
116
117 Ok(())
118 }
119
120 fn extract_symbols(
122 &self,
123 cursor: &mut TreeCursor,
124 source: &[u8],
125 file_path: &Path,
126 symbols: &mut Vec<Symbol>,
127 ) -> AstResult<()> {
128 let node = cursor.node();
129 let node_type = node.kind();
130 let _ast_kind = AstNodeKind::from_node_type(node_type);
131
132 if self.is_symbol_definition(&node) {
134 let symbol = self.create_symbol(&node, source, file_path)?;
135 symbols.push(symbol);
136 }
137
138 if cursor.goto_first_child() {
140 loop {
141 self.extract_symbols(cursor, source, file_path, symbols)?;
142 if !cursor.goto_next_sibling() {
143 break;
144 }
145 }
146 cursor.goto_parent();
147 }
148
149 Ok(())
150 }
151
152 fn is_symbol_definition(&self, node: &Node) -> bool {
154 let node_type = node.kind();
155 matches!(
156 node_type,
157 "function_declaration"
158 | "function_definition"
159 | "function_item"
160 | "method_declaration"
161 | "method_definition"
162 | "class_declaration"
163 | "class_definition"
164 | "struct_item"
165 | "struct_declaration"
166 | "enum_item"
167 | "enum_declaration"
168 | "interface_declaration"
169 | "protocol_declaration"
170 | "trait_item"
171 | "trait_declaration"
172 | "module"
173 | "module_declaration"
174 | "variable_declaration"
175 | "const_item"
176 | "let_declaration"
177 | "type_alias"
178 | "typedef"
179 )
180 }
181
182 fn create_symbol(&self, node: &Node, source: &[u8], file_path: &Path) -> AstResult<Symbol> {
184 let name = self.extract_symbol_name(node, source)?;
185 let kind = SymbolKind::from_ast_kind(AstNodeKind::from_node_type(node.kind()));
186
187 let location = SourceLocation::new(
188 file_path.display().to_string(),
189 node.start_position().row + 1,
190 node.start_position().column + 1,
191 node.end_position().row + 1,
192 node.end_position().column + 1,
193 (node.start_byte(), node.end_byte()),
194 );
195
196 let signature = self.extract_signature(node, source)?;
197 let visibility = self.extract_visibility(node, source);
198 let documentation = self.extract_documentation(node, source);
199
200 Ok(Symbol {
201 name,
202 kind,
203 location: location.clone(),
204 visibility,
205 signature,
206 documentation,
207 references: Vec::new(),
208 definitions: vec![location],
209 call_sites: Vec::new(),
210 })
211 }
212
213 fn extract_symbol_name(&self, node: &Node, source: &[u8]) -> AstResult<String> {
215 for i in 0..node.child_count() {
217 if let Some(child) = node.child(i)
218 && (child.kind() == "identifier" || child.kind() == "name")
219 {
220 let name = std::str::from_utf8(&source[child.byte_range()])
221 .map_err(|e| AstError::ParserError(e.to_string()))?;
222 return Ok(name.to_string());
223 }
224 }
225
226 let text = std::str::from_utf8(&source[node.byte_range()])
228 .map_err(|e| AstError::ParserError(e.to_string()))?;
229
230 let words: Vec<&str> = text.split_whitespace().collect();
232 for word in words {
233 if !Self::is_keyword(word) && word.chars().any(|c| c.is_alphabetic()) {
234 return Ok(word.to_string());
235 }
236 }
237
238 Ok("anonymous".to_string())
239 }
240
241 fn is_keyword(word: &str) -> bool {
243 matches!(
244 word,
245 "fn" | "function"
246 | "def"
247 | "class"
248 | "struct"
249 | "enum"
250 | "interface"
251 | "trait"
252 | "impl"
253 | "module"
254 | "namespace"
255 | "const"
256 | "let"
257 | "var"
258 | "type"
259 | "public"
260 | "private"
261 | "protected"
262 | "static"
263 | "async"
264 | "export"
265 | "import"
266 )
267 }
268
269 fn extract_signature(&self, node: &Node, source: &[u8]) -> AstResult<String> {
271 let mut sig_end = node.end_byte();
273
274 for i in 0..node.child_count() {
275 if let Some(child) = node.child(i)
276 && (child.kind() == "block"
277 || child.kind() == "compound_statement"
278 || child.kind() == "function_body")
279 {
280 sig_end = child.start_byte();
281 break;
282 }
283 }
284
285 let signature = std::str::from_utf8(&source[node.start_byte()..sig_end])
286 .map_err(|e| AstError::ParserError(e.to_string()))?;
287
288 Ok(signature.trim().to_string())
289 }
290
291 fn extract_visibility(&self, node: &Node, source: &[u8]) -> Visibility {
293 let text = std::str::from_utf8(&source[node.byte_range()]).unwrap_or("");
294 Visibility::from_text(text)
295 }
296
297 fn extract_documentation(&self, node: &Node, source: &[u8]) -> Option<String> {
299 if let Some(prev) = node.prev_sibling()
301 && prev.kind().contains("comment")
302 {
303 let doc = std::str::from_utf8(&source[prev.byte_range()]).ok()?;
304 return Some(doc.to_string());
305 }
306 None
307 }
308
309 const fn build_relationships(&mut self, _path: &Path, _ast: &ParsedAst) -> AstResult<()> {
311 Ok(())
315 }
316
317 fn get_symbol_id(&self, symbol: &Symbol) -> String {
319 format!(
320 "{}:{}:{}",
321 symbol.location.file_path, symbol.kind as u8, symbol.name
322 )
323 }
324
325 pub fn search(&self, query: &str) -> Vec<Symbol> {
327 let query_lower = query.to_lowercase();
328 let mut results = Vec::new();
329
330 for entry in self.symbols.iter() {
331 let symbol = entry.value();
332 if symbol.name.to_lowercase().contains(&query_lower) {
333 results.push(symbol.clone());
334 }
335 }
336
337 results
338 }
339
340 pub fn get_call_graph(&self, path: &Path, function_name: &str) -> Vec<Symbol> {
342 let symbol_id = format!(
344 "{}:{}:{}",
345 path.display(),
346 SymbolKind::Function as u8,
347 function_name
348 );
349
350 if let Some(callees) = self.call_graph.get(&symbol_id) {
352 let mut results = Vec::new();
353 for callee_id in callees.value() {
354 if let Some(symbol) = self.symbols.get(callee_id) {
355 results.push(symbol.clone());
356 }
357 }
358 return results;
359 }
360
361 Vec::new()
362 }
363
364 pub fn get_file_symbols(&self, path: &Path) -> Vec<Symbol> {
366 if let Some(symbol_ids) = self.file_symbols.get(&path.to_path_buf()) {
367 let mut results = Vec::new();
368 for id in symbol_ids.value() {
369 if let Some(symbol) = self.symbols.get(id) {
370 results.push(symbol.clone());
371 }
372 }
373 return results;
374 }
375 Vec::new()
376 }
377
378 pub fn clear(&mut self) {
380 self.symbols.clear();
381 self.file_symbols.clear();
382 self.call_graph.clear();
383 self.inheritance_graph.clear();
384 self.import_graph.clear();
385 }
386}
387
388impl Default for SemanticIndex {
389 fn default() -> Self {
390 Self::new()
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::language_registry::Language;
398 use crate::language_registry::LanguageRegistry;
399
400 #[test]
401 fn test_semantic_indexing() {
402 let mut index = SemanticIndex::new();
403 let registry = LanguageRegistry::new();
404
405 let code = r#"
406pub fn calculate(x: i32, y: i32) -> i32 {
407 add(x, y)
408}
409
410fn add(a: i32, b: i32) -> i32 {
411 a + b
412}
413
414pub struct Calculator {
415 value: i32,
416}
417
418impl Calculator {
419 pub fn new() -> Self {
420 Self { value: 0 }
421 }
422}
423"#;
424
425 let ast = registry.parse(&Language::Rust, code).unwrap();
426 let path = Path::new("test.rs");
427
428 index.index_ast(path, &ast).unwrap();
429
430 let results = index.search("calc");
432 assert!(!results.is_empty());
433 assert!(results.iter().any(|s| s.name.contains("calculate")));
434
435 let file_symbols = index.get_file_symbols(path);
437 assert!(file_symbols.len() >= 3); }
439}