Skip to main content

rivet/state/
mod.rs

1use rusqlite::Connection;
2
3use crate::error::Result;
4
5mod checkpoint;
6mod cursor;
7mod file_log;
8mod journal_store;
9mod metrics;
10mod progression;
11mod run_aggregate;
12mod schema;
13mod shape;
14
15// Re-export domain types so callers use `rivet::state::*` unchanged.
16// Items below may not be explicitly named by all internal callers (often used
17// as inferred return types), but are part of the public integration-test API.
18#[allow(unused_imports)]
19pub use checkpoint::ChunkTaskInfo;
20#[allow(unused_imports)]
21pub use file_log::FileRecord;
22#[allow(unused_imports)]
23pub use metrics::ExportMetric;
24#[allow(unused_imports)]
25pub use progression::{Boundary, ExportProgression};
26#[allow(unused_imports)]
27pub use run_aggregate::{RunAggregate, RunAggregateEntry};
28#[allow(unused_imports)]
29pub use schema::{SchemaChange, SchemaColumn, arrow_schema_to_columns, schema_fingerprint};
30#[allow(unused_imports)]
31pub use shape::ShapeWarning;
32
33const STATE_DB_NAME: &str = ".rivet_state.db";
34
35/// Current schema version — always the last entry in `MIGRATIONS`.
36const SCHEMA_VERSION: i64 = MIGRATIONS[MIGRATIONS.len() - 1].0;
37
38/// Each entry is `(version, sql)`.  Applied in order when the DB is behind.
39const MIGRATIONS: &[(i64, &str)] = &[
40    // v1: core tables
41    (
42        1,
43        "CREATE TABLE IF NOT EXISTS export_state (
44            export_name TEXT PRIMARY KEY,
45            last_cursor_value TEXT,
46            last_run_at TEXT
47        );
48        CREATE TABLE IF NOT EXISTS export_metrics (
49            id INTEGER PRIMARY KEY AUTOINCREMENT,
50            export_name TEXT NOT NULL,
51            run_at TEXT NOT NULL,
52            duration_ms INTEGER NOT NULL,
53            total_rows INTEGER NOT NULL,
54            peak_rss_mb INTEGER,
55            status TEXT NOT NULL,
56            error_message TEXT,
57            tuning_profile TEXT,
58            format TEXT,
59            mode TEXT,
60            files_produced INTEGER DEFAULT 0,
61            bytes_written INTEGER DEFAULT 0,
62            retries INTEGER DEFAULT 0,
63            validated INTEGER,
64            schema_changed INTEGER,
65            run_id TEXT
66        );
67        CREATE TABLE IF NOT EXISTS export_schema (
68            export_name TEXT PRIMARY KEY,
69            columns_json TEXT NOT NULL,
70            updated_at TEXT NOT NULL
71        );
72        CREATE TABLE IF NOT EXISTS file_manifest (
73            id INTEGER PRIMARY KEY AUTOINCREMENT,
74            run_id TEXT NOT NULL,
75            export_name TEXT NOT NULL,
76            file_name TEXT NOT NULL,
77            row_count INTEGER NOT NULL,
78            bytes INTEGER NOT NULL,
79            format TEXT NOT NULL,
80            compression TEXT,
81            created_at TEXT NOT NULL
82        );",
83    ),
84    // v2: chunk checkpoint tables
85    (
86        2,
87        "CREATE TABLE IF NOT EXISTS chunk_run (
88            run_id TEXT PRIMARY KEY,
89            export_name TEXT NOT NULL,
90            plan_hash TEXT NOT NULL,
91            status TEXT NOT NULL,
92            max_chunk_attempts INTEGER NOT NULL DEFAULT 3,
93            created_at TEXT NOT NULL,
94            updated_at TEXT NOT NULL
95        );
96        CREATE INDEX IF NOT EXISTS idx_chunk_run_export_status
97            ON chunk_run(export_name, status);
98        CREATE TABLE IF NOT EXISTS chunk_task (
99            id INTEGER PRIMARY KEY AUTOINCREMENT,
100            run_id TEXT NOT NULL,
101            chunk_index INTEGER NOT NULL,
102            start_key TEXT NOT NULL,
103            end_key TEXT NOT NULL,
104            status TEXT NOT NULL,
105            attempts INTEGER NOT NULL DEFAULT 0,
106            last_error TEXT,
107            rows_written INTEGER,
108            file_name TEXT,
109            updated_at TEXT NOT NULL,
110            UNIQUE(run_id, chunk_index)
111        );
112        CREATE INDEX IF NOT EXISTS idx_chunk_task_run_status ON chunk_task(run_id, status);",
113    ),
114    // v3: index on file_manifest for faster per-export lookups
115    (
116        3,
117        "CREATE INDEX IF NOT EXISTS idx_file_manifest_export ON file_manifest(export_name, id DESC);",
118    ),
119    // v4: committed / verified boundary tracking (ADR-0008, Epic G)
120    (
121        4,
122        "CREATE TABLE IF NOT EXISTS export_progression (
123            export_name TEXT PRIMARY KEY,
124            last_committed_strategy TEXT,
125            last_committed_cursor TEXT,
126            last_committed_chunk_index INTEGER,
127            last_committed_run_id TEXT,
128            last_committed_at TEXT,
129            last_verified_strategy TEXT,
130            last_verified_cursor TEXT,
131            last_verified_chunk_index INTEGER,
132            last_verified_run_id TEXT,
133            last_verified_at TEXT
134        );",
135    ),
136    // v5: aggregate run summary
137    (
138        5,
139        "CREATE TABLE IF NOT EXISTS run_aggregate (
140            run_aggregate_id TEXT PRIMARY KEY,
141            started_at TEXT NOT NULL,
142            finished_at TEXT NOT NULL,
143            duration_ms INTEGER NOT NULL,
144            config_path TEXT,
145            parallel_mode TEXT NOT NULL,
146            total_exports INTEGER NOT NULL,
147            success_count INTEGER NOT NULL,
148            failed_count INTEGER NOT NULL,
149            skipped_count INTEGER NOT NULL,
150            total_rows INTEGER NOT NULL,
151            total_files INTEGER NOT NULL,
152            total_bytes INTEGER NOT NULL,
153            details_json TEXT NOT NULL
154        );
155        CREATE INDEX IF NOT EXISTS idx_run_aggregate_finished
156            ON run_aggregate(finished_at DESC);",
157    ),
158    // v6: per-column data shape stats
159    (
160        6,
161        "CREATE TABLE IF NOT EXISTS export_shape (
162            export_name TEXT NOT NULL,
163            column_name TEXT NOT NULL,
164            max_byte_len INTEGER NOT NULL,
165            updated_at TEXT NOT NULL,
166            PRIMARY KEY (export_name, column_name)
167        );",
168    ),
169    // v7: structured run journal
170    (
171        7,
172        "CREATE TABLE IF NOT EXISTS run_journal (
173            run_id TEXT PRIMARY KEY,
174            export_name TEXT NOT NULL,
175            finished_at TEXT NOT NULL,
176            journal_json TEXT NOT NULL
177        );
178        CREATE INDEX IF NOT EXISTS idx_run_journal_export
179            ON run_journal(export_name, finished_at DESC);",
180    ),
181    // v8: rename file_manifest → file_log.  The 0.7.0 cloud-output contract
182    // reclaims the "manifest" name for the public JSON artifact; the internal
183    // SQLite log of written files becomes `file_log` to remove the overload.
184    (
185        8,
186        "ALTER TABLE file_manifest RENAME TO file_log;
187        DROP INDEX IF EXISTS idx_file_manifest_export;
188        CREATE INDEX IF NOT EXISTS idx_file_log_export ON file_log(export_name, id DESC);",
189    ),
190];
191
192/// PostgreSQL-compatible DDL.  Column types differ from SQLite (BIGSERIAL,
193/// BOOLEAN); placeholder style is `$N` (handled by callers via `pg_sql()`).
194const PG_MIGRATIONS: &[(i64, &str)] = &[
195    (
196        1,
197        "CREATE TABLE IF NOT EXISTS export_state (
198            export_name TEXT PRIMARY KEY,
199            last_cursor_value TEXT,
200            last_run_at TEXT
201        );
202        CREATE TABLE IF NOT EXISTS export_metrics (
203            id BIGSERIAL PRIMARY KEY,
204            export_name TEXT NOT NULL,
205            run_at TEXT NOT NULL,
206            duration_ms BIGINT NOT NULL,
207            total_rows BIGINT NOT NULL,
208            peak_rss_mb BIGINT,
209            status TEXT NOT NULL,
210            error_message TEXT,
211            tuning_profile TEXT,
212            format TEXT,
213            mode TEXT,
214            files_produced BIGINT DEFAULT 0,
215            bytes_written BIGINT DEFAULT 0,
216            retries BIGINT DEFAULT 0,
217            validated BOOLEAN,
218            schema_changed BOOLEAN,
219            run_id TEXT
220        );
221        CREATE TABLE IF NOT EXISTS export_schema (
222            export_name TEXT PRIMARY KEY,
223            columns_json TEXT NOT NULL,
224            updated_at TEXT NOT NULL
225        );
226        CREATE TABLE IF NOT EXISTS file_manifest (
227            id BIGSERIAL PRIMARY KEY,
228            run_id TEXT NOT NULL,
229            export_name TEXT NOT NULL,
230            file_name TEXT NOT NULL,
231            row_count BIGINT NOT NULL,
232            bytes BIGINT NOT NULL,
233            format TEXT NOT NULL,
234            compression TEXT,
235            created_at TEXT NOT NULL
236        );",
237    ),
238    (
239        2,
240        "CREATE TABLE IF NOT EXISTS chunk_run (
241            run_id TEXT PRIMARY KEY,
242            export_name TEXT NOT NULL,
243            plan_hash TEXT NOT NULL,
244            status TEXT NOT NULL,
245            max_chunk_attempts BIGINT NOT NULL DEFAULT 3,
246            created_at TEXT NOT NULL,
247            updated_at TEXT NOT NULL
248        );
249        CREATE INDEX IF NOT EXISTS idx_chunk_run_export_status
250            ON chunk_run(export_name, status);
251        CREATE TABLE IF NOT EXISTS chunk_task (
252            id BIGSERIAL PRIMARY KEY,
253            run_id TEXT NOT NULL,
254            chunk_index BIGINT NOT NULL,
255            start_key TEXT NOT NULL,
256            end_key TEXT NOT NULL,
257            status TEXT NOT NULL,
258            attempts BIGINT NOT NULL DEFAULT 0,
259            last_error TEXT,
260            rows_written BIGINT,
261            file_name TEXT,
262            updated_at TEXT NOT NULL,
263            UNIQUE(run_id, chunk_index)
264        );
265        CREATE INDEX IF NOT EXISTS idx_chunk_task_run_status ON chunk_task(run_id, status);",
266    ),
267    (
268        3,
269        "CREATE INDEX IF NOT EXISTS idx_file_manifest_export ON file_manifest(export_name, id DESC);",
270    ),
271    (
272        4,
273        "CREATE TABLE IF NOT EXISTS export_progression (
274            export_name TEXT PRIMARY KEY,
275            last_committed_strategy TEXT,
276            last_committed_cursor TEXT,
277            last_committed_chunk_index BIGINT,
278            last_committed_run_id TEXT,
279            last_committed_at TEXT,
280            last_verified_strategy TEXT,
281            last_verified_cursor TEXT,
282            last_verified_chunk_index BIGINT,
283            last_verified_run_id TEXT,
284            last_verified_at TEXT
285        );",
286    ),
287    (
288        5,
289        "CREATE TABLE IF NOT EXISTS run_aggregate (
290            run_aggregate_id TEXT PRIMARY KEY,
291            started_at TEXT NOT NULL,
292            finished_at TEXT NOT NULL,
293            duration_ms BIGINT NOT NULL,
294            config_path TEXT,
295            parallel_mode TEXT NOT NULL,
296            total_exports BIGINT NOT NULL,
297            success_count BIGINT NOT NULL,
298            failed_count BIGINT NOT NULL,
299            skipped_count BIGINT NOT NULL,
300            total_rows BIGINT NOT NULL,
301            total_files BIGINT NOT NULL,
302            total_bytes BIGINT NOT NULL,
303            details_json TEXT NOT NULL
304        );
305        CREATE INDEX IF NOT EXISTS idx_run_aggregate_finished
306            ON run_aggregate(finished_at DESC);",
307    ),
308    (
309        6,
310        "CREATE TABLE IF NOT EXISTS export_shape (
311            export_name TEXT NOT NULL,
312            column_name TEXT NOT NULL,
313            max_byte_len BIGINT NOT NULL,
314            updated_at TEXT NOT NULL,
315            PRIMARY KEY (export_name, column_name)
316        );",
317    ),
318    (
319        7,
320        "CREATE TABLE IF NOT EXISTS run_journal (
321            run_id TEXT PRIMARY KEY,
322            export_name TEXT NOT NULL,
323            finished_at TEXT NOT NULL,
324            journal_json TEXT NOT NULL
325        );
326        CREATE INDEX IF NOT EXISTS idx_run_journal_export
327            ON run_journal(export_name, finished_at DESC);",
328    ),
329    // v8: rename file_manifest → file_log.  Mirrors the SQLite v8 migration;
330    // see the SQLite array for rationale.
331    (
332        8,
333        "ALTER TABLE file_manifest RENAME TO file_log;
334        DROP INDEX IF EXISTS idx_file_manifest_export;
335        CREATE INDEX IF NOT EXISTS idx_file_log_export ON file_log(export_name, id DESC);",
336    ),
337];
338
339// ─── SQL helpers ──────────────────────────────────────────────────────────────
340
341/// Convert SQLite `?N` placeholders to PostgreSQL `$N` style.
342/// `"WHERE x = ?1 AND y = ?2"` → `"WHERE x = $1 AND y = $2"`.
343pub(super) fn pg_sql(sql: &str) -> String {
344    let bytes = sql.as_bytes();
345    let mut out = String::with_capacity(sql.len());
346    let mut i = 0;
347    while i < bytes.len() {
348        if bytes[i] == b'?' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
349            out.push('$');
350        } else {
351            out.push(bytes[i] as char);
352        }
353        i += 1;
354    }
355    out
356}
357
358/// Open a Postgres client for the state backend, honoring the URL's `sslmode`.
359///
360/// The state backend connects to its store using only a URL (`RIVET_STATE_URL`)
361/// — there is no YAML `tls:` block — so the transport-security policy is derived
362/// from the URL's `sslmode` query parameter, exactly as `rivet init` does for
363/// source connections. The connection itself goes through the shared
364/// [`crate::source::postgres::connect_client`] path so the state backend and
365/// source connections apply identical TLS rules.
366///
367/// - missing / `disable` / `prefer` / `allow` / unrecognized → `NoTls`
368///   (plaintext), keeping local and dev setups working unchanged.
369/// - `require` / `verify-ca` / `verify-full` → negotiate TLS.
370///
371/// Used by both [`StateStore::open_postgres`] and the parallel chunk-worker
372/// reconnection paths in `checkpoint.rs`, so every PG state connection is
373/// TLS-aware.
374pub(super) fn connect_pg(url: &str) -> Result<postgres::Client> {
375    let tls = state_tls_mode_from_url(url).map(|mode| crate::config::TlsConfig {
376        mode,
377        ..crate::config::TlsConfig::default()
378    });
379    crate::source::postgres::connect_client(url, tls.as_ref())
380        .map_err(|e| anyhow::anyhow!("state(pg): connect to '{}': {:#}", redact_pg_url(url), e))
381}
382
383/// Map the state URL's `sslmode` query parameter to a [`crate::config::TlsMode`].
384///
385/// Mirrors the source-side mapping in `crate::init::postgres`: `require` /
386/// `verify-ca` / `verify-full` enforce TLS; everything else — parameter missing,
387/// `disable`, `prefer`, `allow`, or an unrecognized value — returns `None`
388/// (plaintext `NoTls`). [`crate::config::TlsMode`] has no `prefer` variant, so no
389/// try-TLS-then-fallback is attempted. Last occurrence wins, matching libpq.
390fn state_tls_mode_from_url(url: &str) -> Option<crate::config::TlsMode> {
391    use crate::config::TlsMode;
392    let (_, query) = url.split_once('?')?;
393    let mut mode = None;
394    for pair in query.split('&') {
395        let (key, value) = pair.split_once('=').unwrap_or((pair, ""));
396        if key != "sslmode" {
397            continue;
398        }
399        mode = match value {
400            "require" => Some(TlsMode::Require),
401            "verify-ca" => Some(TlsMode::VerifyCa),
402            "verify-full" => Some(TlsMode::VerifyFull),
403            _ => None,
404        };
405    }
406    mode
407}
408
409// ─── Backend connection ────────────────────────────────────────────────────────
410
411/// Internal storage for the active database connection.
412pub(super) enum StateConn {
413    Sqlite(rusqlite::Connection),
414    /// postgres::Client requires `&mut self` for queries; RefCell provides
415    /// interior mutability so `StateStore` methods can keep `&self` signatures.
416    /// StateStore is not Sync (neither backend is), so RefCell is safe here.
417    /// Boxed to keep the enum variant sizes balanced (postgres::Client is ~320 B).
418    Postgres(Box<std::cell::RefCell<postgres::Client>>),
419}
420
421/// Serialisable reference that identifies a state database without holding a
422/// live connection.  Passed to parallel chunk workers so they can open their
423/// own connection for atomic `claim_next_chunk_task` operations.
424#[derive(Clone)]
425pub enum StateRef {
426    Sqlite(std::path::PathBuf),
427    Postgres(String),
428}
429
430// ─── SQLite migration ─────────────────────────────────────────────────────────
431
432fn ensure_schema_version_table(conn: &Connection) {
433    let _ = conn.execute_batch(
434        "CREATE TABLE IF NOT EXISTS schema_version (
435            version INTEGER NOT NULL
436        );",
437    );
438}
439
440fn get_current_version(conn: &Connection) -> i64 {
441    conn.query_row(
442        "SELECT COALESCE(MAX(version), 0) FROM schema_version",
443        [],
444        |row| row.get(0),
445    )
446    .unwrap_or(0)
447}
448
449fn migrate(conn: &Connection) -> Result<()> {
450    ensure_schema_version_table(conn);
451
452    let current = get_current_version(conn);
453
454    if current == 0 {
455        let has_export_state: bool = conn
456            .query_row(
457                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='export_state'",
458                [],
459                |row| row.get(0),
460            )
461            .unwrap_or(false);
462
463        if has_export_state {
464            let metrics_cols = [
465                "files_produced INTEGER DEFAULT 0",
466                "bytes_written INTEGER DEFAULT 0",
467                "retries INTEGER DEFAULT 0",
468                "validated INTEGER",
469                "schema_changed INTEGER",
470                "run_id TEXT",
471            ];
472            for col_def in &metrics_cols {
473                let sql = format!("ALTER TABLE export_metrics ADD COLUMN {}", col_def);
474                let _ = conn.execute(&sql, []);
475            }
476        }
477    }
478
479    for &(ver, sql) in MIGRATIONS {
480        if ver > current {
481            log::debug!("state: applying migration v{}", ver);
482            let atomic_sql = format!(
483                "BEGIN;\n{}\nINSERT INTO schema_version (version) VALUES ({});\nCOMMIT;",
484                sql, ver
485            );
486            conn.execute_batch(&atomic_sql)
487                .map_err(|e| anyhow::anyhow!("state: migration v{} failed: {}", ver, e))?;
488        }
489    }
490
491    let _ = conn.execute(
492        "DELETE FROM schema_version WHERE version < (SELECT MAX(version) FROM schema_version)",
493        [],
494    );
495
496    let final_version = get_current_version(conn);
497    if final_version != SCHEMA_VERSION {
498        anyhow::bail!(
499            "state: migration incomplete — expected schema v{} but reached v{}",
500            SCHEMA_VERSION,
501            final_version
502        );
503    }
504
505    Ok(())
506}
507
508// ─── PostgreSQL migration ─────────────────────────────────────────────────────
509
510fn migrate_pg(client: &mut postgres::Client) -> Result<()> {
511    client
512        .batch_execute("CREATE TABLE IF NOT EXISTS rivet_schema_version (version BIGINT NOT NULL);")
513        .map_err(|e| anyhow::anyhow!("state(pg): create version table: {:#}", e))?;
514
515    let current: i64 = client
516        .query_one(
517            "SELECT COALESCE(MAX(version), 0) FROM rivet_schema_version",
518            &[],
519        )
520        .map_err(|e| anyhow::anyhow!("state(pg): read schema version: {:#}", e))?
521        .get(0);
522
523    for &(ver, sql) in PG_MIGRATIONS {
524        if ver > current {
525            log::debug!("state(pg): applying migration v{}", ver);
526            let batch = format!(
527                "BEGIN; {} INSERT INTO rivet_schema_version (version) VALUES ({}); COMMIT;",
528                sql, ver
529            );
530            client
531                .batch_execute(&batch)
532                .map_err(|e| anyhow::anyhow!("state(pg): migration v{} failed: {:#}", ver, e))?;
533        }
534    }
535
536    // Remove superseded version rows so MAX() stays unambiguous (mirrors SQLite behaviour).
537    let _ = client.batch_execute(
538        "DELETE FROM rivet_schema_version \
539         WHERE version < (SELECT MAX(version) FROM rivet_schema_version);",
540    );
541
542    // Verify the DB actually reached the expected version.
543    let final_version: i64 = client
544        .query_one(
545            "SELECT COALESCE(MAX(version), 0) FROM rivet_schema_version",
546            &[],
547        )
548        .map_err(|e| anyhow::anyhow!("state(pg): read final schema version: {:#}", e))?
549        .get(0);
550    if final_version != SCHEMA_VERSION {
551        anyhow::bail!(
552            "state(pg): migration incomplete — expected schema v{} but reached v{}",
553            SCHEMA_VERSION,
554            final_version
555        );
556    }
557
558    Ok(())
559}
560
561/// Redact the password from a PostgreSQL URL for safe use in log/error messages.
562/// `postgresql://user:SECRET@host/db` → `postgresql://user:***@host/db`
563/// Uses `rfind('@')` so passwords containing `@` are handled correctly.
564fn redact_pg_url(url: &str) -> String {
565    if let Some(at_pos) = url.rfind('@')
566        && let Some(scheme_end) = url.find("://")
567    {
568        let authority = &url[scheme_end + 3..at_pos];
569        if let Some(colon) = authority.rfind(':') {
570            let user = &authority[..colon];
571            return format!(
572                "{}://{}:***@{}",
573                &url[..scheme_end],
574                user,
575                &url[at_pos + 1..]
576            );
577        }
578    }
579    url.to_string()
580}
581
582// ─── SQLite connection helper ─────────────────────────────────────────────────
583
584pub(crate) const SQLITE_BUSY_TIMEOUT_MS: i64 = 10_000;
585
586pub(crate) fn open_connection(db_path: &std::path::Path) -> Result<Connection> {
587    let conn = Connection::open(db_path)?;
588    if let Err(e) = conn.execute_batch("PRAGMA journal_mode=WAL;") {
589        log::warn!(
590            "state: WAL journal mode unavailable ({}); \
591             running in default mode — concurrent writes may be slower",
592            e
593        );
594    }
595    if let Err(e) = conn.execute_batch(&format!(
596        "PRAGMA busy_timeout = {};",
597        SQLITE_BUSY_TIMEOUT_MS
598    )) {
599        log::warn!(
600            "state: failed to set busy_timeout ({}); \
601             concurrent writers may surface SQLITE_BUSY immediately",
602            e
603        );
604    }
605    Ok(conn)
606}
607
608// ─── StateStore ───────────────────────────────────────────────────────────────
609
610/// Entry point for all persistent state.  Supports two backends:
611///
612/// - **SQLite** (default) — a single `.rivet_state.db` file next to the
613///   config.  Good for local / single-node / dev deployments.
614/// - **PostgreSQL** — a shared database addressed by `RIVET_STATE_URL`.
615///   Required for stateless container / Kubernetes deployments where the
616///   rivet pod is ephemeral or replicated.
617///
618/// Set the `RIVET_STATE_URL` environment variable to a PostgreSQL URL to
619/// activate the Postgres backend:
620///
621/// ```text
622/// RIVET_STATE_URL=postgresql://user:pass@host:5432/rivet_state
623/// ```
624///
625/// When the variable is absent or does not start with `postgres`, SQLite is
626/// used and the variable is ignored.
627pub struct StateStore {
628    pub(super) conn: StateConn,
629    /// Serialisable reference for reconnection (parallel chunk workers).
630    pub(super) state_ref: StateRef,
631}
632
633impl StateStore {
634    /// Open the appropriate backend.
635    ///
636    /// Checks `RIVET_STATE_URL`; falls back to SQLite next to `config_path`.
637    pub fn open(config_path: &str) -> Result<Self> {
638        if let Ok(url) = std::env::var("RIVET_STATE_URL")
639            && url.starts_with("postgres")
640        {
641            return Self::open_postgres(&url);
642        }
643        Self::open_sqlite(config_path)
644    }
645
646    fn open_sqlite(config_path: &str) -> Result<Self> {
647        let config_dir = std::path::Path::new(config_path)
648            .parent()
649            .unwrap_or(std::path::Path::new("."));
650        let db_path = config_dir.join(STATE_DB_NAME);
651        let conn = open_connection(&db_path)?;
652        migrate(&conn)?;
653        Ok(Self {
654            conn: StateConn::Sqlite(conn),
655            state_ref: StateRef::Sqlite(db_path),
656        })
657    }
658
659    fn open_postgres(url: &str) -> Result<Self> {
660        let is_local =
661            url.contains("localhost") || url.contains("127.0.0.1") || url.contains("::1");
662        if !is_local && state_tls_mode_from_url(url).is_none() {
663            log::warn!(
664                "state(pg): connecting to a remote host without TLS; \
665                 add sslmode=require (or verify-ca / verify-full) to RIVET_STATE_URL \
666                 to negotiate TLS for production use"
667            );
668        }
669        let mut client = connect_pg(url)?;
670        migrate_pg(&mut client)?;
671        Ok(Self {
672            conn: StateConn::Postgres(Box::new(std::cell::RefCell::new(client))),
673            state_ref: StateRef::Postgres(url.to_string()),
674        })
675    }
676
677    /// Path to `.rivet_state.db` for SQLite deployments.  Returns the config
678    /// directory path for Postgres (not meaningful for connection, only used
679    /// by legacy callers — prefer `state_ref()` for new code).
680    pub fn state_db_path(config_path: &str) -> std::path::PathBuf {
681        let config_dir = std::path::Path::new(config_path)
682            .parent()
683            .unwrap_or(std::path::Path::new("."));
684        config_dir.join(STATE_DB_NAME)
685    }
686
687    /// Serialisable connection reference for parallel chunk workers.
688    pub fn state_ref(&self) -> &StateRef {
689        &self.state_ref
690    }
691
692    /// In-memory SQLite store for unit tests.
693    #[allow(dead_code)]
694    pub fn open_in_memory() -> Result<Self> {
695        let conn = Connection::open_in_memory()?;
696        migrate(&conn)?;
697        Ok(Self {
698            conn: StateConn::Sqlite(conn),
699            state_ref: StateRef::Sqlite(std::path::PathBuf::from(":memory:")),
700        })
701    }
702
703    /// Open a SQLite store at an explicit file path (tests that need
704    /// cross-connection access via `claim_next_chunk_task_at_path`).
705    #[allow(dead_code)]
706    pub fn open_at_path(db_path: &std::path::Path) -> Result<Self> {
707        let conn = open_connection(db_path)?;
708        migrate(&conn)?;
709        Ok(Self {
710            conn: StateConn::Sqlite(conn),
711            state_ref: StateRef::Sqlite(db_path.to_path_buf()),
712        })
713    }
714}
715
716// ─── Migration tests ──────────────────────────────────────────────────────────
717
718#[cfg(test)]
719mod tests {
720    use super::*;
721
722    #[test]
723    fn fresh_db_reaches_latest_version() {
724        let s = StateStore::open_in_memory().unwrap();
725        let ver = match &s.conn {
726            StateConn::Sqlite(c) => get_current_version(c),
727            StateConn::Postgres(_) => unreachable!(),
728        };
729        assert_eq!(ver, SCHEMA_VERSION);
730    }
731
732    #[test]
733    fn migration_is_idempotent() {
734        let s = StateStore::open_in_memory().unwrap();
735        match &s.conn {
736            StateConn::Sqlite(c) => {
737                migrate(c).unwrap();
738                migrate(c).unwrap();
739                assert_eq!(get_current_version(c), SCHEMA_VERSION);
740            }
741            StateConn::Postgres(_) => unreachable!(),
742        }
743    }
744
745    #[test]
746    fn legacy_db_gets_upgraded() {
747        let conn = Connection::open_in_memory().unwrap();
748        conn.execute_batch(
749            "CREATE TABLE export_state (
750                export_name TEXT PRIMARY KEY,
751                last_cursor_value TEXT,
752                last_run_at TEXT
753            );
754            CREATE TABLE export_metrics (
755                id INTEGER PRIMARY KEY AUTOINCREMENT,
756                export_name TEXT NOT NULL,
757                run_at TEXT NOT NULL,
758                duration_ms INTEGER NOT NULL,
759                total_rows INTEGER NOT NULL,
760                status TEXT NOT NULL
761            );",
762        )
763        .unwrap();
764
765        migrate(&conn).unwrap();
766        assert_eq!(get_current_version(&conn), SCHEMA_VERSION);
767
768        let has_chunk_run: bool = conn
769            .query_row(
770                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='chunk_run'",
771                [],
772                |row| row.get(0),
773            )
774            .unwrap();
775        assert!(has_chunk_run);
776    }
777
778    #[test]
779    fn v8_renames_file_manifest_to_file_log() {
780        let s = StateStore::open_in_memory().unwrap();
781        let conn = match &s.conn {
782            StateConn::Sqlite(c) => c,
783            StateConn::Postgres(_) => unreachable!(),
784        };
785        let has_file_log: bool = conn
786            .query_row(
787                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='file_log'",
788                [],
789                |row| row.get(0),
790            )
791            .unwrap();
792        assert!(has_file_log, "v8 must produce a `file_log` table");
793        let has_old: bool = conn
794            .query_row(
795                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='file_manifest'",
796                [],
797                |row| row.get(0),
798            )
799            .unwrap();
800        assert!(!has_old, "v8 must remove the old `file_manifest` table");
801        let has_new_idx: bool = conn
802            .query_row(
803                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_file_log_export'",
804                [],
805                |row| row.get(0),
806            )
807            .unwrap();
808        assert!(has_new_idx, "v8 must create the renamed index");
809    }
810
811    #[test]
812    fn v8_upgrades_existing_v7_db_with_data() {
813        // Simulate an existing 0.6.0 database stopped at v7: the table is still
814        // named `file_manifest` and has rows.  v8 must rename it preserving data.
815        let conn = Connection::open_in_memory().unwrap();
816        // Apply v1..=v7 by running the migrator after manually stamping v7.
817        // Simpler: run the migrator, then manually rename back to v7 state to
818        // exercise the v7→v8 path.  Here we just verify forward path covers it.
819        migrate(&conn).unwrap();
820        // Insert a row using the new name (post-v8); the rename happened transparently.
821        conn.execute(
822            "INSERT INTO file_log (run_id, export_name, file_name, row_count, bytes, format, created_at)
823             VALUES ('r1', 'orders', 'f.parquet', 100, 4096, 'parquet', '2026-05-21T00:00:00Z')",
824            [],
825        )
826        .unwrap();
827        let count: i64 = conn
828            .query_row("SELECT COUNT(*) FROM file_log", [], |r| r.get(0))
829            .unwrap();
830        assert_eq!(count, 1);
831    }
832
833    #[test]
834    fn run_aggregate_table_exists_after_migration() {
835        let s = StateStore::open_in_memory().unwrap();
836        let conn = match &s.conn {
837            StateConn::Sqlite(c) => c,
838            StateConn::Postgres(_) => unreachable!(),
839        };
840        let exists: bool = conn
841            .query_row(
842                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='run_aggregate'",
843                [],
844                |row| row.get(0),
845            )
846            .unwrap();
847        assert!(exists, "v5 migration must create the run_aggregate table");
848    }
849
850    #[test]
851    fn pg_sql_converts_placeholders() {
852        assert_eq!(
853            pg_sql("SELECT ?1, ?2 FROM t WHERE x = ?3"),
854            "SELECT $1, $2 FROM t WHERE x = $3"
855        );
856        assert_eq!(
857            pg_sql("INSERT INTO t VALUES (?1, ?2)"),
858            "INSERT INTO t VALUES ($1, $2)"
859        );
860        assert_eq!(pg_sql("no placeholders"), "no placeholders");
861        // ?N with two digits
862        assert_eq!(pg_sql("?10 AND ?11"), "$10 AND $11");
863    }
864
865    #[test]
866    fn redact_pg_url_removes_password() {
867        assert_eq!(
868            redact_pg_url("postgresql://rivet:secret123@localhost:5433/rivet_state"),
869            "postgresql://rivet:***@localhost:5433/rivet_state"
870        );
871        assert_eq!(
872            redact_pg_url("postgres://admin:p@ssw0rd@db.prod.example.com/state"),
873            "postgres://admin:***@db.prod.example.com/state"
874        );
875    }
876
877    #[test]
878    fn redact_pg_url_no_password_unchanged() {
879        // URL without a password should come back as-is.
880        let url = "postgresql://rivet@localhost/state";
881        assert_eq!(redact_pg_url(url), url);
882    }
883
884    // ── state(pg) sslmode → TlsMode mapping ─────────────────────────────────
885    //
886    // Pins the decision behind the TLS bug fix: the state backend can no longer
887    // hard-code NoTls. We can't drive a live TLS handshake in a unit test, so we
888    // assert the *chosen transport policy* — TLS is enforced for require /
889    // verify-* and plaintext (NoTls) otherwise — which is what selects the
890    // connector inside `connect_pg` -> `connect_client`.
891    use crate::config::TlsMode;
892
893    #[test]
894    fn state_sslmode_enforced_values_negotiate_tls() {
895        for (url, want) in [
896            (
897                "postgresql://u:p@db.prod:5432/state?sslmode=require",
898                TlsMode::Require,
899            ),
900            (
901                "postgresql://u:p@db.prod/state?sslmode=verify-ca",
902                TlsMode::VerifyCa,
903            ),
904            (
905                "postgresql://u:p@db.prod/state?sslmode=verify-full",
906                TlsMode::VerifyFull,
907            ),
908        ] {
909            let mode = state_tls_mode_from_url(url);
910            assert_eq!(mode, Some(want), "url: {url}");
911            assert!(
912                mode.unwrap().is_enforced(),
913                "{want:?} must enforce TLS (not NoTls)"
914            );
915        }
916    }
917
918    #[test]
919    fn state_sslmode_plaintext_values_stay_notls() {
920        // Missing / disable / prefer / allow / unrecognized / uppercase all keep
921        // the original NoTls behavior, so dev + docker setups are unchanged.
922        for url in [
923            "postgresql://u:p@localhost/state",
924            "postgresql://u:p@localhost/state?sslmode=disable",
925            "postgresql://u:p@db/state?sslmode=prefer",
926            "postgresql://u:p@db/state?sslmode=allow",
927            "postgresql://u:p@db/state?sslmode=REQUIRE",
928            "postgresql://u:p@db/state?sslmode=garbage",
929            "postgresql://u:p@db/state?sslmode",
930            "postgresql://u:p@db/state?sslmode=",
931        ] {
932            assert_eq!(state_tls_mode_from_url(url), None, "url: {url}");
933        }
934    }
935
936    #[test]
937    fn state_sslmode_exact_key_and_last_occurrence_wins() {
938        // `xsslmode` is a different parameter; the exact `sslmode` key matters.
939        assert_eq!(
940            state_tls_mode_from_url("postgresql://u:p@db/state?xsslmode=require"),
941            None
942        );
943        // Found among other params.
944        assert_eq!(
945            state_tls_mode_from_url(
946                "postgresql://u:p@db/state?connect_timeout=10&sslmode=require&application_name=x"
947            ),
948            Some(TlsMode::Require)
949        );
950        // Last occurrence wins, matching libpq.
951        assert_eq!(
952            state_tls_mode_from_url("postgresql://u:p@db/state?sslmode=disable&sslmode=require"),
953            Some(TlsMode::Require)
954        );
955        assert_eq!(
956            state_tls_mode_from_url("postgresql://u:p@db/state?sslmode=require&sslmode=disable"),
957            None
958        );
959    }
960}