use sqlparser::keywords::Keyword;
use sqlparser::tokenizer::{Token, TokenWithSpan, Tokenizer};
use sqlparser::dialect::GenericDialect;
use super::extraction::is_separator;
#[derive(Debug, PartialEq, Eq)]
pub enum CompletionContext {
Column,
Keyword,
QualifiedColumn(String),
Table,
None,
}
pub(super) fn detect_sql_context(before_token: &str) -> CompletionContext {
let trimmed = before_token.trim_end();
if trimmed.is_empty() {
return CompletionContext::None;
}
if let Some(before_dot) = trimmed.strip_suffix('.') {
let table_name = before_dot
.rsplit(|c: char| is_separator(c) && c != '.')
.next()
.unwrap_or("");
if !table_name.is_empty() {
return CompletionContext::QualifiedColumn(table_name.to_string());
}
}
let dialect = GenericDialect {};
let mut tokenizer = Tokenizer::new(&dialect, trimmed);
let mut token_buf: Vec<TokenWithSpan> = Vec::new();
let _ = tokenizer.tokenize_with_location_into_buf(&mut token_buf);
let significant: Vec<&Token> = token_buf
.iter()
.map(|token_with_span| &token_with_span.token)
.filter(|token| !matches!(token, Token::Whitespace(_)))
.collect();
if significant.is_empty() {
return CompletionContext::None;
}
classify_last_token(
significant.last().unwrap(),
&significant,
trimmed.len() != before_token.len(),
)
}
fn classify_last_token(
last: &Token,
tokens: &[&Token],
has_trailing_separator: bool,
) -> CompletionContext {
match last {
Token::Word(word) => match word.keyword {
Keyword::NoKeyword => classify_identifier_context(tokens, has_trailing_separator),
Keyword::SELECT | Keyword::DISTINCT => CompletionContext::Column,
Keyword::FROM | Keyword::JOIN | Keyword::INTO => CompletionContext::Table,
Keyword::WHERE
| Keyword::AND
| Keyword::OR
| Keyword::NOT
| Keyword::ON
| Keyword::HAVING
| Keyword::BETWEEN
| Keyword::CASE
| Keyword::WHEN
| Keyword::THEN
| Keyword::ELSE
| Keyword::IN
| Keyword::LIKE
| Keyword::IS
| Keyword::SET => CompletionContext::Column,
Keyword::BY => {
if tokens.len() >= 2
&& let Token::Word(prev) = tokens[tokens.len() - 2]
&& matches!(prev.keyword, Keyword::ORDER | Keyword::GROUP)
{
return CompletionContext::Column;
}
CompletionContext::None
}
Keyword::ASC | Keyword::DESC => CompletionContext::Column,
Keyword::AS | Keyword::LIMIT | Keyword::OFFSET => CompletionContext::None,
_ => CompletionContext::None,
},
Token::Eq | Token::Neq | Token::Lt | Token::Gt | Token::LtEq | Token::GtEq => {
CompletionContext::Column
}
Token::LParen => CompletionContext::Column,
Token::Comma => find_clause_for_comma(tokens),
_ => CompletionContext::None,
}
}
fn classify_identifier_context(
tokens: &[&Token],
has_trailing_separator: bool,
) -> CompletionContext {
if !has_trailing_separator {
return CompletionContext::None;
}
if tokens.len() < 2 {
return CompletionContext::Keyword;
}
match tokens[tokens.len() - 2] {
Token::Word(word) if matches!(word.keyword, Keyword::AS) => CompletionContext::None,
_ => CompletionContext::Keyword,
}
}
fn find_clause_for_comma(tokens: &[&Token]) -> CompletionContext {
for token in tokens.iter().rev().skip(1) {
if let Token::Word(word) = token {
match word.keyword {
Keyword::SELECT | Keyword::DISTINCT => return CompletionContext::Column,
Keyword::FROM | Keyword::JOIN => return CompletionContext::Table,
Keyword::WHERE | Keyword::HAVING | Keyword::ON => return CompletionContext::Column,
Keyword::BY | Keyword::ORDER | Keyword::GROUP => return CompletionContext::Column,
_ => continue,
}
}
}
CompletionContext::None
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
#[case::empty("", CompletionContext::None)]
#[case::whitespace_only(" ", CompletionContext::None)]
#[case::qualified_column("SELECT t.", CompletionContext::QualifiedColumn("t".to_string()))]
#[case::qualified_column_with_schema("SELECT schema.t.", CompletionContext::QualifiedColumn("schema.t".to_string()))]
#[case::select("SELECT", CompletionContext::Column)]
#[case::select_distinct("SELECT DISTINCT", CompletionContext::Column)]
#[case::where_clause("SELECT * FROM t WHERE", CompletionContext::Column)]
#[case::and("SELECT * FROM t WHERE a = 1 AND", CompletionContext::Column)]
#[case::or("SELECT * FROM t WHERE a = 1 OR", CompletionContext::Column)]
#[case::having("SELECT a FROM t GROUP BY a HAVING", CompletionContext::Column)]
#[case::between("SELECT * FROM t WHERE a BETWEEN", CompletionContext::Column)]
#[case::case_keyword("SELECT CASE", CompletionContext::Column)]
#[case::when_keyword("SELECT CASE WHEN", CompletionContext::Column)]
#[case::when_keyword("SELECT CASE WHEN a", CompletionContext::None)]
#[case::then_keyword("SELECT CASE WHEN a THEN", CompletionContext::Column)]
#[case::else_keyword("SELECT CASE WHEN a THEN b ELSE", CompletionContext::Column)]
#[case::in_keyword("SELECT * FROM t WHERE a IN", CompletionContext::Column)]
#[case::like_keyword("SELECT * FROM t WHERE a LIKE", CompletionContext::Column)]
#[case::is_keyword("SELECT * FROM t WHERE a IS", CompletionContext::Column)]
#[case::not_keyword("SELECT * FROM t WHERE NOT", CompletionContext::Column)]
#[case::set_keyword("UPDATE t SET", CompletionContext::Column)]
#[case::on_keyword("SELECT * FROM a JOIN b ON", CompletionContext::Column)]
#[case::from("SELECT * FROM", CompletionContext::Table)]
#[case::join("SELECT * FROM a JOIN", CompletionContext::Table)]
#[case::into("INSERT INTO", CompletionContext::Table)]
#[case::order_by("SELECT * FROM t ORDER BY", CompletionContext::Column)]
#[case::group_by("SELECT a FROM t GROUP BY", CompletionContext::Column)]
#[case::bare_by("BY", CompletionContext::None)]
#[case::asc("SELECT * FROM t ORDER BY a ASC,", CompletionContext::Column)]
#[case::desc("SELECT * FROM t ORDER BY a DESC", CompletionContext::Column)]
#[case::as_keyword("SELECT a AS", CompletionContext::None)]
#[case::limit("SELECT * FROM t LIMIT", CompletionContext::None)]
#[case::offset("SELECT * FROM t LIMIT 10 OFFSET", CompletionContext::None)]
#[case::eq("SELECT * FROM t WHERE a =", CompletionContext::Column)]
#[case::neq("SELECT * FROM t WHERE a !=", CompletionContext::Column)]
#[case::lt("SELECT * FROM t WHERE a <", CompletionContext::Column)]
#[case::gt("SELECT * FROM t WHERE a >", CompletionContext::Column)]
#[case::lte("SELECT * FROM t WHERE a <=", CompletionContext::Column)]
#[case::gte("SELECT * FROM t WHERE a >=", CompletionContext::Column)]
#[case::lparen("SELECT COUNT(", CompletionContext::Column)]
#[case::comma_in_select("SELECT a,", CompletionContext::Column)]
#[case::comma_in_from("SELECT * FROM a,", CompletionContext::Table)]
#[case::comma_in_where("SELECT * FROM t WHERE a = 1 AND b IN (1,", CompletionContext::Column)]
#[case::comma_in_order_by("SELECT * FROM t ORDER BY a,", CompletionContext::Column)]
#[case::plain_identifier("SELECT a", CompletionContext::None)]
#[case::completed_select_item("SELECT a ", CompletionContext::Keyword)]
#[case::table_identifier("SELECT * FROM t", CompletionContext::None)]
#[case::completed_table_identifier("SELECT * FROM t ", CompletionContext::Keyword)]
#[case::alias_after_as("SELECT a AS alias", CompletionContext::None)]
fn test_detect_sql_context(#[case] before_token: &str, #[case] expected: CompletionContext) {
assert_eq!(detect_sql_context(before_token), expected);
}
}