use std::borrow::Cow;
#[cfg(test)]
use insta::assert_snapshot;
use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
#[cfg(test)]
use crate::SourceFile;
use rowan::Direction;
use crate::ast;
use crate::ast::AstNode;
use crate::{SyntaxKind, SyntaxNode, SyntaxToken, TokenText};
use super::support;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LitKind {
BitString(SyntaxToken),
ByteString(SyntaxToken),
DollarQuotedString(SyntaxToken),
EscString(SyntaxToken),
FloatNumber(SyntaxToken),
IntNumber(SyntaxToken),
Null(SyntaxToken),
PositionalParam(SyntaxToken),
String(SyntaxToken),
UnicodeEscString(SyntaxToken),
}
impl ast::Literal {
pub fn kind(&self) -> Option<LitKind> {
let token = self.syntax().first_child_or_token()?.into_token()?;
let kind = match token.kind() {
SyntaxKind::STRING => LitKind::String(token),
SyntaxKind::NULL_KW => LitKind::Null(token),
SyntaxKind::FLOAT_NUMBER => LitKind::FloatNumber(token),
SyntaxKind::INT_NUMBER => LitKind::IntNumber(token),
SyntaxKind::BYTE_STRING => LitKind::ByteString(token),
SyntaxKind::BIT_STRING => LitKind::BitString(token),
SyntaxKind::DOLLAR_QUOTED_STRING => LitKind::DollarQuotedString(token),
SyntaxKind::UNICODE_ESC_STRING => LitKind::UnicodeEscString(token),
SyntaxKind::ESC_STRING => LitKind::EscString(token),
SyntaxKind::POSITIONAL_PARAM => LitKind::PositionalParam(token),
_ => return None,
};
Some(kind)
}
}
impl ast::Constraint {
#[inline]
pub fn constraint_name(&self) -> Option<ast::ConstraintName> {
support::child(self.syntax())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BinOp {
And(SyntaxToken),
AtTimeZone(ast::AtTimeZone),
Caret(SyntaxToken),
Collate(SyntaxToken),
ColonColon(ast::ColonColon),
ColonEq(ast::ColonEq),
CustomOp(ast::CustomOp),
Eq(SyntaxToken),
FatArrow(ast::FatArrow),
Gteq(ast::Gteq),
Ilike(SyntaxToken),
In(SyntaxToken),
Is(SyntaxToken),
IsDistinctFrom(ast::IsDistinctFrom),
IsNot(ast::IsNot),
IsNotDistinctFrom(ast::IsNotDistinctFrom),
LAngle(SyntaxToken),
Like(SyntaxToken),
Lteq(ast::Lteq),
Minus(SyntaxToken),
Neq(ast::Neq),
Neqb(ast::Neqb),
NotIlike(ast::NotIlike),
NotIn(ast::NotIn),
NotLike(ast::NotLike),
NotSimilarTo(ast::NotSimilarTo),
OperatorCall(ast::OperatorCall),
Or(SyntaxToken),
Overlaps(SyntaxToken),
Percent(SyntaxToken),
Plus(SyntaxToken),
RAngle(SyntaxToken),
SimilarTo(ast::SimilarTo),
Slash(SyntaxToken),
Star(SyntaxToken),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PostfixOp {
AtLocal(SyntaxToken),
IsJson(ast::IsJson),
IsJsonArray(ast::IsJsonArray),
IsJsonObject(ast::IsJsonObject),
IsJsonScalar(ast::IsJsonScalar),
IsJsonValue(ast::IsJsonValue),
IsNormalized(ast::IsNormalized),
IsNotJson(ast::IsNotJson),
IsNotJsonArray(ast::IsNotJsonArray),
IsNotJsonObject(ast::IsNotJsonObject),
IsNotJsonScalar(ast::IsNotJsonScalar),
IsNotJsonValue(ast::IsNotJsonValue),
IsNotNormalized(ast::IsNotNormalized),
IsNull(SyntaxToken),
NotNull(SyntaxToken),
}
impl ast::BinExpr {
#[inline]
pub fn lhs(&self) -> Option<ast::Expr> {
support::children(self.syntax()).next()
}
#[inline]
pub fn rhs(&self) -> Option<ast::Expr> {
support::children(self.syntax()).nth(1)
}
pub fn op(&self) -> Option<BinOp> {
let lhs = self.lhs()?;
for child in lhs.syntax().siblings_with_tokens(Direction::Next).skip(1) {
match child {
NodeOrToken::Token(token) => {
let op = match token.kind() {
SyntaxKind::AND_KW => BinOp::And(token),
SyntaxKind::CARET => BinOp::Caret(token),
SyntaxKind::COLLATE_KW => BinOp::Collate(token),
SyntaxKind::EQ => BinOp::Eq(token),
SyntaxKind::ILIKE_KW => BinOp::Ilike(token),
SyntaxKind::IN_KW => BinOp::In(token),
SyntaxKind::IS_KW => BinOp::Is(token),
SyntaxKind::L_ANGLE => BinOp::LAngle(token),
SyntaxKind::LIKE_KW => BinOp::Like(token),
SyntaxKind::MINUS => BinOp::Minus(token),
SyntaxKind::OR_KW => BinOp::Or(token),
SyntaxKind::OVERLAPS_KW => BinOp::Overlaps(token),
SyntaxKind::PERCENT => BinOp::Percent(token),
SyntaxKind::PLUS => BinOp::Plus(token),
SyntaxKind::R_ANGLE => BinOp::RAngle(token),
SyntaxKind::SLASH => BinOp::Slash(token),
SyntaxKind::STAR => BinOp::Star(token),
_ => continue,
};
return Some(op);
}
NodeOrToken::Node(node) => {
let op = match node.kind() {
SyntaxKind::AT_TIME_ZONE => {
BinOp::AtTimeZone(ast::AtTimeZone { syntax: node })
}
SyntaxKind::COLON_COLON => {
BinOp::ColonColon(ast::ColonColon { syntax: node })
}
SyntaxKind::COLON_EQ => BinOp::ColonEq(ast::ColonEq { syntax: node }),
SyntaxKind::CUSTOM_OP => BinOp::CustomOp(ast::CustomOp { syntax: node }),
SyntaxKind::FAT_ARROW => BinOp::FatArrow(ast::FatArrow { syntax: node }),
SyntaxKind::GTEQ => BinOp::Gteq(ast::Gteq { syntax: node }),
SyntaxKind::IS_DISTINCT_FROM => {
BinOp::IsDistinctFrom(ast::IsDistinctFrom { syntax: node })
}
SyntaxKind::IS_NOT => BinOp::IsNot(ast::IsNot { syntax: node }),
SyntaxKind::IS_NOT_DISTINCT_FROM => {
BinOp::IsNotDistinctFrom(ast::IsNotDistinctFrom { syntax: node })
}
SyntaxKind::LTEQ => BinOp::Lteq(ast::Lteq { syntax: node }),
SyntaxKind::NEQ => BinOp::Neq(ast::Neq { syntax: node }),
SyntaxKind::NEQB => BinOp::Neqb(ast::Neqb { syntax: node }),
SyntaxKind::NOT_ILIKE => BinOp::NotIlike(ast::NotIlike { syntax: node }),
SyntaxKind::NOT_IN => BinOp::NotIn(ast::NotIn { syntax: node }),
SyntaxKind::NOT_LIKE => BinOp::NotLike(ast::NotLike { syntax: node }),
SyntaxKind::NOT_SIMILAR_TO => {
BinOp::NotSimilarTo(ast::NotSimilarTo { syntax: node })
}
SyntaxKind::OPERATOR_CALL => {
BinOp::OperatorCall(ast::OperatorCall { syntax: node })
}
SyntaxKind::SIMILAR_TO => BinOp::SimilarTo(ast::SimilarTo { syntax: node }),
_ => continue,
};
return Some(op);
}
}
}
None
}
}
impl ast::PostfixExpr {
pub fn op(&self) -> Option<PostfixOp> {
let lhs = self.expr()?;
let siblings = lhs.syntax().siblings_with_tokens(Direction::Next).skip(1);
for child in siblings {
match child {
NodeOrToken::Token(token) => {
let op = match token.kind() {
SyntaxKind::AT_KW => PostfixOp::AtLocal(token),
SyntaxKind::ISNULL_KW => PostfixOp::IsNull(token),
SyntaxKind::NOTNULL_KW => PostfixOp::NotNull(token),
_ => continue,
};
return Some(op);
}
NodeOrToken::Node(node) => {
let op = match node.kind() {
SyntaxKind::IS_JSON => PostfixOp::IsJson(ast::IsJson { syntax: node }),
SyntaxKind::IS_JSON_ARRAY => {
PostfixOp::IsJsonArray(ast::IsJsonArray { syntax: node })
}
SyntaxKind::IS_JSON_OBJECT => {
PostfixOp::IsJsonObject(ast::IsJsonObject { syntax: node })
}
SyntaxKind::IS_JSON_SCALAR => {
PostfixOp::IsJsonScalar(ast::IsJsonScalar { syntax: node })
}
SyntaxKind::IS_JSON_VALUE => {
PostfixOp::IsJsonValue(ast::IsJsonValue { syntax: node })
}
SyntaxKind::IS_NORMALIZED => {
PostfixOp::IsNormalized(ast::IsNormalized { syntax: node })
}
SyntaxKind::IS_NOT_JSON => {
PostfixOp::IsNotJson(ast::IsNotJson { syntax: node })
}
SyntaxKind::IS_NOT_JSON_ARRAY => {
PostfixOp::IsNotJsonArray(ast::IsNotJsonArray { syntax: node })
}
SyntaxKind::IS_NOT_JSON_OBJECT => {
PostfixOp::IsNotJsonObject(ast::IsNotJsonObject { syntax: node })
}
SyntaxKind::IS_NOT_JSON_SCALAR => {
PostfixOp::IsNotJsonScalar(ast::IsNotJsonScalar { syntax: node })
}
SyntaxKind::IS_NOT_JSON_VALUE => {
PostfixOp::IsNotJsonValue(ast::IsNotJsonValue { syntax: node })
}
SyntaxKind::IS_NOT_NORMALIZED => {
PostfixOp::IsNotNormalized(ast::IsNotNormalized { syntax: node })
}
_ => continue,
};
return Some(op);
}
}
}
None
}
}
impl ast::FieldExpr {
#[inline]
pub fn base(&self) -> Option<ast::Expr> {
support::children(self.syntax()).next()
}
#[inline]
pub fn field(&self) -> Option<ast::NameRef> {
support::children(self.syntax()).last()
}
}
impl ast::IndexExpr {
#[inline]
pub fn base(&self) -> Option<ast::Expr> {
support::children(&self.syntax).next()
}
#[inline]
pub fn index(&self) -> Option<ast::Expr> {
support::children(&self.syntax).nth(1)
}
}
impl ast::SliceExpr {
#[inline]
pub fn base(&self) -> Option<ast::Expr> {
support::children(&self.syntax).next()
}
#[inline]
pub fn start(&self) -> Option<ast::Expr> {
let colon = self.colon_token()?;
support::children(&self.syntax)
.skip(1)
.find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
}
#[inline]
pub fn end(&self) -> Option<ast::Expr> {
let colon = self.colon_token()?;
support::children(&self.syntax)
.find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
}
}
impl ast::RenameColumn {
#[inline]
pub fn from(&self) -> Option<ast::NameRef> {
support::children(&self.syntax).nth(0)
}
#[inline]
pub fn to(&self) -> Option<ast::NameRef> {
support::children(&self.syntax).nth(1)
}
}
impl ast::ForeignKeyConstraint {
#[inline]
pub fn from_columns(&self) -> Option<ast::ColumnList> {
support::children(&self.syntax).nth(0)
}
#[inline]
pub fn to_columns(&self) -> Option<ast::ColumnList> {
support::children(&self.syntax).nth(1)
}
}
impl ast::BetweenExpr {
#[inline]
pub fn target(&self) -> Option<ast::Expr> {
support::children(&self.syntax).nth(0)
}
#[inline]
pub fn start(&self) -> Option<ast::Expr> {
support::children(&self.syntax).nth(1)
}
#[inline]
pub fn end(&self) -> Option<ast::Expr> {
support::children(&self.syntax).nth(2)
}
}
impl ast::WhenClause {
#[inline]
pub fn condition(&self) -> Option<ast::Expr> {
support::children(&self.syntax).next()
}
#[inline]
pub fn then(&self) -> Option<ast::Expr> {
support::children(&self.syntax).nth(1)
}
}
impl ast::CompoundSelect {
#[inline]
pub fn lhs(&self) -> Option<ast::SelectVariant> {
support::children(&self.syntax).next()
}
#[inline]
pub fn rhs(&self) -> Option<ast::SelectVariant> {
support::children(&self.syntax).nth(1)
}
}
impl ast::NameRef {
#[inline]
pub fn text(&self) -> TokenText<'_> {
text_of_first_token(self.syntax())
}
}
impl ast::Name {
#[inline]
pub fn text(&self) -> TokenText<'_> {
text_of_first_token(self.syntax())
}
}
impl ast::CharType {
#[inline]
pub fn text(&self) -> TokenText<'_> {
text_of_first_token(self.syntax())
}
}
impl ast::OpSig {
#[inline]
pub fn lhs(&self) -> Option<ast::Type> {
support::children(self.syntax()).next()
}
#[inline]
pub fn rhs(&self) -> Option<ast::Type> {
support::children(self.syntax()).nth(1)
}
}
impl ast::CastSig {
#[inline]
pub fn lhs(&self) -> Option<ast::Type> {
support::children(self.syntax()).next()
}
#[inline]
pub fn rhs(&self) -> Option<ast::Type> {
support::children(self.syntax()).nth(1)
}
}
pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
green_ref
.children()
.next()
.and_then(NodeOrToken::into_token)
.unwrap()
}
match node.green() {
Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
}
}
impl ast::WithQuery {
#[inline]
pub fn with_clause(&self) -> Option<ast::WithClause> {
support::child(self.syntax())
}
}
impl ast::HasParamList for ast::FunctionSig {}
impl ast::HasParamList for ast::Aggregate {}
impl ast::NameLike for ast::Name {}
impl ast::NameLike for ast::NameRef {}
impl ast::HasWithClause for ast::Select {}
impl ast::HasWithClause for ast::SelectInto {}
impl ast::HasWithClause for ast::Insert {}
impl ast::HasWithClause for ast::Update {}
impl ast::HasWithClause for ast::Delete {}
impl ast::HasCreateTable for ast::CreateTable {}
impl ast::HasCreateTable for ast::CreateForeignTable {}
impl ast::HasCreateTable for ast::CreateTableLike {}
#[test]
fn index_expr() {
let source_code = "
select foo[bar];
";
let parse = SourceFile::parse(source_code);
assert!(parse.errors().is_empty());
let file: SourceFile = parse.tree();
let stmt = file.stmts().next().unwrap();
let ast::Stmt::Select(select) = stmt else {
unreachable!()
};
let select_clause = select.select_clause().unwrap();
let target = select_clause
.target_list()
.unwrap()
.targets()
.next()
.unwrap();
let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
unreachable!()
};
let base = index_expr.base().unwrap();
let index = index_expr.index().unwrap();
assert_eq!(base.syntax().text(), "foo");
assert_eq!(index.syntax().text(), "bar");
}
#[test]
fn slice_expr() {
use insta::assert_snapshot;
let source_code = "
select x[1:2], x[2:], x[:3], x[:];
";
let parse = SourceFile::parse(source_code);
assert!(parse.errors().is_empty());
let file: SourceFile = parse.tree();
let stmt = file.stmts().next().unwrap();
let ast::Stmt::Select(select) = stmt else {
unreachable!()
};
let select_clause = select.select_clause().unwrap();
let mut targets = select_clause.target_list().unwrap().targets();
let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
unreachable!()
};
assert_snapshot!(slice.syntax(), @"x[1:2]");
assert_eq!(slice.base().unwrap().syntax().text(), "x");
assert_eq!(slice.start().unwrap().syntax().text(), "1");
assert_eq!(slice.end().unwrap().syntax().text(), "2");
let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
unreachable!()
};
assert_snapshot!(slice.syntax(), @"x[2:]");
assert_eq!(slice.base().unwrap().syntax().text(), "x");
assert_eq!(slice.start().unwrap().syntax().text(), "2");
assert!(slice.end().is_none());
let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
unreachable!()
};
assert_snapshot!(slice.syntax(), @"x[:3]");
assert_eq!(slice.base().unwrap().syntax().text(), "x");
assert!(slice.start().is_none());
assert_eq!(slice.end().unwrap().syntax().text(), "3");
let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
unreachable!()
};
assert_snapshot!(slice.syntax(), @"x[:]");
assert_eq!(slice.base().unwrap().syntax().text(), "x");
assert!(slice.start().is_none());
assert!(slice.end().is_none());
}
#[test]
fn field_expr() {
let source_code = "
select foo.bar;
";
let parse = SourceFile::parse(source_code);
assert!(parse.errors().is_empty());
let file: SourceFile = parse.tree();
let stmt = file.stmts().next().unwrap();
let ast::Stmt::Select(select) = stmt else {
unreachable!()
};
let select_clause = select.select_clause().unwrap();
let target = select_clause
.target_list()
.unwrap()
.targets()
.next()
.unwrap();
let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
unreachable!()
};
let base = field_expr.base().unwrap();
let field = field_expr.field().unwrap();
assert_eq!(base.syntax().text(), "foo");
assert_eq!(field.syntax().text(), "bar");
}
#[test]
fn between_expr() {
let source_code = "
select 2 between 1 and 3;
";
let parse = SourceFile::parse(source_code);
assert!(parse.errors().is_empty());
let file: SourceFile = parse.tree();
let stmt = file.stmts().next().unwrap();
let ast::Stmt::Select(select) = stmt else {
unreachable!()
};
let select_clause = select.select_clause().unwrap();
let target = select_clause
.target_list()
.unwrap()
.targets()
.next()
.unwrap();
let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
unreachable!()
};
let target = between_expr.target().unwrap();
let start = between_expr.start().unwrap();
let end = between_expr.end().unwrap();
assert_eq!(target.syntax().text(), "2");
assert_eq!(start.syntax().text(), "1");
assert_eq!(end.syntax().text(), "3");
}
#[test]
fn cast_expr() {
use insta::assert_snapshot;
let cast = extract_expr("select cast('123' as int)");
assert!(cast.expr().is_some());
assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
assert!(cast.ty().is_some());
assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
let cast = extract_expr("select cast('123' as pg_catalog.int4)");
assert!(cast.expr().is_some());
assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
assert!(cast.ty().is_some());
assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
let cast = extract_expr("select int '123'");
assert!(cast.expr().is_some());
assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
assert!(cast.ty().is_some());
assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
let cast = extract_expr("select pg_catalog.int4 '123'");
assert!(cast.expr().is_some());
assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
assert!(cast.ty().is_some());
assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
let cast = extract_expr("select '123'::int");
assert!(cast.expr().is_some());
assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
assert!(cast.ty().is_some());
assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
let cast = extract_expr("select '123'::int4");
assert!(cast.expr().is_some());
assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
assert!(cast.ty().is_some());
assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
let cast = extract_expr("select '123'::pg_catalog.int4");
assert!(cast.expr().is_some());
assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
assert!(cast.ty().is_some());
assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
assert!(cast.expr().is_some());
assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
assert!(cast.ty().is_some());
assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
assert!(cast.expr().is_some());
assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
assert!(cast.ty().is_some());
assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
assert!(cast.expr().is_some());
assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
assert!(cast.ty().is_some());
assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
let cast = extract_expr("select interval '1' month");
assert!(cast.expr().is_some());
assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
assert!(cast.ty().is_some());
assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
fn extract_expr(sql: &str) -> ast::CastExpr {
let parse = SourceFile::parse(sql);
assert!(parse.errors().is_empty());
let file: SourceFile = parse.tree();
let node = file
.stmts()
.map(|x| match x {
ast::Stmt::Select(select) => select
.select_clause()
.unwrap()
.target_list()
.unwrap()
.targets()
.next()
.unwrap()
.expr()
.unwrap()
.clone(),
_ => unreachable!(),
})
.next()
.unwrap();
match node {
ast::Expr::CastExpr(cast) => cast,
_ => unreachable!(),
}
}
}
#[test]
fn op_sig() {
let source_code = "
alter operator p.+ (int4, int8)
owner to u;
";
let parse = SourceFile::parse(source_code);
assert!(parse.errors().is_empty());
let file: SourceFile = parse.tree();
let stmt = file.stmts().next().unwrap();
let ast::Stmt::AlterOperator(alter_op) = stmt else {
unreachable!()
};
let op_sig = alter_op.op_sig().unwrap();
let lhs = op_sig.lhs().unwrap();
let rhs = op_sig.rhs().unwrap();
assert_snapshot!(lhs.syntax().text(), @"int4");
assert_snapshot!(rhs.syntax().text(), @"int8");
}
#[test]
fn cast_sig() {
let source_code = "
drop cast (text as int);
";
let parse = SourceFile::parse(source_code);
assert!(parse.errors().is_empty());
let file: SourceFile = parse.tree();
let stmt = file.stmts().next().unwrap();
let ast::Stmt::DropCast(alter_op) = stmt else {
unreachable!()
};
let cast_sig = alter_op.cast_sig().unwrap();
let lhs = cast_sig.lhs().unwrap();
let rhs = cast_sig.rhs().unwrap();
assert_snapshot!(lhs.syntax().text(), @"text");
assert_snapshot!(rhs.syntax().text(), @"int");
}