use std::sync::Arc;
use std::time::Duration;
use affinidi_did_resolver_cache_sdk::DIDCacheClient;
use chrono::Utc;
use serde_json::Value as JsonValue;
use thiserror::Error;
use tokio::sync::RwLock;
use tracing::info;
use vti_common::seed_store::SeedStore;
use vti_common::telemetry::{SharedTelemetrySink, TelemetryEvent, TelemetryKind};
use crate::auth::AuthClaims;
use crate::config::AppConfig;
use crate::didcomm_bridge::DIDCommBridge;
use crate::error::AppError;
use crate::messaging::drain_sweeper::DrainSweeper;
use crate::messaging::registry::{MediatorListenerRegistry, RegistryError};
use crate::operations::did_webvh::{UpdateDidWebvhError, UpdateDidWebvhOptions, update_did_webvh};
use crate::operations::protocol::PROTOCOL_LOCK;
use crate::operations::protocol::document::{DocumentPatchError, without_didcomm_service};
use crate::store::KeyspaceHandle;
use crate::webvh_store;
pub const MIN_DRAIN_TTL_OVER_DIDCOMM: Duration = Duration::from_secs(3600);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DisableTransport {
Rest,
Didcomm,
}
#[derive(Debug, Clone)]
pub struct DisableDidcommParams {
pub drain_ttl: Duration,
pub transport: DisableTransport,
}
#[derive(Debug, Clone)]
pub struct DisableDidcommResult {
pub new_version_id: String,
pub prior_mediator_did: String,
pub drains_until: Option<chrono::DateTime<chrono::Utc>>,
}
#[derive(Debug, Error)]
pub enum DisableDidcommError {
#[error(
"DIDComm is not currently enabled. Use `pnm services enable didcomm --mediator-did <did>` first."
)]
DidcommNotEnabled,
#[error(
"cannot disable DIDComm — REST is also disabled. The VTA would have no protocol surface left. \
Run `pnm services enable rest` first, or use `pnm services disable didcomm --drain-ttl 0s` \
after enabling REST."
)]
NoProtocolRemaining,
#[error(
"drain-ttl 0s over DIDComm transport is not permitted (minimum 1h). \
Either retry over REST transport (`--transport rest`) or use a TTL >= 1h."
)]
DrainTtlTooShortForDidcomm,
#[error("VTA DID is not configured — run `vta setup` first")]
VtaDidNotConfigured,
#[error("VTA DID `{0}` has no webvh record")]
VtaDidRecordMissing(String),
#[error("VTA DID `{0}` has no published log")]
VtaDidLogMissing(String),
#[error("VTA DID log is empty")]
EmptyLog,
#[error(
"DIDComm is enabled but the VTA's DID document has no `#vta-didcomm` service entry — \
on-disk state is inconsistent (re-run setup)"
)]
NoActiveMediator,
#[error("DID document patch failed: {0}")]
DocumentPatch(#[from] DocumentPatchError),
#[error("WebVH update failed: {0}")]
WebVHUpdate(#[from] UpdateDidWebvhError),
#[error("config persistence failed: {0}")]
ConfigPersistence(String),
#[error(transparent)]
Registry(#[from] RegistryError),
#[error("auth: {0}")]
Auth(String),
#[error("storage error: {0}")]
Storage(String),
}
impl From<AppError> for DisableDidcommError {
fn from(value: AppError) -> Self {
Self::Storage(value.to_string())
}
}
#[allow(clippy::too_many_arguments)]
pub async fn disable_didcomm(
config: &Arc<RwLock<AppConfig>>,
keys_ks: &KeyspaceHandle,
contexts_ks: &KeyspaceHandle,
webvh_ks: &KeyspaceHandle,
audit_ks: &KeyspaceHandle,
drains_ks: &KeyspaceHandle,
seed_store: &dyn SeedStore,
did_resolver: &DIDCacheClient,
didcomm_bridge: &Arc<DIDCommBridge>,
registry: &MediatorListenerRegistry,
sweeper: &DrainSweeper,
telemetry: &SharedTelemetrySink,
auth: &AuthClaims,
params: DisableDidcommParams,
channel: &str,
) -> Result<DisableDidcommResult, DisableDidcommError> {
auth.require_super_admin()
.map_err(|e| DisableDidcommError::Auth(e.to_string()))?;
let _guard = PROTOCOL_LOCK.lock().await;
let (_vta_did, scid, current_doc, prior_mediator) =
read_preconditions(config, webvh_ks, ¶ms).await?;
let patched = without_didcomm_service(current_doc);
let update_result = update_did_webvh(
keys_ks,
contexts_ks,
webvh_ks,
audit_ks,
seed_store,
auth,
&scid,
UpdateDidWebvhOptions {
document: Some(patched),
..Default::default()
},
did_resolver,
didcomm_bridge,
channel,
)
.await?;
persist_didcomm_disabled(config).await?;
let drains_until = if params.drain_ttl.is_zero() {
registry.record_deactivate().await;
None
} else {
registry.record_deactivate().await;
let deadline = Utc::now()
+ chrono::Duration::from_std(params.drain_ttl).map_err(|e| {
DisableDidcommError::ConfigPersistence(format!("drain TTL out of range: {e}"))
})?;
let endpoint = best_effort_endpoint(did_resolver, &prior_mediator).await;
registry
.record_drain_persisted(drains_ks, &prior_mediator, endpoint, deadline)
.await?;
sweeper.arm(&prior_mediator, deadline).await;
Some(deadline)
};
let _ = telemetry
.record(
TelemetryEvent::new(TelemetryKind::ServicesDidcommDisable)
.with_mediator(&prior_mediator)
.with_field(
"drain_ttl_secs",
JsonValue::from(params.drain_ttl.as_secs()),
)
.with_field(
"new_version_id",
JsonValue::from(update_result.new_version_id.clone()),
)
.with_field(
"transport",
JsonValue::from(match params.transport {
DisableTransport::Rest => "rest",
DisableTransport::Didcomm => "didcomm",
}),
),
)
.await;
info!(
channel,
prior_mediator = %prior_mediator,
new_version_id = %update_result.new_version_id,
drain_ttl_secs = params.drain_ttl.as_secs(),
"DIDComm disabled"
);
Ok(DisableDidcommResult {
new_version_id: update_result.new_version_id,
prior_mediator_did: prior_mediator,
drains_until,
})
}
async fn read_preconditions(
config: &Arc<RwLock<AppConfig>>,
webvh_ks: &KeyspaceHandle,
params: &DisableDidcommParams,
) -> Result<(String, String, JsonValue, String), DisableDidcommError> {
let cfg = config.read().await;
if !cfg.services.didcomm {
return Err(DisableDidcommError::DidcommNotEnabled);
}
if !cfg.services.rest {
return Err(DisableDidcommError::NoProtocolRemaining);
}
if params.transport == DisableTransport::Didcomm
&& params.drain_ttl < MIN_DRAIN_TTL_OVER_DIDCOMM
{
return Err(DisableDidcommError::DrainTtlTooShortForDidcomm);
}
let vta_did = cfg
.vta_did
.clone()
.ok_or(DisableDidcommError::VtaDidNotConfigured)?;
drop(cfg);
let record = webvh_store::get_did(webvh_ks, &vta_did)
.await?
.ok_or_else(|| DisableDidcommError::VtaDidRecordMissing(vta_did.clone()))?;
let scid = record.scid.clone();
let did_log = webvh_store::get_did_log(webvh_ks, &vta_did)
.await?
.ok_or_else(|| DisableDidcommError::VtaDidLogMissing(vta_did.clone()))?;
let current_doc = current_document_from_log(&did_log)?;
let prior_mediator =
crate::operations::protocol::document::current_didcomm_service(¤t_doc)
.map(|s| s.mediator_did)
.ok_or(DisableDidcommError::NoActiveMediator)?;
Ok((vta_did, scid, current_doc, prior_mediator))
}
fn current_document_from_log(did_log: &str) -> Result<JsonValue, DisableDidcommError> {
use didwebvh_rs::log_entry::{LogEntry, LogEntryMethods};
let line = did_log
.lines()
.rfind(|l| !l.trim().is_empty())
.ok_or(DisableDidcommError::EmptyLog)?;
let entry: LogEntry = serde_json::from_str(line)
.map_err(|e| DisableDidcommError::Storage(format!("DID log line parse: {e}")))?;
Ok(entry.get_state().clone())
}
async fn persist_didcomm_disabled(
config: &Arc<RwLock<AppConfig>>,
) -> Result<(), DisableDidcommError> {
let (contents, path) = {
let mut cfg = config.write().await;
cfg.services.didcomm = false;
let contents = toml::to_string_pretty(&*cfg)
.map_err(|e| DisableDidcommError::ConfigPersistence(e.to_string()))?;
let path = cfg.config_path.clone();
(contents, path)
};
std::fs::write(&path, contents)
.map_err(|e| DisableDidcommError::ConfigPersistence(e.to_string()))?;
Ok(())
}
async fn best_effort_endpoint(resolver: &DIDCacheClient, mediator_did: &str) -> String {
match crate::messaging::handshake::resolve_mediator(resolver, mediator_did).await {
Ok(r) => r.endpoint,
Err(_) => String::new(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{AppConfig, ServerConfig, ServicesConfig, StoreConfig};
use crate::keys::seed_store::PlaintextSeedStore;
use crate::store::Store;
use vti_common::telemetry::RingBufferTelemetry;
fn fresh_config(tmpdir: &std::path::Path, didcomm: bool, rest: bool) -> Arc<RwLock<AppConfig>> {
let cfg = AppConfig {
server: ServerConfig {
host: "127.0.0.1".into(),
port: 0,
},
log: Default::default(),
store: StoreConfig {
data_dir: tmpdir.into(),
},
services: ServicesConfig { rest, didcomm },
vta_did: Some("did:webvh:scid123:host:vta".into()),
vta_name: None,
public_url: None,
messaging: None,
secrets: Default::default(),
auth: Default::default(),
audit: Default::default(),
#[cfg(feature = "tee")]
tee: Default::default(),
resolver_url: None,
config_path: tmpdir.join("config.toml"),
};
Arc::new(RwLock::new(cfg))
}
fn registry() -> (
Arc<DIDCommBridge>,
Arc<MediatorListenerRegistry>,
SharedTelemetrySink,
) {
let bridge = Arc::new(DIDCommBridge::placeholder());
let sink: SharedTelemetrySink = Arc::new(RingBufferTelemetry::with_capacity(64));
let registry = Arc::new(MediatorListenerRegistry::new(Arc::clone(&sink)));
(bridge, registry, sink)
}
fn sweeper_for(
registry: Arc<MediatorListenerRegistry>,
drains_ks: KeyspaceHandle,
) -> Arc<DrainSweeper> {
let (tx, _rx) = crate::messaging::drain_sweeper::teardown_channel(8);
Arc::new(DrainSweeper::new(registry, drains_ks, tx))
}
async fn empty_keyspace(name: &str) -> (tempfile::TempDir, KeyspaceHandle) {
let dir = tempfile::tempdir().unwrap();
let store = Store::open(&StoreConfig {
data_dir: dir.path().into(),
})
.unwrap();
let ks = store.keyspace(name).unwrap();
(dir, ks)
}
fn super_admin() -> AuthClaims {
AuthClaims::unsafe_local_cli_super_admin("test")
}
fn dummy_seed(dir: &std::path::Path) -> Arc<dyn SeedStore> {
Arc::new(PlaintextSeedStore::new(dir))
}
async fn resolver() -> DIDCacheClient {
DIDCacheClient::new(
affinidi_did_resolver_cache_sdk::config::DIDCacheConfigBuilder::default().build(),
)
.await
.unwrap()
}
fn rest_params(ttl: Duration) -> DisableDidcommParams {
DisableDidcommParams {
drain_ttl: ttl,
transport: DisableTransport::Rest,
}
}
fn didcomm_params(ttl: Duration) -> DisableDidcommParams {
DisableDidcommParams {
drain_ttl: ttl,
transport: DisableTransport::Didcomm,
}
}
#[tokio::test]
async fn refuses_when_not_currently_enabled() {
let dir = tempfile::tempdir().unwrap();
let config = fresh_config(
dir.path(),
false,
true,
);
let (bridge, reg, sink) = registry();
let (_d1, keys_ks) = empty_keyspace("keys").await;
let (_d2, contexts_ks) = empty_keyspace("contexts").await;
let (_d3, webvh_ks) = empty_keyspace("webvh").await;
let (_d4, audit_ks) = empty_keyspace("audit").await;
let (_d5, drains_ks) = empty_keyspace("drains").await;
let resolver = resolver().await;
let seed = dummy_seed(dir.path());
let err = disable_didcomm(
&config,
&keys_ks,
&contexts_ks,
&webvh_ks,
&audit_ks,
&drains_ks,
&*seed,
&resolver,
&bridge,
®,
&sweeper_for(Arc::clone(®), drains_ks.clone()),
&sink,
&super_admin(),
rest_params(Duration::from_secs(3600)),
"test",
)
.await
.unwrap_err();
assert!(matches!(err, DisableDidcommError::DidcommNotEnabled));
}
#[tokio::test]
async fn refuses_when_rest_also_disabled() {
let dir = tempfile::tempdir().unwrap();
let config = fresh_config(
dir.path(),
true,
false,
);
let (bridge, reg, sink) = registry();
let (_d1, keys_ks) = empty_keyspace("keys").await;
let (_d2, contexts_ks) = empty_keyspace("contexts").await;
let (_d3, webvh_ks) = empty_keyspace("webvh").await;
let (_d4, audit_ks) = empty_keyspace("audit").await;
let (_d5, drains_ks) = empty_keyspace("drains").await;
let resolver = resolver().await;
let seed = dummy_seed(dir.path());
let err = disable_didcomm(
&config,
&keys_ks,
&contexts_ks,
&webvh_ks,
&audit_ks,
&drains_ks,
&*seed,
&resolver,
&bridge,
®,
&sweeper_for(Arc::clone(®), drains_ks.clone()),
&sink,
&super_admin(),
rest_params(Duration::from_secs(3600)),
"test",
)
.await
.unwrap_err();
assert!(matches!(err, DisableDidcommError::NoProtocolRemaining));
}
#[tokio::test]
async fn refuses_short_drain_over_didcomm() {
let dir = tempfile::tempdir().unwrap();
let config = fresh_config(dir.path(), true, true);
let (bridge, reg, sink) = registry();
let (_d1, keys_ks) = empty_keyspace("keys").await;
let (_d2, contexts_ks) = empty_keyspace("contexts").await;
let (_d3, webvh_ks) = empty_keyspace("webvh").await;
let (_d4, audit_ks) = empty_keyspace("audit").await;
let (_d5, drains_ks) = empty_keyspace("drains").await;
let resolver = resolver().await;
let seed = dummy_seed(dir.path());
let err = disable_didcomm(
&config,
&keys_ks,
&contexts_ks,
&webvh_ks,
&audit_ks,
&drains_ks,
&*seed,
&resolver,
&bridge,
®,
&sweeper_for(Arc::clone(®), drains_ks.clone()),
&sink,
&super_admin(),
didcomm_params(Duration::from_secs(1800)),
"test",
)
.await
.unwrap_err();
assert!(matches!(
err,
DisableDidcommError::DrainTtlTooShortForDidcomm
));
}
#[tokio::test]
async fn allows_zero_drain_over_rest() {
let dir = tempfile::tempdir().unwrap();
let config = fresh_config(dir.path(), true, true);
let (bridge, reg, sink) = registry();
let (_d1, keys_ks) = empty_keyspace("keys").await;
let (_d2, contexts_ks) = empty_keyspace("contexts").await;
let (_d3, webvh_ks) = empty_keyspace("webvh").await;
let (_d4, audit_ks) = empty_keyspace("audit").await;
let (_d5, drains_ks) = empty_keyspace("drains").await;
let resolver = resolver().await;
let seed = dummy_seed(dir.path());
let err = disable_didcomm(
&config,
&keys_ks,
&contexts_ks,
&webvh_ks,
&audit_ks,
&drains_ks,
&*seed,
&resolver,
&bridge,
®,
&sweeper_for(Arc::clone(®), drains_ks.clone()),
&sink,
&super_admin(),
rest_params(Duration::from_secs(0)),
"test",
)
.await
.unwrap_err();
assert!(
matches!(err, DisableDidcommError::VtaDidRecordMissing(_)),
"TTL=0 over REST must pass the TTL guard; saw {err:?}"
);
}
#[tokio::test]
async fn allows_1h_drain_over_didcomm() {
let dir = tempfile::tempdir().unwrap();
let config = fresh_config(dir.path(), true, true);
let (bridge, reg, sink) = registry();
let (_d1, keys_ks) = empty_keyspace("keys").await;
let (_d2, contexts_ks) = empty_keyspace("contexts").await;
let (_d3, webvh_ks) = empty_keyspace("webvh").await;
let (_d4, audit_ks) = empty_keyspace("audit").await;
let (_d5, drains_ks) = empty_keyspace("drains").await;
let resolver = resolver().await;
let seed = dummy_seed(dir.path());
let err = disable_didcomm(
&config,
&keys_ks,
&contexts_ks,
&webvh_ks,
&audit_ks,
&drains_ks,
&*seed,
&resolver,
&bridge,
®,
&sweeper_for(Arc::clone(®), drains_ks.clone()),
&sink,
&super_admin(),
didcomm_params(Duration::from_secs(3600)),
"test",
)
.await
.unwrap_err();
assert!(matches!(err, DisableDidcommError::VtaDidRecordMissing(_)));
}
}