use vta_sdk::error::VtaError;
use crate::operations::protocol::snapshot::ServiceKind;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CurrentServices {
pub rest_enabled: bool,
pub didcomm_enabled: bool,
pub webauthn_enabled: bool,
}
impl CurrentServices {
pub const fn new(rest_enabled: bool, didcomm_enabled: bool, webauthn_enabled: bool) -> Self {
Self {
rest_enabled,
didcomm_enabled,
webauthn_enabled,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ProposedOp {
pub kind: ServiceKind,
pub kind_will_be_enabled: bool,
}
impl ProposedOp {
pub const fn enable(kind: ServiceKind) -> Self {
Self {
kind,
kind_will_be_enabled: true,
}
}
pub const fn disable(kind: ServiceKind) -> Self {
Self {
kind,
kind_will_be_enabled: false,
}
}
}
pub fn would_violate_last_service(state: &CurrentServices, op: ProposedOp) -> Result<(), VtaError> {
let (rest_after, didcomm_after, webauthn_after) = match op.kind {
ServiceKind::Rest => (
op.kind_will_be_enabled,
state.didcomm_enabled,
state.webauthn_enabled,
),
ServiceKind::Didcomm => (
state.rest_enabled,
op.kind_will_be_enabled,
state.webauthn_enabled,
),
ServiceKind::Webauthn => (
state.rest_enabled,
state.didcomm_enabled,
op.kind_will_be_enabled,
),
};
if !rest_after && !didcomm_after && !webauthn_after {
return Err(VtaError::LastServiceRefused);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
const S0: CurrentServices = CurrentServices::new(false, false, false);
const S1: CurrentServices = CurrentServices::new(true, false, false); const S2: CurrentServices = CurrentServices::new(false, true, false); const S3: CurrentServices = CurrentServices::new(true, true, false);
#[test]
fn s3_can_disable_either_kind() {
assert!(would_violate_last_service(&S3, ProposedOp::disable(ServiceKind::Rest)).is_ok());
assert!(would_violate_last_service(&S3, ProposedOp::disable(ServiceKind::Didcomm)).is_ok());
}
#[test]
fn s1_disable_rest_is_rejected() {
let err =
would_violate_last_service(&S1, ProposedOp::disable(ServiceKind::Rest)).unwrap_err();
assert!(matches!(err, VtaError::LastServiceRefused));
}
#[test]
fn s2_disable_didcomm_is_rejected() {
let err =
would_violate_last_service(&S2, ProposedOp::disable(ServiceKind::Didcomm)).unwrap_err();
assert!(matches!(err, VtaError::LastServiceRefused));
}
#[test]
fn s1_disable_didcomm_is_accepted_by_invariant() {
assert!(would_violate_last_service(&S1, ProposedOp::disable(ServiceKind::Didcomm)).is_ok());
}
#[test]
fn s2_disable_rest_is_accepted_by_invariant() {
assert!(would_violate_last_service(&S2, ProposedOp::disable(ServiceKind::Rest)).is_ok());
}
#[test]
fn enable_never_violates() {
for state in &[S0, S1, S2, S3] {
for kind in &[
ServiceKind::Rest,
ServiceKind::Didcomm,
ServiceKind::Webauthn,
] {
assert!(
would_violate_last_service(state, ProposedOp::enable(*kind)).is_ok(),
"enable from {state:?} for {kind:?} must be accepted",
);
}
}
}
#[test]
fn webauthn_only_state_cannot_disable_webauthn() {
let s_w = CurrentServices::new(false, false, true);
let err = would_violate_last_service(&s_w, ProposedOp::disable(ServiceKind::Webauthn))
.unwrap_err();
assert!(matches!(err, VtaError::LastServiceRefused));
}
#[test]
fn webauthn_keeps_invariant_when_other_two_disabled() {
let s_w = CurrentServices::new(false, false, true);
assert!(would_violate_last_service(&s_w, ProposedOp::disable(ServiceKind::Rest)).is_ok());
assert!(
would_violate_last_service(&s_w, ProposedOp::disable(ServiceKind::Didcomm)).is_ok()
);
}
#[test]
fn all_three_on_any_single_disable_is_ok() {
let s_all = CurrentServices::new(true, true, true);
for kind in &[
ServiceKind::Rest,
ServiceKind::Didcomm,
ServiceKind::Webauthn,
] {
assert!(
would_violate_last_service(&s_all, ProposedOp::disable(*kind)).is_ok(),
"all-on, disable {kind:?} must be accepted",
);
}
}
#[test]
fn s0_enable_either_kind_is_accepted() {
assert!(would_violate_last_service(&S0, ProposedOp::enable(ServiceKind::Rest)).is_ok());
assert!(would_violate_last_service(&S0, ProposedOp::enable(ServiceKind::Didcomm)).is_ok());
}
#[test]
fn s0_disable_is_rejected() {
assert!(matches!(
would_violate_last_service(&S0, ProposedOp::disable(ServiceKind::Rest)).unwrap_err(),
VtaError::LastServiceRefused,
));
assert!(matches!(
would_violate_last_service(&S0, ProposedOp::disable(ServiceKind::Didcomm)).unwrap_err(),
VtaError::LastServiceRefused,
));
}
#[test]
fn full_truth_table() {
let cases = [
(S0, ServiceKind::Rest, false, false),
(S0, ServiceKind::Rest, true, true),
(S0, ServiceKind::Didcomm, false, false),
(S0, ServiceKind::Didcomm, true, true),
(S1, ServiceKind::Rest, false, false), (S1, ServiceKind::Rest, true, true),
(S1, ServiceKind::Didcomm, false, true), (S1, ServiceKind::Didcomm, true, true),
(S2, ServiceKind::Rest, false, true), (S2, ServiceKind::Rest, true, true),
(S2, ServiceKind::Didcomm, false, false), (S2, ServiceKind::Didcomm, true, true),
(S3, ServiceKind::Rest, false, true),
(S3, ServiceKind::Rest, true, true),
(S3, ServiceKind::Didcomm, false, true),
(S3, ServiceKind::Didcomm, true, true),
];
for (state, kind, kind_will_be_enabled, expected_ok) in cases {
let op = ProposedOp {
kind,
kind_will_be_enabled,
};
let result = would_violate_last_service(&state, op);
assert_eq!(
result.is_ok(),
expected_ok,
"case ({state:?}, {kind:?}, {kind_will_be_enabled}) — expected ok={expected_ok}, got {result:?}",
);
if !expected_ok {
assert!(matches!(result.unwrap_err(), VtaError::LastServiceRefused));
}
}
}
}