use crate::common::error::QueryError;
use sqlx::{Pool, MySql};
use sqlx::{pool::PoolOptions, Error, MySqlPool};
use sqlx::mysql::{MySqlConnectOptions, MySqlSslMode};
use std::cmp::{max, min};
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::OnceCell;
use std::time::Duration;
static DB_POOL: OnceCell<Arc<MySqlPool>> = OnceCell::const_new();
fn connect_limits() -> (u32, u32, u32) {
let num_cpus = num_cpus::get() as u32;
let max_connections = max(10, min(50, num_cpus * 2));
let min_connections = max(2, min(10, num_cpus / 2));
let warmup = (max_connections as f32 * 0.2).ceil() as u32;
(max_connections, min_connections, warmup)
}
pub async fn setup_db_pool(pool: Pool<MySql>) -> Result<&'static MySqlPool, Error> {
let pool = Arc::new(pool);
DB_POOL.get_or_try_init(|| async { Ok(pool) }).await
.map(|arc| arc.as_ref())
}
pub async fn create_db_pool(database_url: &str) -> Result<&'static MySqlPool, Error> {
let (maxc, minc, warmupc) = connect_limits();
let mut options = MySqlConnectOptions::from_str(database_url)
.map_err(|e| Error::from(e))?;
let ssl_mode = if database_url.contains("sslmode=disable") {
MySqlSslMode::Disabled
} else if database_url.contains("sslmode=require") {
MySqlSslMode::Required
} else {
MySqlSslMode::Preferred
};
options = options.ssl_mode(ssl_mode);
let pool = PoolOptions::new()
.max_connections(maxc)
.min_connections(minc)
.acquire_timeout(Duration::from_secs(5))
.test_before_acquire(false)
.idle_timeout(Duration::from_secs(300))
.max_lifetime(Duration::from_secs(1800))
.connect_with(options)
.await
.map_err(|e| Error::from(e))?;
let _ = warmup_connect(&pool, warmupc).await;
setup_db_pool(pool).await
}
async fn warmup_connect(pool: &MySqlPool, warmup_num: u32) -> Result<(), Error> {
for _ in 0..warmup_num {
let conn = pool.acquire().await?;
drop(conn);
}
Ok(())
}
pub fn get_db_pool() -> Result<Arc<MySqlPool>, Error> {
DB_POOL.get()
.cloned()
.ok_or_else(||QueryError::DBPoolNotInitialized.into())
}