1use crate::languages::get_language_info;
2use crate::types::{
3 AssignmentInfo, CallInfo, ClassInfo, FieldAccessInfo, FunctionInfo, ImportInfo, ReferenceInfo,
4 ReferenceType, SemanticAnalysis,
5};
6use std::cell::RefCell;
7use std::collections::HashMap;
8use std::sync::LazyLock;
9use thiserror::Error;
10use tracing::instrument;
11use tree_sitter::{Node, Parser, Query, QueryCursor, StreamingIterator};
12
13#[derive(Debug, Error)]
14pub enum ParserError {
15 #[error("Unsupported language: {0}")]
16 UnsupportedLanguage(String),
17 #[error("Failed to parse file: {0}")]
18 ParseError(String),
19 #[error("Invalid UTF-8 in file")]
20 InvalidUtf8,
21 #[error("Query error: {0}")]
22 QueryError(String),
23}
24
25struct CompiledQueries {
28 element: Query,
29 call: Query,
30 import: Option<Query>,
31 impl_block: Option<Query>,
32 reference: Option<Query>,
33 assignment: Option<Query>,
34 field: Option<Query>,
35}
36
37fn build_compiled_queries(
39 lang_info: &crate::languages::LanguageInfo,
40) -> Result<CompiledQueries, ParserError> {
41 let element = Query::new(&lang_info.language, lang_info.element_query).map_err(|e| {
42 ParserError::QueryError(format!(
43 "Failed to compile element query for {}: {}",
44 lang_info.name, e
45 ))
46 })?;
47
48 let call = Query::new(&lang_info.language, lang_info.call_query).map_err(|e| {
49 ParserError::QueryError(format!(
50 "Failed to compile call query for {}: {}",
51 lang_info.name, e
52 ))
53 })?;
54
55 let import = if let Some(import_query_str) = lang_info.import_query {
56 Some(
57 Query::new(&lang_info.language, import_query_str).map_err(|e| {
58 ParserError::QueryError(format!(
59 "Failed to compile import query for {}: {}",
60 lang_info.name, e
61 ))
62 })?,
63 )
64 } else {
65 None
66 };
67
68 let impl_block = if let Some(impl_query_str) = lang_info.impl_query {
69 Some(
70 Query::new(&lang_info.language, impl_query_str).map_err(|e| {
71 ParserError::QueryError(format!(
72 "Failed to compile impl query for {}: {}",
73 lang_info.name, e
74 ))
75 })?,
76 )
77 } else {
78 None
79 };
80
81 let reference = if let Some(ref_query_str) = lang_info.reference_query {
82 Some(Query::new(&lang_info.language, ref_query_str).map_err(|e| {
83 ParserError::QueryError(format!(
84 "Failed to compile reference query for {}: {}",
85 lang_info.name, e
86 ))
87 })?)
88 } else {
89 None
90 };
91
92 let assignment = if let Some(assignment_query_str) = lang_info.assignment_query {
93 Some(
94 Query::new(&lang_info.language, assignment_query_str).map_err(|e| {
95 ParserError::QueryError(format!(
96 "Failed to compile assignment query for {}: {}",
97 lang_info.name, e
98 ))
99 })?,
100 )
101 } else {
102 None
103 };
104
105 let field = if let Some(field_query_str) = lang_info.field_query {
106 Some(
107 Query::new(&lang_info.language, field_query_str).map_err(|e| {
108 ParserError::QueryError(format!(
109 "Failed to compile field query for {}: {}",
110 lang_info.name, e
111 ))
112 })?,
113 )
114 } else {
115 None
116 };
117
118 Ok(CompiledQueries {
119 element,
120 call,
121 import,
122 impl_block,
123 reference,
124 assignment,
125 field,
126 })
127}
128
129fn init_query_cache() -> HashMap<&'static str, CompiledQueries> {
131 let supported_languages = ["rust", "python", "typescript", "tsx", "go", "java"];
132 let mut cache = HashMap::new();
133
134 for lang_name in &supported_languages {
135 if let Some(lang_info) = get_language_info(lang_name) {
136 match build_compiled_queries(&lang_info) {
137 Ok(compiled) => {
138 cache.insert(*lang_name, compiled);
139 }
140 Err(e) => {
141 tracing::error!(
142 "Failed to compile queries for language {}: {}",
143 lang_name,
144 e
145 );
146 }
147 }
148 }
149 }
150
151 cache
152}
153
154static QUERY_CACHE: LazyLock<HashMap<&'static str, CompiledQueries>> =
156 LazyLock::new(init_query_cache);
157
158fn get_compiled_queries(language: &str) -> Result<&'static CompiledQueries, ParserError> {
160 QUERY_CACHE
161 .get(language)
162 .ok_or_else(|| ParserError::UnsupportedLanguage(language.to_string()))
163}
164
165thread_local! {
166 static PARSER: RefCell<Parser> = RefCell::new(Parser::new());
167}
168
169pub struct ElementExtractor;
171
172impl ElementExtractor {
173 #[instrument(skip_all, fields(language))]
175 pub fn extract_with_depth(source: &str, language: &str) -> Result<(usize, usize), ParserError> {
176 let lang_info = get_language_info(language)
177 .ok_or_else(|| ParserError::UnsupportedLanguage(language.to_string()))?;
178
179 let tree = PARSER.with(|p| {
180 let mut parser = p.borrow_mut();
181 parser
182 .set_language(&lang_info.language)
183 .map_err(|e| ParserError::ParseError(format!("Failed to set language: {}", e)))?;
184 parser
185 .parse(source, None)
186 .ok_or_else(|| ParserError::ParseError("Failed to parse".to_string()))
187 })?;
188
189 let compiled = get_compiled_queries(language)?;
190
191 let mut cursor = QueryCursor::new();
192 let mut function_count = 0;
193 let mut class_count = 0;
194
195 let mut matches = cursor.matches(&compiled.element, tree.root_node(), source.as_bytes());
196 while let Some(mat) = matches.next() {
197 for capture in mat.captures {
198 let capture_name = compiled.element.capture_names()[capture.index as usize];
199 match capture_name {
200 "function" => function_count += 1,
201 "class" => class_count += 1,
202 _ => {}
203 }
204 }
205 }
206
207 tracing::debug!(language = %language, functions = function_count, classes = class_count, "parse complete");
208
209 Ok((function_count, class_count))
210 }
211}
212
213fn extract_imports_from_node(
217 node: &Node,
218 source: &str,
219 prefix: &str,
220 line: usize,
221 imports: &mut Vec<ImportInfo>,
222) {
223 match node.kind() {
224 "identifier" | "self" | "super" | "crate" => {
226 let name = source[node.start_byte()..node.end_byte()].to_string();
227 imports.push(ImportInfo {
228 module: prefix.to_string(),
229 items: vec![name],
230 line,
231 });
232 }
233 "scoped_identifier" => {
235 let item = node
236 .child_by_field_name("name")
237 .map(|n| source[n.start_byte()..n.end_byte()].to_string())
238 .unwrap_or_default();
239 let module = node
240 .child_by_field_name("path")
241 .map(|p| {
242 let path_text = source[p.start_byte()..p.end_byte()].to_string();
243 if prefix.is_empty() {
244 path_text
245 } else {
246 format!("{}::{}", prefix, path_text)
247 }
248 })
249 .unwrap_or_else(|| prefix.to_string());
250 if !item.is_empty() {
251 imports.push(ImportInfo {
252 module,
253 items: vec![item],
254 line,
255 });
256 }
257 }
258 "scoped_use_list" => {
260 let new_prefix = node
261 .child_by_field_name("path")
262 .map(|p| {
263 let path_text = source[p.start_byte()..p.end_byte()].to_string();
264 if prefix.is_empty() {
265 path_text
266 } else {
267 format!("{}::{}", prefix, path_text)
268 }
269 })
270 .unwrap_or_else(|| prefix.to_string());
271 if let Some(list) = node.child_by_field_name("list") {
272 extract_imports_from_node(&list, source, &new_prefix, line, imports);
273 }
274 }
275 "use_list" => {
277 let mut cursor = node.walk();
278 for child in node.children(&mut cursor) {
279 match child.kind() {
280 "{" | "}" | "," => {}
281 _ => extract_imports_from_node(&child, source, prefix, line, imports),
282 }
283 }
284 }
285 "use_wildcard" => {
287 let text = source[node.start_byte()..node.end_byte()].to_string();
288 let module = if let Some(stripped) = text.strip_suffix("::*") {
289 if prefix.is_empty() {
290 stripped.to_string()
291 } else {
292 format!("{}::{}", prefix, stripped)
293 }
294 } else {
295 prefix.to_string()
296 };
297 imports.push(ImportInfo {
298 module,
299 items: vec!["*".to_string()],
300 line,
301 });
302 }
303 "use_as_clause" => {
305 let alias = node
306 .child_by_field_name("alias")
307 .map(|n| source[n.start_byte()..n.end_byte()].to_string())
308 .unwrap_or_default();
309 let module = if let Some(path_node) = node.child_by_field_name("path") {
310 match path_node.kind() {
311 "scoped_identifier" => path_node
312 .child_by_field_name("path")
313 .map(|p| {
314 let p_text = source[p.start_byte()..p.end_byte()].to_string();
315 if prefix.is_empty() {
316 p_text
317 } else {
318 format!("{}::{}", prefix, p_text)
319 }
320 })
321 .unwrap_or_else(|| prefix.to_string()),
322 _ => prefix.to_string(),
323 }
324 } else {
325 prefix.to_string()
326 };
327 if !alias.is_empty() {
328 imports.push(ImportInfo {
329 module,
330 items: vec![alias],
331 line,
332 });
333 }
334 }
335 _ => {
337 let text = source[node.start_byte()..node.end_byte()]
338 .trim()
339 .to_string();
340 if !text.is_empty() {
341 imports.push(ImportInfo {
342 module: text,
343 items: vec![],
344 line,
345 });
346 }
347 }
348 }
349}
350
351pub struct SemanticExtractor;
352
353impl SemanticExtractor {
354 #[instrument(skip_all, fields(language))]
356 pub fn extract(
357 source: &str,
358 language: &str,
359 ast_recursion_limit: Option<usize>,
360 ) -> Result<SemanticAnalysis, ParserError> {
361 let lang_info = get_language_info(language)
362 .ok_or_else(|| ParserError::UnsupportedLanguage(language.to_string()))?;
363
364 let tree = PARSER.with(|p| {
365 let mut parser = p.borrow_mut();
366 parser
367 .set_language(&lang_info.language)
368 .map_err(|e| ParserError::ParseError(format!("Failed to set language: {}", e)))?;
369 parser
370 .parse(source, None)
371 .ok_or_else(|| ParserError::ParseError("Failed to parse".to_string()))
372 })?;
373
374 let mut functions = Vec::new();
375 let mut classes = Vec::new();
376 let mut imports = Vec::new();
377 let mut references = Vec::new();
378 let mut call_frequency = HashMap::new();
379 let mut calls = Vec::new();
380 let mut assignments: Vec<AssignmentInfo> = Vec::new();
381 let mut field_accesses: Vec<FieldAccessInfo> = Vec::new();
382
383 let max_depth: Option<u32> = ast_recursion_limit
385 .map(|limit| {
386 u32::try_from(limit).map_err(|_| {
387 ParserError::ParseError(format!(
388 "ast_recursion_limit {} exceeds maximum supported value {}",
389 limit,
390 u32::MAX
391 ))
392 })
393 })
394 .transpose()?;
395
396 let compiled = get_compiled_queries(language)?;
398 let mut cursor = QueryCursor::new();
399 if let Some(depth) = max_depth {
400 cursor.set_max_start_depth(Some(depth));
401 }
402
403 let mut matches = cursor.matches(&compiled.element, tree.root_node(), source.as_bytes());
404 let mut seen_functions = std::collections::HashSet::new();
405
406 while let Some(mat) = matches.next() {
407 for capture in mat.captures {
408 let capture_name = compiled.element.capture_names()[capture.index as usize];
409 let node = capture.node;
410
411 match capture_name {
412 "function" => {
413 if let Some(name_node) = node.child_by_field_name("name") {
414 let name =
415 source[name_node.start_byte()..name_node.end_byte()].to_string();
416 let func_key = (name.clone(), node.start_position().row);
417
418 if !seen_functions.contains(&func_key) {
419 seen_functions.insert(func_key);
420
421 let params = node
422 .child_by_field_name("parameters")
423 .map(|p| source[p.start_byte()..p.end_byte()].to_string())
424 .unwrap_or_default();
425 let return_type = node
426 .child_by_field_name("return_type")
427 .map(|r| source[r.start_byte()..r.end_byte()].to_string());
428
429 functions.push(FunctionInfo {
430 name,
431 line: node.start_position().row + 1,
432 end_line: node.end_position().row + 1,
433 parameters: if params.is_empty() {
434 Vec::new()
435 } else {
436 vec![params]
437 },
438 return_type,
439 });
440 }
441 }
442 }
443 "class" => {
444 if let Some(name_node) = node.child_by_field_name("name") {
445 let name =
446 source[name_node.start_byte()..name_node.end_byte()].to_string();
447 let inherits = if let Some(handler) = lang_info.extract_inheritance {
448 handler(&node, source)
449 } else {
450 Vec::new()
451 };
452 classes.push(ClassInfo {
453 name,
454 line: node.start_position().row + 1,
455 end_line: node.end_position().row + 1,
456 methods: Vec::new(),
457 fields: Vec::new(),
458 inherits,
459 });
460 }
461 }
462 _ => {}
463 }
464 }
465 }
466
467 let mut cursor = QueryCursor::new();
469 if let Some(depth) = max_depth {
470 cursor.set_max_start_depth(Some(depth));
471 }
472
473 let mut matches = cursor.matches(&compiled.call, tree.root_node(), source.as_bytes());
474 while let Some(mat) = matches.next() {
475 for capture in mat.captures {
476 let capture_name = compiled.call.capture_names()[capture.index as usize];
477 if capture_name == "call" {
478 let node = capture.node;
479 let call_name = source[node.start_byte()..node.end_byte()].to_string();
480 *call_frequency.entry(call_name.clone()).or_insert(0) += 1;
481
482 let mut current = node;
484 let mut caller = "<module>".to_string();
485 while let Some(parent) = current.parent() {
486 if parent.kind() == "function_item"
487 && let Some(name_node) = parent.child_by_field_name("name")
488 {
489 caller =
490 source[name_node.start_byte()..name_node.end_byte()].to_string();
491 break;
492 }
493 current = parent;
494 }
495
496 let mut arg_count = None;
498 let mut arg_node = node;
499 while let Some(parent) = arg_node.parent() {
500 if parent.kind() == "call_expression" {
501 if let Some(args) = parent.child_by_field_name("arguments") {
502 arg_count = Some(args.named_child_count());
503 }
504 break;
505 }
506 arg_node = parent;
507 }
508
509 calls.push(CallInfo {
510 caller,
511 callee: call_name,
512 line: node.start_position().row + 1,
513 column: node.start_position().column,
514 arg_count,
515 });
516 }
517 }
518 }
519
520 if let Some(ref import_query) = compiled.import {
522 let mut cursor = QueryCursor::new();
523 if let Some(depth) = max_depth {
524 cursor.set_max_start_depth(Some(depth));
525 }
526
527 let mut matches = cursor.matches(import_query, tree.root_node(), source.as_bytes());
528 while let Some(mat) = matches.next() {
529 for capture in mat.captures {
530 let capture_name = import_query.capture_names()[capture.index as usize];
531 if capture_name == "import_path" {
532 let node = capture.node;
533 let line = node.start_position().row + 1;
534 extract_imports_from_node(&node, source, "", line, &mut imports);
535 }
536 }
537 }
538 }
539
540 if let Some(ref impl_query) = compiled.impl_block {
542 let mut cursor = QueryCursor::new();
543 if let Some(depth) = max_depth {
544 cursor.set_max_start_depth(Some(depth));
545 }
546
547 let mut matches = cursor.matches(impl_query, tree.root_node(), source.as_bytes());
548 while let Some(mat) = matches.next() {
549 let mut impl_type_name = String::new();
550 let mut method_name = String::new();
551 let mut method_line = 0usize;
552 let mut method_end_line = 0usize;
553 let mut method_params = String::new();
554 let mut method_return_type: Option<String> = None;
555
556 for capture in mat.captures {
557 let capture_name = impl_query.capture_names()[capture.index as usize];
558 let node = capture.node;
559 match capture_name {
560 "impl_type" => {
561 impl_type_name = source[node.start_byte()..node.end_byte()].to_string();
562 }
563 "method_name" => {
564 method_name = source[node.start_byte()..node.end_byte()].to_string();
565 }
566 "method_params" => {
567 method_params = source[node.start_byte()..node.end_byte()].to_string();
568 }
569 "method" => {
570 method_line = node.start_position().row + 1;
571 method_end_line = node.end_position().row + 1;
572 method_return_type = node
573 .child_by_field_name("return_type")
574 .map(|r| source[r.start_byte()..r.end_byte()].to_string());
575 }
576 _ => {}
577 }
578 }
579
580 if !impl_type_name.is_empty() && !method_name.is_empty() {
581 let func = FunctionInfo {
582 name: method_name,
583 line: method_line,
584 end_line: method_end_line,
585 parameters: if method_params.is_empty() {
586 Vec::new()
587 } else {
588 vec![method_params]
589 },
590 return_type: method_return_type,
591 };
592 if let Some(class) = classes.iter_mut().find(|c| c.name == impl_type_name) {
593 class.methods.push(func);
594 }
595 }
596 }
597 }
598
599 if let Some(ref ref_query) = compiled.reference {
601 let mut cursor = QueryCursor::new();
602 if let Some(depth) = max_depth {
603 cursor.set_max_start_depth(Some(depth));
604 }
605
606 let mut seen_refs = std::collections::HashSet::new();
607 let mut matches = cursor.matches(ref_query, tree.root_node(), source.as_bytes());
608 while let Some(mat) = matches.next() {
609 for capture in mat.captures {
610 let capture_name = ref_query.capture_names()[capture.index as usize];
611 if capture_name == "type_ref" {
612 let node = capture.node;
613 let type_ref = source[node.start_byte()..node.end_byte()].to_string();
614 if seen_refs.insert(type_ref.clone()) {
615 references.push(ReferenceInfo {
616 symbol: type_ref,
617 reference_type: ReferenceType::Usage,
618 location: String::new(),
620 line: node.start_position().row + 1,
621 });
622 }
623 }
624 }
625 }
626 }
627
628 if let Some(ref assignment_query) = compiled.assignment {
630 let mut cursor = QueryCursor::new();
631 if let Some(depth) = max_depth {
632 cursor.set_max_start_depth(Some(depth));
633 }
634
635 let mut matches = cursor.matches(assignment_query, tree.root_node(), source.as_bytes());
636 while let Some(mat) = matches.next() {
637 let mut variable = String::new();
638 let mut value = String::new();
639 let mut line = 0usize;
640
641 for capture in mat.captures {
642 let capture_name = assignment_query.capture_names()[capture.index as usize];
643 let node = capture.node;
644 match capture_name {
645 "variable" => {
646 variable = source[node.start_byte()..node.end_byte()].to_string();
647 }
648 "value" => {
649 value = source[node.start_byte()..node.end_byte()].to_string();
650 line = node.start_position().row + 1;
651 }
652 _ => {}
653 }
654 }
655
656 if !variable.is_empty() && !value.is_empty() {
657 let mut current = mat.captures[0].node;
658 let mut scope = "global".to_string();
659 while let Some(parent) = current.parent() {
660 if parent.kind() == "function_item"
661 && let Some(name_node) = parent.child_by_field_name("name")
662 {
663 scope =
664 source[name_node.start_byte()..name_node.end_byte()].to_string();
665 break;
666 }
667 current = parent;
668 }
669
670 assignments.push(AssignmentInfo {
671 variable,
672 value,
673 line,
674 scope,
675 });
676 }
677 }
678 }
679
680 if let Some(ref field_query) = compiled.field {
682 let mut cursor = QueryCursor::new();
683 if let Some(depth) = max_depth {
684 cursor.set_max_start_depth(Some(depth));
685 }
686
687 let mut matches = cursor.matches(field_query, tree.root_node(), source.as_bytes());
688 while let Some(mat) = matches.next() {
689 let mut object = String::new();
690 let mut field = String::new();
691 let mut line = 0usize;
692
693 for capture in mat.captures {
694 let capture_name = field_query.capture_names()[capture.index as usize];
695 let node = capture.node;
696 match capture_name {
697 "object" => {
698 object = source[node.start_byte()..node.end_byte()].to_string();
699 }
700 "field" => {
701 field = source[node.start_byte()..node.end_byte()].to_string();
702 line = node.start_position().row + 1;
703 }
704 _ => {}
705 }
706 }
707
708 if !object.is_empty() && !field.is_empty() {
709 let mut current = mat.captures[0].node;
710 let mut scope = "global".to_string();
711 while let Some(parent) = current.parent() {
712 if parent.kind() == "function_item"
713 && let Some(name_node) = parent.child_by_field_name("name")
714 {
715 scope =
716 source[name_node.start_byte()..name_node.end_byte()].to_string();
717 break;
718 }
719 current = parent;
720 }
721
722 field_accesses.push(FieldAccessInfo {
723 object,
724 field,
725 line,
726 scope,
727 });
728 }
729 }
730 }
731
732 tracing::debug!(language = %language, functions = functions.len(), classes = classes.len(), imports = imports.len(), references = references.len(), calls = calls.len(), "extraction complete");
733
734 Ok(SemanticAnalysis {
735 functions,
736 classes,
737 imports,
738 references,
739 call_frequency,
740 calls,
741 assignments,
742 field_accesses,
743 })
744 }
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750
751 #[test]
752 fn test_extract_assignments() {
753 let source = r#"
754fn main() {
755 let x = 42;
756 let y = x + 1;
757}
758"#;
759 let result = SemanticExtractor::extract(source, "rust", None);
760 assert!(result.is_ok());
761 let analysis = result.unwrap();
762 assert!(!analysis.assignments.is_empty());
763 assert_eq!(analysis.assignments[0].variable, "x");
764 assert_eq!(analysis.assignments[0].value, "42");
765 assert_eq!(analysis.assignments[0].scope, "main");
766 }
767
768 #[test]
769 fn test_extract_field_accesses() {
770 let source = r#"
771fn process(user: &User) {
772 let name = user.name;
773 let age = user.age;
774}
775"#;
776 let result = SemanticExtractor::extract(source, "rust", None);
777 assert!(result.is_ok());
778 let analysis = result.unwrap();
779 assert!(!analysis.field_accesses.is_empty());
780 assert!(
781 analysis
782 .field_accesses
783 .iter()
784 .any(|fa| fa.object == "user" && fa.field == "name")
785 );
786 assert_eq!(analysis.field_accesses[0].scope, "process");
787 }
788}