1use std::fmt::Debug;
4use std::sync::atomic::AtomicBool;
5use std::sync::{Arc, OnceLock};
6use std::time::Duration;
7
8use cdk_common::database::Error;
9use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
10use cdk_sql_common::mint::SQLMintAuthDatabase;
11use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
12use cdk_sql_common::stmt::{Column, Statement};
13use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
14use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
15use native_tls::TlsConnector;
16use postgres_native_tls::MakeTlsConnector;
17use tokio::sync::{Mutex, Notify};
18use tokio::time::timeout;
19use tokio_postgres::{connect, Client, Error as PgError, NoTls};
20
21mod db;
22mod value;
23
24#[derive(Debug)]
25pub struct PgConnectionPool;
27
28#[derive(Clone)]
29pub enum SslMode {
31 NoTls(NoTls),
33 NativeTls(postgres_native_tls::MakeTlsConnector),
35}
36const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
37const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
38const SSLMODE_PREFER: &str = "sslmode=prefer";
39const SSLMODE_ALLOW: &str = "sslmode=allow";
40const SSLMODE_REQUIRE: &str = "sslmode=require";
41
42impl Default for SslMode {
43 fn default() -> Self {
44 SslMode::NoTls(NoTls {})
45 }
46}
47
48impl Debug for SslMode {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 let debug_text = match self {
51 Self::NoTls(_) => "NoTls",
52 Self::NativeTls(_) => "NativeTls",
53 };
54
55 write!(f, "SslMode::{debug_text}")
56 }
57}
58
59#[derive(Clone, Debug)]
61pub struct PgConfig {
62 url: String,
63 schema: Option<String>,
64 tls: SslMode,
65 max_connections: usize,
66 connection_timeout: Duration,
67}
68
69impl DatabaseConfig for PgConfig {
70 fn default_timeout(&self) -> Duration {
71 self.connection_timeout
72 }
73
74 fn max_size(&self) -> usize {
75 self.max_connections
76 }
77}
78
79const DEFAULT_MAX_CONNECTIONS: usize = 20;
81
82const DEFAULT_CONNECTION_TIMEOUT_SECS: u64 = 10;
84
85fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
87 let mut builder = TlsConnector::builder();
88 if accept_invalid_certs {
89 builder.danger_accept_invalid_certs(true);
90 }
91 if accept_invalid_hostnames {
92 builder.danger_accept_invalid_hostnames(true);
93 }
94
95 match builder.build() {
96 Ok(connector) => {
97 let make_tls_connector = MakeTlsConnector::new(connector);
98 SslMode::NativeTls(make_tls_connector)
99 }
100 Err(_) => SslMode::NoTls(NoTls {}),
101 }
102}
103
104fn ssl_mode_from_url(url: &str) -> SslMode {
106 if url.contains(SSLMODE_VERIFY_FULL) {
107 build_tls(false, false)
109 } else if url.contains(SSLMODE_VERIFY_CA) {
110 build_tls(false, true)
112 } else if url.contains(SSLMODE_PREFER)
113 || url.contains(SSLMODE_ALLOW)
114 || url.contains(SSLMODE_REQUIRE)
115 {
116 build_tls(true, true)
118 } else {
119 SslMode::NoTls(NoTls {})
120 }
121}
122
123fn ssl_mode_from_config(tls_mode: Option<&str>, url: &str) -> SslMode {
128 match tls_mode {
129 Some(mode) => match mode.to_lowercase().as_str() {
130 "verify-full" => build_tls(false, false),
131 "verify-ca" => build_tls(false, true),
132 "require" | "prefer" | "allow" => build_tls(true, true),
133 _ => SslMode::NoTls(NoTls {}),
135 },
136 None => ssl_mode_from_url(url),
138 }
139}
140
141impl PgConfig {
142 pub fn new(
149 conn_str: &str,
150 tls_mode: Option<&str>,
151 max_connections: Option<usize>,
152 connection_timeout_secs: Option<u64>,
153 ) -> Self {
154 let (schema, conn_str) = Self::strip_schema(conn_str);
155 let tls = ssl_mode_from_config(tls_mode, &conn_str);
156 PgConfig {
157 url: conn_str,
158 schema,
159 tls,
160 max_connections: max_connections.unwrap_or(DEFAULT_MAX_CONNECTIONS),
161 connection_timeout: Duration::from_secs(
162 connection_timeout_secs.unwrap_or(DEFAULT_CONNECTION_TIMEOUT_SECS),
163 ),
164 }
165 }
166
167 fn strip_schema(input: &str) -> (Option<String>, String) {
169 let mut schema: Option<String> = None;
170
171 let mut parts = Vec::new();
173 for token in input.split_whitespace() {
174 if let Some(rest) = token.strip_prefix("schema=") {
175 schema = Some(rest.to_string());
176 } else {
177 parts.push(token);
178 }
179 }
180
181 let cleaned = parts.join(" ");
182 (schema, cleaned)
183 }
184}
185
186impl From<&str> for PgConfig {
187 fn from(conn_str: &str) -> Self {
188 let (schema, conn_str) = Self::strip_schema(conn_str);
189 let tls = ssl_mode_from_url(&conn_str);
190
191 PgConfig {
192 url: conn_str,
193 schema,
194 tls,
195 max_connections: DEFAULT_MAX_CONNECTIONS,
196 connection_timeout: Duration::from_secs(DEFAULT_CONNECTION_TIMEOUT_SECS),
197 }
198 }
199}
200
201impl DatabasePool for PgConnectionPool {
202 type Config = PgConfig;
203
204 type Connection = PostgresConnection;
205
206 type Error = PgError;
207
208 fn new_resource(
209 config: &Self::Config,
210 stale: Arc<AtomicBool>,
211 timeout: Duration,
212 ) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
213 Ok(PostgresConnection::new(config.to_owned(), timeout, stale))
214 }
215}
216
217#[derive(Debug)]
219pub struct PostgresConnection {
220 timeout: Duration,
221 error: Arc<Mutex<Option<cdk_common::database::Error>>>,
222 result: Arc<OnceLock<Client>>,
223 notify: Arc<Notify>,
224}
225
226impl PostgresConnection {
227 pub fn new(config: PgConfig, timeout: Duration, stale: Arc<AtomicBool>) -> Self {
229 let failed = Arc::new(Mutex::new(None));
230 let result = Arc::new(OnceLock::new());
231 let notify = Arc::new(Notify::new());
232 let error_clone = failed.clone();
233 let result_clone = result.clone();
234 let notify_clone = notify.clone();
235
236 async fn select_schema(conn: &Client, schema: &str) -> Result<(), Error> {
237 conn.batch_execute(&format!(
238 r#"
239 CREATE SCHEMA IF NOT EXISTS "{schema}";
240 SET search_path TO "{schema}"
241 "#
242 ))
243 .await
244 .map_err(|e| Error::Database(Box::new(e)))
245 }
246
247 tokio::spawn(async move {
248 match config.tls {
249 SslMode::NoTls(tls) => {
250 let (client, connection) = match connect(&config.url, tls).await {
251 Ok((client, connection)) => (client, connection),
252 Err(err) => {
253 *error_clone.lock().await =
254 Some(cdk_common::database::Error::Database(Box::new(err)));
255 stale.store(false, std::sync::atomic::Ordering::Release);
256 notify_clone.notify_waiters();
257 return;
258 }
259 };
260
261 let stale_for_spawn = stale.clone();
262 tokio::spawn(async move {
263 let _ = connection.await;
264 stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
265 });
266
267 if let Some(schema) = config.schema.as_ref() {
268 if let Err(err) = select_schema(&client, schema).await {
269 *error_clone.lock().await = Some(err);
270 stale.store(false, std::sync::atomic::Ordering::Release);
271 notify_clone.notify_waiters();
272 return;
273 }
274 }
275
276 let _ = result_clone.set(client);
277 notify_clone.notify_waiters();
278 }
279 SslMode::NativeTls(tls) => {
280 let (client, connection) = match connect(&config.url, tls).await {
281 Ok((client, connection)) => (client, connection),
282 Err(err) => {
283 *error_clone.lock().await =
284 Some(cdk_common::database::Error::Database(Box::new(err)));
285 stale.store(false, std::sync::atomic::Ordering::Release);
286 notify_clone.notify_waiters();
287 return;
288 }
289 };
290
291 let stale_for_spawn = stale.clone();
292 tokio::spawn(async move {
293 let _ = connection.await;
294 stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
295 });
296
297 if let Some(schema) = config.schema.as_ref() {
298 if let Err(err) = select_schema(&client, schema).await {
299 *error_clone.lock().await = Some(err);
300 stale.store(true, std::sync::atomic::Ordering::Release);
301 notify_clone.notify_waiters();
302 return;
303 }
304 }
305
306 let _ = result_clone.set(client);
307 notify_clone.notify_waiters();
308 }
309 }
310 });
311
312 Self {
313 error: failed,
314 timeout,
315 result,
316 notify,
317 }
318 }
319
320 async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
323 if let Some(client) = self.result.get() {
324 return Ok(client);
325 }
326
327 if let Some(error) = self.error.lock().await.take() {
328 return Err(error);
329 }
330
331 if timeout(self.timeout, self.notify.notified()).await.is_err() {
332 return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
333 }
334
335 if let Some(client) = self.result.get() {
337 Ok(client)
338 } else if let Some(error) = self.error.lock().await.take() {
339 Err(error)
340 } else {
341 Err(cdk_common::database::Error::Internal(
342 "Failed connection".to_owned(),
343 ))
344 }
345 }
346}
347
348#[async_trait::async_trait]
349impl DatabaseConnector for PostgresConnection {
350 type Transaction = GenericTransactionHandler<Self>;
351}
352
353#[async_trait::async_trait]
354impl DatabaseExecutor for PostgresConnection {
355 fn name() -> &'static str {
356 "postgres"
357 }
358
359 async fn execute(&self, statement: Statement) -> Result<usize, Error> {
360 pg_execute(self.inner().await?, statement).await
361 }
362
363 async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
364 pg_fetch_one(self.inner().await?, statement).await
365 }
366
367 async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
368 pg_fetch_all(self.inner().await?, statement).await
369 }
370
371 async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
372 pg_pluck(self.inner().await?, statement).await
373 }
374
375 async fn batch(&self, statement: Statement) -> Result<(), Error> {
376 pg_batch(self.inner().await?, statement).await
377 }
378}
379
380pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
382
383pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
385
386pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
388
389pub async fn new_wallet_pg_database(conn_str: &str) -> Result<WalletPgDatabase, Error> {
392 <SQLWalletDatabase<PgConnectionPool>>::new(conn_str).await
393}
394
395#[cfg(test)]
396mod test {
397 use cdk_common::{mint_db_test, wallet_db_test};
398
399 use super::*;
400
401 async fn provide_mint_db(test_id: String) -> MintPgDatabase {
402 let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
403 .or_else(|_| std::env::var("PG_DB_URL")) .unwrap_or(
405 "host=localhost user=cdk_user password=cdk_password dbname=cdk_mint port=5432"
406 .to_owned(),
407 );
408
409 let db_url = format!("{db_url} schema={test_id}");
410
411 MintPgDatabase::new(db_url.as_str())
412 .await
413 .expect("database")
414 }
415
416 mint_db_test!(provide_mint_db);
417
418 async fn provide_wallet_db(test_id: String) -> WalletPgDatabase {
419 let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
420 .or_else(|_| std::env::var("PG_DB_URL")) .unwrap_or(
422 "host=localhost user=cdk_user password=cdk_password dbname=cdk_mint port=5432"
423 .to_owned(),
424 );
425
426 let db_url = format!("{db_url} schema={test_id}");
427
428 WalletPgDatabase::new(db_url.as_str())
429 .await
430 .expect("database")
431 }
432
433 wallet_db_test!(provide_wallet_db);
434}