Skip to main content

sqry_lang_elixir/
lib.rs

1//! Elixir language plugin
2//!
3//! Extracts scopes and graph relations using tree-sitter.
4
5pub mod relations;
6
7pub use relations::ElixirGraphBuilder;
8
9use sqry_core::ast::{Scope, ScopeId, link_nested_scopes};
10use sqry_core::plugin::{
11    LanguageMetadata, LanguagePlugin,
12    error::{ParseError, ScopeError},
13};
14use std::path::Path;
15use tree_sitter::{Language, Node, Parser, Tree};
16
17const LANGUAGE_ID: &str = "elixir";
18const LANGUAGE_NAME: &str = "Elixir";
19const TREE_SITTER_VERSION: &str = "0.23";
20
21/// Elixir language plugin implementation
22pub struct ElixirPlugin {
23    graph_builder: ElixirGraphBuilder,
24}
25
26impl ElixirPlugin {
27    /// Creates a new Elixir plugin instance.
28    #[must_use]
29    pub fn new() -> Self {
30        Self {
31            graph_builder: ElixirGraphBuilder::default(),
32        }
33    }
34}
35
36impl Default for ElixirPlugin {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl LanguagePlugin for ElixirPlugin {
43    fn metadata(&self) -> LanguageMetadata {
44        LanguageMetadata {
45            id: LANGUAGE_ID,
46            name: LANGUAGE_NAME,
47            version: env!("CARGO_PKG_VERSION"),
48            author: "Verivus Pty Ltd",
49            description: "Elixir language support for sqry",
50            tree_sitter_version: TREE_SITTER_VERSION,
51        }
52    }
53
54    fn extensions(&self) -> &'static [&'static str] {
55        &["ex", "exs"]
56    }
57
58    fn language(&self) -> Language {
59        tree_sitter_elixir_sqry::language()
60    }
61
62    fn parse_ast(&self, content: &[u8]) -> Result<Tree, ParseError> {
63        let mut parser = Parser::new();
64        parser
65            .set_language(&self.language())
66            .map_err(|e| ParseError::LanguageSetFailed(e.to_string()))?;
67
68        parser
69            .parse(content, None)
70            .ok_or(ParseError::TreeSitterFailed)
71    }
72
73    fn extract_scopes(
74        &self,
75        tree: &Tree,
76        content: &[u8],
77        file_path: &Path,
78    ) -> Result<Vec<Scope>, ScopeError> {
79        Ok(Self::extract_elixir_scopes(tree, content, file_path))
80    }
81
82    fn graph_builder(&self) -> Option<&dyn sqry_core::graph::GraphBuilder> {
83        Some(&self.graph_builder)
84    }
85}
86
87impl ElixirPlugin {
88    /// Extract scopes from Elixir source using AST traversal
89    ///
90    /// Elixir represents all macros (defmodule, def, etc.) as `call` nodes,
91    /// so we traverse the AST to find scope-creating constructs.
92    fn extract_elixir_scopes(tree: &Tree, content: &[u8], file_path: &Path) -> Vec<Scope> {
93        let mut scopes = Vec::new();
94        Self::collect_scopes_from_node(tree.root_node(), content, file_path, &mut scopes);
95
96        // Sort by (start_line, start_column) for link_nested_scopes
97        scopes.sort_by_key(|s| (s.start_line, s.start_column));
98
99        link_nested_scopes(&mut scopes);
100        scopes
101    }
102
103    fn collect_scopes_from_node(
104        node: Node<'_>,
105        content: &[u8],
106        file_path: &Path,
107        scopes: &mut Vec<Scope>,
108    ) {
109        if node.kind() == "call" {
110            // Check identifier or target field for macro name
111            let macro_name = node
112                .child_by_field_name("identifier")
113                .or_else(|| node.child_by_field_name("target"))
114                .and_then(|n| n.utf8_text(content).ok());
115
116            if let Some(name) = macro_name {
117                let (scope_type, scope_name) = match name {
118                    "defmodule" | "defprotocol" | "defimpl" => {
119                        let module_name = Self::extract_module_name_for_scope(node, content);
120                        ("module", module_name)
121                    }
122                    "def" | "defp" | "defmacro" | "defmacrop" => {
123                        let func_name = Self::extract_function_name_for_scope(node, content);
124                        ("function", func_name)
125                    }
126                    _ => (name, None),
127                };
128
129                // Only create scope if we have a valid scope type
130                if matches!(scope_type, "module" | "function") {
131                    let scope_name = scope_name.unwrap_or_else(|| "<anonymous>".to_string());
132                    let start = node.start_position();
133                    let end = node.end_position();
134
135                    scopes.push(Scope {
136                        id: ScopeId::new(0), // Will be reassigned by link_nested_scopes
137                        scope_type: scope_type.to_string(),
138                        name: scope_name,
139                        file_path: file_path.to_path_buf(),
140                        start_line: start.row + 1,
141                        start_column: start.column,
142                        end_line: end.row + 1,
143                        end_column: end.column,
144                        parent_id: None,
145                    });
146                }
147            }
148        }
149
150        // Recurse into children
151        let mut cursor = node.walk();
152        for child in node.children(&mut cursor) {
153            if child.is_named() {
154                Self::collect_scopes_from_node(child, content, file_path, scopes);
155            }
156        }
157    }
158
159    fn extract_module_name_for_scope(node: Node<'_>, content: &[u8]) -> Option<String> {
160        // Look in arguments for the module alias
161        let arguments = node.child_by_field_name("arguments").or_else(|| {
162            let mut cursor = node.walk();
163            node.children(&mut cursor).find(|c| c.kind() == "arguments")
164        })?;
165
166        let mut cursor = arguments.walk();
167        arguments
168            .children(&mut cursor)
169            .find(|child| {
170                child.is_named() && matches!(child.kind(), "alias" | "identifier" | "atom")
171            })
172            .and_then(|child| child.utf8_text(content).ok())
173            .map(String::from)
174    }
175
176    fn extract_function_name_for_scope(node: Node<'_>, content: &[u8]) -> Option<String> {
177        // Look in arguments for the function head
178        let arguments = node.child_by_field_name("arguments").or_else(|| {
179            let mut cursor = node.walk();
180            node.children(&mut cursor).find(|c| c.kind() == "arguments")
181        })?;
182
183        let mut cursor = arguments.walk();
184        for child in arguments.children(&mut cursor) {
185            if !child.is_named() {
186                continue;
187            }
188            match child.kind() {
189                "call" => {
190                    // def foo(args) - get function name from call target
191                    if let Some(target) = child.child_by_field_name("target") {
192                        return target.utf8_text(content).ok().map(String::from);
193                    }
194                    // Fallback: try to find identifier in call
195                    let mut inner_cursor = child.walk();
196                    for inner in child.children(&mut inner_cursor) {
197                        if inner.is_named() && inner.kind() == "identifier" {
198                            return inner.utf8_text(content).ok().map(String::from);
199                        }
200                    }
201                }
202                "identifier" => {
203                    // def foo, do: ... - simple identifier
204                    return child.utf8_text(content).ok().map(String::from);
205                }
206                "binary_operator" => {
207                    // def foo(a, b) when ... - guard clause
208                    if let Some(left) = child.child_by_field_name("left") {
209                        if left.kind() == "call" {
210                            if let Some(target) = left.child_by_field_name("target") {
211                                return target.utf8_text(content).ok().map(String::from);
212                            }
213                        } else if left.kind() == "identifier" {
214                            return left.utf8_text(content).ok().map(String::from);
215                        }
216                    }
217                }
218                _ => {}
219            }
220        }
221        None
222    }
223}