Skip to main content

squawk_ide/
column_name.rs

1use squawk_syntax::{
2    SyntaxKind, SyntaxNode,
3    ast::{self, AstNode},
4};
5
6use crate::quote::normalize_identifier;
7
8#[derive(Clone, Debug, PartialEq)]
9pub(crate) enum ColumnName {
10    Column(String),
11    /// There's a fallback mechanism that we need to propagate through the
12    /// expressions/types.
13    //
14    /// We can see this with:
15    /// ```sql
16    /// select case when true then 'a' else now()::text end;
17    /// -- column named `now`, propagating the function name
18    /// -- vs
19    /// select case when true then 'a' else 'b' end;
20    /// -- column named `case`
21    /// ```
22    UnknownColumn(Option<String>),
23    Star,
24}
25
26impl ColumnName {
27    // Get the alias, otherwise infer the column name.
28    pub(crate) fn from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
29        if let Some(as_name) = target.as_name()
30            && let Some(name_node) = as_name.name()
31        {
32            let text = name_node.text();
33            let normalized = normalize_identifier(&text);
34            return Some((ColumnName::Column(normalized), name_node.syntax().clone()));
35        }
36        Self::inferred_from_target(target)
37    }
38
39    // Ignore any aliases, just infer the what the column name.
40    pub(crate) fn inferred_from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
41        if let Some(expr) = target.expr()
42            && let Some(name) = name_from_expr(expr, false)
43        {
44            return Some(name);
45        } else if target.star_token().is_some() {
46            return Some((ColumnName::Star, target.syntax().clone()));
47        }
48        None
49    }
50
51    fn new(name: String, unknown_column: bool) -> ColumnName {
52        if unknown_column {
53            ColumnName::UnknownColumn(Some(name))
54        } else {
55            ColumnName::Column(name)
56        }
57    }
58
59    pub(crate) fn to_string(&self) -> Option<String> {
60        match self {
61            ColumnName::Column(string) => Some(string.to_string()),
62            ColumnName::Star => None,
63            ColumnName::UnknownColumn(c) => {
64                Some(c.clone().unwrap_or_else(|| "?column?".to_string()))
65            }
66        }
67    }
68}
69
70fn name_from_type(ty: ast::Type, unknown_column: bool) -> Option<(ColumnName, SyntaxNode)> {
71    match ty {
72        ast::Type::PathType(path_type) => {
73            if let Some(name_ref) = path_type
74                .path()
75                .and_then(|x| x.segment())
76                .and_then(|x| x.name_ref())
77            {
78                return name_from_name_ref(name_ref, true).map(|(column, node)| {
79                    let column = match column {
80                        ColumnName::Column(c) => ColumnName::new(c, unknown_column),
81                        _ => column,
82                    };
83                    (column, node)
84                });
85            }
86        }
87        ast::Type::BitType(bit_type) => {
88            let name = if bit_type.varying_token().is_some() {
89                "varbit"
90            } else {
91                "bit"
92            };
93            return Some((
94                ColumnName::new(name.to_string(), unknown_column),
95                bit_type.syntax().clone(),
96            ));
97        }
98        ast::Type::CharType(char_type) => {
99            let name = if char_type.varchar_token().is_some() || char_type.varying_token().is_some()
100            {
101                "varchar"
102            } else {
103                "bpchar"
104            };
105            return Some((
106                ColumnName::new(name.to_string(), unknown_column),
107                char_type.syntax().clone(),
108            ));
109        }
110        ast::Type::DoubleType(double_type) => {
111            return Some((
112                ColumnName::new("float8".to_string(), unknown_column),
113                double_type.syntax().clone(),
114            ));
115        }
116        ast::Type::IntervalType(interval_type) => {
117            return Some((
118                ColumnName::new("interval".to_string(), unknown_column),
119                interval_type.syntax().clone(),
120            ));
121        }
122        ast::Type::TimeType(time_type) => {
123            let mut name = if time_type.timestamp_token().is_some() {
124                "timestamp".to_owned()
125            } else {
126                "time".to_owned()
127            };
128            if let Some(ast::Timezone::WithTimezone(_)) = time_type.timezone() {
129                // time -> timetz
130                // timestamp -> timestamptz
131                name.push_str("tz");
132            };
133            return Some((
134                ColumnName::new(name.to_string(), unknown_column),
135                time_type.syntax().clone(),
136            ));
137        }
138        ast::Type::ArrayType(array_type) => {
139            if let Some(inner_ty) = array_type.ty() {
140                return name_from_type(inner_ty, unknown_column);
141            }
142        }
143        // we shouldn't ever hit this since the following isn't valid syntax:
144        // select cast('foo' as t.a%TYPE);
145        ast::Type::PercentType(_) => return None,
146        ast::Type::ExprType(expr_type) => {
147            if let Some(expr) = expr_type.expr() {
148                return name_from_expr(expr, true).map(|(column, node)| {
149                    let column = match column {
150                        ColumnName::Column(c) => ColumnName::new(c, unknown_column),
151                        _ => column,
152                    };
153                    (column, node)
154                });
155            }
156        }
157    }
158    None
159}
160
161fn name_from_name_ref(name_ref: ast::NameRef, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
162    if in_type {
163        for node in name_ref.syntax().children_with_tokens() {
164            match node.kind() {
165                SyntaxKind::BIGINT_KW => {
166                    return Some((
167                        ColumnName::Column("int8".to_owned()),
168                        name_ref.syntax().clone(),
169                    ));
170                }
171                SyntaxKind::INT_KW | SyntaxKind::INTEGER_KW => {
172                    return Some((
173                        ColumnName::Column("int4".to_owned()),
174                        name_ref.syntax().clone(),
175                    ));
176                }
177                SyntaxKind::SMALLINT_KW => {
178                    return Some((
179                        ColumnName::Column("int2".to_owned()),
180                        name_ref.syntax().clone(),
181                    ));
182                }
183                SyntaxKind::REAL_KW => {
184                    return Some((
185                        ColumnName::Column("float4".to_owned()),
186                        name_ref.syntax().clone(),
187                    ));
188                }
189                _ => (),
190            }
191        }
192    }
193    let text = name_ref.text();
194    let normalized = normalize_identifier(&text);
195    return Some((ColumnName::Column(normalized), name_ref.syntax().clone()));
196}
197
198/*
199TODO:
200
201unnest(anyarray, anyarray [, ... ]) → setof anyelement, anyelement [, ... ]
202
203select * from unnest(ARRAY[1,2], ARRAY['foo','bar','baz']) →
204 unnset | unnset
205--------+-----
206      1 | foo
207      2 | bar
208        | baz
209*/
210
211// NOTE: we have to have this in_type param because we parse some casts as exprs
212// instead of types.
213fn name_from_expr(expr: ast::Expr, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
214    let node = expr.syntax().clone();
215    match expr {
216        ast::Expr::ArrayExpr(_) => {
217            return Some((ColumnName::Column("array".to_string()), node));
218        }
219        ast::Expr::BetweenExpr(_) | ast::Expr::BinExpr(_) => {
220            return Some((ColumnName::UnknownColumn(None), node));
221        }
222        ast::Expr::CallExpr(call_expr) => {
223            if let Some(exists_fn) = call_expr.exists_fn() {
224                return Some((
225                    ColumnName::Column("exists".to_string()),
226                    exists_fn.syntax().clone(),
227                ));
228            }
229            if let Some(extract_fn) = call_expr.extract_fn() {
230                return Some((
231                    ColumnName::Column("extract".to_string()),
232                    extract_fn.syntax().clone(),
233                ));
234            }
235            if let Some(json_exists_fn) = call_expr.json_exists_fn() {
236                return Some((
237                    ColumnName::Column("json_exists".to_string()),
238                    json_exists_fn.syntax().clone(),
239                ));
240            }
241            if let Some(json_array_fn) = call_expr.json_array_fn() {
242                return Some((
243                    ColumnName::Column("json_array".to_string()),
244                    json_array_fn.syntax().clone(),
245                ));
246            }
247            if let Some(json_object_fn) = call_expr.json_object_fn() {
248                return Some((
249                    ColumnName::Column("json_object".to_string()),
250                    json_object_fn.syntax().clone(),
251                ));
252            }
253            if let Some(json_object_agg_fn) = call_expr.json_object_agg_fn() {
254                return Some((
255                    ColumnName::Column("json_objectagg".to_string()),
256                    json_object_agg_fn.syntax().clone(),
257                ));
258            }
259            if let Some(json_array_agg_fn) = call_expr.json_array_agg_fn() {
260                return Some((
261                    ColumnName::Column("json_arrayagg".to_string()),
262                    json_array_agg_fn.syntax().clone(),
263                ));
264            }
265            if let Some(json_query_fn) = call_expr.json_query_fn() {
266                return Some((
267                    ColumnName::Column("json_query".to_string()),
268                    json_query_fn.syntax().clone(),
269                ));
270            }
271            if let Some(json_scalar_fn) = call_expr.json_scalar_fn() {
272                return Some((
273                    ColumnName::Column("json_scalar".to_string()),
274                    json_scalar_fn.syntax().clone(),
275                ));
276            }
277            if let Some(json_serialize_fn) = call_expr.json_serialize_fn() {
278                return Some((
279                    ColumnName::Column("json_serialize".to_string()),
280                    json_serialize_fn.syntax().clone(),
281                ));
282            }
283            if let Some(json_value_fn) = call_expr.json_value_fn() {
284                return Some((
285                    ColumnName::Column("json_value".to_string()),
286                    json_value_fn.syntax().clone(),
287                ));
288            }
289            if let Some(json_fn) = call_expr.json_fn() {
290                return Some((
291                    ColumnName::Column("json".to_string()),
292                    json_fn.syntax().clone(),
293                ));
294            }
295            if let Some(substring_fn) = call_expr.substring_fn() {
296                return Some((
297                    ColumnName::Column("substring".to_string()),
298                    substring_fn.syntax().clone(),
299                ));
300            }
301            if let Some(position_fn) = call_expr.position_fn() {
302                return Some((
303                    ColumnName::Column("position".to_string()),
304                    position_fn.syntax().clone(),
305                ));
306            }
307            if let Some(overlay_fn) = call_expr.overlay_fn() {
308                return Some((
309                    ColumnName::Column("overlay".to_string()),
310                    overlay_fn.syntax().clone(),
311                ));
312            }
313            if let Some(trim_fn) = call_expr.trim_fn() {
314                return Some((
315                    ColumnName::Column("trim".to_string()),
316                    trim_fn.syntax().clone(),
317                ));
318            }
319            if let Some(xml_root_fn) = call_expr.xml_root_fn() {
320                return Some((
321                    ColumnName::Column("xml_root".to_string()),
322                    xml_root_fn.syntax().clone(),
323                ));
324            }
325            if let Some(xml_serialize_fn) = call_expr.xml_serialize_fn() {
326                return Some((
327                    ColumnName::Column("xml_serialize".to_string()),
328                    xml_serialize_fn.syntax().clone(),
329                ));
330            }
331            if let Some(xml_element_fn) = call_expr.xml_element_fn() {
332                return Some((
333                    ColumnName::Column("xml_element".to_string()),
334                    xml_element_fn.syntax().clone(),
335                ));
336            }
337            if let Some(xml_forest_fn) = call_expr.xml_forest_fn() {
338                return Some((
339                    ColumnName::Column("xml_forest".to_string()),
340                    xml_forest_fn.syntax().clone(),
341                ));
342            }
343            if let Some(xml_exists_fn) = call_expr.xml_exists_fn() {
344                return Some((
345                    ColumnName::Column("xml_exists".to_string()),
346                    xml_exists_fn.syntax().clone(),
347                ));
348            }
349            if let Some(xml_parse_fn) = call_expr.xml_parse_fn() {
350                return Some((
351                    ColumnName::Column("xml_parse".to_string()),
352                    xml_parse_fn.syntax().clone(),
353                ));
354            }
355            if let Some(xml_pi_fn) = call_expr.xml_pi_fn() {
356                return Some((
357                    ColumnName::Column("xml_pi".to_string()),
358                    xml_pi_fn.syntax().clone(),
359                ));
360            }
361            if let Some(func_name) = call_expr.expr() {
362                match func_name {
363                    ast::Expr::ArrayExpr(_)
364                    | ast::Expr::BetweenExpr(_)
365                    | ast::Expr::ParenExpr(_)
366                    | ast::Expr::BinExpr(_)
367                    | ast::Expr::CallExpr(_)
368                    | ast::Expr::CaseExpr(_)
369                    | ast::Expr::CastExpr(_)
370                    | ast::Expr::Literal(_)
371                    | ast::Expr::PostfixExpr(_)
372                    | ast::Expr::PrefixExpr(_)
373                    | ast::Expr::TupleExpr(_)
374                    | ast::Expr::IndexExpr(_)
375                    | ast::Expr::SliceExpr(_) => unreachable!("not possible in the grammar"),
376                    ast::Expr::FieldExpr(field_expr) => {
377                        if let Some(name_ref) = field_expr.field() {
378                            return name_from_name_ref(name_ref, in_type);
379                        }
380                    }
381                    ast::Expr::NameRef(name_ref) => {
382                        return name_from_name_ref(name_ref, in_type);
383                    }
384                }
385            }
386        }
387        ast::Expr::CaseExpr(case) => {
388            if let Some(else_clause) = case.else_clause()
389                && let Some(expr) = else_clause.expr()
390                && let Some((column, node)) = name_from_expr(expr, in_type)
391            {
392                if !matches!(column, ColumnName::UnknownColumn(_)) {
393                    return Some((column, node));
394                }
395            }
396            return Some((ColumnName::Column("case".to_string()), node));
397        }
398        ast::Expr::CastExpr(cast_expr) => {
399            let mut unknown_column = false;
400            if let Some(expr) = cast_expr.expr()
401                && let Some((column, node)) = name_from_expr(expr, in_type)
402            {
403                match column {
404                    ColumnName::Column(_) => return Some((column, node)),
405                    ColumnName::UnknownColumn(_) => unknown_column = true,
406                    ColumnName::Star => (),
407                }
408            }
409            if let Some(ty) = cast_expr.ty() {
410                return name_from_type(ty, unknown_column);
411            }
412        }
413        ast::Expr::FieldExpr(field_expr) => {
414            if let Some(name_ref) = field_expr.field() {
415                return name_from_name_ref(name_ref, in_type);
416            }
417        }
418        ast::Expr::IndexExpr(index_expr) => {
419            if let Some(base) = index_expr.base() {
420                return name_from_expr(base, in_type);
421            }
422        }
423        ast::Expr::SliceExpr(slice_expr) => {
424            if let Some(base) = slice_expr.base() {
425                return name_from_expr(base, in_type);
426            }
427        }
428        ast::Expr::Literal(_) | ast::Expr::PrefixExpr(_) | ast::Expr::PostfixExpr(_) => {
429            return Some((ColumnName::UnknownColumn(None), node));
430        }
431        ast::Expr::NameRef(name_ref) => {
432            return name_from_name_ref(name_ref, in_type);
433        }
434        ast::Expr::ParenExpr(paren_expr) => {
435            if let Some(expr) = paren_expr.expr() {
436                return name_from_expr(expr, in_type);
437            } else if let Some(select) = paren_expr.select()
438                && let Some(mut targets) = select
439                    .select_clause()
440                    .and_then(|x| x.target_list())
441                    .map(|x| x.targets())
442                && let Some(target) = targets.next()
443            {
444                return ColumnName::from_target(target);
445            }
446        }
447        ast::Expr::TupleExpr(_) => {
448            return Some((ColumnName::Column("row".to_string()), node));
449        }
450    }
451    None
452}
453
454#[test]
455fn examples() {
456    use insta::assert_snapshot;
457
458    // array
459    assert_snapshot!(name("array(select 1)"), @"array");
460    assert_snapshot!(name("array[1, 2, 3]"), @"array");
461
462    // unknown columns
463    assert_snapshot!(name("1 between 0 and 10"), @"?column?");
464    assert_snapshot!(name("1 + 2"), @"?column?");
465    assert_snapshot!(name("42"), @"?column?");
466    assert_snapshot!(name("'string'"), @"?column?");
467    // prefix
468    assert_snapshot!(name("-42"), @"?column?");
469    assert_snapshot!(name("|/ 42"), @"?column?");
470    // postfix
471    assert_snapshot!(name("x is null"), @"?column?");
472    assert_snapshot!(name("x is not null"), @"?column?");
473    // paren expr
474    assert_snapshot!(name("(1 * 2)"), @"?column?");
475    assert_snapshot!(name("(select 1 as a)"), @"a");
476
477    // func
478    assert_snapshot!(name("count(*)"), @"count");
479    assert_snapshot!(name("schema.func_name(1)"), @"func_name");
480
481    // special funcs
482    assert_snapshot!(name("extract(year from now())"), @"extract");
483    assert_snapshot!(name("exists(select 1)"), @"exists");
484    assert_snapshot!(name(r#"json_exists('{"a":1}', '$.a')"#), @"json_exists");
485    assert_snapshot!(name("json_array(1, 2)"), @"json_array");
486    assert_snapshot!(name("json_object('a': 1)"), @"json_object");
487    assert_snapshot!(name("json_objectagg('a': 1)"), @"json_objectagg");
488    assert_snapshot!(name("json_arrayagg(1)"), @"json_arrayagg");
489    assert_snapshot!(name(r#"json_query('{"a":1}', '$.a')"#), @"json_query");
490    assert_snapshot!(name("json_scalar(1)"), @"json_scalar");
491    assert_snapshot!(name(r#"json_serialize('{"a":1}')"#), @"json_serialize");
492    assert_snapshot!(name(r#"json_value('{"a":1}', '$.a')"#), @"json_value");
493    assert_snapshot!(name(r#"json('{"a":1}')"#), @"json");
494    assert_snapshot!(name("substring('hello' from 2 for 3)"), @"substring");
495    assert_snapshot!(name("position('a' in 'abc')"), @"position");
496    assert_snapshot!(name("overlay('hello' placing 'X' from 2)"), @"overlay");
497    assert_snapshot!(name("trim('  hi  ')"), @"trim");
498    assert_snapshot!(name("xmlroot('<a/>', version '1.0')"), @"xml_root");
499    assert_snapshot!(name("xmlserialize(document '<a/>' as text)"), @"xml_serialize");
500    assert_snapshot!(name("xmlelement(name foo, 'bar')"), @"xml_element");
501    assert_snapshot!(name("xmlforest('bar' as foo)"), @"xml_forest");
502    assert_snapshot!(name("xmlexists('//a' passing '<a/>')"), @"xml_exists");
503    assert_snapshot!(name("xmlparse(document '<a/>')"), @"xml_parse");
504    assert_snapshot!(name("xmlpi(name foo, 'bar')"), @"xml_pi");
505
506    // index
507    assert_snapshot!(name("foo[bar]"), @"foo");
508    assert_snapshot!(name("foo[1]"), @"foo");
509
510    // column
511    assert_snapshot!(name("database.schema.table.column"), @"column");
512    assert_snapshot!(name("t.a"), @"a");
513    assert_snapshot!(name("col_name"), @"col_name");
514    assert_snapshot!(name("(c)"), @"c");
515
516    // case
517    assert_snapshot!(name("case when true then 'foo' end"), @"case");
518    assert_snapshot!(name("case when true then 'foo' else now()::text end"), @"now");
519    assert_snapshot!(name("case when true then 'foo' else 'bar' end"), @"case");
520    assert_snapshot!(name("case when true then 'foo' else '1'::bigint::text end"), @"case");
521
522    // casts
523    assert_snapshot!(name("now()::text"), @"now");
524    assert_snapshot!(name("cast(col_name as text)"), @"col_name");
525    assert_snapshot!(name("col_name::text"), @"col_name");
526    assert_snapshot!(name("col_name::int::text"), @"col_name");
527    assert_snapshot!(name("'1'::bigint"), @"int8");
528    assert_snapshot!(name("'1'::int"), @"int4");
529    assert_snapshot!(name("'1'::smallint"), @"int2");
530    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::bigint[][]"), @"int8");
531    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[][]"), @"int4");
532    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::smallint[]"), @"int2");
533    assert_snapshot!(name("pg_catalog.varchar(100) '{1}'"), @"varchar");
534    assert_snapshot!(name("'{1}'::integer[];"), @"int4");
535    assert_snapshot!(name("'{1}'::pg_catalog.varchar(1)[]::integer[];"), @"int4");
536    assert_snapshot!(name("'1'::bigint::smallint"), @"int2");
537
538    // alias
539    // with quoting
540    assert_snapshot!(name(r#"'foo' as "FOO""#), @"FOO");
541    assert_snapshot!(name(r#"'foo' as "foo""#), @"foo");
542    // without quoting
543    assert_snapshot!(name(r#"'foo' as FOO"#), @"foo");
544    assert_snapshot!(name(r#"'foo' as foo"#), @"foo");
545
546    // tuple
547    assert_snapshot!(name("(1, 2, 3)"), @"row");
548    assert_snapshot!(name("(1, 2, 3)::address"), @"row");
549
550    // composite type
551    assert_snapshot!(name("(x).city"), @"city");
552
553    // array types
554    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[]"), @"int4");
555    assert_snapshot!(name("cast('{foo}' as text[])"), @"text");
556
557    // bit types
558    assert_snapshot!(name("cast('1010' as bit varying(10))"), @"varbit");
559
560    // char types
561    assert_snapshot!(name("cast('hello' as character varying(10))"), @"varchar");
562    assert_snapshot!(name("cast('hello' as char varying(5))"), @"varchar");
563    assert_snapshot!(name("cast('hello' as char(5))"), @"bpchar");
564    assert_snapshot!(name("cast('hello' as character)"), @"bpchar");
565    assert_snapshot!(name("cast('hello' as bpchar)"), @"bpchar");
566
567    assert_snapshot!(name(r#"cast('hello' as "char")"#), @"char");
568
569    // double types
570    assert_snapshot!(name("cast(1.5 as double precision)"), @"float8");
571    // real
572    assert_snapshot!(name("cast(1.5 as real)"), @"float4");
573
574    // interval types
575    assert_snapshot!(name("cast('1 hour' as interval hour to minute)"), @"interval");
576
577    // percent types
578    assert_snapshot!(name("cast(foo as schema.%TYPE)"), @"foo");
579
580    // time types
581    assert_snapshot!(name("cast('12:00:00' as time(6) without time zone)"), @"time");
582    assert_snapshot!(name("cast('12:00:00' as time(6) with time zone)"), @"timetz");
583    assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) with time zone)"), @"timestamptz");
584    assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) without time zone)"), @"timestamp");
585
586    #[track_caller]
587    fn name(sql: &str) -> String {
588        let sql = "select ".to_string() + sql;
589        let parse = squawk_syntax::SourceFile::parse(&sql);
590        assert_eq!(parse.errors(), vec![]);
591        let file = parse.tree();
592
593        let stmt = file.stmts().next().unwrap();
594        let ast::Stmt::Select(select) = stmt else {
595            unreachable!()
596        };
597
598        let target = select
599            .select_clause()
600            .and_then(|sc| sc.target_list())
601            .and_then(|tl| tl.targets().next())
602            .unwrap();
603
604        ColumnName::from_target(target)
605            .and_then(|x| x.0.to_string())
606            .unwrap()
607    }
608}