Skip to main content

oxigdal_postgis/
connection.rs

1//! Database connection management for PostGIS
2//!
3//! This module provides connection pooling and management for PostgreSQL/PostGIS databases.
4
5use crate::error::{ConnectionError, PostGisError, Result};
6use deadpool_postgres::{Config, ManagerConfig, Pool, RecyclingMethod, Runtime};
7use std::time::Duration;
8use tokio_postgres::NoTls;
9use tracing::{debug, warn};
10
11/// PostgreSQL connection configuration
12#[derive(Debug, Clone)]
13pub struct ConnectionConfig {
14    /// Database host
15    pub host: Option<String>,
16    /// Database port
17    pub port: u16,
18    /// Database name
19    pub dbname: String,
20    /// Username
21    pub user: String,
22    /// Password
23    pub password: Option<String>,
24    /// Connection timeout in seconds
25    pub connect_timeout: u64,
26    /// Application name
27    pub application_name: Option<String>,
28    /// SSL mode
29    pub sslmode: SslMode,
30}
31
32/// SSL mode for PostgreSQL connections
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum SslMode {
35    /// Disable SSL
36    Disable,
37    /// Prefer SSL if available
38    Prefer,
39    /// Require SSL
40    Require,
41}
42
43impl SslMode {
44    /// Converts to PostgreSQL sslmode string
45    pub const fn as_str(&self) -> &'static str {
46        match self {
47            Self::Disable => "disable",
48            Self::Prefer => "prefer",
49            Self::Require => "require",
50        }
51    }
52}
53
54impl Default for ConnectionConfig {
55    fn default() -> Self {
56        Self {
57            host: Some("localhost".to_string()),
58            port: 5432,
59            dbname: "postgres".to_string(),
60            user: "postgres".to_string(),
61            password: None,
62            connect_timeout: 30,
63            application_name: Some("oxigdal-postgis".to_string()),
64            sslmode: SslMode::Prefer,
65        }
66    }
67}
68
69impl ConnectionConfig {
70    /// Creates a new connection configuration
71    pub fn new(dbname: impl Into<String>) -> Self {
72        Self {
73            dbname: dbname.into(),
74            ..Default::default()
75        }
76    }
77
78    /// Sets the host
79    pub fn host(mut self, host: impl Into<String>) -> Self {
80        self.host = Some(host.into());
81        self
82    }
83
84    /// Sets the port
85    pub const fn port(mut self, port: u16) -> Self {
86        self.port = port;
87        self
88    }
89
90    /// Sets the user
91    pub fn user(mut self, user: impl Into<String>) -> Self {
92        self.user = user.into();
93        self
94    }
95
96    /// Sets the password
97    pub fn password(mut self, password: impl Into<String>) -> Self {
98        self.password = Some(password.into());
99        self
100    }
101
102    /// Sets the connection timeout
103    pub const fn connect_timeout(mut self, seconds: u64) -> Self {
104        self.connect_timeout = seconds;
105        self
106    }
107
108    /// Sets the application name
109    pub fn application_name(mut self, name: impl Into<String>) -> Self {
110        self.application_name = Some(name.into());
111        self
112    }
113
114    /// Sets the SSL mode
115    pub const fn sslmode(mut self, mode: SslMode) -> Self {
116        self.sslmode = mode;
117        self
118    }
119
120    /// Builds a connection string
121    pub fn to_connection_string(&self) -> String {
122        let mut parts = Vec::new();
123
124        if let Some(ref host) = self.host {
125            parts.push(format!("host={host}"));
126        }
127
128        parts.push(format!("port={}", self.port));
129        parts.push(format!("dbname={}", self.dbname));
130        parts.push(format!("user={}", self.user));
131
132        if let Some(ref password) = self.password {
133            parts.push(format!("password={password}"));
134        }
135
136        parts.push(format!("connect_timeout={}", self.connect_timeout));
137
138        if let Some(ref app_name) = self.application_name {
139            parts.push(format!("application_name={app_name}"));
140        }
141
142        parts.push(format!("sslmode={}", self.sslmode.as_str()));
143
144        parts.join(" ")
145    }
146
147    /// Parses a connection string into configuration
148    pub fn from_connection_string(conn_str: &str) -> Result<Self> {
149        let mut config = Self::default();
150
151        for part in conn_str.split_whitespace() {
152            if let Some((key, value)) = part.split_once('=') {
153                match key {
154                    "host" => config.host = Some(value.to_string()),
155                    "port" => {
156                        config.port = value.parse().map_err(|_| {
157                            ConnectionError::InvalidConnectionString {
158                                message: format!("Invalid port: {value}"),
159                            }
160                        })?;
161                    }
162                    "dbname" => config.dbname = value.to_string(),
163                    "user" => config.user = value.to_string(),
164                    "password" => config.password = Some(value.to_string()),
165                    "connect_timeout" => {
166                        config.connect_timeout = value.parse().map_err(|_| {
167                            ConnectionError::InvalidConnectionString {
168                                message: format!("Invalid connect_timeout: {value}"),
169                            }
170                        })?;
171                    }
172                    "application_name" => config.application_name = Some(value.to_string()),
173                    "sslmode" => {
174                        config.sslmode = match value {
175                            "disable" => SslMode::Disable,
176                            "prefer" => SslMode::Prefer,
177                            "require" => SslMode::Require,
178                            _ => {
179                                return Err(ConnectionError::InvalidConnectionString {
180                                    message: format!("Invalid sslmode: {value}"),
181                                }
182                                .into());
183                            }
184                        };
185                    }
186                    _ => {
187                        warn!("Unknown connection string parameter: {key}");
188                    }
189                }
190            }
191        }
192
193        Ok(config)
194    }
195}
196
197/// Connection pool configuration
198#[derive(Debug, Clone)]
199pub struct PoolConfig {
200    /// Maximum pool size
201    pub max_size: usize,
202    /// Connection timeout
203    pub timeout: Duration,
204    /// Recycling method
205    pub recycling_method: RecyclingMethod,
206}
207
208impl Default for PoolConfig {
209    fn default() -> Self {
210        Self {
211            max_size: 16,
212            timeout: Duration::from_secs(30),
213            recycling_method: RecyclingMethod::Fast,
214        }
215    }
216}
217
218impl PoolConfig {
219    /// Creates a new pool configuration
220    pub fn new() -> Self {
221        Self::default()
222    }
223
224    /// Sets the maximum pool size
225    pub const fn max_size(mut self, size: usize) -> Self {
226        self.max_size = size;
227        self
228    }
229
230    /// Sets the connection timeout
231    pub const fn timeout(mut self, timeout: Duration) -> Self {
232        self.timeout = timeout;
233        self
234    }
235
236    /// Sets the recycling method
237    pub fn recycling_method(mut self, method: RecyclingMethod) -> Self {
238        self.recycling_method = method;
239        self
240    }
241}
242
243/// Connection pool for PostgreSQL/PostGIS
244pub struct ConnectionPool {
245    pool: Pool,
246    config: ConnectionConfig,
247}
248
249impl ConnectionPool {
250    /// Creates a new connection pool
251    pub fn new(config: ConnectionConfig) -> Result<Self> {
252        let pool_config = PoolConfig::default();
253        Self::with_pool_config(config, pool_config)
254    }
255
256    /// Creates a new connection pool with custom pool configuration
257    pub fn with_pool_config(config: ConnectionConfig, pool_config: PoolConfig) -> Result<Self> {
258        let conn_str = config.to_connection_string();
259        debug!("Creating connection pool with config: {}", conn_str);
260
261        let mut pg_config = Config::new();
262        if let Some(ref host) = config.host {
263            pg_config.host = Some(host.clone());
264        }
265        pg_config.port = Some(config.port);
266        pg_config.dbname = Some(config.dbname.clone());
267        pg_config.user = Some(config.user.clone());
268        pg_config.password = config.password.clone();
269        pg_config.connect_timeout = Some(Duration::from_secs(config.connect_timeout));
270        pg_config.application_name = config.application_name.clone();
271
272        pg_config.manager = Some(ManagerConfig {
273            recycling_method: pool_config.recycling_method,
274        });
275
276        let pool = pg_config
277            .create_pool(Some(Runtime::Tokio1), NoTls)
278            .map_err(|e| ConnectionError::PoolError {
279                message: e.to_string(),
280            })?;
281
282        Ok(Self { pool, config })
283    }
284
285    /// Creates a connection pool from a connection string
286    pub fn from_connection_string(conn_str: &str) -> Result<Self> {
287        let config = ConnectionConfig::from_connection_string(conn_str)?;
288        Self::new(config)
289    }
290
291    /// Gets a connection from the pool
292    pub async fn get(&self) -> Result<deadpool_postgres::Object> {
293        self.pool.get().await.map_err(|e| {
294            ConnectionError::PoolError {
295                message: e.to_string(),
296            }
297            .into()
298        })
299    }
300
301    /// Gets the pool status
302    pub fn status(&self) -> PoolStatus {
303        let status = self.pool.status();
304        PoolStatus {
305            size: status.size,
306            available: status.available,
307            max_size: status.max_size,
308        }
309    }
310
311    /// Checks if PostGIS extension is installed
312    pub async fn check_postgis(&self) -> Result<bool> {
313        let client = self.get().await?;
314
315        let query = "SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'postgis')";
316        let row = client.query_one(query, &[]).await.map_err(|e| {
317            PostGisError::Query(crate::error::QueryError::ExecutionFailed {
318                message: e.to_string(),
319            })
320        })?;
321
322        let exists: bool = row.get(0);
323        Ok(exists)
324    }
325
326    /// Checks the PostGIS version
327    pub async fn postgis_version(&self) -> Result<String> {
328        let client = self.get().await?;
329
330        let query = "SELECT PostGIS_Version()";
331        let row = client.query_one(query, &[]).await.map_err(|e| {
332            PostGisError::Query(crate::error::QueryError::ExecutionFailed {
333                message: e.to_string(),
334            })
335        })?;
336
337        let version: String = row.get(0);
338        Ok(version)
339    }
340
341    /// Performs a health check on the connection pool
342    pub async fn health_check(&self) -> Result<HealthCheckResult> {
343        let start = std::time::Instant::now();
344
345        // Try to get a connection
346        let client = self.get().await?;
347
348        // Execute a simple query
349        client.query_one("SELECT 1", &[]).await.map_err(|e| {
350            PostGisError::Query(crate::error::QueryError::ExecutionFailed {
351                message: e.to_string(),
352            })
353        })?;
354
355        let latency = start.elapsed();
356
357        // Check PostGIS
358        let postgis_installed = self.check_postgis().await?;
359        let postgis_version = if postgis_installed {
360            self.postgis_version().await.ok()
361        } else {
362            None
363        };
364
365        Ok(HealthCheckResult {
366            connected: true,
367            latency,
368            pool_status: self.status(),
369            postgis_installed,
370            postgis_version,
371        })
372    }
373
374    /// Returns the connection configuration
375    pub const fn config(&self) -> &ConnectionConfig {
376        &self.config
377    }
378}
379
380/// Pool status information
381#[derive(Debug, Clone)]
382pub struct PoolStatus {
383    /// Current pool size
384    pub size: usize,
385    /// Available connections
386    pub available: usize,
387    /// Maximum pool size
388    pub max_size: usize,
389}
390
391/// Health check result
392#[derive(Debug, Clone)]
393pub struct HealthCheckResult {
394    /// Whether the connection is established
395    pub connected: bool,
396    /// Connection latency
397    pub latency: Duration,
398    /// Pool status
399    pub pool_status: PoolStatus,
400    /// Whether PostGIS is installed
401    pub postgis_installed: bool,
402    /// PostGIS version (if installed)
403    pub postgis_version: Option<String>,
404}
405
406impl HealthCheckResult {
407    /// Returns true if the connection is healthy
408    pub fn is_healthy(&self) -> bool {
409        self.connected && self.postgis_installed
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    #[test]
418    fn test_connection_config_default() {
419        let config = ConnectionConfig::default();
420        assert_eq!(config.port, 5432);
421        assert_eq!(config.dbname, "postgres");
422        assert_eq!(config.user, "postgres");
423    }
424
425    #[test]
426    fn test_connection_config_builder() {
427        let config = ConnectionConfig::new("test_db")
428            .host("localhost")
429            .port(5433)
430            .user("test_user")
431            .password("test_pass")
432            .connect_timeout(60)
433            .application_name("test_app")
434            .sslmode(SslMode::Require);
435
436        assert_eq!(config.dbname, "test_db");
437        assert_eq!(config.host, Some("localhost".to_string()));
438        assert_eq!(config.port, 5433);
439        assert_eq!(config.user, "test_user");
440        assert_eq!(config.password, Some("test_pass".to_string()));
441        assert_eq!(config.connect_timeout, 60);
442        assert_eq!(config.application_name, Some("test_app".to_string()));
443        assert_eq!(config.sslmode, SslMode::Require);
444    }
445
446    #[test]
447    fn test_connection_string_generation() {
448        let config = ConnectionConfig::new("test_db")
449            .host("localhost")
450            .user("test_user")
451            .password("test_pass");
452
453        let conn_str = config.to_connection_string();
454        assert!(conn_str.contains("host=localhost"));
455        assert!(conn_str.contains("dbname=test_db"));
456        assert!(conn_str.contains("user=test_user"));
457        assert!(conn_str.contains("password=test_pass"));
458    }
459
460    #[test]
461    fn test_connection_string_parsing() {
462        let conn_str = "host=localhost port=5432 dbname=test_db user=test_user password=test_pass";
463        let config = ConnectionConfig::from_connection_string(conn_str).ok();
464        assert!(config.is_some());
465
466        let config = config.expect("config parsing failed");
467        assert_eq!(config.host, Some("localhost".to_string()));
468        assert_eq!(config.port, 5432);
469        assert_eq!(config.dbname, "test_db");
470        assert_eq!(config.user, "test_user");
471        assert_eq!(config.password, Some("test_pass".to_string()));
472    }
473
474    #[test]
475    fn test_sslmode() {
476        assert_eq!(SslMode::Disable.as_str(), "disable");
477        assert_eq!(SslMode::Prefer.as_str(), "prefer");
478        assert_eq!(SslMode::Require.as_str(), "require");
479    }
480
481    #[test]
482    fn test_pool_config() {
483        let config = PoolConfig::new()
484            .max_size(32)
485            .timeout(Duration::from_secs(60));
486
487        assert_eq!(config.max_size, 32);
488        assert_eq!(config.timeout, Duration::from_secs(60));
489    }
490}