Skip to main content

rivet/state/
mod.rs

1use rusqlite::Connection;
2
3use crate::error::Result;
4
5mod checkpoint;
6mod cursor;
7mod file_log;
8mod journal_store;
9mod metrics;
10mod progression;
11mod run_aggregate;
12mod schema;
13mod shape;
14
15// Re-export domain types so callers use `rivet::state::*` unchanged.
16// Items below may not be explicitly named by all internal callers (often used
17// as inferred return types), but are part of the public integration-test API.
18#[allow(unused_imports)]
19pub use checkpoint::ChunkTaskInfo;
20#[allow(unused_imports)]
21pub use file_log::FileRecord;
22#[allow(unused_imports)]
23pub use metrics::ExportMetric;
24pub use metrics::MetricRow;
25#[allow(unused_imports)]
26pub use progression::{Boundary, ExportProgression};
27#[allow(unused_imports)]
28pub use run_aggregate::{RunAggregate, RunAggregateEntry};
29#[allow(unused_imports)]
30pub use schema::{SchemaChange, SchemaColumn, arrow_schema_to_columns, schema_fingerprint};
31#[allow(unused_imports)]
32pub use shape::ShapeWarning;
33
34const STATE_DB_NAME: &str = ".rivet_state.db";
35
36/// Current schema version — always the last entry in `MIGRATIONS`.
37const SCHEMA_VERSION: i64 = MIGRATIONS[MIGRATIONS.len() - 1].0;
38
39/// Each entry is `(version, sql)`.  Applied in order when the DB is behind.
40const MIGRATIONS: &[(i64, &str)] = &[
41    // v1: core tables
42    (
43        1,
44        "CREATE TABLE IF NOT EXISTS export_state (
45            export_name TEXT PRIMARY KEY,
46            last_cursor_value TEXT,
47            last_run_at TEXT
48        );
49        CREATE TABLE IF NOT EXISTS export_metrics (
50            id INTEGER PRIMARY KEY AUTOINCREMENT,
51            export_name TEXT NOT NULL,
52            run_at TEXT NOT NULL,
53            duration_ms INTEGER NOT NULL,
54            total_rows INTEGER NOT NULL,
55            peak_rss_mb INTEGER,
56            status TEXT NOT NULL,
57            error_message TEXT,
58            tuning_profile TEXT,
59            format TEXT,
60            mode TEXT,
61            files_produced INTEGER DEFAULT 0,
62            bytes_written INTEGER DEFAULT 0,
63            retries INTEGER DEFAULT 0,
64            validated INTEGER,
65            schema_changed INTEGER,
66            run_id TEXT
67        );
68        CREATE TABLE IF NOT EXISTS export_schema (
69            export_name TEXT PRIMARY KEY,
70            columns_json TEXT NOT NULL,
71            updated_at TEXT NOT NULL
72        );
73        CREATE TABLE IF NOT EXISTS file_manifest (
74            id INTEGER PRIMARY KEY AUTOINCREMENT,
75            run_id TEXT NOT NULL,
76            export_name TEXT NOT NULL,
77            file_name TEXT NOT NULL,
78            row_count INTEGER NOT NULL,
79            bytes INTEGER NOT NULL,
80            format TEXT NOT NULL,
81            compression TEXT,
82            created_at TEXT NOT NULL
83        );",
84    ),
85    // v2: chunk checkpoint tables
86    (
87        2,
88        "CREATE TABLE IF NOT EXISTS chunk_run (
89            run_id TEXT PRIMARY KEY,
90            export_name TEXT NOT NULL,
91            plan_hash TEXT NOT NULL,
92            status TEXT NOT NULL,
93            max_chunk_attempts INTEGER NOT NULL DEFAULT 3,
94            created_at TEXT NOT NULL,
95            updated_at TEXT NOT NULL
96        );
97        CREATE INDEX IF NOT EXISTS idx_chunk_run_export_status
98            ON chunk_run(export_name, status);
99        CREATE TABLE IF NOT EXISTS chunk_task (
100            id INTEGER PRIMARY KEY AUTOINCREMENT,
101            run_id TEXT NOT NULL,
102            chunk_index INTEGER NOT NULL,
103            start_key TEXT NOT NULL,
104            end_key TEXT NOT NULL,
105            status TEXT NOT NULL,
106            attempts INTEGER NOT NULL DEFAULT 0,
107            last_error TEXT,
108            rows_written INTEGER,
109            file_name TEXT,
110            updated_at TEXT NOT NULL,
111            UNIQUE(run_id, chunk_index)
112        );
113        CREATE INDEX IF NOT EXISTS idx_chunk_task_run_status ON chunk_task(run_id, status);",
114    ),
115    // v3: index on file_manifest for faster per-export lookups
116    (
117        3,
118        "CREATE INDEX IF NOT EXISTS idx_file_manifest_export ON file_manifest(export_name, id DESC);",
119    ),
120    // v4: committed / verified boundary tracking (ADR-0008, Epic G)
121    (
122        4,
123        "CREATE TABLE IF NOT EXISTS export_progression (
124            export_name TEXT PRIMARY KEY,
125            last_committed_strategy TEXT,
126            last_committed_cursor TEXT,
127            last_committed_chunk_index INTEGER,
128            last_committed_run_id TEXT,
129            last_committed_at TEXT,
130            last_verified_strategy TEXT,
131            last_verified_cursor TEXT,
132            last_verified_chunk_index INTEGER,
133            last_verified_run_id TEXT,
134            last_verified_at TEXT
135        );",
136    ),
137    // v5: aggregate run summary
138    (
139        5,
140        "CREATE TABLE IF NOT EXISTS run_aggregate (
141            run_aggregate_id TEXT PRIMARY KEY,
142            started_at TEXT NOT NULL,
143            finished_at TEXT NOT NULL,
144            duration_ms INTEGER NOT NULL,
145            config_path TEXT,
146            parallel_mode TEXT NOT NULL,
147            total_exports INTEGER NOT NULL,
148            success_count INTEGER NOT NULL,
149            failed_count INTEGER NOT NULL,
150            skipped_count INTEGER NOT NULL,
151            total_rows INTEGER NOT NULL,
152            total_files INTEGER NOT NULL,
153            total_bytes INTEGER NOT NULL,
154            details_json TEXT NOT NULL
155        );
156        CREATE INDEX IF NOT EXISTS idx_run_aggregate_finished
157            ON run_aggregate(finished_at DESC);",
158    ),
159    // v6: per-column data shape stats
160    (
161        6,
162        "CREATE TABLE IF NOT EXISTS export_shape (
163            export_name TEXT NOT NULL,
164            column_name TEXT NOT NULL,
165            max_byte_len INTEGER NOT NULL,
166            updated_at TEXT NOT NULL,
167            PRIMARY KEY (export_name, column_name)
168        );",
169    ),
170    // v7: structured run journal
171    (
172        7,
173        "CREATE TABLE IF NOT EXISTS run_journal (
174            run_id TEXT PRIMARY KEY,
175            export_name TEXT NOT NULL,
176            finished_at TEXT NOT NULL,
177            journal_json TEXT NOT NULL
178        );
179        CREATE INDEX IF NOT EXISTS idx_run_journal_export
180            ON run_journal(export_name, finished_at DESC);",
181    ),
182    // v8: rename file_manifest → file_log.  The 0.7.0 cloud-output contract
183    // reclaims the "manifest" name for the public JSON artifact; the internal
184    // SQLite log of written files becomes `file_log` to remove the overload.
185    (
186        8,
187        "ALTER TABLE file_manifest RENAME TO file_log;
188        DROP INDEX IF EXISTS idx_file_manifest_export;
189        CREATE INDEX IF NOT EXISTS idx_file_log_export ON file_log(export_name, id DESC);",
190    ),
191    // v9: extended per-run metrics for post-pilot analysis — source harm
192    // (pg_temp_bytes_delta), completeness (reconciled, source_count,
193    // quality_passed), memory (batch_size[_memory_mb]), and config dimensions
194    // (chunk_size, parallel, source/destination type, rivet_version). All
195    // additive + nullable: old rows read NULL, no backfill, reads stay forward-
196    // compatible.
197    (
198        9,
199        "ALTER TABLE export_metrics ADD COLUMN files_committed INTEGER;
200        ALTER TABLE export_metrics ADD COLUMN reconciled INTEGER;
201        ALTER TABLE export_metrics ADD COLUMN source_count INTEGER;
202        ALTER TABLE export_metrics ADD COLUMN quality_passed INTEGER;
203        ALTER TABLE export_metrics ADD COLUMN pg_temp_bytes_delta INTEGER;
204        ALTER TABLE export_metrics ADD COLUMN batch_size INTEGER;
205        ALTER TABLE export_metrics ADD COLUMN batch_size_memory_mb INTEGER;
206        ALTER TABLE export_metrics ADD COLUMN skip_reason TEXT;
207        ALTER TABLE export_metrics ADD COLUMN schema_fingerprint TEXT;
208        ALTER TABLE export_metrics ADD COLUMN chunk_size INTEGER;
209        ALTER TABLE export_metrics ADD COLUMN parallel INTEGER;
210        ALTER TABLE export_metrics ADD COLUMN source_type TEXT;
211        ALTER TABLE export_metrics ADD COLUMN destination_type TEXT;
212        ALTER TABLE export_metrics ADD COLUMN rivet_version TEXT;",
213    ),
214    // v10: longest single-chunk wall time (ms) — the #5 source-harm lever,
215    // aggregated at finalize from the run journal's per-chunk timings.
216    (
217        10,
218        "ALTER TABLE export_metrics ADD COLUMN longest_chunk_ms INTEGER;",
219    ),
220    // v11: per-run source-harm deltas (locks, rows read, buffer misses, temp
221    // files) — one row per counter, keyed on run_id. Engine-neutral key/value so
222    // each engine's counter set lands without schema churn. Written from
223    // pipeline::job::harm_snapshot via source::{postgres,mysql,mssql}.
224    (
225        11,
226        "CREATE TABLE IF NOT EXISTS export_harm (
227            id INTEGER PRIMARY KEY AUTOINCREMENT,
228            run_id TEXT NOT NULL,
229            export_name TEXT NOT NULL,
230            metric TEXT NOT NULL,
231            delta INTEGER NOT NULL,
232            recorded_at TEXT NOT NULL
233        );
234        CREATE INDEX IF NOT EXISTS idx_export_harm_run ON export_harm(run_id);",
235    ),
236];
237
238/// PostgreSQL-compatible DDL.  Column types differ from SQLite (BIGSERIAL,
239/// BOOLEAN); placeholder style is `$N` (handled by callers via `pg_sql()`).
240const PG_MIGRATIONS: &[(i64, &str)] = &[
241    (
242        1,
243        "CREATE TABLE IF NOT EXISTS export_state (
244            export_name TEXT PRIMARY KEY,
245            last_cursor_value TEXT,
246            last_run_at TEXT
247        );
248        CREATE TABLE IF NOT EXISTS export_metrics (
249            id BIGSERIAL PRIMARY KEY,
250            export_name TEXT NOT NULL,
251            run_at TEXT NOT NULL,
252            duration_ms BIGINT NOT NULL,
253            total_rows BIGINT NOT NULL,
254            peak_rss_mb BIGINT,
255            status TEXT NOT NULL,
256            error_message TEXT,
257            tuning_profile TEXT,
258            format TEXT,
259            mode TEXT,
260            files_produced BIGINT DEFAULT 0,
261            bytes_written BIGINT DEFAULT 0,
262            retries BIGINT DEFAULT 0,
263            validated BOOLEAN,
264            schema_changed BOOLEAN,
265            run_id TEXT
266        );
267        CREATE TABLE IF NOT EXISTS export_schema (
268            export_name TEXT PRIMARY KEY,
269            columns_json TEXT NOT NULL,
270            updated_at TEXT NOT NULL
271        );
272        CREATE TABLE IF NOT EXISTS file_manifest (
273            id BIGSERIAL PRIMARY KEY,
274            run_id TEXT NOT NULL,
275            export_name TEXT NOT NULL,
276            file_name TEXT NOT NULL,
277            row_count BIGINT NOT NULL,
278            bytes BIGINT NOT NULL,
279            format TEXT NOT NULL,
280            compression TEXT,
281            created_at TEXT NOT NULL
282        );",
283    ),
284    (
285        2,
286        "CREATE TABLE IF NOT EXISTS chunk_run (
287            run_id TEXT PRIMARY KEY,
288            export_name TEXT NOT NULL,
289            plan_hash TEXT NOT NULL,
290            status TEXT NOT NULL,
291            max_chunk_attempts BIGINT NOT NULL DEFAULT 3,
292            created_at TEXT NOT NULL,
293            updated_at TEXT NOT NULL
294        );
295        CREATE INDEX IF NOT EXISTS idx_chunk_run_export_status
296            ON chunk_run(export_name, status);
297        CREATE TABLE IF NOT EXISTS chunk_task (
298            id BIGSERIAL PRIMARY KEY,
299            run_id TEXT NOT NULL,
300            chunk_index BIGINT NOT NULL,
301            start_key TEXT NOT NULL,
302            end_key TEXT NOT NULL,
303            status TEXT NOT NULL,
304            attempts BIGINT NOT NULL DEFAULT 0,
305            last_error TEXT,
306            rows_written BIGINT,
307            file_name TEXT,
308            updated_at TEXT NOT NULL,
309            UNIQUE(run_id, chunk_index)
310        );
311        CREATE INDEX IF NOT EXISTS idx_chunk_task_run_status ON chunk_task(run_id, status);",
312    ),
313    (
314        3,
315        "CREATE INDEX IF NOT EXISTS idx_file_manifest_export ON file_manifest(export_name, id DESC);",
316    ),
317    (
318        4,
319        "CREATE TABLE IF NOT EXISTS export_progression (
320            export_name TEXT PRIMARY KEY,
321            last_committed_strategy TEXT,
322            last_committed_cursor TEXT,
323            last_committed_chunk_index BIGINT,
324            last_committed_run_id TEXT,
325            last_committed_at TEXT,
326            last_verified_strategy TEXT,
327            last_verified_cursor TEXT,
328            last_verified_chunk_index BIGINT,
329            last_verified_run_id TEXT,
330            last_verified_at TEXT
331        );",
332    ),
333    (
334        5,
335        "CREATE TABLE IF NOT EXISTS run_aggregate (
336            run_aggregate_id TEXT PRIMARY KEY,
337            started_at TEXT NOT NULL,
338            finished_at TEXT NOT NULL,
339            duration_ms BIGINT NOT NULL,
340            config_path TEXT,
341            parallel_mode TEXT NOT NULL,
342            total_exports BIGINT NOT NULL,
343            success_count BIGINT NOT NULL,
344            failed_count BIGINT NOT NULL,
345            skipped_count BIGINT NOT NULL,
346            total_rows BIGINT NOT NULL,
347            total_files BIGINT NOT NULL,
348            total_bytes BIGINT NOT NULL,
349            details_json TEXT NOT NULL
350        );
351        CREATE INDEX IF NOT EXISTS idx_run_aggregate_finished
352            ON run_aggregate(finished_at DESC);",
353    ),
354    (
355        6,
356        "CREATE TABLE IF NOT EXISTS export_shape (
357            export_name TEXT NOT NULL,
358            column_name TEXT NOT NULL,
359            max_byte_len BIGINT NOT NULL,
360            updated_at TEXT NOT NULL,
361            PRIMARY KEY (export_name, column_name)
362        );",
363    ),
364    (
365        7,
366        "CREATE TABLE IF NOT EXISTS run_journal (
367            run_id TEXT PRIMARY KEY,
368            export_name TEXT NOT NULL,
369            finished_at TEXT NOT NULL,
370            journal_json TEXT NOT NULL
371        );
372        CREATE INDEX IF NOT EXISTS idx_run_journal_export
373            ON run_journal(export_name, finished_at DESC);",
374    ),
375    // v8: rename file_manifest → file_log.  Mirrors the SQLite v8 migration;
376    // see the SQLite array for rationale.
377    (
378        8,
379        "ALTER TABLE file_manifest RENAME TO file_log;
380        DROP INDEX IF EXISTS idx_file_manifest_export;
381        CREATE INDEX IF NOT EXISTS idx_file_log_export ON file_log(export_name, id DESC);",
382    ),
383    // v9: extended per-run metrics (see the SQLite array for rationale).
384    // Additive + nullable; BOOLEAN for the bool flags, BIGINT for counts.
385    (
386        9,
387        "ALTER TABLE export_metrics ADD COLUMN files_committed BIGINT;
388        ALTER TABLE export_metrics ADD COLUMN reconciled BOOLEAN;
389        ALTER TABLE export_metrics ADD COLUMN source_count BIGINT;
390        ALTER TABLE export_metrics ADD COLUMN quality_passed BOOLEAN;
391        ALTER TABLE export_metrics ADD COLUMN pg_temp_bytes_delta BIGINT;
392        ALTER TABLE export_metrics ADD COLUMN batch_size BIGINT;
393        ALTER TABLE export_metrics ADD COLUMN batch_size_memory_mb BIGINT;
394        ALTER TABLE export_metrics ADD COLUMN skip_reason TEXT;
395        ALTER TABLE export_metrics ADD COLUMN schema_fingerprint TEXT;
396        ALTER TABLE export_metrics ADD COLUMN chunk_size BIGINT;
397        ALTER TABLE export_metrics ADD COLUMN parallel BIGINT;
398        ALTER TABLE export_metrics ADD COLUMN source_type TEXT;
399        ALTER TABLE export_metrics ADD COLUMN destination_type TEXT;
400        ALTER TABLE export_metrics ADD COLUMN rivet_version TEXT;",
401    ),
402    // v10: longest single-chunk wall time (ms). See the SQLite array.
403    (
404        10,
405        "ALTER TABLE export_metrics ADD COLUMN longest_chunk_ms BIGINT;",
406    ),
407    // v11: per-run source-harm deltas (see the SQLite array for rationale).
408    (
409        11,
410        "CREATE TABLE IF NOT EXISTS export_harm (
411            id BIGSERIAL PRIMARY KEY,
412            run_id TEXT NOT NULL,
413            export_name TEXT NOT NULL,
414            metric TEXT NOT NULL,
415            delta BIGINT NOT NULL,
416            recorded_at TEXT NOT NULL
417        );
418        CREATE INDEX IF NOT EXISTS idx_export_harm_run ON export_harm(run_id);",
419    ),
420];
421
422// ─── SQL helpers ──────────────────────────────────────────────────────────────
423
424/// Convert SQLite `?N` placeholders to PostgreSQL `$N` style.
425/// `"WHERE x = ?1 AND y = ?2"` → `"WHERE x = $1 AND y = $2"`.
426pub(super) fn pg_sql(sql: &str) -> String {
427    let bytes = sql.as_bytes();
428    let mut out = String::with_capacity(sql.len());
429    let mut i = 0;
430    while i < bytes.len() {
431        if bytes[i] == b'?' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
432            out.push('$');
433        } else {
434            out.push(bytes[i] as char);
435        }
436        i += 1;
437    }
438    out
439}
440
441/// Open a Postgres client for the state backend, honoring the URL's `sslmode`.
442///
443/// The state backend connects to its store using only a URL (`RIVET_STATE_URL`)
444/// — there is no YAML `tls:` block — so the transport-security policy is derived
445/// from the URL's `sslmode` query parameter, exactly as `rivet init` does for
446/// source connections. The connection itself goes through the shared
447/// [`crate::source::postgres::connect_client`] path so the state backend and
448/// source connections apply identical TLS rules.
449///
450/// - missing / `disable` / `prefer` / `allow` / unrecognized → `NoTls`
451///   (plaintext), keeping local and dev setups working unchanged.
452/// - `require` / `verify-ca` / `verify-full` → negotiate TLS.
453///
454/// Used by both [`StateStore::open_postgres`] and the parallel chunk-worker
455/// reconnection paths in `checkpoint.rs`, so every PG state connection is
456/// TLS-aware.
457pub(super) fn connect_pg(url: &str) -> Result<postgres::Client> {
458    let tls = state_tls_mode_from_url(url).map(|mode| crate::config::TlsConfig {
459        mode,
460        ..crate::config::TlsConfig::default()
461    });
462    crate::source::postgres::connect_client(url, tls.as_ref())
463        .map_err(|e| anyhow::anyhow!("state(pg): connect to '{}': {:#}", redact_pg_url(url), e))
464}
465
466/// Map the state URL's `sslmode` query parameter to a [`crate::config::TlsMode`].
467///
468/// Mirrors the source-side mapping in `crate::init::postgres`: `require` /
469/// `verify-ca` / `verify-full` enforce TLS; everything else — parameter missing,
470/// `disable`, `prefer`, `allow`, or an unrecognized value — returns `None`
471/// (plaintext `NoTls`). [`crate::config::TlsMode`] has no `prefer` variant, so no
472/// try-TLS-then-fallback is attempted. Last occurrence wins, matching libpq.
473fn state_tls_mode_from_url(url: &str) -> Option<crate::config::TlsMode> {
474    use crate::config::TlsMode;
475    let (_, query) = url.split_once('?')?;
476    let mut mode = None;
477    for pair in query.split('&') {
478        let (key, value) = pair.split_once('=').unwrap_or((pair, ""));
479        if key != "sslmode" {
480            continue;
481        }
482        mode = match value {
483            "require" => Some(TlsMode::Require),
484            "verify-ca" => Some(TlsMode::VerifyCa),
485            "verify-full" => Some(TlsMode::VerifyFull),
486            _ => None,
487        };
488    }
489    mode
490}
491
492// ─── Backend connection ────────────────────────────────────────────────────────
493
494/// Internal storage for the active database connection.
495pub(super) enum StateConn {
496    Sqlite(rusqlite::Connection),
497    /// postgres::Client requires `&mut self` for queries; RefCell provides
498    /// interior mutability so `StateStore` methods can keep `&self` signatures.
499    /// StateStore is not Sync (neither backend is), so RefCell is safe here.
500    /// Boxed to keep the enum variant sizes balanced (postgres::Client is ~320 B).
501    Postgres(Box<std::cell::RefCell<postgres::Client>>),
502}
503
504/// Serialisable reference that identifies a state database without holding a
505/// live connection.  Passed to parallel chunk workers so they can open their
506/// own connection for atomic `claim_next_chunk_task` operations.
507#[derive(Clone)]
508pub enum StateRef {
509    Sqlite(std::path::PathBuf),
510    Postgres(String),
511}
512
513// ─── SQLite migration ─────────────────────────────────────────────────────────
514
515fn ensure_schema_version_table(conn: &Connection) {
516    let _ = conn.execute_batch(
517        "CREATE TABLE IF NOT EXISTS schema_version (
518            version INTEGER NOT NULL
519        );",
520    );
521}
522
523fn get_current_version(conn: &Connection) -> i64 {
524    conn.query_row(
525        "SELECT COALESCE(MAX(version), 0) FROM schema_version",
526        [],
527        |row| row.get(0),
528    )
529    .unwrap_or(0)
530}
531
532fn migrate(conn: &Connection) -> Result<()> {
533    ensure_schema_version_table(conn);
534
535    let current = get_current_version(conn);
536
537    if current == 0 {
538        let has_export_state: bool = conn
539            .query_row(
540                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='export_state'",
541                [],
542                |row| row.get(0),
543            )
544            .unwrap_or(false);
545
546        if has_export_state {
547            let metrics_cols = [
548                "files_produced INTEGER DEFAULT 0",
549                "bytes_written INTEGER DEFAULT 0",
550                "retries INTEGER DEFAULT 0",
551                "validated INTEGER",
552                "schema_changed INTEGER",
553                "run_id TEXT",
554            ];
555            for col_def in &metrics_cols {
556                let sql = format!("ALTER TABLE export_metrics ADD COLUMN {}", col_def);
557                let _ = conn.execute(&sql, []);
558            }
559        }
560    }
561
562    for &(ver, sql) in MIGRATIONS {
563        if ver > current {
564            log::debug!("state: applying migration v{}", ver);
565            let atomic_sql = format!(
566                "BEGIN;\n{}\nINSERT INTO schema_version (version) VALUES ({});\nCOMMIT;",
567                sql, ver
568            );
569            conn.execute_batch(&atomic_sql)
570                .map_err(|e| anyhow::anyhow!("state: migration v{} failed: {}", ver, e))?;
571        }
572    }
573
574    let _ = conn.execute(
575        "DELETE FROM schema_version WHERE version < (SELECT MAX(version) FROM schema_version)",
576        [],
577    );
578
579    let final_version = get_current_version(conn);
580    if final_version != SCHEMA_VERSION {
581        anyhow::bail!(
582            "state: migration incomplete — expected schema v{} but reached v{}",
583            SCHEMA_VERSION,
584            final_version
585        );
586    }
587
588    Ok(())
589}
590
591// ─── PostgreSQL migration ─────────────────────────────────────────────────────
592
593fn migrate_pg(client: &mut postgres::Client) -> Result<()> {
594    client
595        .batch_execute("CREATE TABLE IF NOT EXISTS rivet_schema_version (version BIGINT NOT NULL);")
596        .map_err(|e| anyhow::anyhow!("state(pg): create version table: {:#}", e))?;
597
598    let current: i64 = client
599        .query_one(
600            "SELECT COALESCE(MAX(version), 0) FROM rivet_schema_version",
601            &[],
602        )
603        .map_err(|e| anyhow::anyhow!("state(pg): read schema version: {:#}", e))?
604        .get(0);
605
606    for &(ver, sql) in PG_MIGRATIONS {
607        if ver > current {
608            log::debug!("state(pg): applying migration v{}", ver);
609            let batch = format!(
610                "BEGIN; {} INSERT INTO rivet_schema_version (version) VALUES ({}); COMMIT;",
611                sql, ver
612            );
613            client
614                .batch_execute(&batch)
615                .map_err(|e| anyhow::anyhow!("state(pg): migration v{} failed: {:#}", ver, e))?;
616        }
617    }
618
619    // Remove superseded version rows so MAX() stays unambiguous (mirrors SQLite behaviour).
620    let _ = client.batch_execute(
621        "DELETE FROM rivet_schema_version \
622         WHERE version < (SELECT MAX(version) FROM rivet_schema_version);",
623    );
624
625    // Verify the DB actually reached the expected version.
626    let final_version: i64 = client
627        .query_one(
628            "SELECT COALESCE(MAX(version), 0) FROM rivet_schema_version",
629            &[],
630        )
631        .map_err(|e| anyhow::anyhow!("state(pg): read final schema version: {:#}", e))?
632        .get(0);
633    if final_version != SCHEMA_VERSION {
634        anyhow::bail!(
635            "state(pg): migration incomplete — expected schema v{} but reached v{}",
636            SCHEMA_VERSION,
637            final_version
638        );
639    }
640
641    Ok(())
642}
643
644/// Redact the password from a PostgreSQL URL for safe use in log/error messages.
645/// `postgresql://user:SECRET@host/db` → `postgresql://user:***@host/db`
646/// Uses `rfind('@')` so passwords containing `@` are handled correctly.
647fn redact_pg_url(url: &str) -> String {
648    if let Some(at_pos) = url.rfind('@')
649        && let Some(scheme_end) = url.find("://")
650    {
651        let authority = &url[scheme_end + 3..at_pos];
652        if let Some(colon) = authority.rfind(':') {
653            let user = &authority[..colon];
654            return format!(
655                "{}://{}:***@{}",
656                &url[..scheme_end],
657                user,
658                &url[at_pos + 1..]
659            );
660        }
661    }
662    url.to_string()
663}
664
665// ─── SQLite connection helper ─────────────────────────────────────────────────
666
667pub(crate) const SQLITE_BUSY_TIMEOUT_MS: i64 = 10_000;
668
669pub(crate) fn open_connection(db_path: &std::path::Path) -> Result<Connection> {
670    let conn = Connection::open(db_path)?;
671    if let Err(e) = conn.execute_batch("PRAGMA journal_mode=WAL;") {
672        log::warn!(
673            "state: WAL journal mode unavailable ({}); \
674             running in default mode — concurrent writes may be slower",
675            e
676        );
677    }
678    if let Err(e) = conn.execute_batch(&format!(
679        "PRAGMA busy_timeout = {};",
680        SQLITE_BUSY_TIMEOUT_MS
681    )) {
682        log::warn!(
683            "state: failed to set busy_timeout ({}); \
684             concurrent writers may surface SQLITE_BUSY immediately",
685            e
686        );
687    }
688    Ok(conn)
689}
690
691// ─── StateStore ───────────────────────────────────────────────────────────────
692
693/// Entry point for all persistent state.  Supports two backends:
694///
695/// - **SQLite** (default) — a single `.rivet_state.db` file next to the
696///   config.  Good for local / single-node / dev deployments.
697/// - **PostgreSQL** — a shared database addressed by `RIVET_STATE_URL`.
698///   Required for stateless container / Kubernetes deployments where the
699///   rivet pod is ephemeral or replicated.
700///
701/// Set the `RIVET_STATE_URL` environment variable to a PostgreSQL URL to
702/// activate the Postgres backend:
703///
704/// ```text
705/// RIVET_STATE_URL=postgresql://user:pass@host:5432/rivet_state
706/// ```
707///
708/// When the variable is absent or does not start with `postgres`, SQLite is
709/// used and the variable is ignored.
710pub struct StateStore {
711    pub(super) conn: StateConn,
712    /// Serialisable reference for reconnection (parallel chunk workers).
713    pub(super) state_ref: StateRef,
714}
715
716impl StateStore {
717    /// Open the appropriate backend.
718    ///
719    /// Checks `RIVET_STATE_URL`; falls back to SQLite next to `config_path`.
720    pub fn open(config_path: &str) -> Result<Self> {
721        if let Ok(url) = std::env::var("RIVET_STATE_URL")
722            && url.starts_with("postgres")
723        {
724            return Self::open_postgres(&url);
725        }
726        Self::open_sqlite(config_path)
727    }
728
729    fn open_sqlite(config_path: &str) -> Result<Self> {
730        let config_dir = std::path::Path::new(config_path)
731            .parent()
732            .unwrap_or(std::path::Path::new("."));
733        let db_path = config_dir.join(STATE_DB_NAME);
734        let conn = open_connection(&db_path)?;
735        migrate(&conn)?;
736        Ok(Self {
737            conn: StateConn::Sqlite(conn),
738            state_ref: StateRef::Sqlite(db_path),
739        })
740    }
741
742    fn open_postgres(url: &str) -> Result<Self> {
743        let is_local =
744            url.contains("localhost") || url.contains("127.0.0.1") || url.contains("::1");
745        if !is_local && state_tls_mode_from_url(url).is_none() {
746            log::warn!(
747                "state(pg): connecting to a remote host without TLS; \
748                 add sslmode=require (or verify-ca / verify-full) to RIVET_STATE_URL \
749                 to negotiate TLS for production use"
750            );
751        }
752        let mut client = connect_pg(url)?;
753        migrate_pg(&mut client)?;
754        Ok(Self {
755            conn: StateConn::Postgres(Box::new(std::cell::RefCell::new(client))),
756            state_ref: StateRef::Postgres(url.to_string()),
757        })
758    }
759
760    /// Path to `.rivet_state.db` for SQLite deployments.  Returns the config
761    /// directory path for Postgres (not meaningful for connection, only used
762    /// by legacy callers — prefer `state_ref()` for new code).
763    pub fn state_db_path(config_path: &str) -> std::path::PathBuf {
764        let config_dir = std::path::Path::new(config_path)
765            .parent()
766            .unwrap_or(std::path::Path::new("."));
767        config_dir.join(STATE_DB_NAME)
768    }
769
770    /// Serialisable connection reference for parallel chunk workers.
771    pub fn state_ref(&self) -> &StateRef {
772        &self.state_ref
773    }
774
775    /// In-memory SQLite store for unit tests.
776    #[allow(dead_code)]
777    pub fn open_in_memory() -> Result<Self> {
778        let conn = Connection::open_in_memory()?;
779        migrate(&conn)?;
780        Ok(Self {
781            conn: StateConn::Sqlite(conn),
782            state_ref: StateRef::Sqlite(std::path::PathBuf::from(":memory:")),
783        })
784    }
785
786    /// Open a SQLite store at an explicit file path (tests that need
787    /// cross-connection access via `claim_next_chunk_task_at_path`).
788    #[allow(dead_code)]
789    pub fn open_at_path(db_path: &std::path::Path) -> Result<Self> {
790        let conn = open_connection(db_path)?;
791        migrate(&conn)?;
792        Ok(Self {
793            conn: StateConn::Sqlite(conn),
794            state_ref: StateRef::Sqlite(db_path.to_path_buf()),
795        })
796    }
797}
798
799// ─── Migration tests ──────────────────────────────────────────────────────────
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804
805    #[test]
806    fn fresh_db_reaches_latest_version() {
807        let s = StateStore::open_in_memory().unwrap();
808        let ver = match &s.conn {
809            StateConn::Sqlite(c) => get_current_version(c),
810            StateConn::Postgres(_) => unreachable!(),
811        };
812        assert_eq!(ver, SCHEMA_VERSION);
813    }
814
815    #[test]
816    fn migration_is_idempotent() {
817        let s = StateStore::open_in_memory().unwrap();
818        match &s.conn {
819            StateConn::Sqlite(c) => {
820                migrate(c).unwrap();
821                migrate(c).unwrap();
822                assert_eq!(get_current_version(c), SCHEMA_VERSION);
823            }
824            StateConn::Postgres(_) => unreachable!(),
825        }
826    }
827
828    #[test]
829    fn legacy_db_gets_upgraded() {
830        let conn = Connection::open_in_memory().unwrap();
831        conn.execute_batch(
832            "CREATE TABLE export_state (
833                export_name TEXT PRIMARY KEY,
834                last_cursor_value TEXT,
835                last_run_at TEXT
836            );
837            CREATE TABLE export_metrics (
838                id INTEGER PRIMARY KEY AUTOINCREMENT,
839                export_name TEXT NOT NULL,
840                run_at TEXT NOT NULL,
841                duration_ms INTEGER NOT NULL,
842                total_rows INTEGER NOT NULL,
843                status TEXT NOT NULL
844            );",
845        )
846        .unwrap();
847
848        migrate(&conn).unwrap();
849        assert_eq!(get_current_version(&conn), SCHEMA_VERSION);
850
851        let has_chunk_run: bool = conn
852            .query_row(
853                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='chunk_run'",
854                [],
855                |row| row.get(0),
856            )
857            .unwrap();
858        assert!(has_chunk_run);
859    }
860
861    #[test]
862    fn v8_renames_file_manifest_to_file_log() {
863        let s = StateStore::open_in_memory().unwrap();
864        let conn = match &s.conn {
865            StateConn::Sqlite(c) => c,
866            StateConn::Postgres(_) => unreachable!(),
867        };
868        let has_file_log: bool = conn
869            .query_row(
870                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='file_log'",
871                [],
872                |row| row.get(0),
873            )
874            .unwrap();
875        assert!(has_file_log, "v8 must produce a `file_log` table");
876        let has_old: bool = conn
877            .query_row(
878                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='file_manifest'",
879                [],
880                |row| row.get(0),
881            )
882            .unwrap();
883        assert!(!has_old, "v8 must remove the old `file_manifest` table");
884        let has_new_idx: bool = conn
885            .query_row(
886                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_file_log_export'",
887                [],
888                |row| row.get(0),
889            )
890            .unwrap();
891        assert!(has_new_idx, "v8 must create the renamed index");
892    }
893
894    #[test]
895    fn v8_upgrades_existing_v7_db_with_data() {
896        // Simulate an existing 0.6.0 database stopped at v7: the table is still
897        // named `file_manifest` and has rows.  v8 must rename it preserving data.
898        let conn = Connection::open_in_memory().unwrap();
899        // Apply v1..=v7 by running the migrator after manually stamping v7.
900        // Simpler: run the migrator, then manually rename back to v7 state to
901        // exercise the v7→v8 path.  Here we just verify forward path covers it.
902        migrate(&conn).unwrap();
903        // Insert a row using the new name (post-v8); the rename happened transparently.
904        conn.execute(
905            "INSERT INTO file_log (run_id, export_name, file_name, row_count, bytes, format, created_at)
906             VALUES ('r1', 'orders', 'f.parquet', 100, 4096, 'parquet', '2026-05-21T00:00:00Z')",
907            [],
908        )
909        .unwrap();
910        let count: i64 = conn
911            .query_row("SELECT COUNT(*) FROM file_log", [], |r| r.get(0))
912            .unwrap();
913        assert_eq!(count, 1);
914    }
915
916    #[test]
917    fn run_aggregate_table_exists_after_migration() {
918        let s = StateStore::open_in_memory().unwrap();
919        let conn = match &s.conn {
920            StateConn::Sqlite(c) => c,
921            StateConn::Postgres(_) => unreachable!(),
922        };
923        let exists: bool = conn
924            .query_row(
925                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='run_aggregate'",
926                [],
927                |row| row.get(0),
928            )
929            .unwrap();
930        assert!(exists, "v5 migration must create the run_aggregate table");
931    }
932
933    #[test]
934    fn pg_sql_converts_placeholders() {
935        assert_eq!(
936            pg_sql("SELECT ?1, ?2 FROM t WHERE x = ?3"),
937            "SELECT $1, $2 FROM t WHERE x = $3"
938        );
939        assert_eq!(
940            pg_sql("INSERT INTO t VALUES (?1, ?2)"),
941            "INSERT INTO t VALUES ($1, $2)"
942        );
943        assert_eq!(pg_sql("no placeholders"), "no placeholders");
944        // ?N with two digits
945        assert_eq!(pg_sql("?10 AND ?11"), "$10 AND $11");
946    }
947
948    #[test]
949    fn redact_pg_url_removes_password() {
950        assert_eq!(
951            redact_pg_url("postgresql://rivet:secret123@localhost:5433/rivet_state"),
952            "postgresql://rivet:***@localhost:5433/rivet_state"
953        );
954        assert_eq!(
955            redact_pg_url("postgres://admin:p@ssw0rd@db.prod.example.com/state"),
956            "postgres://admin:***@db.prod.example.com/state"
957        );
958    }
959
960    #[test]
961    fn redact_pg_url_no_password_unchanged() {
962        // URL without a password should come back as-is.
963        let url = "postgresql://rivet@localhost/state";
964        assert_eq!(redact_pg_url(url), url);
965    }
966
967    // ── state(pg) sslmode → TlsMode mapping ─────────────────────────────────
968    //
969    // Pins the decision behind the TLS bug fix: the state backend can no longer
970    // hard-code NoTls. We can't drive a live TLS handshake in a unit test, so we
971    // assert the *chosen transport policy* — TLS is enforced for require /
972    // verify-* and plaintext (NoTls) otherwise — which is what selects the
973    // connector inside `connect_pg` -> `connect_client`.
974    use crate::config::TlsMode;
975
976    #[test]
977    fn state_sslmode_enforced_values_negotiate_tls() {
978        for (url, want) in [
979            (
980                "postgresql://u:p@db.prod:5432/state?sslmode=require",
981                TlsMode::Require,
982            ),
983            (
984                "postgresql://u:p@db.prod/state?sslmode=verify-ca",
985                TlsMode::VerifyCa,
986            ),
987            (
988                "postgresql://u:p@db.prod/state?sslmode=verify-full",
989                TlsMode::VerifyFull,
990            ),
991        ] {
992            let mode = state_tls_mode_from_url(url);
993            assert_eq!(mode, Some(want), "url: {url}");
994            assert!(
995                mode.unwrap().is_enforced(),
996                "{want:?} must enforce TLS (not NoTls)"
997            );
998        }
999    }
1000
1001    #[test]
1002    fn state_sslmode_plaintext_values_stay_notls() {
1003        // Missing / disable / prefer / allow / unrecognized / uppercase all keep
1004        // the original NoTls behavior, so dev + docker setups are unchanged.
1005        for url in [
1006            "postgresql://u:p@localhost/state",
1007            "postgresql://u:p@localhost/state?sslmode=disable",
1008            "postgresql://u:p@db/state?sslmode=prefer",
1009            "postgresql://u:p@db/state?sslmode=allow",
1010            "postgresql://u:p@db/state?sslmode=REQUIRE",
1011            "postgresql://u:p@db/state?sslmode=garbage",
1012            "postgresql://u:p@db/state?sslmode",
1013            "postgresql://u:p@db/state?sslmode=",
1014        ] {
1015            assert_eq!(state_tls_mode_from_url(url), None, "url: {url}");
1016        }
1017    }
1018
1019    #[test]
1020    fn state_sslmode_exact_key_and_last_occurrence_wins() {
1021        // `xsslmode` is a different parameter; the exact `sslmode` key matters.
1022        assert_eq!(
1023            state_tls_mode_from_url("postgresql://u:p@db/state?xsslmode=require"),
1024            None
1025        );
1026        // Found among other params.
1027        assert_eq!(
1028            state_tls_mode_from_url(
1029                "postgresql://u:p@db/state?connect_timeout=10&sslmode=require&application_name=x"
1030            ),
1031            Some(TlsMode::Require)
1032        );
1033        // Last occurrence wins, matching libpq.
1034        assert_eq!(
1035            state_tls_mode_from_url("postgresql://u:p@db/state?sslmode=disable&sslmode=require"),
1036            Some(TlsMode::Require)
1037        );
1038        assert_eq!(
1039            state_tls_mode_from_url("postgresql://u:p@db/state?sslmode=require&sslmode=disable"),
1040            None
1041        );
1042    }
1043}