use rusqlite::Connection;
use crate::error::Result;
mod checkpoint;
mod cursor;
mod file_log;
mod journal_store;
mod metrics;
mod progression;
mod run_aggregate;
mod schema;
mod shape;
#[allow(unused_imports)]
pub use checkpoint::ChunkTaskInfo;
#[allow(unused_imports)]
pub use file_log::FileRecord;
#[allow(unused_imports)]
pub use metrics::ExportMetric;
#[allow(unused_imports)]
pub use progression::{Boundary, ExportProgression};
#[allow(unused_imports)]
pub use run_aggregate::{RunAggregate, RunAggregateEntry};
#[allow(unused_imports)]
pub use schema::{SchemaChange, SchemaColumn, arrow_schema_to_columns, schema_fingerprint};
#[allow(unused_imports)]
pub use shape::ShapeWarning;
const STATE_DB_NAME: &str = ".rivet_state.db";
const SCHEMA_VERSION: i64 = MIGRATIONS[MIGRATIONS.len() - 1].0;
const MIGRATIONS: &[(i64, &str)] = &[
(
1,
"CREATE TABLE IF NOT EXISTS export_state (
export_name TEXT PRIMARY KEY,
last_cursor_value TEXT,
last_run_at TEXT
);
CREATE TABLE IF NOT EXISTS export_metrics (
id INTEGER PRIMARY KEY AUTOINCREMENT,
export_name TEXT NOT NULL,
run_at TEXT NOT NULL,
duration_ms INTEGER NOT NULL,
total_rows INTEGER NOT NULL,
peak_rss_mb INTEGER,
status TEXT NOT NULL,
error_message TEXT,
tuning_profile TEXT,
format TEXT,
mode TEXT,
files_produced INTEGER DEFAULT 0,
bytes_written INTEGER DEFAULT 0,
retries INTEGER DEFAULT 0,
validated INTEGER,
schema_changed INTEGER,
run_id TEXT
);
CREATE TABLE IF NOT EXISTS export_schema (
export_name TEXT PRIMARY KEY,
columns_json TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS file_manifest (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
export_name TEXT NOT NULL,
file_name TEXT NOT NULL,
row_count INTEGER NOT NULL,
bytes INTEGER NOT NULL,
format TEXT NOT NULL,
compression TEXT,
created_at TEXT NOT NULL
);",
),
(
2,
"CREATE TABLE IF NOT EXISTS chunk_run (
run_id TEXT PRIMARY KEY,
export_name TEXT NOT NULL,
plan_hash TEXT NOT NULL,
status TEXT NOT NULL,
max_chunk_attempts INTEGER NOT NULL DEFAULT 3,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_chunk_run_export_status
ON chunk_run(export_name, status);
CREATE TABLE IF NOT EXISTS chunk_task (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
chunk_index INTEGER NOT NULL,
start_key TEXT NOT NULL,
end_key TEXT NOT NULL,
status TEXT NOT NULL,
attempts INTEGER NOT NULL DEFAULT 0,
last_error TEXT,
rows_written INTEGER,
file_name TEXT,
updated_at TEXT NOT NULL,
UNIQUE(run_id, chunk_index)
);
CREATE INDEX IF NOT EXISTS idx_chunk_task_run_status ON chunk_task(run_id, status);",
),
(
3,
"CREATE INDEX IF NOT EXISTS idx_file_manifest_export ON file_manifest(export_name, id DESC);",
),
(
4,
"CREATE TABLE IF NOT EXISTS export_progression (
export_name TEXT PRIMARY KEY,
last_committed_strategy TEXT,
last_committed_cursor TEXT,
last_committed_chunk_index INTEGER,
last_committed_run_id TEXT,
last_committed_at TEXT,
last_verified_strategy TEXT,
last_verified_cursor TEXT,
last_verified_chunk_index INTEGER,
last_verified_run_id TEXT,
last_verified_at TEXT
);",
),
(
5,
"CREATE TABLE IF NOT EXISTS run_aggregate (
run_aggregate_id TEXT PRIMARY KEY,
started_at TEXT NOT NULL,
finished_at TEXT NOT NULL,
duration_ms INTEGER NOT NULL,
config_path TEXT,
parallel_mode TEXT NOT NULL,
total_exports INTEGER NOT NULL,
success_count INTEGER NOT NULL,
failed_count INTEGER NOT NULL,
skipped_count INTEGER NOT NULL,
total_rows INTEGER NOT NULL,
total_files INTEGER NOT NULL,
total_bytes INTEGER NOT NULL,
details_json TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_run_aggregate_finished
ON run_aggregate(finished_at DESC);",
),
(
6,
"CREATE TABLE IF NOT EXISTS export_shape (
export_name TEXT NOT NULL,
column_name TEXT NOT NULL,
max_byte_len INTEGER NOT NULL,
updated_at TEXT NOT NULL,
PRIMARY KEY (export_name, column_name)
);",
),
(
7,
"CREATE TABLE IF NOT EXISTS run_journal (
run_id TEXT PRIMARY KEY,
export_name TEXT NOT NULL,
finished_at TEXT NOT NULL,
journal_json TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_run_journal_export
ON run_journal(export_name, finished_at DESC);",
),
(
8,
"ALTER TABLE file_manifest RENAME TO file_log;
DROP INDEX IF EXISTS idx_file_manifest_export;
CREATE INDEX IF NOT EXISTS idx_file_log_export ON file_log(export_name, id DESC);",
),
];
const PG_MIGRATIONS: &[(i64, &str)] = &[
(
1,
"CREATE TABLE IF NOT EXISTS export_state (
export_name TEXT PRIMARY KEY,
last_cursor_value TEXT,
last_run_at TEXT
);
CREATE TABLE IF NOT EXISTS export_metrics (
id BIGSERIAL PRIMARY KEY,
export_name TEXT NOT NULL,
run_at TEXT NOT NULL,
duration_ms BIGINT NOT NULL,
total_rows BIGINT NOT NULL,
peak_rss_mb BIGINT,
status TEXT NOT NULL,
error_message TEXT,
tuning_profile TEXT,
format TEXT,
mode TEXT,
files_produced BIGINT DEFAULT 0,
bytes_written BIGINT DEFAULT 0,
retries BIGINT DEFAULT 0,
validated BOOLEAN,
schema_changed BOOLEAN,
run_id TEXT
);
CREATE TABLE IF NOT EXISTS export_schema (
export_name TEXT PRIMARY KEY,
columns_json TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS file_manifest (
id BIGSERIAL PRIMARY KEY,
run_id TEXT NOT NULL,
export_name TEXT NOT NULL,
file_name TEXT NOT NULL,
row_count BIGINT NOT NULL,
bytes BIGINT NOT NULL,
format TEXT NOT NULL,
compression TEXT,
created_at TEXT NOT NULL
);",
),
(
2,
"CREATE TABLE IF NOT EXISTS chunk_run (
run_id TEXT PRIMARY KEY,
export_name TEXT NOT NULL,
plan_hash TEXT NOT NULL,
status TEXT NOT NULL,
max_chunk_attempts BIGINT NOT NULL DEFAULT 3,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_chunk_run_export_status
ON chunk_run(export_name, status);
CREATE TABLE IF NOT EXISTS chunk_task (
id BIGSERIAL PRIMARY KEY,
run_id TEXT NOT NULL,
chunk_index BIGINT NOT NULL,
start_key TEXT NOT NULL,
end_key TEXT NOT NULL,
status TEXT NOT NULL,
attempts BIGINT NOT NULL DEFAULT 0,
last_error TEXT,
rows_written BIGINT,
file_name TEXT,
updated_at TEXT NOT NULL,
UNIQUE(run_id, chunk_index)
);
CREATE INDEX IF NOT EXISTS idx_chunk_task_run_status ON chunk_task(run_id, status);",
),
(
3,
"CREATE INDEX IF NOT EXISTS idx_file_manifest_export ON file_manifest(export_name, id DESC);",
),
(
4,
"CREATE TABLE IF NOT EXISTS export_progression (
export_name TEXT PRIMARY KEY,
last_committed_strategy TEXT,
last_committed_cursor TEXT,
last_committed_chunk_index BIGINT,
last_committed_run_id TEXT,
last_committed_at TEXT,
last_verified_strategy TEXT,
last_verified_cursor TEXT,
last_verified_chunk_index BIGINT,
last_verified_run_id TEXT,
last_verified_at TEXT
);",
),
(
5,
"CREATE TABLE IF NOT EXISTS run_aggregate (
run_aggregate_id TEXT PRIMARY KEY,
started_at TEXT NOT NULL,
finished_at TEXT NOT NULL,
duration_ms BIGINT NOT NULL,
config_path TEXT,
parallel_mode TEXT NOT NULL,
total_exports BIGINT NOT NULL,
success_count BIGINT NOT NULL,
failed_count BIGINT NOT NULL,
skipped_count BIGINT NOT NULL,
total_rows BIGINT NOT NULL,
total_files BIGINT NOT NULL,
total_bytes BIGINT NOT NULL,
details_json TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_run_aggregate_finished
ON run_aggregate(finished_at DESC);",
),
(
6,
"CREATE TABLE IF NOT EXISTS export_shape (
export_name TEXT NOT NULL,
column_name TEXT NOT NULL,
max_byte_len BIGINT NOT NULL,
updated_at TEXT NOT NULL,
PRIMARY KEY (export_name, column_name)
);",
),
(
7,
"CREATE TABLE IF NOT EXISTS run_journal (
run_id TEXT PRIMARY KEY,
export_name TEXT NOT NULL,
finished_at TEXT NOT NULL,
journal_json TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_run_journal_export
ON run_journal(export_name, finished_at DESC);",
),
(
8,
"ALTER TABLE file_manifest RENAME TO file_log;
DROP INDEX IF EXISTS idx_file_manifest_export;
CREATE INDEX IF NOT EXISTS idx_file_log_export ON file_log(export_name, id DESC);",
),
];
pub(super) fn pg_sql(sql: &str) -> String {
let bytes = sql.as_bytes();
let mut out = String::with_capacity(sql.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'?' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
out.push('$');
} else {
out.push(bytes[i] as char);
}
i += 1;
}
out
}
pub(super) enum StateConn {
Sqlite(rusqlite::Connection),
Postgres(Box<std::cell::RefCell<postgres::Client>>),
}
#[derive(Clone)]
pub enum StateRef {
Sqlite(std::path::PathBuf),
Postgres(String),
}
fn ensure_schema_version_table(conn: &Connection) {
let _ = conn.execute_batch(
"CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER NOT NULL
);",
);
}
fn get_current_version(conn: &Connection) -> i64 {
conn.query_row(
"SELECT COALESCE(MAX(version), 0) FROM schema_version",
[],
|row| row.get(0),
)
.unwrap_or(0)
}
fn migrate(conn: &Connection) -> Result<()> {
ensure_schema_version_table(conn);
let current = get_current_version(conn);
if current == 0 {
let has_export_state: bool = conn
.query_row(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='export_state'",
[],
|row| row.get(0),
)
.unwrap_or(false);
if has_export_state {
let metrics_cols = [
"files_produced INTEGER DEFAULT 0",
"bytes_written INTEGER DEFAULT 0",
"retries INTEGER DEFAULT 0",
"validated INTEGER",
"schema_changed INTEGER",
"run_id TEXT",
];
for col_def in &metrics_cols {
let sql = format!("ALTER TABLE export_metrics ADD COLUMN {}", col_def);
let _ = conn.execute(&sql, []);
}
}
}
for &(ver, sql) in MIGRATIONS {
if ver > current {
log::debug!("state: applying migration v{}", ver);
let atomic_sql = format!(
"BEGIN;\n{}\nINSERT INTO schema_version (version) VALUES ({});\nCOMMIT;",
sql, ver
);
conn.execute_batch(&atomic_sql)
.map_err(|e| anyhow::anyhow!("state: migration v{} failed: {}", ver, e))?;
}
}
let _ = conn.execute(
"DELETE FROM schema_version WHERE version < (SELECT MAX(version) FROM schema_version)",
[],
);
let final_version = get_current_version(conn);
if final_version != SCHEMA_VERSION {
anyhow::bail!(
"state: migration incomplete — expected schema v{} but reached v{}",
SCHEMA_VERSION,
final_version
);
}
Ok(())
}
fn migrate_pg(client: &mut postgres::Client) -> Result<()> {
client
.batch_execute("CREATE TABLE IF NOT EXISTS rivet_schema_version (version BIGINT NOT NULL);")
.map_err(|e| anyhow::anyhow!("state(pg): create version table: {:#}", e))?;
let current: i64 = client
.query_one(
"SELECT COALESCE(MAX(version), 0) FROM rivet_schema_version",
&[],
)
.map_err(|e| anyhow::anyhow!("state(pg): read schema version: {:#}", e))?
.get(0);
for &(ver, sql) in PG_MIGRATIONS {
if ver > current {
log::debug!("state(pg): applying migration v{}", ver);
let batch = format!(
"BEGIN; {} INSERT INTO rivet_schema_version (version) VALUES ({}); COMMIT;",
sql, ver
);
client
.batch_execute(&batch)
.map_err(|e| anyhow::anyhow!("state(pg): migration v{} failed: {:#}", ver, e))?;
}
}
let _ = client.batch_execute(
"DELETE FROM rivet_schema_version \
WHERE version < (SELECT MAX(version) FROM rivet_schema_version);",
);
let final_version: i64 = client
.query_one(
"SELECT COALESCE(MAX(version), 0) FROM rivet_schema_version",
&[],
)
.map_err(|e| anyhow::anyhow!("state(pg): read final schema version: {:#}", e))?
.get(0);
if final_version != SCHEMA_VERSION {
anyhow::bail!(
"state(pg): migration incomplete — expected schema v{} but reached v{}",
SCHEMA_VERSION,
final_version
);
}
Ok(())
}
fn redact_pg_url(url: &str) -> String {
if let Some(at_pos) = url.rfind('@')
&& let Some(scheme_end) = url.find("://")
{
let authority = &url[scheme_end + 3..at_pos];
if let Some(colon) = authority.rfind(':') {
let user = &authority[..colon];
return format!(
"{}://{}:***@{}",
&url[..scheme_end],
user,
&url[at_pos + 1..]
);
}
}
url.to_string()
}
pub(crate) const SQLITE_BUSY_TIMEOUT_MS: i64 = 10_000;
pub(crate) fn open_connection(db_path: &std::path::Path) -> Result<Connection> {
let conn = Connection::open(db_path)?;
if let Err(e) = conn.execute_batch("PRAGMA journal_mode=WAL;") {
log::warn!(
"state: WAL journal mode unavailable ({}); \
running in default mode — concurrent writes may be slower",
e
);
}
if let Err(e) = conn.execute_batch(&format!(
"PRAGMA busy_timeout = {};",
SQLITE_BUSY_TIMEOUT_MS
)) {
log::warn!(
"state: failed to set busy_timeout ({}); \
concurrent writers may surface SQLITE_BUSY immediately",
e
);
}
Ok(conn)
}
pub struct StateStore {
pub(super) conn: StateConn,
pub(super) state_ref: StateRef,
}
impl StateStore {
pub fn open(config_path: &str) -> Result<Self> {
if let Ok(url) = std::env::var("RIVET_STATE_URL")
&& url.starts_with("postgres")
{
return Self::open_postgres(&url);
}
Self::open_sqlite(config_path)
}
fn open_sqlite(config_path: &str) -> Result<Self> {
let config_dir = std::path::Path::new(config_path)
.parent()
.unwrap_or(std::path::Path::new("."));
let db_path = config_dir.join(STATE_DB_NAME);
let conn = open_connection(&db_path)?;
migrate(&conn)?;
Ok(Self {
conn: StateConn::Sqlite(conn),
state_ref: StateRef::Sqlite(db_path),
})
}
fn open_postgres(url: &str) -> Result<Self> {
let is_local =
url.contains("localhost") || url.contains("127.0.0.1") || url.contains("::1");
if !is_local {
log::warn!(
"state(pg): connecting to a remote host without TLS; \
set RIVET_STATE_URL to a sslmode=require URL for production use"
);
}
let mut client = postgres::Client::connect(url, postgres::NoTls).map_err(|e| {
anyhow::anyhow!("state(pg): connect to '{}': {:#}", redact_pg_url(url), e)
})?;
migrate_pg(&mut client)?;
Ok(Self {
conn: StateConn::Postgres(Box::new(std::cell::RefCell::new(client))),
state_ref: StateRef::Postgres(url.to_string()),
})
}
pub fn state_db_path(config_path: &str) -> std::path::PathBuf {
let config_dir = std::path::Path::new(config_path)
.parent()
.unwrap_or(std::path::Path::new("."));
config_dir.join(STATE_DB_NAME)
}
pub fn state_ref(&self) -> &StateRef {
&self.state_ref
}
#[allow(dead_code)]
pub fn open_in_memory() -> Result<Self> {
let conn = Connection::open_in_memory()?;
migrate(&conn)?;
Ok(Self {
conn: StateConn::Sqlite(conn),
state_ref: StateRef::Sqlite(std::path::PathBuf::from(":memory:")),
})
}
#[allow(dead_code)]
pub fn open_at_path(db_path: &std::path::Path) -> Result<Self> {
let conn = open_connection(db_path)?;
migrate(&conn)?;
Ok(Self {
conn: StateConn::Sqlite(conn),
state_ref: StateRef::Sqlite(db_path.to_path_buf()),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fresh_db_reaches_latest_version() {
let s = StateStore::open_in_memory().unwrap();
let ver = match &s.conn {
StateConn::Sqlite(c) => get_current_version(c),
StateConn::Postgres(_) => unreachable!(),
};
assert_eq!(ver, SCHEMA_VERSION);
}
#[test]
fn migration_is_idempotent() {
let s = StateStore::open_in_memory().unwrap();
match &s.conn {
StateConn::Sqlite(c) => {
migrate(c).unwrap();
migrate(c).unwrap();
assert_eq!(get_current_version(c), SCHEMA_VERSION);
}
StateConn::Postgres(_) => unreachable!(),
}
}
#[test]
fn legacy_db_gets_upgraded() {
let conn = Connection::open_in_memory().unwrap();
conn.execute_batch(
"CREATE TABLE export_state (
export_name TEXT PRIMARY KEY,
last_cursor_value TEXT,
last_run_at TEXT
);
CREATE TABLE export_metrics (
id INTEGER PRIMARY KEY AUTOINCREMENT,
export_name TEXT NOT NULL,
run_at TEXT NOT NULL,
duration_ms INTEGER NOT NULL,
total_rows INTEGER NOT NULL,
status TEXT NOT NULL
);",
)
.unwrap();
migrate(&conn).unwrap();
assert_eq!(get_current_version(&conn), SCHEMA_VERSION);
let has_chunk_run: bool = conn
.query_row(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='chunk_run'",
[],
|row| row.get(0),
)
.unwrap();
assert!(has_chunk_run);
}
#[test]
fn v8_renames_file_manifest_to_file_log() {
let s = StateStore::open_in_memory().unwrap();
let conn = match &s.conn {
StateConn::Sqlite(c) => c,
StateConn::Postgres(_) => unreachable!(),
};
let has_file_log: bool = conn
.query_row(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='file_log'",
[],
|row| row.get(0),
)
.unwrap();
assert!(has_file_log, "v8 must produce a `file_log` table");
let has_old: bool = conn
.query_row(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='file_manifest'",
[],
|row| row.get(0),
)
.unwrap();
assert!(!has_old, "v8 must remove the old `file_manifest` table");
let has_new_idx: bool = conn
.query_row(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_file_log_export'",
[],
|row| row.get(0),
)
.unwrap();
assert!(has_new_idx, "v8 must create the renamed index");
}
#[test]
fn v8_upgrades_existing_v7_db_with_data() {
let conn = Connection::open_in_memory().unwrap();
migrate(&conn).unwrap();
conn.execute(
"INSERT INTO file_log (run_id, export_name, file_name, row_count, bytes, format, created_at)
VALUES ('r1', 'orders', 'f.parquet', 100, 4096, 'parquet', '2026-05-21T00:00:00Z')",
[],
)
.unwrap();
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM file_log", [], |r| r.get(0))
.unwrap();
assert_eq!(count, 1);
}
#[test]
fn run_aggregate_table_exists_after_migration() {
let s = StateStore::open_in_memory().unwrap();
let conn = match &s.conn {
StateConn::Sqlite(c) => c,
StateConn::Postgres(_) => unreachable!(),
};
let exists: bool = conn
.query_row(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='run_aggregate'",
[],
|row| row.get(0),
)
.unwrap();
assert!(exists, "v5 migration must create the run_aggregate table");
}
#[test]
fn pg_sql_converts_placeholders() {
assert_eq!(
pg_sql("SELECT ?1, ?2 FROM t WHERE x = ?3"),
"SELECT $1, $2 FROM t WHERE x = $3"
);
assert_eq!(
pg_sql("INSERT INTO t VALUES (?1, ?2)"),
"INSERT INTO t VALUES ($1, $2)"
);
assert_eq!(pg_sql("no placeholders"), "no placeholders");
assert_eq!(pg_sql("?10 AND ?11"), "$10 AND $11");
}
#[test]
fn redact_pg_url_removes_password() {
assert_eq!(
redact_pg_url("postgresql://rivet:secret123@localhost:5433/rivet_state"),
"postgresql://rivet:***@localhost:5433/rivet_state"
);
assert_eq!(
redact_pg_url("postgres://admin:p@ssw0rd@db.prod.example.com/state"),
"postgres://admin:***@db.prod.example.com/state"
);
}
#[test]
fn redact_pg_url_no_password_unchanged() {
let url = "postgresql://rivet@localhost/state";
assert_eq!(redact_pg_url(url), url);
}
}