cdk_postgres/
lib.rs

1use std::fmt::Debug;
2use std::sync::atomic::AtomicBool;
3use std::sync::{Arc, OnceLock};
4use std::time::Duration;
5
6use cdk_common::database::Error;
7use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
8use cdk_sql_common::mint::SQLMintAuthDatabase;
9use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
10use cdk_sql_common::stmt::{Column, Statement};
11use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
12use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
13use native_tls::TlsConnector;
14use postgres_native_tls::MakeTlsConnector;
15use tokio::sync::{Mutex, Notify};
16use tokio::time::timeout;
17use tokio_postgres::{connect, Client, Error as PgError, NoTls};
18
19mod db;
20mod value;
21
22#[derive(Debug)]
23pub struct PgConnectionPool;
24
25#[derive(Clone)]
26pub enum SslMode {
27    NoTls(NoTls),
28    NativeTls(postgres_native_tls::MakeTlsConnector),
29}
30const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
31const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
32const SSLMODE_PREFER: &str = "sslmode=prefer";
33const SSLMODE_ALLOW: &str = "sslmode=allow";
34const SSLMODE_REQUIRE: &str = "sslmode=require";
35
36impl Default for SslMode {
37    fn default() -> Self {
38        SslMode::NoTls(NoTls {})
39    }
40}
41
42impl Debug for SslMode {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        let debug_text = match self {
45            Self::NoTls(_) => "NoTls",
46            Self::NativeTls(_) => "NativeTls",
47        };
48
49        write!(f, "SslMode::{debug_text}")
50    }
51}
52
53/// Postgres configuration
54#[derive(Clone, Debug)]
55pub struct PgConfig {
56    url: String,
57    schema: Option<String>,
58    tls: SslMode,
59}
60
61impl DatabaseConfig for PgConfig {
62    fn default_timeout(&self) -> Duration {
63        Duration::from_secs(10)
64    }
65
66    fn max_size(&self) -> usize {
67        20
68    }
69}
70
71impl PgConfig {
72    /// strip schema from the connection string
73    fn strip_schema(input: &str) -> (Option<String>, String) {
74        let mut schema: Option<String> = None;
75
76        // Split by whitespace
77        let mut parts = Vec::new();
78        for token in input.split_whitespace() {
79            if let Some(rest) = token.strip_prefix("schema=") {
80                schema = Some(rest.to_string());
81            } else {
82                parts.push(token);
83            }
84        }
85
86        let cleaned = parts.join(" ");
87        (schema, cleaned)
88    }
89}
90
91impl From<&str> for PgConfig {
92    fn from(conn_str: &str) -> Self {
93        let (schema, conn_str) = Self::strip_schema(conn_str);
94        fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
95            let mut builder = TlsConnector::builder();
96            if accept_invalid_certs {
97                builder.danger_accept_invalid_certs(true);
98            }
99            if accept_invalid_hostnames {
100                builder.danger_accept_invalid_hostnames(true);
101            }
102
103            match builder.build() {
104                Ok(connector) => {
105                    let make_tls_connector = MakeTlsConnector::new(connector);
106                    SslMode::NativeTls(make_tls_connector)
107                }
108                Err(_) => SslMode::NoTls(NoTls {}),
109            }
110        }
111
112        let tls = if conn_str.contains(SSLMODE_VERIFY_FULL) {
113            // Strict TLS: valid certs and hostnames required
114            build_tls(false, false)
115        } else if conn_str.contains(SSLMODE_VERIFY_CA) {
116            // Verify CA, but allow invalid hostnames
117            build_tls(false, true)
118        } else if conn_str.contains(SSLMODE_PREFER)
119            || conn_str.contains(SSLMODE_ALLOW)
120            || conn_str.contains(SSLMODE_REQUIRE)
121        {
122            // Lenient TLS for preferred/allow/require: accept invalid certs and hostnames
123            build_tls(true, true)
124        } else {
125            SslMode::NoTls(NoTls {})
126        };
127
128        PgConfig {
129            url: conn_str.to_owned(),
130            schema,
131            tls,
132        }
133    }
134}
135
136impl DatabasePool for PgConnectionPool {
137    type Config = PgConfig;
138
139    type Connection = PostgresConnection;
140
141    type Error = PgError;
142
143    fn new_resource(
144        config: &Self::Config,
145        stale: Arc<AtomicBool>,
146        timeout: Duration,
147    ) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
148        Ok(PostgresConnection::new(config.to_owned(), timeout, stale))
149    }
150}
151
152/// A postgres connection
153#[derive(Debug)]
154pub struct PostgresConnection {
155    timeout: Duration,
156    error: Arc<Mutex<Option<cdk_common::database::Error>>>,
157    result: Arc<OnceLock<Client>>,
158    notify: Arc<Notify>,
159}
160
161impl PostgresConnection {
162    /// Creates a new instance
163    pub fn new(config: PgConfig, timeout: Duration, stale: Arc<AtomicBool>) -> Self {
164        let failed = Arc::new(Mutex::new(None));
165        let result = Arc::new(OnceLock::new());
166        let notify = Arc::new(Notify::new());
167        let error_clone = failed.clone();
168        let result_clone = result.clone();
169        let notify_clone = notify.clone();
170
171        async fn select_schema(conn: &Client, schema: &str) -> Result<(), Error> {
172            conn.batch_execute(&format!(
173                r#"
174                    CREATE SCHEMA IF NOT EXISTS "{schema}";
175                    SET search_path TO "{schema}"
176                    "#
177            ))
178            .await
179            .map_err(|e| Error::Database(Box::new(e)))
180        }
181
182        tokio::spawn(async move {
183            match config.tls {
184                SslMode::NoTls(tls) => {
185                    let (client, connection) = match connect(&config.url, tls).await {
186                        Ok((client, connection)) => (client, connection),
187                        Err(err) => {
188                            *error_clone.lock().await =
189                                Some(cdk_common::database::Error::Database(Box::new(err)));
190                            stale.store(false, std::sync::atomic::Ordering::Release);
191                            notify_clone.notify_waiters();
192                            return;
193                        }
194                    };
195
196                    let stale_for_spawn = stale.clone();
197                    tokio::spawn(async move {
198                        let _ = connection.await;
199                        stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
200                    });
201
202                    if let Some(schema) = config.schema.as_ref() {
203                        if let Err(err) = select_schema(&client, schema).await {
204                            *error_clone.lock().await = Some(err);
205                            stale.store(false, std::sync::atomic::Ordering::Release);
206                            notify_clone.notify_waiters();
207                            return;
208                        }
209                    }
210
211                    let _ = result_clone.set(client);
212                    notify_clone.notify_waiters();
213                }
214                SslMode::NativeTls(tls) => {
215                    let (client, connection) = match connect(&config.url, tls).await {
216                        Ok((client, connection)) => (client, connection),
217                        Err(err) => {
218                            *error_clone.lock().await =
219                                Some(cdk_common::database::Error::Database(Box::new(err)));
220                            stale.store(false, std::sync::atomic::Ordering::Release);
221                            notify_clone.notify_waiters();
222                            return;
223                        }
224                    };
225
226                    let stale_for_spawn = stale.clone();
227                    tokio::spawn(async move {
228                        let _ = connection.await;
229                        stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
230                    });
231
232                    if let Some(schema) = config.schema.as_ref() {
233                        if let Err(err) = select_schema(&client, schema).await {
234                            *error_clone.lock().await = Some(err);
235                            stale.store(true, std::sync::atomic::Ordering::Release);
236                            notify_clone.notify_waiters();
237                            return;
238                        }
239                    }
240
241                    let _ = result_clone.set(client);
242                    notify_clone.notify_waiters();
243                }
244            }
245        });
246
247        Self {
248            error: failed,
249            timeout,
250            result,
251            notify,
252        }
253    }
254
255    /// Gets the wrapped instance or the connection error. The connection is returned as reference,
256    /// and the actual error is returned once, next times a generic error would be returned
257    async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
258        if let Some(client) = self.result.get() {
259            return Ok(client);
260        }
261
262        if let Some(error) = self.error.lock().await.take() {
263            return Err(error);
264        }
265
266        if timeout(self.timeout, self.notify.notified()).await.is_err() {
267            return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
268        }
269
270        // Check result again
271        if let Some(client) = self.result.get() {
272            Ok(client)
273        } else if let Some(error) = self.error.lock().await.take() {
274            Err(error)
275        } else {
276            Err(cdk_common::database::Error::Internal(
277                "Failed connection".to_owned(),
278            ))
279        }
280    }
281}
282
283#[async_trait::async_trait]
284impl DatabaseConnector for PostgresConnection {
285    type Transaction = GenericTransactionHandler<Self>;
286}
287
288#[async_trait::async_trait]
289impl DatabaseExecutor for PostgresConnection {
290    fn name() -> &'static str {
291        "postgres"
292    }
293
294    async fn execute(&self, statement: Statement) -> Result<usize, Error> {
295        pg_execute(self.inner().await?, statement).await
296    }
297
298    async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
299        pg_fetch_one(self.inner().await?, statement).await
300    }
301
302    async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
303        pg_fetch_all(self.inner().await?, statement).await
304    }
305
306    async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
307        pg_pluck(self.inner().await?, statement).await
308    }
309
310    async fn batch(&self, statement: Statement) -> Result<(), Error> {
311        pg_batch(self.inner().await?, statement).await
312    }
313}
314
315/// Mint DB implementation with PostgreSQL
316pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
317
318/// Mint Auth database with Postgres
319#[cfg(feature = "auth")]
320pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
321
322/// Mint DB implementation with PostgresSQL
323pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
324
325#[cfg(test)]
326mod test {
327    use cdk_common::mint_db_test;
328
329    use super::*;
330
331    async fn provide_db(test_id: String) -> MintPgDatabase {
332        let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
333            .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
334            .unwrap_or("host=localhost user=test password=test dbname=testdb port=5433".to_owned());
335
336        let db_url = format!("{db_url} schema={test_id}");
337
338        let db = MintPgDatabase::new(db_url.as_str())
339            .await
340            .expect("database");
341        db
342    }
343
344    mint_db_test!(provide_db);
345}