squawk_ide/
expand_selection.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/8d75311400a108d7ffe17dc9c38182c566952e6e/crates/ide/src/extend_selection.rs#L1C1-L1C1
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27// NOTE: this is pretty much copied as is from rust analyzer with some
28// simplifications. I imagine there's more we can do to adapt it for SQL.
29
30use rowan::{Direction, NodeOrToken, TextRange, TextSize};
31use squawk_syntax::{
32    SyntaxKind, SyntaxNode, SyntaxToken,
33    ast::{self, AstToken},
34};
35
36const DELIMITED_LIST_KINDS: &[SyntaxKind] = &[
37    SyntaxKind::ALTER_OPTION_LIST,
38    SyntaxKind::ARG_LIST,
39    SyntaxKind::ATTRIBUTE_LIST,
40    SyntaxKind::COLUMN_LIST,
41    SyntaxKind::CONFLICT_INDEX_ITEM_LIST,
42    SyntaxKind::CONSTRAINT_EXCLUSION_LIST,
43    SyntaxKind::FUNCTION_SIG_LIST,
44    SyntaxKind::GROUP_BY_LIST,
45    SyntaxKind::JSON_TABLE_COLUMN_LIST,
46    SyntaxKind::OPTION_ITEM_LIST,
47    SyntaxKind::PARAM_LIST,
48    SyntaxKind::PARTITION_ITEM_LIST,
49    SyntaxKind::ROW_LIST,
50    SyntaxKind::SET_COLUMN_LIST,
51    SyntaxKind::SET_EXPR_LIST,
52    SyntaxKind::SET_OPTIONS_LIST,
53    SyntaxKind::SORT_BY_LIST,
54    SyntaxKind::TABLE_ARG_LIST,
55    SyntaxKind::TABLE_LIST,
56    SyntaxKind::TARGET_LIST,
57    SyntaxKind::TRANSACTION_MODE_LIST,
58    SyntaxKind::VACUUM_OPTION_LIST,
59    SyntaxKind::VARIANT_LIST,
60    SyntaxKind::XML_TABLE_COLUMN_LIST,
61];
62
63pub fn extend_selection(root: &SyntaxNode, range: TextRange) -> TextRange {
64    try_extend_selection(root, range).unwrap_or(range)
65}
66
67fn try_extend_selection(root: &SyntaxNode, range: TextRange) -> Option<TextRange> {
68    let string_kinds = [
69        SyntaxKind::COMMENT,
70        SyntaxKind::STRING,
71        SyntaxKind::BYTE_STRING,
72        SyntaxKind::BIT_STRING,
73        SyntaxKind::DOLLAR_QUOTED_STRING,
74        SyntaxKind::ESC_STRING,
75    ];
76
77    if range.is_empty() {
78        let offset = range.start();
79        let mut leaves = root.token_at_offset(offset);
80        // Make sure that if we're on the whitespace at the start of a line, we
81        // expand to the node on that line instead of the previous one
82        if leaves.clone().all(|it| it.kind() == SyntaxKind::WHITESPACE) {
83            return Some(extend_ws(root, leaves.next()?, offset));
84        }
85        let leaf_range = match root.token_at_offset(offset) {
86            rowan::TokenAtOffset::None => return None,
87            rowan::TokenAtOffset::Single(l) => {
88                if string_kinds.contains(&l.kind()) {
89                    extend_single_word_in_comment_or_string(&l, offset)
90                        .unwrap_or_else(|| l.text_range())
91                } else {
92                    l.text_range()
93                }
94            }
95            rowan::TokenAtOffset::Between(l, r) => pick_best(l, r).text_range(),
96        };
97        return Some(leaf_range);
98    }
99
100    let node = match root.covering_element(range) {
101        NodeOrToken::Token(token) => {
102            if token.text_range() != range {
103                return Some(token.text_range());
104            }
105            if let Some(comment) = ast::Comment::cast(token.clone())
106                && let Some(range) = extend_comments(comment)
107            {
108                return Some(range);
109            }
110            token.parent()?
111        }
112        NodeOrToken::Node(node) => node,
113    };
114
115    if node.text_range() != range {
116        return Some(node.text_range());
117    }
118
119    let node = shallowest_node(&node);
120
121    if node
122        .parent()
123        .is_some_and(|n| DELIMITED_LIST_KINDS.contains(&n.kind()))
124    {
125        if let Some(range) = extend_list_item(&node) {
126            return Some(range);
127        }
128    }
129
130    node.parent().map(|it| it.text_range())
131}
132
133/// Find the shallowest node with same range, which allows us to traverse siblings.
134fn shallowest_node(node: &SyntaxNode) -> SyntaxNode {
135    node.ancestors()
136        .take_while(|n| n.text_range() == node.text_range())
137        .last()
138        .unwrap()
139}
140
141/// Expand to the current word instead the full text range of the node.
142fn extend_single_word_in_comment_or_string(
143    leaf: &SyntaxToken,
144    offset: TextSize,
145) -> Option<TextRange> {
146    let text: &str = leaf.text();
147    let cursor_position: u32 = (offset - leaf.text_range().start()).into();
148
149    let (before, after) = text.split_at(cursor_position as usize);
150
151    fn non_word_char(c: char) -> bool {
152        !(c.is_alphanumeric() || c == '_')
153    }
154
155    let start_idx = before.rfind(non_word_char)? as u32;
156    let end_idx = after.find(non_word_char).unwrap_or(after.len()) as u32;
157
158    // FIXME: use `ceil_char_boundary` from `std::str` when it gets stable
159    // https://github.com/rust-lang/rust/issues/93743
160    fn ceil_char_boundary(text: &str, index: u32) -> u32 {
161        (index..)
162            .find(|&index| text.is_char_boundary(index as usize))
163            .unwrap_or(text.len() as u32)
164    }
165
166    let from: TextSize = ceil_char_boundary(text, start_idx + 1).into();
167    let to: TextSize = (cursor_position + end_idx).into();
168
169    let range = TextRange::new(from, to);
170    if range.is_empty() {
171        None
172    } else {
173        Some(range + leaf.text_range().start())
174    }
175}
176
177fn extend_comments(comment: ast::Comment) -> Option<TextRange> {
178    let prev = adj_comments(&comment, Direction::Prev);
179    let next = adj_comments(&comment, Direction::Next);
180    if prev != next {
181        Some(TextRange::new(
182            prev.syntax().text_range().start(),
183            next.syntax().text_range().end(),
184        ))
185    } else {
186        None
187    }
188}
189
190fn adj_comments(comment: &ast::Comment, dir: Direction) -> ast::Comment {
191    let mut res = comment.clone();
192    for element in comment.syntax().siblings_with_tokens(dir) {
193        let Some(token) = element.as_token() else {
194            break;
195        };
196        if let Some(c) = ast::Comment::cast(token.clone()) {
197            res = c
198        } else if token.kind() != SyntaxKind::WHITESPACE || token.text().contains("\n\n") {
199            break;
200        }
201    }
202    res
203}
204
205fn extend_ws(root: &SyntaxNode, ws: SyntaxToken, offset: TextSize) -> TextRange {
206    let ws_text = ws.text();
207    let suffix = TextRange::new(offset, ws.text_range().end()) - ws.text_range().start();
208    let prefix = TextRange::new(ws.text_range().start(), offset) - ws.text_range().start();
209    let ws_suffix = &ws_text[suffix];
210    let ws_prefix = &ws_text[prefix];
211    if ws_text.contains('\n')
212        && !ws_suffix.contains('\n')
213        && let Some(node) = ws.next_sibling_or_token()
214    {
215        let start = match ws_prefix.rfind('\n') {
216            Some(idx) => ws.text_range().start() + TextSize::from((idx + 1) as u32),
217            None => node.text_range().start(),
218        };
219        let end = if root.text().char_at(node.text_range().end()) == Some('\n') {
220            node.text_range().end() + TextSize::of('\n')
221        } else {
222            node.text_range().end()
223        };
224        return TextRange::new(start, end);
225    }
226    ws.text_range()
227}
228
229fn pick_best(l: SyntaxToken, r: SyntaxToken) -> SyntaxToken {
230    return if priority(&r) > priority(&l) { r } else { l };
231    fn priority(n: &SyntaxToken) -> usize {
232        match n.kind() {
233            SyntaxKind::WHITESPACE => 0,
234            // TODO: we can probably include more here, rust analyzer includes a
235            // handful of keywords
236            SyntaxKind::IDENT => 2,
237            _ => 1,
238        }
239    }
240}
241
242/// Extend list item selection to include nearby delimiter and whitespace.
243fn extend_list_item(node: &SyntaxNode) -> Option<TextRange> {
244    fn is_single_line_ws(node: &SyntaxToken) -> bool {
245        node.kind() == SyntaxKind::WHITESPACE && !node.text().contains('\n')
246    }
247
248    fn nearby_comma(node: &SyntaxNode, dir: Direction) -> Option<SyntaxToken> {
249        node.siblings_with_tokens(dir)
250            .skip(1)
251            .find(|node| match node {
252                NodeOrToken::Node(_) => true,
253                NodeOrToken::Token(it) => !is_single_line_ws(it),
254            })
255            .and_then(|it| it.into_token())
256            .filter(|node| node.kind() == SyntaxKind::COMMA)
257    }
258
259    if let Some(comma) = nearby_comma(node, Direction::Next) {
260        // Include any following whitespace when delimiter is after list item.
261        let final_node = comma
262            .next_sibling_or_token()
263            .and_then(|n| n.into_token())
264            .filter(is_single_line_ws)
265            .unwrap_or(comma);
266
267        return Some(TextRange::new(
268            node.text_range().start(),
269            final_node.text_range().end(),
270        ));
271    }
272
273    if let Some(comma) = nearby_comma(node, Direction::Prev) {
274        return Some(TextRange::new(
275            comma.text_range().start(),
276            node.text_range().end(),
277        ));
278    }
279
280    None
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use crate::test_utils::fixture;
287    use insta::assert_debug_snapshot;
288    use squawk_syntax::{SourceFile, ast::AstNode};
289
290    fn expand(sql: &str) -> Vec<String> {
291        let (offset, sql) = fixture(sql);
292        let parse = SourceFile::parse(&sql);
293        let file = parse.tree();
294        let root = file.syntax();
295
296        let mut range = TextRange::empty(offset);
297        let mut results = Vec::new();
298
299        for _ in 0..20 {
300            let new_range = extend_selection(root, range);
301            if new_range == range {
302                break;
303            }
304            range = new_range;
305            results.push(sql[range].to_string());
306        }
307
308        results
309    }
310
311    #[test]
312    fn simple() {
313        assert_debug_snapshot!(expand(r#"select $01 + 1"#), @r#"
314        [
315            "1",
316            "1 + 1",
317            "select 1 + 1",
318        ]
319        "#);
320    }
321
322    #[test]
323    fn word_in_string_string() {
324        assert_debug_snapshot!(expand(r"
325select 'some stret$0ched out words in a string'
326"), @r#"
327        [
328            "stretched",
329            "'some stretched out words in a string'",
330            "select 'some stretched out words in a string'",
331            "\nselect 'some stretched out words in a string'\n",
332        ]
333        "#);
334    }
335
336    #[test]
337    fn string() {
338        assert_debug_snapshot!(expand(r"
339select b'foo$0 bar'
340'buzz';
341"), @r#"
342        [
343            "foo",
344            "b'foo bar'",
345            "b'foo bar'\n'buzz'",
346            "select b'foo bar'\n'buzz'",
347            "\nselect b'foo bar'\n'buzz';\n",
348        ]
349        "#);
350    }
351
352    #[test]
353    fn dollar_string() {
354        assert_debug_snapshot!(expand(r"
355select $$foo$0 bar$$;
356"), @r#"
357        [
358            "foo",
359            "$$foo bar$$",
360            "select $$foo bar$$",
361            "\nselect $$foo bar$$;\n",
362        ]
363        "#);
364    }
365
366    #[test]
367    fn comment_muli_line() {
368        assert_debug_snapshot!(expand(r"
369-- foo bar
370-- buzz$0
371-- boo
372select 1
373"), @r#"
374        [
375            "-- buzz",
376            "-- foo bar\n-- buzz\n-- boo",
377            "\n-- foo bar\n-- buzz\n-- boo\nselect 1\n",
378        ]
379        "#);
380    }
381
382    #[test]
383    fn comment() {
384        assert_debug_snapshot!(expand(r"
385-- foo bar$0
386select 1
387"), @r#"
388        [
389            "-- foo bar",
390            "\n-- foo bar\nselect 1\n",
391        ]
392        "#);
393
394        assert_debug_snapshot!(expand(r"
395/* foo bar$0 */
396select 1
397"), @r#"
398        [
399            "bar",
400            "/* foo bar */",
401            "\n/* foo bar */\nselect 1\n",
402        ]
403        "#);
404    }
405
406    #[test]
407    fn create_table_with_comment() {
408        assert_debug_snapshot!(expand(r"
409-- foo bar buzz
410create table t(
411  x int$0,
412  y text
413);
414"), @r#"
415        [
416            "int",
417            "x int",
418            "x int,",
419            "(\n  x int,\n  y text\n)",
420            "-- foo bar buzz\ncreate table t(\n  x int,\n  y text\n)",
421            "\n-- foo bar buzz\ncreate table t(\n  x int,\n  y text\n);\n",
422        ]
423        "#);
424    }
425
426    #[test]
427    fn column_list() {
428        assert_debug_snapshot!(expand(r#"create table t($0x int)"#), @r#"
429        [
430            "x",
431            "x int",
432            "(x int)",
433            "create table t(x int)",
434        ]
435        "#);
436
437        assert_debug_snapshot!(expand(r#"create table t($0x int, y int)"#), @r#"
438        [
439            "x",
440            "x int",
441            "x int, ",
442            "(x int, y int)",
443            "create table t(x int, y int)",
444        ]
445        "#);
446
447        assert_debug_snapshot!(expand(r#"create table t(x int, $0y int)"#), @r#"
448        [
449            "y",
450            "y int",
451            ", y int",
452            "(x int, y int)",
453            "create table t(x int, y int)",
454        ]
455        "#);
456    }
457
458    #[test]
459    fn start_of_line_whitespace_select() {
460        assert_debug_snapshot!(expand(r#"    
461select 1;
462
463$0    select 2;"#), @r#"
464        [
465            "    select 2",
466            "    \nselect 1;\n\n    select 2;",
467        ]
468        "#);
469    }
470
471    #[test]
472    fn select_list() {
473        assert_debug_snapshot!(expand(r#"select x$0, y from t"#), @r#"
474        [
475            "x",
476            "x, ",
477            "x, y",
478            "select x, y",
479            "select x, y from t",
480        ]
481        "#);
482
483        assert_debug_snapshot!(expand(r#"select x, y$0 from t"#), @r#"
484        [
485            "y",
486            ", y",
487            "x, y",
488            "select x, y",
489            "select x, y from t",
490        ]
491        "#);
492    }
493
494    #[test]
495    fn expand_whitespace() {
496        assert_debug_snapshot!(expand(r#"select 1 + 
497$0
4981;"#), @r#"
499        [
500            " \n\n",
501            "1 + \n\n1",
502            "select 1 + \n\n1",
503            "select 1 + \n\n1;",
504        ]
505        "#);
506    }
507
508    #[test]
509    fn function_args() {
510        assert_debug_snapshot!(expand(r#"select f(1$0, 2)"#), @r#"
511        [
512            "1",
513            "1, ",
514            "(1, 2)",
515            "f(1, 2)",
516            "select f(1, 2)",
517        ]
518        "#);
519    }
520
521    #[test]
522    fn prefer_idents() {
523        assert_debug_snapshot!(expand(r#"select foo$0+bar"#), @r#"
524        [
525            "foo",
526            "foo+bar",
527            "select foo+bar",
528        ]
529        "#);
530
531        assert_debug_snapshot!(expand(r#"select foo+$0bar"#), @r#"
532        [
533            "bar",
534            "foo+bar",
535            "select foo+bar",
536        ]
537        "#);
538    }
539
540    #[test]
541    fn list_variants() {
542        let delimited_ws_list_kinds = &[
543            SyntaxKind::FUNC_OPTION_LIST,
544            SyntaxKind::ROLE_OPTION_LIST,
545            SyntaxKind::SEQUENCE_OPTION_LIST,
546            SyntaxKind::XML_COLUMN_OPTION_LIST,
547            SyntaxKind::WHEN_CLAUSE_LIST,
548        ];
549
550        let unhandled_list_kinds = (0..SyntaxKind::__LAST as u16)
551            .map(SyntaxKind::from)
552            .filter(|kind| {
553                format!("{:?}", kind).ends_with("_LIST") && !delimited_ws_list_kinds.contains(kind)
554            })
555            .filter(|kind| !DELIMITED_LIST_KINDS.contains(kind))
556            .collect::<Vec<_>>();
557
558        assert_eq!(
559            unhandled_list_kinds,
560            vec![],
561            "We shouldn't have any unhandled list kinds"
562        )
563    }
564}