1use anyhow::Result;
2use r2d2::Pool;
3use r2d2_sqlite::SqliteConnectionManager;
4use rusqlite::Connection;
5use std::path::Path;
6
7pub type DbPool = Pool<SqliteConnectionManager>;
8
9pub struct Database {
10 pool: DbPool,
11}
12
13impl Clone for Database {
14 fn clone(&self) -> Self {
15 Self {
16 pool: self.pool.clone(),
17 }
18 }
19}
20
21impl Database {
22 pub fn open(path: &str) -> Result<Self> {
23 Self::open_with_pool_size(path, 10)
24 }
25
26 pub fn open_with_pool_size(path: &str, pool_size: u32) -> Result<Self> {
27 let path = Path::new(path);
28 if let Some(parent) = path.parent() {
29 if !parent.as_os_str().is_empty() {
30 std::fs::create_dir_all(parent)?;
31 }
32 }
33
34 let manager = SqliteConnectionManager::file(path);
35 let pool = Pool::builder().max_size(pool_size).build(manager)?;
36
37 let conn = pool.get()?;
38 conn.execute_batch(
47 "PRAGMA journal_mode=WAL;
48 PRAGMA foreign_keys=ON;
49 PRAGMA busy_timeout=5000;
50 PRAGMA journal_size_limit=67108864;
51 PRAGMA synchronous=NORMAL;
52 PRAGMA mmap_size=134217728;
53 PRAGMA cache_size=-65536;",
54 )?;
55
56 Ok(Self { pool })
57 }
58
59 pub fn health_check(&self) -> Result<bool> {
62 let conn = self.get()?;
63 let result: i32 = conn.query_row("SELECT 1", [], |row| row.get(0))?;
64 Ok(result == 1)
65 }
66
67 pub fn open_memory(name: &str) -> Result<Self> {
68 let uri = format!("file:{}?mode=memory&cache=shared", name);
69 let manager = SqliteConnectionManager::file(&uri);
70 let pool = Pool::builder()
71 .max_size(5)
72 .connection_timeout(std::time::Duration::from_secs(5))
73 .build(manager)?;
74
75 let conn = pool.get()?;
76 conn.execute_batch("PRAGMA foreign_keys=ON;")?;
77
78 Ok(Self { pool })
79 }
80
81 pub fn get(&self) -> Result<r2d2::PooledConnection<SqliteConnectionManager>> {
82 Ok(self.pool.get()?)
83 }
84
85 pub fn migrate(&self) -> Result<()> {
86 let conn = self.get()?;
87 run_migrations(&conn)?;
88 Ok(())
89 }
90
91 pub fn get_migration_status(&self) -> Result<Vec<(i32, Option<String>)>> {
94 let conn = self.get()?;
95
96 conn.execute_batch(
98 "CREATE TABLE IF NOT EXISTS schema_migrations (
99 version INTEGER PRIMARY KEY,
100 applied_at TEXT DEFAULT CURRENT_TIMESTAMP
101 );",
102 )?;
103
104 let total_migrations = 10;
105 let mut result = Vec::with_capacity(total_migrations);
106
107 for version in 1..=total_migrations as i32 {
108 let applied_at: Option<String> = conn
109 .query_row(
110 "SELECT applied_at FROM schema_migrations WHERE version = ?1",
111 [version],
112 |row| row.get(0),
113 )
114 .ok();
115
116 result.push((version, applied_at));
117 }
118
119 Ok(result)
120 }
121
122 pub fn rollback_migration(&self, version: i32) -> Result<()> {
125 let conn = self.get()?;
126
127 conn.execute_batch("PRAGMA foreign_keys=OFF;")?;
129
130 let rollback_sql = get_rollback_sql(version)?;
131 conn.execute_batch(rollback_sql)?;
132 conn.execute(
133 "DELETE FROM schema_migrations WHERE version = ?1",
134 [version],
135 )?;
136
137 conn.execute_batch("PRAGMA foreign_keys=ON;")?;
139
140 Ok(())
141 }
142}
143
144fn run_migrations(conn: &Connection) -> Result<()> {
145 conn.execute_batch(
146 r#"
147 CREATE TABLE IF NOT EXISTS schema_migrations (
148 version INTEGER PRIMARY KEY,
149 applied_at TEXT DEFAULT CURRENT_TIMESTAMP
150 );
151 "#,
152 )?;
153
154 let current_version: i32 = conn
155 .query_row(
156 "SELECT COALESCE(MAX(version), 0) FROM schema_migrations",
157 [],
158 |row| row.get(0),
159 )
160 .unwrap_or(0);
161
162 let migrations: Vec<(i32, &str)> = vec![
163 (1, include_str!("migrations/001_initial.sql")),
164 (2, include_str!("migrations/002_fts.sql")),
165 (3, include_str!("migrations/003_media_optimization.sql")),
166 (4, include_str!("migrations/004_scheduled_publishing.sql")),
167 (5, include_str!("migrations/005_analytics.sql")),
168 (6, include_str!("migrations/006_content_versions.sql")),
169 (7, include_str!("migrations/007_audit_log.sql")),
170 (8, include_str!("migrations/008_preview_tokens.sql")),
171 (9, include_str!("migrations/009_content_series.sql")),
172 (10, include_str!("migrations/010_api_and_webhooks.sql")),
173 ];
174
175 for (version, sql) in migrations {
176 if version > current_version {
177 tracing::info!("Running migration {}", version);
178 conn.execute_batch(sql)?;
179 conn.execute(
180 "INSERT INTO schema_migrations (version) VALUES (?)",
181 [version],
182 )?;
183 }
184 }
185
186 Ok(())
187}
188
189fn get_rollback_sql(version: i32) -> Result<&'static str> {
190 match version {
191 1 => Ok(include_str!("migrations/001_rollback.sql")),
192 2 => Ok(include_str!("migrations/002_rollback.sql")),
193 3 => Ok(include_str!("migrations/003_rollback.sql")),
194 4 => Ok(include_str!("migrations/004_rollback.sql")),
195 5 => Ok(include_str!("migrations/005_rollback.sql")),
196 6 => Ok(include_str!("migrations/006_rollback.sql")),
197 7 => Ok(include_str!("migrations/007_rollback.sql")),
198 8 => Ok(include_str!("migrations/008_rollback.sql")),
199 9 => Ok(include_str!("migrations/009_rollback.sql")),
200 10 => Ok(include_str!("migrations/010_rollback.sql")),
201 _ => anyhow::bail!("No rollback SQL for migration version {}", version),
202 }
203}