use std::sync::Arc;
use std::time::Duration;
use affinidi_did_resolver_cache_sdk::DIDCacheClient;
use chrono::Utc;
use serde_json::Value as JsonValue;
use thiserror::Error;
use tokio::sync::RwLock;
use tracing::info;
use vti_common::seed_store::SeedStore;
use vti_common::telemetry::{SharedTelemetrySink, TelemetryEvent, TelemetryKind};
use crate::auth::AuthClaims;
use crate::config::AppConfig;
use crate::didcomm_bridge::DIDCommBridge;
use crate::error::AppError;
use crate::messaging::drain_sweeper::DrainSweeper;
use crate::messaging::registry::{MediatorListenerRegistry, RegistryError};
use crate::operations::did_webvh::{UpdateDidWebvhError, UpdateDidWebvhOptions, update_did_webvh};
use crate::operations::protocol::document::{DocumentPatchError, without_didcomm_service};
use crate::operations::protocol::{OpContext, PROTOCOL_LOCK};
use crate::store::KeyspaceHandle;
pub const MIN_DRAIN_TTL_OVER_DIDCOMM: Duration = Duration::from_secs(3600);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DisableTransport {
Rest,
Didcomm,
}
#[derive(Debug, Clone)]
pub struct DisableDidcommParams {
pub drain_ttl: Duration,
pub transport: DisableTransport,
}
#[derive(Debug, Clone)]
pub struct DisableDidcommResult {
pub new_version_id: String,
pub prior_mediator_did: String,
pub drains_until: Option<chrono::DateTime<chrono::Utc>>,
pub vta_did: String,
pub serverless: bool,
}
#[derive(Debug, Error)]
pub enum DisableDidcommError {
#[error(
"DIDComm is not currently enabled. Enable it first: \
`pnm services didcomm enable --mediator-did <did>` (online) \
or `vta services didcomm enable --mediator-did <did>` (offline, daemon stopped)."
)]
DidcommNotEnabled,
#[error(
"cannot disable DIDComm — REST is also disabled. The VTA would have no protocol surface left. \
Enable REST first: `pnm services rest enable --url <url>` (online) \
or `vta services rest enable --url <url>` (offline, daemon stopped). \
Then retry the disable."
)]
NoProtocolRemaining,
#[error("drain ttl {requested}s outside allowed range [{min}s, {max}s]")]
DrainTtlOutOfBounds { min: u64, max: u64, requested: u64 },
#[error("VTA DID is not configured — run `vta setup` first")]
VtaDidNotConfigured,
#[error("VTA DID `{0}` has no webvh record")]
VtaDidRecordMissing(String),
#[error("VTA DID `{0}` has no published log")]
VtaDidLogMissing(String),
#[error("VTA DID log is empty")]
EmptyLog,
#[error(
"DIDComm is enabled but the VTA's DID document has no `#vta-didcomm` service entry — \
on-disk state is inconsistent (re-run setup)"
)]
NoActiveMediator,
#[error("DID document patch failed: {0}")]
DocumentPatch(#[from] DocumentPatchError),
#[error("WebVH update failed: {0}")]
WebVHUpdate(#[from] UpdateDidWebvhError),
#[error("config persistence failed: {0}")]
ConfigPersistence(String),
#[error(transparent)]
Registry(#[from] RegistryError),
#[error("auth: {0}")]
Auth(String),
#[error("storage error: {0}")]
Storage(String),
}
impl From<AppError> for DisableDidcommError {
fn from(value: AppError) -> Self {
Self::Storage(value.to_string())
}
}
impl From<crate::operations::protocol::preconditions::ProtocolPreconditionError>
for DisableDidcommError
{
fn from(value: crate::operations::protocol::preconditions::ProtocolPreconditionError) -> Self {
use crate::operations::protocol::preconditions::ProtocolPreconditionError as E;
match value {
E::VtaDidNotConfigured => Self::VtaDidNotConfigured,
E::VtaDidRecordMissing(s) => Self::VtaDidRecordMissing(s),
E::VtaDidLogMissing(s) => Self::VtaDidLogMissing(s),
E::EmptyLog => Self::EmptyLog,
E::Storage(s) | E::DocumentParse(s) => Self::Storage(s),
}
}
}
#[allow(clippy::too_many_arguments)]
pub async fn disable_didcomm(
config: &Arc<RwLock<AppConfig>>,
keys_ks: &KeyspaceHandle,
imported_ks: &KeyspaceHandle,
contexts_ks: &KeyspaceHandle,
webvh_ks: &KeyspaceHandle,
audit_ks: &KeyspaceHandle,
drains_ks: &KeyspaceHandle,
snapshot_ks: &KeyspaceHandle,
service_state_ks: &KeyspaceHandle,
seed_store: &dyn SeedStore,
did_resolver: &DIDCacheClient,
didcomm_bridge: &Arc<DIDCommBridge>,
registry: &MediatorListenerRegistry,
sweeper: &DrainSweeper,
telemetry: &SharedTelemetrySink,
auth: &AuthClaims,
params: DisableDidcommParams,
ctx: OpContext,
webvh_auth_locks: &crate::operations::did_webvh::WebvhAuthLocks,
channel: &str,
) -> Result<DisableDidcommResult, DisableDidcommError> {
auth.require_super_admin()
.map_err(|e| DisableDidcommError::Auth(e.to_string()))?;
let _guard = PROTOCOL_LOCK.lock().await;
let (vta_did, scid, current_doc, prior_mediator) =
read_preconditions(config, webvh_ks, ¶ms).await?;
use crate::operations::protocol::snapshot::{self, DidcommSnapshot, ServiceConfigSnapshot};
snapshot::write(
snapshot_ks,
ServiceConfigSnapshot::Didcomm(DidcommSnapshot::Enabled {
mediator_did: prior_mediator.clone(),
routing_keys: vec![],
}),
)
.await
.map_err(|e| DisableDidcommError::Storage(format!("snapshot write: {e}")))?;
let patched = without_didcomm_service(current_doc);
let update_result = update_did_webvh(
keys_ks,
imported_ks,
contexts_ks,
webvh_ks,
audit_ks,
seed_store,
auth,
&scid,
UpdateDidWebvhOptions {
document: Some(patched),
..Default::default()
},
did_resolver,
didcomm_bridge,
Some(vta_did.as_str()),
webvh_auth_locks,
channel,
)
.await?;
crate::operations::protocol::runtime_state::set_didcomm_enabled(service_state_ks, false)
.await
.map_err(|e| DisableDidcommError::ConfigPersistence(format!("runtime state: {e}")))?;
{
let mut cfg = config.write().await;
cfg.services.didcomm = false;
}
let drains_until = if params.drain_ttl.is_zero() {
registry.record_deactivate().await;
None
} else {
registry.record_deactivate().await;
let deadline = Utc::now()
+ chrono::Duration::from_std(params.drain_ttl).map_err(|e| {
DisableDidcommError::ConfigPersistence(format!("drain TTL out of range: {e}"))
})?;
let endpoint = best_effort_endpoint(did_resolver, &prior_mediator).await;
registry
.record_drain_persisted(drains_ks, &prior_mediator, endpoint, deadline)
.await?;
sweeper.arm(&prior_mediator, deadline).await;
Some(deadline)
};
let mut event = TelemetryEvent::new(TelemetryKind::ServicesDidcommDisable)
.with_mediator(&prior_mediator)
.with_field(
"drain_ttl_secs",
JsonValue::from(params.drain_ttl.as_secs()),
)
.with_field(
"new_version_id",
JsonValue::from(update_result.new_version_id.clone()),
)
.with_field(
"transport",
JsonValue::from(match params.transport {
DisableTransport::Rest => "rest",
DisableTransport::Didcomm => "didcomm",
}),
);
if let Some(tag) = ctx.telemetry_triggered_by() {
event = event.with_field("triggered_by", JsonValue::from(tag));
}
let _ = telemetry.record(event).await;
info!(
channel,
prior_mediator = %prior_mediator,
new_version_id = %update_result.new_version_id,
drain_ttl_secs = params.drain_ttl.as_secs(),
"DIDComm disabled"
);
Ok(DisableDidcommResult {
new_version_id: update_result.new_version_id,
prior_mediator_did: prior_mediator,
drains_until,
vta_did,
serverless: update_result.serverless,
})
}
async fn read_preconditions(
config: &Arc<RwLock<AppConfig>>,
webvh_ks: &KeyspaceHandle,
params: &DisableDidcommParams,
) -> Result<(String, String, JsonValue, String), DisableDidcommError> {
use crate::operations::protocol::invariant::{
CurrentServices, ProposedOp, would_violate_last_service,
};
use crate::operations::protocol::snapshot::ServiceKind;
use vta_sdk::error::VtaError;
{
let cfg = config.read().await;
if !cfg.services.didcomm {
return Err(DisableDidcommError::DidcommNotEnabled);
}
if let Err(VtaError::LastServiceRefused) = would_violate_last_service(
&CurrentServices::new(
cfg.services.rest,
cfg.services.didcomm,
cfg.services.webauthn,
),
ProposedOp::disable(ServiceKind::Didcomm),
) {
return Err(DisableDidcommError::NoProtocolRemaining);
}
crate::operations::protocol::validate_drain_ttl(params.transport, params.drain_ttl)
.map_err(|e| DisableDidcommError::DrainTtlOutOfBounds {
min: e.min,
max: e.max,
requested: e.requested,
})?;
}
let state = super::preconditions::load_vta_doc_state(config, webvh_ks).await?;
let prior_mediator =
crate::operations::protocol::document::current_didcomm_service(&state.current_doc)
.map(|s| s.mediator_did)
.ok_or(DisableDidcommError::NoActiveMediator)?;
Ok((state.vta_did, state.scid, state.current_doc, prior_mediator))
}
async fn best_effort_endpoint(resolver: &DIDCacheClient, mediator_did: &str) -> String {
match crate::messaging::handshake::resolve_mediator(resolver, mediator_did).await {
Ok(r) => r.endpoint,
Err(_) => String::new(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::AppConfig;
use crate::keys::seed_store::PlaintextSeedStore;
use crate::operations::protocol::snapshot;
use crate::store::Store;
use crate::test_support::test_app_config;
use vti_common::telemetry::RingBufferTelemetry;
fn fresh_config(tmpdir: &std::path::Path, didcomm: bool, rest: bool) -> Arc<RwLock<AppConfig>> {
let mut cfg = test_app_config(tmpdir.into());
cfg.services.rest = rest;
cfg.services.didcomm = didcomm;
cfg.vta_did = Some("did:webvh:scid123:host:vta".into());
cfg.config_path = tmpdir.join("config.toml");
Arc::new(RwLock::new(cfg))
}
fn registry() -> (
Arc<DIDCommBridge>,
Arc<MediatorListenerRegistry>,
SharedTelemetrySink,
) {
let bridge = Arc::new(DIDCommBridge::placeholder());
let sink: SharedTelemetrySink = Arc::new(RingBufferTelemetry::with_capacity(64));
let registry = Arc::new(MediatorListenerRegistry::new(Arc::clone(&sink)));
(bridge, registry, sink)
}
fn sweeper_for(
registry: Arc<MediatorListenerRegistry>,
drains_ks: KeyspaceHandle,
) -> Arc<DrainSweeper> {
let (tx, _rx) = crate::messaging::drain_sweeper::teardown_channel(8);
Arc::new(DrainSweeper::new(registry, drains_ks, tx))
}
async fn empty_keyspace(name: &str) -> (tempfile::TempDir, KeyspaceHandle) {
use vti_common::config::StoreConfig as VtiStoreConfig;
let dir = tempfile::tempdir().unwrap();
let store = Store::open(&VtiStoreConfig {
data_dir: dir.path().into(),
})
.unwrap();
let ks = store.keyspace(name).unwrap();
(dir, ks)
}
fn super_admin() -> AuthClaims {
AuthClaims::unsafe_local_cli_super_admin("test")
}
fn dummy_seed(dir: &std::path::Path) -> Arc<dyn SeedStore> {
Arc::new(PlaintextSeedStore::new(dir))
}
async fn resolver() -> DIDCacheClient {
DIDCacheClient::new(
affinidi_did_resolver_cache_sdk::config::DIDCacheConfigBuilder::default().build(),
)
.await
.unwrap()
}
fn rest_params(ttl: Duration) -> DisableDidcommParams {
DisableDidcommParams {
drain_ttl: ttl,
transport: DisableTransport::Rest,
}
}
fn didcomm_params(ttl: Duration) -> DisableDidcommParams {
DisableDidcommParams {
drain_ttl: ttl,
transport: DisableTransport::Didcomm,
}
}
#[tokio::test]
async fn refuses_when_not_currently_enabled() {
let dir = tempfile::tempdir().unwrap();
let config = fresh_config(
dir.path(),
false,
true,
);
let (bridge, reg, sink) = registry();
let (_d1, keys_ks) = empty_keyspace("keys").await;
let (_dimp, imported_ks) = empty_keyspace("imported_secrets").await;
let (_d2, contexts_ks) = empty_keyspace("contexts").await;
let (_d3, webvh_ks) = empty_keyspace("webvh").await;
let (_d4, audit_ks) = empty_keyspace("audit").await;
let (_d5, drains_ks) = empty_keyspace("drains").await;
let (_d6, snapshot_ks) = empty_keyspace(snapshot::KEYSPACE_NAME).await;
let (_d_svc_state, service_state_ks) = empty_keyspace("service_state").await;
let resolver = resolver().await;
let seed = dummy_seed(dir.path());
let err = disable_didcomm(
&config,
&keys_ks,
&imported_ks,
&contexts_ks,
&webvh_ks,
&audit_ks,
&drains_ks,
&snapshot_ks,
&service_state_ks,
&*seed,
&resolver,
&bridge,
®,
&sweeper_for(Arc::clone(®), drains_ks.clone()),
&sink,
&super_admin(),
rest_params(Duration::from_secs(3600)),
OpContext::Direct,
&crate::operations::did_webvh::WebvhAuthLocks::new(),
"test",
)
.await
.unwrap_err();
assert!(matches!(err, DisableDidcommError::DidcommNotEnabled));
}
#[tokio::test]
async fn refuses_when_rest_also_disabled() {
let dir = tempfile::tempdir().unwrap();
let config = fresh_config(
dir.path(),
true,
false,
);
let (bridge, reg, sink) = registry();
let (_d1, keys_ks) = empty_keyspace("keys").await;
let (_dimp, imported_ks) = empty_keyspace("imported_secrets").await;
let (_d2, contexts_ks) = empty_keyspace("contexts").await;
let (_d3, webvh_ks) = empty_keyspace("webvh").await;
let (_d4, audit_ks) = empty_keyspace("audit").await;
let (_d5, drains_ks) = empty_keyspace("drains").await;
let (_d6, snapshot_ks) = empty_keyspace(snapshot::KEYSPACE_NAME).await;
let (_d_svc_state, service_state_ks) = empty_keyspace("service_state").await;
let resolver = resolver().await;
let seed = dummy_seed(dir.path());
let err = disable_didcomm(
&config,
&keys_ks,
&imported_ks,
&contexts_ks,
&webvh_ks,
&audit_ks,
&drains_ks,
&snapshot_ks,
&service_state_ks,
&*seed,
&resolver,
&bridge,
®,
&sweeper_for(Arc::clone(®), drains_ks.clone()),
&sink,
&super_admin(),
rest_params(Duration::from_secs(3600)),
OpContext::Direct,
&crate::operations::did_webvh::WebvhAuthLocks::new(),
"test",
)
.await
.unwrap_err();
assert!(matches!(err, DisableDidcommError::NoProtocolRemaining));
}
#[tokio::test]
async fn refuses_short_drain_over_didcomm() {
let dir = tempfile::tempdir().unwrap();
let config = fresh_config(dir.path(), true, true);
let (bridge, reg, sink) = registry();
let (_d1, keys_ks) = empty_keyspace("keys").await;
let (_dimp, imported_ks) = empty_keyspace("imported_secrets").await;
let (_d2, contexts_ks) = empty_keyspace("contexts").await;
let (_d3, webvh_ks) = empty_keyspace("webvh").await;
let (_d4, audit_ks) = empty_keyspace("audit").await;
let (_d5, drains_ks) = empty_keyspace("drains").await;
let (_d6, snapshot_ks) = empty_keyspace(snapshot::KEYSPACE_NAME).await;
let (_d_svc_state, service_state_ks) = empty_keyspace("service_state").await;
let resolver = resolver().await;
let seed = dummy_seed(dir.path());
let err = disable_didcomm(
&config,
&keys_ks,
&imported_ks,
&contexts_ks,
&webvh_ks,
&audit_ks,
&drains_ks,
&snapshot_ks,
&service_state_ks,
&*seed,
&resolver,
&bridge,
®,
&sweeper_for(Arc::clone(®), drains_ks.clone()),
&sink,
&super_admin(),
didcomm_params(Duration::from_secs(1800)),
OpContext::Direct,
&crate::operations::did_webvh::WebvhAuthLocks::new(),
"test",
)
.await
.unwrap_err();
assert!(
matches!(
err,
DisableDidcommError::DrainTtlOutOfBounds { min: 3600, .. }
),
"expected DrainTtlOutOfBounds with 1h min, got {err:?}"
);
}
#[tokio::test]
async fn refuses_drain_ttl_above_max() {
let dir = tempfile::tempdir().unwrap();
let config = fresh_config(dir.path(), true, true);
let (bridge, reg, sink) = registry();
let (_d1, keys_ks) = empty_keyspace("keys").await;
let (_dimp, imported_ks) = empty_keyspace("imported_secrets").await;
let (_d2, contexts_ks) = empty_keyspace("contexts").await;
let (_d3, webvh_ks) = empty_keyspace("webvh").await;
let (_d4, audit_ks) = empty_keyspace("audit").await;
let (_d5, drains_ks) = empty_keyspace("drains").await;
let (_d6, snapshot_ks) = empty_keyspace(snapshot::KEYSPACE_NAME).await;
let (_d_svc_state, service_state_ks) = empty_keyspace("service_state").await;
let resolver = resolver().await;
let seed = dummy_seed(dir.path());
let err = disable_didcomm(
&config,
&keys_ks,
&imported_ks,
&contexts_ks,
&webvh_ks,
&audit_ks,
&drains_ks,
&snapshot_ks,
&service_state_ks,
&*seed,
&resolver,
&bridge,
®,
&sweeper_for(Arc::clone(®), drains_ks.clone()),
&sink,
&super_admin(),
rest_params(Duration::from_secs(31 * 86_400)),
OpContext::Direct,
&crate::operations::did_webvh::WebvhAuthLocks::new(),
"test",
)
.await
.unwrap_err();
assert!(
matches!(
err,
DisableDidcommError::DrainTtlOutOfBounds { max: 2_592_000, .. }
),
"expected DrainTtlOutOfBounds with 30d max, got {err:?}"
);
}
#[tokio::test]
async fn allows_zero_drain_over_rest() {
let dir = tempfile::tempdir().unwrap();
let config = fresh_config(dir.path(), true, true);
let (bridge, reg, sink) = registry();
let (_d1, keys_ks) = empty_keyspace("keys").await;
let (_dimp, imported_ks) = empty_keyspace("imported_secrets").await;
let (_d2, contexts_ks) = empty_keyspace("contexts").await;
let (_d3, webvh_ks) = empty_keyspace("webvh").await;
let (_d4, audit_ks) = empty_keyspace("audit").await;
let (_d5, drains_ks) = empty_keyspace("drains").await;
let (_d6, snapshot_ks) = empty_keyspace(snapshot::KEYSPACE_NAME).await;
let (_d_svc_state, service_state_ks) = empty_keyspace("service_state").await;
let resolver = resolver().await;
let seed = dummy_seed(dir.path());
let err = disable_didcomm(
&config,
&keys_ks,
&imported_ks,
&contexts_ks,
&webvh_ks,
&audit_ks,
&drains_ks,
&snapshot_ks,
&service_state_ks,
&*seed,
&resolver,
&bridge,
®,
&sweeper_for(Arc::clone(®), drains_ks.clone()),
&sink,
&super_admin(),
rest_params(Duration::from_secs(0)),
OpContext::Direct,
&crate::operations::did_webvh::WebvhAuthLocks::new(),
"test",
)
.await
.unwrap_err();
assert!(
matches!(err, DisableDidcommError::VtaDidRecordMissing(_)),
"TTL=0 over REST must pass the TTL guard; saw {err:?}"
);
}
#[tokio::test]
async fn allows_1h_drain_over_didcomm() {
let dir = tempfile::tempdir().unwrap();
let config = fresh_config(dir.path(), true, true);
let (bridge, reg, sink) = registry();
let (_d1, keys_ks) = empty_keyspace("keys").await;
let (_dimp, imported_ks) = empty_keyspace("imported_secrets").await;
let (_d2, contexts_ks) = empty_keyspace("contexts").await;
let (_d3, webvh_ks) = empty_keyspace("webvh").await;
let (_d4, audit_ks) = empty_keyspace("audit").await;
let (_d5, drains_ks) = empty_keyspace("drains").await;
let (_d6, snapshot_ks) = empty_keyspace(snapshot::KEYSPACE_NAME).await;
let (_d_svc_state, service_state_ks) = empty_keyspace("service_state").await;
let resolver = resolver().await;
let seed = dummy_seed(dir.path());
let err = disable_didcomm(
&config,
&keys_ks,
&imported_ks,
&contexts_ks,
&webvh_ks,
&audit_ks,
&drains_ks,
&snapshot_ks,
&service_state_ks,
&*seed,
&resolver,
&bridge,
®,
&sweeper_for(Arc::clone(®), drains_ks.clone()),
&sink,
&super_admin(),
didcomm_params(Duration::from_secs(3600)),
OpContext::Direct,
&crate::operations::did_webvh::WebvhAuthLocks::new(),
"test",
)
.await
.unwrap_err();
assert!(matches!(err, DisableDidcommError::VtaDidRecordMissing(_)));
}
}