1use anyhow::Result;
2use r2d2::Pool;
3use r2d2_sqlite::SqliteConnectionManager;
4use rusqlite::Connection;
5use std::path::Path;
6
7pub type DbPool = Pool<SqliteConnectionManager>;
8
9pub const MIGRATION_COUNT: i32 = 10;
10
11pub struct Database {
12 pool: DbPool,
13}
14
15impl Clone for Database {
16 fn clone(&self) -> Self {
17 Self {
18 pool: self.pool.clone(),
19 }
20 }
21}
22
23impl Database {
24 pub fn open(path: &str) -> Result<Self> {
25 Self::open_with_pool_size(path, 10)
26 }
27
28 pub fn open_with_pool_size(path: &str, pool_size: u32) -> Result<Self> {
29 let path = Path::new(path);
30 if let Some(parent) = path.parent() {
31 if !parent.as_os_str().is_empty() {
32 std::fs::create_dir_all(parent)?;
33 }
34 }
35
36 let manager = SqliteConnectionManager::file(path);
37 let pool = Pool::builder().max_size(pool_size).build(manager)?;
38
39 let conn = pool.get()?;
40 conn.execute_batch(
49 "PRAGMA journal_mode=WAL;
50 PRAGMA foreign_keys=ON;
51 PRAGMA busy_timeout=5000;
52 PRAGMA journal_size_limit=67108864;
53 PRAGMA synchronous=NORMAL;
54 PRAGMA mmap_size=134217728;
55 PRAGMA cache_size=-65536;",
56 )?;
57
58 Ok(Self { pool })
59 }
60
61 pub fn health_check(&self) -> Result<bool> {
64 let conn = self.get()?;
65 let result: i32 = conn.query_row("SELECT 1", [], |row| row.get(0))?;
66 Ok(result == 1)
67 }
68
69 pub fn open_memory(name: &str) -> Result<Self> {
70 let uri = format!("file:{}?mode=memory&cache=shared", name);
71 let manager = SqliteConnectionManager::file(&uri);
72 let pool = Pool::builder()
73 .max_size(5)
74 .connection_timeout(std::time::Duration::from_secs(5))
75 .build(manager)?;
76
77 let conn = pool.get()?;
78 conn.execute_batch("PRAGMA foreign_keys=ON;")?;
79
80 Ok(Self { pool })
81 }
82
83 pub fn get(&self) -> Result<r2d2::PooledConnection<SqliteConnectionManager>> {
84 Ok(self.pool.get()?)
85 }
86
87 pub fn migrate(&self) -> Result<()> {
88 let conn = self.get()?;
89 run_migrations(&conn)?;
90 Ok(())
91 }
92
93 pub fn get_migration_status(&self) -> Result<Vec<(i32, Option<String>)>> {
96 let conn = self.get()?;
97
98 conn.execute_batch(
100 "CREATE TABLE IF NOT EXISTS schema_migrations (
101 version INTEGER PRIMARY KEY,
102 applied_at TEXT DEFAULT CURRENT_TIMESTAMP
103 );",
104 )?;
105
106 let total_migrations = 10;
107 let mut result = Vec::with_capacity(total_migrations);
108
109 for version in 1..=total_migrations as i32 {
110 let applied_at: Option<String> = conn
111 .query_row(
112 "SELECT applied_at FROM schema_migrations WHERE version = ?1",
113 [version],
114 |row| row.get(0),
115 )
116 .ok();
117
118 result.push((version, applied_at));
119 }
120
121 Ok(result)
122 }
123
124 pub fn rollback_migration(&self, version: i32) -> Result<()> {
127 let conn = self.get()?;
128
129 conn.execute_batch("PRAGMA foreign_keys=OFF;")?;
131
132 let rollback_sql = get_rollback_sql(version)?;
133 conn.execute_batch(rollback_sql)?;
134 conn.execute(
135 "DELETE FROM schema_migrations WHERE version = ?1",
136 [version],
137 )?;
138
139 conn.execute_batch("PRAGMA foreign_keys=ON;")?;
141
142 Ok(())
143 }
144}
145
146fn run_migrations(conn: &Connection) -> Result<()> {
147 conn.execute_batch(
148 r#"
149 CREATE TABLE IF NOT EXISTS schema_migrations (
150 version INTEGER PRIMARY KEY,
151 applied_at TEXT DEFAULT CURRENT_TIMESTAMP
152 );
153 "#,
154 )?;
155
156 let current_version: i32 = conn
157 .query_row(
158 "SELECT COALESCE(MAX(version), 0) FROM schema_migrations",
159 [],
160 |row| row.get(0),
161 )
162 .unwrap_or(0);
163
164 let migrations: Vec<(i32, &str)> = vec![
165 (1, include_str!("migrations/001_initial.sql")),
166 (2, include_str!("migrations/002_fts.sql")),
167 (3, include_str!("migrations/003_media_optimization.sql")),
168 (4, include_str!("migrations/004_scheduled_publishing.sql")),
169 (5, include_str!("migrations/005_analytics.sql")),
170 (6, include_str!("migrations/006_content_versions.sql")),
171 (7, include_str!("migrations/007_audit_log.sql")),
172 (8, include_str!("migrations/008_preview_tokens.sql")),
173 (9, include_str!("migrations/009_content_series.sql")),
174 (10, include_str!("migrations/010_api_and_webhooks.sql")),
175 ];
176
177 for (version, sql) in migrations {
178 if version > current_version {
179 tracing::info!("Running migration {}", version);
180 conn.execute_batch(sql)?;
181 conn.execute(
182 "INSERT INTO schema_migrations (version) VALUES (?)",
183 [version],
184 )?;
185 }
186 }
187
188 Ok(())
189}
190
191fn get_rollback_sql(version: i32) -> Result<&'static str> {
192 match version {
193 1 => Ok(include_str!("migrations/001_rollback.sql")),
194 2 => Ok(include_str!("migrations/002_rollback.sql")),
195 3 => Ok(include_str!("migrations/003_rollback.sql")),
196 4 => Ok(include_str!("migrations/004_rollback.sql")),
197 5 => Ok(include_str!("migrations/005_rollback.sql")),
198 6 => Ok(include_str!("migrations/006_rollback.sql")),
199 7 => Ok(include_str!("migrations/007_rollback.sql")),
200 8 => Ok(include_str!("migrations/008_rollback.sql")),
201 9 => Ok(include_str!("migrations/009_rollback.sql")),
202 10 => Ok(include_str!("migrations/010_rollback.sql")),
203 _ => anyhow::bail!("No rollback SQL for migration version {}", version),
204 }
205}