use std::collections::BTreeSet;
use std::fs;
use std::path::{Path, PathBuf};
use anyhow::{Context, Result, anyhow, bail};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use surrealdb::Surreal;
use surrealdb::engine::any::Any;
use surrealdb_types::SurrealValue;
use time::OffsetDateTime;
use time::format_description::well_known::Rfc3339;
use time::macros::format_description;
use crate::core::{exec_surql, sha256_hex};
use crate::schema_state::{
CATALOG_SNAPSHOT_PATH, CatalogDiff, CatalogEntity, CatalogSnapshot, EntityKey, FileDiff,
ROLLOUTS_DIR, SchemaFile, build_catalog_snapshot, collect_schema_files, diff_catalog,
diff_schema, ensure_local_state_dirs, ensure_overwrite, hash_schema_snapshot,
load_catalog_snapshot, load_schema_snapshot, render_remove_sql, save_catalog_snapshot,
save_schema_snapshot, snapshot_from_files,
};
use crate::setup::run_setup;
#[derive(Debug, Clone)]
pub struct RolloutPlanOpts {
pub name: Option<String>,
pub dry_run: bool,
}
#[derive(Debug, Clone)]
pub struct RolloutExecutionOpts {
pub selector: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum RolloutPhase {
Start,
Complete,
Rollback,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum RolloutStepKind {
ApplySchema,
RunSql,
AssertSql,
RemoveEntities,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum RolloutStatus {
Planned,
RunningStart,
ReadyToComplete,
RunningComplete,
Completed,
RunningRollback,
RolledBack,
Failed,
}
impl RolloutStatus {
fn as_str(&self) -> &'static str {
match self {
Self::Planned => "planned",
Self::RunningStart => "running_start",
Self::ReadyToComplete => "ready_to_complete",
Self::RunningComplete => "running_complete",
Self::Completed => "completed",
Self::RunningRollback => "running_rollback",
Self::RolledBack => "rolled_back",
Self::Failed => "failed",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RolloutSpec {
pub id: String,
pub name: String,
pub source_schema_hash: String,
pub target_schema_hash: String,
pub compatibility: String,
#[serde(default)]
pub renames: Vec<RolloutRename>,
#[serde(default)]
pub steps: Vec<RolloutStep>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RolloutRename {
pub kind: String,
pub scope: Option<String>,
pub from: String,
pub to: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RolloutStep {
pub id: String,
pub phase: RolloutPhase,
pub kind: RolloutStepKind,
#[serde(default)]
pub files: Vec<String>,
pub sql: Option<String>,
pub expect: Option<String>,
#[serde(default)]
pub entities: Vec<EntityKey>,
pub idempotent: Option<bool>,
}
#[derive(Debug, Clone)]
pub struct LoadedRolloutSpec {
pub path: PathBuf,
pub checksum: String,
pub spec: RolloutSpec,
}
#[expect(dead_code)]
#[derive(Debug, Clone)]
pub struct ManagedEntityRecord {
pub entity: CatalogEntity,
pub active_rollout_id: Option<String>,
pub state: String,
}
pub async fn run_baseline(db: &Surreal<Any>) -> Result<()> {
run_setup(db).await?;
ensure_local_state_dirs()?;
if rollout_rows_exist(db).await? {
bail!("rollout state already exists; baseline can only be run once");
}
let files = collect_schema_files()?;
let schema_snapshot = snapshot_from_files(&files);
let catalog_snapshot = build_catalog_snapshot(&files)?;
replace_managed_entities(db, &catalog_snapshot.entities, None, "active").await?;
replace_sync_hashes(db, &files).await?;
save_schema_snapshot(&schema_snapshot)?;
save_catalog_snapshot(&catalog_snapshot)?;
println!(
"Seeded managed entity baseline with {} schema file(s) and {} managed object(s).",
files.len(),
catalog_snapshot.entities.len()
);
Ok(())
}
pub async fn run_plan(opts: RolloutPlanOpts) -> Result<()> {
ensure_local_state_dirs()?;
let files = collect_schema_files()?;
let old_schema = load_schema_snapshot()?;
let old_catalog = load_catalog_snapshot()?;
let new_schema = snapshot_from_files(&files);
let new_catalog = build_catalog_snapshot(&files)?;
let file_diff = diff_schema(&old_schema, &new_schema);
let catalog_diff = diff_catalog(&old_catalog, &new_catalog);
validate_autoplan(&catalog_diff)?;
let name = opts.name.unwrap_or_else(|| "schema_rollout".to_string());
let slug = slugify(&name);
let ts = OffsetDateTime::now_utc()
.format(&format_description!("[year][month][day][hour][minute][second]"))?;
let rollout_id = format!("{ts}__{slug}");
let path = Path::new(ROLLOUTS_DIR).join(format!("{rollout_id}.toml"));
let spec = build_rollout_spec(
&rollout_id,
&name,
&files,
&file_diff,
&catalog_diff,
&old_schema,
&new_schema,
)?;
let raw = toml::to_string_pretty(&spec).context("serializing rollout spec")?;
if opts.dry_run {
println!("Pending rollout plan:");
println!(
" files: +{} ~{} -{}",
file_diff.added.len(),
file_diff.modified.len(),
file_diff.removed.len()
);
println!(
" entities: +{} ~{} -{}",
catalog_diff.added.len(),
catalog_diff.modified.len(),
catalog_diff.removed.len()
);
println!(" would create: {}", path.display());
return Ok(());
}
fs::write(&path, raw).with_context(|| format!("writing rollout file {}", path.display()))?;
save_schema_snapshot(&new_schema)?;
save_catalog_snapshot(&new_catalog)?;
println!("Generated rollout manifest {}", path.display());
println!("Updated {}", CATALOG_SNAPSHOT_PATH);
Ok(())
}
pub async fn run_lint(opts: RolloutExecutionOpts) -> Result<()> {
ensure_local_state_dirs()?;
let rollout = load_rollout_spec(resolve_rollout_path(opts.selector.as_deref())?)?;
validate_rollout_spec(&rollout.spec)?;
let files = collect_schema_files()?;
let current_hash = hash_schema_snapshot(&snapshot_from_files(&files))?;
if current_hash != rollout.spec.target_schema_hash {
bail!(
"target schema hash mismatch for '{}': manifest={}, current={}",
rollout.spec.id,
rollout.spec.target_schema_hash,
current_hash
);
}
println!("Rollout {} is valid (checksum {}).", rollout.spec.id, rollout.checksum);
Ok(())
}
pub async fn run_status(db: &Surreal<Any>, selector: Option<String>) -> Result<()> {
run_setup(db).await?;
let mut query =
"SELECT id, name, status, started_at, completed_at, last_error, steps FROM __rollout"
.to_string();
if selector.is_some() {
query.push_str(" WHERE record::id(id) = $id");
}
query.push_str(" ORDER BY started_at DESC;");
let mut req = db.query(query);
if let Some(id) = selector {
req = req.bind(("id", id));
}
let mut resp = req.await?;
let rows: Vec<Value> = resp.take(0)?;
if rows.is_empty() {
println!("No rollout records found.");
return Ok(());
}
for row in rows {
let id = string_field(&row, "id").unwrap_or_else(|| "<unknown>".to_string());
let name = string_field(&row, "name").unwrap_or_else(|| "<unnamed>".to_string());
let status = string_field(&row, "status").unwrap_or_else(|| "<unknown>".to_string());
println!("{} [{}] {}", id, status, name);
if let Some(started_at) = string_field(&row, "started_at") {
println!(" started_at: {}", started_at);
}
if let Some(completed_at) = string_field(&row, "completed_at") {
println!(" completed_at: {}", completed_at);
}
if let Some(last_error) = string_field(&row, "last_error") {
println!(" last_error: {}", last_error);
}
let steps = row.get("steps").and_then(|v| v.as_array()).cloned().unwrap_or_default();
for step in steps {
let step_id = string_field(&step, "step_id").unwrap_or_else(|| "<step>".to_string());
let phase = string_field(&step, "phase").unwrap_or_else(|| "?".to_string());
let kind = string_field(&step, "kind").unwrap_or_else(|| "?".to_string());
let status = string_field(&step, "status").unwrap_or_else(|| "?".to_string());
println!(" - {} [{}:{}] {}", step_id, phase, kind, status);
if let Some(err) = string_field(&step, "error") {
println!(" error: {}", err);
}
}
}
Ok(())
}
pub async fn run_start(db: &Surreal<Any>, opts: RolloutExecutionOpts) -> Result<()> {
run_setup(db).await?;
ensure_local_state_dirs()?;
let rollout = load_rollout_spec(resolve_rollout_path(opts.selector.as_deref())?)?;
validate_rollout_spec(&rollout.spec)?;
let files = collect_schema_files()?;
let target_schema = snapshot_from_files(&files);
let target_hash = hash_schema_snapshot(&target_schema)?;
if target_hash != rollout.spec.target_schema_hash {
bail!(
"target schema hash mismatch for '{}': manifest={}, current={}",
rollout.spec.id,
rollout.spec.target_schema_hash,
target_hash
);
}
let target_catalog = build_catalog_snapshot(&files)?;
let source_entities = load_managed_entities(db).await?;
let source_catalog = CatalogSnapshot {
version: 2,
entities: source_entities.iter().map(|row| row.entity.clone()).collect(),
};
acquire_lock(db, "global").await?;
let result = async {
ensure_no_conflicting_active_rollout(db, &rollout.spec.id).await?;
let record = load_rollout_record(db, &rollout.spec.id).await?;
match record.as_ref().and_then(|row| string_field(row, "status")).as_deref() {
Some("completed") => bail!("rollout '{}' is already completed", rollout.spec.id),
Some("rolled_back") => {
bail!("rollout '{}' has already been rolled back", rollout.spec.id)
}
_ => {}
}
if let Some(ref row) = record {
verify_rollout_record_matches(row, &rollout)?;
} else {
create_rollout_record(
db,
&rollout,
&source_catalog.entities,
&target_catalog.entities,
RolloutStatus::Planned,
)
.await?;
}
set_rollout_status(db, &rollout.spec.id, RolloutStatus::RunningStart, None, None).await?;
if let Err(err) = execute_phase(db, &rollout, RolloutPhase::Start).await {
set_rollout_status(
db,
&rollout.spec.id,
RolloutStatus::Failed,
Some(&format!("{err:#}")),
None,
)
.await?;
return Err(err);
}
set_rollout_status(db, &rollout.spec.id, RolloutStatus::ReadyToComplete, None, None)
.await?;
println!("Rollout {} is ready to complete.", rollout.spec.id);
Ok(())
}
.await;
let release = release_lock(db, "global").await;
match (result, release) {
(Err(err), _) => Err(err),
(Ok(_), Err(err)) => Err(err),
(Ok(value), Ok(())) => Ok(value),
}
}
pub async fn run_complete(db: &Surreal<Any>, opts: RolloutExecutionOpts) -> Result<()> {
run_setup(db).await?;
let rollout = load_rollout_spec(resolve_rollout_path(opts.selector.as_deref())?)?;
validate_rollout_spec(&rollout.spec)?;
acquire_lock(db, "global").await?;
let result = async {
let row = load_rollout_record(db, &rollout.spec.id)
.await?
.ok_or_else(|| anyhow!("rollout '{}' has not been started", rollout.spec.id))?;
verify_rollout_record_matches(&row, &rollout)?;
match string_field(&row, "status").as_deref() {
Some("ready_to_complete") | Some("running_complete") | Some("failed") => {}
Some(other) => {
bail!("rollout '{}' is not ready to complete (status={})", rollout.spec.id, other)
}
None => bail!("rollout '{}' has no status", rollout.spec.id),
}
set_rollout_status(db, &rollout.spec.id, RolloutStatus::RunningComplete, None, None)
.await?;
if let Err(err) = execute_phase(db, &rollout, RolloutPhase::Complete).await {
set_rollout_status(
db,
&rollout.spec.id,
RolloutStatus::Failed,
Some(&format!("{err:#}")),
None,
)
.await?;
return Err(err);
}
let target_entities = deserialize_entities_field(&row, "target_entities")?;
replace_managed_entities(db, &target_entities, None, "active").await?;
set_rollout_status(
db,
&rollout.spec.id,
RolloutStatus::Completed,
None,
Some(OffsetDateTime::now_utc().format(&Rfc3339)?),
)
.await?;
println!("Completed rollout {}.", rollout.spec.id);
Ok(())
}
.await;
let release = release_lock(db, "global").await;
match (result, release) {
(Err(err), _) => Err(err),
(Ok(_), Err(err)) => Err(err),
(Ok(value), Ok(())) => Ok(value),
}
}
pub async fn run_rollback(db: &Surreal<Any>, opts: RolloutExecutionOpts) -> Result<()> {
run_setup(db).await?;
let rollout = load_rollout_spec(resolve_rollout_path(opts.selector.as_deref())?)?;
validate_rollout_spec(&rollout.spec)?;
acquire_lock(db, "global").await?;
let result = async {
let row = load_rollout_record(db, &rollout.spec.id)
.await?
.ok_or_else(|| anyhow!("rollout '{}' has not been started", rollout.spec.id))?;
verify_rollout_record_matches(&row, &rollout)?;
match string_field(&row, "status").as_deref() {
Some("completed") => bail!("rollout '{}' is already completed", rollout.spec.id),
Some("rolled_back") => {
println!("Rollout {} is already rolled back.", rollout.spec.id);
return Ok(());
}
_ => {}
}
set_rollout_status(db, &rollout.spec.id, RolloutStatus::RunningRollback, None, None)
.await?;
if let Err(err) = execute_phase(db, &rollout, RolloutPhase::Rollback).await {
set_rollout_status(
db,
&rollout.spec.id,
RolloutStatus::Failed,
Some(&format!("{err:#}")),
None,
)
.await?;
return Err(err);
}
let source_entities = deserialize_entities_field(&row, "source_entities")?;
replace_managed_entities(db, &source_entities, None, "active").await?;
set_rollout_status(
db,
&rollout.spec.id,
RolloutStatus::RolledBack,
None,
Some(OffsetDateTime::now_utc().format(&Rfc3339)?),
)
.await?;
println!("Rolled back rollout {}.", rollout.spec.id);
Ok(())
}
.await;
let release = release_lock(db, "global").await;
match (result, release) {
(Err(err), _) => Err(err),
(Ok(_), Err(err)) => Err(err),
(Ok(value), Ok(())) => Ok(value),
}
}
pub async fn load_active_rollout_id(db: &Surreal<Any>) -> Result<Option<String>> {
let mut resp = db
.query(
"SELECT id, status, started_at FROM __rollout \
WHERE status INSIDE ['planned', 'running_start', 'ready_to_complete', 'running_complete', 'running_rollback', 'failed'] \
ORDER BY started_at DESC LIMIT 1;",
)
.await?;
let row: Option<Value> = resp.take(0)?;
Ok(row.and_then(|value| string_field(&value, "id")))
}
pub async fn load_managed_entities(db: &Surreal<Any>) -> Result<Vec<ManagedEntityRecord>> {
let mut resp = db.query("SELECT key, val FROM __entity WHERE ns = 'schema';").await?;
let rows: Vec<Value> = resp.take(0)?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let key = string_field_req(&row, "key")?;
let val = row.get("val").cloned().unwrap_or(Value::Null);
let parts: Vec<&str> = key.splitn(3, ':').collect();
if parts.len() < 3 {
continue;
}
let kind = parts[0].to_string();
let scope = if parts[1].is_empty() {
None
} else {
Some(parts[1].to_string())
};
let name = parts[2].to_string();
let source_path =
val.get("source_path").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let statement_hash =
val.get("statement_hash").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let file_hash =
val.get("file_hash").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let active_rollout_id =
val.get("active_rollout_id").and_then(|v| v.as_str()).map(str::to_string);
let state = val.get("state").and_then(|v| v.as_str()).unwrap_or("active").to_string();
out.push(ManagedEntityRecord {
entity: CatalogEntity {
kind,
scope,
name,
source_path,
statement_hash,
file_hash,
},
active_rollout_id,
state,
});
}
out.sort_by(|a, b| a.entity.cmp(&b.entity));
Ok(out)
}
pub async fn upsert_managed_entities(
db: &Surreal<Any>,
entities: &[CatalogEntity],
active_rollout_id: Option<&str>,
state: &str,
) -> Result<()> {
for entity in entities {
let entity_key = entity_key_string(&entity.kind, entity.scope.as_deref(), &entity.name);
db.query(
"DELETE __entity WHERE ns = 'schema' AND key = $key; \
CREATE __entity CONTENT { \
ns: 'schema', \
key: $key, \
val: { \
source_path: $source_path, \
statement_hash: $statement_hash, \
file_hash: $file_hash, \
active_rollout_id: $active_rollout_id, \
state: $state \
}, \
updated_at: time::now() \
};",
)
.bind(("key", entity_key))
.bind(("source_path", entity.source_path.clone()))
.bind(("statement_hash", entity.statement_hash.clone()))
.bind(("file_hash", entity.file_hash.clone()))
.bind(("active_rollout_id", active_rollout_id.map(str::to_string)))
.bind(("state", state.to_string()))
.await?
.check()?;
}
Ok(())
}
fn entity_key_string(kind: &str, scope: Option<&str>, name: &str) -> String {
format!("{}:{}:{}", kind, scope.unwrap_or(""), name)
}
pub async fn delete_managed_entities(db: &Surreal<Any>, entities: &[EntityKey]) -> Result<()> {
for entity in entities {
let key = entity_key_string(&entity.kind, entity.scope.as_deref(), &entity.name);
db.query("DELETE __entity WHERE ns = 'schema' AND key = $key;")
.bind(("key", key))
.await?
.check()?;
}
Ok(())
}
pub async fn replace_managed_entities(
db: &Surreal<Any>,
entities: &[CatalogEntity],
active_rollout_id: Option<&str>,
state: &str,
) -> Result<()> {
db.query("DELETE __entity WHERE ns = 'schema';").await?.check()?;
upsert_managed_entities(db, entities, active_rollout_id, state).await
}
pub async fn replace_sync_hashes(db: &Surreal<Any>, files: &[SchemaFile]) -> Result<()> {
db.query("DELETE __entity WHERE ns = 'sync';").await?.check()?;
for file in files {
db.query(
"CREATE __entity CONTENT { ns: 'sync', key: $path, val: { hash: $hash }, updated_at: time::now() };",
)
.bind(("path", file.path.clone()))
.bind(("hash", file.hash.clone()))
.await?
.check()?;
}
Ok(())
}
pub async fn delete_sync_hashes(db: &Surreal<Any>, paths: &[String]) -> Result<()> {
for path in paths {
db.query("DELETE __entity WHERE ns = 'sync' AND key = $path;")
.bind(("path", path.clone()))
.await?
.check()?;
}
Ok(())
}
fn build_rollout_spec(
rollout_id: &str,
name: &str,
files: &[SchemaFile],
file_diff: &FileDiff,
catalog_diff: &CatalogDiff,
old_schema: &crate::schema_state::SchemaSnapshot,
new_schema: &crate::schema_state::SchemaSnapshot,
) -> Result<RolloutSpec> {
let changed_paths = changed_files(files, file_diff);
let mut steps = Vec::new();
if !changed_paths.is_empty() {
steps.push(RolloutStep {
id: "apply_expand_schema".to_string(),
phase: RolloutPhase::Start,
kind: RolloutStepKind::ApplySchema,
files: changed_paths,
sql: None,
expect: None,
entities: Vec::new(),
idempotent: None,
});
}
let added_entities: Vec<EntityKey> =
catalog_diff.added.iter().map(CatalogEntity::key).collect();
if !added_entities.is_empty() {
steps.push(RolloutStep {
id: "rollback_expand_schema".to_string(),
phase: RolloutPhase::Rollback,
kind: RolloutStepKind::RemoveEntities,
files: Vec::new(),
sql: None,
expect: None,
entities: added_entities,
idempotent: None,
});
}
let removed_entities: Vec<EntityKey> =
catalog_diff.removed.iter().map(CatalogEntity::key).collect();
if !removed_entities.is_empty() {
steps.push(RolloutStep {
id: "remove_legacy_entities".to_string(),
phase: RolloutPhase::Complete,
kind: RolloutStepKind::RemoveEntities,
files: Vec::new(),
sql: None,
expect: None,
entities: removed_entities,
idempotent: None,
});
}
Ok(RolloutSpec {
id: rollout_id.to_string(),
name: name.to_string(),
source_schema_hash: hash_schema_snapshot(old_schema)?,
target_schema_hash: hash_schema_snapshot(new_schema)?,
compatibility: "phased".to_string(),
renames: Vec::new(),
steps,
})
}
fn validate_autoplan(diff: &CatalogDiff) -> Result<()> {
if !diff.modified.is_empty() {
let names = diff
.modified
.iter()
.map(|change| format!("{}:{}", change.old.kind, change.old.name))
.collect::<Vec<_>>()
.join(", ");
bail!(
"automatic rollout planning refuses modified managed entities: {}. \
Author a manual rollout manifest for non-additive changes.",
names
);
}
let removed_by_scope: BTreeSet<(String, Option<String>)> =
diff.removed.iter().map(|entity| (entity.kind.clone(), entity.scope.clone())).collect();
let added_by_scope: BTreeSet<(String, Option<String>)> =
diff.added.iter().map(|entity| (entity.kind.clone(), entity.scope.clone())).collect();
if removed_by_scope.intersection(&added_by_scope).next().is_some() {
bail!(
"automatic rollout planning detected add/remove changes in the same scope. \
Author a manual rollout manifest with explicit renames/backfill steps."
);
}
Ok(())
}
fn changed_files(files: &[SchemaFile], diff: &FileDiff) -> Vec<String> {
let changed: BTreeSet<&str> =
diff.added.iter().chain(diff.modified.iter()).map(String::as_str).collect();
let mut out: Vec<String> = files
.iter()
.filter(|file| changed.contains(file.path.as_str()))
.map(|file| file.path.clone())
.collect();
out.sort();
out
}
fn load_rollout_spec(path: PathBuf) -> Result<LoadedRolloutSpec> {
let raw = fs::read_to_string(&path).with_context(|| format!("reading {}", path.display()))?;
let spec: RolloutSpec =
toml::from_str(&raw).with_context(|| format!("parsing {}", path.display()))?;
Ok(LoadedRolloutSpec {
path,
checksum: sha256_hex(raw.as_bytes()),
spec,
})
}
fn resolve_rollout_path(selector: Option<&str>) -> Result<PathBuf> {
let selector = selector.ok_or_else(|| anyhow!("rollout id or path is required"))?;
let path = Path::new(selector);
if path.exists() {
return Ok(path.to_path_buf());
}
let direct = Path::new(ROLLOUTS_DIR).join(selector);
if direct.exists() {
return Ok(direct);
}
let with_ext = Path::new(ROLLOUTS_DIR).join(format!("{selector}.toml"));
if with_ext.exists() {
return Ok(with_ext);
}
bail!("unable to find rollout '{}'", selector)
}
fn validate_rollout_spec(spec: &RolloutSpec) -> Result<()> {
if spec.id.trim().is_empty() {
bail!("rollout id is required");
}
if spec.name.trim().is_empty() {
bail!("rollout name is required");
}
if spec.compatibility.trim().is_empty() {
bail!("compatibility is required");
}
let mut step_ids = BTreeSet::new();
for step in &spec.steps {
if !step_ids.insert(step.id.clone()) {
bail!("duplicate rollout step id '{}'", step.id);
}
match step.kind {
RolloutStepKind::ApplySchema => {
if step.files.is_empty() {
bail!("apply_schema step '{}' requires files", step.id);
}
}
RolloutStepKind::RunSql => {
if step.sql.as_deref().unwrap_or("").trim().is_empty() {
bail!("run_sql step '{}' requires sql", step.id);
}
if step.idempotent != Some(true) {
bail!("run_sql step '{}' must declare idempotent = true", step.id);
}
}
RolloutStepKind::AssertSql => {
if step.sql.as_deref().unwrap_or("").trim().is_empty() {
bail!("assert_sql step '{}' requires sql", step.id);
}
if step.expect.as_deref().unwrap_or("").trim().is_empty() {
bail!("assert_sql step '{}' requires expect", step.id);
}
}
RolloutStepKind::RemoveEntities => {
if step.entities.is_empty() {
bail!("remove_entities step '{}' requires entities", step.id);
}
}
}
}
Ok(())
}
async fn execute_phase(
db: &Surreal<Any>,
rollout: &LoadedRolloutSpec,
phase: RolloutPhase,
) -> Result<()> {
for step in rollout.spec.steps.iter().filter(|step| step.phase == phase) {
if step_already_completed(db, &rollout.spec.id, &step.id).await? {
continue;
}
record_step_start(db, &rollout.spec.id, step).await?;
let result = execute_step(db, step).await;
match result {
Ok(()) => record_step_complete(db, &rollout.spec.id, step).await?,
Err(err) => {
record_step_failure(db, &rollout.spec.id, step, &format!("{err:#}")).await?;
return Err(err);
}
}
}
Ok(())
}
async fn execute_step(db: &Surreal<Any>, step: &RolloutStep) -> Result<()> {
match step.kind {
RolloutStepKind::ApplySchema => {
for file in &step.files {
let raw = fs::read_to_string(file).with_context(|| format!("reading {}", file))?;
let sql = ensure_overwrite(&raw);
exec_surql(db, &sql).await?;
}
Ok(())
}
RolloutStepKind::RunSql => {
let sql = step.sql.as_deref().ok_or_else(|| anyhow!("missing sql"))?;
exec_surql(db, sql).await
}
RolloutStepKind::AssertSql => {
let sql = step.sql.as_deref().ok_or_else(|| anyhow!("missing sql"))?;
let expect = step.expect.as_deref().ok_or_else(|| anyhow!("missing expect"))?;
let actual = execute_sql_value(db, sql).await?;
if value_to_expect_string(&actual) != expect.trim() {
bail!(
"assert step '{}' failed: expected {}, got {}",
step.id,
expect,
value_to_expect_string(&actual)
);
}
Ok(())
}
RolloutStepKind::RemoveEntities => {
let sql = render_remove_sql(&step.entities, true)?.join("\n");
if sql.trim().is_empty() {
return Ok(());
}
exec_surql(db, &sql).await
}
}
}
async fn execute_sql_value(db: &Surreal<Any>, sql: &str) -> Result<Value> {
let mut response = db.query(sql).await?.check()?;
let raw: surrealdb_types::Value = response.take(0)?;
Ok(Value::from_value(raw).unwrap_or(Value::Null))
}
fn value_to_expect_string(value: &Value) -> String {
match value {
Value::Null => "null".to_string(),
Value::Bool(v) => v.to_string(),
Value::Number(v) => v.to_string(),
Value::String(v) => v.clone(),
other => other.to_string(),
}
}
async fn rollout_rows_exist(db: &Surreal<Any>) -> Result<bool> {
let mut resp = db.query("SELECT id FROM __rollout LIMIT 1;").await?;
let row: Option<Value> = resp.take(0)?;
Ok(row.is_some())
}
async fn ensure_no_conflicting_active_rollout(db: &Surreal<Any>, rollout_id: &str) -> Result<()> {
if let Some(active_id) = load_active_rollout_id(db).await?
&& active_id != rollout_id
{
bail!("rollout '{}' cannot start while rollout '{}' is active", rollout_id, active_id);
}
Ok(())
}
async fn create_rollout_record(
db: &Surreal<Any>,
rollout: &LoadedRolloutSpec,
source_entities: &[CatalogEntity],
target_entities: &[CatalogEntity],
status: RolloutStatus,
) -> Result<()> {
let started_at = OffsetDateTime::now_utc().format(&Rfc3339)?;
db.query(
"DELETE __rollout WHERE record::id(id) = $id; \
CREATE __rollout CONTENT { \
id: $id, \
name: $name, \
manifest_path: $manifest_path, \
manifest_checksum: $manifest_checksum, \
source_schema_hash: $source_schema_hash, \
target_schema_hash: $target_schema_hash, \
status: $status, \
source_entities: $source_entities, \
target_entities: $target_entities, \
started_at: <datetime> $started_at, \
updated_at: time::now(), \
last_error: NONE \
};",
)
.bind(("id", rollout.spec.id.clone()))
.bind(("name", rollout.spec.name.clone()))
.bind(("manifest_path", rollout.path.to_string_lossy().to_string()))
.bind(("manifest_checksum", rollout.checksum.clone()))
.bind(("source_schema_hash", rollout.spec.source_schema_hash.clone()))
.bind(("target_schema_hash", rollout.spec.target_schema_hash.clone()))
.bind(("status", status.as_str().to_string()))
.bind(("source_entities", serde_json::to_value(source_entities)?))
.bind(("target_entities", serde_json::to_value(target_entities)?))
.bind(("started_at", started_at))
.await?
.check()?;
Ok(())
}
async fn load_rollout_record(db: &Surreal<Any>, rollout_id: &str) -> Result<Option<Value>> {
let mut resp = db
.query("SELECT * FROM __rollout WHERE record::id(id) = $id LIMIT 1;")
.bind(("id", rollout_id.to_string()))
.await?;
let row: Option<Value> = resp.take(0)?;
Ok(row)
}
fn verify_rollout_record_matches(row: &Value, rollout: &LoadedRolloutSpec) -> Result<()> {
let checksum = string_field_req(row, "manifest_checksum")?;
if checksum != rollout.checksum {
bail!(
"manifest checksum mismatch for '{}': db={}, file={}",
rollout.spec.id,
checksum,
rollout.checksum
);
}
let source = string_field_req(row, "source_schema_hash")?;
let target = string_field_req(row, "target_schema_hash")?;
if source != rollout.spec.source_schema_hash || target != rollout.spec.target_schema_hash {
bail!("schema hash mismatch for rollout '{}'", rollout.spec.id);
}
Ok(())
}
fn deserialize_entities_field(row: &Value, key: &str) -> Result<Vec<CatalogEntity>> {
let value =
row.get(key).cloned().ok_or_else(|| anyhow!("missing '{}' on rollout record", key))?;
serde_json::from_value(value).with_context(|| format!("parsing {}", key))
}
async fn set_rollout_status(
db: &Surreal<Any>,
rollout_id: &str,
status: RolloutStatus,
last_error: Option<&str>,
completed_at: Option<String>,
) -> Result<()> {
db.query(
"UPDATE __rollout SET \
status = $status, \
last_error = $last_error, \
completed_at = IF $completed_at THEN <datetime> $completed_at ELSE NONE END, \
updated_at = time::now() \
WHERE record::id(id) = $id;",
)
.bind(("id", rollout_id.to_string()))
.bind(("status", status.as_str().to_string()))
.bind(("last_error", last_error.map(str::to_string)))
.bind(("completed_at", completed_at))
.await?
.check()?;
Ok(())
}
async fn step_already_completed(
db: &Surreal<Any>,
rollout_id: &str,
step_id: &str,
) -> Result<bool> {
let row = load_rollout_record(db, rollout_id).await?;
let Some(row) = row else {
return Ok(false);
};
let steps = row.get("steps").and_then(|v| v.as_array());
Ok(steps
.map(|arr| {
arr.iter().any(|s| {
s.get("step_id").and_then(|v| v.as_str()) == Some(step_id)
&& s.get("status").and_then(|v| v.as_str()) == Some("completed")
})
})
.unwrap_or(false))
}
async fn record_step_start(db: &Surreal<Any>, rollout_id: &str, step: &RolloutStep) -> Result<()> {
let new_step = serde_json::json!({
"step_id": step.id,
"phase": format!("{:?}", step.phase).to_ascii_lowercase(),
"kind": format!("{:?}", step.kind).to_ascii_lowercase(),
"checksum": step_checksum(step)?,
"status": "running",
"error": null
});
let row = load_rollout_record(db, rollout_id)
.await?
.ok_or_else(|| anyhow!("rollout '{}' not found", rollout_id))?;
let mut steps: Vec<Value> =
row.get("steps").and_then(|v| v.as_array()).cloned().unwrap_or_default();
steps.retain(|s| s.get("step_id").and_then(|v| v.as_str()) != Some(&step.id));
steps.push(new_step);
db.query(
"UPDATE __rollout SET steps = $steps, updated_at = time::now() \
WHERE record::id(id) = $id;",
)
.bind(("id", rollout_id.to_string()))
.bind(("steps", steps))
.await?
.check()?;
Ok(())
}
async fn record_step_complete(
db: &Surreal<Any>,
rollout_id: &str,
step: &RolloutStep,
) -> Result<()> {
update_step_status(db, rollout_id, &step.id, "completed", None).await
}
async fn record_step_failure(
db: &Surreal<Any>,
rollout_id: &str,
step: &RolloutStep,
error: &str,
) -> Result<()> {
update_step_status(db, rollout_id, &step.id, "failed", Some(error)).await
}
async fn update_step_status(
db: &Surreal<Any>,
rollout_id: &str,
step_id: &str,
status: &str,
error: Option<&str>,
) -> Result<()> {
let row = load_rollout_record(db, rollout_id)
.await?
.ok_or_else(|| anyhow!("rollout '{}' not found", rollout_id))?;
let mut steps: Vec<Value> =
row.get("steps").and_then(|v| v.as_array()).cloned().unwrap_or_default();
for s in &mut steps {
if s.get("step_id").and_then(|v| v.as_str()) == Some(step_id)
&& let Some(obj) = s.as_object_mut()
{
obj.insert("status".into(), Value::String(status.to_string()));
obj.insert(
"error".into(),
error.map(|e| Value::String(e.to_string())).unwrap_or(Value::Null),
);
}
}
db.query(
"UPDATE __rollout SET steps = $steps, updated_at = time::now() \
WHERE record::id(id) = $id;",
)
.bind(("id", rollout_id.to_string()))
.bind(("steps", steps))
.await?
.check()?;
Ok(())
}
fn step_checksum(step: &RolloutStep) -> Result<String> {
let raw = serde_json::to_vec(step).context("serializing rollout step")?;
Ok(sha256_hex(&raw))
}
pub async fn acquire_lock(db: &Surreal<Any>, lock_key: &str) -> Result<()> {
let owner = std::env::var("SURREALKIT_OWNER").unwrap_or_else(|_| "surrealkit".to_string());
db.query(
"DELETE __entity WHERE ns = 'lock' AND key = $key; \
CREATE __entity CONTENT { \
ns: 'lock', \
key: $key, \
val: { owner: $owner }, \
updated_at: time::now() \
};",
)
.bind(("key", lock_key.to_string()))
.bind(("owner", owner))
.await?
.check()?;
Ok(())
}
pub async fn release_lock(db: &Surreal<Any>, lock_key: &str) -> Result<()> {
db.query("DELETE __entity WHERE ns = 'lock' AND key = $key;")
.bind(("key", lock_key.to_string()))
.await?
.check()?;
Ok(())
}
fn slugify(input: &str) -> String {
let mut out = String::new();
let mut prev_dash = false;
for ch in input.chars() {
let c = ch.to_ascii_lowercase();
if c.is_ascii_alphanumeric() {
out.push(c);
prev_dash = false;
} else if !prev_dash {
out.push('_');
prev_dash = true;
}
}
let trimmed = out.trim_matches('_');
if trimmed.is_empty() {
"schema_rollout".to_string()
} else {
trimmed.to_string()
}
}
fn string_field(row: &Value, key: &str) -> Option<String> {
row.get(key).and_then(|value| value.as_str()).map(str::to_string)
}
fn string_field_req(row: &Value, key: &str) -> Result<String> {
string_field(row, key).ok_or_else(|| anyhow!("missing '{}' in database row", key))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema_state::{CatalogChange, SchemaSnapshot, SchemaSnapshotEntry};
#[test]
fn plan_rejects_modified_entities() {
let diff = CatalogDiff {
added: Vec::new(),
removed: Vec::new(),
modified: vec![CatalogChange {
old: CatalogEntity {
kind: "field".to_string(),
scope: Some("person".to_string()),
name: "name".to_string(),
source_path: "database/schema/person.surql".to_string(),
statement_hash: "a".to_string(),
file_hash: "fa".to_string(),
},
new: CatalogEntity {
kind: "field".to_string(),
scope: Some("person".to_string()),
name: "name".to_string(),
source_path: "database/schema/person.surql".to_string(),
statement_hash: "b".to_string(),
file_hash: "fb".to_string(),
},
}],
};
let err = validate_autoplan(&diff).expect_err("should reject modified entities");
assert!(err.to_string().contains("refuses modified"));
}
#[test]
fn build_plan_creates_add_and_remove_steps() {
let files = vec![SchemaFile {
path: "database/schema/customer.surql".to_string(),
sql: "DEFINE TABLE customer SCHEMAFULL;".to_string(),
hash: "file-a".to_string(),
}];
let spec = build_rollout_spec(
"20260302153045__customer",
"customer",
&files,
&FileDiff {
added: vec!["database/schema/customer.surql".to_string()],
modified: Vec::new(),
removed: Vec::new(),
},
&CatalogDiff {
added: vec![CatalogEntity {
kind: "table".to_string(),
scope: None,
name: "customer".to_string(),
source_path: "database/schema/customer.surql".to_string(),
statement_hash: "stmt".to_string(),
file_hash: "file-a".to_string(),
}],
removed: vec![CatalogEntity {
kind: "field".to_string(),
scope: Some("person".to_string()),
name: "nickname".to_string(),
source_path: "database/schema/person.surql".to_string(),
statement_hash: "old".to_string(),
file_hash: "file-old".to_string(),
}],
modified: Vec::new(),
},
&SchemaSnapshot {
version: 1,
files: vec![SchemaSnapshotEntry {
path: "database/schema/person.surql".to_string(),
hash: "old".to_string(),
}],
},
&SchemaSnapshot {
version: 1,
files: vec![SchemaSnapshotEntry {
path: "database/schema/customer.surql".to_string(),
hash: "new".to_string(),
}],
},
)
.expect("build rollout");
assert_eq!(spec.steps.len(), 3);
assert!(
spec.steps.iter().any(|step| step.phase == RolloutPhase::Start
&& step.kind == RolloutStepKind::ApplySchema)
);
assert!(spec.steps.iter().any(|step| {
step.phase == RolloutPhase::Rollback && step.kind == RolloutStepKind::RemoveEntities
}));
assert!(spec.steps.iter().any(|step| {
step.phase == RolloutPhase::Complete && step.kind == RolloutStepKind::RemoveEntities
}));
}
#[test]
fn rollout_lint_rejects_non_idempotent_run_sql() {
let spec = RolloutSpec {
id: "a".to_string(),
name: "a".to_string(),
source_schema_hash: "1".to_string(),
target_schema_hash: "2".to_string(),
compatibility: "phased".to_string(),
renames: Vec::new(),
steps: vec![RolloutStep {
id: "step".to_string(),
phase: RolloutPhase::Start,
kind: RolloutStepKind::RunSql,
files: Vec::new(),
sql: Some("UPDATE person SET name = 'a';".to_string()),
expect: None,
entities: Vec::new(),
idempotent: Some(false),
}],
};
let err = validate_rollout_spec(&spec).expect_err("must reject non-idempotent run_sql");
assert!(err.to_string().contains("idempotent = true"));
}
async fn connect_mem_db() -> Surreal<Any> {
use surrealdb::engine::any::connect;
use surrealdb::opt::Config;
use surrealdb::opt::capabilities::Capabilities;
let config = Config::new().capabilities(Capabilities::all());
let db = connect(("mem://", config)).await.expect("connect mem://");
db.use_ns("surrealkit_test").use_db("rollout_test").await.expect("use_ns/use_db");
db.query(crate::scaffold::DEFAULT_SETUP)
.await
.expect("setup schema")
.check()
.expect("setup schema check");
db
}
fn sample_loaded_spec(id: &str) -> LoadedRolloutSpec {
LoadedRolloutSpec {
path: PathBuf::from(format!("database/rollouts/{id}.toml")),
checksum: "sum".to_string(),
spec: RolloutSpec {
id: id.to_string(),
name: "test".to_string(),
source_schema_hash: "src".to_string(),
target_schema_hash: "tgt".to_string(),
compatibility: "phased".to_string(),
renames: Vec::new(),
steps: Vec::new(),
},
}
}
async fn load_single_row(db: &Surreal<Any>) -> Value {
let mut resp =
db.query("SELECT * FROM __rollout LIMIT 1;").await.expect("select __rollout");
let rows: Vec<Value> = resp.take(0).expect("take rows");
rows.into_iter().next().expect("one row exists")
}
#[tokio::test]
async fn create_rollout_record_accepts_rfc3339_started_at() {
let db = connect_mem_db().await;
let loaded = sample_loaded_spec("20260417181055__initial_schema");
create_rollout_record(&db, &loaded, &[], &[], RolloutStatus::Planned)
.await
.expect("create_rollout_record should coerce started_at string to datetime");
let row = load_single_row(&db).await;
let started = row
.get("started_at")
.and_then(|v| v.as_str())
.expect("started_at is serialized as a datetime string");
time::OffsetDateTime::parse(started, &Rfc3339)
.expect("started_at should round-trip through RFC3339");
assert_eq!(row.get("status").and_then(|v| v.as_str()), Some("planned"));
}
#[tokio::test]
async fn set_rollout_status_accepts_rfc3339_completed_at() {
let db = connect_mem_db().await;
let loaded = sample_loaded_spec("20260417181055__complete_path");
create_rollout_record(&db, &loaded, &[], &[], RolloutStatus::RunningComplete)
.await
.expect("seed rollout record");
let completed_at = OffsetDateTime::now_utc().format(&Rfc3339).expect("format rfc3339");
set_rollout_status(
&db,
&loaded.spec.id,
RolloutStatus::Completed,
None,
Some(completed_at),
)
.await
.expect("set_rollout_status should coerce completed_at string to datetime");
let row = load_single_row(&db).await;
let completed = row
.get("completed_at")
.and_then(|v| v.as_str())
.expect("completed_at is serialized as a datetime string");
time::OffsetDateTime::parse(completed, &Rfc3339)
.expect("completed_at should round-trip through RFC3339");
}
#[tokio::test]
async fn set_rollout_status_accepts_none_completed_at() {
let db = connect_mem_db().await;
let loaded = sample_loaded_spec("20260417181055__running_path");
create_rollout_record(&db, &loaded, &[], &[], RolloutStatus::Planned)
.await
.expect("seed rollout record");
set_rollout_status(&db, &loaded.spec.id, RolloutStatus::RunningStart, None, None)
.await
.expect("set_rollout_status with None completed_at should succeed");
let row = load_single_row(&db).await;
assert!(
row.get("completed_at").is_none_or(Value::is_null),
"completed_at should be NONE/null, got {:?}",
row.get("completed_at")
);
assert_eq!(row.get("status").and_then(|v| v.as_str()), Some("running_start"));
}
#[tokio::test]
async fn load_rollout_record_finds_created_row() {
let db = connect_mem_db().await;
let loaded = sample_loaded_spec("20260417181055__lookup");
create_rollout_record(&db, &loaded, &[], &[], RolloutStatus::Planned)
.await
.expect("seed rollout record");
let row = load_rollout_record(&db, &loaded.spec.id)
.await
.expect("load_rollout_record query")
.expect("row must be found by rollout id");
assert_eq!(row.get("status").and_then(|v| v.as_str()), Some("planned"));
}
}