1mod preprocess;
7pub mod relations;
8
9pub use relations::HaskellGraphBuilder;
10
11use preprocess::preprocess_content;
12use sqry_core::ast::{Scope, ScopeId, link_nested_scopes};
13use sqry_core::plugin::LanguageMetadata;
14use sqry_core::plugin::LanguagePlugin;
15use sqry_core::plugin::error::ScopeError;
16use std::borrow::Cow;
17use std::path::Path;
18use tree_sitter::{Language, Node, Tree};
19
20const LANGUAGE_ID: &str = "haskell";
21const LANGUAGE_NAME: &str = "Haskell";
22const TREE_SITTER_VERSION: &str = "0.23";
23
24pub struct HaskellPlugin {
26 graph_builder: HaskellGraphBuilder,
27}
28
29impl HaskellPlugin {
30 #[must_use]
32 pub fn new() -> Self {
33 Self {
34 graph_builder: HaskellGraphBuilder::default(),
35 }
36 }
37}
38
39impl Default for HaskellPlugin {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl LanguagePlugin for HaskellPlugin {
46 fn metadata(&self) -> LanguageMetadata {
47 LanguageMetadata {
48 id: LANGUAGE_ID,
49 name: LANGUAGE_NAME,
50 version: env!("CARGO_PKG_VERSION"),
51 author: "Verivus Pty Ltd",
52 description: "Haskell language support for sqry",
53 tree_sitter_version: TREE_SITTER_VERSION,
54 }
55 }
56
57 fn extensions(&self) -> &'static [&'static str] {
58 &["hs", "lhs", "hs-boot"]
59 }
60
61 fn language(&self) -> Language {
62 tree_sitter_haskell::LANGUAGE.into()
63 }
64
65 fn preprocess<'a>(&self, content: &'a [u8]) -> Cow<'a, [u8]> {
66 preprocess_content(content)
67 }
68
69 fn extract_scopes(
70 &self,
71 tree: &Tree,
72 content: &[u8],
73 file_path: &Path,
74 ) -> Result<Vec<Scope>, ScopeError> {
75 let processed = self.preprocess(content);
76 Ok(extract_haskell_scopes(tree, processed.as_ref(), file_path))
77 }
78
79 fn graph_builder(&self) -> Option<&dyn sqry_core::graph::GraphBuilder> {
80 Some(&self.graph_builder)
81 }
82}
83
84fn extract_haskell_scopes(tree: &Tree, content: &[u8], file_path: &Path) -> Vec<Scope> {
86 let mut scopes = Vec::new();
87 let root = tree.root_node();
88
89 let mut root_cursor = root.walk();
90 for child in root.children(&mut root_cursor) {
91 if child.kind() == "header" {
92 if let Some(module_name) = extract_module_name_from_header(child, content) {
93 let start = child.start_position();
94 let end = root.end_position();
95 scopes.push(Scope {
96 id: ScopeId::new(0),
97 scope_type: "module".to_string(),
98 name: module_name,
99 file_path: file_path.to_path_buf(),
100 start_line: start.row + 1,
101 start_column: start.column,
102 end_line: end.row + 1,
103 end_column: end.column,
104 parent_id: None,
105 });
106 }
107 break;
108 }
109 }
110
111 if let Some(decls) = root.child_by_field_name("declarations") {
112 collect_declaration_scopes(decls, content, file_path, &mut scopes);
113 }
114
115 scopes.sort_by_key(|s| (s.start_line, s.start_column));
116 link_nested_scopes(&mut scopes);
117 scopes
118}
119
120fn collect_declaration_scopes(
121 node: Node<'_>,
122 content: &[u8],
123 file_path: &Path,
124 scopes: &mut Vec<Scope>,
125) {
126 let mut cursor = node.walk();
127 for child in node.children(&mut cursor) {
128 let (scope_type, name_field) = match child.kind() {
129 "function" | "bind" => ("function", Some("name")),
130 "data_type" | "newtype" | "type_synomym" => ("type", Some("name")),
131 "class" => ("class", Some("name")),
132 "instance" => ("instance", Some("name")),
133 "pattern_synonym" => ("function", Some("synonym")),
134 _ => continue,
135 };
136
137 let name = name_field
138 .and_then(|field| child.child_by_field_name(field))
139 .and_then(|n| n.utf8_text(content).ok())
140 .map_or_else(|| format!("<{}>", child.kind()), |s| s.trim().to_string());
141
142 let start = child.start_position();
143 let end = child.end_position();
144
145 scopes.push(Scope {
146 id: ScopeId::new(0),
147 scope_type: scope_type.to_string(),
148 name,
149 file_path: file_path.to_path_buf(),
150 start_line: start.row + 1,
151 start_column: start.column,
152 end_line: end.row + 1,
153 end_column: end.column,
154 parent_id: None,
155 });
156 }
157}
158
159fn extract_module_name_from_header(header: Node<'_>, content: &[u8]) -> Option<String> {
160 let mut cursor = header.walk();
161 for child in header.children(&mut cursor) {
162 if matches!(child.kind(), "module" | "module_id")
163 && let Ok(text) = child.utf8_text(content)
164 && text != "module"
165 {
166 return Some(text.to_string());
167 }
168 }
169 header
170 .utf8_text(content)
171 .ok()
172 .and_then(parse_module_name_from_text)
173}
174
175fn parse_module_name_from_text(text: &str) -> Option<String> {
176 let mut tokens = text.split_whitespace();
177 while let Some(token) = tokens.next() {
178 if token == "module"
179 && let Some(name_token) = tokens.next()
180 {
181 let trimmed = name_token.trim_end_matches(['(', ';']);
182 if !trimmed.is_empty() {
183 return Some(trimmed.to_string());
184 }
185 }
186 }
187 None
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use sqry_core::plugin::LanguagePlugin;
194 use std::fs;
195 use std::path::PathBuf;
196
197 fn load_fixture(name: &str) -> (Vec<u8>, PathBuf) {
198 let path = PathBuf::from(format!("tests/fixtures/{name}"));
199 let content = fs::read(&path).expect("failed to read fixture");
200 (content, path)
201 }
202
203 fn extract_scopes_from_fixture(plugin: &HaskellPlugin, name: &str) -> Vec<Scope> {
204 let (content, path) = load_fixture(name);
205 let tree = plugin.parse_ast(&content).expect("parse fixture");
206 plugin
207 .extract_scopes(&tree, &content, &path)
208 .expect("extract scopes")
209 }
210
211 fn has_scope(scopes: &[Scope], scope_type: &str, name: &str) -> bool {
212 scopes
213 .iter()
214 .any(|scope| scope.scope_type == scope_type && scope.name == name)
215 }
216
217 #[test]
218 fn extracts_scopes_from_basic_fixture() {
219 let plugin = HaskellPlugin::default();
220 let scopes = extract_scopes_from_fixture(&plugin, "basic.hs");
221
222 assert!(has_scope(&scopes, "module", "Sample"));
223 assert!(has_scope(&scopes, "function", "foo"));
224 assert!(has_scope(&scopes, "function", "bar"));
225 assert!(has_scope(&scopes, "class", "Run"));
226 }
227
228 #[test]
229 fn parses_literate_haskell() {
230 let plugin = HaskellPlugin::default();
231 let scopes = extract_scopes_from_fixture(&plugin, "literate.lhs");
232
233 assert!(has_scope(&scopes, "module", "Literate"));
234 assert!(has_scope(&scopes, "function", "answer"));
235 }
236}