use rusqlite::Connection;
use crate::error::{Error, Result};
const CURRENT_SCHEMA_VERSION: i32 = 2;
pub fn ensure_schema(conn: &Connection) -> Result<()> {
conn.execute_batch("CREATE TABLE IF NOT EXISTS kg_schema_version (version INTEGER NOT NULL);")?;
let stored: Option<i32> = conn
.query_row("SELECT version FROM kg_schema_version", [], |r| r.get(0))
.ok();
let current_version = match stored {
Some(v) => v,
None => {
if schema_exists(conn)? {
1 } else {
0 }
}
};
if current_version >= CURRENT_SCHEMA_VERSION {
return Ok(()); }
let tx = conn.unchecked_transaction()?;
for v in (current_version + 1)..=CURRENT_SCHEMA_VERSION {
apply_migration(&tx, v)?;
}
tx.execute("DELETE FROM kg_schema_version", [])?;
tx.execute(
"INSERT INTO kg_schema_version (version) VALUES (?1)",
[CURRENT_SCHEMA_VERSION],
)?;
tx.commit()?;
Ok(())
}
#[inline]
pub fn create_schema(conn: &Connection) -> Result<()> {
ensure_schema(conn)
}
pub fn schema_version(conn: &Connection) -> Result<Option<i32>> {
let v = conn
.query_row("SELECT version FROM kg_schema_version", [], |r| r.get(0))
.ok();
Ok(v)
}
pub fn schema_exists(conn: &Connection) -> Result<bool> {
let mut stmt = conn
.prepare("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='kg_entities'")?;
let count: i64 = stmt.query_row([], |row| row.get(0))?;
Ok(count > 0)
}
fn apply_migration(conn: &Connection, version: i32) -> Result<()> {
match version {
1 => migration_v1(conn),
2 => migration_v2(conn),
_ => Err(Error::Other(format!(
"Unknown schema migration version: {}",
version
))),
}
}
fn migration_v1(conn: &Connection) -> Result<()> {
conn.execute_batch(
r#"
CREATE TABLE IF NOT EXISTS kg_entities (
id INTEGER PRIMARY KEY AUTOINCREMENT,
entity_type TEXT NOT NULL,
name TEXT NOT NULL,
properties TEXT,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
);
CREATE TABLE IF NOT EXISTS kg_relations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_id INTEGER NOT NULL,
target_id INTEGER NOT NULL,
rel_type TEXT NOT NULL,
weight REAL DEFAULT 1.0,
properties TEXT,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
FOREIGN KEY (source_id) REFERENCES kg_entities(id) ON DELETE CASCADE,
FOREIGN KEY (target_id) REFERENCES kg_entities(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS kg_vectors (
entity_id INTEGER NOT NULL PRIMARY KEY,
vector BLOB NOT NULL,
dimension INTEGER NOT NULL,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
FOREIGN KEY (entity_id) REFERENCES kg_entities(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_entities_type ON kg_entities(entity_type);
CREATE INDEX IF NOT EXISTS idx_entities_name ON kg_entities(name);
CREATE INDEX IF NOT EXISTS idx_relations_source ON kg_relations(source_id);
CREATE INDEX IF NOT EXISTS idx_relations_target ON kg_relations(target_id);
CREATE INDEX IF NOT EXISTS idx_relations_type ON kg_relations(rel_type);
CREATE TABLE IF NOT EXISTS kg_hyperedges (
id INTEGER PRIMARY KEY AUTOINCREMENT,
hyperedge_type TEXT NOT NULL,
entity_ids TEXT NOT NULL,
weight REAL DEFAULT 1.0,
arity INTEGER NOT NULL,
properties TEXT,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
);
CREATE TABLE IF NOT EXISTS kg_hyperedge_entities (
hyperedge_id INTEGER NOT NULL,
entity_id INTEGER NOT NULL,
position INTEGER NOT NULL,
PRIMARY KEY (hyperedge_id, entity_id),
FOREIGN KEY (hyperedge_id) REFERENCES kg_hyperedges(id) ON DELETE CASCADE,
FOREIGN KEY (entity_id) REFERENCES kg_entities(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_hyperedges_type ON kg_hyperedges(hyperedge_type);
CREATE INDEX IF NOT EXISTS idx_hyperedges_arity ON kg_hyperedges(arity);
CREATE INDEX IF NOT EXISTS idx_he_entities_entity ON kg_hyperedge_entities(entity_id);
CREATE INDEX IF NOT EXISTS idx_he_entities_hyperedge ON kg_hyperedge_entities(hyperedge_id);
CREATE TABLE IF NOT EXISTS kg_turboquant_cache (
id INTEGER PRIMARY KEY CHECK (id = 1),
index_blob BLOB NOT NULL,
vector_count INTEGER NOT NULL
);
"#,
)?;
Ok(())
}
fn migration_v2(conn: &Connection) -> Result<()> {
conn.execute_batch(
"ALTER TABLE kg_turboquant_cache \
ADD COLUMN vectors_checksum INTEGER NOT NULL DEFAULT 0;",
)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use rusqlite::Connection;
#[test]
fn test_fresh_db_reaches_current_version() {
let conn = Connection::open_in_memory().unwrap();
ensure_schema(&conn).unwrap();
let v = schema_version(&conn).unwrap();
assert_eq!(v, Some(CURRENT_SCHEMA_VERSION));
}
#[test]
fn test_idempotent_second_call() {
let conn = Connection::open_in_memory().unwrap();
ensure_schema(&conn).unwrap();
ensure_schema(&conn).unwrap();
let v = schema_version(&conn).unwrap();
assert_eq!(v, Some(CURRENT_SCHEMA_VERSION));
}
#[test]
fn test_legacy_db_migrates_from_v1() {
let conn = Connection::open_in_memory().unwrap();
migration_v1(&conn).unwrap();
assert!(schema_exists(&conn).unwrap());
assert_eq!(schema_version(&conn).unwrap(), None);
ensure_schema(&conn).unwrap();
assert_eq!(schema_version(&conn).unwrap(), Some(CURRENT_SCHEMA_VERSION));
conn.execute(
"INSERT INTO kg_turboquant_cache (id, index_blob, vector_count, vectors_checksum) \
VALUES (1, X'', 0, 0)",
[],
)
.unwrap();
}
#[test]
fn test_all_tables_created() {
let conn = Connection::open_in_memory().unwrap();
ensure_schema(&conn).unwrap();
let tables = [
"kg_entities",
"kg_relations",
"kg_vectors",
"kg_hyperedges",
"kg_hyperedge_entities",
"kg_turboquant_cache",
"kg_schema_version",
];
for table in &tables {
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?1",
[table],
|r| r.get(0),
)
.unwrap();
assert_eq!(count, 1, "table {table} should exist");
}
}
#[test]
fn test_create_schema_alias() {
let conn = Connection::open_in_memory().unwrap();
create_schema(&conn).unwrap();
assert_eq!(schema_version(&conn).unwrap(), Some(CURRENT_SCHEMA_VERSION));
}
}