use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;
use tracing::info;
use crate::auth::AuthClaims;
use crate::config::AppConfig;
use crate::error::AppError;
use crate::operations::did_webvh::UpdateDidWebvhError;
use crate::operations::protocol::disable_rest::{
DisableRestError, DisableRestParams, disable_rest,
};
use crate::operations::protocol::document::{DocumentPatchError, current_rest_service};
use crate::operations::protocol::enable_rest::{EnableRestError, EnableRestParams, enable_rest};
use crate::operations::protocol::snapshot::{
self, RestSnapshot, ServiceConfigSnapshot, ServiceKind,
};
use crate::operations::protocol::update_rest::{UpdateRestError, UpdateRestParams, update_rest};
use crate::operations::protocol::{OpContext, ServiceOpDeps};
use crate::store::KeyspaceHandle;
#[derive(Debug, Clone, Default)]
pub struct RollbackRestParams;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RollbackKind {
Disabled,
Enabled,
Updated,
NoOp,
}
#[derive(Debug, Clone)]
pub struct RollbackRestResult {
pub new_version_id: Option<String>,
pub kind: RollbackKind,
pub vta_did: String,
pub serverless: bool,
}
#[derive(Debug, Error)]
pub enum RollbackRestError {
#[error(
"no prior mutation for `services rest` to roll back from. \
Use `services rest enable / update / disable` directly instead."
)]
NoPriorMutation,
#[error("VTA DID is not configured — run `vta setup` first")]
VtaDidNotConfigured,
#[error("VTA DID `{0}` has no webvh record")]
VtaDidRecordMissing(String),
#[error("VTA DID `{0}` has no published log")]
VtaDidLogMissing(String),
#[error("VTA DID log is empty")]
EmptyLog,
#[error(transparent)]
EnableForward(#[from] EnableRestError),
#[error(transparent)]
UpdateForward(#[from] UpdateRestError),
#[error(transparent)]
DisableForward(#[from] DisableRestError),
#[error("DID document patch failed: {0}")]
DocumentPatch(#[from] DocumentPatchError),
#[error("WebVH update failed: {0}")]
WebVHUpdate(#[from] UpdateDidWebvhError),
#[error("auth: {0}")]
Auth(String),
#[error("storage error: {0}")]
Storage(String),
}
impl From<AppError> for RollbackRestError {
fn from(value: AppError) -> Self {
Self::Storage(value.to_string())
}
}
impl From<crate::operations::protocol::preconditions::ProtocolPreconditionError>
for RollbackRestError
{
fn from(value: crate::operations::protocol::preconditions::ProtocolPreconditionError) -> Self {
use crate::operations::protocol::preconditions::ProtocolPreconditionError as E;
match value {
E::VtaDidNotConfigured => Self::VtaDidNotConfigured,
E::VtaDidRecordMissing(s) => Self::VtaDidRecordMissing(s),
E::VtaDidLogMissing(s) => Self::VtaDidLogMissing(s),
E::EmptyLog => Self::EmptyLog,
E::Storage(s) | E::DocumentParse(s) => Self::Storage(s),
}
}
}
pub async fn rollback_rest(
deps: &ServiceOpDeps<'_>,
auth: &AuthClaims,
_params: RollbackRestParams,
channel: &str,
) -> Result<RollbackRestResult, RollbackRestError> {
auth.require_super_admin()
.map_err(|e| RollbackRestError::Auth(e.to_string()))?;
let snap = snapshot::read(deps.snapshot_ks, ServiceKind::Rest)
.await
.map_err(|e| RollbackRestError::Storage(format!("snapshot read: {e}")))?
.ok_or(RollbackRestError::NoPriorMutation)?;
let rest_snap = match snap {
ServiceConfigSnapshot::Rest(s) => s,
other => {
return Err(RollbackRestError::Storage(format!(
"snapshot kind mismatch: stored {other:?}, requested Rest",
)));
}
};
let current_url = read_current_rest_url(deps.config, deps.webvh_ks).await?;
info!(
channel,
snapshot = ?rest_snap,
current = ?current_url,
"rollback_rest dispatching",
);
match (rest_snap, current_url.as_deref()) {
(RestSnapshot::Disabled, Some(_)) => {
let result =
disable_rest(deps, auth, DisableRestParams, OpContext::Rollback, channel).await?;
Ok(RollbackRestResult {
new_version_id: Some(result.new_version_id),
kind: RollbackKind::Disabled,
vta_did: result.vta_did,
serverless: result.serverless,
})
}
(RestSnapshot::Enabled { url }, None) => {
let result = enable_rest(
deps,
auth,
EnableRestParams { url: url.clone() },
OpContext::Rollback,
channel,
)
.await?;
Ok(RollbackRestResult {
new_version_id: Some(result.new_version_id),
kind: RollbackKind::Enabled,
vta_did: result.vta_did,
serverless: result.serverless,
})
}
(RestSnapshot::Enabled { url }, Some(current)) if url != current => {
let result = update_rest(
deps,
auth,
UpdateRestParams { url: url.clone() },
OpContext::Rollback,
channel,
)
.await?;
Ok(RollbackRestResult {
new_version_id: Some(result.new_version_id),
kind: RollbackKind::Updated,
vta_did: result.vta_did,
serverless: result.serverless,
})
}
_ => {
info!(
channel,
"rollback_rest: snapshot matches current state — no-op"
);
Ok(RollbackRestResult {
new_version_id: None,
kind: RollbackKind::NoOp,
vta_did: String::new(),
serverless: false,
})
}
}
}
async fn read_current_rest_url(
config: &Arc<RwLock<AppConfig>>,
webvh_ks: &KeyspaceHandle,
) -> Result<Option<String>, RollbackRestError> {
let state = super::preconditions::load_vta_doc_state(config, webvh_ks).await?;
Ok(current_rest_service(&state.current_doc).map(|svc| svc.url))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::Store;
use vti_common::config::StoreConfig as VtiStoreConfig;
struct TestFixture {
_dir: tempfile::TempDir,
_config: Arc<RwLock<AppConfig>>,
store: Store,
}
impl TestFixture {
fn snapshot_ks(&self) -> KeyspaceHandle {
self.store.keyspace(snapshot::KEYSPACE_NAME).unwrap()
}
}
fn build_fixture(rest: bool, didcomm: bool) -> TestFixture {
use crate::test_support::test_app_config;
let dir = tempfile::tempdir().unwrap();
let mut cfg = test_app_config(dir.path().into());
cfg.services.rest = rest;
cfg.services.didcomm = didcomm;
cfg.vta_did = Some("did:webvh:scid123:host:vta".into());
cfg.config_path = dir.path().join("vta.toml");
let initial = toml::to_string_pretty(&cfg).unwrap();
std::fs::write(&cfg.config_path, initial).unwrap();
let store = Store::open(&VtiStoreConfig {
data_dir: dir.path().into(),
})
.unwrap();
TestFixture {
_dir: dir,
_config: Arc::new(RwLock::new(cfg)),
store,
}
}
#[tokio::test]
async fn no_prior_mutation_when_snapshot_empty() {
let fx = build_fixture(true, true);
let snapshot_ks = fx.snapshot_ks();
let snap = snapshot::read(&snapshot_ks, ServiceKind::Rest)
.await
.unwrap();
assert!(snap.is_none());
let err = RollbackRestError::NoPriorMutation;
let msg = err.to_string();
assert!(msg.contains("no prior mutation"));
}
#[tokio::test]
async fn snapshot_disabled_round_trips() {
let fx = build_fixture(true, true);
let snapshot_ks = fx.snapshot_ks();
snapshot::write(
&snapshot_ks,
ServiceConfigSnapshot::Rest(RestSnapshot::Disabled),
)
.await
.unwrap();
let read = snapshot::read(&snapshot_ks, ServiceKind::Rest)
.await
.unwrap()
.unwrap();
match read {
ServiceConfigSnapshot::Rest(RestSnapshot::Disabled) => {}
other => panic!("expected Rest(Disabled), got {other:?}"),
}
}
#[tokio::test]
async fn snapshot_enabled_with_url_round_trips() {
let fx = build_fixture(true, true);
let snapshot_ks = fx.snapshot_ks();
snapshot::write(
&snapshot_ks,
ServiceConfigSnapshot::Rest(RestSnapshot::Enabled {
url: "https://prior.example.com".into(),
}),
)
.await
.unwrap();
let read = snapshot::read(&snapshot_ks, ServiceKind::Rest)
.await
.unwrap()
.unwrap();
match read {
ServiceConfigSnapshot::Rest(RestSnapshot::Enabled { url }) => {
assert_eq!(url, "https://prior.example.com");
}
other => panic!("expected Rest(Enabled {{ url }}), got {other:?}"),
}
}
#[test]
fn rollback_kind_variants_are_distinct() {
assert_ne!(RollbackKind::Disabled, RollbackKind::Enabled);
assert_ne!(RollbackKind::Enabled, RollbackKind::Updated);
assert_ne!(RollbackKind::Updated, RollbackKind::NoOp);
assert_ne!(RollbackKind::NoOp, RollbackKind::Disabled);
}
}