use crate::parser::{MongoToken, MongoTokenKind, SqlToken, SqlTokenKind};
#[derive(Debug, Clone)]
pub enum UnifiedToken {
Sql(SqlToken),
Mongo(MongoToken),
}
impl UnifiedToken {
pub fn is_ident(&self) -> bool {
match self {
UnifiedToken::Sql(t) => matches!(t.kind, SqlTokenKind::Ident(_)),
UnifiedToken::Mongo(t) => matches!(t.kind, MongoTokenKind::Ident(_)),
}
}
pub fn ident_value(&self) -> Option<String> {
match self {
UnifiedToken::Sql(t) => {
if let SqlTokenKind::Ident(s) = &t.kind {
Some(s.clone())
} else {
None
}
}
UnifiedToken::Mongo(t) => {
if let MongoTokenKind::Ident(s) = &t.kind {
Some(s.clone())
} else {
None
}
}
}
}
pub fn is_dot(&self) -> bool {
match self {
UnifiedToken::Sql(t) => matches!(t.kind, SqlTokenKind::Dot),
UnifiedToken::Mongo(t) => matches!(t.kind, MongoTokenKind::Dot),
}
}
pub fn is_open_paren(&self) -> bool {
match self {
UnifiedToken::Mongo(t) => matches!(t.kind, MongoTokenKind::LParen),
_ => false,
}
}
pub fn is_close_paren(&self) -> bool {
match self {
UnifiedToken::Mongo(t) => matches!(t.kind, MongoTokenKind::RParen),
_ => false,
}
}
pub fn is_db(&self) -> bool {
match self {
UnifiedToken::Mongo(t) => matches!(t.kind, MongoTokenKind::Db),
_ => false,
}
}
pub fn is_sql_keyword(&self, keyword: &str) -> bool {
match self {
UnifiedToken::Sql(t) => match keyword.to_uppercase().as_str() {
"SELECT" => matches!(t.kind, SqlTokenKind::Select),
"INSERT" => matches!(t.kind, SqlTokenKind::Insert),
"UPDATE" => matches!(t.kind, SqlTokenKind::Update),
"DELETE" => matches!(t.kind, SqlTokenKind::Delete),
"FROM" => matches!(t.kind, SqlTokenKind::From),
"WHERE" => matches!(t.kind, SqlTokenKind::Where),
"JOIN" => matches!(t.kind, SqlTokenKind::Join),
"INNER" => matches!(t.kind, SqlTokenKind::Inner),
"LEFT" => matches!(t.kind, SqlTokenKind::Left),
"RIGHT" => matches!(t.kind, SqlTokenKind::Right),
"LIMIT" => matches!(t.kind, SqlTokenKind::Limit),
"OFFSET" => matches!(t.kind, SqlTokenKind::Offset),
"ORDER" => matches!(t.kind, SqlTokenKind::Order),
"BY" => matches!(t.kind, SqlTokenKind::By),
"GROUP" => matches!(t.kind, SqlTokenKind::Group),
_ => false,
},
_ => false,
}
}
pub fn is_semicolon(&self) -> bool {
match self {
UnifiedToken::Sql(t) => matches!(t.kind, SqlTokenKind::Semicolon),
UnifiedToken::Mongo(t) => matches!(t.kind, MongoTokenKind::Semicolon),
}
}
pub fn is_number(&self) -> bool {
match self {
UnifiedToken::Sql(t) => matches!(t.kind, SqlTokenKind::Number(_)),
UnifiedToken::Mongo(t) => matches!(t.kind, MongoTokenKind::Number(_)),
}
}
pub fn span(&self) -> std::ops::Range<usize> {
match self {
UnifiedToken::Sql(t) => t.span.clone(),
UnifiedToken::Mongo(t) => t.span.clone(),
}
}
}
pub struct TokenStream {
pub tokens: Vec<UnifiedToken>,
pub cursor: usize,
pub token_index: usize,
}
impl TokenStream {
pub fn from_sql(sql_tokens: Vec<SqlToken>, cursor: usize) -> Self {
let tokens: Vec<UnifiedToken> = sql_tokens.into_iter().map(UnifiedToken::Sql).collect();
let token_index = Self::find_token_at_cursor(&tokens, cursor);
Self {
tokens,
cursor,
token_index,
}
}
pub fn from_mongo(mongo_tokens: Vec<MongoToken>, cursor: usize) -> Self {
let tokens: Vec<UnifiedToken> = mongo_tokens.into_iter().map(UnifiedToken::Mongo).collect();
let token_index = Self::find_token_at_cursor(&tokens, cursor);
Self {
tokens,
cursor,
token_index,
}
}
fn find_token_at_cursor(tokens: &[UnifiedToken], cursor: usize) -> usize {
for (i, token) in tokens.iter().enumerate() {
let span = token.span();
if cursor > span.start && cursor < span.end {
return i;
}
if cursor == span.start {
return i;
}
}
tokens.len().saturating_sub(1)
}
pub fn tokens_before_cursor(&self) -> &[UnifiedToken] {
&self.tokens[..self.token_index]
}
pub fn current_token(&self) -> Option<&UnifiedToken> {
self.tokens.get(self.token_index)
}
pub fn current_prefix(&self) -> String {
if let Some(token) = self.current_token() {
let span = token.span();
if self.cursor >= span.start && self.cursor <= span.end {
if let Some(ident) = token.ident_value() {
let chars_typed = self.cursor - span.start;
return ident.chars().take(chars_typed).collect();
}
}
}
if self.token_index > 0 {
if let Some(prev_token) = self.tokens.get(self.token_index - 1) {
let span = prev_token.span();
if self.cursor == span.end {
if let Some(ident) = prev_token.ident_value() {
return ident;
}
}
}
}
String::new()
}
pub fn completion_start(&self) -> usize {
if let Some(token) = self.current_token() {
let span = token.span();
if self.cursor >= span.start && self.cursor <= span.end && token.is_ident() {
return span.start;
}
}
if self.token_index > 0 {
if let Some(prev_token) = self.tokens.get(self.token_index - 1) {
let span = prev_token.span();
if self.cursor == span.end && prev_token.is_ident() {
return span.start;
}
}
}
self.cursor
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.tokens.len() <= 1
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::{MongoLexer, SqlLexer};
#[test]
fn test_unified_token_mongo_ident() {
let tokens = MongoLexer::tokenize("users");
let unified = UnifiedToken::Mongo(tokens[0].clone());
assert!(unified.is_ident());
assert_eq!(unified.ident_value(), Some("users".to_string()));
assert!(!unified.is_dot());
assert!(!unified.is_db());
}
#[test]
fn test_unified_token_mongo_db() {
let tokens = MongoLexer::tokenize("db");
let unified = UnifiedToken::Mongo(tokens[0].clone());
assert!(!unified.is_ident());
assert!(unified.is_db());
}
#[test]
fn test_unified_token_sql_ident() {
let tokens = SqlLexer::tokenize("users");
let unified = UnifiedToken::Sql(tokens[0].clone());
assert!(unified.is_ident());
assert_eq!(unified.ident_value(), Some("users".to_string()));
}
#[test]
fn test_unified_token_sql_keyword() {
let tokens = SqlLexer::tokenize("SELECT * FROM users");
let select_token = UnifiedToken::Sql(tokens[0].clone());
let from_token = UnifiedToken::Sql(tokens[2].clone());
assert!(select_token.is_sql_keyword("SELECT"));
assert!(!select_token.is_sql_keyword("FROM"));
assert!(from_token.is_sql_keyword("FROM"));
}
#[test]
fn test_token_stream_mongo() {
let tokens = MongoLexer::tokenize("db.users");
let stream = TokenStream::from_mongo(tokens, 8);
assert_eq!(stream.tokens.len(), 4); assert_eq!(stream.cursor, 8);
}
#[test]
fn test_token_stream_sql() {
let tokens = SqlLexer::tokenize("SELECT * FROM users");
let stream = TokenStream::from_sql(tokens, 19);
assert!(!stream.is_empty());
assert_eq!(stream.cursor, 19);
}
#[test]
fn test_tokens_before_cursor() {
let tokens = MongoLexer::tokenize("db.users");
let stream = TokenStream::from_mongo(tokens, 3);
let before = stream.tokens_before_cursor();
assert_eq!(before.len(), 2);
}
#[test]
fn test_current_prefix() {
let tokens = MongoLexer::tokenize("db.us");
let stream = TokenStream::from_mongo(tokens, 4);
let prefix = stream.current_prefix();
assert_eq!(prefix, "u");
}
#[test]
fn test_completion_start() {
let tokens = MongoLexer::tokenize("db.users");
let stream = TokenStream::from_mongo(tokens, 6);
let start = stream.completion_start();
assert_eq!(start, 3);
}
#[test]
fn test_empty_stream() {
let tokens = MongoLexer::tokenize("");
let stream = TokenStream::from_mongo(tokens, 0);
assert!(stream.is_empty());
}
}