1use core::fmt::{Debug, Formatter};
2use core::str::FromStr;
3use std::future::Future;
4use std::ops::Deref;
5use std::path::{Path, PathBuf};
6use std::time::Duration;
7
8use ockam_core::errcode::{Kind, Origin};
9use sqlx::any::{install_default_drivers, AnyConnectOptions};
10use sqlx::pool::PoolOptions;
11use sqlx::{Any, ConnectOptions, Pool};
12use sqlx_core::any::AnyConnection;
13use sqlx_core::executor::Executor;
14use sqlx_core::row::Row;
15use tempfile::NamedTempFile;
16use tokio_retry::strategy::{jitter, FixedInterval};
17use tokio_retry::Retry;
18use tracing::debug;
19use tracing::log::LevelFilter;
20
21use crate::database::database_configuration::DatabaseConfiguration;
22use crate::database::migrations::application_migration_set::ApplicationMigrationSet;
23use crate::database::migrations::node_migration_set::NodeMigrationSet;
24use crate::database::migrations::MigrationSet;
25use crate::database::{DatabaseType, MigrationStatus};
26use ockam_core::compat::rand::random_string;
27use ockam_core::compat::sync::Arc;
28use ockam_core::{Error, Result};
29
30#[derive(Clone)]
37pub struct SqlxDatabase {
38 pub pool: Arc<Pool<Any>>,
40 pub configuration: DatabaseConfiguration,
42}
43
44impl Debug for SqlxDatabase {
45 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
46 f.write_str(format!("database options {:?}", self.pool.connect_options()).as_str())
47 }
48}
49
50impl Deref for SqlxDatabase {
51 type Target = Pool<Any>;
52
53 fn deref(&self) -> &Self::Target {
54 &self.pool
55 }
56}
57
58impl SqlxDatabase {
59 pub async fn create(configuration: &DatabaseConfiguration) -> Result<Self> {
61 Self::create_impl(
62 configuration,
63 Some(NodeMigrationSet::new(configuration.database_type())),
64 )
65 .await
66 }
67
68 pub async fn create_application_database(
70 configuration: &DatabaseConfiguration,
71 ) -> Result<Self> {
72 Self::create_impl(
73 configuration,
74 Some(ApplicationMigrationSet::new(configuration.database_type())),
75 )
76 .await
77 }
78
79 pub async fn create_sqlite(path: impl AsRef<Path>) -> Result<Self> {
81 Self::create(&DatabaseConfiguration::sqlite(path)).await
82 }
83
84 pub async fn create_sqlite_no_migration(path: impl AsRef<Path>) -> Result<Self> {
86 Self::create_no_migration(&DatabaseConfiguration::sqlite(path)).await
87 }
88
89 pub async fn create_application_sqlite(path: impl AsRef<Path>) -> Result<Self> {
91 Self::create_application_database(&DatabaseConfiguration::sqlite(path)).await
92 }
93
94 pub async fn create_postgres_no_migration(legacy_sqlite_path: Option<PathBuf>) -> Result<Self> {
96 match DatabaseConfiguration::postgres_with_legacy_sqlite_path(legacy_sqlite_path)? {
97 Some(configuration) => Self::create_no_migration(&configuration).await,
98 None => Err(Error::new(Origin::Core, Kind::NotFound, "There is no postgres database configuration, or it is incomplete. Please run ockam environment to check the database environment variables".to_string())),
99 }
100 }
101
102 pub async fn create_new_postgres() -> Result<Self> {
104 match DatabaseConfiguration::postgres()? {
105 Some(configuration) => {
106 let db = Self::create_no_migration(&configuration).await?;
107 db.drop_all_postgres_tables().await?;
108 SqlxDatabase::create(&configuration).await
109 },
110 None => Err(Error::new(Origin::Core, Kind::NotFound, "There is no postgres database configuration, or it is incomplete. Please run ockam environment to check the database environment variables".to_string())),
111 }
112 }
113
114 pub async fn create_new_application_postgres() -> Result<Self> {
116 match DatabaseConfiguration::postgres()? {
117 Some(configuration) => {
118 let db = Self::create_application_no_migration(&configuration).await?;
119 db.drop_all_postgres_tables().await?;
120 SqlxDatabase::create_application_database(&configuration).await
121 },
122 None => Err(Error::new(Origin::Core, Kind::NotFound, "There is no postgres database configuration, or it is incomplete. Please run ockam environment to check the database environment variables".to_string())),
123 }
124 }
125
126 pub async fn create_with_migration(
128 configuration: &DatabaseConfiguration,
129 migration_set: impl MigrationSet,
130 ) -> Result<Self> {
131 Self::create_impl(configuration, Some(migration_set)).await
132 }
133
134 pub async fn create_no_migration(configuration: &DatabaseConfiguration) -> Result<Self> {
136 Self::create_impl(configuration, None::<NodeMigrationSet>).await
137 }
138
139 pub async fn create_application_no_migration(
141 configuration: &DatabaseConfiguration,
142 ) -> Result<Self> {
143 Self::create_impl(configuration, None::<ApplicationMigrationSet>).await
144 }
145
146 async fn create_impl(
147 configuration: &DatabaseConfiguration,
148 migration_set: Option<impl MigrationSet>,
149 ) -> Result<Self> {
150 debug!("Creating SQLx database using configuration");
151
152 configuration.create_directory_if_necessary()?;
153
154 let retry_strategy = FixedInterval::from_millis(1000)
158 .map(jitter) .take(10); let database = if configuration.database_type() == DatabaseType::Sqlite
163 && configuration.path().is_some()
164 {
165 if let Some(migration_set) = migration_set {
166 let migration_config = configuration.single_connection();
173
174 let database = Retry::spawn(retry_strategy.clone(), || async {
175 match Self::create_at(&migration_config).await {
176 Ok(db) => Ok(db),
177 Err(e) => {
178 println!("{e:?}");
179 Err(e)
180 }
181 }
182 })
183 .await?;
184
185 let migrator = migration_set.create_migrator()?;
186 let status = migrator.migrate(&database.pool).await?;
187 database.close().await;
188 match status {
189 MigrationStatus::UpToDate(_) => (),
190 MigrationStatus::Todo(_, _) => (),
191 MigrationStatus::Failed(version, reason) => Err(Error::new(
192 Origin::Node,
193 Kind::Conflict,
194 format!(
195 "Sql migration previously failed for version {}. Reason: {}",
196 version, reason
197 ),
198 ))?,
199 }
200 };
201
202 Retry::spawn(retry_strategy, || async {
204 match Self::create_at(configuration).await {
205 Ok(db) => Ok(db),
206 Err(e) => {
207 println!("{e:?}");
208 Err(e)
209 }
210 }
211 })
212 .await?
213 } else {
214 let database = Retry::spawn(retry_strategy, || async {
215 match Self::create_at(configuration).await {
216 Ok(db) => Ok(db),
217 Err(e) => {
218 println!("{e:?}");
219 Err(e)
220 }
221 }
222 })
223 .await?;
224
225 let migrate_database = if configuration.database_type() == DatabaseType::Postgres {
229 let database_schema_already_created: bool = sqlx::query("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'identity')")
230 .fetch_one(&*database.pool)
231 .await.into_core()?.get(0);
232 !database_schema_already_created
233 } else {
234 true
235 };
236
237 if migrate_database {
238 if let Some(migration_set) = migration_set {
239 let migrator = migration_set.create_migrator()?;
240 migrator.migrate(&database.pool).await?;
241 }
242 }
243
244 database
245 };
246
247 Ok(database)
248 }
249
250 pub async fn in_memory(usage: &str) -> Result<Self> {
253 Self::in_memory_with_migration(usage, NodeMigrationSet::new(DatabaseType::Sqlite)).await
254 }
255
256 pub async fn application_in_memory(usage: &str) -> Result<Self> {
260 Self::in_memory_with_migration(usage, ApplicationMigrationSet::new(DatabaseType::Sqlite))
261 .await
262 }
263
264 pub async fn in_memory_with_migration(
266 usage: &str,
267 migration_set: impl MigrationSet,
268 ) -> Result<Self> {
269 debug!("create an in memory database for {usage}");
270 let configuration = DatabaseConfiguration::sqlite_in_memory();
271 let pool = Self::create_in_memory_connection_pool().await?;
272 let migrator = migration_set.create_migrator()?;
273 migrator.migrate(&pool).await?;
274 let db = SqlxDatabase {
276 pool: Arc::new(pool),
277 configuration,
278 };
279 Ok(db)
280 }
281
282 pub fn needs_retry(&self) -> bool {
285 matches!(
286 self.configuration,
287 DatabaseConfiguration::SqlitePersistent { .. }
288 )
289 }
290
291 async fn create_at(configuration: &DatabaseConfiguration) -> Result<Self> {
292 let pool = Self::create_connection_pool(configuration).await?;
294 Ok(SqlxDatabase {
295 pool: Arc::new(pool),
296 configuration: configuration.clone(),
297 })
298 }
299
300 pub(crate) async fn create_connection_pool(
301 configuration: &DatabaseConfiguration,
302 ) -> Result<Pool<Any>> {
303 install_default_drivers();
304 let connection_string = configuration.connection_string();
305 debug!("connecting to {connection_string}");
306 let options = AnyConnectOptions::from_str(&connection_string)
307 .map_err(Self::map_sql_err)?
308 .log_statements(LevelFilter::Trace)
309 .log_slow_statements(LevelFilter::Trace, Duration::from_secs(1));
310
311 const MAX_POOL_SIZE: u32 = 16;
314
315 let max_pool_size = match configuration {
316 DatabaseConfiguration::SqlitePersistent {
317 single_connection, ..
318 }
319 | DatabaseConfiguration::SqliteInMemory { single_connection } => {
320 if *single_connection {
321 1
322 } else {
323 MAX_POOL_SIZE
324 }
325 }
326 _ => MAX_POOL_SIZE,
327 };
328
329 let pool_options = PoolOptions::new()
330 .max_connections(max_pool_size)
331 .min_connections(1);
332
333 let pool_options = if configuration.database_type() == DatabaseType::Sqlite {
334 pool_options.after_connect(|connection: &mut AnyConnection, _metadata| {
336 Box::pin(async move {
337 let _ = connection
343 .execute(
344 r#"
345PRAGMA synchronous = EXTRA;
346PRAGMA locking_mode = NORMAL;
347PRAGMA busy_timeout = 10000;
348 "#,
349 )
350 .await
351 .expect("Failed to set SQLite configuration");
352
353 Ok(())
354 })
355 })
356 } else {
357 pool_options
358 };
359
360 let pool = pool_options
361 .connect_with(options)
362 .await
363 .map_err(Self::map_sql_err)?;
364
365 Ok(pool)
366 }
367
368 pub async fn create_sqlite_single_connection_pool(path: impl AsRef<Path>) -> Result<Pool<Any>> {
370 Self::create_connection_pool(&DatabaseConfiguration::sqlite(path).single_connection()).await
371 }
372
373 pub(crate) async fn create_in_memory_connection_pool() -> Result<Pool<Any>> {
374 install_default_drivers();
375 let file_name = random_string();
378 let options = AnyConnectOptions::from_str(
379 format!("sqlite:file:{file_name}?mode=memory&cache=shared").as_str(),
380 )
381 .map_err(Self::map_sql_err)?
382 .log_statements(LevelFilter::Trace)
383 .log_slow_statements(LevelFilter::Trace, Duration::from_secs(1));
384 let pool_options = PoolOptions::new().idle_timeout(None).max_lifetime(None);
385
386 let pool = pool_options
387 .connect_with(options)
388 .await
389 .map_err(Self::map_sql_err)?;
390 Ok(pool)
391 }
392
393 pub fn path(&self) -> Option<PathBuf> {
395 self.configuration.path()
396 }
397
398 #[track_caller]
400 pub fn map_sql_err(err: sqlx::Error) -> Error {
401 Error::new(Origin::Application, Kind::Io, err)
402 }
403
404 #[track_caller]
406 pub fn map_decode_err(err: minicbor::decode::Error) -> Error {
407 Error::new(Origin::Application, Kind::Io, err)
408 }
409
410 pub async fn drop_all_postgres_tables(&self) -> Result<()> {
412 self.clean_postgres_node_tables(Clean::Drop, None).await
413 }
414
415 pub async fn truncate_all_postgres_tables(&self) -> Result<()> {
417 self.clean_postgres_node_tables(Clean::Truncate, None).await
418 }
419
420 pub async fn drop_postgres_node_tables(&self) -> Result<()> {
422 self.clean_postgres_node_tables(Clean::Drop, Some("AND tablename NOT LIKE '%journey%'"))
423 .await
424 }
425
426 pub async fn truncate_postgres_node_tables(&self) -> Result<()> {
428 self.clean_postgres_node_tables(Clean::Truncate, Some("AND tablename NOT LIKE '%journey%'"))
429 .await
430 }
431
432 async fn clean_postgres_node_tables(&self, clean: Clean, filter: Option<&str>) -> Result<()> {
434 match self.configuration.database_type() {
435 DatabaseType::Sqlite => Ok(()),
436 DatabaseType::Postgres => {
437 sqlx::query(
438 format!(r#"DO $$
439 DECLARE
440 r RECORD;
441 BEGIN
442 FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public' {}) LOOP
443 EXECUTE '{} TABLE ' || quote_ident(r.tablename) || ' CASCADE';
444 END LOOP;
445 END $$;"#, filter.unwrap_or(""), clean.as_str(),
446 ).as_str())
447 .execute(&*self.pool)
448 .await
449 .void()
450 }
451 }
452 }
453}
454
455enum Clean {
456 Drop,
457 Truncate,
458}
459
460impl Clean {
461 fn as_str(&self) -> &str {
462 match self {
463 Clean::Drop => "DROP",
464 Clean::Truncate => "TRUNCATE",
465 }
466 }
467}
468
469pub async fn with_sqlite_dbs<F, Fut>(f: F) -> Result<()>
471where
472 F: Fn(SqlxDatabase) -> Fut + Send + Sync + 'static,
473 Fut: Future<Output = Result<()>> + Send + 'static,
474{
475 let db = SqlxDatabase::in_memory("test").await?;
476 rethrow("SQLite in memory", f(db)).await?;
477
478 let db_file = NamedTempFile::new().unwrap();
479 let db = SqlxDatabase::create_sqlite(db_file.path()).await?;
480 rethrow("SQLite on disk", f(db)).await?;
481 Ok(())
482}
483
484pub async fn with_dbs<F, Fut>(f: F) -> Result<()>
486where
487 F: Fn(SqlxDatabase) -> Fut + Send + Sync + 'static,
488 Fut: Future<Output = Result<()>> + Send + 'static,
489{
490 let db = SqlxDatabase::in_memory("test").await?;
491 rethrow("SQLite in memory", f(db)).await?;
492
493 let db_file = NamedTempFile::new().unwrap();
494 let db = SqlxDatabase::create_sqlite(db_file.path()).await?;
495 rethrow("SQLite on disk", f(db)).await?;
496
497 with_postgres(f).await?;
499 Ok(())
500}
501
502pub async fn with_postgres<F, Fut>(f: F) -> Result<()>
504where
505 F: Fn(SqlxDatabase) -> Fut + Send + Sync + 'static,
506 Fut: Future<Output = Result<()>> + Send + 'static,
507{
508 if let Ok(db) = SqlxDatabase::create_new_postgres().await {
510 db.truncate_all_postgres_tables().await?;
511 rethrow("Postgres local", f(db.clone())).await?;
512 };
513 Ok(())
514}
515
516pub async fn skip_if_postgres<F, Fut, R>(f: F) -> std::result::Result<(), R>
518where
519 F: Fn() -> Fut + Send + Sync + 'static,
520 Fut: Future<Output = std::result::Result<(), R>> + Send + 'static,
521 R: From<Error>,
522{
523 if DatabaseConfiguration::postgres()?.is_none() {
525 f().await?
526 };
527 Ok(())
528}
529
530pub async fn with_application_dbs<F, Fut>(f: F) -> Result<()>
533where
534 F: Fn(SqlxDatabase) -> Fut + Send + Sync + 'static,
535 Fut: Future<Output = Result<()>> + Send + 'static,
536{
537 let db = SqlxDatabase::application_in_memory("test").await?;
538 rethrow("SQLite in memory", f(db)).await?;
539
540 let db_file = NamedTempFile::new().unwrap();
541 let db = SqlxDatabase::create_application_sqlite(db_file.path()).await?;
542 rethrow("SQLite on disk", f(db)).await?;
543
544 if let Ok(db) = SqlxDatabase::create_new_application_postgres().await {
546 rethrow("Postgres local", f(db.clone())).await?;
547 db.drop_all_postgres_tables().await?;
548 }
549 Ok(())
550}
551
552async fn rethrow<Fut>(database_type: &str, f: Fut) -> Result<()>
554where
555 Fut: Future<Output = Result<()>> + Send + 'static,
556{
557 f.await.map_err(|e| {
558 Error::new(
559 Origin::Core,
560 Kind::Invalid,
561 format!("{database_type}: {e:?}"),
562 )
563 })
564}
565
566pub trait FromSqlxError<T> {
568 fn into_core(self) -> Result<T>;
570}
571
572impl<T> FromSqlxError<T> for core::result::Result<T, sqlx::error::Error> {
573 #[track_caller]
574 fn into_core(self) -> Result<T> {
575 match self {
576 Ok(r) => Ok(r),
577 Err(err) => {
578 let err = Error::new(Origin::Api, Kind::Internal, err.to_string());
579 Err(err)
580 }
581 }
582 }
583}
584
585impl<T> FromSqlxError<T> for core::result::Result<T, sqlx::migrate::MigrateError> {
586 #[track_caller]
587 fn into_core(self) -> Result<T> {
588 match self {
589 Ok(r) => Ok(r),
590 Err(err) => Err(Error::new(
591 Origin::Application,
592 Kind::Io,
593 format!("migration error {err}"),
594 )),
595 }
596 }
597}
598
599pub trait ToVoid<T> {
601 fn void(self) -> Result<()>;
603}
604
605impl<T> ToVoid<T> for core::result::Result<T, sqlx::error::Error> {
606 #[track_caller]
607 fn void(self) -> Result<()> {
608 self.map(|_| ()).into_core()
609 }
610}
611
612pub fn create_temp_db_file() -> Result<PathBuf> {
614 let (_, path) = NamedTempFile::new()
615 .map_err(|e| Error::new(Origin::Core, Kind::Io, format!("{e:?}")))?
616 .keep()
617 .map_err(|e| Error::new(Origin::Core, Kind::Io, format!("{e:?}")))?;
618 Ok(path)
619}
620
621#[cfg(test)]
622#[allow(missing_docs)]
623pub mod tests {
624 use super::*;
625 use crate::database::Boolean;
626 use sqlx::any::AnyQueryResult;
627 use sqlx::FromRow;
628
629 #[tokio::test]
632 async fn test_create_sqlite_database() -> Result<()> {
633 let db_file = NamedTempFile::new().unwrap();
634 let db = SqlxDatabase::create_sqlite(db_file.path()).await?;
635
636 let inserted = insert_identity(&db).await.unwrap();
637
638 assert_eq!(inserted.rows_affected(), 1);
639 Ok(())
640 }
641
642 #[tokio::test]
644 async fn test_create_postgres_database() -> Result<()> {
645 if let Some(configuration) = DatabaseConfiguration::postgres()? {
646 let db = SqlxDatabase::create_no_migration(&configuration).await?;
647 db.drop_all_postgres_tables().await?;
648
649 let db = SqlxDatabase::create(&configuration).await?;
650 let inserted = insert_identity(&db).await.unwrap();
651 assert_eq!(inserted.rows_affected(), 1);
652 }
653 Ok(())
654 }
655
656 #[tokio::test]
658 async fn test_query() -> Result<()> {
659 with_dbs(|db| async move {
660 insert_identity(&db).await.unwrap();
661
662 let result: Option<IdentifierRow> =
664 sqlx::query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE identifier = $1")
665 .bind("Ifa804b7fca12a19eed206ae180b5b576860ae651")
666 .fetch_optional(&*db.pool)
667 .await
668 .unwrap();
669 assert_eq!(
670 result,
671 Some(IdentifierRow {
672 identifier: "Ifa804b7fca12a19eed206ae180b5b576860ae651".into(),
673 name: "identity-1".to_string(),
674 vault_name: "vault-1".to_string(),
675 is_default: Boolean::new(true),
678 })
679 );
680
681 let result: Option<IdentifierRow> =
683 sqlx::query_as("SELECT identifier FROM named_identity WHERE identifier = $1")
684 .bind("x")
685 .fetch_optional(&*db.pool)
686 .await
687 .unwrap();
688 assert_eq!(result, None);
689 Ok(())
690 }).await
691 }
692
693 #[tokio::test]
694 async fn test_create_pool_with_relative_and_absolute_paths() -> Result<()> {
695 install_default_drivers();
696 let relative = Path::new("relative");
697 let connection_string = DatabaseConfiguration::sqlite(relative).connection_string();
698 let options =
699 AnyConnectOptions::from_str(&connection_string).map_err(SqlxDatabase::map_sql_err)?;
700
701 let pool = Pool::<Any>::connect_with(options)
702 .await
703 .map_err(SqlxDatabase::map_sql_err);
704 assert!(pool.is_ok());
705
706 let absolute = std::fs::canonicalize(relative).unwrap();
707 let connection_string = DatabaseConfiguration::sqlite(&absolute).connection_string();
708 let options =
709 AnyConnectOptions::from_str(&connection_string).map_err(SqlxDatabase::map_sql_err)?;
710
711 let pool = Pool::<Any>::connect_with(options)
712 .await
713 .map_err(SqlxDatabase::map_sql_err);
714 assert!(pool.is_ok());
715
716 let _ = std::fs::remove_file(absolute);
717
718 Ok(())
719 }
720
721 async fn insert_identity(db: &SqlxDatabase) -> Result<AnyQueryResult> {
723 sqlx::query("INSERT INTO named_identity (identifier, name, vault_name, is_default) VALUES ($1, $2, $3, $4)")
724 .bind("Ifa804b7fca12a19eed206ae180b5b576860ae651")
725 .bind("identity-1")
726 .bind("vault-1")
727 .bind(true)
728 .execute(&*db.pool)
729 .await
730 .into_core()
731 }
732
733 #[derive(FromRow, PartialEq, Eq, Debug)]
734 struct IdentifierRow {
735 identifier: String,
736 name: String,
737 vault_name: String,
738 is_default: Boolean,
739 }
740}