use rusqlite::Connection;
use std::path::Path;
use super::schema;
pub fn open_connection(path: &Path) -> rusqlite::Result<Connection> {
let conn = Connection::open(path)?;
conn.execute_batch(
"PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA cache_size = -64000;
PRAGMA temp_store = MEMORY;
PRAGMA busy_timeout = 5000;",
)?;
Ok(conn)
}
pub fn run_migrations(conn: &Connection, mcpr_version: &str) -> rusqlite::Result<()> {
let version = get_schema_version(conn);
if version < 1 {
conn.execute_batch(schema::V1_SCHEMA)?;
conn.execute_batch(schema::V1_META_SEED)?;
}
if version < 2 {
conn.execute_batch(schema::V2_SCHEMA)?;
}
if version < 3 {
conn.execute_batch(schema::V3_SCHEMA)?;
}
if version < 4 {
conn.execute_batch(schema::V4_SCHEMA)?;
}
if version < 5 {
conn.execute_batch(schema::V5_SCHEMA)?;
}
conn.execute(schema::UPSERT_MCPR_VERSION, rusqlite::params![mcpr_version])?;
Ok(())
}
fn get_schema_version(conn: &Connection) -> u32 {
let table_exists: bool = conn
.query_row(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type = 'table' AND name = 'meta'",
[],
|row| row.get(0),
)
.unwrap_or(false);
if !table_exists {
return 0;
}
conn.query_row(
"SELECT value FROM meta WHERE key = 'schema_version'",
[],
|row| {
let v: String = row.get(0)?;
Ok(v.parse::<u32>().unwrap_or(0))
},
)
.unwrap_or(0)
}
#[cfg(test)]
#[allow(non_snake_case)]
mod tests {
use super::*;
#[test]
fn run_migrations__fresh_db() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let conn = open_connection(&db_path).unwrap();
run_migrations(&conn, "0.3.0-test").unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type = 'table' AND name IN ('requests', 'sessions', 'meta')",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 3);
let version: String = conn
.query_row(
"SELECT value FROM meta WHERE key = 'schema_version'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(version, "5");
let mcpr_ver: String = conn
.query_row(
"SELECT value FROM meta WHERE key = 'mcpr_version'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(mcpr_ver, "0.3.0-test");
}
#[test]
fn run_migrations__idempotent() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let conn = open_connection(&db_path).unwrap();
run_migrations(&conn, "0.3.0").unwrap();
run_migrations(&conn, "0.3.1").unwrap();
let mcpr_ver: String = conn
.query_row(
"SELECT value FROM meta WHERE key = 'mcpr_version'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(mcpr_ver, "0.3.1");
}
#[test]
fn run_migrations__v3_adds_proxy() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let conn = open_connection(&db_path).unwrap();
run_migrations(&conn, "test").unwrap();
conn.execute(
"INSERT INTO server_schema (proxy, upstream_url, method, payload, captured_at, schema_hash)
VALUES ('search', 'http://localhost:9000', 'tools/list', '{}', 1000, 'abc')",
[],
)
.unwrap();
let proxy: String = conn
.query_row(
"SELECT proxy FROM server_schema WHERE upstream_url = 'http://localhost:9000'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(proxy, "search");
conn.execute(
"INSERT INTO schema_changes (proxy, upstream_url, method, change_type, detected_at)
VALUES ('search', 'http://localhost:9000', 'tools/list', 'initial', 1000)",
[],
)
.unwrap();
let proxy: String = conn
.query_row(
"SELECT proxy FROM schema_changes WHERE upstream_url = 'http://localhost:9000'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(proxy, "search");
conn.execute(
"INSERT INTO server_schema (proxy, upstream_url, method, payload, captured_at, schema_hash)
VALUES ('email', 'http://localhost:9000', 'tools/list', '{}', 2000, 'def')",
[],
)
.unwrap();
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM server_schema", [], |row| row.get(0))
.unwrap();
assert_eq!(count, 2);
}
#[test]
fn run_migrations__v4_renames_latency() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let conn = open_connection(&db_path).unwrap();
run_migrations(&conn, "test").unwrap();
conn.execute(
"INSERT INTO requests (request_id, ts, proxy, method, latency_us, status)
VALUES ('r1', 1000, 'api', 'tools/call', 142000, 'ok')",
[],
)
.unwrap();
let latency: i64 = conn
.query_row(
"SELECT latency_us FROM requests WHERE request_id = 'r1'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(latency, 142_000);
}
#[test]
fn run_migrations__v4_converts_ms_to_us() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let conn = open_connection(&db_path).unwrap();
conn.execute_batch(schema::V1_SCHEMA).unwrap();
conn.execute_batch(schema::V1_META_SEED).unwrap();
conn.execute_batch(schema::V2_SCHEMA).unwrap();
conn.execute_batch(schema::V3_SCHEMA).unwrap();
conn.execute(
"INSERT INTO requests (request_id, ts, proxy, method, latency_ms, status)
VALUES ('r1', 1000, 'api', 'tools/call', 42, 'ok')",
[],
)
.unwrap();
conn.execute(
"INSERT INTO requests (request_id, ts, proxy, method, latency_ms, status)
VALUES ('r2', 2000, 'api', 'tools/call', 1500, 'ok')",
[],
)
.unwrap();
conn.execute_batch(schema::V4_SCHEMA).unwrap();
let latency1: i64 = conn
.query_row(
"SELECT latency_us FROM requests WHERE request_id = 'r1'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(latency1, 42_000);
let latency2: i64 = conn
.query_row(
"SELECT latency_us FROM requests WHERE request_id = 'r2'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(latency2, 1_500_000);
}
#[test]
fn run_migrations__v4_rebuilds_slow_index() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let conn = open_connection(&db_path).unwrap();
run_migrations(&conn, "test").unwrap();
let idx_sql: String = conn
.query_row(
"SELECT sql FROM sqlite_master WHERE type = 'index' AND name = 'idx_requests_slow'",
[],
|row| row.get(0),
)
.unwrap();
assert!(idx_sql.contains("latency_us"));
}
#[test]
fn run_migrations__v3_defaults_proxy() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let conn = open_connection(&db_path).unwrap();
conn.execute_batch(super::schema::V1_SCHEMA).unwrap();
conn.execute_batch(super::schema::V1_META_SEED).unwrap();
conn.execute_batch(super::schema::V2_SCHEMA).unwrap();
conn.execute(
"INSERT INTO server_schema (upstream_url, method, payload, captured_at, schema_hash)
VALUES ('http://localhost:9000', 'tools/list', '{}', 1000, 'abc')",
[],
)
.unwrap();
conn.execute_batch(super::schema::V3_SCHEMA).unwrap();
let proxy: String = conn
.query_row(
"SELECT proxy FROM server_schema WHERE upstream_url = 'http://localhost:9000'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(proxy, "default");
}
}