1use std::sync::Arc;
4use std::time::Duration;
5
6use bb8::{Pool, PooledConnection};
7use bb8_tiberius::ConnectionManager;
8use tracing::{debug, info};
9
10use crate::config::MssqlConfig;
11use crate::connection::MssqlConnection;
12use crate::error::{MssqlError, MssqlResult};
13
14type TiberiusPool = Pool<ConnectionManager>;
16
17#[derive(Clone)]
19pub struct MssqlPool {
20 inner: TiberiusPool,
21 config: Arc<MssqlConfig>,
22 max_size: usize,
23}
24
25impl MssqlPool {
26 pub async fn new(config: MssqlConfig) -> MssqlResult<Self> {
28 Self::with_pool_config(config, PoolConfig::default()).await
29 }
30
31 pub async fn with_pool_config(
33 config: MssqlConfig,
34 pool_config: PoolConfig,
35 ) -> MssqlResult<Self> {
36 let tiberius_config = config.to_tiberius_config()?;
37
38 let mgr = ConnectionManager::new(tiberius_config);
39
40 let pool = Pool::builder()
41 .max_size(pool_config.max_connections as u32)
42 .min_idle(Some(pool_config.min_connections as u32))
43 .connection_timeout(
44 pool_config
45 .connection_timeout
46 .unwrap_or(Duration::from_secs(30)),
47 )
48 .idle_timeout(pool_config.idle_timeout)
49 .max_lifetime(pool_config.max_lifetime)
50 .build(mgr)
51 .await
52 .map_err(|e| MssqlError::pool(format!("failed to create pool: {}", e)))?;
53
54 info!(
55 host = %config.host,
56 port = %config.port,
57 database = %config.database,
58 max_connections = %pool_config.max_connections,
59 "MSSQL connection pool created"
60 );
61
62 Ok(Self {
63 inner: pool,
64 config: Arc::new(config),
65 max_size: pool_config.max_connections,
66 })
67 }
68
69 pub async fn get(&self) -> MssqlResult<MssqlConnection<'_>> {
71 debug!("Acquiring connection from pool");
72 let client = self.inner.get().await?;
73 Ok(MssqlConnection::new(client))
74 }
75
76 pub async fn get_raw(&self) -> MssqlResult<PooledConnection<'_, ConnectionManager>> {
78 let client = self.inner.get().await?;
79 Ok(client)
80 }
81
82 pub fn status(&self) -> PoolStatus {
84 let state = self.inner.state();
85 PoolStatus {
86 connections: state.connections as usize,
87 idle_connections: state.idle_connections as usize,
88 max_size: self.max_size,
89 }
90 }
91
92 pub fn config(&self) -> &MssqlConfig {
94 &self.config
95 }
96
97 pub async fn is_healthy(&self) -> bool {
99 match self.inner.get().await {
100 Ok(mut client) => {
101 client.simple_query("SELECT 1").await.is_ok()
103 }
104 Err(_) => false,
105 }
106 }
107
108 pub fn builder() -> MssqlPoolBuilder {
110 MssqlPoolBuilder::new()
111 }
112
113 pub async fn warmup(&self, count: usize) -> MssqlResult<()> {
115 info!(count = count, "Warming up MSSQL connection pool");
116
117 let count = count.min(self.max_size);
118 let mut connections = Vec::with_capacity(count);
119
120 for i in 0..count {
121 match self.inner.get().await {
122 Ok(mut conn) => {
123 if let Err(e) = conn.simple_query("SELECT 1").await {
125 debug!(error = %e, "Warmup connection {} failed validation", i);
126 } else {
127 debug!("Warmup connection {} established", i);
128 connections.push(conn);
129 }
130 }
131 Err(e) => {
132 debug!(error = %e, "Failed to establish warmup connection {}", i);
133 }
134 }
135 }
136
137 let established = connections.len();
138 drop(connections);
139
140 info!(
141 established = established,
142 requested = count,
143 "MSSQL connection pool warmup complete"
144 );
145
146 Ok(())
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct PoolStatus {
153 pub connections: usize,
155 pub idle_connections: usize,
157 pub max_size: usize,
159}
160
161#[derive(Debug, Clone)]
163pub struct PoolConfig {
164 pub max_connections: usize,
166 pub min_connections: usize,
168 pub connection_timeout: Option<Duration>,
170 pub idle_timeout: Option<Duration>,
172 pub max_lifetime: Option<Duration>,
174}
175
176impl Default for PoolConfig {
177 fn default() -> Self {
178 Self {
179 max_connections: 10,
180 min_connections: 1,
181 connection_timeout: Some(Duration::from_secs(30)),
182 idle_timeout: Some(Duration::from_secs(600)), max_lifetime: Some(Duration::from_secs(1800)), }
185 }
186}
187
188#[derive(Debug, Default)]
190pub struct MssqlPoolBuilder {
191 config: Option<MssqlConfig>,
192 connection_string: Option<String>,
193 pool_config: PoolConfig,
194}
195
196impl MssqlPoolBuilder {
197 pub fn new() -> Self {
199 Self {
200 config: None,
201 connection_string: None,
202 pool_config: PoolConfig::default(),
203 }
204 }
205
206 pub fn connection_string(mut self, conn_str: impl Into<String>) -> Self {
208 self.connection_string = Some(conn_str.into());
209 self
210 }
211
212 pub fn config(mut self, config: MssqlConfig) -> Self {
214 self.config = Some(config);
215 self
216 }
217
218 pub fn host(mut self, host: impl Into<String>) -> Self {
220 let config = self.config.get_or_insert_with(MssqlConfig::default);
221 config.host = host.into();
222 self
223 }
224
225 pub fn port(mut self, port: u16) -> Self {
227 let config = self.config.get_or_insert_with(MssqlConfig::default);
228 config.port = port;
229 self
230 }
231
232 pub fn database(mut self, database: impl Into<String>) -> Self {
234 let config = self.config.get_or_insert_with(MssqlConfig::default);
235 config.database = database.into();
236 self
237 }
238
239 pub fn username(mut self, username: impl Into<String>) -> Self {
241 let config = self.config.get_or_insert_with(MssqlConfig::default);
242 config.username = Some(username.into());
243 self
244 }
245
246 pub fn password(mut self, password: impl Into<String>) -> Self {
248 let config = self.config.get_or_insert_with(MssqlConfig::default);
249 config.password = Some(password.into());
250 self
251 }
252
253 pub fn max_connections(mut self, n: usize) -> Self {
255 self.pool_config.max_connections = n;
256 self
257 }
258
259 pub fn min_connections(mut self, n: usize) -> Self {
261 self.pool_config.min_connections = n;
262 self
263 }
264
265 pub fn connection_timeout(mut self, timeout: Duration) -> Self {
267 self.pool_config.connection_timeout = Some(timeout);
268 self
269 }
270
271 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
273 self.pool_config.idle_timeout = Some(timeout);
274 self
275 }
276
277 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
279 self.pool_config.max_lifetime = Some(lifetime);
280 self
281 }
282
283 pub fn trust_cert(mut self, trust: bool) -> Self {
285 let config = self.config.get_or_insert_with(MssqlConfig::default);
286 config.trust_cert = trust;
287 self
288 }
289
290 pub async fn build(self) -> MssqlResult<MssqlPool> {
292 let config = if let Some(config) = self.config {
293 config
294 } else if let Some(conn_str) = self.connection_string {
295 MssqlConfig::from_connection_string(conn_str)?
296 } else {
297 return Err(MssqlError::config(
298 "no connection string or config provided",
299 ));
300 };
301
302 MssqlPool::with_pool_config(config, self.pool_config).await
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_pool_config_default() {
312 let config = PoolConfig::default();
313 assert_eq!(config.max_connections, 10);
314 assert_eq!(config.min_connections, 1);
315 }
316
317 #[test]
318 fn test_pool_builder() {
319 let builder = MssqlPoolBuilder::new()
320 .host("localhost")
321 .database("test")
322 .username("sa")
323 .password("password")
324 .max_connections(20);
325
326 assert_eq!(builder.pool_config.max_connections, 20);
327 assert!(builder.config.is_some());
328 }
329}