forge_core_db/
lib.rs

1use std::{path::PathBuf, str::FromStr, sync::Arc};
2
3use forge_core_utils::assets::asset_dir;
4use sqlx::{
5    Error, Pool, Sqlite, SqlitePool,
6    sqlite::{SqliteConnectOptions, SqliteConnection, SqlitePoolOptions},
7};
8
9pub mod models;
10
11#[derive(Clone)]
12pub struct DBService {
13    pub pool: Pool<Sqlite>,
14}
15
16impl DBService {
17    /// Get the database URL from environment variable or default to asset_dir
18    fn get_database_url() -> String {
19        if let Ok(db_url) = std::env::var("DATABASE_URL") {
20            // If DATABASE_URL is set, use it
21            // Handle both absolute paths and relative paths
22            if db_url.starts_with("sqlite://") {
23                let path_part = db_url.strip_prefix("sqlite://").unwrap();
24                if PathBuf::from(path_part).is_absolute() {
25                    db_url
26                } else {
27                    // Relative path - resolve from current working directory
28                    let abs_path = std::env::current_dir()
29                        .unwrap_or_else(|_| PathBuf::from("."))
30                        .join(path_part);
31                    Self::format_sqlite_url(&abs_path)
32                }
33            } else {
34                db_url
35            }
36        } else {
37            // Default to asset_dir/db.sqlite
38            let db_path = asset_dir().join("db.sqlite");
39            Self::format_sqlite_url(&db_path)
40        }
41    }
42
43    /// Format a path as a proper SQLite URL
44    /// SQLite URL format: sqlite:// + path
45    /// For absolute paths on Unix (starting with /), this results in sqlite:///path (3 slashes)
46    /// For Windows paths, this results in sqlite://C:/path
47    fn format_sqlite_url(path: &PathBuf) -> String {
48        // Ensure the path is absolute
49        let abs_path = if path.is_absolute() {
50            path.clone()
51        } else {
52            std::env::current_dir()
53                .unwrap_or_else(|_| PathBuf::from("."))
54                .join(path)
55        };
56
57        let abs_path_str = abs_path.to_string_lossy();
58
59        // SQLite URL format: sqlite:// followed by the path
60        // For Unix absolute paths (/home/...), this becomes sqlite:///home/...
61        // The third slash is the root directory indicator
62        if abs_path_str.starts_with('/') {
63            // Unix absolute path - sqlite:// + /path = sqlite:///path
64            format!("sqlite://{abs_path_str}")
65        } else if abs_path_str.len() >= 2 && abs_path_str.chars().nth(1) == Some(':') {
66            // Windows absolute path (C:\...) - needs special handling
67            format!("sqlite:///{abs_path_str}")
68        } else {
69            // Fallback - treat as relative (shouldn't happen after is_absolute check)
70            format!("sqlite://{abs_path_str}")
71        }
72    }
73
74    pub async fn new() -> Result<DBService, Error> {
75        let database_url = Self::get_database_url();
76        let options = SqliteConnectOptions::from_str(&database_url)?.create_if_missing(true);
77        let pool = SqlitePool::connect_with(options).await?;
78        sqlx::migrate!("./migrations").run(&pool).await?;
79        Ok(DBService { pool })
80    }
81
82    pub async fn new_with_after_connect<F>(after_connect: F) -> Result<DBService, Error>
83    where
84        F: for<'a> Fn(
85                &'a mut SqliteConnection,
86            ) -> std::pin::Pin<
87                Box<dyn std::future::Future<Output = Result<(), Error>> + Send + 'a>,
88            > + Send
89            + Sync
90            + 'static,
91    {
92        let pool = Self::create_pool(Some(Arc::new(after_connect))).await?;
93        Ok(DBService { pool })
94    }
95
96    async fn create_pool<F>(after_connect: Option<Arc<F>>) -> Result<Pool<Sqlite>, Error>
97    where
98        F: for<'a> Fn(
99                &'a mut SqliteConnection,
100            ) -> std::pin::Pin<
101                Box<dyn std::future::Future<Output = Result<(), Error>> + Send + 'a>,
102            > + Send
103            + Sync
104            + 'static,
105    {
106        let database_url = Self::get_database_url();
107        let options = SqliteConnectOptions::from_str(&database_url)?.create_if_missing(true);
108
109        let pool = if let Some(hook) = after_connect {
110            SqlitePoolOptions::new()
111                .after_connect(move |conn, _meta| {
112                    let hook = hook.clone();
113                    Box::pin(async move {
114                        hook(conn).await?;
115                        Ok(())
116                    })
117                })
118                .connect_with(options)
119                .await?
120        } else {
121            SqlitePool::connect_with(options).await?
122        };
123
124        sqlx::migrate!("./migrations").run(&pool).await?;
125        Ok(pool)
126    }
127}