burncloud_database_client/
pool.rs1use burncloud_database_core::error::{DatabaseResult, DatabaseError};
2use burncloud_database_core::DatabaseConfig;
3use crate::{DatabaseClient, DatabaseClientFactory};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8pub struct DatabasePool {
9 clients: Arc<RwLock<HashMap<String, Arc<DatabaseClient>>>>,
10 config: DatabaseConfig,
11 max_connections: usize,
12}
13
14impl DatabasePool {
15 pub fn new(config: DatabaseConfig, max_connections: usize) -> Self {
16 Self {
17 clients: Arc::new(RwLock::new(HashMap::new())),
18 config,
19 max_connections,
20 }
21 }
22
23 pub async fn get_client(&self, client_id: Option<&str>) -> DatabaseResult<Arc<DatabaseClient>> {
24 let key = client_id.unwrap_or("default").to_string();
25
26 {
27 let clients = self.clients.read().await;
28 if let Some(client) = clients.get(&key) {
29 if client.is_connected().await {
30 return Ok(client.clone());
31 }
32 }
33 }
34
35 let mut clients = self.clients.write().await;
36
37 if clients.len() >= self.max_connections {
38 return Err(DatabaseError::ConnectionFailed(
39 "Connection pool limit reached".to_string()
40 ));
41 }
42
43 let connection = DatabaseClientFactory::create_connection(&self.config)?;
44 let query_executor = DatabaseClientFactory::create_query_executor(&self.config)?;
45
46 let client = Arc::new(DatabaseClient::new(
47 connection,
48 query_executor,
49 ));
50
51 client.connect().await?;
52 clients.insert(key.clone(), client.clone());
53
54 Ok(client)
55 }
56
57 pub async fn remove_client(&self, client_id: &str) -> DatabaseResult<()> {
58 let mut clients = self.clients.write().await;
59 if let Some(client) = clients.remove(client_id) {
60 client.disconnect().await?;
61 }
62 Ok(())
63 }
64
65 pub async fn disconnect_all(&self) -> DatabaseResult<()> {
66 let mut clients = self.clients.write().await;
67 for (_, client) in clients.drain() {
68 let _ = client.disconnect().await;
69 }
70 Ok(())
71 }
72
73 pub async fn get_pool_size(&self) -> usize {
74 let clients = self.clients.read().await;
75 clients.len()
76 }
77}