use std::{collections::HashMap, fs::File, path::Path, time::Duration};
use duct::cmd;
use fs_err as fs;
use lazy_static::lazy_static;
use regex::Regex;
use sea_orm::{
ActiveModelTrait, ConnectOptions, ConnectionTrait, Database, DatabaseBackend,
DatabaseConnection, DbConn, EntityTrait, IntoActiveModel, Statement,
};
use sea_orm_migration::MigratorTrait;
use tracing::info;
use super::Result as AppResult;
use crate::{
app::{AppContext, Hooks},
config, doctor,
errors::Error,
};
lazy_static! {
pub static ref EXTRACT_DB_NAME: Regex = Regex::new(r"/([^/]+)$").unwrap();
}
#[derive(Default, Clone)]
pub struct MultiDb {
pub db: HashMap<String, DatabaseConnection>,
}
impl MultiDb {
pub async fn new(dbs_config: HashMap<String, config::Database>) -> AppResult<Self> {
let mut multi_db = Self::default();
for (db_name, db_config) in dbs_config {
multi_db.db.insert(db_name, connect(&db_config).await?);
}
Ok(multi_db)
}
pub fn get(&self, name: &str) -> AppResult<&DatabaseConnection> {
self.db
.get(name)
.map_or_else(|| Err(Error::Message("db not found".to_owned())), Ok)
}
}
#[allow(clippy::match_wildcard_for_single_variants)]
pub async fn verify_access(db: &DatabaseConnection) -> AppResult<()> {
match db {
DatabaseConnection::SqlxPostgresPoolConnection(_) => {
let res = db
.query_all(Statement::from_string(
DatabaseBackend::Postgres,
"SELECT * FROM pg_catalog.pg_tables WHERE tableowner = current_user;",
))
.await?;
if res.is_empty() {
return Err(Error::string(
"current user has no access to tables in the database",
));
}
}
DatabaseConnection::Disconnected => {
return Err(Error::string("connection to database has been closed"));
}
_ => {}
}
Ok(())
}
pub async fn converge<H: Hooks, M: MigratorTrait>(
db: &DatabaseConnection,
config: &config::Database,
) -> AppResult<()> {
if config.dangerously_recreate {
info!("recreating schema");
reset::<M>(db).await?;
return Ok(());
}
if config.auto_migrate {
info!("auto migrating");
migrate::<M>(db).await?;
}
if config.dangerously_truncate {
info!("truncating tables");
H::truncate(db).await?;
}
Ok(())
}
pub async fn connect(config: &config::Database) -> Result<DbConn, sea_orm::DbErr> {
let mut opt = ConnectOptions::new(&config.uri);
opt.max_connections(config.max_connections)
.min_connections(config.min_connections)
.connect_timeout(Duration::from_millis(config.connect_timeout))
.idle_timeout(Duration::from_millis(config.idle_timeout))
.sqlx_logging(config.enable_logging);
if let Some(acquire_timeout) = config.acquire_timeout {
opt.acquire_timeout(Duration::from_millis(acquire_timeout));
}
Database::connect(opt).await
}
pub async fn create(db_uri: &str) -> AppResult<()> {
if !db_uri.starts_with("postgres://") {
return Err(Error::string(
"Only Postgres databases are supported for table creation",
));
}
let db_name = EXTRACT_DB_NAME
.captures(db_uri)
.and_then(|cap| cap.get(1).map(|db| db.as_str()))
.ok_or_else(|| {
Error::string(
"The specified table name was not found in the given Postgre database URI",
)
})?;
let conn = EXTRACT_DB_NAME.replace(db_uri, "/postgres").to_string();
let db = Database::connect(conn).await?;
Ok(create_postgres_database(db_name, &db).await?)
}
pub async fn migrate<M: MigratorTrait>(db: &DatabaseConnection) -> Result<(), sea_orm::DbErr> {
M::up(db, None).await
}
pub async fn status<M: MigratorTrait>(db: &DatabaseConnection) -> Result<(), sea_orm::DbErr> {
M::status(db).await
}
pub async fn reset<M: MigratorTrait>(db: &DatabaseConnection) -> Result<(), sea_orm::DbErr> {
M::fresh(db).await?;
migrate::<M>(db).await
}
#[allow(clippy::type_repetition_in_bounds)]
pub async fn seed<A>(db: &DatabaseConnection, path: &str) -> AppResult<()>
where
<<A as ActiveModelTrait>::Entity as EntityTrait>::Model: IntoActiveModel<A>,
for<'de> <<A as ActiveModelTrait>::Entity as EntityTrait>::Model: serde::de::Deserialize<'de>,
A: sea_orm::ActiveModelTrait + Send + Sync,
sea_orm::Insert<A>: Send + Sync, {
let loader: Vec<serde_json::Value> = serde_yaml::from_reader(File::open(path)?)?;
let mut users: Vec<A> = vec![];
for user in loader {
users.push(A::from_json(user)?);
}
<A as ActiveModelTrait>::Entity::insert_many(users)
.exec(db)
.await?;
Ok(())
}
pub async fn entities<M: MigratorTrait>(ctx: &AppContext) -> AppResult<String> {
doctor::check_seaorm_cli().to_result()?;
doctor::check_db(&ctx.config.database).await.to_result()?;
let out = cmd!(
"sea-orm-cli",
"generate",
"entity",
"--with-serde",
"both",
"--output-dir",
"src/models/_entities",
"--database-url",
&ctx.config.database.uri
)
.stderr_to_stdout()
.run()
.map_err(|err| {
Error::Message(format!(
"failed to generate entity using sea-orm-cli binary. error details: `{err}`",
))
})?;
fix_entities()?;
Ok(String::from_utf8_lossy(&out.stdout).to_string())
}
fn fix_entities() -> AppResult<()> {
let dir = fs::read_dir("src/models/_entities")?
.flatten()
.filter(|ent| {
ent.path().is_file() && ent.file_name() != "mod.rs" && ent.file_name() != "prelude.rs"
})
.map(|ent| ent.path())
.collect::<Vec<_>>();
let activemodel_exp = "impl ActiveModelBehavior for ActiveModel {}";
let mut cleaned_entities = Vec::new();
for file in &dir {
let content = fs::read_to_string(file)?;
if content.contains(activemodel_exp) {
let content = content
.lines()
.filter(|line| !line.contains(activemodel_exp))
.collect::<Vec<_>>()
.join("\n");
fs::write(file, content)?;
cleaned_entities.push(file);
}
}
let mut models_mod = fs::read_to_string("src/models/mod.rs")?;
for entity_file in cleaned_entities {
let new_file = Path::new("src/models").join(
entity_file
.file_name()
.ok_or_else(|| Error::string("cannot extract file name"))?,
);
if !new_file.exists() {
let module = new_file
.file_stem()
.ok_or_else(|| Error::string("cannot extract file stem"))?
.to_str()
.ok_or_else(|| Error::string("cannot extract file stem"))?;
fs::write(
&new_file,
format!(
r"use sea_orm::entity::prelude::*;
use super::_entities::{module}::ActiveModel;
impl ActiveModelBehavior for ActiveModel {{
// extend activemodel below (keep comment for generators)
}}
"
),
)?;
if !models_mod.contains(&format!("mod {module}")) {
models_mod.push_str(&format!("pub mod {module};\n"));
}
}
}
fs::write("src/models/mod.rs", models_mod)?;
Ok(())
}
pub async fn truncate_table<T>(db: &DatabaseConnection, _: T) -> Result<(), sea_orm::DbErr>
where
T: EntityTrait,
{
T::delete_many().exec(db).await?;
Ok(())
}
pub async fn run_app_seed<H: Hooks>(db: &DatabaseConnection, path: &Path) -> AppResult<()> {
H::seed(db, path).await
}
async fn create_postgres_database(
table_name: &str,
db: &DatabaseConnection,
) -> Result<(), sea_orm::DbErr> {
let with_options = std::env::var("LOCO_POSTGRES_TABLE_OPTIONS")
.unwrap_or_else(|_| "ENCODING='UTF8'".to_string());
let query = format!("CREATE DATABASE {table_name} WITH {with_options}");
tracing::info!(query, "creating postgres table");
db.execute(sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
query,
))
.await?;
Ok(())
}