1use std::path::Path;
37use std::str::FromStr;
38
39use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
40use sqlx::SqlitePool;
41use tracing::{debug, info};
42
43use starpod_core::{Result, StarpodError};
44
45pub struct CoreDb {
62 pool: SqlitePool,
63}
64
65impl CoreDb {
66 pub async fn new(db_dir: &Path) -> Result<Self> {
73 std::fs::create_dir_all(db_dir)?;
74
75 let db_path = db_dir.join("core.db");
76
77 match Self::open_and_migrate(&db_path).await {
79 Ok(pool) => {
80 debug!("core.db ready at {}", db_path.display());
81 Ok(Self { pool })
82 }
83 Err(e) => {
84 let msg = e.to_string();
85 let is_schema_mismatch = msg.contains("previously applied but is missing")
86 || msg.contains("checksum mismatch");
87
88 if !is_schema_mismatch {
89 return Err(e);
90 }
91
92 info!("Migration schema changed — recreating core.db");
93 let db_str = db_path.display().to_string();
95 let _ = std::fs::remove_file(&db_path);
96 let _ = std::fs::remove_file(format!("{db_str}-wal"));
97 let _ = std::fs::remove_file(format!("{db_str}-shm"));
98
99 let pool = Self::open_and_migrate(&db_path).await?;
100 debug!("core.db recreated at {}", db_path.display());
101
102 Ok(Self { pool })
103 }
104 }
105 }
106
107 async fn open_and_migrate(db_path: &Path) -> Result<SqlitePool> {
109 let opts =
110 SqliteConnectOptions::from_str(&format!("sqlite://{}?mode=rwc", db_path.display()))
111 .map_err(|e| StarpodError::Database(format!("Invalid DB path: {}", e)))?
112 .pragma("journal_mode", "WAL")
113 .pragma("foreign_keys", "ON")
114 .pragma("busy_timeout", "5000")
115 .pragma("synchronous", "NORMAL");
116
117 let pool = SqlitePoolOptions::new()
118 .max_connections(2)
119 .connect_with(opts)
120 .await
121 .map_err(|e| StarpodError::Database(format!("Failed to open core db: {}", e)))?;
122
123 sqlx::migrate!("./migrations")
124 .run(&pool)
125 .await
126 .map_err(|e| StarpodError::Database(format!("Core migration failed: {}", e)))?;
127
128 Ok(pool)
129 }
130
131 pub async fn in_memory() -> Result<Self> {
136 let opts = SqliteConnectOptions::from_str("sqlite::memory:")
137 .map_err(|e| StarpodError::Database(format!("Invalid memory DB: {}", e)))?
138 .pragma("foreign_keys", "ON");
139
140 let pool = SqlitePoolOptions::new()
141 .max_connections(1)
142 .connect_with(opts)
143 .await
144 .map_err(|e| StarpodError::Database(format!("Failed to open in-memory db: {}", e)))?;
145
146 sqlx::migrate!("./migrations")
147 .run(&pool)
148 .await
149 .map_err(|e| StarpodError::Database(format!("Core migration failed: {}", e)))?;
150
151 Ok(Self { pool })
152 }
153
154 pub fn pool(&self) -> &SqlitePool {
156 &self.pool
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[tokio::test]
167 async fn in_memory_creates_all_tables() {
168 let db = CoreDb::in_memory().await.unwrap();
169 let pool = db.pool();
170
171 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
173 .fetch_one(pool)
174 .await
175 .unwrap();
176 assert_eq!(row.0, 0);
177
178 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM api_keys")
179 .fetch_one(pool)
180 .await
181 .unwrap();
182 assert_eq!(row.0, 0);
183
184 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM telegram_links")
185 .fetch_one(pool)
186 .await
187 .unwrap();
188 assert_eq!(row.0, 0);
189
190 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM auth_audit_log")
191 .fetch_one(pool)
192 .await
193 .unwrap();
194 assert_eq!(row.0, 0);
195
196 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM session_metadata")
198 .fetch_one(pool)
199 .await
200 .unwrap();
201 assert_eq!(row.0, 0);
202
203 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM session_messages")
204 .fetch_one(pool)
205 .await
206 .unwrap();
207 assert_eq!(row.0, 0);
208
209 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM usage_stats")
210 .fetch_one(pool)
211 .await
212 .unwrap();
213 assert_eq!(row.0, 0);
214
215 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM compaction_log")
216 .fetch_one(pool)
217 .await
218 .unwrap();
219 assert_eq!(row.0, 0);
220
221 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM cron_jobs")
223 .fetch_one(pool)
224 .await
225 .unwrap();
226 assert_eq!(row.0, 0);
227
228 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM cron_runs")
229 .fetch_one(pool)
230 .await
231 .unwrap();
232 assert_eq!(row.0, 0);
233 }
234
235 #[tokio::test]
236 async fn on_disk_creates_core_db() {
237 let tmp = tempfile::tempdir().unwrap();
238 let db = CoreDb::new(tmp.path()).await.unwrap();
239
240 assert!(tmp.path().join("core.db").exists());
241
242 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
243 .fetch_one(db.pool())
244 .await
245 .unwrap();
246 assert_eq!(row.0, 0);
247 }
248
249 #[tokio::test]
250 async fn on_disk_creates_parent_dirs() {
251 let tmp = tempfile::tempdir().unwrap();
252 let nested = tmp.path().join("deep").join("nested").join("db");
253 let db = CoreDb::new(&nested).await.unwrap();
254
255 assert!(nested.join("core.db").exists());
256 drop(db);
257 }
258
259 #[tokio::test]
260 async fn reopen_is_idempotent() {
261 let tmp = tempfile::tempdir().unwrap();
262
263 let db1 = CoreDb::new(tmp.path()).await.unwrap();
265 sqlx::query(
266 "INSERT INTO users (id, email, display_name, role, is_active, created_at, updated_at) \
267 VALUES ('u1', 'a@b.com', 'A', 'admin', 1, '2024-01-01', '2024-01-01')",
268 )
269 .execute(db1.pool())
270 .await
271 .unwrap();
272 drop(db1);
273
274 let db2 = CoreDb::new(tmp.path()).await.unwrap();
276 let row: (String,) = sqlx::query_as("SELECT email FROM users WHERE id = 'u1'")
277 .fetch_one(db2.pool())
278 .await
279 .unwrap();
280 assert_eq!(row.0, "a@b.com");
281 }
282
283 #[tokio::test]
286 async fn fk_rejects_invalid_api_key_user() {
287 let db = CoreDb::in_memory().await.unwrap();
288
289 let result = sqlx::query(
290 "INSERT INTO api_keys (id, user_id, prefix, key_hash, created_at) \
291 VALUES ('k1', 'nonexistent', 'sp_', 'hash', '2024-01-01')",
292 )
293 .execute(db.pool())
294 .await;
295
296 assert!(
297 result.is_err(),
298 "FK should reject api_key with invalid user_id"
299 );
300 }
301
302 #[tokio::test]
303 async fn fk_rejects_invalid_telegram_link_user() {
304 let db = CoreDb::in_memory().await.unwrap();
305
306 let result = sqlx::query(
307 "INSERT INTO telegram_links (telegram_id, user_id, username, linked_at) \
308 VALUES (123, 'nonexistent', 'bob', '2024-01-01')",
309 )
310 .execute(db.pool())
311 .await;
312
313 assert!(
314 result.is_err(),
315 "FK should reject telegram_link with invalid user_id"
316 );
317 }
318
319 #[tokio::test]
320 async fn fk_rejects_invalid_session_message() {
321 let db = CoreDb::in_memory().await.unwrap();
322
323 let result = sqlx::query(
324 "INSERT INTO session_messages (session_id, role, content, timestamp) \
325 VALUES ('nonexistent', 'user', 'hello', '2024-01-01')",
326 )
327 .execute(db.pool())
328 .await;
329
330 assert!(
331 result.is_err(),
332 "FK should reject message with invalid session_id"
333 );
334 }
335
336 #[tokio::test]
337 async fn fk_rejects_invalid_cron_run_job() {
338 let db = CoreDb::in_memory().await.unwrap();
339
340 let result = sqlx::query(
341 "INSERT INTO cron_runs (id, job_id, started_at, status) \
342 VALUES ('r1', 'nonexistent', 1000, 'pending')",
343 )
344 .execute(db.pool())
345 .await;
346
347 assert!(
348 result.is_err(),
349 "FK should reject cron_run with invalid job_id"
350 );
351 }
352
353 #[tokio::test]
356 async fn cascade_delete_user_removes_api_keys() {
357 let db = CoreDb::in_memory().await.unwrap();
358 let pool = db.pool();
359
360 sqlx::query(
361 "INSERT INTO users (id, email, role, is_active, created_at, updated_at) \
362 VALUES ('u1', 'a@b.com', 'admin', 1, '2024-01-01', '2024-01-01')",
363 )
364 .execute(pool)
365 .await
366 .unwrap();
367
368 sqlx::query(
369 "INSERT INTO api_keys (id, user_id, prefix, key_hash, created_at) \
370 VALUES ('k1', 'u1', 'sp_', 'hash1', '2024-01-01')",
371 )
372 .execute(pool)
373 .await
374 .unwrap();
375
376 sqlx::query(
377 "INSERT INTO api_keys (id, user_id, prefix, key_hash, created_at) \
378 VALUES ('k2', 'u1', 'sp_', 'hash2', '2024-01-01')",
379 )
380 .execute(pool)
381 .await
382 .unwrap();
383
384 sqlx::query("DELETE FROM users WHERE id = 'u1'")
386 .execute(pool)
387 .await
388 .unwrap();
389
390 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM api_keys")
392 .fetch_one(pool)
393 .await
394 .unwrap();
395 assert_eq!(row.0, 0);
396 }
397
398 #[tokio::test]
399 async fn cascade_delete_user_removes_telegram_links() {
400 let db = CoreDb::in_memory().await.unwrap();
401 let pool = db.pool();
402
403 sqlx::query(
404 "INSERT INTO users (id, role, is_active, created_at, updated_at) \
405 VALUES ('u1', 'admin', 1, '2024-01-01', '2024-01-01')",
406 )
407 .execute(pool)
408 .await
409 .unwrap();
410
411 sqlx::query(
412 "INSERT INTO telegram_links (telegram_id, user_id, username, linked_at) \
413 VALUES (999, 'u1', 'bob', '2024-01-01')",
414 )
415 .execute(pool)
416 .await
417 .unwrap();
418
419 sqlx::query("DELETE FROM users WHERE id = 'u1'")
420 .execute(pool)
421 .await
422 .unwrap();
423
424 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM telegram_links")
425 .fetch_one(pool)
426 .await
427 .unwrap();
428 assert_eq!(row.0, 0);
429 }
430
431 #[tokio::test]
432 async fn cascade_delete_session_removes_messages_and_compaction() {
433 let db = CoreDb::in_memory().await.unwrap();
434 let pool = db.pool();
435
436 sqlx::query(
437 "INSERT INTO session_metadata (id, created_at, last_message_at) \
438 VALUES ('s1', '2024-01-01', '2024-01-01')",
439 )
440 .execute(pool)
441 .await
442 .unwrap();
443
444 sqlx::query(
445 "INSERT INTO session_messages (session_id, role, content, timestamp) \
446 VALUES ('s1', 'user', 'hi', '2024-01-01')",
447 )
448 .execute(pool)
449 .await
450 .unwrap();
451
452 sqlx::query(
453 "INSERT INTO compaction_log (session_id, timestamp, trigger, pre_tokens, summary) \
454 VALUES ('s1', '2024-01-01', 'auto', 1000, 'summary')",
455 )
456 .execute(pool)
457 .await
458 .unwrap();
459
460 sqlx::query("DELETE FROM session_metadata WHERE id = 's1'")
461 .execute(pool)
462 .await
463 .unwrap();
464
465 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM session_messages")
466 .fetch_one(pool)
467 .await
468 .unwrap();
469 assert_eq!(row.0, 0);
470
471 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM compaction_log")
472 .fetch_one(pool)
473 .await
474 .unwrap();
475 assert_eq!(row.0, 0);
476 }
477
478 #[tokio::test]
479 async fn cascade_delete_cron_job_removes_runs() {
480 let db = CoreDb::in_memory().await.unwrap();
481 let pool = db.pool();
482
483 sqlx::query(
484 "INSERT INTO cron_jobs (id, name, prompt, schedule_type, schedule_value, created_at) \
485 VALUES ('j1', 'test', 'do stuff', 'interval', '60000', 1000)",
486 )
487 .execute(pool)
488 .await
489 .unwrap();
490
491 sqlx::query(
492 "INSERT INTO cron_runs (id, job_id, started_at, status) \
493 VALUES ('r1', 'j1', 2000, 'success')",
494 )
495 .execute(pool)
496 .await
497 .unwrap();
498
499 sqlx::query("DELETE FROM cron_jobs WHERE id = 'j1'")
500 .execute(pool)
501 .await
502 .unwrap();
503
504 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM cron_runs")
505 .fetch_one(pool)
506 .await
507 .unwrap();
508 assert_eq!(row.0, 0);
509 }
510
511 #[tokio::test]
514 async fn cross_domain_join_sessions_with_usage_by_user() {
515 let db = CoreDb::in_memory().await.unwrap();
516 let pool = db.pool();
517
518 sqlx::query(
520 "INSERT INTO users (id, email, role, is_active, created_at, updated_at) \
521 VALUES ('u1', 'alice@test.com', 'admin', 1, '2024-01-01', '2024-01-01')",
522 )
523 .execute(pool)
524 .await
525 .unwrap();
526
527 sqlx::query(
529 "INSERT INTO session_metadata (id, created_at, last_message_at, user_id) \
530 VALUES ('s1', '2024-01-01', '2024-01-01', 'u1')",
531 )
532 .execute(pool)
533 .await
534 .unwrap();
535
536 sqlx::query(
538 "INSERT INTO usage_stats (session_id, turn, input_tokens, output_tokens, cost_usd, timestamp, user_id) \
539 VALUES ('s1', 1, 100, 200, 0.01, '2024-01-01', 'u1')"
540 ).execute(pool).await.unwrap();
541
542 let row: (String, f64) = sqlx::query_as(
544 "SELECT u.email, SUM(us.cost_usd) as total_cost \
545 FROM users u \
546 JOIN usage_stats us ON us.user_id = u.id \
547 GROUP BY u.id",
548 )
549 .fetch_one(pool)
550 .await
551 .unwrap();
552
553 assert_eq!(row.0, "alice@test.com");
554 assert!((row.1 - 0.01).abs() < 0.001);
555 }
556
557 #[tokio::test]
558 async fn pool_clone_shares_state() {
559 let db = CoreDb::in_memory().await.unwrap();
560
561 sqlx::query(
563 "INSERT INTO users (id, role, is_active, created_at, updated_at) \
564 VALUES ('u1', 'admin', 1, '2024-01-01', '2024-01-01')",
565 )
566 .execute(db.pool())
567 .await
568 .unwrap();
569
570 let pool2 = db.pool().clone();
572 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
573 .fetch_one(&pool2)
574 .await
575 .unwrap();
576 assert_eq!(row.0, 1);
577 }
578}