use crate::sql_type;
use futures::TryStreamExt;
use sqlx;
use sqlx::mysql::MySqlPoolOptions;
use sqlx::Row;
use std::fs;
use std::io::Write;
use std::path::Path;
use std::time::Duration;
const LICENSE_HEADER: &str = "// code generated by gen-table. DO NOT EDIT!!!\n";
#[derive(Debug, Default)]
pub struct Engine {
dsn: String,
max_connections: u32, min_connections: u32, max_lifetime: Duration,
idle_timeout: Duration,
connect_timeout: Duration,
enable_table_name_fn: bool, no_null_field: bool, out_dir: String, is_serde: bool, }
#[derive(Debug, Default)]
pub struct ColumnEntry {
pub table_name: String,
pub field: String,
pub data_type: String,
pub field_desc: String,
pub field_key: String,
pub order_by: u64,
pub is_nullable: String,
pub max_length: Option<i64>,
pub numeric_prec: Option<u64>,
pub numeric_scale: Option<u64>,
pub extra: String,
pub field_comment: String,
}
fn capit(s: &str) -> String {
format!("{}{}", (&s[..1].to_string()).to_uppercase(), &s[1..])
}
fn camel_case(s: &str) -> String {
let arr: Vec<&str> = s.split("_").collect();
let mut res: Vec<String> = Vec::new();
for s in arr {
let cur_str = capit(s);
res.push(cur_str);
}
res.join("").to_string()
}
#[test]
fn it_works() {
println!("{}", capit("news"));
println!("{}", camel_case("news_topics"));
}
impl Engine {
pub fn new(dsn: &str, out_dir: &str) -> Self {
let mut s = Self {
dsn: dsn.to_string(),
enable_table_name_fn: true,
max_connections: 50,
min_connections: 6,
max_lifetime: Duration::from_secs(1800),
idle_timeout: Duration::from_secs(600),
connect_timeout: Duration::from_secs(10),
out_dir: out_dir.to_string(),
..Default::default()
};
if s.out_dir.is_empty() {
s.out_dir = "src/model".to_string();
}
s.create_out_dir();
s.create_mod_file();
s
}
pub fn with_enable_tab_name(mut self, enable: bool) -> Self {
self.enable_table_name_fn = enable;
self
}
pub fn with_no_null_field(mut self, no_null_field: bool) -> Self {
self.no_null_field = no_null_field;
self
}
pub fn with_serde(mut self, is_serde: bool) -> Self {
self.is_serde = is_serde;
self
}
pub async fn gen_code(&mut self, tables: Vec<&str>) {
if tables.is_empty() {
println!("No tables require code generation");
return;
}
println!("gen tables:{:?} rust code", tables);
let pool = self
.init_pool()
.await
.expect("mysql pool connection failed");
if !self.check_table_exist(&pool, &tables).await {
return;
}
let out_dir = Path::new(self.out_dir.as_str());
let mut mod_file = fs::OpenOptions::new()
.append(true)
.write(true)
.open(out_dir.join("mod.rs"))
.expect("create mod.rs failed");
mod_file
.write(LICENSE_HEADER.as_bytes())
.expect("write header failed");
for table in &tables {
println!("gen code for table:{}", table);
self.gen_table_code(&pool, &out_dir, &mut mod_file, table)
.await;
println!("gen code for table:{} finish", table);
}
}
async fn check_table_exist(&self, pool: &sqlx::MySqlPool, tables: &Vec<&str>) -> bool {
for table in tables {
let records = self
.get_columns(&pool, table)
.await
.expect("get table columns failed");
if records.is_empty() {
println!("current table:{} has no fields", table);
return false;
}
}
true
}
fn create_out_dir(&self) {
let out_dir = Path::new(self.out_dir.as_str());
if !out_dir.is_dir() {
let _ = fs::create_dir_all(out_dir).expect("create out_dir failed");
}
}
fn create_mod_file(&mut self) {
let out_dir = Path::new(self.out_dir.as_str());
fs::File::create(out_dir.join("mod.rs")).expect("create mod.rs failed");
}
fn get_query_fields(&self) -> Vec<&str> {
let fields = vec![
"TABLE_NAME as table_name",
"COLUMN_NAME as field",
"DATA_TYPE as data_type",
"COLUMN_TYPE as field_desc",
"COLUMN_KEY as field_key",
"ORDINAL_POSITION as order_by",
"IS_NULLABLE as is_nullable",
"CHARACTER_MAXIMUM_LENGTH as max_length",
"NUMERIC_PRECISION as numeric_prec",
"NUMERIC_SCALE as numeric_scale",
"EXTRA as extra",
"COLUMN_COMMENT as field_comment",
];
fields
}
async fn get_columns(
&self,
pool: &sqlx::MySqlPool,
table: &str,
) -> Result<Vec<ColumnEntry>, sqlx::Error> {
let fields = self.get_query_fields().join(",");
let sql = format!(
"SELECT {} FROM information_schema.COLUMNS WHERE table_schema = DATABASE() AND TABLE_NAME = ?",
fields
);
let mut rows = sqlx::query(&sql).bind(table).fetch(pool);
let mut records = Vec::new();
while let Some(row) = rows.try_next().await? {
let record = ColumnEntry {
table_name: row.get("table_name"),
field: row.get("field"),
data_type: row.get("data_type"),
field_desc: row.get("field_desc"),
field_key: row.get("field_key"),
order_by: row.get("order_by"),
is_nullable: row.get("is_nullable"),
max_length: row.get("max_length"),
numeric_prec: row.get("numeric_prec"),
numeric_scale: row.get("numeric_scale"),
extra: row.get("extra"),
field_comment: row.get("field_comment"),
};
records.push(record);
}
Ok(records)
}
async fn init_pool(&self) -> Result<sqlx::MySqlPool, sqlx::Error> {
let pool = MySqlPoolOptions::new()
.max_connections(self.max_connections)
.min_connections(self.min_connections)
.max_lifetime(self.max_lifetime)
.idle_timeout(self.idle_timeout)
.acquire_timeout(self.connect_timeout)
.connect(&self.dsn)
.await?;
Ok(pool)
}
fn get_no_null_fields(&self) -> Vec<String> {
let v = vec![
"i32".to_string(),
"i64".to_string(),
"f64".to_string(),
"f32".to_string(),
"String".to_string(),
];
v
}
async fn gen_table_code(
&self,
pool: &sqlx::MySqlPool,
out_dir: &Path,
mod_file: &mut fs::File,
table: &str,
) {
mod_file
.write(format!("pub mod {};\n", table).as_bytes())
.unwrap();
let mut file = fs::File::create(out_dir.join(table.to_string() + ".rs"))
.expect("create mod.rs failed");
file.write(format!("{}// gen code for {} table.\n", LICENSE_HEADER, table).as_bytes())
.expect("write content failed");
let records = self
.get_columns(pool, table)
.await
.expect("get table columns failed");
if self.check_import_duration(&records) {
file.write(format!("{}", "use std::time::Duration;\n").as_bytes())
.expect("import std::time::Duration failed");
}
if self.is_serde {
file.write(format!("{}", "use serde::{Deserialize, Serialize};\n\n").as_bytes())
.expect("import serde failed");
}
let tab_upper = table.to_uppercase();
file.write(
format!(
"// {}_TABLE for {} table\nconst {}_TABLE: &str = \"{}\";\n\n",
tab_upper, table, tab_upper, table,
)
.as_bytes(),
)
.expect("write content failed");
let table_entity_name = camel_case(table);
file.write(format!("// {}Entity for {} table\n", table_entity_name, table).as_bytes())
.expect("write content failed");
let mut derive_block = format!("{}", "#[derive(Debug, Default)]\n");
if self.is_serde {
derive_block = format!("{}", "#[derive(Debug, Default, Serialize, Deserialize)]\n");
}
file.write(derive_block.as_bytes())
.expect("gen struct derive failed");
file.write(format!("pub struct {}Entity {}\n", table_entity_name, "{").as_bytes())
.expect("write content failed");
let no_null_fields = self.get_no_null_fields();
for record in records {
let data_type = sql_type::get_type(&record.data_type);
let mut is_nullable = record.is_nullable;
if self.no_null_field && no_null_fields.contains(&data_type) {
is_nullable = "NO".to_string();
}
let mut row = format!("\tpub {}: {},\n", record.field.to_lowercase(), data_type);
if is_nullable.eq("YES") {
println!(
"current field:{} is null able,type:{}",
record.field, data_type
);
row = format!(
"\tpub {}: Option<{}>,\n",
record.field.to_lowercase(),
data_type
);
}
file.write(row.as_bytes()).expect("gen struct field failed");
}
file.write(format!("{}\n\n", "}").as_bytes())
.expect("write content failed");
if self.enable_table_name_fn {
let tab_fn_tpl = self.get_tab_fn_tpl(table);
file.write(tab_fn_tpl.as_bytes())
.expect("gen table_name fn failed");
}
}
fn check_import_duration(&self, records: &Vec<ColumnEntry>) -> bool {
for record in records {
let data_type = sql_type::get_type(&record.data_type);
if data_type == "Duration" {
return true;
}
}
false
}
fn get_tab_fn_tpl(&self, table: &str) -> String {
let table_entity_name = camel_case(table);
let header = format!(
"// impl table_name method for {}Entity\n",
table_entity_name
);
let tab_fn_tpl = format!(
"impl {}Entity {}\n\tpub fn table_name(&self) -> String {}\n\t\t{}_TABLE.to_string()\n\t{}\n{}",
table_entity_name,"{","{",
table.to_uppercase(),"}","}",
);
let rows = vec![header, tab_fn_tpl];
rows.join("").to_string()
}
}