1use std::fmt;
4use std::sync::atomic::AtomicBool;
5use std::sync::{Arc, OnceLock};
6use std::time::Duration;
7
8use cdk_common::database::Error;
9use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
10use cdk_sql_common::mint::SQLMintAuthDatabase;
11use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
12use cdk_sql_common::stmt::{Column, Statement};
13use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
14use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
15use native_tls::TlsConnector;
16use postgres_native_tls::MakeTlsConnector;
17use tokio::sync::{Mutex, Notify};
18use tokio::time::timeout;
19use tokio_postgres::{connect, Client, Error as PgError, NoTls};
20
21mod db;
22mod value;
23
24#[derive(Debug)]
25pub struct PgConnectionPool;
27
28#[derive(Clone)]
29pub enum SslMode {
31 NoTls(NoTls),
33 NativeTls(postgres_native_tls::MakeTlsConnector),
35}
36const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
37const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
38const SSLMODE_PREFER: &str = "sslmode=prefer";
39const SSLMODE_ALLOW: &str = "sslmode=allow";
40const SSLMODE_REQUIRE: &str = "sslmode=require";
41
42impl Default for SslMode {
43 fn default() -> Self {
44 SslMode::NoTls(NoTls {})
45 }
46}
47
48impl fmt::Debug for SslMode {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 let debug_text = match self {
51 Self::NoTls(_) => "NoTls",
52 Self::NativeTls(_) => "NativeTls",
53 };
54
55 write!(f, "SslMode::{debug_text}")
56 }
57}
58
59#[derive(Clone)]
61pub struct PgConfig {
62 url: String,
63 schema: Option<String>,
64 tls: SslMode,
65 max_connections: usize,
66 connection_timeout: Duration,
67}
68
69impl fmt::Debug for PgConfig {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 f.debug_struct("PgConfig")
72 .field("url", &"[redacted]")
73 .field("schema", &self.schema)
74 .field("tls", &self.tls)
75 .field("max_connections", &self.max_connections)
76 .field("connection_timeout", &self.connection_timeout)
77 .finish()
78 }
79}
80
81impl DatabaseConfig for PgConfig {
82 fn default_timeout(&self) -> Duration {
83 self.connection_timeout
84 }
85
86 fn max_size(&self) -> usize {
87 self.max_connections
88 }
89}
90
91const DEFAULT_MAX_CONNECTIONS: usize = 20;
93
94const DEFAULT_CONNECTION_TIMEOUT_SECS: u64 = 10;
96
97fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
99 let mut builder = TlsConnector::builder();
100 if accept_invalid_certs {
101 builder.danger_accept_invalid_certs(true);
102 }
103 if accept_invalid_hostnames {
104 builder.danger_accept_invalid_hostnames(true);
105 }
106
107 match builder.build() {
108 Ok(connector) => {
109 let make_tls_connector = MakeTlsConnector::new(connector);
110 SslMode::NativeTls(make_tls_connector)
111 }
112 Err(_) => SslMode::NoTls(NoTls {}),
113 }
114}
115
116fn ssl_mode_from_url(url: &str) -> SslMode {
118 if url.contains(SSLMODE_VERIFY_FULL) {
119 build_tls(false, false)
121 } else if url.contains(SSLMODE_VERIFY_CA) {
122 build_tls(false, true)
124 } else if url.contains(SSLMODE_PREFER)
125 || url.contains(SSLMODE_ALLOW)
126 || url.contains(SSLMODE_REQUIRE)
127 {
128 build_tls(true, true)
130 } else {
131 SslMode::NoTls(NoTls {})
132 }
133}
134
135fn ssl_mode_from_config(tls_mode: Option<&str>, url: &str) -> SslMode {
140 match tls_mode {
141 Some(mode) => match mode.to_lowercase().as_str() {
142 "verify-full" => build_tls(false, false),
143 "verify-ca" => build_tls(false, true),
144 "require" | "prefer" | "allow" => build_tls(true, true),
145 _ => SslMode::NoTls(NoTls {}),
147 },
148 None => ssl_mode_from_url(url),
150 }
151}
152
153impl PgConfig {
154 pub fn new(
161 conn_str: &str,
162 tls_mode: Option<&str>,
163 max_connections: Option<usize>,
164 connection_timeout_secs: Option<u64>,
165 ) -> Self {
166 let (schema, conn_str) = Self::strip_schema(conn_str);
167 let tls = ssl_mode_from_config(tls_mode, &conn_str);
168 PgConfig {
169 url: conn_str,
170 schema,
171 tls,
172 max_connections: max_connections.unwrap_or(DEFAULT_MAX_CONNECTIONS),
173 connection_timeout: Duration::from_secs(
174 connection_timeout_secs.unwrap_or(DEFAULT_CONNECTION_TIMEOUT_SECS),
175 ),
176 }
177 }
178
179 fn strip_schema(input: &str) -> (Option<String>, String) {
181 let mut schema: Option<String> = None;
182
183 let mut parts = Vec::new();
185 for token in input.split_whitespace() {
186 if let Some(rest) = token.strip_prefix("schema=") {
187 schema = Some(rest.to_string());
188 } else {
189 parts.push(token);
190 }
191 }
192
193 let cleaned = parts.join(" ");
194 (schema, cleaned)
195 }
196}
197
198impl From<&str> for PgConfig {
199 fn from(conn_str: &str) -> Self {
200 let (schema, conn_str) = Self::strip_schema(conn_str);
201 let tls = ssl_mode_from_url(&conn_str);
202
203 PgConfig {
204 url: conn_str,
205 schema,
206 tls,
207 max_connections: DEFAULT_MAX_CONNECTIONS,
208 connection_timeout: Duration::from_secs(DEFAULT_CONNECTION_TIMEOUT_SECS),
209 }
210 }
211}
212
213impl DatabasePool for PgConnectionPool {
214 type Config = PgConfig;
215
216 type Connection = PostgresConnection;
217
218 type Error = PgError;
219
220 fn new_resource(
221 config: &Self::Config,
222 stale: Arc<AtomicBool>,
223 timeout: Duration,
224 ) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
225 Ok(PostgresConnection::new(config.to_owned(), timeout, stale))
226 }
227}
228
229#[derive(Debug)]
231pub struct PostgresConnection {
232 timeout: Duration,
233 error: Arc<Mutex<Option<cdk_common::database::Error>>>,
234 result: Arc<OnceLock<Client>>,
235 notify: Arc<Notify>,
236}
237
238impl PostgresConnection {
239 pub fn new(config: PgConfig, timeout: Duration, stale: Arc<AtomicBool>) -> Self {
241 let failed = Arc::new(Mutex::new(None));
242 let result = Arc::new(OnceLock::new());
243 let notify = Arc::new(Notify::new());
244 let error_clone = failed.clone();
245 let result_clone = result.clone();
246 let notify_clone = notify.clone();
247
248 async fn select_schema(conn: &Client, schema: &str) -> Result<(), Error> {
249 conn.batch_execute(&format!(
250 r#"
251 CREATE SCHEMA IF NOT EXISTS "{schema}";
252 SET search_path TO "{schema}"
253 "#
254 ))
255 .await
256 .map_err(|e| Error::Database(Box::new(e)))
257 }
258
259 tokio::spawn(async move {
260 match config.tls {
261 SslMode::NoTls(tls) => {
262 let (client, connection) = match connect(&config.url, tls).await {
263 Ok((client, connection)) => (client, connection),
264 Err(err) => {
265 *error_clone.lock().await =
266 Some(cdk_common::database::Error::Database(Box::new(err)));
267 stale.store(true, std::sync::atomic::Ordering::Release);
268 notify_clone.notify_waiters();
269 return;
270 }
271 };
272
273 let stale_for_spawn = stale.clone();
274 tokio::spawn(async move {
275 let _ = connection.await;
276 stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
277 });
278
279 if let Some(schema) = config.schema.as_ref() {
280 if let Err(err) = select_schema(&client, schema).await {
281 *error_clone.lock().await = Some(err);
282 stale.store(true, std::sync::atomic::Ordering::Release);
283 notify_clone.notify_waiters();
284 return;
285 }
286 }
287
288 let _ = result_clone.set(client);
289 notify_clone.notify_waiters();
290 }
291 SslMode::NativeTls(tls) => {
292 let (client, connection) = match connect(&config.url, tls).await {
293 Ok((client, connection)) => (client, connection),
294 Err(err) => {
295 *error_clone.lock().await =
296 Some(cdk_common::database::Error::Database(Box::new(err)));
297 stale.store(true, std::sync::atomic::Ordering::Release);
298 notify_clone.notify_waiters();
299 return;
300 }
301 };
302
303 let stale_for_spawn = stale.clone();
304 tokio::spawn(async move {
305 let _ = connection.await;
306 stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
307 });
308
309 if let Some(schema) = config.schema.as_ref() {
310 if let Err(err) = select_schema(&client, schema).await {
311 *error_clone.lock().await = Some(err);
312 stale.store(true, std::sync::atomic::Ordering::Release);
313 notify_clone.notify_waiters();
314 return;
315 }
316 }
317
318 let _ = result_clone.set(client);
319 notify_clone.notify_waiters();
320 }
321 }
322 });
323
324 Self {
325 error: failed,
326 timeout,
327 result,
328 notify,
329 }
330 }
331
332 async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
335 if let Some(client) = self.result.get() {
336 return Ok(client);
337 }
338
339 if let Some(error) = self.error.lock().await.take() {
340 return Err(error);
341 }
342
343 if timeout(self.timeout, self.notify.notified()).await.is_err() {
344 return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
345 }
346
347 if let Some(client) = self.result.get() {
349 Ok(client)
350 } else if let Some(error) = self.error.lock().await.take() {
351 Err(error)
352 } else {
353 Err(cdk_common::database::Error::Internal(
354 "Failed connection".to_owned(),
355 ))
356 }
357 }
358}
359
360#[async_trait::async_trait]
361impl DatabaseConnector for PostgresConnection {
362 type Transaction = GenericTransactionHandler<Self>;
363}
364
365#[async_trait::async_trait]
366impl DatabaseExecutor for PostgresConnection {
367 fn name() -> &'static str {
368 "postgres"
369 }
370
371 async fn execute(&self, statement: Statement) -> Result<usize, Error> {
372 pg_execute(self.inner().await?, statement).await
373 }
374
375 async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
376 pg_fetch_one(self.inner().await?, statement).await
377 }
378
379 async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
380 pg_fetch_all(self.inner().await?, statement).await
381 }
382
383 async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
384 pg_pluck(self.inner().await?, statement).await
385 }
386
387 async fn batch(&self, statement: Statement) -> Result<(), Error> {
388 pg_batch(self.inner().await?, statement).await
389 }
390}
391
392pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
394
395pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
397
398pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
400
401pub async fn new_wallet_pg_database(conn_str: &str) -> Result<WalletPgDatabase, Error> {
404 <SQLWalletDatabase<PgConnectionPool>>::new(conn_str).await
405}
406
407#[cfg(test)]
408mod test {
409 use cdk_common::{mint_db_test, wallet_db_test};
410
411 use super::*;
412
413 async fn provide_mint_db(test_id: String) -> MintPgDatabase {
414 let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
415 .or_else(|_| std::env::var("PG_DB_URL")) .unwrap_or(
417 "host=localhost user=cdk_user password=cdk_password dbname=cdk_mint port=5432"
418 .to_owned(),
419 );
420
421 let db_url = format!("{db_url} schema={test_id}");
422
423 MintPgDatabase::new(db_url.as_str())
424 .await
425 .expect("database")
426 }
427
428 mint_db_test!(provide_mint_db);
429
430 async fn provide_wallet_db(test_id: String) -> WalletPgDatabase {
431 let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
432 .or_else(|_| std::env::var("PG_DB_URL")) .unwrap_or(
434 "host=localhost user=cdk_user password=cdk_password dbname=cdk_mint port=5432"
435 .to_owned(),
436 );
437
438 let db_url = format!("{db_url} schema={test_id}");
439
440 WalletPgDatabase::new(db_url.as_str())
441 .await
442 .expect("database")
443 }
444
445 wallet_db_test!(provide_wallet_db);
446
447 #[tokio::test]
448 async fn failed_initial_connect_marks_connection_stale() {
449 let stale = Arc::new(AtomicBool::new(false));
450 let config = PgConfig::from("host=127.0.0.1 port=1 user=cdk dbname=cdk connect_timeout=1");
451 let conn = PostgresConnection::new(config, Duration::from_secs(5), stale.clone());
452
453 assert!(
454 conn.inner().await.is_err(),
455 "connect to refused port should fail"
456 );
457 tokio::task::yield_now().await;
458
459 assert!(
460 stale.load(std::sync::atomic::Ordering::SeqCst),
461 "failed initial connect should mark the pooled connection stale"
462 );
463 }
464
465 #[test]
466 fn pgconfig_debug_does_not_leak_password() {
467 let config = PgConfig::from("host=localhost user=u password=hunter2secret dbname=d");
468 let rendered = format!("{config:?}");
469
470 assert!(
471 !rendered.contains("hunter2secret"),
472 "PgConfig Debug leaked the DB password: {rendered}"
473 );
474 }
475}