use crate::error::Result;
use sqlparser::{
ast::Ident,
dialect::GenericDialect,
parser::{Parser, ParserError},
tokenizer::{Token, TokenWithLocation},
};
use std::borrow::Cow;
#[derive(Debug, Clone)]
pub struct ResolvedTableReference<'a> {
pub catalog: Cow<'a, str>,
pub schema: Cow<'a, str>,
pub table: Cow<'a, str>,
}
impl<'a> std::fmt::Display for ResolvedTableReference<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}.{}", self.catalog, self.schema, self.table)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TableReference<'a> {
Bare {
table: Cow<'a, str>,
},
Partial {
schema: Cow<'a, str>,
table: Cow<'a, str>,
},
Full {
catalog: Cow<'a, str>,
schema: Cow<'a, str>,
table: Cow<'a, str>,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum OwnedTableReference {
Bare {
table: String,
},
Partial {
schema: String,
table: String,
},
Full {
catalog: String,
schema: String,
table: String,
},
}
impl OwnedTableReference {
pub fn as_table_reference(&self) -> TableReference<'_> {
match self {
Self::Bare { table } => TableReference::Bare {
table: table.into(),
},
Self::Partial { schema, table } => TableReference::Partial {
schema: schema.into(),
table: table.into(),
},
Self::Full {
catalog,
schema,
table,
} => TableReference::Full {
catalog: catalog.into(),
schema: schema.into(),
table: table.into(),
},
}
}
pub fn table(&self) -> &str {
match self {
Self::Full { table, .. }
| Self::Partial { table, .. }
| Self::Bare { table } => table,
}
}
}
impl std::fmt::Display for OwnedTableReference {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OwnedTableReference::Bare { table } => write!(f, "{table}"),
OwnedTableReference::Partial { schema, table } => {
write!(f, "{schema}.{table}")
}
OwnedTableReference::Full {
catalog,
schema,
table,
} => write!(f, "{catalog}.{schema}.{table}"),
}
}
}
impl<'a> From<&'a OwnedTableReference> for TableReference<'a> {
fn from(r: &'a OwnedTableReference) -> Self {
r.as_table_reference()
}
}
impl<'a> TableReference<'a> {
pub fn table(&self) -> &str {
match self {
Self::Full { table, .. }
| Self::Partial { table, .. }
| Self::Bare { table } => table,
}
}
pub fn resolve(
self,
default_catalog: &'a str,
default_schema: &'a str,
) -> ResolvedTableReference<'a> {
match self {
Self::Full {
catalog,
schema,
table,
} => ResolvedTableReference {
catalog,
schema,
table,
},
Self::Partial { schema, table } => ResolvedTableReference {
catalog: default_catalog.into(),
schema,
table,
},
Self::Bare { table } => ResolvedTableReference {
catalog: default_catalog.into(),
schema: default_schema.into(),
table,
},
}
}
pub fn parse_str(s: &'a str) -> Self {
let mut parts = parse_identifiers(s)
.unwrap_or_default()
.into_iter()
.map(|id| match id.quote_style {
Some(_) => id.value,
None => id.value.to_ascii_lowercase(),
})
.collect::<Vec<_>>();
match parts.len() {
1 => Self::Bare {
table: parts.remove(0).into(),
},
2 => Self::Partial {
schema: parts.remove(0).into(),
table: parts.remove(0).into(),
},
3 => Self::Full {
catalog: parts.remove(0).into(),
schema: parts.remove(0).into(),
table: parts.remove(0).into(),
},
_ => Self::Bare { table: s.into() },
}
}
}
fn parse_identifiers(s: &str) -> Result<Vec<Ident>> {
let dialect = GenericDialect;
let mut parser = Parser::new(&dialect).try_with_sql(s)?;
let mut idents = vec![];
match parser.next_token_no_skip() {
Some(TokenWithLocation {
token: Token::Word(w),
..
}) => idents.push(w.to_ident()),
Some(TokenWithLocation { token, .. }) => {
return Err(ParserError::ParserError(format!(
"Unexpected token in identifier: {token}"
)))?
}
None => {
return Err(ParserError::ParserError(
"Empty input when parsing identifier".to_string(),
))?
}
};
while let Some(TokenWithLocation { token, .. }) = parser.next_token_no_skip() {
match token {
Token::Period => match parser.next_token_no_skip() {
Some(TokenWithLocation {
token: Token::Word(w),
..
}) => idents.push(w.to_ident()),
Some(TokenWithLocation { token, .. }) => {
return Err(ParserError::ParserError(format!(
"Unexpected token following period in identifier: {token}"
)))?
}
None => {
return Err(ParserError::ParserError(
"Trailing period in identifier".to_string(),
))?
}
},
_ => {
return Err(ParserError::ParserError(format!(
"Unexpected token in identifier: {token}"
)))?
}
}
}
Ok(idents)
}
impl<'a> From<&'a str> for TableReference<'a> {
fn from(s: &'a str) -> Self {
Self::parse_str(s)
}
}
impl<'a> From<ResolvedTableReference<'a>> for TableReference<'a> {
fn from(resolved: ResolvedTableReference<'a>) -> Self {
Self::Full {
catalog: resolved.catalog,
schema: resolved.schema,
table: resolved.table,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_identifiers() -> Result<()> {
let s = "CATALOG.\"F(o)o. \"\"bar\".table";
let actual = parse_identifiers(s)?;
let expected = vec![
Ident {
value: "CATALOG".to_string(),
quote_style: None,
},
Ident {
value: "F(o)o. \"bar".to_string(),
quote_style: Some('"'),
},
Ident {
value: "table".to_string(),
quote_style: None,
},
];
assert_eq!(expected, actual);
let s = "";
let err = parse_identifiers(s).expect_err("didn't fail to parse");
assert_eq!(
"SQL(ParserError(\"Empty input when parsing identifier\"))",
format!("{err:?}")
);
let s = "*schema.table";
let err = parse_identifiers(s).expect_err("didn't fail to parse");
assert_eq!(
"SQL(ParserError(\"Unexpected token in identifier: *\"))",
format!("{err:?}")
);
let s = "schema.table*";
let err = parse_identifiers(s).expect_err("didn't fail to parse");
assert_eq!(
"SQL(ParserError(\"Unexpected token in identifier: *\"))",
format!("{err:?}")
);
let s = "schema.table.";
let err = parse_identifiers(s).expect_err("didn't fail to parse");
assert_eq!(
"SQL(ParserError(\"Trailing period in identifier\"))",
format!("{err:?}")
);
let s = "schema.*";
let err = parse_identifiers(s).expect_err("didn't fail to parse");
assert_eq!(
"SQL(ParserError(\"Unexpected token following period in identifier: *\"))",
format!("{err:?}")
);
Ok(())
}
#[test]
fn test_table_reference_from_str_normalizes() {
let expected = TableReference::Full {
catalog: Cow::Owned("catalog".to_string()),
schema: Cow::Owned("FOO\".bar".to_string()),
table: Cow::Owned("table".to_string()),
};
let actual = TableReference::from("catalog.\"FOO\"\".bar\".TABLE");
assert_eq!(expected, actual);
let expected = TableReference::Partial {
schema: Cow::Owned("FOO\".bar".to_string()),
table: Cow::Owned("table".to_string()),
};
let actual = TableReference::from("\"FOO\"\".bar\".TABLE");
assert_eq!(expected, actual);
let expected = TableReference::Bare {
table: Cow::Owned("table".to_string()),
};
let actual = TableReference::from("TABLE");
assert_eq!(expected, actual);
let expected = TableReference::Bare {
table: Cow::Owned("TABLE()".to_string()),
};
let actual = TableReference::from("TABLE()");
assert_eq!(expected, actual);
}
}