Skip to main content

prax_mssql/
pool.rs

1//! Connection pool for Microsoft SQL Server.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use bb8::{Pool, PooledConnection};
7use bb8_tiberius::ConnectionManager;
8use tracing::{debug, info};
9
10use crate::config::MssqlConfig;
11use crate::connection::MssqlConnection;
12use crate::error::{MssqlError, MssqlResult};
13
14/// Type alias for the BB8 pool with Tiberius.
15type TiberiusPool = Pool<ConnectionManager>;
16
17/// A connection pool for Microsoft SQL Server.
18#[derive(Clone)]
19pub struct MssqlPool {
20    inner: TiberiusPool,
21    config: Arc<MssqlConfig>,
22    max_size: usize,
23}
24
25impl MssqlPool {
26    /// Create a new connection pool from configuration.
27    pub async fn new(config: MssqlConfig) -> MssqlResult<Self> {
28        Self::with_pool_config(config, PoolConfig::default()).await
29    }
30
31    /// Create a new connection pool with custom pool configuration.
32    pub async fn with_pool_config(
33        config: MssqlConfig,
34        pool_config: PoolConfig,
35    ) -> MssqlResult<Self> {
36        let tiberius_config = config.to_tiberius_config()?;
37
38        let mgr = ConnectionManager::new(tiberius_config);
39
40        let pool = Pool::builder()
41            .max_size(pool_config.max_connections as u32)
42            .min_idle(Some(pool_config.min_connections as u32))
43            .connection_timeout(
44                pool_config
45                    .connection_timeout
46                    .unwrap_or(Duration::from_secs(30)),
47            )
48            .idle_timeout(pool_config.idle_timeout)
49            .max_lifetime(pool_config.max_lifetime)
50            .build(mgr)
51            .await
52            .map_err(|e| MssqlError::pool(format!("failed to create pool: {}", e)))?;
53
54        info!(
55            host = %config.host,
56            port = %config.port,
57            database = %config.database,
58            max_connections = %pool_config.max_connections,
59            "MSSQL connection pool created"
60        );
61
62        Ok(Self {
63            inner: pool,
64            config: Arc::new(config),
65            max_size: pool_config.max_connections,
66        })
67    }
68
69    /// Get a connection from the pool.
70    pub async fn get(&self) -> MssqlResult<MssqlConnection<'_>> {
71        debug!("Acquiring connection from pool");
72        let client = self.inner.get().await?;
73        Ok(MssqlConnection::new(client))
74    }
75
76    /// Get a raw pooled connection (for advanced use).
77    pub async fn get_raw(&self) -> MssqlResult<PooledConnection<'_, ConnectionManager>> {
78        let client = self.inner.get().await?;
79        Ok(client)
80    }
81
82    /// Get the current pool status.
83    pub fn status(&self) -> PoolStatus {
84        let state = self.inner.state();
85        PoolStatus {
86            connections: state.connections as usize,
87            idle_connections: state.idle_connections as usize,
88            max_size: self.max_size,
89        }
90    }
91
92    /// Get the pool configuration.
93    pub fn config(&self) -> &MssqlConfig {
94        &self.config
95    }
96
97    /// Check if the pool is healthy by attempting to get a connection.
98    pub async fn is_healthy(&self) -> bool {
99        match self.inner.get().await {
100            Ok(mut client) => {
101                // Try a simple query to verify the connection is actually working
102                client.simple_query("SELECT 1").await.is_ok()
103            }
104            Err(_) => false,
105        }
106    }
107
108    /// Create a builder for configuring the pool.
109    pub fn builder() -> MssqlPoolBuilder {
110        MssqlPoolBuilder::new()
111    }
112
113    /// Warm up the connection pool by pre-establishing connections.
114    pub async fn warmup(&self, count: usize) -> MssqlResult<()> {
115        info!(count = count, "Warming up MSSQL connection pool");
116
117        let count = count.min(self.max_size);
118        let mut connections = Vec::with_capacity(count);
119
120        for i in 0..count {
121            match self.inner.get().await {
122                Ok(mut conn) => {
123                    // Validate the connection with a simple query
124                    if let Err(e) = conn.simple_query("SELECT 1").await {
125                        debug!(error = %e, "Warmup connection {} failed validation", i);
126                    } else {
127                        debug!("Warmup connection {} established", i);
128                        connections.push(conn);
129                    }
130                }
131                Err(e) => {
132                    debug!(error = %e, "Failed to establish warmup connection {}", i);
133                }
134            }
135        }
136
137        let established = connections.len();
138        drop(connections);
139
140        info!(
141            established = established,
142            requested = count,
143            "MSSQL connection pool warmup complete"
144        );
145
146        Ok(())
147    }
148}
149
150/// Pool status information.
151#[derive(Debug, Clone)]
152pub struct PoolStatus {
153    /// Current number of connections (including idle).
154    pub connections: usize,
155    /// Number of idle connections.
156    pub idle_connections: usize,
157    /// Maximum size of the pool.
158    pub max_size: usize,
159}
160
161/// Configuration for the connection pool.
162#[derive(Debug, Clone)]
163pub struct PoolConfig {
164    /// Maximum number of connections in the pool.
165    pub max_connections: usize,
166    /// Minimum number of idle connections to keep.
167    pub min_connections: usize,
168    /// Maximum time to wait for a connection.
169    pub connection_timeout: Option<Duration>,
170    /// Maximum idle time before a connection is closed.
171    pub idle_timeout: Option<Duration>,
172    /// Maximum lifetime of a connection.
173    pub max_lifetime: Option<Duration>,
174}
175
176impl Default for PoolConfig {
177    fn default() -> Self {
178        Self {
179            max_connections: 10,
180            min_connections: 1,
181            connection_timeout: Some(Duration::from_secs(30)),
182            idle_timeout: Some(Duration::from_secs(600)), // 10 minutes
183            max_lifetime: Some(Duration::from_secs(1800)), // 30 minutes
184        }
185    }
186}
187
188/// Builder for creating a connection pool.
189#[derive(Debug, Default)]
190pub struct MssqlPoolBuilder {
191    config: Option<MssqlConfig>,
192    connection_string: Option<String>,
193    pool_config: PoolConfig,
194}
195
196impl MssqlPoolBuilder {
197    /// Create a new pool builder.
198    pub fn new() -> Self {
199        Self {
200            config: None,
201            connection_string: None,
202            pool_config: PoolConfig::default(),
203        }
204    }
205
206    /// Set the connection string.
207    pub fn connection_string(mut self, conn_str: impl Into<String>) -> Self {
208        self.connection_string = Some(conn_str.into());
209        self
210    }
211
212    /// Set the configuration.
213    pub fn config(mut self, config: MssqlConfig) -> Self {
214        self.config = Some(config);
215        self
216    }
217
218    /// Set the server host.
219    pub fn host(mut self, host: impl Into<String>) -> Self {
220        let config = self.config.get_or_insert_with(MssqlConfig::default);
221        config.host = host.into();
222        self
223    }
224
225    /// Set the server port.
226    pub fn port(mut self, port: u16) -> Self {
227        let config = self.config.get_or_insert_with(MssqlConfig::default);
228        config.port = port;
229        self
230    }
231
232    /// Set the database name.
233    pub fn database(mut self, database: impl Into<String>) -> Self {
234        let config = self.config.get_or_insert_with(MssqlConfig::default);
235        config.database = database.into();
236        self
237    }
238
239    /// Set the username.
240    pub fn username(mut self, username: impl Into<String>) -> Self {
241        let config = self.config.get_or_insert_with(MssqlConfig::default);
242        config.username = Some(username.into());
243        self
244    }
245
246    /// Set the password.
247    pub fn password(mut self, password: impl Into<String>) -> Self {
248        let config = self.config.get_or_insert_with(MssqlConfig::default);
249        config.password = Some(password.into());
250        self
251    }
252
253    /// Set the maximum number of connections.
254    pub fn max_connections(mut self, n: usize) -> Self {
255        self.pool_config.max_connections = n;
256        self
257    }
258
259    /// Set the minimum number of idle connections.
260    pub fn min_connections(mut self, n: usize) -> Self {
261        self.pool_config.min_connections = n;
262        self
263    }
264
265    /// Set the connection timeout.
266    pub fn connection_timeout(mut self, timeout: Duration) -> Self {
267        self.pool_config.connection_timeout = Some(timeout);
268        self
269    }
270
271    /// Set the idle timeout.
272    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
273        self.pool_config.idle_timeout = Some(timeout);
274        self
275    }
276
277    /// Set the maximum connection lifetime.
278    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
279        self.pool_config.max_lifetime = Some(lifetime);
280        self
281    }
282
283    /// Trust the server certificate.
284    pub fn trust_cert(mut self, trust: bool) -> Self {
285        let config = self.config.get_or_insert_with(MssqlConfig::default);
286        config.trust_cert = trust;
287        self
288    }
289
290    /// Build the connection pool.
291    pub async fn build(self) -> MssqlResult<MssqlPool> {
292        let config = if let Some(config) = self.config {
293            config
294        } else if let Some(conn_str) = self.connection_string {
295            MssqlConfig::from_connection_string(conn_str)?
296        } else {
297            return Err(MssqlError::config(
298                "no connection string or config provided",
299            ));
300        };
301
302        MssqlPool::with_pool_config(config, self.pool_config).await
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_pool_config_default() {
312        let config = PoolConfig::default();
313        assert_eq!(config.max_connections, 10);
314        assert_eq!(config.min_connections, 1);
315    }
316
317    #[test]
318    fn test_pool_builder() {
319        let builder = MssqlPoolBuilder::new()
320            .host("localhost")
321            .database("test")
322            .username("sa")
323            .password("password")
324            .max_connections(20);
325
326        assert_eq!(builder.pool_config.max_connections, 20);
327        assert!(builder.config.is_some());
328    }
329}