use std::{collections::HashSet, error::Error};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub enum Dialect {
#[serde(rename = "mysql")]
MySQL,
#[serde(rename = "sqlite")]
SQLite,
#[serde(rename = "postgresql")]
PostgreSQL,
}
impl ToString for Dialect {
fn to_string(&self) -> String {
match self {
Dialect::MySQL => "mysql".to_string(),
Dialect::SQLite => "sqlite".to_string(),
Dialect::PostgreSQL => "postgresql".to_string(),
}
}
}
#[async_trait]
pub trait Engine: Send + Sync {
fn dialect(&self) -> Dialect;
async fn query(&self, query: &str) -> Result<(Vec<String>, Vec<Vec<String>>), Box<dyn Error>>;
async fn table_names(&self) -> Result<Vec<String>, Box<dyn Error>>;
async fn table_info(&self, tables: &str) -> Result<String, Box<dyn Error>>;
fn close(&self) -> Result<(), Box<dyn Error>>;
}
pub struct SQLDatabase {
pub engine: Box<dyn Engine>,
pub sample_rows_number: i32,
pub all_tables: HashSet<String>,
}
pub struct SQLDatabaseBuilder {
engine: Box<dyn Engine>,
sample_rows_number: i32,
ignore_tables: HashSet<String>,
}
impl SQLDatabaseBuilder {
pub fn new<E>(engine: E) -> Self
where
E: Engine + 'static,
{
SQLDatabaseBuilder {
engine: Box::new(engine),
sample_rows_number: 3, ignore_tables: HashSet::new(),
}
}
pub fn custom_sample_rows_number(mut self, number: i32) -> Self {
self.sample_rows_number = number;
self
}
pub fn ignore_tables(mut self, ignore_tables: HashSet<String>) -> Self {
self.ignore_tables = ignore_tables;
self
}
pub async fn build(self) -> Result<SQLDatabase, Box<dyn Error>> {
let table_names_result = self.engine.table_names().await;
let table_names = match table_names_result {
Ok(names) => names,
Err(error) => {
return Err(error);
}
};
let all_tables: HashSet<String> = table_names
.into_iter()
.filter(|name| !self.ignore_tables.contains(name))
.collect();
Ok(SQLDatabase {
engine: self.engine,
sample_rows_number: self.sample_rows_number,
all_tables,
})
}
}
impl SQLDatabase {
pub fn dialect(&self) -> Dialect {
self.engine.dialect()
}
pub fn table_names(&self) -> Vec<String> {
self.all_tables.iter().cloned().collect()
}
pub async fn table_info(&self, tables: &[String]) -> Result<String, Box<dyn Error>> {
let mut tables: HashSet<String> = tables.to_vec().into_iter().collect();
if tables.len() == 0 {
tables = self.all_tables.clone();
}
let mut info = String::new();
for table in tables {
let table_info = self.engine.table_info(&table).await?;
info.push_str(&table_info);
info.push_str("\n\n");
if self.sample_rows_number > 0 {
let sample_rows = self.sample_rows(&table).await?;
info.push_str("/*\n");
info.push_str(&sample_rows);
info.push_str("*/ \n\n");
}
}
Ok(info)
}
pub async fn query(&self, query: &str) -> Result<String, Box<dyn Error>> {
log::debug!("Query: {}", query);
let (cols, results) = self.engine.query(query).await?;
let mut str = cols.join("\t") + "\n";
for row in results {
str += &row.join("\t");
str.push('\n');
}
Ok(str)
}
pub fn close(&self) -> Result<(), Box<dyn Error>> {
self.engine.close()
}
pub async fn sample_rows(&self, table: &str) -> Result<String, Box<dyn Error>> {
let query = format!("SELECT * FROM {} LIMIT {}", table, self.sample_rows_number);
log::debug!("Sample Rows Query: {}", query);
self.query(&query).await
}
}