1use crate::dialect::Dialect;
2use crate::dialects::DialectRegistry;
3use crate::schema::{Schema, SchemaId, SchemaManager};
4use dashmap::DashMap;
5use std::sync::Arc;
6use tower_lsp::jsonrpc::Result;
7use tower_lsp::lsp_types::*;
8use tower_lsp::{Client, LanguageServer};
9
10#[derive(Clone)]
12struct DocumentManager {
13 documents: Arc<DashMap<String, String>>,
14}
15
16impl DocumentManager {
17 fn new() -> Self {
18 Self {
19 documents: Arc::new(DashMap::new()),
20 }
21 }
22
23 fn update(&self, uri: String, text: String) {
24 self.documents.insert(uri, text);
25 }
26
27 fn get(&self, uri: &str) -> Option<String> {
28 self.documents.get(uri).map(|v| v.clone())
29 }
30
31 fn remove(&self, uri: &str) {
32 self.documents.remove(uri);
33 }
34}
35
36pub struct SqlLspServer {
38 client: Client,
39 dialect_registry: Arc<DialectRegistry>,
41 schema_manager: Arc<SchemaManager>,
43 file_dialects: Arc<DashMap<String, String>>,
45 file_schemas: Arc<DashMap<String, SchemaId>>,
47 document_manager: DocumentManager,
49}
50
51impl SqlLspServer {
52 pub fn new(client: Client) -> Self {
53 tracing::info!("Creating new SQL LSP server instance");
54 Self {
55 client,
56 dialect_registry: Arc::new(DialectRegistry::new()),
57 schema_manager: Arc::new(SchemaManager::new()),
58 file_dialects: Arc::new(DashMap::new()),
59 file_schemas: Arc::new(DashMap::new()),
60 document_manager: DocumentManager::new(),
61 }
62 }
63
64 fn get_dialect_for_file(&self, uri: &str) -> Option<Arc<dyn Dialect>> {
66 self.file_dialects
67 .get(uri)
68 .and_then(|dialect_name| self.dialect_registry.get_by_name(dialect_name.value()))
69 }
70
71 fn get_schema_for_file(&self, uri: &str) -> Option<Schema> {
74 if let Some(schema_id) = self.file_schemas.get(uri) {
76 return self.schema_manager.get(*schema_id.value());
77 }
78
79 if let Some(text) = self.document_manager.get(uri) {
81 use crate::parser::SqlParser;
82 let mut parser = SqlParser::new();
83 let parse_result = parser.parse(&text);
84
85 if let Some(tree) = parse_result.tree {
86 let tables = parser.extract_tables(&tree, &text);
87
88 if !tables.is_empty() {
89 let best_match = self
91 .schema_manager
92 .list_ids()
93 .iter()
94 .filter_map(|&schema_id| {
95 let schema = self.schema_manager.get(schema_id)?;
96 let score = self.calculate_schema_match_score(&tables, &schema);
97 if score > 0 {
98 Some((schema_id, score))
99 } else {
100 None
101 }
102 })
103 .max_by_key(|(_, score)| *score);
104
105 if let Some((schema_id, _score)) = best_match {
106 self.file_schemas.insert(uri.to_string(), schema_id);
108 return self.schema_manager.get(schema_id);
109 }
110 }
111 }
112 }
113
114 None
115 }
116
117 fn calculate_schema_match_score(&self, tables: &[String], schema: &Schema) -> i32 {
120 let mut score = 0;
121
122 for table_name in tables {
123 if schema.tables.iter().any(|t| t.name == *table_name) {
125 score += 10;
126 } else {
127 for schema_table in &schema.tables {
129 if schema_table.name.contains(table_name)
130 || table_name.contains(&schema_table.name)
131 {
132 score += 5;
133 break; }
135 }
136 }
137 }
138
139 let matched_count = tables
141 .iter()
142 .filter(|table_name| schema.tables.iter().any(|t| t.name == **table_name))
143 .count();
144
145 if matched_count > 1 {
146 score += matched_count as i32 * 2; }
148
149 score
150 }
151
152 fn position_to_offset(&self, text: &str, position: tower_lsp::lsp_types::Position) -> usize {
154 let mut offset = 0;
155 for (line_idx, line) in text.lines().enumerate() {
156 if line_idx < position.line as usize {
157 offset += line.len() + 1; } else {
159 offset += position.character.min(line.len() as u32) as usize;
160 break;
161 }
162 }
163 offset.min(text.len())
164 }
165}
166
167#[tower_lsp::async_trait]
168impl LanguageServer for SqlLspServer {
169 async fn initialize(&self, _: InitializeParams) -> Result<InitializeResult> {
170 Ok(InitializeResult {
171 server_info: Some(ServerInfo {
172 name: "sql-lsp".to_string(),
173 version: Some("0.1.0".to_string()),
174 }),
175 capabilities: ServerCapabilities {
176 text_document_sync: Some(TextDocumentSyncCapability::Kind(
177 TextDocumentSyncKind::INCREMENTAL,
178 )),
179 completion_provider: Some(CompletionOptions {
180 resolve_provider: Some(false),
181 trigger_characters: Some(vec![
182 ".".to_string(),
183 " ".to_string(),
184 "(".to_string(),
185 ]),
186 ..Default::default()
187 }),
188 hover_provider: Some(HoverProviderCapability::Simple(true)),
189 definition_provider: Some(OneOf::Left(true)),
190 references_provider: Some(OneOf::Left(true)),
191 document_formatting_provider: Some(OneOf::Left(true)),
192 diagnostic_provider: Some(DiagnosticServerCapabilities::Options(
193 DiagnosticOptions {
194 identifier: Some("sql-lsp".to_string()),
195 inter_file_dependencies: true,
196 workspace_diagnostics: false,
197 ..Default::default()
198 },
199 )),
200 ..Default::default()
201 },
202 })
203 }
204
205 async fn initialized(&self, _: InitializedParams) {
206 tracing::info!("SQL LSP server initialized and ready");
207 self.client
208 .log_message(MessageType::INFO, "SQL LSP server initialized")
209 .await;
210 }
211
212 async fn shutdown(&self) -> Result<()> {
213 Ok(())
214 }
215
216 async fn did_change_configuration(&self, params: DidChangeConfigurationParams) {
217 tracing::debug!("Received configuration change");
218 if let Some(settings) = params.settings.as_object() {
220 if let Some(schemas_value) = settings.get("schemas") {
222 if let Ok(schemas) =
223 serde_json::from_value::<Vec<crate::schema::Schema>>(schemas_value.clone())
224 {
225 self.schema_manager.clear();
227 let count = schemas.len();
228 for schema in schemas {
229 self.schema_manager.register(schema);
230 }
231 self.client
232 .log_message(MessageType::INFO, format!("Updated {} schemas", count))
233 .await;
234 } else {
235 self.client
236 .log_message(
237 MessageType::WARNING,
238 "Failed to parse schemas configuration",
239 )
240 .await;
241 }
242 }
243
244 if let Some(file_schemas_value) = settings.get("fileSchemas") {
246 if let Some(file_schemas_obj) = file_schemas_value.as_object() {
247 for (uri, schema_id_str) in file_schemas_obj {
248 if let Some(id_str) = schema_id_str.as_str() {
249 if let Ok(schema_id) = id_str.parse::<crate::schema::SchemaId>() {
250 self.file_schemas.insert(uri.clone(), schema_id);
251 }
252 }
253 }
254 self.client
255 .log_message(MessageType::INFO, "Updated file-schema mappings")
256 .await;
257 }
258 }
259 }
260 }
261
262 async fn did_open(&self, params: DidOpenTextDocumentParams) {
263 let uri = params.text_document.uri.to_string();
264 let text = params.text_document.text.clone();
265 let language_id = params.text_document.language_id.clone();
266
267 self.document_manager.update(uri.clone(), text.clone());
269
270 let dialect_name = infer_dialect_from_uri_and_language(&uri, &language_id);
273 self.file_dialects.insert(uri.clone(), dialect_name.clone());
274
275 if let Some(dialect) = self.get_dialect_for_file(&uri) {
277 let schema = self.get_schema_for_file(&uri);
278 let diagnostics = dialect.parse(&text, schema.as_ref()).await;
279 self.client
280 .publish_diagnostics(params.text_document.uri, diagnostics, None)
281 .await;
282 }
283 }
284
285 async fn did_change(&self, params: DidChangeTextDocumentParams) {
286 let uri = params.text_document.uri.to_string();
287
288 for change in params.content_changes {
290 if let Some(range) = change.range {
291 if let Some(mut current_text) = self.document_manager.get(&uri) {
293 let start_offset = self.position_to_offset(¤t_text, range.start);
295 let end_offset = self.position_to_offset(¤t_text, range.end);
296
297 current_text.replace_range(start_offset..end_offset, &change.text);
299 self.document_manager
300 .update(uri.clone(), current_text.clone());
301
302 if let Some(dialect) = self.get_dialect_for_file(&uri) {
304 let schema = self.get_schema_for_file(&uri);
305 let diagnostics = dialect.parse(¤t_text, schema.as_ref()).await;
306 self.client
307 .publish_diagnostics(
308 params.text_document.uri.clone(),
309 diagnostics,
310 None,
311 )
312 .await;
313 }
314 }
315 } else {
316 let text = change.text.clone();
318 self.document_manager.update(uri.clone(), text.clone());
319
320 if let Some(dialect) = self.get_dialect_for_file(&uri) {
321 let schema = self.get_schema_for_file(&uri);
322 let diagnostics = dialect.parse(&text, schema.as_ref()).await;
323 self.client
324 .publish_diagnostics(params.text_document.uri.clone(), diagnostics, None)
325 .await;
326 }
327 }
328 }
329 }
330
331 async fn did_close(&self, params: DidCloseTextDocumentParams) {
332 let uri = params.text_document.uri.to_string();
333 self.document_manager.remove(&uri);
335 }
336
337 async fn completion(&self, params: CompletionParams) -> Result<Option<CompletionResponse>> {
338 let uri = params.text_document_position.text_document.uri.to_string();
339 let position = params.text_document_position.position;
340
341 let text = self.document_manager.get(&uri).unwrap_or_default();
342
343 if let Some(dialect) = self.get_dialect_for_file(&uri) {
344 let schema = self.get_schema_for_file(&uri);
345 let items = dialect.completion(&text, position, schema.as_ref()).await;
346 return Ok(Some(CompletionResponse::Array(items)));
347 }
348
349 Ok(None)
350 }
351
352 async fn hover(&self, params: HoverParams) -> Result<Option<Hover>> {
353 let uri = params
354 .text_document_position_params
355 .text_document
356 .uri
357 .to_string();
358 let position = params.text_document_position_params.position;
359
360 let text = self.document_manager.get(&uri).unwrap_or_default();
361
362 if let Some(dialect) = self.get_dialect_for_file(&uri) {
363 let schema = self.get_schema_for_file(&uri);
364 return Ok(dialect.hover(&text, position, schema.as_ref()).await);
365 }
366
367 Ok(None)
368 }
369
370 async fn goto_definition(
371 &self,
372 params: GotoDefinitionParams,
373 ) -> Result<Option<GotoDefinitionResponse>> {
374 let uri = params
375 .text_document_position_params
376 .text_document
377 .uri
378 .to_string();
379 let position = params.text_document_position_params.position;
380
381 let text = self.document_manager.get(&uri).unwrap_or_default();
382
383 if let Some(dialect) = self.get_dialect_for_file(&uri) {
384 let schema = self.get_schema_for_file(&uri);
385 if let Some(location) = dialect
386 .goto_definition(&text, position, schema.as_ref())
387 .await
388 {
389 return Ok(Some(GotoDefinitionResponse::Scalar(location)));
390 }
391 }
392
393 Ok(None)
394 }
395
396 async fn references(&self, params: ReferenceParams) -> Result<Option<Vec<Location>>> {
397 let uri = params.text_document_position.text_document.uri.to_string();
398 let position = params.text_document_position.position;
399
400 let text = self.document_manager.get(&uri).unwrap_or_default();
401
402 if let Some(dialect) = self.get_dialect_for_file(&uri) {
403 let schema = self.get_schema_for_file(&uri);
404 let locations = dialect.references(&text, position, schema.as_ref()).await;
405 return Ok(Some(locations));
406 }
407
408 Ok(None)
409 }
410
411 async fn formatting(&self, params: DocumentFormattingParams) -> Result<Option<Vec<TextEdit>>> {
412 let uri = params.text_document.uri.to_string();
413 let text = self.document_manager.get(&uri).unwrap_or_default();
414
415 if let Some(dialect) = self.get_dialect_for_file(&uri) {
416 let formatted = dialect.format(&text).await;
417 let line_count = if text.is_empty() {
418 0
419 } else {
420 text.lines().count() as u32
421 };
422 let range = Range {
423 start: Position {
424 line: 0,
425 character: 0,
426 },
427 end: Position {
428 line: line_count.saturating_sub(1),
429 character: 0,
430 },
431 };
432 return Ok(Some(vec![TextEdit {
433 range,
434 new_text: formatted,
435 }]));
436 }
437
438 Ok(None)
439 }
440}
441
442fn infer_dialect_from_uri_and_language(uri: &str, language_id: &str) -> String {
454 let uri_lower = uri.to_lowercase();
456
457 if uri_lower.ends_with(".mysql.sql") || uri_lower.ends_with(".mysql") {
458 return "mysql".to_string();
459 } else if uri_lower.ends_with(".postgres.sql") || uri_lower.ends_with(".pgsql") {
460 return "postgres".to_string();
461 } else if uri_lower.ends_with(".hive.sql") || uri_lower.ends_with(".hql") {
462 return "hive".to_string();
463 } else if uri_lower.ends_with(".es.eql") || uri_lower.ends_with(".eql") {
464 return "elasticsearch-eql".to_string();
465 } else if uri_lower.ends_with(".es.dsl")
466 || uri_lower.ends_with(".es.json")
467 || uri_lower.ends_with(".elasticsearch")
468 {
469 return "elasticsearch-dsl".to_string();
470 } else if uri_lower.ends_with(".ch.sql") || uri_lower.ends_with(".clickhouse") {
471 return "clickhouse".to_string();
472 } else if uri_lower.ends_with(".redis.sql") || uri_lower.ends_with(".redis") {
473 return "redis".to_string();
474 }
475
476 let lang_lower = language_id.to_lowercase();
478 match lang_lower.as_str() {
479 "mysql" | "mysql-sql" => "mysql".to_string(),
480 "postgresql" | "postgres" | "pgsql" => "postgres".to_string(),
481 "hive" | "hql" => "hive".to_string(),
482 "elasticsearch-eql" | "eql" => "elasticsearch-eql".to_string(),
483 "elasticsearch-dsl" | "es-dsl" | "json" if uri_lower.contains("elasticsearch") => {
484 "elasticsearch-dsl".to_string()
485 }
486 "clickhouse" | "ch" => "clickhouse".to_string(),
487 "redis" => "redis".to_string(),
488 _ => "mysql".to_string(), }
490}