use std::{collections::HashMap, fs::File, path::Path, sync::OnceLock, time::Duration};
use duct::cmd;
use fs_err as fs;
use regex::Regex;
use sea_orm::{
ActiveModelTrait, ConnectOptions, ConnectionTrait, Database, DatabaseBackend,
DatabaseConnection, DbConn, EntityTrait, IntoActiveModel, Statement,
};
use sea_orm_migration::MigratorTrait;
use tracing::info;
use super::Result as AppResult;
use crate::{
app::{AppContext, Hooks},
config, doctor,
errors::Error,
};
pub static EXTRACT_DB_NAME: OnceLock<Regex> = OnceLock::new();
const IGNORED_TABLES: &[&str] = &[
"seaql_migrations",
"pg_loco_queue",
"sqlt_loco_queue",
"sqlt_loco_queue_lock",
];
fn get_extract_db_name() -> &'static Regex {
EXTRACT_DB_NAME.get_or_init(|| Regex::new(r"/([^/]+)$").unwrap())
}
#[derive(Default, Clone, Debug)]
pub struct MultiDb {
pub db: HashMap<String, DatabaseConnection>,
}
impl MultiDb {
pub async fn new(dbs_config: HashMap<String, config::Database>) -> AppResult<Self> {
let mut multi_db = Self::default();
for (db_name, db_config) in dbs_config {
multi_db.db.insert(db_name, connect(&db_config).await?);
}
Ok(multi_db)
}
pub fn get(&self, name: &str) -> AppResult<&DatabaseConnection> {
self.db
.get(name)
.map_or_else(|| Err(Error::Message("db not found".to_owned())), Ok)
}
}
#[allow(clippy::match_wildcard_for_single_variants)]
pub async fn verify_access(db: &DatabaseConnection) -> AppResult<()> {
match db {
DatabaseConnection::SqlxPostgresPoolConnection(_) => {
let res = db
.query_all(Statement::from_string(
DatabaseBackend::Postgres,
"SELECT * FROM pg_catalog.pg_tables WHERE tableowner = current_user;",
))
.await?;
if res.is_empty() {
return Err(Error::string(
"current user has no access to tables in the database",
));
}
}
DatabaseConnection::Disconnected => {
return Err(Error::string("connection to database has been closed"));
}
_ => {}
}
Ok(())
}
pub async fn converge<H: Hooks, M: MigratorTrait>(
db: &DatabaseConnection,
config: &config::Database,
) -> AppResult<()> {
if config.dangerously_recreate {
info!("recreating schema");
reset::<M>(db).await?;
return Ok(());
}
if config.auto_migrate {
info!("auto migrating");
migrate::<M>(db).await?;
}
if config.dangerously_truncate {
info!("truncating tables");
H::truncate(db).await?;
}
Ok(())
}
pub async fn connect(config: &config::Database) -> Result<DbConn, sea_orm::DbErr> {
let mut opt = ConnectOptions::new(&config.uri);
opt.max_connections(config.max_connections)
.min_connections(config.min_connections)
.connect_timeout(Duration::from_millis(config.connect_timeout))
.idle_timeout(Duration::from_millis(config.idle_timeout))
.sqlx_logging(config.enable_logging);
if let Some(acquire_timeout) = config.acquire_timeout {
opt.acquire_timeout(Duration::from_millis(acquire_timeout));
}
let db = Database::connect(opt).await?;
if db.get_database_backend() == DatabaseBackend::Sqlite {
db.execute(Statement::from_string(
DatabaseBackend::Sqlite,
"
PRAGMA foreign_keys = ON;
PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA mmap_size = 134217728;
PRAGMA journal_size_limit = 67108864;
PRAGMA cache_size = 2000;
",
))
.await?;
}
Ok(db)
}
pub async fn create(db_uri: &str) -> AppResult<()> {
if !db_uri.starts_with("postgres://") {
return Err(Error::string(
"Only Postgres databases are supported for table creation",
));
}
let db_name = get_extract_db_name()
.captures(db_uri)
.and_then(|cap| cap.get(1).map(|db| db.as_str()))
.ok_or_else(|| {
Error::string(
"The specified table name was not found in the given Postgre database URI",
)
})?;
let conn = get_extract_db_name()
.replace(db_uri, "/postgres")
.to_string();
let db = Database::connect(conn).await?;
Ok(create_postgres_database(db_name, &db).await?)
}
pub async fn migrate<M: MigratorTrait>(db: &DatabaseConnection) -> Result<(), sea_orm::DbErr> {
M::up(db, None).await
}
pub async fn down<M: MigratorTrait>(
db: &DatabaseConnection,
steps: u32,
) -> Result<(), sea_orm::DbErr> {
M::down(db, Some(steps)).await
}
pub async fn status<M: MigratorTrait>(db: &DatabaseConnection) -> Result<(), sea_orm::DbErr> {
M::status(db).await
}
pub async fn reset<M: MigratorTrait>(db: &DatabaseConnection) -> Result<(), sea_orm::DbErr> {
M::fresh(db).await?;
migrate::<M>(db).await
}
use sea_orm::EntityName;
use serde_json::Value;
#[allow(clippy::type_repetition_in_bounds)]
pub async fn seed<A>(db: &DatabaseConnection, path: &str) -> crate::Result<()>
where
<<A as ActiveModelTrait>::Entity as EntityTrait>::Model: IntoActiveModel<A>,
for<'de> <<A as ActiveModelTrait>::Entity as EntityTrait>::Model: serde::de::Deserialize<'de>,
A: ActiveModelTrait + Send + Sync,
sea_orm::Insert<A>: Send + Sync,
<A as ActiveModelTrait>::Entity: EntityName,
{
let seed_data: Vec<Value> = serde_yaml::from_reader(File::open(path)?)?;
for row in seed_data {
let model = A::from_json(row)?;
A::Entity::insert(model).exec(db).await?;
}
let table_name = A::Entity::default().table_name().to_string();
let db_backend = db.get_database_backend();
reset_autoincrement(db_backend, &table_name, db).await?;
Ok(())
}
pub async fn reset_autoincrement(
db_backend: DatabaseBackend,
table_name: &str,
db: &DatabaseConnection,
) -> crate::Result<()> {
match db_backend {
DatabaseBackend::Postgres => {
let query_str = format!(
"SELECT setval(pg_get_serial_sequence('{table_name}', 'id'), COALESCE(MAX(id), 0) \
+ 1, false) FROM {table_name}"
);
db.execute(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
&query_str,
vec![],
))
.await?;
}
DatabaseBackend::Sqlite => {
let query_str = format!(
"UPDATE sqlite_sequence SET seq = (SELECT MAX(id) FROM {table_name}) WHERE name = \
'{table_name}'"
);
db.execute(Statement::from_sql_and_values(
DatabaseBackend::Sqlite,
&query_str,
vec![],
))
.await?;
}
DatabaseBackend::MySql => {
return Err(Error::Message(
"Unsupported database backend: MySQL".to_string(),
))
}
}
Ok(())
}
pub async fn entities<M: MigratorTrait>(ctx: &AppContext) -> AppResult<String> {
doctor::check_seaorm_cli()?.to_result()?;
doctor::check_db(&ctx.config.database).await.to_result()?;
let out = cmd!(
"sea-orm-cli",
"generate",
"entity",
"--with-serde",
"both",
"--output-dir",
"src/models/_entities",
"--database-url",
&ctx.config.database.uri,
"--ignore-tables",
IGNORED_TABLES.join(","),
)
.stderr_to_stdout()
.run()
.map_err(|err| {
Error::Message(format!(
"failed to generate entity using sea-orm-cli binary. error details: `{err}`",
))
})?;
fix_entities()?;
Ok(String::from_utf8_lossy(&out.stdout).to_string())
}
fn fix_entities() -> AppResult<()> {
let dir = fs::read_dir("src/models/_entities")?
.flatten()
.filter(|ent| {
ent.path().is_file() && ent.file_name() != "mod.rs" && ent.file_name() != "prelude.rs"
})
.map(|ent| ent.path())
.collect::<Vec<_>>();
let activemodel_exp = "impl ActiveModelBehavior for ActiveModel {}";
let mut cleaned_entities = Vec::new();
for file in &dir {
let content = fs::read_to_string(file)?;
if content.contains(activemodel_exp) {
let content = content
.lines()
.filter(|line| !line.contains(activemodel_exp))
.collect::<Vec<_>>()
.join("\n");
fs::write(file, content)?;
cleaned_entities.push(file);
}
}
let mut models_mod = fs::read_to_string("src/models/mod.rs")?;
for entity_file in cleaned_entities {
let new_file = Path::new("src/models").join(
entity_file
.file_name()
.ok_or_else(|| Error::string("cannot extract file name"))?,
);
if !new_file.exists() {
let module = new_file
.file_stem()
.ok_or_else(|| Error::string("cannot extract file stem"))?
.to_str()
.ok_or_else(|| Error::string("cannot extract file stem"))?;
let module_pascal = heck::AsPascalCase(module);
fs::write(
&new_file,
format!(
r"use sea_orm::entity::prelude::*;
use super::_entities::{module}::{{ActiveModel, Entity}};
pub type {module_pascal} = Entity;
#[async_trait::async_trait]
impl ActiveModelBehavior for ActiveModel {{
// extend activemodel below (keep comment for generators)
async fn before_save<C>(self, _db: &C, insert: bool) -> std::result::Result<Self, DbErr>
where
C: ConnectionTrait,
{{
if !insert && self.updated_at.is_unchanged() {{
let mut this = self;
this.updated_at = sea_orm::ActiveValue::Set(chrono::Utc::now().into());
Ok(this)
}} else {{
Ok(self)
}}
}}
}}
"
),
)?;
if !models_mod.contains(&format!("mod {module}")) {
models_mod.push_str(&format!("pub mod {module};\n"));
}
}
}
fs::write("src/models/mod.rs", models_mod)?;
Ok(())
}
pub async fn truncate_table<T>(db: &DatabaseConnection, _: T) -> Result<(), sea_orm::DbErr>
where
T: EntityTrait,
{
T::delete_many().exec(db).await?;
Ok(())
}
pub async fn run_app_seed<H: Hooks>(db: &DatabaseConnection, path: &Path) -> AppResult<()> {
H::seed(db, path).await
}
async fn create_postgres_database(
db_name: &str,
db: &DatabaseConnection,
) -> Result<(), sea_orm::DbErr> {
let with_options =
std::env::var("LOCO_POSTGRES_DB_OPTIONS").unwrap_or_else(|_| "ENCODING='UTF8'".to_string());
let query = format!("CREATE DATABASE {db_name} WITH {with_options}");
tracing::info!(query, "creating postgres database");
db.execute(sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
query,
))
.await?;
Ok(())
}