1use std::fmt::Debug;
4use std::sync::atomic::AtomicBool;
5use std::sync::{Arc, OnceLock};
6use std::time::Duration;
7
8use cdk_common::database::Error;
9use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
10use cdk_sql_common::mint::SQLMintAuthDatabase;
11use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
12use cdk_sql_common::stmt::{Column, Statement};
13use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
14use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
15use native_tls::TlsConnector;
16use postgres_native_tls::MakeTlsConnector;
17use tokio::sync::{Mutex, Notify};
18use tokio::time::timeout;
19use tokio_postgres::{connect, Client, Error as PgError, NoTls};
20
21mod db;
22mod value;
23
24#[derive(Debug)]
25pub struct PgConnectionPool;
27
28#[derive(Clone)]
29pub enum SslMode {
31 NoTls(NoTls),
33 NativeTls(postgres_native_tls::MakeTlsConnector),
35}
36const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
37const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
38const SSLMODE_PREFER: &str = "sslmode=prefer";
39const SSLMODE_ALLOW: &str = "sslmode=allow";
40const SSLMODE_REQUIRE: &str = "sslmode=require";
41
42impl Default for SslMode {
43 fn default() -> Self {
44 SslMode::NoTls(NoTls {})
45 }
46}
47
48impl Debug for SslMode {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 let debug_text = match self {
51 Self::NoTls(_) => "NoTls",
52 Self::NativeTls(_) => "NativeTls",
53 };
54
55 write!(f, "SslMode::{debug_text}")
56 }
57}
58
59#[derive(Clone, Debug)]
61pub struct PgConfig {
62 url: String,
63 schema: Option<String>,
64 tls: SslMode,
65}
66
67impl DatabaseConfig for PgConfig {
68 fn default_timeout(&self) -> Duration {
69 Duration::from_secs(10)
70 }
71
72 fn max_size(&self) -> usize {
73 20
74 }
75}
76
77impl PgConfig {
78 fn strip_schema(input: &str) -> (Option<String>, String) {
80 let mut schema: Option<String> = None;
81
82 let mut parts = Vec::new();
84 for token in input.split_whitespace() {
85 if let Some(rest) = token.strip_prefix("schema=") {
86 schema = Some(rest.to_string());
87 } else {
88 parts.push(token);
89 }
90 }
91
92 let cleaned = parts.join(" ");
93 (schema, cleaned)
94 }
95}
96
97impl From<&str> for PgConfig {
98 fn from(conn_str: &str) -> Self {
99 let (schema, conn_str) = Self::strip_schema(conn_str);
100 fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
101 let mut builder = TlsConnector::builder();
102 if accept_invalid_certs {
103 builder.danger_accept_invalid_certs(true);
104 }
105 if accept_invalid_hostnames {
106 builder.danger_accept_invalid_hostnames(true);
107 }
108
109 match builder.build() {
110 Ok(connector) => {
111 let make_tls_connector = MakeTlsConnector::new(connector);
112 SslMode::NativeTls(make_tls_connector)
113 }
114 Err(_) => SslMode::NoTls(NoTls {}),
115 }
116 }
117
118 let tls = if conn_str.contains(SSLMODE_VERIFY_FULL) {
119 build_tls(false, false)
121 } else if conn_str.contains(SSLMODE_VERIFY_CA) {
122 build_tls(false, true)
124 } else if conn_str.contains(SSLMODE_PREFER)
125 || conn_str.contains(SSLMODE_ALLOW)
126 || conn_str.contains(SSLMODE_REQUIRE)
127 {
128 build_tls(true, true)
130 } else {
131 SslMode::NoTls(NoTls {})
132 };
133
134 PgConfig {
135 url: conn_str.to_owned(),
136 schema,
137 tls,
138 }
139 }
140}
141
142impl DatabasePool for PgConnectionPool {
143 type Config = PgConfig;
144
145 type Connection = PostgresConnection;
146
147 type Error = PgError;
148
149 fn new_resource(
150 config: &Self::Config,
151 stale: Arc<AtomicBool>,
152 timeout: Duration,
153 ) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
154 Ok(PostgresConnection::new(config.to_owned(), timeout, stale))
155 }
156}
157
158#[derive(Debug)]
160pub struct PostgresConnection {
161 timeout: Duration,
162 error: Arc<Mutex<Option<cdk_common::database::Error>>>,
163 result: Arc<OnceLock<Client>>,
164 notify: Arc<Notify>,
165}
166
167impl PostgresConnection {
168 pub fn new(config: PgConfig, timeout: Duration, stale: Arc<AtomicBool>) -> Self {
170 let failed = Arc::new(Mutex::new(None));
171 let result = Arc::new(OnceLock::new());
172 let notify = Arc::new(Notify::new());
173 let error_clone = failed.clone();
174 let result_clone = result.clone();
175 let notify_clone = notify.clone();
176
177 async fn select_schema(conn: &Client, schema: &str) -> Result<(), Error> {
178 conn.batch_execute(&format!(
179 r#"
180 CREATE SCHEMA IF NOT EXISTS "{schema}";
181 SET search_path TO "{schema}"
182 "#
183 ))
184 .await
185 .map_err(|e| Error::Database(Box::new(e)))
186 }
187
188 tokio::spawn(async move {
189 match config.tls {
190 SslMode::NoTls(tls) => {
191 let (client, connection) = match connect(&config.url, tls).await {
192 Ok((client, connection)) => (client, connection),
193 Err(err) => {
194 *error_clone.lock().await =
195 Some(cdk_common::database::Error::Database(Box::new(err)));
196 stale.store(false, std::sync::atomic::Ordering::Release);
197 notify_clone.notify_waiters();
198 return;
199 }
200 };
201
202 let stale_for_spawn = stale.clone();
203 tokio::spawn(async move {
204 let _ = connection.await;
205 stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
206 });
207
208 if let Some(schema) = config.schema.as_ref() {
209 if let Err(err) = select_schema(&client, schema).await {
210 *error_clone.lock().await = Some(err);
211 stale.store(false, std::sync::atomic::Ordering::Release);
212 notify_clone.notify_waiters();
213 return;
214 }
215 }
216
217 let _ = result_clone.set(client);
218 notify_clone.notify_waiters();
219 }
220 SslMode::NativeTls(tls) => {
221 let (client, connection) = match connect(&config.url, tls).await {
222 Ok((client, connection)) => (client, connection),
223 Err(err) => {
224 *error_clone.lock().await =
225 Some(cdk_common::database::Error::Database(Box::new(err)));
226 stale.store(false, std::sync::atomic::Ordering::Release);
227 notify_clone.notify_waiters();
228 return;
229 }
230 };
231
232 let stale_for_spawn = stale.clone();
233 tokio::spawn(async move {
234 let _ = connection.await;
235 stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
236 });
237
238 if let Some(schema) = config.schema.as_ref() {
239 if let Err(err) = select_schema(&client, schema).await {
240 *error_clone.lock().await = Some(err);
241 stale.store(true, std::sync::atomic::Ordering::Release);
242 notify_clone.notify_waiters();
243 return;
244 }
245 }
246
247 let _ = result_clone.set(client);
248 notify_clone.notify_waiters();
249 }
250 }
251 });
252
253 Self {
254 error: failed,
255 timeout,
256 result,
257 notify,
258 }
259 }
260
261 async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
264 if let Some(client) = self.result.get() {
265 return Ok(client);
266 }
267
268 if let Some(error) = self.error.lock().await.take() {
269 return Err(error);
270 }
271
272 if timeout(self.timeout, self.notify.notified()).await.is_err() {
273 return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
274 }
275
276 if let Some(client) = self.result.get() {
278 Ok(client)
279 } else if let Some(error) = self.error.lock().await.take() {
280 Err(error)
281 } else {
282 Err(cdk_common::database::Error::Internal(
283 "Failed connection".to_owned(),
284 ))
285 }
286 }
287}
288
289#[async_trait::async_trait]
290impl DatabaseConnector for PostgresConnection {
291 type Transaction = GenericTransactionHandler<Self>;
292}
293
294#[async_trait::async_trait]
295impl DatabaseExecutor for PostgresConnection {
296 fn name() -> &'static str {
297 "postgres"
298 }
299
300 async fn execute(&self, statement: Statement) -> Result<usize, Error> {
301 pg_execute(self.inner().await?, statement).await
302 }
303
304 async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
305 pg_fetch_one(self.inner().await?, statement).await
306 }
307
308 async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
309 pg_fetch_all(self.inner().await?, statement).await
310 }
311
312 async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
313 pg_pluck(self.inner().await?, statement).await
314 }
315
316 async fn batch(&self, statement: Statement) -> Result<(), Error> {
317 pg_batch(self.inner().await?, statement).await
318 }
319}
320
321pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
323
324pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
326
327pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
329
330pub async fn new_wallet_pg_database(conn_str: &str) -> Result<WalletPgDatabase, Error> {
333 <SQLWalletDatabase<PgConnectionPool>>::new(conn_str).await
334}
335
336#[cfg(test)]
337mod test {
338 use cdk_common::{mint_db_test, wallet_db_test};
339
340 use super::*;
341
342 async fn provide_mint_db(test_id: String) -> MintPgDatabase {
343 let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
344 .or_else(|_| std::env::var("PG_DB_URL")) .unwrap_or("host=localhost user=test password=test dbname=testdb port=5433".to_owned());
346
347 let db_url = format!("{db_url} schema={test_id}");
348
349 MintPgDatabase::new(db_url.as_str())
350 .await
351 .expect("database")
352 }
353
354 mint_db_test!(provide_mint_db);
355
356 async fn provide_wallet_db(test_id: String) -> WalletPgDatabase {
357 let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
358 .or_else(|_| std::env::var("PG_DB_URL")) .unwrap_or("host=localhost user=test password=test dbname=testdb port=5433".to_owned());
360
361 let db_url = format!("{db_url} schema={test_id}");
362
363 WalletPgDatabase::new(db_url.as_str())
364 .await
365 .expect("database")
366 }
367
368 wallet_db_test!(provide_wallet_db);
369}