use rusqlite::Connection;
use crate::error::{Error, Result};
const CURRENT_SCHEMA_VERSION: i32 = 4;
pub fn ensure_schema(conn: &Connection) -> Result<()> {
conn.execute_batch("CREATE TABLE IF NOT EXISTS kg_schema_version (version INTEGER NOT NULL);")?;
let stored: Option<i32> = conn
.query_row("SELECT version FROM kg_schema_version", [], |r| r.get(0))
.ok();
let current_version = match stored {
Some(v) => v,
None => {
if schema_exists(conn)? {
1 } else {
0 }
}
};
if current_version >= CURRENT_SCHEMA_VERSION {
return Ok(()); }
let tx = conn.unchecked_transaction()?;
for v in (current_version + 1)..=CURRENT_SCHEMA_VERSION {
apply_migration(&tx, v)?;
}
tx.execute("DELETE FROM kg_schema_version", [])?;
tx.execute(
"INSERT INTO kg_schema_version (version) VALUES (?1)",
[CURRENT_SCHEMA_VERSION],
)?;
tx.commit()?;
Ok(())
}
#[inline]
pub fn create_schema(conn: &Connection) -> Result<()> {
ensure_schema(conn)
}
pub fn schema_version(conn: &Connection) -> Result<Option<i32>> {
let v = conn
.query_row("SELECT version FROM kg_schema_version", [], |r| r.get(0))
.ok();
Ok(v)
}
pub fn schema_exists(conn: &Connection) -> Result<bool> {
let mut stmt = conn
.prepare("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='kg_entities'")?;
let count: i64 = stmt.query_row([], |row| row.get(0))?;
Ok(count > 0)
}
fn apply_migration(conn: &Connection, version: i32) -> Result<()> {
match version {
1 => migration_v1(conn),
2 => migration_v2(conn),
3 => migration_v3(conn),
4 => migration_v4(conn),
_ => Err(Error::Other(format!(
"Unknown schema migration version: {}",
version
))),
}
}
fn migration_v1(conn: &Connection) -> Result<()> {
conn.execute_batch(
r#"
CREATE TABLE IF NOT EXISTS kg_entities (
id INTEGER PRIMARY KEY AUTOINCREMENT,
entity_type TEXT NOT NULL,
name TEXT NOT NULL,
properties TEXT,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
);
CREATE TABLE IF NOT EXISTS kg_relations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_id INTEGER NOT NULL,
target_id INTEGER NOT NULL,
rel_type TEXT NOT NULL,
weight REAL DEFAULT 1.0,
properties TEXT,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
FOREIGN KEY (source_id) REFERENCES kg_entities(id) ON DELETE CASCADE,
FOREIGN KEY (target_id) REFERENCES kg_entities(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS kg_vectors (
entity_id INTEGER NOT NULL PRIMARY KEY,
vector BLOB NOT NULL,
dimension INTEGER NOT NULL,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
FOREIGN KEY (entity_id) REFERENCES kg_entities(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_entities_type ON kg_entities(entity_type);
CREATE INDEX IF NOT EXISTS idx_entities_name ON kg_entities(name);
CREATE INDEX IF NOT EXISTS idx_relations_source ON kg_relations(source_id);
CREATE INDEX IF NOT EXISTS idx_relations_target ON kg_relations(target_id);
CREATE INDEX IF NOT EXISTS idx_relations_type ON kg_relations(rel_type);
CREATE TABLE IF NOT EXISTS kg_hyperedges (
id INTEGER PRIMARY KEY AUTOINCREMENT,
hyperedge_type TEXT NOT NULL,
entity_ids TEXT NOT NULL,
weight REAL DEFAULT 1.0,
arity INTEGER NOT NULL,
properties TEXT,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
);
CREATE TABLE IF NOT EXISTS kg_hyperedge_entities (
hyperedge_id INTEGER NOT NULL,
entity_id INTEGER NOT NULL,
position INTEGER NOT NULL,
PRIMARY KEY (hyperedge_id, entity_id),
FOREIGN KEY (hyperedge_id) REFERENCES kg_hyperedges(id) ON DELETE CASCADE,
FOREIGN KEY (entity_id) REFERENCES kg_entities(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_hyperedges_type ON kg_hyperedges(hyperedge_type);
CREATE INDEX IF NOT EXISTS idx_hyperedges_arity ON kg_hyperedges(arity);
CREATE INDEX IF NOT EXISTS idx_he_entities_entity ON kg_hyperedge_entities(entity_id);
CREATE INDEX IF NOT EXISTS idx_he_entities_hyperedge ON kg_hyperedge_entities(hyperedge_id);
CREATE TABLE IF NOT EXISTS kg_turboquant_cache (
id INTEGER PRIMARY KEY CHECK (id = 1),
index_blob BLOB NOT NULL,
vector_count INTEGER NOT NULL
);
"#,
)?;
Ok(())
}
fn migration_v2(conn: &Connection) -> Result<()> {
conn.execute_batch(
"ALTER TABLE kg_turboquant_cache \
ADD COLUMN vectors_checksum INTEGER NOT NULL DEFAULT 0;",
)?;
Ok(())
}
fn migration_v3(conn: &Connection) -> Result<()> {
conn.execute_batch(
r#"
ALTER TABLE kg_entities ADD COLUMN confidence REAL DEFAULT 1.0;
ALTER TABLE kg_entities ADD COLUMN access_count INTEGER DEFAULT 0;
ALTER TABLE kg_entities ADD COLUMN last_accessed INTEGER;
ALTER TABLE kg_entities ADD COLUMN valid_from INTEGER;
ALTER TABLE kg_entities ADD COLUMN valid_until INTEGER;
ALTER TABLE kg_entities ADD COLUMN base_confidence REAL DEFAULT 1.0;
ALTER TABLE kg_entities ADD COLUMN decay_rate REAL DEFAULT 0.05;
ALTER TABLE kg_relations ADD COLUMN confidence REAL DEFAULT 1.0;
ALTER TABLE kg_relations ADD COLUMN valid_from INTEGER;
ALTER TABLE kg_relations ADD COLUMN valid_until INTEGER;
CREATE TABLE IF NOT EXISTS kg_dependencies (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_id INTEGER NOT NULL,
target_id INTEGER NOT NULL,
dep_type TEXT NOT NULL,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
FOREIGN KEY (source_id) REFERENCES kg_entities(id) ON DELETE CASCADE,
FOREIGN KEY (target_id) REFERENCES kg_entities(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_deps_source ON kg_dependencies(source_id);
CREATE INDEX IF NOT EXISTS idx_deps_target ON kg_dependencies(target_id);
CREATE TABLE IF NOT EXISTS kg_confidence_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
entity_id INTEGER NOT NULL,
old_value REAL NOT NULL,
new_value REAL NOT NULL,
reason TEXT NOT NULL,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
FOREIGN KEY (entity_id) REFERENCES kg_entities(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_conf_log_entity ON kg_confidence_log(entity_id);
CREATE INDEX IF NOT EXISTS idx_conf_log_entity_reason ON kg_confidence_log(entity_id, reason);
"#,
)?;
Ok(())
}
fn migration_v4(conn: &Connection) -> Result<()> {
conn.execute_batch(
r#"
CREATE TABLE IF NOT EXISTS kg_versions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
branch TEXT NOT NULL DEFAULT 'main',
parent_id INTEGER REFERENCES kg_versions(id) ON DELETE SET NULL,
description TEXT,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
is_merged INTEGER NOT NULL DEFAULT 0,
bit_slot INTEGER NOT NULL UNIQUE CHECK (bit_slot BETWEEN 0 AND 63)
);
CREATE INDEX IF NOT EXISTS idx_versions_branch ON kg_versions(branch);
CREATE INDEX IF NOT EXISTS idx_versions_parent ON kg_versions(parent_id);
ALTER TABLE kg_entities ADD COLUMN validity INTEGER DEFAULT NULL;
ALTER TABLE kg_relations ADD COLUMN validity INTEGER DEFAULT NULL;
"#,
)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use rusqlite::Connection;
#[test]
fn test_fresh_db_reaches_current_version() {
let conn = Connection::open_in_memory().unwrap();
ensure_schema(&conn).unwrap();
let v = schema_version(&conn).unwrap();
assert_eq!(v, Some(CURRENT_SCHEMA_VERSION));
}
#[test]
fn test_idempotent_second_call() {
let conn = Connection::open_in_memory().unwrap();
ensure_schema(&conn).unwrap();
ensure_schema(&conn).unwrap();
let v = schema_version(&conn).unwrap();
assert_eq!(v, Some(CURRENT_SCHEMA_VERSION));
}
#[test]
fn test_legacy_db_migrates_from_v1() {
let conn = Connection::open_in_memory().unwrap();
migration_v1(&conn).unwrap();
assert!(schema_exists(&conn).unwrap());
assert_eq!(schema_version(&conn).unwrap(), None);
ensure_schema(&conn).unwrap();
assert_eq!(schema_version(&conn).unwrap(), Some(CURRENT_SCHEMA_VERSION));
conn.execute(
"INSERT INTO kg_turboquant_cache (id, index_blob, vector_count, vectors_checksum) \
VALUES (1, X'', 0, 0)",
[],
)
.unwrap();
}
#[test]
fn test_all_tables_created() {
let conn = Connection::open_in_memory().unwrap();
ensure_schema(&conn).unwrap();
let tables = [
"kg_entities",
"kg_relations",
"kg_vectors",
"kg_hyperedges",
"kg_hyperedge_entities",
"kg_turboquant_cache",
"kg_schema_version",
"kg_versions",
"kg_dependencies",
"kg_confidence_log",
];
for table in &tables {
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?1",
[table],
|r| r.get(0),
)
.unwrap();
assert_eq!(count, 1, "table {table} should exist");
}
}
#[test]
fn test_v3_entity_columns_exist() {
let conn = Connection::open_in_memory().unwrap();
ensure_schema(&conn).unwrap();
conn.execute(
"INSERT INTO kg_entities \
(entity_type, name, confidence, access_count, base_confidence, decay_rate) \
VALUES ('test', 'T', 0.9, 5, 0.9, 0.05)",
[],
)
.unwrap();
let (conf, acc): (f64, i64) = conn
.query_row(
"SELECT confidence, access_count FROM kg_entities WHERE name = 'T'",
[],
|r| Ok((r.get(0)?, r.get(1)?)),
)
.unwrap();
assert!((conf - 0.9).abs() < 1e-9);
assert_eq!(acc, 5);
}
#[test]
fn test_v3_new_tables_writable() {
let conn = Connection::open_in_memory().unwrap();
ensure_schema(&conn).unwrap();
conn.execute(
"INSERT INTO kg_entities (entity_type, name) VALUES ('a', 'X')",
[],
)
.unwrap();
let id: i64 = conn.last_insert_rowid();
conn.execute(
"INSERT INTO kg_confidence_log (entity_id, old_value, new_value, reason) \
VALUES (?1, 1.0, 0.8, 'test')",
[id],
)
.unwrap();
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM kg_confidence_log", [], |r| r.get(0))
.unwrap();
assert_eq!(count, 1);
}
#[test]
fn test_create_schema_alias() {
let conn = Connection::open_in_memory().unwrap();
create_schema(&conn).unwrap();
assert_eq!(schema_version(&conn).unwrap(), Some(CURRENT_SCHEMA_VERSION));
}
#[test]
fn test_v4_validity_columns_exist() {
let conn = Connection::open_in_memory().unwrap();
ensure_schema(&conn).unwrap();
conn.execute(
"INSERT INTO kg_entities (entity_type, name) VALUES ('test', 'V')",
[],
)
.unwrap();
let validity: Option<i64> = conn
.query_row(
"SELECT validity FROM kg_entities WHERE name = 'V'",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(validity, None);
let eid: i64 = conn.last_insert_rowid();
conn.execute(
"INSERT INTO kg_entities (entity_type, name) VALUES ('test', 'V2')",
[],
)
.unwrap();
let eid2: i64 = conn.last_insert_rowid();
conn.execute(
"INSERT INTO kg_relations (source_id, target_id, rel_type) VALUES (?1, ?2, 'rel')",
rusqlite::params![eid, eid2],
)
.unwrap();
let rel_validity: Option<i64> = conn
.query_row(
"SELECT validity FROM kg_relations WHERE source_id = ?1",
[eid],
|r| r.get(0),
)
.unwrap();
assert_eq!(rel_validity, None);
}
#[test]
fn test_v4_versions_table_writable() {
let conn = Connection::open_in_memory().unwrap();
ensure_schema(&conn).unwrap();
conn.execute(
"INSERT INTO kg_versions (name, branch, description, bit_slot) \
VALUES ('v1', 'main', 'first', 0)",
[],
)
.unwrap();
let (name, branch, desc): (String, String, Option<String>) = conn
.query_row(
"SELECT name, branch, description FROM kg_versions WHERE name = 'v1'",
[],
|r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
)
.unwrap();
assert_eq!(name, "v1");
assert_eq!(branch, "main");
assert_eq!(desc.as_deref(), Some("first"));
}
}