1use super::{NavResult, NodeRef};
7use crate::error::AqlError;
8use crate::types::{NodeKind, RelativePath};
9use std::cell::RefCell;
10
11struct KindSelector {
17 kind: String,
18 field: Option<String>,
19 value: Option<String>,
20}
21
22fn parse_kind_selector(selector: &str) -> KindSelector {
24 let selector = selector.trim();
25 if let Some(bracket_start) = selector.find('[') {
26 let kind = selector[..bracket_start].trim().to_string();
27 let rest = selector[bracket_start + 1..].trim_end_matches(']').trim();
28 if let Some(eq_pos) = rest.find('=') {
29 let field = rest[..eq_pos].trim().to_string();
30 let value = rest[eq_pos + 1..]
31 .trim()
32 .trim_matches('"')
33 .trim_matches('\'')
34 .to_string();
35 KindSelector {
36 kind,
37 field: Some(field),
38 value: Some(value),
39 }
40 } else {
41 KindSelector {
42 kind,
43 field: Some(rest.to_string()),
44 value: None,
45 }
46 }
47 } else {
48 KindSelector {
49 kind: selector.to_string(),
50 field: None,
51 value: None,
52 }
53 }
54}
55
56fn matches_selector(node: &tree_sitter::Node, src: &[u8], selector: &KindSelector) -> bool {
58 if !selector.kind.is_empty() && node.kind() != selector.kind {
59 return false;
60 }
61 match (&selector.field, &selector.value) {
62 (Some(field), Some(value)) => node
63 .child_by_field_name(field.as_str())
64 .is_some_and(|child| child.utf8_text(src).unwrap_or("") == value.as_str()),
65 (Some(field), None) => node.child_by_field_name(field.as_str()).is_some(),
66 _ => true,
67 }
68}
69
70fn node_to_ref(node: &tree_sitter::Node, file: &RelativePath) -> NodeRef {
72 let start = node.start_position();
73 let end = node.end_position();
74 NodeRef {
75 file: file.clone(),
76 start_byte: node.start_byte(),
77 end_byte: node.end_byte(),
78 kind: NodeKind::from(node.kind()),
79 line: start.row + 1,
80 column: start.column,
81 end_line: end.row + 1,
82 end_column: end.column,
83 }
84}
85
86fn build_nav_result(nodes: &[tree_sitter::Node], src: &str, file: &RelativePath) -> NavResult {
88 let refs: Vec<NodeRef> = nodes.iter().map(|n| node_to_ref(n, file)).collect();
89 let source: Vec<String> = nodes
90 .iter()
91 .map(|n| {
92 src.get(n.start_byte()..n.end_byte())
93 .unwrap_or("")
94 .to_string()
95 })
96 .collect();
97 NavResult {
98 nodes: refs,
99 source,
100 }
101}
102
103#[derive(Clone, Copy, PartialEq, Eq)]
109enum Language {
110 Rust,
111 TypeScript,
112 Tsx,
113 JavaScript,
114}
115
116fn detect_language(file: &RelativePath) -> Option<Language> {
117 let path: &str = file.as_ref();
118 match std::path::Path::new(path)
119 .extension()
120 .and_then(|e| e.to_str())
121 {
122 Some("rs") => Some(Language::Rust),
123 Some("ts") | Some("mts") => Some(Language::TypeScript),
124 Some("tsx") | Some("jsx") => Some(Language::Tsx),
125 Some("js") | Some("mjs") => Some(Language::JavaScript),
126 _ => None,
127 }
128}
129
130thread_local! {
131 static RUST_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
132 static TS_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
133 static TSX_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
134}
135
136fn parse_source(source: &str, lang: Language) -> Result<tree_sitter::Tree, AqlError> {
137 let parse_with = |cell: &RefCell<Option<tree_sitter::Parser>>,
138 make_lang: fn() -> tree_sitter::Language|
139 -> Result<tree_sitter::Tree, AqlError> {
140 let mut opt = cell.borrow_mut();
141 let parser = opt.get_or_insert_with(|| {
142 let mut p = tree_sitter::Parser::new();
143 p.set_language(&make_lang())
144 .expect("Failed to set tree-sitter language");
145 p
146 });
147 parser
148 .parse(source, None)
149 .ok_or_else(|| AqlError::new("Failed to parse source"))
150 };
151
152 match lang {
153 Language::Rust => {
154 RUST_PARSER.with(|cell| parse_with(cell, || tree_sitter_rust::LANGUAGE.into()))
155 }
156 Language::TypeScript => TS_PARSER
157 .with(|cell| parse_with(cell, || tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into())),
158 Language::Tsx | Language::JavaScript => {
159 TSX_PARSER.with(|cell| parse_with(cell, || tree_sitter_typescript::LANGUAGE_TSX.into()))
160 }
161 }
162}
163
164fn find_node_by_range<'a>(
170 root: tree_sitter::Node<'a>,
171 start_byte: usize,
172 end_byte: usize,
173) -> Option<tree_sitter::Node<'a>> {
174 if root.start_byte() == start_byte && root.end_byte() == end_byte {
175 return Some(root);
176 }
177 let mut cursor = root.walk();
178 for child in root.named_children(&mut cursor) {
179 if child.start_byte() <= start_byte && child.end_byte() >= end_byte {
180 if let Some(found) = find_node_by_range(child, start_byte, end_byte) {
181 return Some(found);
182 }
183 }
184 }
185 if root.start_byte() <= start_byte && root.end_byte() >= end_byte {
187 return Some(root);
188 }
189 None
190}
191
192pub fn select_nodes(
198 source: &str,
199 file: &RelativePath,
200 scope: Option<&NodeRef>,
201 selector: &str,
202) -> Result<NavResult, AqlError> {
203 let lang = detect_language(file)
204 .ok_or_else(|| AqlError::new(format!("No parser for file extension: {file}")))?;
205 let tree = parse_source(source, lang)?;
206 let src = source.as_bytes();
207 let sel = parse_kind_selector(selector);
208
209 let search_root = match scope {
210 Some(node_ref) => {
211 find_node_by_range(tree.root_node(), node_ref.start_byte, node_ref.end_byte)
212 .ok_or_else(|| {
213 AqlError::new(format!(
214 "Could not find node at {}..{} in {file}",
215 node_ref.start_byte, node_ref.end_byte
216 ))
217 })?
218 }
219 None => tree.root_node(),
220 };
221
222 let mut matches = Vec::new();
223 collect_matching_descendants(&search_root, src, &sel, &mut matches);
224 Ok(build_nav_result(&matches, source, file))
225}
226
227fn collect_matching_descendants<'a>(
229 node: &tree_sitter::Node<'a>,
230 src: &[u8],
231 selector: &KindSelector,
232 result: &mut Vec<tree_sitter::Node<'a>>,
233) {
234 let mut cursor = node.walk();
235 for child in node.named_children(&mut cursor) {
236 if matches_selector(&child, src, selector) {
237 result.push(child);
238 }
239 collect_matching_descendants(&child, src, selector, result);
240 }
241}
242
243pub fn expand_node(
245 source: &str,
246 file: &RelativePath,
247 node_ref: &NodeRef,
248 selector: Option<&str>,
249) -> Result<NavResult, AqlError> {
250 let lang = detect_language(file)
251 .ok_or_else(|| AqlError::new(format!("No parser for file extension: {file}")))?;
252 let tree = parse_source(source, lang)?;
253 let src = source.as_bytes();
254
255 let target = find_node_by_range(tree.root_node(), node_ref.start_byte, node_ref.end_byte)
256 .ok_or_else(|| {
257 AqlError::new(format!(
258 "Could not find node at {}..{} in {file}",
259 node_ref.start_byte, node_ref.end_byte
260 ))
261 })?;
262
263 let sel = selector.map(parse_kind_selector);
264
265 let mut current = target.parent();
266 while let Some(parent) = current {
267 match &sel {
268 Some(s) if !matches_selector(&parent, src, s) => {
269 current = parent.parent();
270 }
271 _ => {
272 return Ok(build_nav_result(&[parent], source, file));
273 }
274 }
275 }
276
277 Ok(NavResult {
278 nodes: vec![],
279 source: vec![],
280 })
281}
282
283pub fn shrink_node(
285 source: &str,
286 file: &RelativePath,
287 node_ref: &NodeRef,
288 selector: Option<&str>,
289) -> Result<NavResult, AqlError> {
290 let lang = detect_language(file)
291 .ok_or_else(|| AqlError::new(format!("No parser for file extension: {file}")))?;
292 let tree = parse_source(source, lang)?;
293 let src = source.as_bytes();
294
295 let target = find_node_by_range(tree.root_node(), node_ref.start_byte, node_ref.end_byte)
296 .ok_or_else(|| {
297 AqlError::new(format!(
298 "Could not find node at {}..{} in {file}",
299 node_ref.start_byte, node_ref.end_byte
300 ))
301 })?;
302
303 let mut children = Vec::new();
304 let mut cursor = target.walk();
305 match selector.map(parse_kind_selector) {
306 Some(sel) => {
307 for child in target.named_children(&mut cursor) {
308 if matches_selector(&child, src, &sel) {
309 children.push(child);
310 }
311 }
312 }
313 None => {
314 for child in target.named_children(&mut cursor) {
315 children.push(child);
316 }
317 }
318 }
319
320 Ok(build_nav_result(&children, source, file))
321}
322
323pub fn next_node(
325 source: &str,
326 file: &RelativePath,
327 node_ref: &NodeRef,
328 selector: Option<&str>,
329) -> Result<NavResult, AqlError> {
330 let lang = detect_language(file)
331 .ok_or_else(|| AqlError::new(format!("No parser for file extension: {file}")))?;
332 let tree = parse_source(source, lang)?;
333 let src = source.as_bytes();
334
335 let target = find_node_by_range(tree.root_node(), node_ref.start_byte, node_ref.end_byte)
336 .ok_or_else(|| {
337 AqlError::new(format!(
338 "Could not find node at {}..{} in {file}",
339 node_ref.start_byte, node_ref.end_byte
340 ))
341 })?;
342
343 let sel = selector.map(parse_kind_selector);
344 let mut current = target.next_named_sibling();
345 while let Some(sibling) = current {
346 match &sel {
347 Some(s) if !matches_selector(&sibling, src, s) => {
348 current = sibling.next_named_sibling();
349 }
350 _ => {
351 return Ok(build_nav_result(&[sibling], source, file));
352 }
353 }
354 }
355
356 Ok(NavResult {
357 nodes: vec![],
358 source: vec![],
359 })
360}
361
362pub fn prev_node(
364 source: &str,
365 file: &RelativePath,
366 node_ref: &NodeRef,
367 selector: Option<&str>,
368) -> Result<NavResult, AqlError> {
369 let lang = detect_language(file)
370 .ok_or_else(|| AqlError::new(format!("No parser for file extension: {file}")))?;
371 let tree = parse_source(source, lang)?;
372 let src = source.as_bytes();
373
374 let target = find_node_by_range(tree.root_node(), node_ref.start_byte, node_ref.end_byte)
375 .ok_or_else(|| {
376 AqlError::new(format!(
377 "Could not find node at {}..{} in {file}",
378 node_ref.start_byte, node_ref.end_byte
379 ))
380 })?;
381
382 let sel = selector.map(parse_kind_selector);
383 let mut current = target.prev_named_sibling();
384 while let Some(sibling) = current {
385 match &sel {
386 Some(s) if !matches_selector(&sibling, src, s) => {
387 current = sibling.prev_named_sibling();
388 }
389 _ => {
390 return Ok(build_nav_result(&[sibling], source, file));
391 }
392 }
393 }
394
395 Ok(NavResult {
396 nodes: vec![],
397 source: vec![],
398 })
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 const TS_SOURCE: &str = r#"
406function greet(name: string): string {
407 return `Hello, ${name}!`;
408}
409
410async function fetchUser(id: number): Promise<User> {
411 const response = await fetch(`/api/users/${id}`);
412 return response.json();
413}
414
415class UserService {
416 private baseUrl: string;
417
418 constructor(baseUrl: string) {
419 this.baseUrl = baseUrl;
420 }
421
422 async getById(id: number): Promise<User> {
423 return fetchUser(id);
424 }
425
426 async create(data: UserInput): Promise<User> {
427 const response = await fetch(this.baseUrl, {
428 method: 'POST',
429 body: JSON.stringify(data),
430 });
431 return response.json();
432 }
433}
434
435export const MAX_RETRIES = 3;
436"#;
437
438 const RUST_SOURCE: &str = r#"
439pub fn parse_selector(input: &str) -> Result<SelectorAst, AqlError> {
440 let trimmed = input.trim();
441 if trimmed.is_empty() {
442 return Err(AqlError::new("Empty selector"));
443 }
444 Ok(SelectorAst { compounds: vec![] })
445}
446
447pub struct SelectorAst {
448 pub compounds: Vec<CompoundSelector>,
449}
450
451pub enum Combinator {
452 Child,
453 Descendant,
454}
455
456impl SelectorAst {
457 pub fn is_empty(&self) -> bool {
458 self.compounds.is_empty()
459 }
460}
461"#;
462
463 fn ts_file() -> RelativePath {
464 RelativePath::from("test.ts")
465 }
466
467 fn rs_file() -> RelativePath {
468 RelativePath::from("test.rs")
469 }
470
471 #[test]
472 fn select_finds_function_declarations() {
473 let result = select_nodes(TS_SOURCE, &ts_file(), None, "function_declaration").unwrap();
475
476 assert_eq!(result.nodes.len(), 2, "should find 2 function declarations");
478 assert!(
479 result.source[0].contains("function greet"),
480 "first function should be greet"
481 );
482 assert!(
483 result.source[1].contains("async function fetchUser"),
484 "second function should be fetchUser"
485 );
486 }
487
488 #[test]
489 fn select_finds_class_declarations() {
490 let result = select_nodes(TS_SOURCE, &ts_file(), None, "class_declaration").unwrap();
492
493 assert_eq!(result.nodes.len(), 1, "should find 1 class declaration");
495 assert!(
496 result.source[0].contains("class UserService"),
497 "should be UserService"
498 );
499 }
500
501 #[test]
502 fn select_within_scope() {
503 let classes = select_nodes(TS_SOURCE, &ts_file(), None, "class_declaration").unwrap();
505 let class_ref = &classes.nodes[0];
506
507 let methods =
509 select_nodes(TS_SOURCE, &ts_file(), Some(class_ref), "method_definition").unwrap();
510
511 assert_eq!(
513 methods.nodes.len(),
514 3,
515 "UserService has constructor + 2 methods"
516 );
517 }
518
519 #[test]
520 fn select_with_field_predicate() {
521 let result = select_nodes(
523 TS_SOURCE,
524 &ts_file(),
525 None,
526 r#"function_declaration[name=greet]"#,
527 )
528 .unwrap();
529
530 assert_eq!(result.nodes.len(), 1, "should find exactly greet");
532 assert!(
533 result.source[0].contains("function greet"),
534 "should be the greet function"
535 );
536 }
537
538 #[test]
539 fn expand_returns_parent() {
540 let methods = select_nodes(TS_SOURCE, &ts_file(), None, "method_definition").unwrap();
542 let method_ref = &methods.nodes[0];
543
544 let result = expand_node(TS_SOURCE, &ts_file(), method_ref, None).unwrap();
546
547 assert_eq!(result.nodes.len(), 1, "should find parent");
549 assert!(
551 result.nodes[0].kind == "class_body",
552 "parent should be class_body, got: {}",
553 result.nodes[0].kind
554 );
555 }
556
557 #[test]
558 fn expand_with_selector_finds_ancestor() {
559 let methods = select_nodes(TS_SOURCE, &ts_file(), None, "method_definition").unwrap();
561 let method_ref = &methods.nodes[0];
562
563 let result =
565 expand_node(TS_SOURCE, &ts_file(), method_ref, Some("class_declaration")).unwrap();
566
567 assert_eq!(result.nodes.len(), 1, "should find class ancestor");
569 assert_eq!(
570 result.nodes[0].kind, "class_declaration",
571 "should be class_declaration"
572 );
573 }
574
575 #[test]
576 fn shrink_returns_children() {
577 let classes = select_nodes(TS_SOURCE, &ts_file(), None, "class_declaration").unwrap();
579 let class_ref = &classes.nodes[0];
580
581 let result = shrink_node(TS_SOURCE, &ts_file(), class_ref, None).unwrap();
583
584 assert!(
586 result.nodes.len() >= 2,
587 "class should have at least name and body children"
588 );
589 }
590
591 #[test]
592 fn shrink_with_selector() {
593 let classes = select_nodes(TS_SOURCE, &ts_file(), None, "class_declaration").unwrap();
595 let class_ref = &classes.nodes[0];
596
597 let result = shrink_node(TS_SOURCE, &ts_file(), class_ref, Some("class_body")).unwrap();
599
600 assert_eq!(result.nodes.len(), 1, "should find class_body child");
602 assert_eq!(result.nodes[0].kind, "class_body", "should be class_body");
603 }
604
605 #[test]
606 fn next_returns_sibling() {
607 let funcs = select_nodes(TS_SOURCE, &ts_file(), None, "function_declaration").unwrap();
609 let first = &funcs.nodes[0];
610
611 let result = next_node(TS_SOURCE, &ts_file(), first, None).unwrap();
613
614 assert_eq!(result.nodes.len(), 1, "should find next sibling");
616 assert!(
617 result.source[0].contains("fetchUser"),
618 "next function should be fetchUser"
619 );
620 }
621
622 #[test]
623 fn next_with_selector_skips_non_matching() {
624 let funcs = select_nodes(TS_SOURCE, &ts_file(), None, "function_declaration").unwrap();
626 let first = &funcs.nodes[0];
627
628 let result = next_node(TS_SOURCE, &ts_file(), first, Some("class_declaration")).unwrap();
630
631 assert_eq!(result.nodes.len(), 1, "should find class");
633 assert_eq!(
634 result.nodes[0].kind, "class_declaration",
635 "should be class_declaration"
636 );
637 }
638
639 #[test]
640 fn prev_returns_sibling() {
641 let funcs = select_nodes(TS_SOURCE, &ts_file(), None, "function_declaration").unwrap();
643 let second = &funcs.nodes[1];
644
645 let result = prev_node(TS_SOURCE, &ts_file(), second, None).unwrap();
647
648 assert_eq!(result.nodes.len(), 1, "should find prev sibling");
650 assert!(
651 result.source[0].contains("function greet"),
652 "prev function should be greet"
653 );
654 }
655
656 #[test]
657 fn select_rust_functions() {
658 let result = select_nodes(RUST_SOURCE, &rs_file(), None, "function_item").unwrap();
660
661 assert_eq!(
663 result.nodes.len(),
664 2,
665 "should find parse_selector and is_empty"
666 );
667 }
668
669 #[test]
670 fn select_rust_structs() {
671 let result = select_nodes(RUST_SOURCE, &rs_file(), None, "struct_item").unwrap();
673
674 assert_eq!(result.nodes.len(), 1, "should find SelectorAst");
676 assert!(
677 result.source[0].contains("struct SelectorAst"),
678 "should be SelectorAst"
679 );
680 }
681
682 #[test]
683 fn empty_result_for_no_matches() {
684 let result = select_nodes(TS_SOURCE, &ts_file(), None, "trait_item").unwrap();
686
687 assert_eq!(result.nodes.len(), 0, "TS has no trait_item nodes");
689 }
690
691 #[test]
692 fn node_ref_byte_ranges_are_precise() {
693 let result = select_nodes(TS_SOURCE, &ts_file(), None, "function_declaration").unwrap();
695 let node = &result.nodes[0];
696 let extracted = &TS_SOURCE[node.start_byte..node.end_byte];
697
698 assert_eq!(
700 extracted, result.source[0],
701 "byte range should produce identical source text"
702 );
703 }
704}