Skip to main content

dbkit_rs/
connection.rs

1use crate::config::DbkitConfig;
2use crate::DbkitError;
3use deadpool_postgres::{
4    Config as PostgresConfig, ManagerConfig, Pool, PoolError, RecyclingMethod, Runtime,
5};
6use std::time::Duration;
7use tokio_postgres::{NoTls, error::SqlState};
8use tracing::{error, info, warn};
9
10/// Postgres connection pool with automatic database creation.
11///
12/// If `auto_create_db` is enabled (default) and the target database doesn't
13/// exist, it connects to the `postgres` system database and creates it.
14pub struct ConnectionManager {
15    pool: Pool,
16    db_name: String,
17    connection_string: String,
18    config: DbkitConfig,
19}
20
21impl ConnectionManager {
22    /// Connect using a [`DbkitConfig`].
23    pub async fn connect(config: DbkitConfig) -> Result<Self, DbkitError> {
24        let db_name = Self::extract_db_name(&config.url);
25        let connection_string = config.url.clone();
26
27        let mut cfg = PostgresConfig::new();
28        cfg.url = Some(config.url.clone());
29        cfg.pool = Some(deadpool_postgres::PoolConfig {
30            max_size: config.pool_size,
31            timeouts: deadpool_postgres::Timeouts {
32                wait: Some(Duration::from_secs(config.connect_timeout_secs)),
33                create: Some(Duration::from_secs(config.connect_timeout_secs)),
34                recycle: Some(Duration::from_secs(config.connect_timeout_secs)),
35            },
36            ..Default::default()
37        });
38        cfg.manager = Some(ManagerConfig {
39            recycling_method: RecyclingMethod::Fast,
40        });
41
42        let pool = cfg
43            .create_pool(Some(Runtime::Tokio1), NoTls)
44            .map_err(|e| DbkitError::PoolCreation(e.to_string()))?;
45
46        let final_pool = match pool.get().await {
47            Ok(_) => {
48                info!("connected to database '{}'", db_name);
49                pool
50            }
51            Err(PoolError::Backend(e)) => {
52                if let Some(code) = e.code() {
53                    if *code == SqlState::INVALID_CATALOG_NAME {
54                        if config.auto_create_db {
55                            warn!("database '{}' does not exist, creating...", db_name);
56                            Self::create_database_if_missing(&config.url, &db_name).await?;
57                            cfg.create_pool(Some(Runtime::Tokio1), NoTls)
58                                .map_err(|e| DbkitError::PoolCreation(e.to_string()))?
59                        } else {
60                            return Err(DbkitError::DatabaseCreation {
61                                name: db_name,
62                                reason: "database does not exist and auto_create_db is disabled"
63                                    .into(),
64                            });
65                        }
66                    } else if *code == SqlState::INVALID_PASSWORD {
67                        error!("authentication failed");
68                        return Err(DbkitError::AuthFailed);
69                    } else if *code == SqlState::TOO_MANY_CONNECTIONS {
70                        return Err(DbkitError::TooManyConnections);
71                    } else {
72                        return Err(DbkitError::Connection(format!(
73                            "code {:?}: {}",
74                            code, e
75                        )));
76                    }
77                } else {
78                    return Err(DbkitError::Connection(e.to_string()));
79                }
80            }
81            Err(e) => {
82                return Err(DbkitError::Connection(format!(
83                    "could not connect to '{}': {}",
84                    db_name, e
85                )));
86            }
87        };
88
89        Ok(Self {
90            pool: final_pool,
91            db_name,
92            connection_string,
93            config,
94        })
95    }
96
97    /// Connect using a connection URL with default settings.
98    ///
99    /// Shorthand for `ConnectionManager::connect(DbkitConfig::from_url(url))`.
100    pub async fn new(url: &str) -> Result<Self, DbkitError> {
101        Self::connect(DbkitConfig::from_url(url)).await
102    }
103
104    /// Get the underlying connection pool.
105    pub fn pool(&self) -> &Pool {
106        &self.pool
107    }
108
109    /// Get a connection from the pool.
110    pub async fn get_connection(&self) -> Result<deadpool_postgres::Object, DbkitError> {
111        self.pool
112            .get()
113            .await
114            .map_err(|e| DbkitError::Pool(e.to_string()))
115    }
116
117    /// Check if the database is reachable.
118    pub async fn is_connected(&self) -> bool {
119        self.pool.get().await.is_ok()
120    }
121
122    /// The database name extracted from the connection URL.
123    pub fn db_name(&self) -> &str {
124        &self.db_name
125    }
126
127    /// The full connection string.
128    pub fn connection_string(&self) -> &str {
129        &self.connection_string
130    }
131
132    /// The config used to create this connection.
133    pub fn config(&self) -> &DbkitConfig {
134        &self.config
135    }
136
137    /// Pool health metrics.
138    pub fn pool_status(&self) -> PoolStatus {
139        let status = self.pool.status();
140        PoolStatus {
141            max_size: status.max_size,
142            size: status.size,
143            available: status.available as usize,
144            waiting: status.waiting,
145        }
146    }
147
148    fn extract_db_name(url: &str) -> String {
149        url.rsplit('/')
150            .next()
151            .unwrap_or("postgres")
152            .split('?')
153            .next()
154            .unwrap_or("postgres")
155            .to_string()
156    }
157
158    async fn create_database_if_missing(url: &str, db_name: &str) -> Result<(), DbkitError> {
159        let base_url = if let Some(pos) = url.rfind('/') {
160            format!("{}postgres", &url[..=pos])
161        } else {
162            return Err(DbkitError::DatabaseCreation {
163                name: db_name.to_string(),
164                reason: "invalid database URL".into(),
165            });
166        };
167
168        let (client, connection) = tokio_postgres::connect(&base_url, NoTls).await?;
169
170        tokio::spawn(async move {
171            if let Err(e) = connection.await {
172                warn!("connection error during DB creation: {}", e);
173            }
174        });
175
176        let exists = client
177            .query_one("SELECT 1 FROM pg_database WHERE datname = $1", &[&db_name])
178            .await
179            .is_ok();
180
181        if !exists {
182            info!("creating database '{}'...", db_name);
183            let create_query = format!("CREATE DATABASE \"{}\"", db_name);
184            client
185                .batch_execute(&create_query)
186                .await
187                .map_err(|e| DbkitError::DatabaseCreation {
188                    name: db_name.to_string(),
189                    reason: e.to_string(),
190                })?;
191            info!("database '{}' created", db_name);
192        }
193
194        Ok(())
195    }
196}
197
198/// Snapshot of connection pool health.
199#[derive(Debug, Clone)]
200pub struct PoolStatus {
201    /// Maximum number of connections in the pool.
202    pub max_size: usize,
203    /// Current number of connections (active + idle).
204    pub size: usize,
205    /// Number of idle connections available.
206    pub available: usize,
207    /// Number of tasks waiting for a connection.
208    pub waiting: usize,
209}
210
211impl std::fmt::Display for PoolStatus {
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        write!(
214            f,
215            "pool: {}/{} connections, {} available, {} waiting",
216            self.size, self.max_size, self.available, self.waiting
217        )
218    }
219}