1pub mod relations;
7
8pub use relations::LuaGraphBuilder;
9
10use sqry_core::ast::{Scope, ScopeId, link_nested_scopes};
11use sqry_core::plugin::{
12 LanguageMetadata, LanguagePlugin,
13 error::{ParseError, ScopeError},
14};
15use std::path::Path;
16use tree_sitter::{Language, Parser, Query, QueryCursor, StreamingIterator, Tree};
17
18const LANGUAGE_ID: &str = "lua";
19const LANGUAGE_NAME: &str = "Lua";
20const TREE_SITTER_VERSION: &str = "0.2.0";
21
22pub struct LuaPlugin {
24 graph_builder: LuaGraphBuilder,
25}
26
27impl LuaPlugin {
28 #[must_use]
30 pub fn new() -> Self {
31 Self {
32 graph_builder: LuaGraphBuilder::default(),
33 }
34 }
35}
36
37impl Default for LuaPlugin {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl LanguagePlugin for LuaPlugin {
44 fn metadata(&self) -> LanguageMetadata {
45 LanguageMetadata {
46 id: LANGUAGE_ID,
47 name: LANGUAGE_NAME,
48 version: env!("CARGO_PKG_VERSION"),
49 author: "Verivus Pty Ltd",
50 description: "Lua language support for sqry",
51 tree_sitter_version: TREE_SITTER_VERSION,
52 }
53 }
54
55 fn extensions(&self) -> &'static [&'static str] {
56 &["lua", "rockspec"]
57 }
58
59 fn language(&self) -> Language {
60 tree_sitter_lua::LANGUAGE.into()
61 }
62
63 fn parse_ast(&self, content: &[u8]) -> Result<Tree, ParseError> {
64 let mut parser = Parser::new();
65 parser
66 .set_language(&self.language())
67 .map_err(|e| ParseError::LanguageSetFailed(e.to_string()))?;
68
69 parser
70 .parse(content, None)
71 .ok_or(ParseError::TreeSitterFailed)
72 }
73
74 fn extract_scopes(
75 &self,
76 tree: &Tree,
77 content: &[u8],
78 file_path: &Path,
79 ) -> Result<Vec<Scope>, ScopeError> {
80 Self::extract_lua_scopes(tree, content, file_path)
81 }
82
83 fn graph_builder(&self) -> Option<&dyn sqry_core::graph::GraphBuilder> {
84 Some(&self.graph_builder)
85 }
86}
87
88impl LuaPlugin {
89 fn extract_lua_scopes(
91 tree: &Tree,
92 content: &[u8],
93 file_path: &Path,
94 ) -> Result<Vec<Scope>, ScopeError> {
95 let root_node = tree.root_node();
96 let language = tree_sitter_lua::LANGUAGE.into();
97
98 let scope_query = r"
100; Function declarations (function name() ... end)
101(function_declaration
102 name: [
103 (identifier) @function.name
104 (dot_index_expression) @function.name
105 (method_index_expression) @function.name
106 ]
107) @function.type
108
109; Function definitions in assignments (local f = function() ... end)
110(function_definition) @anonymous_function.type
111";
112
113 let query = Query::new(&language, scope_query)
114 .map_err(|e| ScopeError::QueryCompilationFailed(e.to_string()))?;
115
116 let mut scopes = Vec::new();
117 let mut cursor = QueryCursor::new();
118 let mut query_matches = cursor.matches(&query, root_node, content);
119
120 while let Some(m) = query_matches.next() {
121 let mut scope_type = None;
122 let mut scope_name = None;
123 let mut scope_start = None;
124 let mut scope_end = None;
125
126 for capture in m.captures {
127 let capture_name = query.capture_names()[capture.index as usize];
128 let node = capture.node;
129
130 let capture_ext = std::path::Path::new(capture_name)
131 .extension()
132 .and_then(|ext| ext.to_str());
133
134 if capture_ext.is_some_and(|ext| ext.eq_ignore_ascii_case("type")) {
135 scope_type = Some(capture_name.trim_end_matches(".type").to_string());
136 scope_start = Some(node.start_position());
137 scope_end = Some(node.end_position());
138 } else if capture_ext.is_some_and(|ext| ext.eq_ignore_ascii_case("name")) {
139 scope_name = node
140 .utf8_text(content)
141 .ok()
142 .map(std::string::ToString::to_string);
143 }
144 }
145
146 if scope_type.as_deref() == Some("anonymous_function")
147 && scope_name.is_none()
148 && let Some(start) = scope_start
149 {
150 scope_name = Some(format!("<anonymous:{}:{}>", start.row + 1, start.column));
151 }
152
153 if let (Some(stype), Some(sname), Some(start), Some(end)) =
154 (scope_type, scope_name, scope_start, scope_end)
155 {
156 let normalized_type = match stype.as_str() {
157 "function" | "anonymous_function" => "function",
158 other => other,
159 };
160
161 let scope = Scope {
162 id: ScopeId::new(0),
163 scope_type: normalized_type.to_string(),
164 name: sname,
165 file_path: file_path.to_path_buf(),
166 start_line: start.row + 1,
167 start_column: start.column,
168 end_line: end.row + 1,
169 end_column: end.column,
170 parent_id: None,
171 };
172 scopes.push(scope);
173 }
174 }
175
176 scopes.sort_by_key(|s| (s.start_line, s.start_column));
177 link_nested_scopes(&mut scopes);
178 Ok(scopes)
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use std::path::PathBuf;
186
187 #[test]
188 fn test_plugin_metadata() {
189 let plugin = LuaPlugin::default();
190 let metadata = plugin.metadata();
191 assert_eq!(metadata.id, "lua");
192 assert_eq!(metadata.name, "Lua");
193 }
194
195 #[test]
196 fn test_extensions() {
197 let plugin = LuaPlugin::default();
198 assert_eq!(plugin.extensions(), &["lua", "rockspec"]);
199 }
200
201 #[test]
202 fn test_can_parse() {
203 let plugin = LuaPlugin::default();
204 let content = b"function foo() return 1 end";
205 let tree = plugin.parse_ast(content);
206 assert!(tree.is_ok());
207 }
208
209 #[test]
210 fn test_extract_scopes() {
211 let plugin = LuaPlugin::default();
212 let content = b"function foo() end\nfunction Module.bar() end\nlocal baz = function() end";
213 let file = PathBuf::from("test.lua");
214
215 let tree = plugin.parse_ast(content).expect("parse Lua");
216 let scopes = plugin.extract_scopes(&tree, content, &file).unwrap();
217
218 assert!(
219 scopes
220 .iter()
221 .any(|s| s.name == "foo" && s.scope_type == "function"),
222 "foo function scope should be extracted"
223 );
224
225 assert!(
226 scopes
227 .iter()
228 .any(|s| s.name.contains("Module") && s.scope_type == "function"),
229 "Module.bar scope should be extracted"
230 );
231
232 assert!(
233 scopes
234 .iter()
235 .any(|s| s.name.starts_with("<anonymous:") && s.scope_type == "function"),
236 "anonymous function scope should be extracted"
237 );
238 }
239}