Skip to main content

cdk_postgres/
lib.rs

1//! CDK Postgres
2
3use std::fmt::Debug;
4use std::sync::atomic::AtomicBool;
5use std::sync::{Arc, OnceLock};
6use std::time::Duration;
7
8use cdk_common::database::Error;
9use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
10use cdk_sql_common::mint::SQLMintAuthDatabase;
11use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
12use cdk_sql_common::stmt::{Column, Statement};
13use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
14use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
15use native_tls::TlsConnector;
16use postgres_native_tls::MakeTlsConnector;
17use tokio::sync::{Mutex, Notify};
18use tokio::time::timeout;
19use tokio_postgres::{connect, Client, Error as PgError, NoTls};
20
21mod db;
22mod value;
23
24#[derive(Debug)]
25/// Postgres connection pool
26pub struct PgConnectionPool;
27
28#[derive(Clone)]
29/// SSL Mode
30pub enum SslMode {
31    /// No TLS
32    NoTls(NoTls),
33    /// Native TLS
34    NativeTls(postgres_native_tls::MakeTlsConnector),
35}
36const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
37const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
38const SSLMODE_PREFER: &str = "sslmode=prefer";
39const SSLMODE_ALLOW: &str = "sslmode=allow";
40const SSLMODE_REQUIRE: &str = "sslmode=require";
41
42impl Default for SslMode {
43    fn default() -> Self {
44        SslMode::NoTls(NoTls {})
45    }
46}
47
48impl Debug for SslMode {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        let debug_text = match self {
51            Self::NoTls(_) => "NoTls",
52            Self::NativeTls(_) => "NativeTls",
53        };
54
55        write!(f, "SslMode::{debug_text}")
56    }
57}
58
59/// Postgres configuration
60#[derive(Clone, Debug)]
61pub struct PgConfig {
62    url: String,
63    schema: Option<String>,
64    tls: SslMode,
65}
66
67impl DatabaseConfig for PgConfig {
68    fn default_timeout(&self) -> Duration {
69        Duration::from_secs(10)
70    }
71
72    fn max_size(&self) -> usize {
73        20
74    }
75}
76
77impl PgConfig {
78    /// strip schema from the connection string
79    fn strip_schema(input: &str) -> (Option<String>, String) {
80        let mut schema: Option<String> = None;
81
82        // Split by whitespace
83        let mut parts = Vec::new();
84        for token in input.split_whitespace() {
85            if let Some(rest) = token.strip_prefix("schema=") {
86                schema = Some(rest.to_string());
87            } else {
88                parts.push(token);
89            }
90        }
91
92        let cleaned = parts.join(" ");
93        (schema, cleaned)
94    }
95}
96
97impl From<&str> for PgConfig {
98    fn from(conn_str: &str) -> Self {
99        let (schema, conn_str) = Self::strip_schema(conn_str);
100        fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
101            let mut builder = TlsConnector::builder();
102            if accept_invalid_certs {
103                builder.danger_accept_invalid_certs(true);
104            }
105            if accept_invalid_hostnames {
106                builder.danger_accept_invalid_hostnames(true);
107            }
108
109            match builder.build() {
110                Ok(connector) => {
111                    let make_tls_connector = MakeTlsConnector::new(connector);
112                    SslMode::NativeTls(make_tls_connector)
113                }
114                Err(_) => SslMode::NoTls(NoTls {}),
115            }
116        }
117
118        let tls = if conn_str.contains(SSLMODE_VERIFY_FULL) {
119            // Strict TLS: valid certs and hostnames required
120            build_tls(false, false)
121        } else if conn_str.contains(SSLMODE_VERIFY_CA) {
122            // Verify CA, but allow invalid hostnames
123            build_tls(false, true)
124        } else if conn_str.contains(SSLMODE_PREFER)
125            || conn_str.contains(SSLMODE_ALLOW)
126            || conn_str.contains(SSLMODE_REQUIRE)
127        {
128            // Lenient TLS for preferred/allow/require: accept invalid certs and hostnames
129            build_tls(true, true)
130        } else {
131            SslMode::NoTls(NoTls {})
132        };
133
134        PgConfig {
135            url: conn_str.to_owned(),
136            schema,
137            tls,
138        }
139    }
140}
141
142impl DatabasePool for PgConnectionPool {
143    type Config = PgConfig;
144
145    type Connection = PostgresConnection;
146
147    type Error = PgError;
148
149    fn new_resource(
150        config: &Self::Config,
151        stale: Arc<AtomicBool>,
152        timeout: Duration,
153    ) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
154        Ok(PostgresConnection::new(config.to_owned(), timeout, stale))
155    }
156}
157
158/// A postgres connection
159#[derive(Debug)]
160pub struct PostgresConnection {
161    timeout: Duration,
162    error: Arc<Mutex<Option<cdk_common::database::Error>>>,
163    result: Arc<OnceLock<Client>>,
164    notify: Arc<Notify>,
165}
166
167impl PostgresConnection {
168    /// Creates a new instance
169    pub fn new(config: PgConfig, timeout: Duration, stale: Arc<AtomicBool>) -> Self {
170        let failed = Arc::new(Mutex::new(None));
171        let result = Arc::new(OnceLock::new());
172        let notify = Arc::new(Notify::new());
173        let error_clone = failed.clone();
174        let result_clone = result.clone();
175        let notify_clone = notify.clone();
176
177        async fn select_schema(conn: &Client, schema: &str) -> Result<(), Error> {
178            conn.batch_execute(&format!(
179                r#"
180                    CREATE SCHEMA IF NOT EXISTS "{schema}";
181                    SET search_path TO "{schema}"
182                    "#
183            ))
184            .await
185            .map_err(|e| Error::Database(Box::new(e)))
186        }
187
188        tokio::spawn(async move {
189            match config.tls {
190                SslMode::NoTls(tls) => {
191                    let (client, connection) = match connect(&config.url, tls).await {
192                        Ok((client, connection)) => (client, connection),
193                        Err(err) => {
194                            *error_clone.lock().await =
195                                Some(cdk_common::database::Error::Database(Box::new(err)));
196                            stale.store(false, std::sync::atomic::Ordering::Release);
197                            notify_clone.notify_waiters();
198                            return;
199                        }
200                    };
201
202                    let stale_for_spawn = stale.clone();
203                    tokio::spawn(async move {
204                        let _ = connection.await;
205                        stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
206                    });
207
208                    if let Some(schema) = config.schema.as_ref() {
209                        if let Err(err) = select_schema(&client, schema).await {
210                            *error_clone.lock().await = Some(err);
211                            stale.store(false, std::sync::atomic::Ordering::Release);
212                            notify_clone.notify_waiters();
213                            return;
214                        }
215                    }
216
217                    let _ = result_clone.set(client);
218                    notify_clone.notify_waiters();
219                }
220                SslMode::NativeTls(tls) => {
221                    let (client, connection) = match connect(&config.url, tls).await {
222                        Ok((client, connection)) => (client, connection),
223                        Err(err) => {
224                            *error_clone.lock().await =
225                                Some(cdk_common::database::Error::Database(Box::new(err)));
226                            stale.store(false, std::sync::atomic::Ordering::Release);
227                            notify_clone.notify_waiters();
228                            return;
229                        }
230                    };
231
232                    let stale_for_spawn = stale.clone();
233                    tokio::spawn(async move {
234                        let _ = connection.await;
235                        stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
236                    });
237
238                    if let Some(schema) = config.schema.as_ref() {
239                        if let Err(err) = select_schema(&client, schema).await {
240                            *error_clone.lock().await = Some(err);
241                            stale.store(true, std::sync::atomic::Ordering::Release);
242                            notify_clone.notify_waiters();
243                            return;
244                        }
245                    }
246
247                    let _ = result_clone.set(client);
248                    notify_clone.notify_waiters();
249                }
250            }
251        });
252
253        Self {
254            error: failed,
255            timeout,
256            result,
257            notify,
258        }
259    }
260
261    /// Gets the wrapped instance or the connection error. The connection is returned as reference,
262    /// and the actual error is returned once, next times a generic error would be returned
263    async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
264        if let Some(client) = self.result.get() {
265            return Ok(client);
266        }
267
268        if let Some(error) = self.error.lock().await.take() {
269            return Err(error);
270        }
271
272        if timeout(self.timeout, self.notify.notified()).await.is_err() {
273            return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
274        }
275
276        // Check result again
277        if let Some(client) = self.result.get() {
278            Ok(client)
279        } else if let Some(error) = self.error.lock().await.take() {
280            Err(error)
281        } else {
282            Err(cdk_common::database::Error::Internal(
283                "Failed connection".to_owned(),
284            ))
285        }
286    }
287}
288
289#[async_trait::async_trait]
290impl DatabaseConnector for PostgresConnection {
291    type Transaction = GenericTransactionHandler<Self>;
292}
293
294#[async_trait::async_trait]
295impl DatabaseExecutor for PostgresConnection {
296    fn name() -> &'static str {
297        "postgres"
298    }
299
300    async fn execute(&self, statement: Statement) -> Result<usize, Error> {
301        pg_execute(self.inner().await?, statement).await
302    }
303
304    async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
305        pg_fetch_one(self.inner().await?, statement).await
306    }
307
308    async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
309        pg_fetch_all(self.inner().await?, statement).await
310    }
311
312    async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
313        pg_pluck(self.inner().await?, statement).await
314    }
315
316    async fn batch(&self, statement: Statement) -> Result<(), Error> {
317        pg_batch(self.inner().await?, statement).await
318    }
319}
320
321/// Mint DB implementation with PostgreSQL
322pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
323
324/// Mint Auth database with Postgres
325pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
326
327/// Wallet DB implementation with PostgreSQL
328pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
329
330/// Convenience free functions (cannot add inherent impls for a foreign type).
331/// These mirror the Mint patterns and call through to the generic constructors.
332pub async fn new_wallet_pg_database(conn_str: &str) -> Result<WalletPgDatabase, Error> {
333    <SQLWalletDatabase<PgConnectionPool>>::new(conn_str).await
334}
335
336#[cfg(test)]
337mod test {
338    use cdk_common::{mint_db_test, wallet_db_test};
339
340    use super::*;
341
342    async fn provide_mint_db(test_id: String) -> MintPgDatabase {
343        let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
344            .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
345            .unwrap_or("host=localhost user=test password=test dbname=testdb port=5433".to_owned());
346
347        let db_url = format!("{db_url} schema={test_id}");
348
349        MintPgDatabase::new(db_url.as_str())
350            .await
351            .expect("database")
352    }
353
354    mint_db_test!(provide_mint_db);
355
356    async fn provide_wallet_db(test_id: String) -> WalletPgDatabase {
357        let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
358            .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
359            .unwrap_or("host=localhost user=test password=test dbname=testdb port=5433".to_owned());
360
361        let db_url = format!("{db_url} schema={test_id}");
362
363        WalletPgDatabase::new(db_url.as_str())
364            .await
365            .expect("database")
366    }
367
368    wallet_db_test!(provide_wallet_db);
369}