1use rusqlite::Connection;
2
3use crate::error::Result;
4
5mod checkpoint;
6mod cursor;
7mod file_log;
8mod journal_store;
9mod metrics;
10mod progression;
11mod run_aggregate;
12mod schema;
13mod shape;
14
15#[allow(unused_imports)]
19pub use checkpoint::ChunkTaskInfo;
20#[allow(unused_imports)]
21pub use file_log::FileRecord;
22#[allow(unused_imports)]
23pub use metrics::ExportMetric;
24pub use metrics::MetricRow;
25#[allow(unused_imports)]
26pub use progression::{Boundary, ExportProgression};
27#[allow(unused_imports)]
28pub use run_aggregate::{RunAggregate, RunAggregateEntry};
29#[allow(unused_imports)]
30pub use schema::{SchemaChange, SchemaColumn, arrow_schema_to_columns, schema_fingerprint};
31#[allow(unused_imports)]
32pub use shape::ShapeWarning;
33
34const STATE_DB_NAME: &str = ".rivet_state.db";
35
36const SCHEMA_VERSION: i64 = MIGRATIONS[MIGRATIONS.len() - 1].0;
38
39const MIGRATIONS: &[(i64, &str)] = &[
41 (
43 1,
44 "CREATE TABLE IF NOT EXISTS export_state (
45 export_name TEXT PRIMARY KEY,
46 last_cursor_value TEXT,
47 last_run_at TEXT
48 );
49 CREATE TABLE IF NOT EXISTS export_metrics (
50 id INTEGER PRIMARY KEY AUTOINCREMENT,
51 export_name TEXT NOT NULL,
52 run_at TEXT NOT NULL,
53 duration_ms INTEGER NOT NULL,
54 total_rows INTEGER NOT NULL,
55 peak_rss_mb INTEGER,
56 status TEXT NOT NULL,
57 error_message TEXT,
58 tuning_profile TEXT,
59 format TEXT,
60 mode TEXT,
61 files_produced INTEGER DEFAULT 0,
62 bytes_written INTEGER DEFAULT 0,
63 retries INTEGER DEFAULT 0,
64 validated INTEGER,
65 schema_changed INTEGER,
66 run_id TEXT
67 );
68 CREATE TABLE IF NOT EXISTS export_schema (
69 export_name TEXT PRIMARY KEY,
70 columns_json TEXT NOT NULL,
71 updated_at TEXT NOT NULL
72 );
73 CREATE TABLE IF NOT EXISTS file_manifest (
74 id INTEGER PRIMARY KEY AUTOINCREMENT,
75 run_id TEXT NOT NULL,
76 export_name TEXT NOT NULL,
77 file_name TEXT NOT NULL,
78 row_count INTEGER NOT NULL,
79 bytes INTEGER NOT NULL,
80 format TEXT NOT NULL,
81 compression TEXT,
82 created_at TEXT NOT NULL
83 );",
84 ),
85 (
87 2,
88 "CREATE TABLE IF NOT EXISTS chunk_run (
89 run_id TEXT PRIMARY KEY,
90 export_name TEXT NOT NULL,
91 plan_hash TEXT NOT NULL,
92 status TEXT NOT NULL,
93 max_chunk_attempts INTEGER NOT NULL DEFAULT 3,
94 created_at TEXT NOT NULL,
95 updated_at TEXT NOT NULL
96 );
97 CREATE INDEX IF NOT EXISTS idx_chunk_run_export_status
98 ON chunk_run(export_name, status);
99 CREATE TABLE IF NOT EXISTS chunk_task (
100 id INTEGER PRIMARY KEY AUTOINCREMENT,
101 run_id TEXT NOT NULL,
102 chunk_index INTEGER NOT NULL,
103 start_key TEXT NOT NULL,
104 end_key TEXT NOT NULL,
105 status TEXT NOT NULL,
106 attempts INTEGER NOT NULL DEFAULT 0,
107 last_error TEXT,
108 rows_written INTEGER,
109 file_name TEXT,
110 updated_at TEXT NOT NULL,
111 UNIQUE(run_id, chunk_index)
112 );
113 CREATE INDEX IF NOT EXISTS idx_chunk_task_run_status ON chunk_task(run_id, status);",
114 ),
115 (
117 3,
118 "CREATE INDEX IF NOT EXISTS idx_file_manifest_export ON file_manifest(export_name, id DESC);",
119 ),
120 (
122 4,
123 "CREATE TABLE IF NOT EXISTS export_progression (
124 export_name TEXT PRIMARY KEY,
125 last_committed_strategy TEXT,
126 last_committed_cursor TEXT,
127 last_committed_chunk_index INTEGER,
128 last_committed_run_id TEXT,
129 last_committed_at TEXT,
130 last_verified_strategy TEXT,
131 last_verified_cursor TEXT,
132 last_verified_chunk_index INTEGER,
133 last_verified_run_id TEXT,
134 last_verified_at TEXT
135 );",
136 ),
137 (
139 5,
140 "CREATE TABLE IF NOT EXISTS run_aggregate (
141 run_aggregate_id TEXT PRIMARY KEY,
142 started_at TEXT NOT NULL,
143 finished_at TEXT NOT NULL,
144 duration_ms INTEGER NOT NULL,
145 config_path TEXT,
146 parallel_mode TEXT NOT NULL,
147 total_exports INTEGER NOT NULL,
148 success_count INTEGER NOT NULL,
149 failed_count INTEGER NOT NULL,
150 skipped_count INTEGER NOT NULL,
151 total_rows INTEGER NOT NULL,
152 total_files INTEGER NOT NULL,
153 total_bytes INTEGER NOT NULL,
154 details_json TEXT NOT NULL
155 );
156 CREATE INDEX IF NOT EXISTS idx_run_aggregate_finished
157 ON run_aggregate(finished_at DESC);",
158 ),
159 (
161 6,
162 "CREATE TABLE IF NOT EXISTS export_shape (
163 export_name TEXT NOT NULL,
164 column_name TEXT NOT NULL,
165 max_byte_len INTEGER NOT NULL,
166 updated_at TEXT NOT NULL,
167 PRIMARY KEY (export_name, column_name)
168 );",
169 ),
170 (
172 7,
173 "CREATE TABLE IF NOT EXISTS run_journal (
174 run_id TEXT PRIMARY KEY,
175 export_name TEXT NOT NULL,
176 finished_at TEXT NOT NULL,
177 journal_json TEXT NOT NULL
178 );
179 CREATE INDEX IF NOT EXISTS idx_run_journal_export
180 ON run_journal(export_name, finished_at DESC);",
181 ),
182 (
186 8,
187 "ALTER TABLE file_manifest RENAME TO file_log;
188 DROP INDEX IF EXISTS idx_file_manifest_export;
189 CREATE INDEX IF NOT EXISTS idx_file_log_export ON file_log(export_name, id DESC);",
190 ),
191 (
198 9,
199 "ALTER TABLE export_metrics ADD COLUMN files_committed INTEGER;
200 ALTER TABLE export_metrics ADD COLUMN reconciled INTEGER;
201 ALTER TABLE export_metrics ADD COLUMN source_count INTEGER;
202 ALTER TABLE export_metrics ADD COLUMN quality_passed INTEGER;
203 ALTER TABLE export_metrics ADD COLUMN pg_temp_bytes_delta INTEGER;
204 ALTER TABLE export_metrics ADD COLUMN batch_size INTEGER;
205 ALTER TABLE export_metrics ADD COLUMN batch_size_memory_mb INTEGER;
206 ALTER TABLE export_metrics ADD COLUMN skip_reason TEXT;
207 ALTER TABLE export_metrics ADD COLUMN schema_fingerprint TEXT;
208 ALTER TABLE export_metrics ADD COLUMN chunk_size INTEGER;
209 ALTER TABLE export_metrics ADD COLUMN parallel INTEGER;
210 ALTER TABLE export_metrics ADD COLUMN source_type TEXT;
211 ALTER TABLE export_metrics ADD COLUMN destination_type TEXT;
212 ALTER TABLE export_metrics ADD COLUMN rivet_version TEXT;",
213 ),
214 (
217 10,
218 "ALTER TABLE export_metrics ADD COLUMN longest_chunk_ms INTEGER;",
219 ),
220 (
225 11,
226 "CREATE TABLE IF NOT EXISTS export_harm (
227 id INTEGER PRIMARY KEY AUTOINCREMENT,
228 run_id TEXT NOT NULL,
229 export_name TEXT NOT NULL,
230 metric TEXT NOT NULL,
231 delta INTEGER NOT NULL,
232 recorded_at TEXT NOT NULL
233 );
234 CREATE INDEX IF NOT EXISTS idx_export_harm_run ON export_harm(run_id);",
235 ),
236];
237
238const PG_MIGRATIONS: &[(i64, &str)] = &[
241 (
242 1,
243 "CREATE TABLE IF NOT EXISTS export_state (
244 export_name TEXT PRIMARY KEY,
245 last_cursor_value TEXT,
246 last_run_at TEXT
247 );
248 CREATE TABLE IF NOT EXISTS export_metrics (
249 id BIGSERIAL PRIMARY KEY,
250 export_name TEXT NOT NULL,
251 run_at TEXT NOT NULL,
252 duration_ms BIGINT NOT NULL,
253 total_rows BIGINT NOT NULL,
254 peak_rss_mb BIGINT,
255 status TEXT NOT NULL,
256 error_message TEXT,
257 tuning_profile TEXT,
258 format TEXT,
259 mode TEXT,
260 files_produced BIGINT DEFAULT 0,
261 bytes_written BIGINT DEFAULT 0,
262 retries BIGINT DEFAULT 0,
263 validated BOOLEAN,
264 schema_changed BOOLEAN,
265 run_id TEXT
266 );
267 CREATE TABLE IF NOT EXISTS export_schema (
268 export_name TEXT PRIMARY KEY,
269 columns_json TEXT NOT NULL,
270 updated_at TEXT NOT NULL
271 );
272 CREATE TABLE IF NOT EXISTS file_manifest (
273 id BIGSERIAL PRIMARY KEY,
274 run_id TEXT NOT NULL,
275 export_name TEXT NOT NULL,
276 file_name TEXT NOT NULL,
277 row_count BIGINT NOT NULL,
278 bytes BIGINT NOT NULL,
279 format TEXT NOT NULL,
280 compression TEXT,
281 created_at TEXT NOT NULL
282 );",
283 ),
284 (
285 2,
286 "CREATE TABLE IF NOT EXISTS chunk_run (
287 run_id TEXT PRIMARY KEY,
288 export_name TEXT NOT NULL,
289 plan_hash TEXT NOT NULL,
290 status TEXT NOT NULL,
291 max_chunk_attempts BIGINT NOT NULL DEFAULT 3,
292 created_at TEXT NOT NULL,
293 updated_at TEXT NOT NULL
294 );
295 CREATE INDEX IF NOT EXISTS idx_chunk_run_export_status
296 ON chunk_run(export_name, status);
297 CREATE TABLE IF NOT EXISTS chunk_task (
298 id BIGSERIAL PRIMARY KEY,
299 run_id TEXT NOT NULL,
300 chunk_index BIGINT NOT NULL,
301 start_key TEXT NOT NULL,
302 end_key TEXT NOT NULL,
303 status TEXT NOT NULL,
304 attempts BIGINT NOT NULL DEFAULT 0,
305 last_error TEXT,
306 rows_written BIGINT,
307 file_name TEXT,
308 updated_at TEXT NOT NULL,
309 UNIQUE(run_id, chunk_index)
310 );
311 CREATE INDEX IF NOT EXISTS idx_chunk_task_run_status ON chunk_task(run_id, status);",
312 ),
313 (
314 3,
315 "CREATE INDEX IF NOT EXISTS idx_file_manifest_export ON file_manifest(export_name, id DESC);",
316 ),
317 (
318 4,
319 "CREATE TABLE IF NOT EXISTS export_progression (
320 export_name TEXT PRIMARY KEY,
321 last_committed_strategy TEXT,
322 last_committed_cursor TEXT,
323 last_committed_chunk_index BIGINT,
324 last_committed_run_id TEXT,
325 last_committed_at TEXT,
326 last_verified_strategy TEXT,
327 last_verified_cursor TEXT,
328 last_verified_chunk_index BIGINT,
329 last_verified_run_id TEXT,
330 last_verified_at TEXT
331 );",
332 ),
333 (
334 5,
335 "CREATE TABLE IF NOT EXISTS run_aggregate (
336 run_aggregate_id TEXT PRIMARY KEY,
337 started_at TEXT NOT NULL,
338 finished_at TEXT NOT NULL,
339 duration_ms BIGINT NOT NULL,
340 config_path TEXT,
341 parallel_mode TEXT NOT NULL,
342 total_exports BIGINT NOT NULL,
343 success_count BIGINT NOT NULL,
344 failed_count BIGINT NOT NULL,
345 skipped_count BIGINT NOT NULL,
346 total_rows BIGINT NOT NULL,
347 total_files BIGINT NOT NULL,
348 total_bytes BIGINT NOT NULL,
349 details_json TEXT NOT NULL
350 );
351 CREATE INDEX IF NOT EXISTS idx_run_aggregate_finished
352 ON run_aggregate(finished_at DESC);",
353 ),
354 (
355 6,
356 "CREATE TABLE IF NOT EXISTS export_shape (
357 export_name TEXT NOT NULL,
358 column_name TEXT NOT NULL,
359 max_byte_len BIGINT NOT NULL,
360 updated_at TEXT NOT NULL,
361 PRIMARY KEY (export_name, column_name)
362 );",
363 ),
364 (
365 7,
366 "CREATE TABLE IF NOT EXISTS run_journal (
367 run_id TEXT PRIMARY KEY,
368 export_name TEXT NOT NULL,
369 finished_at TEXT NOT NULL,
370 journal_json TEXT NOT NULL
371 );
372 CREATE INDEX IF NOT EXISTS idx_run_journal_export
373 ON run_journal(export_name, finished_at DESC);",
374 ),
375 (
378 8,
379 "ALTER TABLE file_manifest RENAME TO file_log;
380 DROP INDEX IF EXISTS idx_file_manifest_export;
381 CREATE INDEX IF NOT EXISTS idx_file_log_export ON file_log(export_name, id DESC);",
382 ),
383 (
386 9,
387 "ALTER TABLE export_metrics ADD COLUMN files_committed BIGINT;
388 ALTER TABLE export_metrics ADD COLUMN reconciled BOOLEAN;
389 ALTER TABLE export_metrics ADD COLUMN source_count BIGINT;
390 ALTER TABLE export_metrics ADD COLUMN quality_passed BOOLEAN;
391 ALTER TABLE export_metrics ADD COLUMN pg_temp_bytes_delta BIGINT;
392 ALTER TABLE export_metrics ADD COLUMN batch_size BIGINT;
393 ALTER TABLE export_metrics ADD COLUMN batch_size_memory_mb BIGINT;
394 ALTER TABLE export_metrics ADD COLUMN skip_reason TEXT;
395 ALTER TABLE export_metrics ADD COLUMN schema_fingerprint TEXT;
396 ALTER TABLE export_metrics ADD COLUMN chunk_size BIGINT;
397 ALTER TABLE export_metrics ADD COLUMN parallel BIGINT;
398 ALTER TABLE export_metrics ADD COLUMN source_type TEXT;
399 ALTER TABLE export_metrics ADD COLUMN destination_type TEXT;
400 ALTER TABLE export_metrics ADD COLUMN rivet_version TEXT;",
401 ),
402 (
404 10,
405 "ALTER TABLE export_metrics ADD COLUMN longest_chunk_ms BIGINT;",
406 ),
407 (
409 11,
410 "CREATE TABLE IF NOT EXISTS export_harm (
411 id BIGSERIAL PRIMARY KEY,
412 run_id TEXT NOT NULL,
413 export_name TEXT NOT NULL,
414 metric TEXT NOT NULL,
415 delta BIGINT NOT NULL,
416 recorded_at TEXT NOT NULL
417 );
418 CREATE INDEX IF NOT EXISTS idx_export_harm_run ON export_harm(run_id);",
419 ),
420];
421
422pub(super) fn pg_sql(sql: &str) -> String {
427 let bytes = sql.as_bytes();
428 let mut out = String::with_capacity(sql.len());
429 let mut i = 0;
430 while i < bytes.len() {
431 if bytes[i] == b'?' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
432 out.push('$');
433 } else {
434 out.push(bytes[i] as char);
435 }
436 i += 1;
437 }
438 out
439}
440
441pub(super) fn connect_pg(url: &str) -> Result<postgres::Client> {
458 let tls = state_tls_mode_from_url(url).map(|mode| crate::config::TlsConfig {
459 mode,
460 ..crate::config::TlsConfig::default()
461 });
462 crate::source::postgres::connect_client(url, tls.as_ref())
463 .map_err(|e| anyhow::anyhow!("state(pg): connect to '{}': {:#}", redact_pg_url(url), e))
464}
465
466fn state_tls_mode_from_url(url: &str) -> Option<crate::config::TlsMode> {
474 use crate::config::TlsMode;
475 let (_, query) = url.split_once('?')?;
476 let mut mode = None;
477 for pair in query.split('&') {
478 let (key, value) = pair.split_once('=').unwrap_or((pair, ""));
479 if key != "sslmode" {
480 continue;
481 }
482 mode = match value {
483 "require" => Some(TlsMode::Require),
484 "verify-ca" => Some(TlsMode::VerifyCa),
485 "verify-full" => Some(TlsMode::VerifyFull),
486 _ => None,
487 };
488 }
489 mode
490}
491
492pub(super) enum StateConn {
496 Sqlite(rusqlite::Connection),
497 Postgres(Box<std::cell::RefCell<postgres::Client>>),
502}
503
504#[derive(Clone)]
508pub enum StateRef {
509 Sqlite(std::path::PathBuf),
510 Postgres(String),
511}
512
513fn ensure_schema_version_table(conn: &Connection) {
516 let _ = conn.execute_batch(
517 "CREATE TABLE IF NOT EXISTS schema_version (
518 version INTEGER NOT NULL
519 );",
520 );
521}
522
523fn get_current_version(conn: &Connection) -> i64 {
524 conn.query_row(
525 "SELECT COALESCE(MAX(version), 0) FROM schema_version",
526 [],
527 |row| row.get(0),
528 )
529 .unwrap_or(0)
530}
531
532fn migrate(conn: &Connection) -> Result<()> {
533 ensure_schema_version_table(conn);
534
535 let current = get_current_version(conn);
536
537 if current == 0 {
538 let has_export_state: bool = conn
539 .query_row(
540 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='export_state'",
541 [],
542 |row| row.get(0),
543 )
544 .unwrap_or(false);
545
546 if has_export_state {
547 let metrics_cols = [
548 "files_produced INTEGER DEFAULT 0",
549 "bytes_written INTEGER DEFAULT 0",
550 "retries INTEGER DEFAULT 0",
551 "validated INTEGER",
552 "schema_changed INTEGER",
553 "run_id TEXT",
554 ];
555 for col_def in &metrics_cols {
556 let sql = format!("ALTER TABLE export_metrics ADD COLUMN {}", col_def);
557 let _ = conn.execute(&sql, []);
558 }
559 }
560 }
561
562 for &(ver, sql) in MIGRATIONS {
563 if ver > current {
564 log::debug!("state: applying migration v{}", ver);
565 let atomic_sql = format!(
566 "BEGIN;\n{}\nINSERT INTO schema_version (version) VALUES ({});\nCOMMIT;",
567 sql, ver
568 );
569 conn.execute_batch(&atomic_sql)
570 .map_err(|e| anyhow::anyhow!("state: migration v{} failed: {}", ver, e))?;
571 }
572 }
573
574 let _ = conn.execute(
575 "DELETE FROM schema_version WHERE version < (SELECT MAX(version) FROM schema_version)",
576 [],
577 );
578
579 let final_version = get_current_version(conn);
580 if final_version != SCHEMA_VERSION {
581 anyhow::bail!(
582 "state: migration incomplete — expected schema v{} but reached v{}",
583 SCHEMA_VERSION,
584 final_version
585 );
586 }
587
588 Ok(())
589}
590
591fn migrate_pg(client: &mut postgres::Client) -> Result<()> {
594 client
595 .batch_execute("CREATE TABLE IF NOT EXISTS rivet_schema_version (version BIGINT NOT NULL);")
596 .map_err(|e| anyhow::anyhow!("state(pg): create version table: {:#}", e))?;
597
598 let current: i64 = client
599 .query_one(
600 "SELECT COALESCE(MAX(version), 0) FROM rivet_schema_version",
601 &[],
602 )
603 .map_err(|e| anyhow::anyhow!("state(pg): read schema version: {:#}", e))?
604 .get(0);
605
606 for &(ver, sql) in PG_MIGRATIONS {
607 if ver > current {
608 log::debug!("state(pg): applying migration v{}", ver);
609 let batch = format!(
610 "BEGIN; {} INSERT INTO rivet_schema_version (version) VALUES ({}); COMMIT;",
611 sql, ver
612 );
613 client
614 .batch_execute(&batch)
615 .map_err(|e| anyhow::anyhow!("state(pg): migration v{} failed: {:#}", ver, e))?;
616 }
617 }
618
619 let _ = client.batch_execute(
621 "DELETE FROM rivet_schema_version \
622 WHERE version < (SELECT MAX(version) FROM rivet_schema_version);",
623 );
624
625 let final_version: i64 = client
627 .query_one(
628 "SELECT COALESCE(MAX(version), 0) FROM rivet_schema_version",
629 &[],
630 )
631 .map_err(|e| anyhow::anyhow!("state(pg): read final schema version: {:#}", e))?
632 .get(0);
633 if final_version != SCHEMA_VERSION {
634 anyhow::bail!(
635 "state(pg): migration incomplete — expected schema v{} but reached v{}",
636 SCHEMA_VERSION,
637 final_version
638 );
639 }
640
641 Ok(())
642}
643
644fn redact_pg_url(url: &str) -> String {
648 if let Some(at_pos) = url.rfind('@')
649 && let Some(scheme_end) = url.find("://")
650 {
651 let authority = &url[scheme_end + 3..at_pos];
652 if let Some(colon) = authority.rfind(':') {
653 let user = &authority[..colon];
654 return format!(
655 "{}://{}:***@{}",
656 &url[..scheme_end],
657 user,
658 &url[at_pos + 1..]
659 );
660 }
661 }
662 url.to_string()
663}
664
665pub(crate) const SQLITE_BUSY_TIMEOUT_MS: i64 = 10_000;
668
669pub(crate) fn open_connection(db_path: &std::path::Path) -> Result<Connection> {
670 let conn = Connection::open(db_path)?;
671 if let Err(e) = conn.execute_batch("PRAGMA journal_mode=WAL;") {
672 log::warn!(
673 "state: WAL journal mode unavailable ({}); \
674 running in default mode — concurrent writes may be slower",
675 e
676 );
677 }
678 if let Err(e) = conn.execute_batch(&format!(
679 "PRAGMA busy_timeout = {};",
680 SQLITE_BUSY_TIMEOUT_MS
681 )) {
682 log::warn!(
683 "state: failed to set busy_timeout ({}); \
684 concurrent writers may surface SQLITE_BUSY immediately",
685 e
686 );
687 }
688 Ok(conn)
689}
690
691pub struct StateStore {
711 pub(super) conn: StateConn,
712 pub(super) state_ref: StateRef,
714}
715
716impl StateStore {
717 pub fn open(config_path: &str) -> Result<Self> {
721 if let Ok(url) = std::env::var("RIVET_STATE_URL")
722 && url.starts_with("postgres")
723 {
724 return Self::open_postgres(&url);
725 }
726 Self::open_sqlite(config_path)
727 }
728
729 fn open_sqlite(config_path: &str) -> Result<Self> {
730 let config_dir = std::path::Path::new(config_path)
731 .parent()
732 .unwrap_or(std::path::Path::new("."));
733 let db_path = config_dir.join(STATE_DB_NAME);
734 let conn = open_connection(&db_path)?;
735 migrate(&conn)?;
736 Ok(Self {
737 conn: StateConn::Sqlite(conn),
738 state_ref: StateRef::Sqlite(db_path),
739 })
740 }
741
742 fn open_postgres(url: &str) -> Result<Self> {
743 let is_local =
744 url.contains("localhost") || url.contains("127.0.0.1") || url.contains("::1");
745 if !is_local && state_tls_mode_from_url(url).is_none() {
746 log::warn!(
747 "state(pg): connecting to a remote host without TLS; \
748 add sslmode=require (or verify-ca / verify-full) to RIVET_STATE_URL \
749 to negotiate TLS for production use"
750 );
751 }
752 let mut client = connect_pg(url)?;
753 migrate_pg(&mut client)?;
754 Ok(Self {
755 conn: StateConn::Postgres(Box::new(std::cell::RefCell::new(client))),
756 state_ref: StateRef::Postgres(url.to_string()),
757 })
758 }
759
760 pub fn state_db_path(config_path: &str) -> std::path::PathBuf {
764 let config_dir = std::path::Path::new(config_path)
765 .parent()
766 .unwrap_or(std::path::Path::new("."));
767 config_dir.join(STATE_DB_NAME)
768 }
769
770 pub fn state_ref(&self) -> &StateRef {
772 &self.state_ref
773 }
774
775 #[allow(dead_code)]
777 pub fn open_in_memory() -> Result<Self> {
778 let conn = Connection::open_in_memory()?;
779 migrate(&conn)?;
780 Ok(Self {
781 conn: StateConn::Sqlite(conn),
782 state_ref: StateRef::Sqlite(std::path::PathBuf::from(":memory:")),
783 })
784 }
785
786 #[allow(dead_code)]
789 pub fn open_at_path(db_path: &std::path::Path) -> Result<Self> {
790 let conn = open_connection(db_path)?;
791 migrate(&conn)?;
792 Ok(Self {
793 conn: StateConn::Sqlite(conn),
794 state_ref: StateRef::Sqlite(db_path.to_path_buf()),
795 })
796 }
797}
798
799#[cfg(test)]
802mod tests {
803 use super::*;
804
805 #[test]
806 fn fresh_db_reaches_latest_version() {
807 let s = StateStore::open_in_memory().unwrap();
808 let ver = match &s.conn {
809 StateConn::Sqlite(c) => get_current_version(c),
810 StateConn::Postgres(_) => unreachable!(),
811 };
812 assert_eq!(ver, SCHEMA_VERSION);
813 }
814
815 #[test]
816 fn migration_is_idempotent() {
817 let s = StateStore::open_in_memory().unwrap();
818 match &s.conn {
819 StateConn::Sqlite(c) => {
820 migrate(c).unwrap();
821 migrate(c).unwrap();
822 assert_eq!(get_current_version(c), SCHEMA_VERSION);
823 }
824 StateConn::Postgres(_) => unreachable!(),
825 }
826 }
827
828 #[test]
829 fn legacy_db_gets_upgraded() {
830 let conn = Connection::open_in_memory().unwrap();
831 conn.execute_batch(
832 "CREATE TABLE export_state (
833 export_name TEXT PRIMARY KEY,
834 last_cursor_value TEXT,
835 last_run_at TEXT
836 );
837 CREATE TABLE export_metrics (
838 id INTEGER PRIMARY KEY AUTOINCREMENT,
839 export_name TEXT NOT NULL,
840 run_at TEXT NOT NULL,
841 duration_ms INTEGER NOT NULL,
842 total_rows INTEGER NOT NULL,
843 status TEXT NOT NULL
844 );",
845 )
846 .unwrap();
847
848 migrate(&conn).unwrap();
849 assert_eq!(get_current_version(&conn), SCHEMA_VERSION);
850
851 let has_chunk_run: bool = conn
852 .query_row(
853 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='chunk_run'",
854 [],
855 |row| row.get(0),
856 )
857 .unwrap();
858 assert!(has_chunk_run);
859 }
860
861 #[test]
862 fn v8_renames_file_manifest_to_file_log() {
863 let s = StateStore::open_in_memory().unwrap();
864 let conn = match &s.conn {
865 StateConn::Sqlite(c) => c,
866 StateConn::Postgres(_) => unreachable!(),
867 };
868 let has_file_log: bool = conn
869 .query_row(
870 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='file_log'",
871 [],
872 |row| row.get(0),
873 )
874 .unwrap();
875 assert!(has_file_log, "v8 must produce a `file_log` table");
876 let has_old: bool = conn
877 .query_row(
878 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='file_manifest'",
879 [],
880 |row| row.get(0),
881 )
882 .unwrap();
883 assert!(!has_old, "v8 must remove the old `file_manifest` table");
884 let has_new_idx: bool = conn
885 .query_row(
886 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_file_log_export'",
887 [],
888 |row| row.get(0),
889 )
890 .unwrap();
891 assert!(has_new_idx, "v8 must create the renamed index");
892 }
893
894 #[test]
895 fn v8_upgrades_existing_v7_db_with_data() {
896 let conn = Connection::open_in_memory().unwrap();
899 migrate(&conn).unwrap();
903 conn.execute(
905 "INSERT INTO file_log (run_id, export_name, file_name, row_count, bytes, format, created_at)
906 VALUES ('r1', 'orders', 'f.parquet', 100, 4096, 'parquet', '2026-05-21T00:00:00Z')",
907 [],
908 )
909 .unwrap();
910 let count: i64 = conn
911 .query_row("SELECT COUNT(*) FROM file_log", [], |r| r.get(0))
912 .unwrap();
913 assert_eq!(count, 1);
914 }
915
916 #[test]
917 fn run_aggregate_table_exists_after_migration() {
918 let s = StateStore::open_in_memory().unwrap();
919 let conn = match &s.conn {
920 StateConn::Sqlite(c) => c,
921 StateConn::Postgres(_) => unreachable!(),
922 };
923 let exists: bool = conn
924 .query_row(
925 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='run_aggregate'",
926 [],
927 |row| row.get(0),
928 )
929 .unwrap();
930 assert!(exists, "v5 migration must create the run_aggregate table");
931 }
932
933 #[test]
934 fn pg_sql_converts_placeholders() {
935 assert_eq!(
936 pg_sql("SELECT ?1, ?2 FROM t WHERE x = ?3"),
937 "SELECT $1, $2 FROM t WHERE x = $3"
938 );
939 assert_eq!(
940 pg_sql("INSERT INTO t VALUES (?1, ?2)"),
941 "INSERT INTO t VALUES ($1, $2)"
942 );
943 assert_eq!(pg_sql("no placeholders"), "no placeholders");
944 assert_eq!(pg_sql("?10 AND ?11"), "$10 AND $11");
946 }
947
948 #[test]
949 fn redact_pg_url_removes_password() {
950 assert_eq!(
951 redact_pg_url("postgresql://rivet:secret123@localhost:5433/rivet_state"),
952 "postgresql://rivet:***@localhost:5433/rivet_state"
953 );
954 assert_eq!(
955 redact_pg_url("postgres://admin:p@ssw0rd@db.prod.example.com/state"),
956 "postgres://admin:***@db.prod.example.com/state"
957 );
958 }
959
960 #[test]
961 fn redact_pg_url_no_password_unchanged() {
962 let url = "postgresql://rivet@localhost/state";
964 assert_eq!(redact_pg_url(url), url);
965 }
966
967 use crate::config::TlsMode;
975
976 #[test]
977 fn state_sslmode_enforced_values_negotiate_tls() {
978 for (url, want) in [
979 (
980 "postgresql://u:p@db.prod:5432/state?sslmode=require",
981 TlsMode::Require,
982 ),
983 (
984 "postgresql://u:p@db.prod/state?sslmode=verify-ca",
985 TlsMode::VerifyCa,
986 ),
987 (
988 "postgresql://u:p@db.prod/state?sslmode=verify-full",
989 TlsMode::VerifyFull,
990 ),
991 ] {
992 let mode = state_tls_mode_from_url(url);
993 assert_eq!(mode, Some(want), "url: {url}");
994 assert!(
995 mode.unwrap().is_enforced(),
996 "{want:?} must enforce TLS (not NoTls)"
997 );
998 }
999 }
1000
1001 #[test]
1002 fn state_sslmode_plaintext_values_stay_notls() {
1003 for url in [
1006 "postgresql://u:p@localhost/state",
1007 "postgresql://u:p@localhost/state?sslmode=disable",
1008 "postgresql://u:p@db/state?sslmode=prefer",
1009 "postgresql://u:p@db/state?sslmode=allow",
1010 "postgresql://u:p@db/state?sslmode=REQUIRE",
1011 "postgresql://u:p@db/state?sslmode=garbage",
1012 "postgresql://u:p@db/state?sslmode",
1013 "postgresql://u:p@db/state?sslmode=",
1014 ] {
1015 assert_eq!(state_tls_mode_from_url(url), None, "url: {url}");
1016 }
1017 }
1018
1019 #[test]
1020 fn state_sslmode_exact_key_and_last_occurrence_wins() {
1021 assert_eq!(
1023 state_tls_mode_from_url("postgresql://u:p@db/state?xsslmode=require"),
1024 None
1025 );
1026 assert_eq!(
1028 state_tls_mode_from_url(
1029 "postgresql://u:p@db/state?connect_timeout=10&sslmode=require&application_name=x"
1030 ),
1031 Some(TlsMode::Require)
1032 );
1033 assert_eq!(
1035 state_tls_mode_from_url("postgresql://u:p@db/state?sslmode=disable&sslmode=require"),
1036 Some(TlsMode::Require)
1037 );
1038 assert_eq!(
1039 state_tls_mode_from_url("postgresql://u:p@db/state?sslmode=require&sslmode=disable"),
1040 None
1041 );
1042 }
1043}