1use std::{path::PathBuf, str::FromStr, sync::Arc};
2
3use forge_core_utils::assets::asset_dir;
4use sqlx::{
5 Error, Pool, Sqlite, SqlitePool,
6 sqlite::{SqliteConnectOptions, SqliteConnection, SqlitePoolOptions},
7};
8
9pub mod models;
10
11#[derive(Clone)]
12pub struct DBService {
13 pub pool: Pool<Sqlite>,
14}
15
16impl DBService {
17 fn get_database_url() -> String {
19 if let Ok(db_url) = std::env::var("DATABASE_URL") {
20 if db_url.starts_with("sqlite://") {
23 let path_part = db_url.strip_prefix("sqlite://").unwrap();
24 if PathBuf::from(path_part).is_absolute() {
25 db_url
26 } else {
27 let abs_path = std::env::current_dir()
29 .unwrap_or_else(|_| PathBuf::from("."))
30 .join(path_part);
31 Self::format_sqlite_url(&abs_path)
32 }
33 } else {
34 db_url
35 }
36 } else {
37 let db_path = asset_dir().join("db.sqlite");
39 Self::format_sqlite_url(&db_path)
40 }
41 }
42
43 fn format_sqlite_url(path: &PathBuf) -> String {
48 let abs_path = if path.is_absolute() {
50 path.clone()
51 } else {
52 std::env::current_dir()
53 .unwrap_or_else(|_| PathBuf::from("."))
54 .join(path)
55 };
56
57 let abs_path_str = abs_path.to_string_lossy();
58
59 if abs_path_str.starts_with('/') {
63 format!("sqlite://{abs_path_str}")
65 } else if abs_path_str.len() >= 2 && abs_path_str.chars().nth(1) == Some(':') {
66 format!("sqlite:///{abs_path_str}")
68 } else {
69 format!("sqlite://{abs_path_str}")
71 }
72 }
73
74 pub async fn new() -> Result<DBService, Error> {
75 let database_url = Self::get_database_url();
76 let options = SqliteConnectOptions::from_str(&database_url)?.create_if_missing(true);
77 let pool = SqlitePool::connect_with(options).await?;
78 sqlx::migrate!("./migrations").run(&pool).await?;
79 Ok(DBService { pool })
80 }
81
82 pub async fn new_with_after_connect<F>(after_connect: F) -> Result<DBService, Error>
83 where
84 F: for<'a> Fn(
85 &'a mut SqliteConnection,
86 ) -> std::pin::Pin<
87 Box<dyn std::future::Future<Output = Result<(), Error>> + Send + 'a>,
88 > + Send
89 + Sync
90 + 'static,
91 {
92 let pool = Self::create_pool(Some(Arc::new(after_connect))).await?;
93 Ok(DBService { pool })
94 }
95
96 async fn create_pool<F>(after_connect: Option<Arc<F>>) -> Result<Pool<Sqlite>, Error>
97 where
98 F: for<'a> Fn(
99 &'a mut SqliteConnection,
100 ) -> std::pin::Pin<
101 Box<dyn std::future::Future<Output = Result<(), Error>> + Send + 'a>,
102 > + Send
103 + Sync
104 + 'static,
105 {
106 let database_url = Self::get_database_url();
107 let options = SqliteConnectOptions::from_str(&database_url)?.create_if_missing(true);
108
109 let pool = if let Some(hook) = after_connect {
110 SqlitePoolOptions::new()
111 .after_connect(move |conn, _meta| {
112 let hook = hook.clone();
113 Box::pin(async move {
114 hook(conn).await?;
115 Ok(())
116 })
117 })
118 .connect_with(options)
119 .await?
120 } else {
121 SqlitePool::connect_with(options).await?
122 };
123
124 sqlx::migrate!("./migrations").run(&pool).await?;
125 Ok(pool)
126 }
127}