use std::sync::Arc;
use std::time::Duration;
use affinidi_did_resolver_cache_sdk::DIDCacheClient;
use thiserror::Error;
use tokio::sync::RwLock;
use tracing::info;
use vti_common::seed_store::SeedStore;
use vti_common::telemetry::SharedTelemetrySink;
use crate::auth::AuthClaims;
use crate::config::AppConfig;
use crate::didcomm_bridge::DIDCommBridge;
use crate::error::AppError;
use crate::messaging::drain_sweeper::DrainSweeper;
use crate::messaging::handshake::{HandshakeError, ListenerProver};
use crate::messaging::registry::MediatorListenerRegistry;
use crate::operations::did_webvh::UpdateDidWebvhError;
use crate::operations::protocol::OpContext;
use crate::operations::protocol::disable_didcomm::{
DisableDidcommError, DisableDidcommParams, DisableTransport, disable_didcomm,
};
use crate::operations::protocol::document::{DocumentPatchError, current_didcomm_service};
use crate::operations::protocol::enable_didcomm::{
EnableDidcommError, EnableDidcommParams, enable_didcomm,
};
use crate::operations::protocol::snapshot::{
self, DidcommSnapshot, ServiceConfigSnapshot, ServiceKind,
};
use crate::operations::protocol::update_didcomm::{
MigrateAuditKind, UpdateDidcommError, UpdateDidcommParams, update_didcomm,
};
use crate::store::KeyspaceHandle;
const DEFAULT_ROLLBACK_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, Clone)]
pub struct RollbackDidcommParams {
pub drain_ttl: Duration,
pub transport: DisableTransport,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RollbackKind {
Disabled,
Enabled,
Updated,
NoOp,
}
#[derive(Debug, Clone)]
pub struct RollbackDidcommResult {
pub new_version_id: Option<String>,
pub kind: RollbackKind,
pub draining_mediator: Option<String>,
pub vta_did: String,
pub serverless: bool,
}
#[derive(Debug, Error)]
pub enum RollbackDidcommError {
#[error(
"no prior mutation for `services didcomm` to roll back from. \
Use `services didcomm enable / update / disable` directly instead."
)]
NoPriorMutation,
#[error("VTA DID is not configured — run `vta setup` first")]
VtaDidNotConfigured,
#[error("VTA DID `{0}` has no webvh record")]
VtaDidRecordMissing(String),
#[error("VTA DID `{0}` has no published log")]
VtaDidLogMissing(String),
#[error("VTA DID log is empty")]
EmptyLog,
#[error(transparent)]
EnableForward(#[from] EnableDidcommError),
#[error(transparent)]
UpdateForward(#[from] UpdateDidcommError),
#[error(transparent)]
DisableForward(#[from] DisableDidcommError),
#[error(transparent)]
Handshake(#[from] HandshakeError),
#[error("DID document patch failed: {0}")]
DocumentPatch(#[from] DocumentPatchError),
#[error("WebVH update failed: {0}")]
WebVHUpdate(#[from] UpdateDidWebvhError),
#[error("auth: {0}")]
Auth(String),
#[error("storage error: {0}")]
Storage(String),
}
impl From<AppError> for RollbackDidcommError {
fn from(value: AppError) -> Self {
Self::Storage(value.to_string())
}
}
impl From<crate::operations::protocol::preconditions::ProtocolPreconditionError>
for RollbackDidcommError
{
fn from(value: crate::operations::protocol::preconditions::ProtocolPreconditionError) -> Self {
use crate::operations::protocol::preconditions::ProtocolPreconditionError as E;
match value {
E::VtaDidNotConfigured => Self::VtaDidNotConfigured,
E::VtaDidRecordMissing(s) => Self::VtaDidRecordMissing(s),
E::VtaDidLogMissing(s) => Self::VtaDidLogMissing(s),
E::EmptyLog => Self::EmptyLog,
E::Storage(s) | E::DocumentParse(s) => Self::Storage(s),
}
}
}
#[allow(clippy::too_many_arguments)]
pub async fn rollback_didcomm(
config: &Arc<RwLock<AppConfig>>,
keys_ks: &KeyspaceHandle,
imported_ks: &KeyspaceHandle,
contexts_ks: &KeyspaceHandle,
webvh_ks: &KeyspaceHandle,
audit_ks: &KeyspaceHandle,
drains_ks: &KeyspaceHandle,
snapshot_ks: &KeyspaceHandle,
service_state_ks: &KeyspaceHandle,
seed_store: &dyn SeedStore,
did_resolver: &DIDCacheClient,
didcomm_bridge: &Arc<DIDCommBridge>,
registry: &MediatorListenerRegistry,
sweeper: &DrainSweeper,
telemetry: &SharedTelemetrySink,
prover: &(dyn ListenerProver + Send + Sync),
auth: &AuthClaims,
params: RollbackDidcommParams,
webvh_auth_locks: &crate::operations::did_webvh::WebvhAuthLocks,
channel: &str,
) -> Result<RollbackDidcommResult, RollbackDidcommError> {
auth.require_super_admin()
.map_err(|e| RollbackDidcommError::Auth(e.to_string()))?;
let snap = snapshot::read(snapshot_ks, ServiceKind::Didcomm)
.await
.map_err(|e| RollbackDidcommError::Storage(format!("snapshot read: {e}")))?
.ok_or(RollbackDidcommError::NoPriorMutation)?;
let didcomm_snap = match snap {
ServiceConfigSnapshot::Didcomm(s) => s,
other => {
return Err(RollbackDidcommError::Storage(format!(
"snapshot kind mismatch: stored {other:?}, requested Didcomm",
)));
}
};
let current_mediator = read_current_didcomm_mediator(config, webvh_ks).await?;
info!(
channel,
snapshot = ?didcomm_snap,
current = ?current_mediator,
"rollback_didcomm dispatching",
);
match (didcomm_snap, current_mediator.as_deref()) {
(DidcommSnapshot::Disabled, Some(prior)) => {
let result = disable_didcomm(
config,
keys_ks,
imported_ks,
contexts_ks,
webvh_ks,
audit_ks,
drains_ks,
snapshot_ks,
service_state_ks,
seed_store,
did_resolver,
didcomm_bridge,
registry,
sweeper,
telemetry,
auth,
DisableDidcommParams {
drain_ttl: params.drain_ttl,
transport: params.transport,
},
OpContext::Rollback,
webvh_auth_locks,
channel,
)
.await?;
Ok(RollbackDidcommResult {
new_version_id: Some(result.new_version_id),
kind: RollbackKind::Disabled,
draining_mediator: Some(prior.to_string()),
vta_did: result.vta_did,
serverless: result.serverless,
})
}
(
DidcommSnapshot::Enabled {
mediator_did,
routing_keys: _,
},
None,
) => {
let result = enable_didcomm(
config,
keys_ks,
imported_ks,
contexts_ks,
webvh_ks,
audit_ks,
snapshot_ks,
service_state_ks,
seed_store,
did_resolver,
didcomm_bridge,
registry,
telemetry,
prover,
auth,
EnableDidcommParams {
mediator_did: mediator_did.clone(),
force: false,
handshake_timeout: DEFAULT_ROLLBACK_HANDSHAKE_TIMEOUT,
},
OpContext::Rollback,
webvh_auth_locks,
channel,
)
.await?;
Ok(RollbackDidcommResult {
new_version_id: Some(result.new_version_id),
kind: RollbackKind::Enabled,
draining_mediator: None,
vta_did: result.vta_did,
serverless: result.serverless,
})
}
(
DidcommSnapshot::Enabled {
mediator_did,
routing_keys: _,
},
Some(current),
) if mediator_did != current => {
let result = update_didcomm(
config,
keys_ks,
imported_ks,
contexts_ks,
webvh_ks,
audit_ks,
drains_ks,
snapshot_ks,
service_state_ks,
seed_store,
did_resolver,
didcomm_bridge,
registry,
sweeper,
telemetry,
prover,
auth,
UpdateDidcommParams {
new_mediator_did: mediator_did.clone(),
drain_ttl: params.drain_ttl,
force: false,
handshake_timeout: DEFAULT_ROLLBACK_HANDSHAKE_TIMEOUT,
audit_kind: MigrateAuditKind::Rollback,
transport: params.transport,
},
OpContext::Rollback,
webvh_auth_locks,
channel,
)
.await?;
Ok(RollbackDidcommResult {
new_version_id: Some(result.new_version_id),
kind: RollbackKind::Updated,
draining_mediator: Some(result.prior_mediator_did),
vta_did: result.vta_did,
serverless: result.serverless,
})
}
_ => {
info!(
channel,
"rollback_didcomm: snapshot matches current state — no-op"
);
Ok(RollbackDidcommResult {
new_version_id: None,
kind: RollbackKind::NoOp,
draining_mediator: None,
vta_did: String::new(),
serverless: false,
})
}
}
}
async fn read_current_didcomm_mediator(
config: &Arc<RwLock<AppConfig>>,
webvh_ks: &KeyspaceHandle,
) -> Result<Option<String>, RollbackDidcommError> {
let state = super::preconditions::load_vta_doc_state(config, webvh_ks).await?;
Ok(current_didcomm_service(&state.current_doc).map(|svc| svc.mediator_did))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::Store;
use vti_common::config::StoreConfig as VtiStoreConfig;
async fn empty_keyspace(name: &str) -> (tempfile::TempDir, KeyspaceHandle) {
let dir = tempfile::tempdir().unwrap();
let store = Store::open(&VtiStoreConfig {
data_dir: dir.path().into(),
})
.unwrap();
let ks = store.keyspace(name).unwrap();
(dir, ks)
}
#[tokio::test]
async fn no_prior_mutation_error_message_is_actionable() {
let (_dir, ks) = empty_keyspace(snapshot::KEYSPACE_NAME).await;
let snap = snapshot::read(&ks, ServiceKind::Didcomm).await.unwrap();
assert!(snap.is_none());
let err = RollbackDidcommError::NoPriorMutation;
let msg = err.to_string();
assert!(
msg.contains("no prior mutation"),
"error message must point operator at the right next step, got: {msg}",
);
assert!(msg.contains("services didcomm enable"));
}
#[tokio::test]
async fn snapshot_disabled_round_trips() {
let (_dir, ks) = empty_keyspace(snapshot::KEYSPACE_NAME).await;
snapshot::write(
&ks,
ServiceConfigSnapshot::Didcomm(DidcommSnapshot::Disabled),
)
.await
.unwrap();
let read = snapshot::read(&ks, ServiceKind::Didcomm)
.await
.unwrap()
.unwrap();
match read {
ServiceConfigSnapshot::Didcomm(DidcommSnapshot::Disabled) => {}
other => panic!("expected Didcomm(Disabled), got {other:?}"),
}
}
#[tokio::test]
async fn snapshot_enabled_with_mediator_round_trips() {
let (_dir, ks) = empty_keyspace(snapshot::KEYSPACE_NAME).await;
snapshot::write(
&ks,
ServiceConfigSnapshot::Didcomm(DidcommSnapshot::Enabled {
mediator_did: "did:peer:2.M".into(),
routing_keys: vec![],
}),
)
.await
.unwrap();
let read = snapshot::read(&ks, ServiceKind::Didcomm)
.await
.unwrap()
.unwrap();
match read {
ServiceConfigSnapshot::Didcomm(DidcommSnapshot::Enabled { mediator_did, .. }) => {
assert_eq!(mediator_did, "did:peer:2.M");
}
other => panic!("expected Didcomm(Enabled), got {other:?}"),
}
}
#[test]
fn rollback_kind_variants_are_distinct() {
assert_ne!(RollbackKind::Disabled, RollbackKind::Enabled);
assert_ne!(RollbackKind::Enabled, RollbackKind::Updated);
assert_ne!(RollbackKind::Updated, RollbackKind::NoOp);
}
#[test]
fn params_carry_transport_and_ttl() {
let p = RollbackDidcommParams {
drain_ttl: Duration::from_secs(86_400),
transport: DisableTransport::Rest,
};
assert_eq!(p.drain_ttl.as_secs(), 86_400);
assert_eq!(p.transport, DisableTransport::Rest);
}
}