Skip to main content

aptu_coder_core/languages/
fortran.rs

1// SPDX-FileCopyrightText: 2026 aptu-coder contributors
2// SPDX-License-Identifier: Apache-2.0
3/// Tree-sitter query for extracting Fortran elements (modules, functions, subroutines).
4///
5/// Module name is captured via child iteration because `module_statement`
6/// does not expose a named field for the identifier.
7/// CONTAINS sections are captured via `internal_procedures`.
8pub const ELEMENT_QUERY: &str = r"
9(subroutine
10  (subroutine_statement) @function)
11
12(function
13  (function_statement) @function)
14
15(module
16  (module_statement
17    (name) @class) @module_wrapper)
18
19(module
20  (internal_procedures
21    (subroutine
22      (subroutine_statement) @function)))
23
24(module
25  (internal_procedures
26    (function
27      (function_statement) @function)))
28";
29
30/// Tree-sitter query for extracting Fortran function calls.
31/// Includes direct calls and derived type member calls (`obj%method`).
32pub const CALL_QUERY: &str = r"
33(subroutine_call
34  (identifier) @call)
35
36(call_expression
37  (identifier) @call)
38
39(derived_type_member_expression
40  (type_member) @call)
41";
42
43/// Tree-sitter query for extracting Fortran type references.
44pub const REFERENCE_QUERY: &str = r"
45(name) @type_ref
46";
47
48/// Tree-sitter query for extracting Fortran imports (USE statements).
49pub const IMPORT_QUERY: &str = r"
50(use_statement
51  (module_name) @import_path)
52";
53
54use tree_sitter::Node;
55
56use crate::languages::get_node_text;
57
58/// Extract inheritance information from a Fortran node.
59/// Fortran does not have classical inheritance; return empty.
60#[must_use]
61pub fn extract_inheritance(_node: &Node, _source: &str) -> Vec<String> {
62    Vec::new()
63}
64
65/// Extract the name of a function or subroutine node.
66/// Both `subroutine_statement` and `function_statement` expose a named field
67/// called `name`. Return the identifier text if present.
68#[must_use]
69pub fn extract_function_name(node: &Node, source: &str, _lang: &str) -> Option<String> {
70    match node.kind() {
71        "subroutine_statement" | "function_statement" => node
72            .child_by_field_name("name")
73            .and_then(|n| get_node_text(&n, source)),
74        _ => None,
75    }
76}
77
78/// Extract the name identifier from a `module` node.
79///
80/// `module_statement` does not expose a named field for the identifier; the
81/// `name` node is an unnamed child of `module_statement`. This helper
82/// centralises the two-level child walk so callers stay readable.
83fn extract_module_name<'a>(module_node: &tree_sitter::Node<'a>, source: &str) -> Option<String> {
84    let mut cursor = module_node.walk();
85    for child in module_node.children(&mut cursor) {
86        if child.kind() == "module_statement" {
87            let mut stmt_cursor = child.walk();
88            for name_child in child.children(&mut stmt_cursor) {
89                if name_child.kind() == "name" {
90                    return get_node_text(&name_child, source);
91                }
92            }
93        }
94    }
95    None
96}
97
98/// Find the enclosing module name for a given subroutine/function node.
99/// Walk up the parent chain until a `module` node is found and return its name.
100#[must_use]
101pub fn find_receiver_type(node: &Node, source: &str) -> Option<String> {
102    if !matches!(node.kind(), "subroutine_statement" | "function_statement") {
103        return None;
104    }
105    let mut current = *node;
106    while let Some(parent) = current.parent() {
107        if parent.kind() == "module" {
108            return extract_module_name(&parent, source);
109        }
110        current = parent;
111    }
112    None
113}
114
115/// Find the method name for a subroutine/function defined inside a module.
116/// Returns the function/subroutine identifier if the node is enclosed by a module.
117#[must_use]
118pub fn find_method_for_receiver(
119    node: &Node,
120    source: &str,
121    _depth: Option<usize>,
122) -> Option<String> {
123    if !matches!(node.kind(), "subroutine_statement" | "function_statement") {
124        return None;
125    }
126    // Walk up to see if we are inside a module.
127    let mut current = *node;
128    let mut in_module = false;
129    while let Some(parent) = current.parent() {
130        if parent.kind() == "module" {
131            in_module = true;
132            break;
133        }
134        current = parent;
135    }
136    if !in_module {
137        return None;
138    }
139    node.child_by_field_name("name")
140        .and_then(|n| get_node_text(&n, source))
141}
142
143#[cfg(all(test, feature = "lang-fortran"))]
144mod tests {
145    use super::*;
146    use tree_sitter::Parser;
147    use tree_sitter::StreamingIterator;
148
149    fn find_node<'a>(root: tree_sitter::Node<'a>, kind: &str) -> Option<tree_sitter::Node<'a>> {
150        if root.kind() == kind {
151            return Some(root);
152        }
153        let mut cursor = root.walk();
154        for child in root.children(&mut cursor) {
155            if let Some(n) = find_node(child, kind) {
156                return Some(n);
157            }
158        }
159        None
160    }
161
162    fn parse_fortran(source: &str) -> (tree_sitter::Tree, Vec<u8>) {
163        let mut parser = Parser::new();
164        parser
165            .set_language(&tree_sitter_fortran::LANGUAGE.into())
166            .expect("failed to set Fortran language");
167        let source_bytes = source.as_bytes().to_vec();
168        let tree = parser.parse(&source_bytes, None).expect("failed to parse");
169        (tree, source_bytes)
170    }
171
172    fn run_query(tree: &tree_sitter::Tree, source: &str, query_str: &str) -> Vec<(String, String)> {
173        let query = tree_sitter::Query::new(&tree_sitter_fortran::LANGUAGE.into(), query_str)
174            .expect("invalid query");
175        let mut cursor = tree_sitter::QueryCursor::new();
176        let mut captures = Vec::new();
177        let mut matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
178        while let Some(m) = matches.next() {
179            for c in m.captures {
180                let node = c.node;
181                let name = query.capture_names()[c.index as usize].to_string();
182                let text = node
183                    .utf8_text(source.as_bytes())
184                    .unwrap_or_default()
185                    .to_string();
186                captures.push((name, text));
187            }
188        }
189        captures
190    }
191
192    #[test]
193    fn test_element_query_captures_module() {
194        let source = "MODULE foo\nEND MODULE foo\n";
195        let (tree, _) = parse_fortran(source);
196        let caps = run_query(&tree, source, ELEMENT_QUERY);
197        assert!(caps.iter().any(|(c, t)| c == "class" && t == "foo"));
198    }
199
200    #[test]
201    fn test_element_query_empty_module() {
202        let source = "MODULE foo\nEND MODULE foo\n";
203        let (tree, _) = parse_fortran(source);
204        let caps = run_query(&tree, source, ELEMENT_QUERY);
205        // No @function captures
206        assert!(!caps.iter().any(|(c, _)| c == "function"));
207    }
208
209    #[test]
210    fn test_element_query_captures_subroutine() {
211        let source = "SUBROUTINE bar(x)\n  x = x + 1\nEND SUBROUTINE bar\n";
212        let (tree, _) = parse_fortran(source);
213        let caps = run_query(&tree, source, ELEMENT_QUERY);
214        assert!(
215            caps.iter()
216                .any(|(c, t)| c == "function" && t.contains("bar"))
217        );
218    }
219
220    #[test]
221    fn test_element_query_captures_function() {
222        let source = "FUNCTION baz(x) RESULT(r)\n  r = x * 2\nEND FUNCTION baz\n";
223        let (tree, _) = parse_fortran(source);
224        let caps = run_query(&tree, source, ELEMENT_QUERY);
225        assert!(
226            caps.iter()
227                .any(|(c, t)| c == "function" && t.contains("baz"))
228        );
229    }
230
231    #[test]
232    fn test_element_query_module_contains_subroutine() {
233        let source =
234            "MODULE mod1\nCONTAINS\nSUBROUTINE sub1()\nEND SUBROUTINE sub1\nEND MODULE mod1\n";
235        let (tree, _) = parse_fortran(source);
236        let caps = run_query(&tree, source, ELEMENT_QUERY);
237        assert!(caps.iter().any(|(c, t)| c == "class" && t == "mod1"));
238        assert!(
239            caps.iter()
240                .any(|(c, t)| c == "function" && t.contains("sub1"))
241        );
242    }
243
244    #[test]
245    fn test_import_query_captures_use_statement() {
246        let source = "PROGRAM prog\nUSE iso_fortran_env\nEND PROGRAM prog\n";
247        let (tree, _) = parse_fortran(source);
248        let caps = run_query(&tree, source, IMPORT_QUERY);
249        assert!(
250            caps.iter()
251                .any(|(c, t)| c == "import_path" && t == "iso_fortran_env")
252        );
253    }
254
255    #[test]
256    fn test_call_query_direct_call() {
257        let source = "CALL compute(x, y)\n";
258        let (tree, _) = parse_fortran(source);
259        let caps = run_query(&tree, source, CALL_QUERY);
260        assert!(caps.iter().any(|(c, t)| c == "call" && t == "compute"));
261    }
262
263    #[test]
264    fn test_call_query_derived_type_member() {
265        let source = "CALL obj%solve(rhs)\n";
266        let (tree, _) = parse_fortran(source);
267        let caps = run_query(&tree, source, CALL_QUERY);
268        assert!(caps.iter().any(|(c, t)| c == "call" && t == "solve"));
269    }
270
271    #[test]
272    fn test_extract_function_name_subroutine() {
273        let source = "SUBROUTINE foo(a)\nEND SUBROUTINE foo\n";
274        let (tree, _) = parse_fortran(source);
275        let node = find_node(tree.root_node(), "subroutine_statement")
276            .expect("subroutine_statement not found");
277        let name = extract_function_name(&node, source, "fortran").expect("name");
278        assert_eq!(name, "foo");
279    }
280
281    #[test]
282    fn test_extract_function_name_function() {
283        let source = "FUNCTION bar(x) RESULT(r)\nEND FUNCTION bar\n";
284        let (tree, _) = parse_fortran(source);
285        let node = find_node(tree.root_node(), "function_statement")
286            .expect("function_statement not found");
287        let name = extract_function_name(&node, source, "fortran").expect("name");
288        assert_eq!(name, "bar");
289    }
290
291    #[test]
292    fn test_extract_function_name_wrong_node() {
293        let source = "MODULE foo\nEND MODULE foo\n";
294        let (tree, _) = parse_fortran(source);
295        let node = tree.root_node();
296        let name = extract_function_name(&node, source, "fortran");
297        assert!(name.is_none());
298    }
299
300    #[test]
301    fn test_find_receiver_type_module_scoped() {
302        let source =
303            "MODULE mod1\nCONTAINS\nSUBROUTINE sub1()\nEND SUBROUTINE sub1\nEND MODULE mod1\n";
304        let (tree, _) = parse_fortran(source);
305        let node = find_node(tree.root_node(), "subroutine_statement")
306            .expect("subroutine_statement not found");
307        let mod_name = find_receiver_type(&node, source).expect("module name");
308        assert_eq!(mod_name, "mod1");
309    }
310
311    #[test]
312    fn test_find_receiver_type_top_level() {
313        let source = "SUBROUTINE top()\nEND SUBROUTINE top\n";
314        let (tree, _) = parse_fortran(source);
315        let node = find_node(tree.root_node(), "subroutine_statement")
316            .expect("subroutine_statement not found");
317        let mod_name = find_receiver_type(&node, source);
318        assert!(mod_name.is_none());
319    }
320
321    #[test]
322    fn test_find_method_for_receiver_in_module() {
323        let source =
324            "MODULE mod1\nCONTAINS\nSUBROUTINE sub1()\nEND SUBROUTINE sub1\nEND MODULE mod1\n";
325        let (tree, _) = parse_fortran(source);
326        let node = find_node(tree.root_node(), "subroutine_statement")
327            .expect("subroutine_statement not found");
328        let method_name = find_method_for_receiver(&node, source, None).expect("method name");
329        assert_eq!(method_name, "sub1");
330    }
331
332    #[test]
333    fn test_find_method_for_receiver_top_level() {
334        let source = "SUBROUTINE top()\nEND SUBROUTINE top\n";
335        let (tree, _) = parse_fortran(source);
336        let node = find_node(tree.root_node(), "subroutine_statement")
337            .expect("subroutine_statement not found");
338        let method_name = find_method_for_receiver(&node, source, None);
339        assert!(method_name.is_none());
340    }
341
342    #[test]
343    fn test_extract_inheritance_returns_empty() {
344        // Arrange
345        let source = "PROGRAM test\nEND PROGRAM test\n";
346        let (tree, _source_bytes) = parse_fortran(source);
347        let root = tree.root_node();
348
349        // Act
350        let result = extract_inheritance(&root, source);
351
352        // Assert
353        assert!(result.is_empty());
354    }
355}