use crate::base::config::DbConfig;
use crate::base::db_type::DbType;
use crate::base::error::DatabaseError;
use crate::pool::datasource::{get_datasource_name, set_datasource_type};
use dashmap::DashMap;
use tracing::{info, trace, warn};
use std::any::{Any, TypeId};
use std::error::Error;
use std::sync::{Arc, OnceLock};
static DB_REGISTRY: OnceLock<Arc<DatabaseRegistry>> = OnceLock::new();
struct DatabaseRegistry {
instances: DashMap<(String, TypeId), Arc<dyn Any + Send + Sync>>,
defaults: DashMap<TypeId, Arc<dyn Any + Send + Sync>>,
}
impl DatabaseRegistry {
fn new() -> Self {
Self {
instances: DashMap::new(),
defaults: DashMap::new(),
}
}
fn insert<M: Send + Sync + 'static>(
&self,
name: String,
instance: Arc<DbManager<M>>,
) {
let type_id = TypeId::of::<M>();
let key = (name.clone(), type_id);
warn!("Inserting instance with key: {:?}", key);
self.instances.insert(key, instance.clone() as Arc<dyn Any + Send + Sync>);
self.defaults.entry(type_id).or_insert(instance as Arc<dyn Any + Send + Sync>);
}
fn get_instance<M: Send + Sync + 'static>(
&self,
name: &str
) -> Option<Arc<DbManager<M>>> {
let key = (name.to_string(), TypeId::of::<M>());
trace!("Getting instance with key: {:?}", key);
self.instances
.get(&key)
.and_then(|entry| entry.value().clone().downcast::<DbManager<M>>().ok())
}
fn get_default<M: Send + Sync + 'static>(&self) -> Option<Arc<DbManager<M>>> {
self.defaults
.get(&TypeId::of::<M>())
.and_then(|entry| entry.value().clone().downcast::<DbManager<M>>().ok())
}
fn contains<M: Send + Sync + 'static>(&self, name: &str) -> bool {
let key = (name.to_string(), TypeId::of::<M>());
self.instances.contains_key(&key)
}
fn remove<M: Send + Sync + 'static>(
&self,
name: &str
) -> Option<Arc<DbManager<M>>> {
let key = (name.to_string(), TypeId::of::<M>());
if let Some((_, instance)) = self.instances.remove(&key) {
let type_id = TypeId::of::<M>();
if let Some(default_entry) = self.defaults.get(&type_id) {
if let Ok(default_instance) = default_entry.value().clone().downcast::<DbManager<M>>() {
if default_instance.get_name() == name {
self.defaults.remove(&type_id);
if let Some(first_instance) = self.get_first_instance::<M>() {
self.defaults.insert(type_id, first_instance as Arc<dyn Any + Send + Sync>);
}
}
}
}
instance.downcast::<DbManager<M>>().ok()
} else {
None
}
}
fn get_first_instance<M: Send + Sync + 'static>(&self) -> Option<Arc<DbManager<M>>> {
let type_id = TypeId::of::<M>();
self.instances
.iter()
.find(|entry| entry.key().1 == type_id)
.and_then(|entry| entry.value().clone().downcast::<DbManager<M>>().ok())
}
fn list_instances<M: Send + Sync + 'static>(&self) -> Vec<String> {
let type_id = TypeId::of::<M>();
self.instances
.iter()
.filter(|entry| entry.key().1 == type_id)
.map(|entry| entry.key().0.clone())
.collect()
}
fn instance_count(&self) -> usize {
self.instances.len()
}
}
pub struct DbManager<M: Send + Sync + 'static> {
pool: M,
name: String,
db_type: DbType,
}
impl<M: Send + Sync + 'static> DbManager<M> {
pub fn init_registry() {
DB_REGISTRY.get_or_init(|| Arc::new(DatabaseRegistry::new()));
}
pub fn register<F>(config: &DbConfig, factory: F) -> Result<Arc<Self>, DatabaseError>
where
F: Fn(&DbConfig) -> Result<M,DatabaseError> + Sync + Send + 'static,
{
Self::init_registry();
let registry = DB_REGISTRY.get().unwrap();
if registry.contains::<M>(&config.name) {
return Err(DatabaseError::InstanceAlreadyExistsError(config.name.clone()));
}
set_datasource_type(config.name.clone(), config.db_type);
let pool = factory(config)?;
let instance = Arc::new(Self {
pool,
name: config.name.clone(),
db_type: config.db_type,
});
registry.insert(config.name.clone(), instance.clone());
Ok(instance)
}
pub fn register_batch<F>(configs: Vec<DbConfig>, factory: F) -> Result<Vec<Arc<Self>>, Box<dyn Error>>
where
F: Fn(&DbConfig) -> Result<M, DatabaseError> + Sync + Send + 'static + Clone,
{
let mut instances = Vec::new();
for config in configs {
let instance = Self::register(&config, factory.clone())?;
instances.push(instance);
}
Ok(instances)
}
pub fn unregister(name: &str) -> Result<Option<Arc<Self>>, Box<dyn Error>> {
let registry = DB_REGISTRY.get()
.ok_or("Database registry not initialized")?;
let removed = registry.remove::<M>(name);
if removed.is_some() {
info!("Database instance '{}' unregistered successfully", name);
}
Ok(removed)
}
pub fn get_instance(name: &str) -> Result<Arc<Self>, DatabaseError> {
let registry = DB_REGISTRY.get().ok_or(DatabaseError::NotFoundError("Database registry not initialized".to_string()))?;
registry.get_instance::<M>(name).ok_or(DatabaseError::NotFoundError(format!("Database instance '{}' not found", name)))
}
pub fn get_current() -> Option<Arc<Self>> {
let registry = DB_REGISTRY.get()?;
let name = get_datasource_name();
registry.get_instance::<M>(&name)
}
pub fn default() -> Option<Arc<Self>> {
let registry = DB_REGISTRY.get()?;
registry.get_default::<M>()
}
pub fn exists(name: &str) -> bool {
if let Some(registry) = DB_REGISTRY.get() {
registry.contains::<M>(name)
} else {
false
}
}
pub fn list_instances() -> Vec<String> {
if let Some(registry) = DB_REGISTRY.get() {
registry.list_instances::<M>()
} else {
Vec::new()
}
}
pub fn count() -> usize {
if let Some(registry) = DB_REGISTRY.get() {
registry.instance_count()
} else {
0
}
}
#[inline]
pub fn get_pool(&self) -> &M {
&self.pool
}
#[inline]
pub fn get_db_type(&self) -> DbType {
self.db_type
}
#[inline]
pub fn get_name(&self) -> &str {
&self.name
}
}
pub trait DatabaseManagerExt<M: Send + Sync + 'static> {
fn with_all_instances<F, R>(f: F) -> Vec<R>
where
F: Fn(Arc<DbManager<M>>) -> R + Send + Sync;
}
impl<M: Send + Sync + 'static> DatabaseManagerExt<M> for DbManager<M> {
fn with_all_instances<F, R>(f: F) -> Vec<R>
where
F: Fn(Arc<DbManager<M>>) -> R + Send + Sync,
{
let registry = match DB_REGISTRY.get() {
Some(r) => r,
None => return Vec::new(),
};
let type_id = TypeId::of::<M>();
registry.instances
.iter()
.filter(|entry| entry.key().1 == type_id)
.filter_map(|entry| {
entry.value().clone().downcast::<DbManager<M>>().ok().map(|instance| f(instance))
})
.collect()
}
}
pub trait DbRegister{
fn register_db(&self,config: &DbConfig) -> Result<(), DatabaseError>;
fn check_config(&self, config: &DbConfig) -> Result<(), DatabaseError>{
if config.url.is_none() {
return Err(DatabaseError::ConfigNotFoundError("Database is missing".to_string()));
}
if config.db_type == DbType::Other {
return Err(DatabaseError::ConfigNotFoundError("Database type is missing".to_string()));
}
if config.username.is_none() {
return Err(DatabaseError::ConfigNotFoundError("Database username is missing".to_string()));
}
if config.password.is_none() {
return Err(DatabaseError::ConfigNotFoundError("Database password is missing".to_string()));
}
Ok(())
}
}