1use crate::error::{ConnectionError, PostGisError, Result};
6use deadpool_postgres::{Config, ManagerConfig, Pool, RecyclingMethod, Runtime};
7use std::time::Duration;
8use tokio_postgres::NoTls;
9use tracing::{debug, warn};
10
11#[derive(Debug, Clone)]
13pub struct ConnectionConfig {
14 pub host: Option<String>,
16 pub port: u16,
18 pub dbname: String,
20 pub user: String,
22 pub password: Option<String>,
24 pub connect_timeout: u64,
26 pub application_name: Option<String>,
28 pub sslmode: SslMode,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum SslMode {
35 Disable,
37 Prefer,
39 Require,
41}
42
43impl SslMode {
44 pub const fn as_str(&self) -> &'static str {
46 match self {
47 Self::Disable => "disable",
48 Self::Prefer => "prefer",
49 Self::Require => "require",
50 }
51 }
52}
53
54impl Default for ConnectionConfig {
55 fn default() -> Self {
56 Self {
57 host: Some("localhost".to_string()),
58 port: 5432,
59 dbname: "postgres".to_string(),
60 user: "postgres".to_string(),
61 password: None,
62 connect_timeout: 30,
63 application_name: Some("oxigdal-postgis".to_string()),
64 sslmode: SslMode::Prefer,
65 }
66 }
67}
68
69impl ConnectionConfig {
70 pub fn new(dbname: impl Into<String>) -> Self {
72 Self {
73 dbname: dbname.into(),
74 ..Default::default()
75 }
76 }
77
78 pub fn host(mut self, host: impl Into<String>) -> Self {
80 self.host = Some(host.into());
81 self
82 }
83
84 pub const fn port(mut self, port: u16) -> Self {
86 self.port = port;
87 self
88 }
89
90 pub fn user(mut self, user: impl Into<String>) -> Self {
92 self.user = user.into();
93 self
94 }
95
96 pub fn password(mut self, password: impl Into<String>) -> Self {
98 self.password = Some(password.into());
99 self
100 }
101
102 pub const fn connect_timeout(mut self, seconds: u64) -> Self {
104 self.connect_timeout = seconds;
105 self
106 }
107
108 pub fn application_name(mut self, name: impl Into<String>) -> Self {
110 self.application_name = Some(name.into());
111 self
112 }
113
114 pub const fn sslmode(mut self, mode: SslMode) -> Self {
116 self.sslmode = mode;
117 self
118 }
119
120 pub fn to_connection_string(&self) -> String {
122 let mut parts = Vec::new();
123
124 if let Some(ref host) = self.host {
125 parts.push(format!("host={host}"));
126 }
127
128 parts.push(format!("port={}", self.port));
129 parts.push(format!("dbname={}", self.dbname));
130 parts.push(format!("user={}", self.user));
131
132 if let Some(ref password) = self.password {
133 parts.push(format!("password={password}"));
134 }
135
136 parts.push(format!("connect_timeout={}", self.connect_timeout));
137
138 if let Some(ref app_name) = self.application_name {
139 parts.push(format!("application_name={app_name}"));
140 }
141
142 parts.push(format!("sslmode={}", self.sslmode.as_str()));
143
144 parts.join(" ")
145 }
146
147 pub fn from_connection_string(conn_str: &str) -> Result<Self> {
149 let mut config = Self::default();
150
151 for part in conn_str.split_whitespace() {
152 if let Some((key, value)) = part.split_once('=') {
153 match key {
154 "host" => config.host = Some(value.to_string()),
155 "port" => {
156 config.port = value.parse().map_err(|_| {
157 ConnectionError::InvalidConnectionString {
158 message: format!("Invalid port: {value}"),
159 }
160 })?;
161 }
162 "dbname" => config.dbname = value.to_string(),
163 "user" => config.user = value.to_string(),
164 "password" => config.password = Some(value.to_string()),
165 "connect_timeout" => {
166 config.connect_timeout = value.parse().map_err(|_| {
167 ConnectionError::InvalidConnectionString {
168 message: format!("Invalid connect_timeout: {value}"),
169 }
170 })?;
171 }
172 "application_name" => config.application_name = Some(value.to_string()),
173 "sslmode" => {
174 config.sslmode = match value {
175 "disable" => SslMode::Disable,
176 "prefer" => SslMode::Prefer,
177 "require" => SslMode::Require,
178 _ => {
179 return Err(ConnectionError::InvalidConnectionString {
180 message: format!("Invalid sslmode: {value}"),
181 }
182 .into());
183 }
184 };
185 }
186 _ => {
187 warn!("Unknown connection string parameter: {key}");
188 }
189 }
190 }
191 }
192
193 Ok(config)
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct PoolConfig {
200 pub max_size: usize,
202 pub timeout: Duration,
204 pub recycling_method: RecyclingMethod,
206}
207
208impl Default for PoolConfig {
209 fn default() -> Self {
210 Self {
211 max_size: 16,
212 timeout: Duration::from_secs(30),
213 recycling_method: RecyclingMethod::Fast,
214 }
215 }
216}
217
218impl PoolConfig {
219 pub fn new() -> Self {
221 Self::default()
222 }
223
224 pub const fn max_size(mut self, size: usize) -> Self {
226 self.max_size = size;
227 self
228 }
229
230 pub const fn timeout(mut self, timeout: Duration) -> Self {
232 self.timeout = timeout;
233 self
234 }
235
236 pub fn recycling_method(mut self, method: RecyclingMethod) -> Self {
238 self.recycling_method = method;
239 self
240 }
241}
242
243pub struct ConnectionPool {
245 pool: Pool,
246 config: ConnectionConfig,
247}
248
249impl ConnectionPool {
250 pub fn new(config: ConnectionConfig) -> Result<Self> {
252 let pool_config = PoolConfig::default();
253 Self::with_pool_config(config, pool_config)
254 }
255
256 pub fn with_pool_config(config: ConnectionConfig, pool_config: PoolConfig) -> Result<Self> {
258 let conn_str = config.to_connection_string();
259 debug!("Creating connection pool with config: {}", conn_str);
260
261 let mut pg_config = Config::new();
262 if let Some(ref host) = config.host {
263 pg_config.host = Some(host.clone());
264 }
265 pg_config.port = Some(config.port);
266 pg_config.dbname = Some(config.dbname.clone());
267 pg_config.user = Some(config.user.clone());
268 pg_config.password = config.password.clone();
269 pg_config.connect_timeout = Some(Duration::from_secs(config.connect_timeout));
270 pg_config.application_name = config.application_name.clone();
271
272 pg_config.manager = Some(ManagerConfig {
273 recycling_method: pool_config.recycling_method,
274 });
275
276 let pool = pg_config
277 .create_pool(Some(Runtime::Tokio1), NoTls)
278 .map_err(|e| ConnectionError::PoolError {
279 message: e.to_string(),
280 })?;
281
282 Ok(Self { pool, config })
283 }
284
285 pub fn from_connection_string(conn_str: &str) -> Result<Self> {
287 let config = ConnectionConfig::from_connection_string(conn_str)?;
288 Self::new(config)
289 }
290
291 pub async fn get(&self) -> Result<deadpool_postgres::Object> {
293 self.pool.get().await.map_err(|e| {
294 ConnectionError::PoolError {
295 message: e.to_string(),
296 }
297 .into()
298 })
299 }
300
301 pub fn status(&self) -> PoolStatus {
303 let status = self.pool.status();
304 PoolStatus {
305 size: status.size,
306 available: status.available,
307 max_size: status.max_size,
308 }
309 }
310
311 pub async fn check_postgis(&self) -> Result<bool> {
313 let client = self.get().await?;
314
315 let query = "SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'postgis')";
316 let row = client.query_one(query, &[]).await.map_err(|e| {
317 PostGisError::Query(crate::error::QueryError::ExecutionFailed {
318 message: e.to_string(),
319 })
320 })?;
321
322 let exists: bool = row.get(0);
323 Ok(exists)
324 }
325
326 pub async fn postgis_version(&self) -> Result<String> {
328 let client = self.get().await?;
329
330 let query = "SELECT PostGIS_Version()";
331 let row = client.query_one(query, &[]).await.map_err(|e| {
332 PostGisError::Query(crate::error::QueryError::ExecutionFailed {
333 message: e.to_string(),
334 })
335 })?;
336
337 let version: String = row.get(0);
338 Ok(version)
339 }
340
341 pub async fn health_check(&self) -> Result<HealthCheckResult> {
343 let start = std::time::Instant::now();
344
345 let client = self.get().await?;
347
348 client.query_one("SELECT 1", &[]).await.map_err(|e| {
350 PostGisError::Query(crate::error::QueryError::ExecutionFailed {
351 message: e.to_string(),
352 })
353 })?;
354
355 let latency = start.elapsed();
356
357 let postgis_installed = self.check_postgis().await?;
359 let postgis_version = if postgis_installed {
360 self.postgis_version().await.ok()
361 } else {
362 None
363 };
364
365 Ok(HealthCheckResult {
366 connected: true,
367 latency,
368 pool_status: self.status(),
369 postgis_installed,
370 postgis_version,
371 })
372 }
373
374 pub const fn config(&self) -> &ConnectionConfig {
376 &self.config
377 }
378}
379
380#[derive(Debug, Clone)]
382pub struct PoolStatus {
383 pub size: usize,
385 pub available: usize,
387 pub max_size: usize,
389}
390
391#[derive(Debug, Clone)]
393pub struct HealthCheckResult {
394 pub connected: bool,
396 pub latency: Duration,
398 pub pool_status: PoolStatus,
400 pub postgis_installed: bool,
402 pub postgis_version: Option<String>,
404}
405
406impl HealthCheckResult {
407 pub fn is_healthy(&self) -> bool {
409 self.connected && self.postgis_installed
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[test]
418 fn test_connection_config_default() {
419 let config = ConnectionConfig::default();
420 assert_eq!(config.port, 5432);
421 assert_eq!(config.dbname, "postgres");
422 assert_eq!(config.user, "postgres");
423 }
424
425 #[test]
426 fn test_connection_config_builder() {
427 let config = ConnectionConfig::new("test_db")
428 .host("localhost")
429 .port(5433)
430 .user("test_user")
431 .password("test_pass")
432 .connect_timeout(60)
433 .application_name("test_app")
434 .sslmode(SslMode::Require);
435
436 assert_eq!(config.dbname, "test_db");
437 assert_eq!(config.host, Some("localhost".to_string()));
438 assert_eq!(config.port, 5433);
439 assert_eq!(config.user, "test_user");
440 assert_eq!(config.password, Some("test_pass".to_string()));
441 assert_eq!(config.connect_timeout, 60);
442 assert_eq!(config.application_name, Some("test_app".to_string()));
443 assert_eq!(config.sslmode, SslMode::Require);
444 }
445
446 #[test]
447 fn test_connection_string_generation() {
448 let config = ConnectionConfig::new("test_db")
449 .host("localhost")
450 .user("test_user")
451 .password("test_pass");
452
453 let conn_str = config.to_connection_string();
454 assert!(conn_str.contains("host=localhost"));
455 assert!(conn_str.contains("dbname=test_db"));
456 assert!(conn_str.contains("user=test_user"));
457 assert!(conn_str.contains("password=test_pass"));
458 }
459
460 #[test]
461 fn test_connection_string_parsing() {
462 let conn_str = "host=localhost port=5432 dbname=test_db user=test_user password=test_pass";
463 let config = ConnectionConfig::from_connection_string(conn_str).ok();
464 assert!(config.is_some());
465
466 let config = config.expect("config parsing failed");
467 assert_eq!(config.host, Some("localhost".to_string()));
468 assert_eq!(config.port, 5432);
469 assert_eq!(config.dbname, "test_db");
470 assert_eq!(config.user, "test_user");
471 assert_eq!(config.password, Some("test_pass".to_string()));
472 }
473
474 #[test]
475 fn test_sslmode() {
476 assert_eq!(SslMode::Disable.as_str(), "disable");
477 assert_eq!(SslMode::Prefer.as_str(), "prefer");
478 assert_eq!(SslMode::Require.as_str(), "require");
479 }
480
481 #[test]
482 fn test_pool_config() {
483 let config = PoolConfig::new()
484 .max_size(32)
485 .timeout(Duration::from_secs(60));
486
487 assert_eq!(config.max_size, 32);
488 assert_eq!(config.timeout, Duration::from_secs(60));
489 }
490}