burncloud_database_client/
pool.rs

1use burncloud_database_core::error::{DatabaseResult, DatabaseError};
2use burncloud_database_core::DatabaseConfig;
3use crate::{DatabaseClient, DatabaseClientFactory};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8pub struct DatabasePool {
9    clients: Arc<RwLock<HashMap<String, Arc<DatabaseClient>>>>,
10    config: DatabaseConfig,
11    max_connections: usize,
12}
13
14impl DatabasePool {
15    pub fn new(config: DatabaseConfig, max_connections: usize) -> Self {
16        Self {
17            clients: Arc::new(RwLock::new(HashMap::new())),
18            config,
19            max_connections,
20        }
21    }
22
23    pub async fn get_client(&self, client_id: Option<&str>) -> DatabaseResult<Arc<DatabaseClient>> {
24        let key = client_id.unwrap_or("default").to_string();
25
26        {
27            let clients = self.clients.read().await;
28            if let Some(client) = clients.get(&key) {
29                if client.is_connected().await {
30                    return Ok(client.clone());
31                }
32            }
33        }
34
35        let mut clients = self.clients.write().await;
36
37        if clients.len() >= self.max_connections {
38            return Err(DatabaseError::ConnectionFailed(
39                "Connection pool limit reached".to_string()
40            ));
41        }
42
43        let connection = DatabaseClientFactory::create_connection(&self.config)?;
44        let query_executor = DatabaseClientFactory::create_query_executor(&self.config)?;
45
46        let client = Arc::new(DatabaseClient::new(
47            connection,
48            query_executor,
49        ));
50
51        client.connect().await?;
52        clients.insert(key.clone(), client.clone());
53
54        Ok(client)
55    }
56
57    pub async fn remove_client(&self, client_id: &str) -> DatabaseResult<()> {
58        let mut clients = self.clients.write().await;
59        if let Some(client) = clients.remove(client_id) {
60            client.disconnect().await?;
61        }
62        Ok(())
63    }
64
65    pub async fn disconnect_all(&self) -> DatabaseResult<()> {
66        let mut clients = self.clients.write().await;
67        for (_, client) in clients.drain() {
68            let _ = client.disconnect().await;
69        }
70        Ok(())
71    }
72
73    pub async fn get_pool_size(&self) -> usize {
74        let clients = self.clients.read().await;
75        clients.len()
76    }
77}