1use crate::config::DbkitConfig;
2use crate::DbkitError;
3use deadpool_postgres::{
4 Config as PostgresConfig, ManagerConfig, Pool, PoolError, RecyclingMethod, Runtime,
5};
6use std::time::Duration;
7use tokio_postgres::{NoTls, error::SqlState};
8use tracing::{error, info, warn};
9
10pub struct ConnectionManager {
15 pool: Pool,
16 db_name: String,
17 connection_string: String,
18 config: DbkitConfig,
19}
20
21impl ConnectionManager {
22 pub async fn connect(config: DbkitConfig) -> Result<Self, DbkitError> {
24 let db_name = Self::extract_db_name(&config.url);
25 let connection_string = config.url.clone();
26
27 let mut cfg = PostgresConfig::new();
28 cfg.url = Some(config.url.clone());
29 cfg.pool = Some(deadpool_postgres::PoolConfig {
30 max_size: config.pool_size,
31 timeouts: deadpool_postgres::Timeouts {
32 wait: Some(Duration::from_secs(config.connect_timeout_secs)),
33 create: Some(Duration::from_secs(config.connect_timeout_secs)),
34 recycle: Some(Duration::from_secs(config.connect_timeout_secs)),
35 },
36 ..Default::default()
37 });
38 cfg.manager = Some(ManagerConfig {
39 recycling_method: RecyclingMethod::Fast,
40 });
41
42 let pool = cfg
43 .create_pool(Some(Runtime::Tokio1), NoTls)
44 .map_err(|e| DbkitError::PoolCreation(e.to_string()))?;
45
46 let final_pool = match pool.get().await {
47 Ok(_) => {
48 info!("connected to database '{}'", db_name);
49 pool
50 }
51 Err(PoolError::Backend(e)) => {
52 if let Some(code) = e.code() {
53 if *code == SqlState::INVALID_CATALOG_NAME {
54 if config.auto_create_db {
55 warn!("database '{}' does not exist, creating...", db_name);
56 Self::create_database_if_missing(&config.url, &db_name).await?;
57 cfg.create_pool(Some(Runtime::Tokio1), NoTls)
58 .map_err(|e| DbkitError::PoolCreation(e.to_string()))?
59 } else {
60 return Err(DbkitError::DatabaseCreation {
61 name: db_name,
62 reason: "database does not exist and auto_create_db is disabled"
63 .into(),
64 });
65 }
66 } else if *code == SqlState::INVALID_PASSWORD {
67 error!("authentication failed");
68 return Err(DbkitError::AuthFailed);
69 } else if *code == SqlState::TOO_MANY_CONNECTIONS {
70 return Err(DbkitError::TooManyConnections);
71 } else {
72 return Err(DbkitError::Connection(format!(
73 "code {:?}: {}",
74 code, e
75 )));
76 }
77 } else {
78 return Err(DbkitError::Connection(e.to_string()));
79 }
80 }
81 Err(e) => {
82 return Err(DbkitError::Connection(format!(
83 "could not connect to '{}': {}",
84 db_name, e
85 )));
86 }
87 };
88
89 Ok(Self {
90 pool: final_pool,
91 db_name,
92 connection_string,
93 config,
94 })
95 }
96
97 pub async fn new(url: &str) -> Result<Self, DbkitError> {
101 Self::connect(DbkitConfig::from_url(url)).await
102 }
103
104 pub fn pool(&self) -> &Pool {
106 &self.pool
107 }
108
109 pub async fn get_connection(&self) -> Result<deadpool_postgres::Object, DbkitError> {
111 self.pool
112 .get()
113 .await
114 .map_err(|e| DbkitError::Pool(e.to_string()))
115 }
116
117 pub async fn is_connected(&self) -> bool {
119 self.pool.get().await.is_ok()
120 }
121
122 pub fn db_name(&self) -> &str {
124 &self.db_name
125 }
126
127 pub fn connection_string(&self) -> &str {
129 &self.connection_string
130 }
131
132 pub fn config(&self) -> &DbkitConfig {
134 &self.config
135 }
136
137 pub fn pool_status(&self) -> PoolStatus {
139 let status = self.pool.status();
140 PoolStatus {
141 max_size: status.max_size,
142 size: status.size,
143 available: status.available as usize,
144 waiting: status.waiting,
145 }
146 }
147
148 fn extract_db_name(url: &str) -> String {
149 url.rsplit('/')
150 .next()
151 .unwrap_or("postgres")
152 .split('?')
153 .next()
154 .unwrap_or("postgres")
155 .to_string()
156 }
157
158 async fn create_database_if_missing(url: &str, db_name: &str) -> Result<(), DbkitError> {
159 let base_url = if let Some(pos) = url.rfind('/') {
160 format!("{}postgres", &url[..=pos])
161 } else {
162 return Err(DbkitError::DatabaseCreation {
163 name: db_name.to_string(),
164 reason: "invalid database URL".into(),
165 });
166 };
167
168 let (client, connection) = tokio_postgres::connect(&base_url, NoTls).await?;
169
170 tokio::spawn(async move {
171 if let Err(e) = connection.await {
172 warn!("connection error during DB creation: {}", e);
173 }
174 });
175
176 let exists = client
177 .query_one("SELECT 1 FROM pg_database WHERE datname = $1", &[&db_name])
178 .await
179 .is_ok();
180
181 if !exists {
182 info!("creating database '{}'...", db_name);
183 let create_query = format!("CREATE DATABASE \"{}\"", db_name);
184 client
185 .batch_execute(&create_query)
186 .await
187 .map_err(|e| DbkitError::DatabaseCreation {
188 name: db_name.to_string(),
189 reason: e.to_string(),
190 })?;
191 info!("database '{}' created", db_name);
192 }
193
194 Ok(())
195 }
196}
197
198#[derive(Debug, Clone)]
200pub struct PoolStatus {
201 pub max_size: usize,
203 pub size: usize,
205 pub available: usize,
207 pub waiting: usize,
209}
210
211impl std::fmt::Display for PoolStatus {
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 write!(
214 f,
215 "pool: {}/{} connections, {} available, {} waiting",
216 self.size, self.max_size, self.available, self.waiting
217 )
218 }
219}