Skip to main content

aptu_coder_core/languages/
csharp.rs

1// SPDX-FileCopyrightText: 2026 aptu-coder contributors
2// SPDX-License-Identifier: Apache-2.0
3
4use tree_sitter::Node;
5
6use super::get_node_text;
7
8/// Extract function name from a C# method or constructor declaration.
9#[must_use]
10pub fn extract_function_name(node: &Node, source: &str, _lang: &str) -> Option<String> {
11    if node.kind() != "method_declaration" && node.kind() != "constructor_declaration" {
12        return None;
13    }
14    node.child_by_field_name("name")
15        .and_then(|n| get_node_text(&n, source))
16}
17
18/// Find receiver type (enclosing class/struct/interface/record/enum) for a C# method.
19#[must_use]
20pub fn find_receiver_type(node: &Node, source: &str) -> Option<String> {
21    if node.kind() != "method_declaration" && node.kind() != "constructor_declaration" {
22        return None;
23    }
24
25    // Walk ancestors to find enclosing type declaration
26    let mut current = *node;
27    while let Some(parent) = current.parent() {
28        match parent.kind() {
29            "class_declaration"
30            | "interface_declaration"
31            | "record_declaration"
32            | "struct_declaration"
33            | "enum_declaration" => {
34                // Found the enclosing type, extract its name
35                return parent
36                    .child_by_field_name("name")
37                    .and_then(|n| get_node_text(&n, source));
38            }
39            _ => {
40                current = parent;
41            }
42        }
43    }
44
45    None
46}
47
48/// Tree-sitter query for extracting C# elements (methods, constructors, classes,
49/// interfaces, records, structs, and enums).
50pub const ELEMENT_QUERY: &str = r"
51(method_declaration name: (identifier) @method_name) @function
52(constructor_declaration name: (identifier) @ctor_name) @function
53(class_declaration name: (identifier) @class_name) @class
54(interface_declaration name: (identifier) @interface_name) @class
55(record_declaration name: (identifier) @record_name) @class
56(struct_declaration name: (identifier) @struct_name) @class
57(enum_declaration name: (identifier) @enum_name) @class
58";
59
60/// Tree-sitter query for extracting C# method invocations.
61pub const CALL_QUERY: &str = r"
62(invocation_expression
63  function: (member_access_expression name: (identifier) @call))
64(invocation_expression
65  function: (identifier) @call)
66";
67
68/// Tree-sitter query for extracting C# type references (base types, generic args).
69pub const REFERENCE_QUERY: &str = r"
70(base_list (identifier) @type_ref)
71(base_list (generic_name (identifier) @type_ref))
72(type_argument_list (identifier) @type_ref)
73(type_parameter_list (type_parameter (identifier) @type_ref))
74";
75
76/// Tree-sitter query for extracting C# `using` directives.
77///
78/// All `using` forms (namespace, `using static`, and `using alias = ...`)
79/// are represented by a single `using_directive` node kind. There are no
80/// separate `using_static_directive` or `using_alias_directive` node kinds,
81/// so one pattern captures everything.
82pub const IMPORT_QUERY: &str = r"
83(using_directive) @import_path
84";
85
86/// Tree-sitter query for extracting definition and use sites.
87pub const DEFUSE_QUERY: &str = r"
88(variable_declarator name: (identifier) @write.var)
89(assignment_expression left: (identifier) @write.assign)
90(identifier) @read.usage
91";
92
93/// Extract base class and interface names from a C# class, interface, or record node.
94///
95/// The parser calls this with the class/interface/record declaration node itself.
96/// We locate the `base_list` child and extract each base type name.
97#[must_use]
98pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
99    let mut bases = Vec::new();
100
101    // base_list is an unnamed child of class_declaration/interface_declaration/record_declaration
102    for i in 0..node.child_count() {
103        if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX))
104            && child.kind() == "base_list"
105        {
106            bases.extend(extract_base_list(&child, source));
107            break;
108        }
109    }
110
111    bases
112}
113
114/// Extract base type names from a `base_list` node.
115fn extract_base_list(node: &Node, source: &str) -> Vec<String> {
116    let mut bases = Vec::new();
117
118    for i in 0..node.named_child_count() {
119        if let Some(child) = node.named_child(u32::try_from(i).unwrap_or(u32::MAX)) {
120            match child.kind() {
121                "identifier" => {
122                    let end = child.end_byte();
123                    if end <= source.len() {
124                        bases.push(source[child.start_byte()..end].to_string());
125                    }
126                }
127                "generic_name" => {
128                    // First named child of generic_name is the identifier.
129                    if let Some(id) = child.named_child(0)
130                        && id.kind() == "identifier"
131                    {
132                        let end = id.end_byte();
133                        if end <= source.len() {
134                            bases.push(source[id.start_byte()..end].to_string());
135                        }
136                    }
137                }
138                _ => {}
139            }
140        }
141    }
142
143    bases
144}
145
146/// Return the method or constructor name when `node` is a `method_declaration`
147/// or `constructor_declaration` that is nested inside a class, interface, or
148/// record body.
149///
150/// This follows the same contract as the Rust, Go, and C++ handlers: return
151/// the **method name** (the `name` field of the declaration node), or `None`
152/// when the node is not a class-level method.
153#[must_use]
154pub fn find_method_for_receiver(
155    node: &Node,
156    source: &str,
157    _depth: Option<usize>,
158) -> Option<String> {
159    if node.kind() != "method_declaration" && node.kind() != "constructor_declaration" {
160        return None;
161    }
162
163    // Only return a name when the node is nested inside a type body.
164    let mut current = *node;
165    let mut in_type_body = false;
166    while let Some(parent) = current.parent() {
167        match parent.kind() {
168            "class_declaration"
169            | "interface_declaration"
170            | "record_declaration"
171            | "struct_declaration"
172            | "enum_declaration" => {
173                in_type_body = true;
174                break;
175            }
176            _ => {
177                current = parent;
178            }
179        }
180    }
181
182    if !in_type_body {
183        return None;
184    }
185
186    node.child_by_field_name("name")
187        .and_then(|n| get_node_text(&n, source))
188}
189
190#[cfg(all(test, feature = "lang-csharp"))]
191mod tests {
192    use super::*;
193    use crate::DefUseKind;
194    use crate::parser::SemanticExtractor;
195    use tree_sitter::Parser;
196
197    fn parse_csharp(src: &str) -> tree_sitter::Tree {
198        let mut parser = Parser::new();
199        parser
200            .set_language(&tree_sitter_c_sharp::LANGUAGE.into())
201            .expect("Error loading C# language");
202        parser.parse(src, None).expect("Failed to parse C#")
203    }
204
205    #[test]
206    fn test_extract_function_name() {
207        // Arrange: method inside a class
208        let src = "class C { void foo() {} }";
209        let tree = parse_csharp(src);
210        let root = tree.root_node();
211
212        // Find method_declaration node using stack traversal
213        let mut method_node = None;
214        let mut stack = vec![root];
215        while let Some(node) = stack.pop() {
216            if node.kind() == "method_declaration" {
217                method_node = Some(node);
218                break;
219            }
220            for i in 0..node.child_count() {
221                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
222                    stack.push(child);
223                }
224            }
225        }
226        let method_node = method_node.expect("expected method_declaration");
227
228        // Act
229        let result = extract_function_name(&method_node, src, "csharp");
230
231        // Assert
232        assert_eq!(result, Some("foo".to_string()));
233    }
234
235    #[test]
236    fn test_find_receiver_type() {
237        // Arrange: method inside a class
238        let src = "class MyClass { void bar() {} }";
239        let tree = parse_csharp(src);
240        let root = tree.root_node();
241
242        // Find method_declaration node using stack traversal
243        let mut method_node = None;
244        let mut stack = vec![root];
245        while let Some(node) = stack.pop() {
246            if node.kind() == "method_declaration" {
247                method_node = Some(node);
248                break;
249            }
250            for i in 0..node.child_count() {
251                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
252                    stack.push(child);
253                }
254            }
255        }
256        let method_node = method_node.expect("expected method_declaration");
257
258        // Act
259        let result = find_receiver_type(&method_node, src);
260
261        // Assert
262        assert_eq!(result, Some("MyClass".to_string()));
263    }
264
265    #[test]
266    fn test_csharp_method_in_class() {
267        // Arrange
268        let src = "class Foo { void Bar() { Baz(); } void Baz() {} }";
269        let tree = parse_csharp(src);
270        let root = tree.root_node();
271
272        // Act -- collect method names by reading the `name` field of each
273        // `method_declaration` node directly (testing name field extraction).
274        let mut methods: Vec<String> = Vec::new();
275        let mut stack = vec![root];
276        while let Some(node) = stack.pop() {
277            if node.kind() == "method_declaration" {
278                if let Some(name_node) = node.child_by_field_name("name") {
279                    methods.push(src[name_node.start_byte()..name_node.end_byte()].to_string());
280                }
281            }
282            for i in 0..node.child_count() {
283                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
284                    stack.push(child);
285                }
286            }
287        }
288        methods.sort();
289
290        // Assert
291        assert_eq!(methods, vec!["Bar", "Baz"]);
292    }
293
294    #[test]
295    fn test_csharp_constructor() {
296        // Arrange
297        let src = "class Foo { public Foo() {} }";
298        let tree = parse_csharp(src);
299        let root = tree.root_node();
300
301        // Act
302        let mut ctors: Vec<String> = Vec::new();
303        let mut stack = vec![root];
304        while let Some(node) = stack.pop() {
305            if node.kind() == "constructor_declaration" {
306                if let Some(name_node) = node.child_by_field_name("name") {
307                    ctors.push(src[name_node.start_byte()..name_node.end_byte()].to_string());
308                }
309            }
310            for i in 0..node.child_count() {
311                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
312                    stack.push(child);
313                }
314            }
315        }
316
317        // Assert
318        assert_eq!(ctors, vec!["Foo"]);
319    }
320
321    #[test]
322    fn test_csharp_interface() {
323        // Arrange
324        let src = "interface IBar { void Do(); }";
325        let tree = parse_csharp(src);
326        let root = tree.root_node();
327
328        // Act
329        let mut interfaces: Vec<String> = Vec::new();
330        let mut stack = vec![root];
331        while let Some(node) = stack.pop() {
332            if node.kind() == "interface_declaration" {
333                if let Some(name_node) = node.child_by_field_name("name") {
334                    interfaces.push(src[name_node.start_byte()..name_node.end_byte()].to_string());
335                }
336            }
337            for i in 0..node.child_count() {
338                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
339                    stack.push(child);
340                }
341            }
342        }
343
344        // Assert
345        assert_eq!(interfaces, vec!["IBar"]);
346    }
347
348    #[test]
349    fn test_csharp_using_directive() {
350        // Arrange
351        let src = "using System;";
352        let tree = parse_csharp(src);
353        let root = tree.root_node();
354
355        // Act
356        let mut imports: Vec<String> = Vec::new();
357        let mut stack = vec![root];
358        while let Some(node) = stack.pop() {
359            if node.kind() == "using_directive" {
360                imports.push(src[node.start_byte()..node.end_byte()].to_string());
361            }
362            for i in 0..node.child_count() {
363                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
364                    stack.push(child);
365                }
366            }
367        }
368
369        // Assert
370        assert_eq!(imports, vec!["using System;"]);
371    }
372
373    #[test]
374    fn test_csharp_async_method() {
375        // Arrange -- async modifier is a sibling of the return type; name field unchanged
376        let src = "class C { async Task Foo() { await Bar(); } Task Bar() { return Task.CompletedTask; } }";
377        let tree = parse_csharp(src);
378        let root = tree.root_node();
379
380        // Act
381        let mut methods: Vec<String> = Vec::new();
382        let mut stack = vec![root];
383        while let Some(node) = stack.pop() {
384            if node.kind() == "method_declaration" {
385                if let Some(name_node) = node.child_by_field_name("name") {
386                    methods.push(src[name_node.start_byte()..name_node.end_byte()].to_string());
387                }
388            }
389            for i in 0..node.child_count() {
390                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
391                    stack.push(child);
392                }
393            }
394        }
395
396        // Assert -- Foo must be extracted even with async modifier
397        assert!(methods.contains(&"Foo".to_string()));
398    }
399
400    #[test]
401    fn test_csharp_generic_class() {
402        // Arrange -- type_parameter_list is a child of class_declaration; class name unchanged
403        let src = "class Generic<T> { T value; }";
404        let tree = parse_csharp(src);
405        let root = tree.root_node();
406
407        // Act
408        let mut classes: Vec<String> = Vec::new();
409        let mut stack = vec![root];
410        while let Some(node) = stack.pop() {
411            if node.kind() == "class_declaration" {
412                if let Some(name_node) = node.child_by_field_name("name") {
413                    classes.push(src[name_node.start_byte()..name_node.end_byte()].to_string());
414                }
415            }
416            for i in 0..node.child_count() {
417                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
418                    stack.push(child);
419                }
420            }
421        }
422
423        // Assert -- generic name captured without type parameters, consistent with Go
424        assert_eq!(classes, vec!["Generic"]);
425    }
426
427    #[test]
428    fn test_csharp_inheritance_extraction() {
429        // Arrange
430        let src = "class Dog : Animal, ICanRun {}";
431        let tree = parse_csharp(src);
432        let root = tree.root_node();
433
434        // Act -- find base_list node under class_declaration
435        let mut base_list_node: Option<Node> = None;
436        let mut stack = vec![root];
437        while let Some(node) = stack.pop() {
438            if node.kind() == "base_list" {
439                base_list_node = Some(node);
440                break;
441            }
442            for i in 0..node.child_count() {
443                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
444                    stack.push(child);
445                }
446            }
447        }
448
449        // The parser passes the class_declaration node, not the base_list
450        let mut class_node: Option<Node> = None;
451        let mut stack2 = vec![root];
452        while let Some(node) = stack2.pop() {
453            if node.kind() == "class_declaration" {
454                class_node = Some(node);
455                break;
456            }
457            for i in 0..node.child_count() {
458                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
459                    stack2.push(child);
460                }
461            }
462        }
463        let class = class_node.expect("class_declaration not found");
464        let _ = base_list_node; // retained for context clarity
465        let bases = extract_inheritance(&class, src);
466
467        // Assert
468        assert_eq!(bases, vec!["Animal", "ICanRun"]);
469    }
470
471    #[test]
472    fn test_csharp_find_method_for_receiver() {
473        // Arrange
474        let src = "class MyClass { void MyMethod() {} }";
475        let tree = parse_csharp(src);
476        let root = tree.root_node();
477
478        // Act -- find method_declaration node and check it returns the method name
479        let mut method_node: Option<Node> = None;
480        let mut stack = vec![root];
481        while let Some(node) = stack.pop() {
482            if node.kind() == "method_declaration" {
483                method_node = Some(node);
484                break;
485            }
486            for i in 0..node.child_count() {
487                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
488                    stack.push(child);
489                }
490            }
491        }
492
493        let method = method_node.expect("method_declaration not found");
494        let name = find_method_for_receiver(&method, src, None);
495
496        // Assert -- returns the method name, not the enclosing type name
497        assert_eq!(name, Some("MyMethod".to_string()));
498    }
499
500    #[test]
501    fn test_defuse_query_write_site() {
502        // Arrange
503        let src = "class C { void M() { int b = 3; } }\n";
504        let sites =
505            SemanticExtractor::extract_def_use_for_file(src, "csharp", "b", "test.cs", None);
506        assert!(!sites.is_empty(), "defuse sites should not be empty");
507        let has_write = sites.iter().any(|s| matches!(s.kind, DefUseKind::Write));
508        assert!(has_write, "should contain a Write DefUseSite");
509    }
510}