use crate::meta_command::*;
use sqlrite::Connection;
use sqlrite::sql::SQLCommand;
use sqlrite::sql::db::database::Database;
use std::borrow::Cow::{self, Borrowed, Owned};
use std::sync::MutexGuard;
use rustyline::error::ReadlineError;
use rustyline::highlight::{CmdKind, Highlighter, MatchingBracketHighlighter};
use rustyline::hint::{Hinter, HistoryHinter};
use rustyline::validate::Validator;
use rustyline::validate::{ValidationContext, ValidationResult};
use rustyline::{CompletionType, Config, Context, EditMode};
use rustyline_derive::{Completer, Helper};
#[derive(Debug, PartialEq)]
pub enum CommandType {
MetaCommand(MetaCommand),
SQLCommand(SQLCommand),
}
pub fn get_command_type(command: &str) -> CommandType {
if command.starts_with('.') {
CommandType::MetaCommand(MetaCommand::new(command.to_owned()))
} else {
CommandType::SQLCommand(SQLCommand::new(command.to_owned()))
}
}
#[derive(Helper, Completer)]
pub struct REPLHelper {
pub colored_prompt: String,
pub hinter: HistoryHinter,
pub highlighter: MatchingBracketHighlighter,
}
impl Default for REPLHelper {
fn default() -> Self {
Self {
highlighter: MatchingBracketHighlighter::new(),
hinter: HistoryHinter::new(),
colored_prompt: "".to_owned(),
}
}
}
impl Hinter for REPLHelper {
type Hint = String;
fn hint(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Option<String> {
self.hinter.hint(line, pos, ctx)
}
}
impl Validator for REPLHelper {
fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult, ReadlineError> {
use ValidationResult::{Incomplete, Valid};
let input = ctx.input();
let result = if input.starts_with('.') {
Valid(None)
} else if !input.ends_with(';') {
Incomplete
} else {
Valid(None)
};
Ok(result)
}
}
impl Highlighter for REPLHelper {
fn highlight_prompt<'b, 's: 'b, 'p: 'b>(
&'s self,
prompt: &'p str,
default: bool,
) -> Cow<'b, str> {
if default {
Borrowed(&self.colored_prompt)
} else {
Borrowed(prompt)
}
}
fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> {
Owned("\x1b[1m".to_owned() + hint + "\x1b[m")
}
fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> {
self.highlighter.highlight(line, pos)
}
fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool {
self.highlighter.highlight_char(line, pos, kind)
}
}
pub fn get_config() -> Config {
Config::builder()
.history_ignore_space(true)
.completion_type(CompletionType::List)
.edit_mode(EditMode::Emacs)
.build()
}
pub struct ReplState {
conns: Vec<Connection>,
names: Vec<String>,
active: usize,
}
impl ReplState {
pub fn new(conn: Connection) -> Self {
Self {
conns: vec![conn],
names: vec!["A".to_string()],
active: 0,
}
}
pub fn active_name(&self) -> &str {
&self.names[self.active]
}
pub fn handles_summary(&self) -> Vec<(String, bool)> {
self.conns
.iter()
.zip(self.names.iter())
.map(|(c, n)| (n.clone(), c.concurrent_tx_is_open()))
.collect()
}
pub fn lock_active(&self) -> MutexGuard<'_, Database> {
self.conns[self.active].database()
}
pub fn active_conn_mut(&mut self) -> &mut Connection {
&mut self.conns[self.active]
}
pub fn spawn_sibling(&mut self) -> String {
let sibling = self.conns[self.active].connect();
let name = next_handle_name(self.conns.len());
self.conns.push(sibling);
self.names.push(name.clone());
self.active = self.conns.len() - 1;
name
}
pub fn use_handle(&mut self, target: &str) -> Result<String, String> {
let target_upper = target.to_ascii_uppercase();
if let Some(idx) = self.names.iter().position(|n| *n == target_upper) {
self.active = idx;
Ok(self.names[idx].clone())
} else {
let valid = self.names.join(", ");
Err(format!(
"no handle named '{target}'; current handles: {valid}"
))
}
}
pub fn handle_count(&self) -> usize {
self.conns.len()
}
pub fn collapse_to_active(&mut self) {
if self.conns.len() == 1 {
return;
}
let kept = self.conns.swap_remove(self.active);
self.conns.clear();
self.names.clear();
self.conns.push(kept);
self.names.push("A".to_string());
self.active = 0;
}
}
fn next_handle_name(index: usize) -> String {
let mut n = index;
let mut out = String::new();
loop {
let r = n % 26;
out.insert(0, (b'A' + r as u8) as char);
if n < 26 {
break;
}
n = n / 26 - 1;
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn get_command_type_meta_command_test() {
let input = String::from(".help");
let expected = CommandType::MetaCommand(MetaCommand::Help);
let result = get_command_type(&input);
assert_eq!(result, expected);
}
#[test]
fn get_command_type_sql_command_test() {
let input = String::from("SELECT * from users;");
let expected = CommandType::SQLCommand(SQLCommand::Unknown(input.clone()));
let result = get_command_type(&input);
assert_eq!(result, expected);
}
}