squawk_ide/
column_name.rs

1use squawk_syntax::{
2    SyntaxKind, SyntaxNode,
3    ast::{self, AstNode},
4};
5
6use crate::quote::normalize_identifier;
7
8#[derive(Clone, Debug, PartialEq)]
9pub(crate) enum ColumnName {
10    Column(String),
11    /// There's a fallback mechanism that we need to propagate through the
12    /// expressions/types.
13    //
14    /// We can see this with:
15    /// ```sql
16    /// select case when true then 'a' else now()::text end;
17    /// -- column named `now`, propagating the function name
18    /// -- vs
19    /// select case when true then 'a' else 'b' end;
20    /// -- column named `case`
21    /// ```
22    UnknownColumn(Option<String>),
23    Star,
24}
25
26impl ColumnName {
27    // Get the alias, otherwise infer the column name.
28    pub(crate) fn from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
29        if let Some(as_name) = target.as_name()
30            && let Some(name_node) = as_name.name()
31        {
32            let text = name_node.text();
33            let normalized = normalize_identifier(&text);
34            return Some((ColumnName::Column(normalized), name_node.syntax().clone()));
35        }
36        Self::inferred_from_target(target)
37    }
38
39    // Ignore any aliases, just infer the what the column name.
40    pub(crate) fn inferred_from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
41        if let Some(expr) = target.expr()
42            && let Some(name) = name_from_expr(expr, false)
43        {
44            return Some(name);
45        } else if target.star_token().is_some() {
46            return Some((ColumnName::Star, target.syntax().clone()));
47        }
48        None
49    }
50
51    fn new(name: String, unknown_column: bool) -> ColumnName {
52        if unknown_column {
53            ColumnName::UnknownColumn(Some(name))
54        } else {
55            ColumnName::Column(name)
56        }
57    }
58
59    pub(crate) fn to_string(&self) -> Option<String> {
60        match self {
61            ColumnName::Column(string) => Some(string.to_string()),
62            ColumnName::Star => None,
63            ColumnName::UnknownColumn(c) => {
64                Some(c.clone().unwrap_or_else(|| "?column?".to_string()))
65            }
66        }
67    }
68}
69
70fn name_from_type(ty: ast::Type, unknown_column: bool) -> Option<(ColumnName, SyntaxNode)> {
71    match ty {
72        ast::Type::PathType(path_type) => {
73            if let Some(name_ref) = path_type
74                .path()
75                .and_then(|x| x.segment())
76                .and_then(|x| x.name_ref())
77            {
78                return name_from_name_ref(name_ref, true).map(|(column, node)| {
79                    let column = match column {
80                        ColumnName::Column(c) => ColumnName::new(c, unknown_column),
81                        _ => column,
82                    };
83                    (column, node)
84                });
85            }
86        }
87        ast::Type::BitType(bit_type) => {
88            let name = if bit_type.varying_token().is_some() {
89                "varbit"
90            } else {
91                "bit"
92            };
93            return Some((
94                ColumnName::new(name.to_string(), unknown_column),
95                bit_type.syntax().clone(),
96            ));
97        }
98        ast::Type::CharType(char_type) => {
99            let name = if char_type.varchar_token().is_some() || char_type.varying_token().is_some()
100            {
101                "varchar"
102            } else {
103                "bpchar"
104            };
105            return Some((
106                ColumnName::new(name.to_string(), unknown_column),
107                char_type.syntax().clone(),
108            ));
109        }
110        ast::Type::DoubleType(double_type) => {
111            return Some((
112                ColumnName::new("float8".to_string(), unknown_column),
113                double_type.syntax().clone(),
114            ));
115        }
116        ast::Type::IntervalType(interval_type) => {
117            return Some((
118                ColumnName::new("interval".to_string(), unknown_column),
119                interval_type.syntax().clone(),
120            ));
121        }
122        ast::Type::TimeType(time_type) => {
123            let mut name = if time_type.timestamp_token().is_some() {
124                "timestamp".to_owned()
125            } else {
126                "time".to_owned()
127            };
128            if let Some(ast::Timezone::WithTimezone(_)) = time_type.timezone() {
129                // time -> timetz
130                // timestamp -> timestamptz
131                name.push_str("tz");
132            };
133            return Some((
134                ColumnName::new(name.to_string(), unknown_column),
135                time_type.syntax().clone(),
136            ));
137        }
138        ast::Type::ArrayType(array_type) => {
139            if let Some(inner_ty) = array_type.ty() {
140                return name_from_type(inner_ty, unknown_column);
141            }
142        }
143        // we shouldn't ever hit this since the following isn't valid syntax:
144        // select cast('foo' as t.a%TYPE);
145        ast::Type::PercentType(_) => return None,
146        ast::Type::ExprType(expr_type) => {
147            if let Some(expr) = expr_type.expr() {
148                return name_from_expr(expr, true).map(|(column, node)| {
149                    let column = match column {
150                        ColumnName::Column(c) => ColumnName::new(c, unknown_column),
151                        _ => column,
152                    };
153                    (column, node)
154                });
155            }
156        }
157    }
158    None
159}
160
161fn name_from_name_ref(name_ref: ast::NameRef, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
162    if in_type {
163        for node in name_ref.syntax().children_with_tokens() {
164            match node.kind() {
165                SyntaxKind::BIGINT_KW => {
166                    return Some((
167                        ColumnName::Column("int8".to_owned()),
168                        name_ref.syntax().clone(),
169                    ));
170                }
171                SyntaxKind::INT_KW | SyntaxKind::INTEGER_KW => {
172                    return Some((
173                        ColumnName::Column("int4".to_owned()),
174                        name_ref.syntax().clone(),
175                    ));
176                }
177                SyntaxKind::SMALLINT_KW => {
178                    return Some((
179                        ColumnName::Column("int2".to_owned()),
180                        name_ref.syntax().clone(),
181                    ));
182                }
183                _ => (),
184            }
185        }
186    }
187    let text = name_ref.text();
188    let normalized = normalize_identifier(&text);
189    return Some((ColumnName::Column(normalized), name_ref.syntax().clone()));
190}
191
192/*
193TODO:
194
195unnest(anyarray, anyarray [, ... ]) → setof anyelement, anyelement [, ... ]
196
197select * from unnest(ARRAY[1,2], ARRAY['foo','bar','baz']) →
198 unnset | unnset
199--------+-----
200      1 | foo
201      2 | bar
202        | baz
203*/
204
205// NOTE: we have to have this in_type param because we parse some casts as exprs
206// instead of types.
207fn name_from_expr(expr: ast::Expr, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
208    let node = expr.syntax().clone();
209    match expr {
210        ast::Expr::ArrayExpr(_) => {
211            return Some((ColumnName::Column("array".to_string()), node));
212        }
213        ast::Expr::BetweenExpr(_) | ast::Expr::BinExpr(_) => {
214            return Some((ColumnName::UnknownColumn(None), node));
215        }
216        ast::Expr::CallExpr(call_expr) => {
217            if let Some(func_name) = call_expr.expr() {
218                match func_name {
219                    ast::Expr::ArrayExpr(_)
220                    | ast::Expr::BetweenExpr(_)
221                    | ast::Expr::ParenExpr(_)
222                    | ast::Expr::BinExpr(_)
223                    | ast::Expr::CallExpr(_)
224                    | ast::Expr::CaseExpr(_)
225                    | ast::Expr::CastExpr(_)
226                    | ast::Expr::Literal(_)
227                    | ast::Expr::PostfixExpr(_)
228                    | ast::Expr::PrefixExpr(_)
229                    | ast::Expr::TupleExpr(_)
230                    | ast::Expr::IndexExpr(_)
231                    | ast::Expr::SliceExpr(_) => unreachable!("not possible in the grammar"),
232                    ast::Expr::FieldExpr(field_expr) => {
233                        if let Some(name_ref) = field_expr.field() {
234                            return name_from_name_ref(name_ref, in_type);
235                        }
236                    }
237                    ast::Expr::NameRef(name_ref) => {
238                        return name_from_name_ref(name_ref, in_type);
239                    }
240                }
241            }
242        }
243        ast::Expr::CaseExpr(case) => {
244            if let Some(else_clause) = case.else_clause()
245                && let Some(expr) = else_clause.expr()
246                && let Some((column, node)) = name_from_expr(expr, in_type)
247            {
248                if !matches!(column, ColumnName::UnknownColumn(_)) {
249                    return Some((column, node));
250                }
251            }
252            return Some((ColumnName::Column("case".to_string()), node));
253        }
254        ast::Expr::CastExpr(cast_expr) => {
255            let mut unknown_column = false;
256            if let Some(expr) = cast_expr.expr()
257                && let Some((column, node)) = name_from_expr(expr, in_type)
258            {
259                match column {
260                    ColumnName::Column(_) => return Some((column, node)),
261                    ColumnName::UnknownColumn(_) => unknown_column = true,
262                    ColumnName::Star => (),
263                }
264            }
265            if let Some(ty) = cast_expr.ty() {
266                return name_from_type(ty, unknown_column);
267            }
268        }
269        ast::Expr::FieldExpr(field_expr) => {
270            if let Some(name_ref) = field_expr.field() {
271                return name_from_name_ref(name_ref, in_type);
272            }
273        }
274        ast::Expr::IndexExpr(index_expr) => {
275            if let Some(base) = index_expr.base() {
276                return name_from_expr(base, in_type);
277            }
278        }
279        ast::Expr::SliceExpr(slice_expr) => {
280            if let Some(base) = slice_expr.base() {
281                return name_from_expr(base, in_type);
282            }
283        }
284        ast::Expr::Literal(_) | ast::Expr::PrefixExpr(_) | ast::Expr::PostfixExpr(_) => {
285            return Some((ColumnName::UnknownColumn(None), node));
286        }
287        ast::Expr::NameRef(name_ref) => {
288            return name_from_name_ref(name_ref, in_type);
289        }
290        ast::Expr::ParenExpr(paren_expr) => {
291            if let Some(expr) = paren_expr.expr() {
292                return name_from_expr(expr, in_type);
293            } else if let Some(select) = paren_expr.select()
294                && let Some(mut targets) = select
295                    .select_clause()
296                    .and_then(|x| x.target_list())
297                    .map(|x| x.targets())
298                && let Some(target) = targets.next()
299            {
300                return ColumnName::from_target(target);
301            }
302        }
303        ast::Expr::TupleExpr(_) => {
304            return Some((ColumnName::Column("row".to_string()), node));
305        }
306    }
307    None
308}
309
310#[test]
311fn examples() {
312    use insta::assert_snapshot;
313
314    // array
315    assert_snapshot!(name("array(select 1)"), @"array");
316    assert_snapshot!(name("array[1, 2, 3]"), @"array");
317
318    // unknown columns
319    assert_snapshot!(name("1 between 0 and 10"), @"?column?");
320    assert_snapshot!(name("1 + 2"), @"?column?");
321    assert_snapshot!(name("42"), @"?column?");
322    assert_snapshot!(name("'string'"), @"?column?");
323    // prefix
324    assert_snapshot!(name("-42"), @"?column?");
325    assert_snapshot!(name("|/ 42"), @"?column?");
326    // postfix
327    assert_snapshot!(name("x is null"), @"?column?");
328    assert_snapshot!(name("x is not null"), @"?column?");
329    // paren expr
330    assert_snapshot!(name("(1 * 2)"), @"?column?");
331    assert_snapshot!(name("(select 1 as a)"), @"a");
332
333    // func
334    assert_snapshot!(name("count(*)"), @"count");
335    assert_snapshot!(name("schema.func_name(1)"), @"func_name");
336
337    // index
338    assert_snapshot!(name("foo[bar]"), @"foo");
339    assert_snapshot!(name("foo[1]"), @"foo");
340
341    // column
342    assert_snapshot!(name("database.schema.table.column"), @"column");
343    assert_snapshot!(name("t.a"), @"a");
344    assert_snapshot!(name("col_name"), @"col_name");
345    assert_snapshot!(name("(c)"), @"c");
346
347    // case
348    assert_snapshot!(name("case when true then 'foo' end"), @"case");
349    assert_snapshot!(name("case when true then 'foo' else now()::text end"), @"now");
350    assert_snapshot!(name("case when true then 'foo' else 'bar' end"), @"case");
351    assert_snapshot!(name("case when true then 'foo' else '1'::bigint::text end"), @"case");
352
353    // casts
354    assert_snapshot!(name("now()::text"), @"now");
355    assert_snapshot!(name("cast(col_name as text)"), @"col_name");
356    assert_snapshot!(name("col_name::text"), @"col_name");
357    assert_snapshot!(name("col_name::int::text"), @"col_name");
358    assert_snapshot!(name("'1'::bigint"), @"int8");
359    assert_snapshot!(name("'1'::int"), @"int4");
360    assert_snapshot!(name("'1'::smallint"), @"int2");
361    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::bigint[][]"), @"int8");
362    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[][]"), @"int4");
363    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::smallint[]"), @"int2");
364    assert_snapshot!(name("pg_catalog.varchar(100) '{1}'"), @"varchar");
365    assert_snapshot!(name("'{1}'::integer[];"), @"int4");
366    assert_snapshot!(name("'{1}'::pg_catalog.varchar(1)[]::integer[];"), @"int4");
367    assert_snapshot!(name("'1'::bigint::smallint"), @"int2");
368
369    // alias
370    // with quoting
371    assert_snapshot!(name(r#"'foo' as "FOO""#), @"FOO");
372    assert_snapshot!(name(r#"'foo' as "foo""#), @"foo");
373    // without quoting
374    assert_snapshot!(name(r#"'foo' as FOO"#), @"foo");
375    assert_snapshot!(name(r#"'foo' as foo"#), @"foo");
376
377    // tuple
378    assert_snapshot!(name("(1, 2, 3)"), @"row");
379    assert_snapshot!(name("(1, 2, 3)::address"), @"row");
380
381    // composite type
382    assert_snapshot!(name("(x).city"), @"city");
383
384    // array types
385    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[]"), @"int4");
386    assert_snapshot!(name("cast('{foo}' as text[])"), @"text");
387
388    // bit types
389    assert_snapshot!(name("cast('1010' as bit varying(10))"), @"varbit");
390
391    // char types
392    assert_snapshot!(name("cast('hello' as character varying(10))"), @"varchar");
393    assert_snapshot!(name("cast('hello' as char varying(5))"), @"varchar");
394    assert_snapshot!(name("cast('hello' as char(5))"), @"bpchar");
395
396    // double types
397    assert_snapshot!(name("cast(1.5 as double precision)"), @"float8");
398
399    // interval types
400    assert_snapshot!(name("cast('1 hour' as interval hour to minute)"), @"interval");
401
402    // percent types
403    assert_snapshot!(name("cast(foo as schema.%TYPE)"), @"foo");
404
405    // time types
406    assert_snapshot!(name("cast('12:00:00' as time(6) without time zone)"), @"time");
407    assert_snapshot!(name("cast('12:00:00' as time(6) with time zone)"), @"timetz");
408    assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) with time zone)"), @"timestamptz");
409    assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) without time zone)"), @"timestamp");
410
411    #[track_caller]
412    fn name(sql: &str) -> String {
413        let sql = "select ".to_string() + sql;
414        let parse = squawk_syntax::SourceFile::parse(&sql);
415        assert_eq!(parse.errors(), vec![]);
416        let file = parse.tree();
417
418        let stmt = file.stmts().next().unwrap();
419        let ast::Stmt::Select(select) = stmt else {
420            unreachable!()
421        };
422
423        let target = select
424            .select_clause()
425            .and_then(|sc| sc.target_list())
426            .and_then(|tl| tl.targets().next())
427            .unwrap();
428
429        ColumnName::from_target(target)
430            .and_then(|x| x.0.to_string())
431            .unwrap()
432    }
433}