Skip to main content

sqry_lang_lua/
lib.rs

1//! Lua language plugin.
2//!
3//! Provides graph-native extraction via `LuaGraphBuilder`, AST parsing,
4//! and scope extraction for Lua source files.
5
6pub mod relations;
7
8pub use relations::LuaGraphBuilder;
9
10use sqry_core::ast::{Scope, ScopeId, link_nested_scopes};
11use sqry_core::plugin::{
12    LanguageMetadata, LanguagePlugin,
13    error::{ParseError, ScopeError},
14};
15use std::path::Path;
16use tree_sitter::{Language, Parser, Query, QueryCursor, StreamingIterator, Tree};
17
18const LANGUAGE_ID: &str = "lua";
19const LANGUAGE_NAME: &str = "Lua";
20const TREE_SITTER_VERSION: &str = "0.2.0";
21
22/// Lua plugin implementation.
23pub struct LuaPlugin {
24    graph_builder: LuaGraphBuilder,
25}
26
27impl LuaPlugin {
28    /// Creates a new Lua plugin instance.
29    #[must_use]
30    pub fn new() -> Self {
31        Self {
32            graph_builder: LuaGraphBuilder::default(),
33        }
34    }
35}
36
37impl Default for LuaPlugin {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl LanguagePlugin for LuaPlugin {
44    fn metadata(&self) -> LanguageMetadata {
45        LanguageMetadata {
46            id: LANGUAGE_ID,
47            name: LANGUAGE_NAME,
48            version: env!("CARGO_PKG_VERSION"),
49            author: "Verivus Pty Ltd",
50            description: "Lua language support for sqry",
51            tree_sitter_version: TREE_SITTER_VERSION,
52        }
53    }
54
55    fn extensions(&self) -> &'static [&'static str] {
56        &["lua", "rockspec"]
57    }
58
59    fn language(&self) -> Language {
60        tree_sitter_lua::LANGUAGE.into()
61    }
62
63    fn parse_ast(&self, content: &[u8]) -> Result<Tree, ParseError> {
64        let mut parser = Parser::new();
65        parser
66            .set_language(&self.language())
67            .map_err(|e| ParseError::LanguageSetFailed(e.to_string()))?;
68
69        parser
70            .parse(content, None)
71            .ok_or(ParseError::TreeSitterFailed)
72    }
73
74    fn extract_scopes(
75        &self,
76        tree: &Tree,
77        content: &[u8],
78        file_path: &Path,
79    ) -> Result<Vec<Scope>, ScopeError> {
80        Self::extract_lua_scopes(tree, content, file_path)
81    }
82
83    fn graph_builder(&self) -> Option<&dyn sqry_core::graph::GraphBuilder> {
84        Some(&self.graph_builder)
85    }
86}
87
88impl LuaPlugin {
89    /// Extract scopes from Lua source using tree-sitter queries.
90    fn extract_lua_scopes(
91        tree: &Tree,
92        content: &[u8],
93        file_path: &Path,
94    ) -> Result<Vec<Scope>, ScopeError> {
95        let root_node = tree.root_node();
96        let language = tree_sitter_lua::LANGUAGE.into();
97
98        // Lua scope query: function definitions (both styles).
99        let scope_query = r"
100; Function declarations (function name() ... end)
101(function_declaration
102  name: [
103    (identifier) @function.name
104    (dot_index_expression) @function.name
105    (method_index_expression) @function.name
106  ]
107) @function.type
108
109; Function definitions in assignments (local f = function() ... end)
110(function_definition) @anonymous_function.type
111";
112
113        let query = Query::new(&language, scope_query)
114            .map_err(|e| ScopeError::QueryCompilationFailed(e.to_string()))?;
115
116        let mut scopes = Vec::new();
117        let mut cursor = QueryCursor::new();
118        let mut query_matches = cursor.matches(&query, root_node, content);
119
120        while let Some(m) = query_matches.next() {
121            let mut scope_type = None;
122            let mut scope_name = None;
123            let mut scope_start = None;
124            let mut scope_end = None;
125
126            for capture in m.captures {
127                let capture_name = query.capture_names()[capture.index as usize];
128                let node = capture.node;
129
130                let capture_ext = std::path::Path::new(capture_name)
131                    .extension()
132                    .and_then(|ext| ext.to_str());
133
134                if capture_ext.is_some_and(|ext| ext.eq_ignore_ascii_case("type")) {
135                    scope_type = Some(capture_name.trim_end_matches(".type").to_string());
136                    scope_start = Some(node.start_position());
137                    scope_end = Some(node.end_position());
138                } else if capture_ext.is_some_and(|ext| ext.eq_ignore_ascii_case("name")) {
139                    scope_name = node
140                        .utf8_text(content)
141                        .ok()
142                        .map(std::string::ToString::to_string);
143                }
144            }
145
146            if scope_type.as_deref() == Some("anonymous_function")
147                && scope_name.is_none()
148                && let Some(start) = scope_start
149            {
150                scope_name = Some(format!("<anonymous:{}:{}>", start.row + 1, start.column));
151            }
152
153            if let (Some(stype), Some(sname), Some(start), Some(end)) =
154                (scope_type, scope_name, scope_start, scope_end)
155            {
156                let normalized_type = match stype.as_str() {
157                    "function" | "anonymous_function" => "function",
158                    other => other,
159                };
160
161                let scope = Scope {
162                    id: ScopeId::new(0),
163                    scope_type: normalized_type.to_string(),
164                    name: sname,
165                    file_path: file_path.to_path_buf(),
166                    start_line: start.row + 1,
167                    start_column: start.column,
168                    end_line: end.row + 1,
169                    end_column: end.column,
170                    parent_id: None,
171                };
172                scopes.push(scope);
173            }
174        }
175
176        scopes.sort_by_key(|s| (s.start_line, s.start_column));
177        link_nested_scopes(&mut scopes);
178        Ok(scopes)
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use std::path::PathBuf;
186
187    #[test]
188    fn test_plugin_metadata() {
189        let plugin = LuaPlugin::default();
190        let metadata = plugin.metadata();
191        assert_eq!(metadata.id, "lua");
192        assert_eq!(metadata.name, "Lua");
193    }
194
195    #[test]
196    fn test_extensions() {
197        let plugin = LuaPlugin::default();
198        assert_eq!(plugin.extensions(), &["lua", "rockspec"]);
199    }
200
201    #[test]
202    fn test_can_parse() {
203        let plugin = LuaPlugin::default();
204        let content = b"function foo() return 1 end";
205        let tree = plugin.parse_ast(content);
206        assert!(tree.is_ok());
207    }
208
209    #[test]
210    fn test_extract_scopes() {
211        let plugin = LuaPlugin::default();
212        let content = b"function foo() end\nfunction Module.bar() end\nlocal baz = function() end";
213        let file = PathBuf::from("test.lua");
214
215        let tree = plugin.parse_ast(content).expect("parse Lua");
216        let scopes = plugin.extract_scopes(&tree, content, &file).unwrap();
217
218        assert!(
219            scopes
220                .iter()
221                .any(|s| s.name == "foo" && s.scope_type == "function"),
222            "foo function scope should be extracted"
223        );
224
225        assert!(
226            scopes
227                .iter()
228                .any(|s| s.name.contains("Module") && s.scope_type == "function"),
229            "Module.bar scope should be extracted"
230        );
231
232        assert!(
233            scopes
234                .iter()
235                .any(|s| s.name.starts_with("<anonymous:") && s.scope_type == "function"),
236            "anonymous function scope should be extracted"
237        );
238    }
239}