use std::collections::BTreeMap;
use std::io::BufRead;
use std::path::Path;
use std::sync::Arc;
use anyhow::anyhow;
use anyhow::Result;
use async_recursion::async_recursion;
use chrono::NaiveDateTime;
use databend_common_ast::parser::all_reserved_keywords;
use databend_common_ast::parser::token::TokenKind;
use databend_common_ast::parser::token::Tokenizer;
use databend_driver::ServerStats;
use databend_driver::{Client, Connection};
use once_cell::sync::Lazy;
use rustyline::config::Builder;
use rustyline::error::ReadlineError;
use rustyline::history::DefaultHistory;
use rustyline::{CompletionType, Editor};
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use tokio::fs::{remove_file, File};
use tokio::io::AsyncWriteExt;
use tokio::task::JoinHandle;
use tokio::time::Instant;
use tokio_stream::StreamExt;
use crate::config::Settings;
use crate::config::TimeOption;
use crate::display::INTERRUPTED_MESSAGE;
use crate::display::{format_write_progress, ChunkDisplay, FormatDisplay};
use crate::helper::CliHelper;
use crate::web::find_available_port;
use crate::web::start_server;
use crate::VERSION;
static PROMPT_SQL: &str = "select name, 'f' as type from system.functions union all select name, 'd' as type from system.databases union all select name, 't' as type from system.tables union all select name, 'c' as type from system.columns limit 10000";
static VERSION_SHORT: Lazy<String> = Lazy::new(|| {
let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
let sha = option_env!("VERGEN_GIT_SHA").unwrap_or("dev");
match option_env!("BENDSQL_BUILD_INFO") {
Some(info) => format!("{}-{}", version, info),
None => format!("{}-{}", version, sha),
}
});
const ALTER_USER_PASSWORD_TOKENS: [TokenKind; 6] = [
TokenKind::USER,
TokenKind::USER,
TokenKind::LParen,
TokenKind::RParen,
TokenKind::IDENTIFIED,
TokenKind::BY,
];
pub struct Session {
client: Client,
pub conn: Box<dyn Connection>,
is_repl: bool,
settings: Settings,
query: String,
server_handle: Option<JoinHandle<std::io::Result<()>>>,
keywords: Option<Arc<sled::Db>>,
interrupted: Arc<AtomicBool>,
}
impl Session {
pub async fn try_new(dsn: String, mut settings: Settings, is_repl: bool) -> Result<Self> {
let client = Client::new(dsn).with_name(format!("bendsql/{}", VERSION_SHORT.as_str()));
let conn = client.get_conn().await?;
let info = conn.info().await;
let mut keywords: Option<Arc<sled::Db>> = None;
if is_repl {
println!("Welcome to BendSQL {}.", VERSION.as_str());
match info.warehouse {
Some(ref warehouse) => {
println!(
"Connecting to {}:{} with warehouse {} as user {}",
info.host, info.port, warehouse, info.user
);
}
None => {
println!(
"Connecting to {}:{} as user {}.",
info.host, info.port, info.user
);
}
}
let version = match conn.version().await {
Ok(version) => version,
Err(err) => {
match err {
databend_driver::Error::Api(databend_client::Error::AuthFailure(_)) => {
return Err(err.into());
}
databend_driver::Error::Arrow(arrow::error::ArrowError::IpcError(
ref ipc_err,
)) => {
if ipc_err.contains("Unauthenticated")
|| ipc_err.contains("Connection refused")
{
return Err(err.into());
}
}
databend_driver::Error::Api(databend_client::Error::Request(
ref resp_err,
)) => {
if resp_err.contains("error sending request for url") {
return Err(err.into());
}
}
_ => {}
}
"".to_string()
}
};
println!("Connected to {}", version);
let config = sled::Config::new().temporary(true);
let db = config.open()?;
{
let keywords = all_reserved_keywords();
let mut batch = sled::Batch::default();
for word in keywords {
batch.insert(word.to_ascii_lowercase().as_str(), "k")
}
db.apply_batch(batch)?;
}
if !settings.no_auto_complete {
let rows = conn.query_iter(PROMPT_SQL).await;
match rows {
Ok(mut rows) => {
let mut count = 0;
let mut batch = sled::Batch::default();
while let Some(Ok(row)) = rows.next().await {
let (w, t): (String, String) = row.try_into().unwrap();
batch.insert(w.as_str(), t.as_str());
count += 1;
if count % 1000 == 0 {
db.apply_batch(batch)?;
batch = sled::Batch::default();
}
}
db.apply_batch(batch)?;
println!("Loaded {} auto complete keywords from server.", db.len());
}
Err(e) => {
eprintln!("WARN: loading auto complete keywords failed: {}", e);
}
}
}
keywords = Some(Arc::new(db));
}
let server_handle = if is_repl {
let port = find_available_port(settings.bind_port).await;
let addr = settings.bind_address.clone();
let server_handle = tokio::spawn(async move { start_server(&addr, port).await });
println!("Started web server at {}:{}", settings.bind_address, port);
settings.bind_port = port;
Some(server_handle)
} else {
None
};
let interrupted = Arc::new(AtomicBool::new(false));
let interrupted_clone = interrupted.clone();
if is_repl {
println!();
ctrlc::set_handler(move || {
interrupted_clone.store(true, Ordering::SeqCst);
})
.expect("Error setting Ctrl-C handler");
}
Ok(Self {
client,
conn,
is_repl,
settings,
query: String::new(),
keywords,
server_handle,
interrupted,
})
}
async fn prompt(&self) -> String {
if !self.query.trim().is_empty() {
"> ".to_owned()
} else {
let info = self.conn.info().await;
let mut prompt = self.settings.prompt.clone();
prompt = prompt.replace("{host}", &info.host);
prompt = prompt.replace("{user}", &info.user);
prompt = prompt.replace("{port}", &info.port.to_string());
if let Some(database) = &info.database {
prompt = prompt.replace("{database}", database);
} else {
prompt = prompt.replace("{database}", "default");
}
if let Some(warehouse) = &info.warehouse {
prompt = prompt.replace("{warehouse}", &format!("({})", warehouse));
} else {
prompt = prompt.replace("{warehouse}", &format!("{}:{}", info.host, info.port));
}
format!("{} ", prompt.trim_end())
}
}
pub async fn check(&mut self) -> Result<()> {
{
println!("BendSQL {}", VERSION.as_str());
}
{
let info = self.conn.info().await;
println!(
"Checking Databend Query server via {} at {}:{} as user {}.",
info.handler, info.host, info.port, info.user
);
if let Some(warehouse) = &info.warehouse {
println!("Using Databend Cloud warehouse: {}", warehouse);
}
if let Some(database) = &info.database {
println!("Current database: {}", database);
} else {
println!("Current database: default");
}
}
{
let version = self.conn.version().await.unwrap_or_default();
println!("Server version: {}", version);
}
match self.conn.query_iter("call admin$license_info()").await {
Ok(mut rows) => {
let row = rows.next().await.unwrap()?;
let linfo: (String, String, String, NaiveDateTime, NaiveDateTime, String) = row
.try_into()
.map_err(|e| anyhow!("parse license info failed: {}", e))?;
if chrono::Utc::now().naive_utc() > linfo.4 {
eprintln!("-> WARN: License expired at {}", linfo.4);
} else {
println!(
"License({}) issued by {} for {} from {} to {}",
linfo.1, linfo.0, linfo.2, linfo.3, linfo.4
);
}
}
Err(_) => {
eprintln!("-> WARN: License not available, only community features enabled.");
}
}
{
let stage_file = "@~/bendsql/.check";
match self.conn.get_presigned_url("UPLOAD", stage_file).await {
Err(_) => {
eprintln!("-> WARN: Backend storage dose not support presigned url.");
eprintln!(" Loading data from local file may not work as expected.");
eprintln!(" Be aware of data transfer cost with arg `presign=off`.");
}
Ok(resp) => {
let now_utc = chrono::Utc::now();
let data = now_utc.to_rfc3339().as_bytes().to_vec();
let size = data.len() as u64;
let reader = Box::new(std::io::Cursor::new(data));
match self.conn.upload_to_stage(stage_file, reader, size).await {
Err(e) => {
eprintln!("-> ERR: Backend storage upload not working as expected.");
eprintln!(" {}", e);
}
Ok(()) => {
let u = url::Url::parse(&resp.url)?;
let host = u.host_str().unwrap_or("unknown");
println!("Backend storage OK: {}", host);
}
};
}
}
}
Ok(())
}
pub async fn handle_repl(&mut self) {
let config = Builder::new()
.completion_prompt_limit(10)
.completion_type(CompletionType::List)
.build();
let mut rl = Editor::<CliHelper, DefaultHistory>::with_config(config).unwrap();
rl.set_helper(Some(CliHelper::new(self.keywords.clone())));
rl.load_history(&get_history_path()).ok();
'F: loop {
match rl.readline(&self.prompt().await) {
Ok(line) => {
let queries = self.append_query(&line);
for query in queries {
let _ = rl.add_history_entry(&query);
match self.handle_query(true, &query).await {
Ok(None) => {
break 'F;
}
Ok(Some(_)) => {}
Err(e) => {
if e.to_string().contains("Unauthenticated") {
if let Err(e) = self.reconnect().await {
eprintln!("reconnect error: {}", e);
} else if let Err(e) = self.handle_query(true, &query).await {
eprintln!("error: {}", e);
}
} else {
eprintln!("error: {}", e);
if e.to_string().contains(INTERRUPTED_MESSAGE) {
if let Some(query_id) = self.conn.last_query_id() {
println!("killing query: {}", query_id);
let _ = self.conn.kill_query(&query_id).await;
}
}
self.query.clear();
break;
}
}
}
}
}
Err(e) => match e {
ReadlineError::Io(err) => {
eprintln!("io err: {err}");
}
ReadlineError::Interrupted => {
self.query.clear();
println!("^C");
}
ReadlineError::Eof => {
break;
}
_ => {}
},
}
}
if let Err(e) = self.conn.close().await {
println!("got error when closing session: {}", e);
}
println!("Bye~");
let _ = rl.save_history(&get_history_path());
}
pub async fn handle_reader<R: BufRead>(&mut self, r: R) -> Result<()> {
let start = Instant::now();
let mut lines = r.lines();
let mut stats: Option<ServerStats> = None;
loop {
match lines.next() {
Some(Ok(line)) => {
let queries = self.append_query(&line);
for query in queries {
stats = self.handle_query(false, &query).await?;
}
}
Some(Err(e)) => {
return Err(anyhow!("read lines err: {}", e.to_string()));
}
None => break,
}
}
let query = self.query.trim().to_owned();
if !query.is_empty() {
self.query.clear();
stats = self.handle_query(false, &query).await?;
}
match self.settings.time {
None => {}
Some(TimeOption::Local) => {
println!("{:.3}", start.elapsed().as_secs_f64());
}
Some(TimeOption::Server) => {
let server_time_ms = match stats {
None => 0.0,
Some(ss) => ss.running_time_ms,
};
println!("{:.3}", server_time_ms / 1000.0);
}
}
self.conn.close().await.ok();
Ok(())
}
pub fn append_query(&mut self, line: &str) -> Vec<String> {
if line.is_empty() {
return vec![];
}
if self.query.is_empty()
&& (line.starts_with('!')
|| line == "exit"
|| line == "quit"
|| line.to_uppercase().starts_with("PUT"))
{
return vec![line.to_owned()];
}
if !self.settings.multi_line {
if line.starts_with("--") {
return vec![];
} else {
return vec![line.to_owned()];
}
}
let mut queries = Vec::new();
if !self.query.is_empty() {
self.query.push('\n');
}
self.query.push_str(line);
'Parser: loop {
let mut is_valid = true;
let tokenizer = Tokenizer::new(&self.query);
for token in tokenizer {
match token {
Ok(token) => {
if let TokenKind::SemiColon = token.kind {
let (sql, remain) = self.query.split_at(token.span.end as usize);
if is_valid && !sql.is_empty() {
queries.push(sql.to_string());
}
self.query = remain.to_string();
continue 'Parser;
}
}
Err(_) => {
is_valid = false;
continue;
}
}
}
break;
}
queries
}
#[async_recursion]
pub async fn handle_query(
&mut self,
is_repl: bool,
query: &str,
) -> Result<Option<ServerStats>> {
let query = query.trim_end_matches(';').trim();
self.interrupted.store(false, Ordering::SeqCst);
if is_repl {
if query.starts_with('!') {
return self.handle_commands(query).await;
}
if query == "exit" || query == "quit" {
return Ok(None);
}
}
let start = Instant::now();
let kind = QueryKind::from(query);
match kind {
QueryKind::AlterUserPassword => {
let _ = self.conn.exec(query).await?;
Ok(None)
}
other => {
let replace_newline = !if self.settings.replace_newline {
false
} else {
replace_newline_in_box_display(query)
};
let data = match other {
QueryKind::Put => {
let args: Vec<String> = get_put_get_args(query);
if args.len() != 3 {
eprintln!("put args are invalid, must be 2 argruments");
return Ok(Some(ServerStats::default()));
}
self.conn.put_files(&args[1], &args[2]).await?
}
QueryKind::Get => {
let args: Vec<String> = get_put_get_args(query);
if args.len() != 3 {
eprintln!("put args are invalid, must be 2 argruments");
return Ok(Some(ServerStats::default()));
}
self.conn.get_files(&args[1], &args[2]).await?
}
_ => self.conn.query_iter_ext(query).await?,
};
let mut displayer = FormatDisplay::new(
&self.settings,
query,
replace_newline,
start,
data,
self.interrupted.clone(),
);
let stats = displayer.display().await?;
Ok(Some(stats))
}
}
}
#[async_recursion]
pub async fn handle_commands(&mut self, query: &str) -> Result<Option<ServerStats>> {
match query {
"!exit" | "!quit" => {
return Ok(None);
}
"!configs" => {
println!("{:#?}", self.settings);
}
other => {
if other.starts_with("!set") {
let query = query[4..].split_whitespace().collect::<Vec<_>>();
if query.len() != 2 {
return Err(anyhow!(
"Control command error, must be syntax of `.cmd_name cmd_value`."
));
}
self.settings.inject_ctrl_cmd(query[0], query[1])?;
} else if other.starts_with("!source") {
let query = query[7..].trim();
let path = Path::new(query);
if !path.exists() {
return Err(anyhow!("File not found: {}", query));
}
let file = std::fs::File::open(path)?;
let reader = std::io::BufReader::new(file);
self.handle_reader(reader).await?;
} else {
return Err(anyhow!("Unknown commands: {}", other));
}
}
}
Ok(Some(ServerStats::default()))
}
pub async fn stream_load_stdin(
&mut self,
query: &str,
options: BTreeMap<&str, &str>,
) -> Result<()> {
let dir = std::env::temp_dir();
let mut lines = std::io::stdin().lock().lines();
let now = chrono::Utc::now().timestamp_nanos_opt().ok_or_else(|| {
anyhow!("Failed to get timestamp, please check your system time is correct and retry.")
})?;
let tmp_file = dir.join(format!("bendsql_{}", now));
{
let mut file = File::create(&tmp_file).await?;
loop {
match lines.next() {
Some(Ok(line)) => {
file.write_all(line.as_bytes()).await?;
file.write_all(b"\n").await?;
}
Some(Err(e)) => {
return Err(anyhow!("stream load stdin err: {}", e.to_string()));
}
None => break,
}
}
file.flush().await?;
}
self.stream_load_file(query, &tmp_file, options).await?;
remove_file(tmp_file).await?;
Ok(())
}
pub async fn stream_load_file(
&mut self,
query: &str,
file_path: &Path,
options: BTreeMap<&str, &str>,
) -> Result<()> {
let start = Instant::now();
let file = File::open(file_path).await?;
let metadata = file.metadata().await?;
let ss = self
.conn
.load_data(query, Box::new(file), metadata.len(), Some(options), None)
.await?;
if self.settings.show_progress {
eprintln!(
"==> stream loaded {}:\n {}",
file_path.display(),
format_write_progress(&ss, start.elapsed().as_secs_f64())
);
}
Ok(())
}
async fn reconnect(&mut self) -> Result<()> {
self.conn = self.client.get_conn().await?;
if self.is_repl {
let info = self.conn.info().await;
eprintln!(
"reconnecting to {}:{} as user {}.",
info.host, info.port, info.user
);
let version = self.conn.version().await.unwrap_or_default();
eprintln!("connected to {}", version);
eprintln!();
}
Ok(())
}
}
fn get_history_path() -> String {
format!(
"{}/.bendsql_history",
std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
)
}
#[derive(PartialEq, Eq, Debug)]
pub enum QueryKind {
Query,
Update,
Explain,
Put,
Get,
AlterUserPassword,
Graphical,
ShowCreate,
}
impl From<&str> for QueryKind {
fn from(query: &str) -> Self {
let mut tz = Tokenizer::new(query);
match tz.next() {
Some(Ok(t)) => match t.kind {
TokenKind::EXPLAIN => {
if query.to_lowercase().contains("graphical") {
QueryKind::Graphical
} else {
QueryKind::Explain
}
}
TokenKind::SHOW => match tz.next() {
Some(Ok(t)) if t.kind == TokenKind::CREATE => QueryKind::ShowCreate,
_ => QueryKind::Query,
},
TokenKind::PUT => QueryKind::Put,
TokenKind::GET => QueryKind::Get,
TokenKind::ALTER => {
let mut tzs = vec![];
while let Some(Ok(t)) = tz.next() {
tzs.push(t.kind);
if tzs.len() == ALTER_USER_PASSWORD_TOKENS.len() {
break;
}
}
if tzs == ALTER_USER_PASSWORD_TOKENS {
QueryKind::AlterUserPassword
} else {
QueryKind::Update
}
}
TokenKind::DELETE
| TokenKind::UPDATE
| TokenKind::INSERT
| TokenKind::CREATE
| TokenKind::DROP
| TokenKind::OPTIMIZE => QueryKind::Update,
_ => QueryKind::Query,
},
_ => QueryKind::Query,
}
}
}
fn get_put_get_args(query: &str) -> Vec<String> {
query
.split_ascii_whitespace()
.map(|x| x.to_owned())
.collect()
}
fn replace_newline_in_box_display(query: &str) -> bool {
let mut tz = Tokenizer::new(query);
match tz.next() {
Some(Ok(t)) => match t.kind {
TokenKind::EXPLAIN => false,
TokenKind::SHOW => !matches!(tz.next(), Some(Ok(t)) if t.kind == TokenKind::CREATE),
_ => true,
},
_ => true,
}
}
impl Drop for Session {
fn drop(&mut self) {
if let Some(handle) = self.server_handle.take() {
handle.abort();
}
}
}