sql_lsp/dialects/
mysql.rs

1use crate::dialect::Dialect;
2use crate::parser::SqlParser;
3use crate::schema::Schema;
4use async_trait::async_trait;
5use tower_lsp::lsp_types::{
6    CompletionItem, CompletionItemKind, Diagnostic, Hover, Location, Position,
7};
8
9pub struct MysqlDialect {
10    parser: std::sync::Mutex<SqlParser>,
11}
12
13impl Default for MysqlDialect {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl MysqlDialect {
20    pub fn new() -> Self {
21        Self {
22            parser: std::sync::Mutex::new(SqlParser::new()),
23        }
24    }
25
26    /// 创建关键字补全项
27    fn create_keyword_item(&self, keyword: &str) -> CompletionItem {
28        CompletionItem {
29            label: keyword.to_string(),
30            kind: Some(CompletionItemKind::KEYWORD),
31            detail: Some(format!("MySQL keyword: {}", keyword)),
32            documentation: None,
33            deprecated: None,
34            preselect: None,
35            sort_text: Some(format!("0{}", keyword)),
36            filter_text: None,
37            insert_text: Some(keyword.to_string()),
38            insert_text_format: None,
39            insert_text_mode: None,
40            text_edit: None,
41            additional_text_edits: None,
42            commit_characters: None,
43            command: None,
44            data: None,
45            tags: None,
46            label_details: None,
47        }
48    }
49
50    /// 创建表补全项
51    fn create_table_item(&self, table: &crate::schema::Table) -> CompletionItem {
52        CompletionItem {
53            label: table.name.clone(),
54            kind: Some(CompletionItemKind::CLASS),
55            detail: Some(format!("Table: {}", table.name)),
56            documentation: table
57                .comment
58                .clone()
59                .map(tower_lsp::lsp_types::Documentation::String),
60            deprecated: None,
61            preselect: None,
62            sort_text: Some(format!("1{}", table.name)),
63            filter_text: None,
64            insert_text: Some(table.name.clone()),
65            insert_text_format: None,
66            insert_text_mode: None,
67            text_edit: None,
68            additional_text_edits: None,
69            commit_characters: None,
70            command: None,
71            data: None,
72            tags: None,
73            label_details: None,
74        }
75    }
76
77    /// 创建列补全项
78    fn create_column_item(
79        &self,
80        column: &crate::schema::Column,
81        table_name: Option<&str>,
82    ) -> CompletionItem {
83        let label = if let Some(table) = table_name {
84            format!("{}.{}", table, column.name)
85        } else {
86            column.name.clone()
87        };
88
89        let detail = if let Some(table) = table_name {
90            format!("Column: {}.{} ({})", table, column.name, column.data_type)
91        } else {
92            format!("Column: {} ({})", column.name, column.data_type)
93        };
94
95        CompletionItem {
96            label,
97            kind: Some(CompletionItemKind::FIELD),
98            detail: Some(detail),
99            documentation: column
100                .comment
101                .clone()
102                .map(tower_lsp::lsp_types::Documentation::String),
103            deprecated: None,
104            preselect: None,
105            sort_text: Some(format!("2{}", column.name)),
106            filter_text: None,
107            insert_text: Some(column.name.clone()),
108            insert_text_format: None,
109            insert_text_mode: None,
110            text_edit: None,
111            additional_text_edits: None,
112            commit_characters: None,
113            command: None,
114            data: None,
115            tags: None,
116            label_details: None,
117        }
118    }
119}
120
121#[async_trait]
122impl Dialect for MysqlDialect {
123    fn name(&self) -> &str {
124        "mysql"
125    }
126
127    async fn parse(&self, sql: &str, _schema: Option<&Schema>) -> Vec<Diagnostic> {
128        // 使用 Tree-sitter 进行容错 SQL 解析
129        let mut parser = self.parser.lock().unwrap();
130        let parse_result = parser.parse(sql);
131        parse_result.diagnostics
132    }
133
134    async fn completion(
135        &self,
136        sql: &str,
137        position: Position,
138        schema: Option<&Schema>,
139    ) -> Vec<CompletionItem> {
140        let mut parser = self.parser.lock().unwrap();
141        let parse_result = parser.parse(sql);
142
143        // 分析补全上下文
144        let context = if let Some(tree) = &parse_result.tree {
145            if let Some(node) = parser.get_node_at_position(tree, position) {
146                parser.analyze_completion_context(node, sql)
147            } else {
148                crate::parser::CompletionContext::Default
149            }
150        } else {
151            crate::parser::CompletionContext::Default
152        };
153
154        let mut items = Vec::new();
155
156        // 根据上下文提供不同的补全
157        match context {
158            crate::parser::CompletionContext::FromClause
159            | crate::parser::CompletionContext::JoinClause => {
160                // FROM/JOIN 子句:只补全表名和 JOIN 相关关键字
161                let join_keywords = vec!["JOIN", "INNER", "LEFT", "RIGHT", "OUTER", "ON"];
162                for keyword in join_keywords {
163                    items.push(self.create_keyword_item(keyword));
164                }
165
166                // 添加表名补全
167                if let Some(schema) = schema {
168                    for table in &schema.tables {
169                        items.push(self.create_table_item(table));
170                    }
171                }
172            }
173
174            crate::parser::CompletionContext::SelectClause => {
175                // SELECT 子句:补全列名和 SELECT 相关关键字
176                let select_keywords = vec!["SELECT", "DISTINCT", "AS", "FROM"];
177                for keyword in select_keywords {
178                    items.push(self.create_keyword_item(keyword));
179                }
180
181                // 添加列名补全
182                if let Some(schema) = schema {
183                    for table in &schema.tables {
184                        for column in &table.columns {
185                            items.push(self.create_column_item(column, Some(&table.name)));
186                        }
187                    }
188                }
189            }
190
191            crate::parser::CompletionContext::WhereClause => {
192                // WHERE 子句:补全列名、操作符、关键字
193                let where_keywords = vec![
194                    "AND", "OR", "NOT", "IN", "LIKE", "BETWEEN", "IS", "NULL", "TRUE", "FALSE",
195                ];
196                for keyword in where_keywords {
197                    items.push(self.create_keyword_item(keyword));
198                }
199
200                // 添加操作符
201                let operators = vec!["=", "<>", "!=", ">", "<", ">=", "<="];
202                for op in operators {
203                    items.push(CompletionItem {
204                        label: op.to_string(),
205                        kind: Some(CompletionItemKind::OPERATOR),
206                        detail: Some(format!("Operator: {}", op)),
207                        documentation: None,
208                        deprecated: None,
209                        preselect: None,
210                        sort_text: Some(format!("1{}", op)),
211                        filter_text: None,
212                        insert_text: Some(op.to_string()),
213                        insert_text_format: None,
214                        insert_text_mode: None,
215                        text_edit: None,
216                        additional_text_edits: None,
217                        commit_characters: None,
218                        command: None,
219                        data: None,
220                        tags: None,
221                        label_details: None,
222                    });
223                }
224
225                // 添加列名补全
226                if let Some(schema) = schema {
227                    for table in &schema.tables {
228                        for column in &table.columns {
229                            items.push(self.create_column_item(column, Some(&table.name)));
230                        }
231                    }
232                }
233            }
234
235            crate::parser::CompletionContext::OrderByClause
236            | crate::parser::CompletionContext::GroupByClause => {
237                // ORDER BY / GROUP BY:补全列名和关键字
238                let keywords = vec!["ASC", "DESC", "BY"];
239                for keyword in keywords {
240                    items.push(self.create_keyword_item(keyword));
241                }
242
243                // 添加列名补全
244                if let Some(schema) = schema {
245                    for table in &schema.tables {
246                        for column in &table.columns {
247                            items.push(self.create_column_item(column, Some(&table.name)));
248                        }
249                    }
250                }
251            }
252
253            crate::parser::CompletionContext::HavingClause => {
254                // HAVING 子句:类似 WHERE,补全列名、操作符、关键字
255                let having_keywords =
256                    vec!["AND", "OR", "NOT", "IN", "LIKE", "BETWEEN", "IS", "NULL"];
257                for keyword in having_keywords {
258                    items.push(self.create_keyword_item(keyword));
259                }
260
261                // 添加聚合函数
262                let aggregate_functions = vec!["COUNT", "SUM", "AVG", "MIN", "MAX"];
263                for func in aggregate_functions {
264                    items.push(self.create_keyword_item(func));
265                }
266
267                // 添加列名补全
268                if let Some(schema) = schema {
269                    for table in &schema.tables {
270                        for column in &table.columns {
271                            items.push(self.create_column_item(column, Some(&table.name)));
272                        }
273                    }
274                }
275            }
276
277            crate::parser::CompletionContext::TableColumn => {
278                // 表名.列名:只补全特定表的列名
279                if let Some(tree) = &parse_result.tree {
280                    if let Some(node) = parser.get_node_at_position(tree, position) {
281                        if let Some(table_name) = parser.get_table_name_for_column(node, sql) {
282                            if let Some(schema) = schema {
283                                if let Some(table) =
284                                    schema.tables.iter().find(|t| t.name == table_name)
285                                {
286                                    for column in &table.columns {
287                                        items.push(self.create_column_item(column, None));
288                                    }
289                                }
290                            }
291                        }
292                    }
293                }
294            }
295
296            crate::parser::CompletionContext::Default => {
297                // 默认:返回所有关键字
298                let keywords = vec![
299                    "SELECT", "FROM", "WHERE", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP",
300                    "ALTER", "TABLE", "INDEX", "DATABASE", "SHOW", "DESCRIBE", "EXPLAIN", "JOIN",
301                    "INNER", "LEFT", "RIGHT", "OUTER", "ON", "GROUP", "BY", "ORDER", "HAVING",
302                    "LIMIT", "OFFSET", "UNION", "ALL", "DISTINCT", "AS", "AND", "OR", "NOT", "IN",
303                    "LIKE", "BETWEEN", "IS", "NULL", "TRUE", "FALSE",
304                ];
305
306                for keyword in keywords {
307                    items.push(self.create_keyword_item(keyword));
308                }
309
310                // 如果提供了 schema,添加表和列补全
311                if let Some(schema) = schema {
312                    for table in &schema.tables {
313                        items.push(self.create_table_item(table));
314                    }
315                }
316            }
317        }
318
319        items
320    }
321
322    async fn hover(&self, sql: &str, position: Position, schema: Option<&Schema>) -> Option<Hover> {
323        let mut parser = self.parser.lock().unwrap();
324        let parse_result = parser.parse(sql);
325
326        // 获取光标位置的节点
327        if let Some(tree) = &parse_result.tree {
328            if let Some(node) = parser.get_node_at_position(tree, position) {
329                let node_text = parser.node_text(node, sql);
330                let node_kind = node.kind();
331                let node_range = parser.node_range(node);
332
333                // 过滤关键字、操作符、分隔符
334                if crate::token::Keywords::is_keyword(&node_text)
335                    || crate::token::Operators::is_operator(&node_text)
336                    || crate::token::Delimiters::is_delimiter(&node_text)
337                {
338                    return None;
339                }
340
341                if let Some(schema) = schema {
342                    // 检查是否是表名
343                    let is_table = node_kind == "table_name"
344                        || node_kind == "table_reference"
345                        || node_kind == "table_identifier"
346                        || (node_kind == "identifier" && parser.is_in_from_context(node, sql));
347
348                    if is_table {
349                        if let Some(table) = schema.tables.iter().find(|t| t.name == node_text) {
350                            let mut info = format!("**Table**: `{}`\n\n", table.name);
351                            if let Some(comment) = &table.comment {
352                                info.push_str(&format!("{}\n\n", comment));
353                            }
354                            info.push_str(&format!("**Columns** ({})\n", table.columns.len()));
355                            for (idx, col) in table.columns.iter().take(10).enumerate() {
356                                info.push_str(&format!(
357                                    "- `{}`: {} {}\n",
358                                    col.name,
359                                    col.data_type,
360                                    if col.nullable { "" } else { "NOT NULL" }
361                                ));
362                                if idx == 9 && table.columns.len() > 10 {
363                                    info.push_str(&format!(
364                                        "- ... and {} more\n",
365                                        table.columns.len() - 10
366                                    ));
367                                    break;
368                                }
369                            }
370
371                            return Some(Hover {
372                                contents: tower_lsp::lsp_types::HoverContents::Markup(
373                                    tower_lsp::lsp_types::MarkupContent {
374                                        kind: tower_lsp::lsp_types::MarkupKind::Markdown,
375                                        value: info,
376                                    },
377                                ),
378                                range: Some(node_range),
379                            });
380                        }
381                    }
382
383                    // 检查是否是列名
384                    let is_column = node_kind == "column_name"
385                        || node_kind == "column_reference"
386                        || node_kind == "column_identifier"
387                        || (node_kind == "identifier" && parser.is_in_column_context(node, sql));
388
389                    if is_column {
390                        // 尝试获取表名(如果是 table.column 格式)
391                        let table_name = parser.get_table_name_for_column(node, sql);
392
393                        for table in &schema.tables {
394                            // 如果有明确的表名,只在该表中查找
395                            if let Some(ref tname) = table_name {
396                                if table.name != *tname {
397                                    continue;
398                                }
399                            }
400
401                            if let Some(column) = table.columns.iter().find(|c| c.name == node_text)
402                            {
403                                let mut info =
404                                    format!("**Column**: `{}.{}`\n\n", table.name, column.name);
405                                info.push_str(&format!("**Type**: `{}`\n", column.data_type));
406                                info.push_str(&format!(
407                                    "**Nullable**: {}\n",
408                                    if column.nullable { "Yes" } else { "No" }
409                                ));
410                                if let Some(comment) = &column.comment {
411                                    info.push_str(&format!("\n{}\n", comment));
412                                }
413
414                                return Some(Hover {
415                                    contents: tower_lsp::lsp_types::HoverContents::Markup(
416                                        tower_lsp::lsp_types::MarkupContent {
417                                            kind: tower_lsp::lsp_types::MarkupKind::Markdown,
418                                            value: info,
419                                        },
420                                    ),
421                                    range: Some(node_range),
422                                });
423                            }
424                        }
425                    }
426
427                    // 检查是否是函数名
428                    if node_kind == "function_name" || node_kind.contains("function") {
429                        if let Some(func) = schema.functions.iter().find(|f| f.name == node_text) {
430                            let mut info = format!("**Function**: `{}`\n\n", func.name);
431                            if let Some(desc) = &func.description {
432                                info.push_str(&format!("{}\n\n", desc));
433                            }
434                            info.push_str(&format!("**Returns**: `{}`\n", func.return_type));
435                            if !func.parameters.is_empty() {
436                                info.push_str("\n**Parameters**:\n");
437                                for param in &func.parameters {
438                                    info.push_str(&format!(
439                                        "- `{}`: `{}`{}\n",
440                                        param.name,
441                                        param.data_type,
442                                        if param.optional { " (optional)" } else { "" }
443                                    ));
444                                }
445                            }
446
447                            return Some(Hover {
448                                contents: tower_lsp::lsp_types::HoverContents::Markup(
449                                    tower_lsp::lsp_types::MarkupContent {
450                                        kind: tower_lsp::lsp_types::MarkupKind::Markdown,
451                                        value: info,
452                                    },
453                                ),
454                                range: Some(node_range),
455                            });
456                        }
457                    }
458                }
459            }
460        }
461
462        None
463    }
464
465    async fn goto_definition(
466        &self,
467        sql: &str,
468        position: Position,
469        schema: Option<&Schema>,
470    ) -> Option<Location> {
471        let mut parser = self.parser.lock().unwrap();
472        let parse_result = parser.parse(sql);
473
474        // 获取光标位置的节点
475        if let Some(tree) = &parse_result.tree {
476            if let Some(node) = parser.get_node_at_position(tree, position) {
477                let node_text = parser.node_text(node, sql);
478                let node_kind = node.kind();
479
480                // 过滤关键字、操作符、分隔符
481                if crate::token::Keywords::is_keyword(&node_text)
482                    || crate::token::Operators::is_operator(&node_text)
483                    || crate::token::Delimiters::is_delimiter(&node_text)
484                {
485                    return None;
486                }
487
488                // 判断是表名还是列名
489                let is_table = node_kind == "table_name"
490                    || node_kind == "table_reference"
491                    || node_kind == "table_identifier"
492                    || (node_kind == "identifier" && parser.is_in_from_context(node, sql));
493
494                let is_column = node_kind == "column_name"
495                    || node_kind == "column_reference"
496                    || node_kind == "column_identifier"
497                    || (node_kind == "identifier" && parser.is_in_column_context(node, sql));
498
499                // 如果是表名,查找表定义
500                if is_table {
501                    if let Some(schema) = schema {
502                        if let Some(table) = schema.tables.iter().find(|t| t.name == node_text) {
503                            // 使用表的源位置(如果有)
504                            let (uri, line) = if let Some((ref source_uri, source_line)) =
505                                table.source_location
506                            {
507                                (
508                                    tower_lsp::lsp_types::Url::parse(source_uri).unwrap_or_else(
509                                        |_| {
510                                            tower_lsp::lsp_types::Url::parse("file:///schema.sql")
511                                                .unwrap()
512                                        },
513                                    ),
514                                    source_line.saturating_sub(1), // 转换为0-indexed
515                                )
516                            } else if let Some(ref schema_uri) = schema.source_uri {
517                                // 回退到 schema 的源文件
518                                (
519                                    tower_lsp::lsp_types::Url::parse(schema_uri).unwrap_or_else(
520                                        |_| {
521                                            tower_lsp::lsp_types::Url::parse("file:///schema.sql")
522                                                .unwrap()
523                                        },
524                                    ),
525                                    0,
526                                )
527                            } else {
528                                // 默认虚拟位置
529                                (
530                                    tower_lsp::lsp_types::Url::parse("file:///schema.sql").unwrap(),
531                                    0,
532                                )
533                            };
534
535                            return Some(Location {
536                                uri,
537                                range: tower_lsp::lsp_types::Range {
538                                    start: tower_lsp::lsp_types::Position { line, character: 0 },
539                                    end: tower_lsp::lsp_types::Position {
540                                        line,
541                                        character: 100,
542                                    },
543                                },
544                            });
545                        }
546                    }
547                }
548
549                // 如果是列名,查找列定义
550                if is_column {
551                    if let Some(schema) = schema {
552                        // 检查是否是 table.column 格式
553                        let (table_name, column_name) =
554                            if let Some(table_name) = parser.get_table_name_for_column(node, sql) {
555                                (Some(table_name), node_text.clone())
556                            } else {
557                                // 查找列所属的表
558                                let tables = parser.extract_tables(tree, sql);
559                                let table_name = tables.first().cloned();
560                                (table_name, node_text.clone())
561                            };
562
563                        // 在 Schema 中查找列
564                        for table in &schema.tables {
565                            if let Some(ref tname) = table_name {
566                                if table.name == *tname
567                                    && table.columns.iter().any(|c| c.name == column_name)
568                                {
569                                    // 返回当前文档中列名第一次出现的位置
570                                    return Some(Location {
571                                        uri: tower_lsp::lsp_types::Url::parse("file:///schema.sql")
572                                            .unwrap_or_else(|_| {
573                                                tower_lsp::lsp_types::Url::parse("file:///")
574                                                    .unwrap()
575                                            }),
576                                        range: parser.node_range(node),
577                                    });
578                                }
579                            } else if table.columns.iter().any(|c| c.name == column_name) {
580                                // 在所有表中查找列
581                                return Some(Location {
582                                    uri: tower_lsp::lsp_types::Url::parse("file:///schema.sql")
583                                        .unwrap_or_else(|_| {
584                                            tower_lsp::lsp_types::Url::parse("file:///").unwrap()
585                                        }),
586                                    range: parser.node_range(node),
587                                });
588                            }
589                        }
590                    }
591                }
592            }
593        }
594
595        None
596    }
597
598    async fn references(
599        &self,
600        sql: &str,
601        position: Position,
602        _schema: Option<&Schema>,
603    ) -> Vec<Location> {
604        let mut parser = self.parser.lock().unwrap();
605        let parse_result = parser.parse(sql);
606
607        let mut locations = Vec::new();
608
609        // 获取光标位置的标识符
610        if let Some(tree) = &parse_result.tree {
611            if let Some(node) = parser.get_node_at_position(tree, position) {
612                let identifier = parser.node_text(node, sql);
613                let node_kind = node.kind();
614
615                // 过滤关键字、操作符、分隔符
616                if crate::token::Keywords::is_keyword(&identifier)
617                    || crate::token::Operators::is_operator(&identifier)
618                    || crate::token::Delimiters::is_delimiter(&identifier)
619                {
620                    return locations;
621                }
622
623                // 判断是表名还是列名
624                let is_table = node_kind == "table_name"
625                    || node_kind == "table_reference"
626                    || node_kind == "table_identifier"
627                    || (node_kind == "identifier" && parser.is_in_from_context(node, sql));
628
629                let is_column = node_kind == "column_name"
630                    || node_kind == "column_reference"
631                    || node_kind == "column_identifier"
632                    || (node_kind == "identifier" && parser.is_in_column_context(node, sql));
633
634                if is_table || is_column {
635                    // 在当前文档中查找所有引用
636                    let tokens = parser.tokenize(tree, sql);
637                    let current_uri = tower_lsp::lsp_types::Url::parse("file:///current.sql")
638                        .unwrap_or_else(|_| tower_lsp::lsp_types::Url::parse("file:///").unwrap());
639
640                    for token in tokens {
641                        // 匹配标识符(忽略大小写)
642                        if token.text.eq_ignore_ascii_case(&identifier)
643                            && !crate::token::Keywords::is_keyword(&token.text)
644                            && !crate::token::Operators::is_operator(&token.text)
645                            && !crate::token::Delimiters::is_delimiter(&token.text)
646                        {
647                            // 检查 token 类型,确保是标识符而不是关键字
648                            locations.push(Location {
649                                uri: current_uri.clone(),
650                                range: tower_lsp::lsp_types::Range {
651                                    start: token.position,
652                                    end: tower_lsp::lsp_types::Position {
653                                        line: token.position.line,
654                                        character: token.position.character
655                                            + token.text.len() as u32,
656                                    },
657                                },
658                            });
659                        }
660                    }
661                }
662            }
663        }
664
665        locations
666    }
667
668    async fn format(&self, sql: &str) -> String {
669        use sqlformat::{FormatOptions, Indent, QueryParams};
670        let options = FormatOptions {
671            indent: Indent::Spaces(2),
672            uppercase: true,
673            lines_between_queries: 1,
674        };
675        sqlformat::format(sql, &QueryParams::None, options)
676    }
677
678    async fn validate(&self, sql: &str, schema: Option<&Schema>) -> Vec<Diagnostic> {
679        self.parse(sql, schema).await
680    }
681}