use rusqlite::Connection;
use crate::error::SqliteError;
pub struct Migration {
pub id: &'static str,
pub up_sql: &'static str,
pub down_sql: Option<&'static str>,
pub is_already_applied: Option<fn(&Connection) -> bool>,
}
pub struct ServiceSchemaPlan {
pub service: &'static str,
pub sqlite: &'static [Migration],
pub postgres: &'static [Migration],
}
const SCHEMA_VERSION_TABLE: &str = include_str!("../sql/schema-version-table.sql");
pub fn apply_schema_plan(conn: &Connection, plan: &ServiceSchemaPlan) -> Result<(), SqliteError> {
conn.execute_batch(SCHEMA_VERSION_TABLE)?;
for migration in plan.sqlite {
if let Some(check) = migration.is_already_applied {
if check(conn) {
continue;
}
}
let already: bool = conn.query_row(
"SELECT COUNT(*) > 0 FROM _schema_versions WHERE service = ?1 AND migration_id = ?2",
rusqlite::params![plan.service, migration.id],
|row| row.get(0),
)?;
if already {
continue;
}
conn.execute_batch(migration.up_sql)?;
conn.execute(
"INSERT INTO _schema_versions (service, migration_id, applied_at) VALUES (?1, ?2, ?3)",
rusqlite::params![
plan.service,
migration.id,
chrono::Utc::now().timestamp_micros(),
],
)?;
}
Ok(())
}
pub struct VersionedMigration {
pub version: u32,
pub name: &'static str,
pub up: &'static str,
}
const V1_UP: &str = include_str!("../sql/schema.sql");
const V2_UP: &str = include_str!("../sql/002-narrow-fts-sections-update-trigger.sql");
const V3_UP: &str = include_str!("../sql/003-backfill-domain-mirror-atoms.sql");
pub const EMBEDDING_MODELS_DDL: &str = include_str!("../sql/embedding-models-ddl.sql");
pub const MIGRATIONS: &[VersionedMigration] = &[
VersionedMigration {
version: 1,
name: "initial_schema",
up: V1_UP,
},
VersionedMigration {
version: 2,
name: "narrow_fts_sections_update_trigger",
up: V2_UP,
},
VersionedMigration {
version: 3,
name: "backfill_domain_mirror_atoms",
up: V3_UP,
},
];
const MIGRATION_TRACKING_TABLE: &str = include_str!("../sql/schema-migrations-table.sql");
pub fn read_schema_version(conn: &Connection) -> u32 {
conn.query_row(
"SELECT COALESCE(MAX(version), 0) FROM _schema_migrations",
[],
|row| row.get(0),
)
.unwrap_or(0)
}
pub fn inspect_schema_version(path: &std::path::Path) -> Result<u32, SqliteError> {
let conn = Connection::open_with_flags(
path,
rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
)?;
Ok(read_schema_version(&conn))
}
pub fn run_migrations(conn: &mut Connection) -> Result<u32, SqliteError> {
conn.execute_batch(MIGRATION_TRACKING_TABLE)?;
let current_version: u32 = conn
.query_row(
"SELECT COALESCE(MAX(version), 0) FROM _schema_migrations",
[],
|row| row.get(0),
)
.unwrap_or(0);
let latest_version = MIGRATIONS.last().map(|m| m.version).unwrap_or(0);
if current_version > latest_version {
return Err(SqliteError::InvalidData(format!(
"database schema version {current_version} is ahead of the latest known migration \
{latest_version}. This database predates the consolidated baseline (ADR-015) or was \
written by a newer build. Recreate it from the current schema; in-place downgrade is \
not supported."
)));
}
let mut applied_version = current_version;
for migration in MIGRATIONS {
if migration.version <= current_version {
continue;
}
let tx = conn.transaction().map_err(|e| SqliteError::Migration {
version: migration.version,
error: e.to_string(),
})?;
tx.execute_batch(migration.up)
.map_err(|e| SqliteError::Migration {
version: migration.version,
error: e.to_string(),
})?;
let now = chrono::Utc::now().timestamp_micros();
tx.execute(
"INSERT INTO _schema_migrations (version, name, applied_at) VALUES (?1, ?2, ?3)",
rusqlite::params![migration.version, migration.name, now],
)
.map_err(|e| SqliteError::Migration {
version: migration.version,
error: e.to_string(),
})?;
tx.commit().map_err(|e| SqliteError::Migration {
version: migration.version,
error: e.to_string(),
})?;
applied_version = migration.version;
}
Ok(applied_version)
}
#[derive(Debug)]
pub struct EmbeddingModelRegistryRecord {
pub engine_name: String,
pub model_id: String,
pub key_version: String,
pub dimensions: u32,
pub status: String,
pub activated_at: Option<i64>,
pub superseded_at: Option<i64>,
}
pub fn query_embedding_models(
db: Option<&std::path::Path>,
engine_filter: Option<&str>,
) -> Result<Vec<EmbeddingModelRegistryRecord>, SqliteError> {
let path = db.map(std::path::Path::to_path_buf).unwrap_or_else(|| {
std::env::var("HOME")
.map(std::path::PathBuf::from)
.unwrap_or_else(|_| std::path::PathBuf::from("."))
.join(".khive/khive.db")
});
if !path.exists() {
return Ok(Vec::new());
}
let conn = Connection::open(path)?;
query_embedding_models_conn(&conn, engine_filter)
}
pub(crate) fn query_embedding_models_conn(
conn: &Connection,
engine_filter: Option<&str>,
) -> Result<Vec<EmbeddingModelRegistryRecord>, SqliteError> {
let exists: bool = conn.query_row(
"SELECT COUNT(*) > 0 FROM sqlite_master \
WHERE type='table' AND name='_embedding_models'",
[],
|row| row.get(0),
)?;
if !exists {
return Ok(Vec::new());
}
let sql = if engine_filter.is_some() {
"SELECT engine_name, model_id, key_version, dim, status, activated_at, superseded_at \
FROM _embedding_models WHERE engine_name = ?1 \
ORDER BY engine_name, activated_at IS NULL, activated_at"
} else {
"SELECT engine_name, model_id, key_version, dim, status, activated_at, superseded_at \
FROM _embedding_models \
ORDER BY engine_name, activated_at IS NULL, activated_at"
};
let mut stmt = conn.prepare(sql)?;
let map_row = |row: &rusqlite::Row<'_>| {
let dim_raw: i64 = row.get(3)?;
let dimensions = u32::try_from(dim_raw).map_err(|_| {
rusqlite::Error::FromSqlConversionFailure(
3,
rusqlite::types::Type::Integer,
Box::new(std::io::Error::other(format!(
"_embedding_models.dim value {dim_raw} is outside the valid u32 range [0, {}]",
u32::MAX,
))),
)
})?;
Ok(EmbeddingModelRegistryRecord {
engine_name: row.get(0)?,
model_id: row.get(1)?,
key_version: row.get(2)?,
dimensions,
status: row.get(4)?,
activated_at: row.get(5)?,
superseded_at: row.get(6)?,
})
};
if let Some(engine) = engine_filter {
stmt.query_map([engine], map_row)?
.collect::<Result<Vec<_>, _>>()
.map_err(Into::into)
} else {
stmt.query_map([], map_row)?
.collect::<Result<Vec<_>, _>>()
.map_err(Into::into)
}
}
#[cfg(test)]
#[path = "migrations_tests.rs"]
mod tests;