Skip to main content

cdk_postgres/
lib.rs

1//! CDK Postgres
2
3use std::fmt::Debug;
4use std::sync::atomic::AtomicBool;
5use std::sync::{Arc, OnceLock};
6use std::time::Duration;
7
8use cdk_common::database::Error;
9use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
10use cdk_sql_common::mint::SQLMintAuthDatabase;
11use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
12use cdk_sql_common::stmt::{Column, Statement};
13use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
14use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
15use native_tls::TlsConnector;
16use postgres_native_tls::MakeTlsConnector;
17use tokio::sync::{Mutex, Notify};
18use tokio::time::timeout;
19use tokio_postgres::{connect, Client, Error as PgError, NoTls};
20
21mod db;
22mod value;
23
24#[derive(Debug)]
25/// Postgres connection pool
26pub struct PgConnectionPool;
27
28#[derive(Clone)]
29/// SSL Mode
30pub enum SslMode {
31    /// No TLS
32    NoTls(NoTls),
33    /// Native TLS
34    NativeTls(postgres_native_tls::MakeTlsConnector),
35}
36const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
37const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
38const SSLMODE_PREFER: &str = "sslmode=prefer";
39const SSLMODE_ALLOW: &str = "sslmode=allow";
40const SSLMODE_REQUIRE: &str = "sslmode=require";
41
42impl Default for SslMode {
43    fn default() -> Self {
44        SslMode::NoTls(NoTls {})
45    }
46}
47
48impl Debug for SslMode {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        let debug_text = match self {
51            Self::NoTls(_) => "NoTls",
52            Self::NativeTls(_) => "NativeTls",
53        };
54
55        write!(f, "SslMode::{debug_text}")
56    }
57}
58
59/// Postgres configuration
60#[derive(Clone, Debug)]
61pub struct PgConfig {
62    url: String,
63    schema: Option<String>,
64    tls: SslMode,
65    max_connections: usize,
66    connection_timeout: Duration,
67}
68
69impl DatabaseConfig for PgConfig {
70    fn default_timeout(&self) -> Duration {
71        self.connection_timeout
72    }
73
74    fn max_size(&self) -> usize {
75        self.max_connections
76    }
77}
78
79/// Default maximum number of connections in the pool
80const DEFAULT_MAX_CONNECTIONS: usize = 20;
81
82/// Default connection timeout in seconds
83const DEFAULT_CONNECTION_TIMEOUT_SECS: u64 = 10;
84
85/// Build a TLS connector with the given certificate/hostname validation settings.
86fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
87    let mut builder = TlsConnector::builder();
88    if accept_invalid_certs {
89        builder.danger_accept_invalid_certs(true);
90    }
91    if accept_invalid_hostnames {
92        builder.danger_accept_invalid_hostnames(true);
93    }
94
95    match builder.build() {
96        Ok(connector) => {
97            let make_tls_connector = MakeTlsConnector::new(connector);
98            SslMode::NativeTls(make_tls_connector)
99        }
100        Err(_) => SslMode::NoTls(NoTls {}),
101    }
102}
103
104/// Determine TLS mode from the `sslmode=` parameter in a connection URL.
105fn ssl_mode_from_url(url: &str) -> SslMode {
106    if url.contains(SSLMODE_VERIFY_FULL) {
107        // Strict TLS: valid certs and hostnames required
108        build_tls(false, false)
109    } else if url.contains(SSLMODE_VERIFY_CA) {
110        // Verify CA, but allow invalid hostnames
111        build_tls(false, true)
112    } else if url.contains(SSLMODE_PREFER)
113        || url.contains(SSLMODE_ALLOW)
114        || url.contains(SSLMODE_REQUIRE)
115    {
116        // Lenient TLS for preferred/allow/require: accept invalid certs and hostnames
117        build_tls(true, true)
118    } else {
119        SslMode::NoTls(NoTls {})
120    }
121}
122
123/// Resolve TLS mode from an explicit `tls_mode` string (from config/env), such
124/// as `"disable"`, `"prefer"`, `"require"`, `"verify-ca"`, or `"verify-full"`.
125///
126/// If the value is `None`, falls back to parsing `sslmode=` from the URL.
127fn ssl_mode_from_config(tls_mode: Option<&str>, url: &str) -> SslMode {
128    match tls_mode {
129        Some(mode) => match mode.to_lowercase().as_str() {
130            "verify-full" => build_tls(false, false),
131            "verify-ca" => build_tls(false, true),
132            "require" | "prefer" | "allow" => build_tls(true, true),
133            // "disable" or any unrecognised value → no TLS
134            _ => SslMode::NoTls(NoTls {}),
135        },
136        // No explicit tls_mode: fall back to URL-based detection
137        None => ssl_mode_from_url(url),
138    }
139}
140
141impl PgConfig {
142    /// Create a new `PgConfig` with explicit TLS mode, pool size, and timeout.
143    ///
144    /// `tls_mode` accepts the same strings as the configuration file:
145    /// `"disable"`, `"prefer"`, `"allow"`, `"require"`, `"verify-ca"`,
146    /// `"verify-full"`.  When `None`, the TLS mode is inferred from
147    /// `sslmode=` in the connection URL (matching the old behaviour).
148    pub fn new(
149        conn_str: &str,
150        tls_mode: Option<&str>,
151        max_connections: Option<usize>,
152        connection_timeout_secs: Option<u64>,
153    ) -> Self {
154        let (schema, conn_str) = Self::strip_schema(conn_str);
155        let tls = ssl_mode_from_config(tls_mode, &conn_str);
156        PgConfig {
157            url: conn_str,
158            schema,
159            tls,
160            max_connections: max_connections.unwrap_or(DEFAULT_MAX_CONNECTIONS),
161            connection_timeout: Duration::from_secs(
162                connection_timeout_secs.unwrap_or(DEFAULT_CONNECTION_TIMEOUT_SECS),
163            ),
164        }
165    }
166
167    /// strip schema from the connection string
168    fn strip_schema(input: &str) -> (Option<String>, String) {
169        let mut schema: Option<String> = None;
170
171        // Split by whitespace
172        let mut parts = Vec::new();
173        for token in input.split_whitespace() {
174            if let Some(rest) = token.strip_prefix("schema=") {
175                schema = Some(rest.to_string());
176            } else {
177                parts.push(token);
178            }
179        }
180
181        let cleaned = parts.join(" ");
182        (schema, cleaned)
183    }
184}
185
186impl From<&str> for PgConfig {
187    fn from(conn_str: &str) -> Self {
188        let (schema, conn_str) = Self::strip_schema(conn_str);
189        let tls = ssl_mode_from_url(&conn_str);
190
191        PgConfig {
192            url: conn_str,
193            schema,
194            tls,
195            max_connections: DEFAULT_MAX_CONNECTIONS,
196            connection_timeout: Duration::from_secs(DEFAULT_CONNECTION_TIMEOUT_SECS),
197        }
198    }
199}
200
201impl DatabasePool for PgConnectionPool {
202    type Config = PgConfig;
203
204    type Connection = PostgresConnection;
205
206    type Error = PgError;
207
208    fn new_resource(
209        config: &Self::Config,
210        stale: Arc<AtomicBool>,
211        timeout: Duration,
212    ) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
213        Ok(PostgresConnection::new(config.to_owned(), timeout, stale))
214    }
215}
216
217/// A postgres connection
218#[derive(Debug)]
219pub struct PostgresConnection {
220    timeout: Duration,
221    error: Arc<Mutex<Option<cdk_common::database::Error>>>,
222    result: Arc<OnceLock<Client>>,
223    notify: Arc<Notify>,
224}
225
226impl PostgresConnection {
227    /// Creates a new instance
228    pub fn new(config: PgConfig, timeout: Duration, stale: Arc<AtomicBool>) -> Self {
229        let failed = Arc::new(Mutex::new(None));
230        let result = Arc::new(OnceLock::new());
231        let notify = Arc::new(Notify::new());
232        let error_clone = failed.clone();
233        let result_clone = result.clone();
234        let notify_clone = notify.clone();
235
236        async fn select_schema(conn: &Client, schema: &str) -> Result<(), Error> {
237            conn.batch_execute(&format!(
238                r#"
239                    CREATE SCHEMA IF NOT EXISTS "{schema}";
240                    SET search_path TO "{schema}"
241                    "#
242            ))
243            .await
244            .map_err(|e| Error::Database(Box::new(e)))
245        }
246
247        tokio::spawn(async move {
248            match config.tls {
249                SslMode::NoTls(tls) => {
250                    let (client, connection) = match connect(&config.url, tls).await {
251                        Ok((client, connection)) => (client, connection),
252                        Err(err) => {
253                            *error_clone.lock().await =
254                                Some(cdk_common::database::Error::Database(Box::new(err)));
255                            stale.store(false, std::sync::atomic::Ordering::Release);
256                            notify_clone.notify_waiters();
257                            return;
258                        }
259                    };
260
261                    let stale_for_spawn = stale.clone();
262                    tokio::spawn(async move {
263                        let _ = connection.await;
264                        stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
265                    });
266
267                    if let Some(schema) = config.schema.as_ref() {
268                        if let Err(err) = select_schema(&client, schema).await {
269                            *error_clone.lock().await = Some(err);
270                            stale.store(false, std::sync::atomic::Ordering::Release);
271                            notify_clone.notify_waiters();
272                            return;
273                        }
274                    }
275
276                    let _ = result_clone.set(client);
277                    notify_clone.notify_waiters();
278                }
279                SslMode::NativeTls(tls) => {
280                    let (client, connection) = match connect(&config.url, tls).await {
281                        Ok((client, connection)) => (client, connection),
282                        Err(err) => {
283                            *error_clone.lock().await =
284                                Some(cdk_common::database::Error::Database(Box::new(err)));
285                            stale.store(false, std::sync::atomic::Ordering::Release);
286                            notify_clone.notify_waiters();
287                            return;
288                        }
289                    };
290
291                    let stale_for_spawn = stale.clone();
292                    tokio::spawn(async move {
293                        let _ = connection.await;
294                        stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
295                    });
296
297                    if let Some(schema) = config.schema.as_ref() {
298                        if let Err(err) = select_schema(&client, schema).await {
299                            *error_clone.lock().await = Some(err);
300                            stale.store(true, std::sync::atomic::Ordering::Release);
301                            notify_clone.notify_waiters();
302                            return;
303                        }
304                    }
305
306                    let _ = result_clone.set(client);
307                    notify_clone.notify_waiters();
308                }
309            }
310        });
311
312        Self {
313            error: failed,
314            timeout,
315            result,
316            notify,
317        }
318    }
319
320    /// Gets the wrapped instance or the connection error. The connection is returned as reference,
321    /// and the actual error is returned once, next times a generic error would be returned
322    async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
323        if let Some(client) = self.result.get() {
324            return Ok(client);
325        }
326
327        if let Some(error) = self.error.lock().await.take() {
328            return Err(error);
329        }
330
331        if timeout(self.timeout, self.notify.notified()).await.is_err() {
332            return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
333        }
334
335        // Check result again
336        if let Some(client) = self.result.get() {
337            Ok(client)
338        } else if let Some(error) = self.error.lock().await.take() {
339            Err(error)
340        } else {
341            Err(cdk_common::database::Error::Internal(
342                "Failed connection".to_owned(),
343            ))
344        }
345    }
346}
347
348#[async_trait::async_trait]
349impl DatabaseConnector for PostgresConnection {
350    type Transaction = GenericTransactionHandler<Self>;
351}
352
353#[async_trait::async_trait]
354impl DatabaseExecutor for PostgresConnection {
355    fn name() -> &'static str {
356        "postgres"
357    }
358
359    async fn execute(&self, statement: Statement) -> Result<usize, Error> {
360        pg_execute(self.inner().await?, statement).await
361    }
362
363    async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
364        pg_fetch_one(self.inner().await?, statement).await
365    }
366
367    async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
368        pg_fetch_all(self.inner().await?, statement).await
369    }
370
371    async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
372        pg_pluck(self.inner().await?, statement).await
373    }
374
375    async fn batch(&self, statement: Statement) -> Result<(), Error> {
376        pg_batch(self.inner().await?, statement).await
377    }
378}
379
380/// Mint DB implementation with PostgreSQL
381pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
382
383/// Mint Auth database with Postgres
384pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
385
386/// Wallet DB implementation with PostgreSQL
387pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
388
389/// Convenience free functions (cannot add inherent impls for a foreign type).
390/// These mirror the Mint patterns and call through to the generic constructors.
391pub async fn new_wallet_pg_database(conn_str: &str) -> Result<WalletPgDatabase, Error> {
392    <SQLWalletDatabase<PgConnectionPool>>::new(conn_str).await
393}
394
395#[cfg(test)]
396mod test {
397    use cdk_common::{mint_db_test, wallet_db_test};
398
399    use super::*;
400
401    async fn provide_mint_db(test_id: String) -> MintPgDatabase {
402        let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
403            .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
404            .unwrap_or(
405                "host=localhost user=cdk_user password=cdk_password dbname=cdk_mint port=5432"
406                    .to_owned(),
407            );
408
409        let db_url = format!("{db_url} schema={test_id}");
410
411        MintPgDatabase::new(db_url.as_str())
412            .await
413            .expect("database")
414    }
415
416    mint_db_test!(provide_mint_db);
417
418    async fn provide_wallet_db(test_id: String) -> WalletPgDatabase {
419        let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
420            .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
421            .unwrap_or(
422                "host=localhost user=cdk_user password=cdk_password dbname=cdk_mint port=5432"
423                    .to_owned(),
424            );
425
426        let db_url = format!("{db_url} schema={test_id}");
427
428        WalletPgDatabase::new(db_url.as_str())
429            .await
430            .expect("database")
431    }
432
433    wallet_db_test!(provide_wallet_db);
434}