use std::fmt;
use std::time::Duration;
use affinidi_did_resolver_cache_sdk::DIDCacheClient;
use async_trait::async_trait;
use serde_json::Value as JsonValue;
use thiserror::Error;
use vti_common::telemetry::{SharedTelemetrySink, TelemetryEvent, TelemetryKind};
pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ResolvedMediator {
pub mediator_did: String,
pub endpoint: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HandshakeStage {
Resolve,
Connect,
Authenticate,
Register,
TrustPing,
}
impl HandshakeStage {
pub fn as_str(self) -> &'static str {
match self {
Self::Resolve => "resolve",
Self::Connect => "connect",
Self::Authenticate => "authenticate",
Self::Register => "register",
Self::TrustPing => "trust-ping",
}
}
}
impl fmt::Display for HandshakeStage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Error)]
pub enum HandshakeError {
#[error("mediator handshake failed at stage `{stage}`: {cause}")]
Failed {
stage: HandshakeStage,
cause: String,
},
}
impl HandshakeError {
pub fn stage(&self) -> HandshakeStage {
let Self::Failed { stage, .. } = self;
*stage
}
}
#[derive(Debug, Clone)]
pub struct HandshakeOptions {
pub timeout: Duration,
pub force: bool,
}
impl Default for HandshakeOptions {
fn default() -> Self {
Self {
timeout: DEFAULT_HANDSHAKE_TIMEOUT,
force: false,
}
}
}
#[derive(Debug, Clone)]
pub struct ProverFailure {
pub stage: HandshakeStage,
pub cause: String,
}
#[async_trait]
pub trait ListenerProver: Send + Sync {
async fn prove(
&self,
resolved: &ResolvedMediator,
vta_did: &str,
timeout: Duration,
) -> Result<(), ProverFailure>;
}
pub async fn mediator_handshake(
resolver: &DIDCacheClient,
prover: &(dyn ListenerProver + Send + Sync),
telemetry: &SharedTelemetrySink,
mediator_did: &str,
vta_did: &str,
opts: HandshakeOptions,
) -> Result<ResolvedMediator, HandshakeError> {
let resolved = match resolve_mediator(resolver, mediator_did).await {
Ok(r) => r,
Err(cause) => {
emit_failed(telemetry, mediator_did, HandshakeStage::Resolve, &cause).await;
return Err(HandshakeError::Failed {
stage: HandshakeStage::Resolve,
cause,
});
}
};
if opts.force {
let _ = telemetry
.record(
TelemetryEvent::new(TelemetryKind::MediatorHandshakeBypassed)
.with_mediator(mediator_did)
.with_field("endpoint", JsonValue::from(resolved.endpoint.clone())),
)
.await;
return Ok(resolved);
}
if let Err(failure) = prover.prove(&resolved, vta_did, opts.timeout).await {
emit_failed(telemetry, mediator_did, failure.stage, &failure.cause).await;
return Err(HandshakeError::Failed {
stage: failure.stage,
cause: failure.cause,
});
}
let _ = telemetry
.record(
TelemetryEvent::new(TelemetryKind::MediatorHandshakeOk)
.with_mediator(mediator_did)
.with_field("endpoint", JsonValue::from(resolved.endpoint.clone())),
)
.await;
Ok(resolved)
}
pub async fn resolve_mediator(
resolver: &DIDCacheClient,
mediator_did: &str,
) -> Result<ResolvedMediator, String> {
let resolved = resolver
.resolve(mediator_did)
.await
.map_err(|e| format!("failed to resolve mediator DID `{mediator_did}`: {e}"))?;
let service = resolved
.doc
.service
.iter()
.find(|s| s.type_.iter().any(|t| t == "DIDCommMessaging"))
.ok_or_else(|| {
format!("mediator DID `{mediator_did}` has no DIDCommMessaging service entry")
})?;
if resolved.doc.key_agreement.is_empty() {
return Err(format!(
"mediator DID `{mediator_did}` exposes no keyAgreement verification method"
));
}
let endpoint = service.service_endpoint.get_uri().unwrap_or_default();
Ok(ResolvedMediator {
mediator_did: mediator_did.to_string(),
endpoint,
})
}
async fn emit_failed(
telemetry: &SharedTelemetrySink,
mediator_did: &str,
stage: HandshakeStage,
cause: &str,
) {
let _ = telemetry
.record(
TelemetryEvent::new(TelemetryKind::MediatorHandshakeFailed)
.with_mediator(mediator_did)
.with_field("stage", JsonValue::from(stage.as_str()))
.with_field("cause", JsonValue::from(cause)),
)
.await;
}
#[doc(hidden)]
pub struct AlwaysOkProver;
#[async_trait]
impl ListenerProver for AlwaysOkProver {
async fn prove(
&self,
_resolved: &ResolvedMediator,
_vta_did: &str,
_timeout: Duration,
) -> Result<(), ProverFailure> {
Ok(())
}
}
#[doc(hidden)]
pub struct FailingProver {
pub stage: HandshakeStage,
pub cause: String,
}
#[async_trait]
impl ListenerProver for FailingProver {
async fn prove(
&self,
_resolved: &ResolvedMediator,
_vta_did: &str,
_timeout: Duration,
) -> Result<(), ProverFailure> {
Err(ProverFailure {
stage: self.stage,
cause: self.cause.clone(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use affinidi_did_resolver_cache_sdk::config::DIDCacheConfigBuilder;
use std::sync::Arc;
use vti_common::telemetry::{RingBufferTelemetry, TelemetryFilter};
fn telemetry() -> SharedTelemetrySink {
Arc::new(RingBufferTelemetry::with_capacity(64))
}
async fn local_resolver() -> DIDCacheClient {
let config = DIDCacheConfigBuilder::default().build();
DIDCacheClient::new(config).await.expect("resolver init")
}
#[test]
fn handshake_stage_string_form() {
assert_eq!(HandshakeStage::Resolve.as_str(), "resolve");
assert_eq!(HandshakeStage::Connect.as_str(), "connect");
assert_eq!(HandshakeStage::Authenticate.as_str(), "authenticate");
assert_eq!(HandshakeStage::Register.as_str(), "register");
assert_eq!(HandshakeStage::TrustPing.as_str(), "trust-ping");
}
#[test]
fn handshake_options_default_is_10s_no_force() {
let opts = HandshakeOptions::default();
assert_eq!(opts.timeout, Duration::from_secs(10));
assert!(!opts.force);
}
#[tokio::test]
async fn resolve_mediator_rejects_unresolvable_did() {
let resolver = local_resolver().await;
let err = resolve_mediator(&resolver, "did:key:zNOTAREALKEY")
.await
.unwrap_err();
assert!(err.contains("failed to resolve"));
}
#[tokio::test]
async fn force_bypass_skips_prover_but_still_resolves() {
let resolver = local_resolver().await;
let sink = telemetry();
let prover = AlwaysOkProver;
let err = mediator_handshake(
&resolver,
&prover,
&sink,
"did:key:zNOTAREALKEY",
"did:webvh:vta",
HandshakeOptions {
force: true,
..Default::default()
},
)
.await
.unwrap_err();
assert_eq!(err.stage(), HandshakeStage::Resolve);
}
#[tokio::test]
async fn failed_resolve_emits_handshake_failed_with_stage() {
let resolver = local_resolver().await;
let sink = telemetry();
let prover = AlwaysOkProver;
let _ = mediator_handshake(
&resolver,
&prover,
&sink,
"did:key:zNOTAREALKEY",
"did:webvh:vta",
HandshakeOptions::default(),
)
.await;
let events = sink
.query(&TelemetryFilter::new().kind(TelemetryKind::MediatorHandshakeFailed))
.await
.unwrap();
assert_eq!(events.len(), 1);
assert_eq!(
events[0].fields.get("stage").and_then(|v| v.as_str()),
Some("resolve"),
);
}
#[tokio::test]
async fn prover_failure_propagates_stage() {
let prover = FailingProver {
stage: HandshakeStage::TrustPing,
cause: "pong timeout".into(),
};
let resolved = ResolvedMediator {
mediator_did: "did:m:fake".into(),
endpoint: "wss://fake".into(),
};
let failure = prover
.prove(&resolved, "did:webvh:vta", Duration::from_secs(1))
.await
.unwrap_err();
assert_eq!(failure.stage, HandshakeStage::TrustPing);
assert_eq!(failure.cause, "pong timeout");
}
}