Skip to main content

cdk_postgres/
lib.rs

1//! CDK Postgres
2
3use std::fmt;
4use std::sync::atomic::AtomicBool;
5use std::sync::{Arc, OnceLock};
6use std::time::Duration;
7
8use cdk_common::database::Error;
9use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
10use cdk_sql_common::mint::SQLMintAuthDatabase;
11use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
12use cdk_sql_common::stmt::{Column, Statement};
13use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
14use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
15use native_tls::TlsConnector;
16use postgres_native_tls::MakeTlsConnector;
17use tokio::sync::{Mutex, Notify};
18use tokio::time::timeout;
19use tokio_postgres::{connect, Client, Error as PgError, NoTls};
20
21mod db;
22mod value;
23
24#[derive(Debug)]
25/// Postgres connection pool
26pub struct PgConnectionPool;
27
28#[derive(Clone)]
29/// SSL Mode
30pub enum SslMode {
31    /// No TLS
32    NoTls(NoTls),
33    /// Native TLS
34    NativeTls(postgres_native_tls::MakeTlsConnector),
35}
36const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
37const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
38const SSLMODE_PREFER: &str = "sslmode=prefer";
39const SSLMODE_ALLOW: &str = "sslmode=allow";
40const SSLMODE_REQUIRE: &str = "sslmode=require";
41
42impl Default for SslMode {
43    fn default() -> Self {
44        SslMode::NoTls(NoTls {})
45    }
46}
47
48impl fmt::Debug for SslMode {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        let debug_text = match self {
51            Self::NoTls(_) => "NoTls",
52            Self::NativeTls(_) => "NativeTls",
53        };
54
55        write!(f, "SslMode::{debug_text}")
56    }
57}
58
59/// Postgres configuration
60#[derive(Clone)]
61pub struct PgConfig {
62    url: String,
63    schema: Option<String>,
64    tls: SslMode,
65    max_connections: usize,
66    connection_timeout: Duration,
67}
68
69impl fmt::Debug for PgConfig {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("PgConfig")
72            .field("url", &"[redacted]")
73            .field("schema", &self.schema)
74            .field("tls", &self.tls)
75            .field("max_connections", &self.max_connections)
76            .field("connection_timeout", &self.connection_timeout)
77            .finish()
78    }
79}
80
81impl DatabaseConfig for PgConfig {
82    fn default_timeout(&self) -> Duration {
83        self.connection_timeout
84    }
85
86    fn max_size(&self) -> usize {
87        self.max_connections
88    }
89}
90
91/// Default maximum number of connections in the pool
92const DEFAULT_MAX_CONNECTIONS: usize = 20;
93
94/// Default connection timeout in seconds
95const DEFAULT_CONNECTION_TIMEOUT_SECS: u64 = 10;
96
97/// Build a TLS connector with the given certificate/hostname validation settings.
98fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
99    let mut builder = TlsConnector::builder();
100    if accept_invalid_certs {
101        builder.danger_accept_invalid_certs(true);
102    }
103    if accept_invalid_hostnames {
104        builder.danger_accept_invalid_hostnames(true);
105    }
106
107    match builder.build() {
108        Ok(connector) => {
109            let make_tls_connector = MakeTlsConnector::new(connector);
110            SslMode::NativeTls(make_tls_connector)
111        }
112        Err(_) => SslMode::NoTls(NoTls {}),
113    }
114}
115
116/// Determine TLS mode from the `sslmode=` parameter in a connection URL.
117fn ssl_mode_from_url(url: &str) -> SslMode {
118    if url.contains(SSLMODE_VERIFY_FULL) {
119        // Strict TLS: valid certs and hostnames required
120        build_tls(false, false)
121    } else if url.contains(SSLMODE_VERIFY_CA) {
122        // Verify CA, but allow invalid hostnames
123        build_tls(false, true)
124    } else if url.contains(SSLMODE_PREFER)
125        || url.contains(SSLMODE_ALLOW)
126        || url.contains(SSLMODE_REQUIRE)
127    {
128        // Lenient TLS for preferred/allow/require: accept invalid certs and hostnames
129        build_tls(true, true)
130    } else {
131        SslMode::NoTls(NoTls {})
132    }
133}
134
135/// Resolve TLS mode from an explicit `tls_mode` string (from config/env), such
136/// as `"disable"`, `"prefer"`, `"require"`, `"verify-ca"`, or `"verify-full"`.
137///
138/// If the value is `None`, falls back to parsing `sslmode=` from the URL.
139fn ssl_mode_from_config(tls_mode: Option<&str>, url: &str) -> SslMode {
140    match tls_mode {
141        Some(mode) => match mode.to_lowercase().as_str() {
142            "verify-full" => build_tls(false, false),
143            "verify-ca" => build_tls(false, true),
144            "require" | "prefer" | "allow" => build_tls(true, true),
145            // "disable" or any unrecognised value → no TLS
146            _ => SslMode::NoTls(NoTls {}),
147        },
148        // No explicit tls_mode: fall back to URL-based detection
149        None => ssl_mode_from_url(url),
150    }
151}
152
153impl PgConfig {
154    /// Create a new `PgConfig` with explicit TLS mode, pool size, and timeout.
155    ///
156    /// `tls_mode` accepts the same strings as the configuration file:
157    /// `"disable"`, `"prefer"`, `"allow"`, `"require"`, `"verify-ca"`,
158    /// `"verify-full"`.  When `None`, the TLS mode is inferred from
159    /// `sslmode=` in the connection URL (matching the old behaviour).
160    pub fn new(
161        conn_str: &str,
162        tls_mode: Option<&str>,
163        max_connections: Option<usize>,
164        connection_timeout_secs: Option<u64>,
165    ) -> Self {
166        let (schema, conn_str) = Self::strip_schema(conn_str);
167        let tls = ssl_mode_from_config(tls_mode, &conn_str);
168        PgConfig {
169            url: conn_str,
170            schema,
171            tls,
172            max_connections: max_connections.unwrap_or(DEFAULT_MAX_CONNECTIONS),
173            connection_timeout: Duration::from_secs(
174                connection_timeout_secs.unwrap_or(DEFAULT_CONNECTION_TIMEOUT_SECS),
175            ),
176        }
177    }
178
179    /// strip schema from the connection string
180    fn strip_schema(input: &str) -> (Option<String>, String) {
181        let mut schema: Option<String> = None;
182
183        // Split by whitespace
184        let mut parts = Vec::new();
185        for token in input.split_whitespace() {
186            if let Some(rest) = token.strip_prefix("schema=") {
187                schema = Some(rest.to_string());
188            } else {
189                parts.push(token);
190            }
191        }
192
193        let cleaned = parts.join(" ");
194        (schema, cleaned)
195    }
196}
197
198impl From<&str> for PgConfig {
199    fn from(conn_str: &str) -> Self {
200        let (schema, conn_str) = Self::strip_schema(conn_str);
201        let tls = ssl_mode_from_url(&conn_str);
202
203        PgConfig {
204            url: conn_str,
205            schema,
206            tls,
207            max_connections: DEFAULT_MAX_CONNECTIONS,
208            connection_timeout: Duration::from_secs(DEFAULT_CONNECTION_TIMEOUT_SECS),
209        }
210    }
211}
212
213impl DatabasePool for PgConnectionPool {
214    type Config = PgConfig;
215
216    type Connection = PostgresConnection;
217
218    type Error = PgError;
219
220    fn new_resource(
221        config: &Self::Config,
222        stale: Arc<AtomicBool>,
223        timeout: Duration,
224    ) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
225        Ok(PostgresConnection::new(config.to_owned(), timeout, stale))
226    }
227}
228
229/// A postgres connection
230#[derive(Debug)]
231pub struct PostgresConnection {
232    timeout: Duration,
233    error: Arc<Mutex<Option<cdk_common::database::Error>>>,
234    result: Arc<OnceLock<Client>>,
235    notify: Arc<Notify>,
236}
237
238impl PostgresConnection {
239    /// Creates a new instance
240    pub fn new(config: PgConfig, timeout: Duration, stale: Arc<AtomicBool>) -> Self {
241        let failed = Arc::new(Mutex::new(None));
242        let result = Arc::new(OnceLock::new());
243        let notify = Arc::new(Notify::new());
244        let error_clone = failed.clone();
245        let result_clone = result.clone();
246        let notify_clone = notify.clone();
247
248        async fn select_schema(conn: &Client, schema: &str) -> Result<(), Error> {
249            conn.batch_execute(&format!(
250                r#"
251                    CREATE SCHEMA IF NOT EXISTS "{schema}";
252                    SET search_path TO "{schema}"
253                    "#
254            ))
255            .await
256            .map_err(|e| Error::Database(Box::new(e)))
257        }
258
259        tokio::spawn(async move {
260            match config.tls {
261                SslMode::NoTls(tls) => {
262                    let (client, connection) = match connect(&config.url, tls).await {
263                        Ok((client, connection)) => (client, connection),
264                        Err(err) => {
265                            *error_clone.lock().await =
266                                Some(cdk_common::database::Error::Database(Box::new(err)));
267                            stale.store(true, std::sync::atomic::Ordering::Release);
268                            notify_clone.notify_waiters();
269                            return;
270                        }
271                    };
272
273                    let stale_for_spawn = stale.clone();
274                    tokio::spawn(async move {
275                        let _ = connection.await;
276                        stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
277                    });
278
279                    if let Some(schema) = config.schema.as_ref() {
280                        if let Err(err) = select_schema(&client, schema).await {
281                            *error_clone.lock().await = Some(err);
282                            stale.store(true, std::sync::atomic::Ordering::Release);
283                            notify_clone.notify_waiters();
284                            return;
285                        }
286                    }
287
288                    let _ = result_clone.set(client);
289                    notify_clone.notify_waiters();
290                }
291                SslMode::NativeTls(tls) => {
292                    let (client, connection) = match connect(&config.url, tls).await {
293                        Ok((client, connection)) => (client, connection),
294                        Err(err) => {
295                            *error_clone.lock().await =
296                                Some(cdk_common::database::Error::Database(Box::new(err)));
297                            stale.store(true, std::sync::atomic::Ordering::Release);
298                            notify_clone.notify_waiters();
299                            return;
300                        }
301                    };
302
303                    let stale_for_spawn = stale.clone();
304                    tokio::spawn(async move {
305                        let _ = connection.await;
306                        stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
307                    });
308
309                    if let Some(schema) = config.schema.as_ref() {
310                        if let Err(err) = select_schema(&client, schema).await {
311                            *error_clone.lock().await = Some(err);
312                            stale.store(true, std::sync::atomic::Ordering::Release);
313                            notify_clone.notify_waiters();
314                            return;
315                        }
316                    }
317
318                    let _ = result_clone.set(client);
319                    notify_clone.notify_waiters();
320                }
321            }
322        });
323
324        Self {
325            error: failed,
326            timeout,
327            result,
328            notify,
329        }
330    }
331
332    /// Gets the wrapped instance or the connection error. The connection is returned as reference,
333    /// and the actual error is returned once, next times a generic error would be returned
334    async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
335        if let Some(client) = self.result.get() {
336            return Ok(client);
337        }
338
339        if let Some(error) = self.error.lock().await.take() {
340            return Err(error);
341        }
342
343        if timeout(self.timeout, self.notify.notified()).await.is_err() {
344            return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
345        }
346
347        // Check result again
348        if let Some(client) = self.result.get() {
349            Ok(client)
350        } else if let Some(error) = self.error.lock().await.take() {
351            Err(error)
352        } else {
353            Err(cdk_common::database::Error::Internal(
354                "Failed connection".to_owned(),
355            ))
356        }
357    }
358}
359
360#[async_trait::async_trait]
361impl DatabaseConnector for PostgresConnection {
362    type Transaction = GenericTransactionHandler<Self>;
363}
364
365#[async_trait::async_trait]
366impl DatabaseExecutor for PostgresConnection {
367    fn name() -> &'static str {
368        "postgres"
369    }
370
371    async fn execute(&self, statement: Statement) -> Result<usize, Error> {
372        pg_execute(self.inner().await?, statement).await
373    }
374
375    async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
376        pg_fetch_one(self.inner().await?, statement).await
377    }
378
379    async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
380        pg_fetch_all(self.inner().await?, statement).await
381    }
382
383    async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
384        pg_pluck(self.inner().await?, statement).await
385    }
386
387    async fn batch(&self, statement: Statement) -> Result<(), Error> {
388        pg_batch(self.inner().await?, statement).await
389    }
390}
391
392/// Mint DB implementation with PostgreSQL
393pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
394
395/// Mint Auth database with Postgres
396pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
397
398/// Wallet DB implementation with PostgreSQL
399pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
400
401/// Convenience free functions (cannot add inherent impls for a foreign type).
402/// These mirror the Mint patterns and call through to the generic constructors.
403pub async fn new_wallet_pg_database(conn_str: &str) -> Result<WalletPgDatabase, Error> {
404    <SQLWalletDatabase<PgConnectionPool>>::new(conn_str).await
405}
406
407#[cfg(test)]
408mod test {
409    use cdk_common::{mint_db_test, wallet_db_test};
410
411    use super::*;
412
413    async fn provide_mint_db(test_id: String) -> MintPgDatabase {
414        let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
415            .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
416            .unwrap_or(
417                "host=localhost user=cdk_user password=cdk_password dbname=cdk_mint port=5432"
418                    .to_owned(),
419            );
420
421        let db_url = format!("{db_url} schema={test_id}");
422
423        MintPgDatabase::new(db_url.as_str())
424            .await
425            .expect("database")
426    }
427
428    mint_db_test!(provide_mint_db);
429
430    async fn provide_wallet_db(test_id: String) -> WalletPgDatabase {
431        let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
432            .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
433            .unwrap_or(
434                "host=localhost user=cdk_user password=cdk_password dbname=cdk_mint port=5432"
435                    .to_owned(),
436            );
437
438        let db_url = format!("{db_url} schema={test_id}");
439
440        WalletPgDatabase::new(db_url.as_str())
441            .await
442            .expect("database")
443    }
444
445    wallet_db_test!(provide_wallet_db);
446
447    #[tokio::test]
448    async fn failed_initial_connect_marks_connection_stale() {
449        let stale = Arc::new(AtomicBool::new(false));
450        let config = PgConfig::from("host=127.0.0.1 port=1 user=cdk dbname=cdk connect_timeout=1");
451        let conn = PostgresConnection::new(config, Duration::from_secs(5), stale.clone());
452
453        assert!(
454            conn.inner().await.is_err(),
455            "connect to refused port should fail"
456        );
457        tokio::task::yield_now().await;
458
459        assert!(
460            stale.load(std::sync::atomic::Ordering::SeqCst),
461            "failed initial connect should mark the pooled connection stale"
462        );
463    }
464
465    #[test]
466    fn pgconfig_debug_does_not_leak_password() {
467        let config = PgConfig::from("host=localhost user=u password=hunter2secret dbname=d");
468        let rendered = format!("{config:?}");
469
470        assert!(
471            !rendered.contains("hunter2secret"),
472            "PgConfig Debug leaked the DB password: {rendered}"
473        );
474    }
475}