use std::{collections::HashMap, fs::File, io::Write, path::Path, sync::OnceLock, time::Duration};
use chrono::{DateTime, Utc};
use duct::cmd;
use fs_err::{self as fs, create_dir_all};
use regex::Regex;
use sea_orm::{
ActiveModelTrait, ConnectOptions, ConnectionTrait, Database, DatabaseBackend,
DatabaseConnection, DbBackend, DbConn, DbErr, EntityTrait, IntoActiveModel, Statement,
};
use sea_orm_migration::MigratorTrait;
use tracing::info;
use super::Result as AppResult;
use crate::{
app::{AppContext, Hooks},
config, doctor, env_vars,
errors::Error,
};
pub static EXTRACT_DB_NAME: OnceLock<Regex> = OnceLock::new();
const IGNORED_TABLES: &[&str] = &[
"seaql_migrations",
"pg_loco_queue",
"sqlt_loco_queue",
"sqlt_loco_queue_lock",
];
fn get_extract_db_name() -> &'static Regex {
EXTRACT_DB_NAME.get_or_init(|| Regex::new(r"/([^/]+)$").unwrap())
}
#[derive(Default, Clone, Debug)]
pub struct MultiDb {
pub db: HashMap<String, DatabaseConnection>,
}
impl MultiDb {
pub async fn new(dbs_config: HashMap<String, config::Database>) -> AppResult<Self> {
let mut multi_db = Self::default();
for (db_name, db_config) in dbs_config {
multi_db.db.insert(db_name, connect(&db_config).await?);
}
Ok(multi_db)
}
pub fn get(&self, name: &str) -> AppResult<&DatabaseConnection> {
self.db
.get(name)
.map_or_else(|| Err(Error::Message("db not found".to_owned())), Ok)
}
}
#[allow(clippy::match_wildcard_for_single_variants)]
pub async fn verify_access(db: &DatabaseConnection) -> AppResult<()> {
match db {
DatabaseConnection::SqlxPostgresPoolConnection(_) => {
let res = db
.query_all(Statement::from_string(
DatabaseBackend::Postgres,
"SELECT * FROM pg_catalog.pg_tables WHERE tableowner = current_user;",
))
.await?;
if res.is_empty() {
return Err(Error::string(
"current user has no access to tables in the database",
));
}
}
DatabaseConnection::Disconnected => {
return Err(Error::string("connection to database has been closed"));
}
_ => {}
}
Ok(())
}
pub async fn converge<H: Hooks, M: MigratorTrait>(
ctx: &AppContext,
config: &config::Database,
) -> AppResult<()> {
if config.dangerously_recreate {
info!("recreating schema");
reset::<M>(&ctx.db).await?;
return Ok(());
}
if config.auto_migrate {
info!("auto migrating");
migrate::<M>(&ctx.db).await?;
}
if config.dangerously_truncate {
info!("truncating tables");
H::truncate(ctx).await?;
}
Ok(())
}
pub async fn connect(config: &config::Database) -> Result<DbConn, sea_orm::DbErr> {
let mut opt = ConnectOptions::new(&config.uri);
opt.max_connections(config.max_connections)
.min_connections(config.min_connections)
.connect_timeout(Duration::from_millis(config.connect_timeout))
.idle_timeout(Duration::from_millis(config.idle_timeout))
.sqlx_logging(config.enable_logging);
if let Some(acquire_timeout) = config.acquire_timeout {
opt.acquire_timeout(Duration::from_millis(acquire_timeout));
}
let db = Database::connect(opt).await?;
if db.get_database_backend() == DatabaseBackend::Sqlite {
db.execute(Statement::from_string(
DatabaseBackend::Sqlite,
"
PRAGMA foreign_keys = ON;
PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA mmap_size = 134217728;
PRAGMA journal_size_limit = 67108864;
PRAGMA cache_size = 2000;
",
))
.await?;
}
Ok(db)
}
pub async fn create(db_uri: &str) -> AppResult<()> {
if !db_uri.starts_with("postgres://") {
return Err(Error::string(
"Only Postgres databases are supported for table creation",
));
}
let db_name = get_extract_db_name()
.captures(db_uri)
.and_then(|cap| cap.get(1).map(|db| db.as_str()))
.ok_or_else(|| {
Error::string(
"The specified table name was not found in the given Postgre database URI",
)
})?;
let conn = get_extract_db_name()
.replace(db_uri, "/postgres")
.to_string();
let db = Database::connect(conn).await?;
Ok(create_postgres_database(db_name, &db).await?)
}
pub async fn migrate<M: MigratorTrait>(db: &DatabaseConnection) -> Result<(), sea_orm::DbErr> {
M::up(db, None).await
}
pub async fn down<M: MigratorTrait>(
db: &DatabaseConnection,
steps: u32,
) -> Result<(), sea_orm::DbErr> {
M::down(db, Some(steps)).await
}
pub async fn status<M: MigratorTrait>(db: &DatabaseConnection) -> Result<(), sea_orm::DbErr> {
M::status(db).await
}
pub async fn reset<M: MigratorTrait>(db: &DatabaseConnection) -> Result<(), sea_orm::DbErr> {
M::fresh(db).await?;
migrate::<M>(db).await
}
use sea_orm::EntityName;
use serde_json::{json, Value};
#[allow(clippy::type_repetition_in_bounds)]
pub async fn seed<A>(db: &DatabaseConnection, path: &str) -> crate::Result<()>
where
<<A as ActiveModelTrait>::Entity as EntityTrait>::Model: IntoActiveModel<A>,
for<'de> <<A as ActiveModelTrait>::Entity as EntityTrait>::Model: serde::de::Deserialize<'de>,
A: ActiveModelTrait + Send + Sync,
sea_orm::Insert<A>: Send + Sync,
<A as ActiveModelTrait>::Entity: EntityName,
{
let seed_data: Vec<Value> = serde_yaml::from_reader(File::open(path)?)?;
for row in seed_data {
let model = A::from_json(row)?;
A::Entity::insert(model).exec(db).await?;
}
let table_name = A::Entity::default().table_name().to_string();
let db_backend = db.get_database_backend();
reset_autoincrement(db_backend, &table_name, db).await?;
Ok(())
}
async fn has_id_column(
db: &DatabaseConnection,
db_backend: &DatabaseBackend,
table_name: &str,
) -> crate::Result<bool> {
let result = match db_backend {
DatabaseBackend::Postgres => {
let query = format!(
"SELECT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = '{table_name}'
AND column_name = 'id'
)"
);
let result = db
.query_one(Statement::from_string(DatabaseBackend::Postgres, query))
.await?;
result.is_some_and(|row| row.try_get::<bool>("", "exists").unwrap_or(false))
}
DatabaseBackend::Sqlite => {
let query = format!(
"SELECT COUNT(*) as count
FROM pragma_table_info('{table_name}')
WHERE name = 'id'"
);
let result = db
.query_one(Statement::from_string(DatabaseBackend::Sqlite, query))
.await?;
result.is_some_and(|row| row.try_get::<i32>("", "count").unwrap_or(0) > 0)
}
DatabaseBackend::MySql => {
return Err(Error::Message(
"Unsupported database backend: MySQL".to_string(),
))
}
};
Ok(result)
}
async fn is_auto_increment(
db: &DatabaseConnection,
db_backend: &DatabaseBackend,
table_name: &str,
) -> crate::Result<bool> {
let result = match db_backend {
DatabaseBackend::Postgres => {
let query = format!(
"SELECT pg_get_serial_sequence('{table_name}', 'id') IS NOT NULL as is_serial"
);
let result = db
.query_one(Statement::from_string(DatabaseBackend::Postgres, query))
.await?;
result.is_some_and(|row| row.try_get::<bool>("", "is_serial").unwrap_or(false))
}
DatabaseBackend::Sqlite => {
let query =
format!("SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}'");
let result = db
.query_one(Statement::from_string(DatabaseBackend::Sqlite, query))
.await?;
result.is_some_and(|row| {
row.try_get::<String>("", "sql")
.is_ok_and(|sql| sql.to_lowercase().contains("autoincrement"))
})
}
DatabaseBackend::MySql => {
return Err(Error::Message(
"Unsupported database backend: MySQL".to_string(),
))
}
};
Ok(result)
}
pub async fn reset_autoincrement(
db_backend: DatabaseBackend,
table_name: &str,
db: &DatabaseConnection,
) -> crate::Result<()> {
let has_id_column = has_id_column(db, &db_backend, table_name).await?;
if !has_id_column {
return Ok(());
}
let is_auto_increment = is_auto_increment(db, &db_backend, table_name).await?;
if !is_auto_increment {
return Ok(());
}
match db_backend {
DatabaseBackend::Postgres => {
let query_str = format!(
"SELECT setval(pg_get_serial_sequence('{table_name}', 'id'), COALESCE(MAX(id), 0) \
+ 1, false) FROM {table_name}"
);
db.execute(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
&query_str,
vec![],
))
.await?;
}
DatabaseBackend::Sqlite => {
let query_str = format!(
"UPDATE sqlite_sequence SET seq = (SELECT MAX(id) FROM {table_name}) WHERE name = \
'{table_name}'"
);
db.execute(Statement::from_sql_and_values(
DatabaseBackend::Sqlite,
&query_str,
vec![],
))
.await?;
}
DatabaseBackend::MySql => {
return Err(Error::Message(
"Unsupported database backend: MySQL".to_string(),
))
}
}
Ok(())
}
pub async fn entities<M: MigratorTrait>(ctx: &AppContext) -> AppResult<String> {
doctor::check_seaorm_cli()?.to_result()?;
doctor::check_db(&ctx.config.database).await.to_result()?;
let out = cmd!(
"sea-orm-cli",
"generate",
"entity",
"--with-serde",
"both",
"--output-dir",
"src/models/_entities",
"--database-url",
&ctx.config.database.uri,
"--ignore-tables",
IGNORED_TABLES.join(","),
)
.stderr_to_stdout()
.run()
.map_err(|err| {
Error::Message(format!(
"failed to generate entity using sea-orm-cli binary. error details: `{err}`",
))
})?;
fix_entities()?;
Ok(String::from_utf8_lossy(&out.stdout).to_string())
}
fn fix_entities() -> AppResult<()> {
let dir = fs::read_dir("src/models/_entities")?
.flatten()
.filter(|ent| {
ent.path().is_file() && ent.file_name() != "mod.rs" && ent.file_name() != "prelude.rs"
})
.map(|ent| ent.path())
.collect::<Vec<_>>();
let activemodel_exp = "impl ActiveModelBehavior for ActiveModel {}";
let mut cleaned_entities = Vec::new();
for file in &dir {
let content = fs::read_to_string(file)?;
if content.contains(activemodel_exp) {
let content = content
.lines()
.filter(|line| !line.contains(activemodel_exp))
.collect::<Vec<_>>()
.join("\n");
fs::write(file, content)?;
cleaned_entities.push(file);
}
}
let mut models_mod = fs::read_to_string("src/models/mod.rs")?;
for entity_file in cleaned_entities {
let new_file = Path::new("src/models").join(
entity_file
.file_name()
.ok_or_else(|| Error::string("cannot extract file name"))?,
);
if !new_file.exists() {
let module = new_file
.file_stem()
.ok_or_else(|| Error::string("cannot extract file stem"))?
.to_str()
.ok_or_else(|| Error::string("cannot extract file stem"))?;
let module_pascal = heck::AsPascalCase(module);
fs::write(
&new_file,
format!(
r"use sea_orm::entity::prelude::*;
pub use super::_entities::{module}::{{ActiveModel, Model, Entity}};
pub type {module_pascal} = Entity;
#[async_trait::async_trait]
impl ActiveModelBehavior for ActiveModel {{
async fn before_save<C>(self, _db: &C, insert: bool) -> std::result::Result<Self, DbErr>
where
C: ConnectionTrait,
{{
if !insert && self.updated_at.is_unchanged() {{
let mut this = self;
this.updated_at = sea_orm::ActiveValue::Set(chrono::Utc::now().into());
Ok(this)
}} else {{
Ok(self)
}}
}}
}}
// implement your read-oriented logic here
impl Model {{}}
// implement your write-oriented logic here
impl ActiveModel {{}}
// implement your custom finders, selectors oriented logic here
impl Entity {{}}
"
),
)?;
if !models_mod.contains(&format!("mod {module}")) {
models_mod.push_str(&format!("pub mod {module};\n"));
}
}
}
fs::write("src/models/mod.rs", models_mod)?;
Ok(())
}
pub async fn truncate_table<T>(db: &DatabaseConnection, _: T) -> Result<(), sea_orm::DbErr>
where
T: EntityTrait,
{
T::delete_many().exec(db).await?;
Ok(())
}
pub async fn run_app_seed<H: Hooks>(ctx: &AppContext, path: &Path) -> AppResult<()> {
H::seed(ctx, path).await
}
async fn create_postgres_database(
db_name: &str,
db: &DatabaseConnection,
) -> Result<(), sea_orm::DbErr> {
let with_options = env_vars::get_or_default(env_vars::POSTGRES_DB_OPTIONS, "ENCODING='UTF8'");
let query = format!("CREATE DATABASE {db_name} WITH {with_options}");
tracing::info!(query, "creating postgres database");
db.execute(sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
query,
))
.await?;
Ok(())
}
pub async fn get_tables(db: &DatabaseConnection) -> AppResult<Vec<String>> {
let query = match db.get_database_backend() {
DatabaseBackend::MySql => {
return Err(Error::Message(
"Unsupported database backend: MySQL".to_string(),
))
}
DatabaseBackend::Postgres => {
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
}
DatabaseBackend::Sqlite => {
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
}
};
let result = db
.query_all(Statement::from_string(
db.get_database_backend(),
query.to_string(),
))
.await?;
Ok(result
.into_iter()
.filter_map(|row| {
let col = match db.get_database_backend() {
sea_orm::DatabaseBackend::MySql | sea_orm::DatabaseBackend::Postgres => {
"table_name"
}
sea_orm::DatabaseBackend::Sqlite => "name",
};
if let Ok(table_name) = row.try_get::<String>("", col) {
if IGNORED_TABLES.contains(&table_name.as_str()) {
return None;
}
Some(table_name)
} else {
None
}
})
.collect())
}
pub async fn dump_tables(
db: &DatabaseConnection,
to: &Path,
only_tables: Option<Vec<String>>,
) -> AppResult<()> {
tracing::debug!("getting tables from the database");
let tables = get_tables(db).await?;
tracing::info!(tables = ?tables, "found tables");
for table in tables {
if let Some(ref only_tables) = only_tables {
if !only_tables.contains(&table) {
tracing::info!(table, "skipping table as it is not in the specified list");
continue;
}
}
tracing::info!(table, "get table data");
let data_result = db
.query_all(Statement::from_string(
db.get_database_backend(),
format!(r#"SELECT * FROM "{table}""#),
))
.await?;
tracing::info!(
table,
rows_fetched = data_result.len(),
"fetched rows from table"
);
let mut table_data: Vec<HashMap<String, serde_json::Value>> = Vec::new();
if !to.exists() {
tracing::info!("the specified dump folder does not exist. creating the folder now");
create_dir_all(to)?;
}
for row in data_result {
let mut row_data: HashMap<String, serde_json::Value> = HashMap::new();
for col_name in row.column_names() {
let value_result = row
.try_get::<String>("", &col_name)
.map(serde_json::Value::String)
.or_else(|_| {
row.try_get::<i8>("", &col_name)
.map(serde_json::Value::from)
})
.or_else(|_| {
row.try_get::<i16>("", &col_name)
.map(serde_json::Value::from)
})
.or_else(|_| {
row.try_get::<i32>("", &col_name)
.map(serde_json::Value::from)
})
.or_else(|_| {
row.try_get::<i64>("", &col_name)
.map(serde_json::Value::from)
})
.or_else(|_| {
row.try_get::<f32>("", &col_name)
.map(serde_json::Value::from)
})
.or_else(|_| {
row.try_get::<f64>("", &col_name)
.map(serde_json::Value::from)
})
.or_else(|_| {
row.try_get::<uuid::Uuid>("", &col_name)
.map(|v| serde_json::Value::String(v.to_string()))
})
.or_else(|_| {
row.try_get::<DateTime<Utc>>("", &col_name)
.map(|v| serde_json::Value::String(v.to_rfc3339()))
})
.or_else(|_| {
row.try_get::<serde_json::Value>("", &col_name)
.map(serde_json::Value::from)
})
.or_else(|_| {
row.try_get::<bool>("", &col_name)
.map(serde_json::Value::Bool)
})
.ok();
if let Some(value) = value_result {
row_data.insert(col_name, value);
}
}
table_data.push(row_data);
}
let data = serde_yaml::to_string(&table_data)?;
let file_db_content_path = to.join(format!("{table}.yaml"));
let mut file = File::create(&file_db_content_path)?;
file.write_all(data.as_bytes())?;
tracing::info!(table, file_db_content_path = %file_db_content_path.display(), "table data written to YAML file");
}
tracing::info!("dumping tables process completed successfully");
Ok(())
}
pub async fn dump_schema(ctx: &AppContext, fname: &str) -> crate::Result<()> {
let db = &ctx.db;
let schema_info = match db.get_database_backend() {
DbBackend::Postgres => {
let query = r"
SELECT table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema = 'public'
ORDER BY table_name, ordinal_position;
";
let stmt = Statement::from_string(DbBackend::Postgres, query.to_owned());
let rows = db.query_all(stmt).await?;
rows.into_iter()
.map(|row| {
Ok(json!({
"table": row.try_get::<String>("", "table_name")?,
"column": row.try_get::<String>("", "column_name")?,
"type": row.try_get::<String>("", "data_type")?,
}))
})
.collect::<Result<Vec<serde_json::Value>, DbErr>>()? }
DbBackend::MySql => {
let query = r"
SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
ORDER BY TABLE_NAME, ORDINAL_POSITION;
";
let stmt = Statement::from_string(DbBackend::MySql, query.to_owned());
let rows = db.query_all(stmt).await?;
rows.into_iter()
.map(|row| {
Ok(json!({
"table": row.try_get::<String>("", "TABLE_NAME")?,
"column": row.try_get::<String>("", "COLUMN_NAME")?,
"type": row.try_get::<String>("", "COLUMN_TYPE")?,
}))
})
.collect::<Result<Vec<serde_json::Value>, DbErr>>()? }
DbBackend::Sqlite => {
let query = r"
SELECT name AS table_name, sql AS table_sql
FROM sqlite_master
WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
ORDER BY name;
";
let stmt = Statement::from_string(DbBackend::Sqlite, query.to_owned());
let rows = db.query_all(stmt).await?;
rows.into_iter()
.map(|row| {
Ok(json!({
"table": row.try_get::<String>("", "table_name")?,
"sql": row.try_get::<String>("", "table_sql")?,
}))
})
.collect::<Result<Vec<serde_json::Value>, DbErr>>()? }
};
let schema_json = serde_json::to_string_pretty(&schema_info)?;
std::fs::write(fname, schema_json)?;
Ok(())
}