sqruff_lsp/
lib.rs

1use ahash::AHashMap;
2use lsp_server::{Connection, Message, Request, RequestId, Response};
3use lsp_types::notification::{
4    DidChangeTextDocument, DidCloseTextDocument, DidOpenTextDocument, DidSaveTextDocument,
5    Notification, PublishDiagnostics,
6};
7use lsp_types::request::{Formatting, Request as _};
8use lsp_types::{
9    Diagnostic, DiagnosticSeverity, DidChangeTextDocumentParams, DidCloseTextDocumentParams,
10    DidOpenTextDocumentParams, DidSaveTextDocumentParams, DocumentFormattingParams,
11    InitializeParams, InitializeResult, NumberOrString, OneOf, Position, PublishDiagnosticsParams,
12    Registration, ServerCapabilities, TextDocumentIdentifier, TextDocumentItem,
13    TextDocumentSyncCapability, TextDocumentSyncKind, Uri, VersionedTextDocumentIdentifier,
14};
15use serde_json::Value;
16use sqruff_lib::core::config::FluffConfig;
17use sqruff_lib::core::linter::core::Linter;
18use wasm_bindgen::prelude::*;
19
20#[cfg(not(target_arch = "wasm32"))]
21fn load_config() -> FluffConfig {
22    FluffConfig::from_root(None, false, None).unwrap_or_default()
23}
24
25#[cfg(target_arch = "wasm32")]
26fn load_config() -> FluffConfig {
27    FluffConfig::default()
28}
29
30fn server_initialize_result() -> InitializeResult {
31    InitializeResult {
32        capabilities: ServerCapabilities {
33            text_document_sync: TextDocumentSyncCapability::Kind(TextDocumentSyncKind::FULL).into(),
34            document_formatting_provider: OneOf::Left(true).into(),
35            ..Default::default()
36        },
37        server_info: None,
38    }
39}
40
41pub struct LanguageServer {
42    linter: Linter,
43    send_diagnostics_callback: Box<dyn Fn(PublishDiagnosticsParams)>,
44    documents: AHashMap<Uri, String>,
45}
46
47#[wasm_bindgen]
48pub struct Wasm(LanguageServer);
49
50#[wasm_bindgen]
51impl Wasm {
52    #[wasm_bindgen(constructor)]
53    pub fn new(send_diagnostics_callback: js_sys::Function) -> Self {
54        console_error_panic_hook::set_once();
55
56        let send_diagnostics_callback = Box::leak(Box::new(send_diagnostics_callback));
57
58        Self(LanguageServer::new(|diagnostics| {
59            let diagnostics = serde_wasm_bindgen::to_value(&diagnostics).unwrap();
60            send_diagnostics_callback
61                .call1(&JsValue::null(), &diagnostics)
62                .unwrap();
63        }))
64    }
65
66    #[wasm_bindgen(js_name = saveRegistrationOptions)]
67    pub fn save_registration_options() -> JsValue {
68        serde_wasm_bindgen::to_value(&save_registration_options()).unwrap()
69    }
70
71    #[wasm_bindgen(js_name = updateConfig)]
72    pub fn update_config(&mut self, source: &str) {
73        *self.0.linter.config_mut() = FluffConfig::from_source(source, None);
74        self.0.recheck_files();
75    }
76
77    #[wasm_bindgen(js_name = onInitialize)]
78    pub fn on_initialize(&self) -> JsValue {
79        serde_wasm_bindgen::to_value(&server_initialize_result()).unwrap()
80    }
81
82    #[wasm_bindgen(js_name = onNotification)]
83    pub fn on_notification(&mut self, method: &str, params: JsValue) {
84        self.0
85            .on_notification(method, serde_wasm_bindgen::from_value(params).unwrap())
86    }
87
88    #[wasm_bindgen]
89    pub fn format(&mut self, uri: JsValue) -> JsValue {
90        let uri = serde_wasm_bindgen::from_value(uri).unwrap();
91        let edits = self.0.format(uri);
92        serde_wasm_bindgen::to_value(&edits).unwrap()
93    }
94
95    #[wasm_bindgen(js_name = formatSource)]
96    pub fn format_source(&mut self, source: &str) -> String {
97        self.0.format_source(source)
98    }
99}
100
101impl LanguageServer {
102    pub fn new(send_diagnostics_callback: impl Fn(PublishDiagnosticsParams) + 'static) -> Self {
103        Self {
104            linter: Linter::new(load_config(), None, None, false),
105            send_diagnostics_callback: Box::new(send_diagnostics_callback),
106            documents: AHashMap::new(),
107        }
108    }
109
110    fn on_request(&mut self, id: RequestId, method: &str, params: Value) -> Option<Response> {
111        match method {
112            Formatting::METHOD => {
113                let DocumentFormattingParams {
114                    text_document: TextDocumentIdentifier { uri },
115                    ..
116                } = serde_json::from_value(params).unwrap();
117
118                let edits = self.format(uri);
119                Some(Response::new_ok(id, edits))
120            }
121            _ => None,
122        }
123    }
124
125    fn format(&mut self, uri: Uri) -> Vec<lsp_types::TextEdit> {
126        let text = self.documents.get(&uri).cloned().unwrap();
127        let new_text = self.format_source(&text);
128        self.documents.insert(uri.clone(), new_text.clone());
129        Self::build_edits(new_text)
130    }
131
132    fn format_source(&mut self, source: &str) -> String {
133        let tree = self.linter.lint_string(source, None, true);
134        tree.fix_string()
135    }
136
137    fn build_edits(new_text: String) -> Vec<lsp_types::TextEdit> {
138        let start_position = Position {
139            line: 0,
140            character: 0,
141        };
142        let end_position = Position {
143            line: new_text.lines().count() as u32,
144            character: new_text.chars().count() as u32,
145        };
146
147        vec![lsp_types::TextEdit {
148            range: lsp_types::Range::new(start_position, end_position),
149            new_text,
150        }]
151    }
152
153    pub fn on_notification(&mut self, method: &str, params: Value) {
154        match method {
155            DidOpenTextDocument::METHOD => {
156                let params: DidOpenTextDocumentParams = serde_json::from_value(params).unwrap();
157                let TextDocumentItem {
158                    uri,
159                    language_id: _,
160                    version: _,
161                    text,
162                } = params.text_document;
163
164                self.check_file(uri.clone(), &text);
165                self.documents.insert(uri, text);
166            }
167            DidChangeTextDocument::METHOD => {
168                let params: DidChangeTextDocumentParams = serde_json::from_value(params).unwrap();
169
170                let content = params.content_changes[0].text.clone();
171                let VersionedTextDocumentIdentifier { uri, version: _ } = params.text_document;
172
173                self.check_file(uri.clone(), &content);
174                self.documents.insert(uri, content);
175            }
176            DidCloseTextDocument::METHOD => {
177                let params: DidCloseTextDocumentParams = serde_json::from_value(params).unwrap();
178                self.documents.remove(&params.text_document.uri);
179            }
180            DidSaveTextDocument::METHOD => {
181                let params: DidSaveTextDocumentParams = serde_json::from_value(params).unwrap();
182                let uri = params.text_document.uri.as_str();
183
184                if uri.ends_with(".sqlfluff") || uri.ends_with(".sqruff") {
185                    *self.linter.config_mut() = load_config();
186
187                    self.recheck_files();
188                }
189            }
190            _ => {}
191        }
192    }
193
194    fn recheck_files(&mut self) {
195        for (uri, text) in self.documents.iter() {
196            self.check_file(uri.clone(), text);
197        }
198    }
199
200    fn check_file(&self, uri: Uri, text: &str) {
201        let result = self.linter.lint_string(text, None, false);
202
203        let diagnostics = result
204            .into_violations()
205            .into_iter()
206            .map(|violation| {
207                let range = {
208                    let pos = Position::new(
209                        (violation.line_no as u32).saturating_sub(1),
210                        (violation.line_pos as u32).saturating_sub(1),
211                    );
212                    lsp_types::Range::new(pos, pos)
213                };
214
215                let code = violation
216                    .rule
217                    .map(|rule| NumberOrString::String(rule.code.to_string()));
218
219                Diagnostic::new(
220                    range,
221                    DiagnosticSeverity::WARNING.into(),
222                    code,
223                    Some("sqruff".to_string()),
224                    violation.description,
225                    None,
226                    None,
227                )
228            })
229            .collect();
230
231        let diagnostics = PublishDiagnosticsParams::new(uri.clone(), diagnostics, None);
232        (self.send_diagnostics_callback)(diagnostics);
233    }
234}
235
236pub fn run() {
237    let (connection, io_threads) = Connection::stdio();
238    let (id, params) = connection.initialize_start().unwrap();
239
240    let init_param: InitializeParams = serde_json::from_value(params).unwrap();
241    let initialize_result = serde_json::to_value(server_initialize_result()).unwrap();
242    connection.initialize_finish(id, initialize_result).unwrap();
243
244    main_loop(connection, init_param);
245
246    io_threads.join().unwrap();
247}
248
249fn main_loop(connection: Connection, _init_param: InitializeParams) {
250    let sender = connection.sender.clone();
251    let mut lsp = LanguageServer::new(move |diagnostics| {
252        let notification = new_notification::<PublishDiagnostics>(diagnostics);
253        sender.send(Message::Notification(notification)).unwrap();
254    });
255
256    let params = save_registration_options();
257    connection
258        .sender
259        .send(Message::Request(Request::new(
260            "textDocument-didSave".to_owned().into(),
261            "client/registerCapability".to_owned(),
262            params,
263        )))
264        .unwrap();
265
266    for message in &connection.receiver {
267        match message {
268            Message::Request(request) => {
269                if connection.handle_shutdown(&request).unwrap() {
270                    return;
271                }
272
273                if let Some(response) = lsp.on_request(request.id, &request.method, request.params)
274                {
275                    connection.sender.send(Message::Response(response)).unwrap();
276                }
277            }
278            Message::Response(_) => {}
279            Message::Notification(notification) => {
280                lsp.on_notification(&notification.method, notification.params);
281            }
282        }
283    }
284}
285
286pub fn save_registration_options() -> lsp_types::RegistrationParams {
287    let save_registration_options = lsp_types::TextDocumentSaveRegistrationOptions {
288        include_text: false.into(),
289        text_document_registration_options: lsp_types::TextDocumentRegistrationOptions {
290            document_selector: Some(vec![
291                lsp_types::DocumentFilter {
292                    language: None,
293                    scheme: None,
294                    pattern: Some("**/.sqlfluff".into()),
295                },
296                lsp_types::DocumentFilter {
297                    language: None,
298                    scheme: None,
299                    pattern: Some("**/.sqruff".into()),
300                },
301            ]),
302        },
303    };
304
305    lsp_types::RegistrationParams {
306        registrations: vec![Registration {
307            id: "textDocument/didSave".into(),
308            method: "textDocument/didSave".into(),
309            register_options: serde_json::to_value(save_registration_options)
310                .unwrap()
311                .into(),
312        }],
313    }
314}
315
316fn new_notification<T>(params: T::Params) -> lsp_server::Notification
317where
318    T: Notification,
319{
320    lsp_server::Notification {
321        method: T::METHOD.to_owned(),
322        params: serde_json::to_value(&params).unwrap(),
323    }
324}