use serde_json::Value as JsonValue;
use thiserror::Error;
use tracing::info;
use vta_sdk::error::VtaError;
use vti_common::telemetry::{TelemetryEvent, TelemetryKind};
use crate::auth::AuthClaims;
use crate::error::AppError;
use crate::operations::did_webvh::UpdateDidWebvhError;
use crate::operations::protocol::document::DocumentPatchError;
use crate::operations::protocol::service_lifecycle::{
DisableMutationError, RestService, ServiceLifecycle, check_disable_preconditions, publish_patch,
};
use crate::operations::protocol::{OpContext, ServiceOpDeps};
use crate::operations::protocol::{PROTOCOL_LOCK, snapshot};
#[derive(Debug, Clone, Default)]
pub struct DisableRestParams;
#[derive(Debug, Clone)]
pub struct DisableRestResult {
pub new_version_id: String,
pub prior_url: String,
pub vta_did: String,
pub serverless: bool,
}
#[derive(Debug, Error)]
pub enum DisableRestError {
#[error("REST is not currently enabled — nothing to disable.")]
ServiceNotPresent,
#[error(
"refusing operation: would leave the VTA with no advertised services. \
Enable DIDComm first via `services didcomm enable --mediator-did <did>`, \
then retry."
)]
LastServiceRefused,
#[error("VTA DID is not configured — run `vta setup` first")]
VtaDidNotConfigured,
#[error("VTA DID `{0}` has no webvh record")]
VtaDidRecordMissing(String),
#[error("VTA DID `{0}` has no published log")]
VtaDidLogMissing(String),
#[error("VTA DID log is empty")]
EmptyLog,
#[error("DID document patch failed: {0}")]
DocumentPatch(#[from] DocumentPatchError),
#[error("WebVH update failed: {0}")]
WebVHUpdate(#[from] UpdateDidWebvhError),
#[error("config persistence failed: {0}")]
ConfigPersistence(String),
#[error("auth: {0}")]
Auth(String),
#[error("storage error: {0}")]
Storage(String),
}
impl From<AppError> for DisableRestError {
fn from(value: AppError) -> Self {
Self::Storage(value.to_string())
}
}
impl From<crate::operations::protocol::preconditions::ProtocolPreconditionError>
for DisableRestError
{
fn from(value: crate::operations::protocol::preconditions::ProtocolPreconditionError) -> Self {
use crate::operations::protocol::preconditions::ProtocolPreconditionError as E;
match value {
E::VtaDidNotConfigured => Self::VtaDidNotConfigured,
E::VtaDidRecordMissing(s) => Self::VtaDidRecordMissing(s),
E::VtaDidLogMissing(s) => Self::VtaDidLogMissing(s),
E::EmptyLog => Self::EmptyLog,
E::Storage(s) | E::DocumentParse(s) => Self::Storage(s),
}
}
}
impl From<VtaError> for DisableRestError {
fn from(value: VtaError) -> Self {
match value {
VtaError::LastServiceRefused => Self::LastServiceRefused,
other => Self::Storage(other.to_string()),
}
}
}
impl DisableMutationError for DisableRestError {
fn not_present() -> Self {
Self::ServiceNotPresent
}
}
pub async fn disable_rest(
deps: &ServiceOpDeps<'_>,
auth: &AuthClaims,
_params: DisableRestParams,
ctx: OpContext,
channel: &str,
) -> Result<DisableRestResult, DisableRestError> {
auth.require_super_admin()
.map_err(|e| DisableRestError::Auth(e.to_string()))?;
let _guard = PROTOCOL_LOCK.lock().await;
let (state, prior_url) =
check_disable_preconditions::<RestService, DisableRestError>(deps.config, deps.webvh_ks)
.await?;
snapshot::write(
deps.snapshot_ks,
RestService::snapshot_enabled(prior_url.clone()),
)
.await
.map_err(|e| DisableRestError::Storage(format!("snapshot write: {e}")))?;
let patched = RestService::without_service(state.current_doc);
let update_result = publish_patch::<DisableRestError>(
deps,
auth,
&state.scid,
&state.vta_did,
patched,
channel,
)
.await?;
crate::operations::protocol::runtime_state::set_rest_enabled(deps.service_state_ks, false)
.await
.map_err(|e| DisableRestError::Storage(format!("runtime state: {e}")))?;
{
let mut cfg = deps.config.write().await;
cfg.services.rest = false;
}
let mut event = TelemetryEvent::new(TelemetryKind::ServicesRestDisable)
.with_field("channel", JsonValue::from(channel))
.with_field(
"new_version_id",
JsonValue::from(update_result.new_version_id.clone()),
)
.with_field("prior_url", JsonValue::from(prior_url.clone()));
if let Some(tag) = ctx.telemetry_triggered_by() {
event = event.with_field("triggered_by", JsonValue::from(tag));
}
let _ = deps.telemetry.record(event).await;
info!(
channel,
prior_url = %prior_url,
new_version_id = %update_result.new_version_id,
vta_did = %state.vta_did,
"REST disabled"
);
Ok(DisableRestResult {
new_version_id: update_result.new_version_id,
prior_url,
vta_did: state.vta_did,
serverless: update_result.serverless,
})
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tokio::sync::RwLock;
use super::*;
use crate::config::AppConfig;
use crate::operations::protocol::invariant::{
CurrentServices, ProposedOp, would_violate_last_service,
};
use crate::operations::protocol::snapshot::ServiceKind;
use crate::store::{KeyspaceHandle, Store};
use vti_common::config::StoreConfig as VtiStoreConfig;
struct TestFixture {
_dir: tempfile::TempDir,
config: Arc<RwLock<AppConfig>>,
store: Store,
}
impl TestFixture {
fn webvh_ks(&self) -> KeyspaceHandle {
self.store.keyspace(crate::keyspaces::WEBVH).unwrap()
}
}
fn build_fixture(rest_initially: bool, didcomm_initially: bool) -> TestFixture {
use crate::test_support::test_app_config;
let dir = tempfile::tempdir().unwrap();
let mut cfg = test_app_config(dir.path().into());
cfg.services.rest = rest_initially;
cfg.services.didcomm = didcomm_initially;
cfg.vta_did = Some("did:webvh:scid123:host:vta".into());
cfg.config_path = dir.path().join("vta.toml");
let initial = toml::to_string_pretty(&cfg).unwrap();
std::fs::write(&cfg.config_path, initial).unwrap();
let store = Store::open(&VtiStoreConfig {
data_dir: dir.path().into(),
})
.unwrap();
TestFixture {
_dir: dir,
config: Arc::new(RwLock::new(cfg)),
store,
}
}
#[tokio::test]
async fn preconditions_reject_when_rest_disabled() {
let fx = build_fixture(false, true);
let err = check_disable_preconditions::<RestService, DisableRestError>(
&fx.config,
&fx.webvh_ks(),
)
.await
.unwrap_err();
assert!(matches!(err, DisableRestError::ServiceNotPresent));
}
#[tokio::test]
async fn preconditions_reject_when_would_brick() {
let fx = build_fixture(true, false);
let err = check_disable_preconditions::<RestService, DisableRestError>(
&fx.config,
&fx.webvh_ks(),
)
.await
.unwrap_err();
assert!(matches!(err, DisableRestError::LastServiceRefused));
}
#[tokio::test]
async fn preconditions_reject_without_vta_did() {
let fx = build_fixture(true, true);
fx.config.write().await.vta_did = None;
let err = check_disable_preconditions::<RestService, DisableRestError>(
&fx.config,
&fx.webvh_ks(),
)
.await
.unwrap_err();
assert!(matches!(err, DisableRestError::VtaDidNotConfigured));
}
#[test]
fn brick_prevention_rejects_disable_rest_when_didcomm_off() {
let result = would_violate_last_service(
&CurrentServices::new(true, false, false),
ProposedOp::disable(ServiceKind::Rest),
);
let err = DisableRestError::from(result.unwrap_err());
assert!(matches!(err, DisableRestError::LastServiceRefused));
}
#[test]
fn brick_prevention_allows_disable_rest_when_didcomm_on() {
let result = would_violate_last_service(
&CurrentServices::new(true, true, false),
ProposedOp::disable(ServiceKind::Rest),
);
assert!(result.is_ok());
}
#[test]
fn brick_prevention_allows_disable_rest_when_webauthn_on() {
let result = would_violate_last_service(
&CurrentServices::new(true, false, true),
ProposedOp::disable(ServiceKind::Rest),
);
assert!(result.is_ok());
}
#[test]
fn vta_error_to_disable_rest_error_mapping_is_typed() {
let mapped = DisableRestError::from(VtaError::LastServiceRefused);
assert!(matches!(mapped, DisableRestError::LastServiceRefused));
let mapped = DisableRestError::from(VtaError::ServiceNotPresent);
assert!(matches!(mapped, DisableRestError::Storage(_)));
}
}