use std::sync::Arc;
use std::time::Duration;
use affinidi_tdk::messaging::ATM;
use async_trait::async_trait;
use reqwest::Client;
use tracing::{debug, warn};
use super::client::{RegistryError, TrustRegistryClient};
use super::model::RegistryRecord;
#[derive(Debug, Clone)]
pub struct UpstreamConfig {
pub base_url: String,
pub http_timeout: Duration,
pub authority_did: Option<String>,
}
pub struct UpstreamRegistryClient {
http: Client,
base_url: String,
authority_did: Option<String>,
#[allow(dead_code)]
atm: Option<Arc<ATM>>,
}
impl std::fmt::Debug for UpstreamRegistryClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UpstreamRegistryClient")
.field("base_url", &self.base_url)
.field("atm_set", &self.atm.is_some())
.finish()
}
}
impl UpstreamRegistryClient {
pub fn new(cfg: UpstreamConfig) -> Result<Self, RegistryError> {
let http = Client::builder()
.timeout(cfg.http_timeout)
.build()
.map_err(|e| RegistryError::Permanent(format!("reqwest client init: {e}")))?;
Ok(Self {
http,
base_url: cfg.base_url.trim_end_matches('/').to_string(),
authority_did: cfg.authority_did,
atm: None,
})
}
#[allow(dead_code)]
pub fn with_atm(mut self, atm: Arc<ATM>) -> Self {
self.atm = Some(atm);
self
}
fn classify_http_error(e: reqwest::Error) -> RegistryError {
if e.is_timeout() {
RegistryError::Unreachable(format!("timeout: {e}"))
} else if e.is_connect() {
RegistryError::Unreachable(format!("connect: {e}"))
} else if let Some(s) = e.status() {
if s.is_server_error() {
RegistryError::Transient(format!("{s}: {e}"))
} else {
RegistryError::Permanent(format!("{s}: {e}"))
}
} else {
RegistryError::Transient(format!("http: {e}"))
}
}
}
#[async_trait]
impl TrustRegistryClient for UpstreamRegistryClient {
async fn publish_member(&self, _record: &RegistryRecord) -> Result<(), RegistryError> {
warn!(
"UpstreamRegistryClient.publish_member called before M3.4's DIDComm transport landed"
);
Err(RegistryError::Permanent(
"publish_member is not yet implemented in M3.2 — DIDComm transport lands in M3.4"
.into(),
))
}
async fn delete_member(&self, _member_did: &str) -> Result<(), RegistryError> {
warn!("UpstreamRegistryClient.delete_member called before M3.4's DIDComm transport landed");
Err(RegistryError::Permanent(
"delete_member is not yet implemented in M3.2 — DIDComm transport lands in M3.4".into(),
))
}
async fn read_member(
&self,
_member_did: &str,
) -> Result<Option<RegistryRecord>, RegistryError> {
warn!("UpstreamRegistryClient.read_member called before M3.10's TRQP mapping landed");
Err(RegistryError::Permanent(
"read_member is not yet implemented in M3.2 — TRQP query mapping lands in M3.10".into(),
))
}
async fn recognise(&self, foreign_issuer_did: &str) -> Result<bool, RegistryError> {
let authority = self.authority_did.as_deref().ok_or_else(|| {
RegistryError::Permanent(
"authority_did not configured — cannot issue recognise query (set vtc_did in config)".into(),
)
})?;
#[derive(serde::Serialize)]
struct RecognitionRequest<'a> {
entity_id: &'a str,
authority_id: &'a str,
action: &'static str,
resource: &'static str,
}
#[derive(serde::Deserialize)]
#[allow(dead_code)]
struct RecognitionResponse {
recognized: bool,
}
let url = format!("{}/recognition", self.base_url);
let body = RecognitionRequest {
entity_id: foreign_issuer_did,
authority_id: authority,
action: "recognise",
resource: "trust-graph",
};
debug!(%url, entity = %foreign_issuer_did, authority = %authority, "trust-registry recognise");
let resp = self
.http
.post(&url)
.json(&body)
.send()
.await
.map_err(Self::classify_http_error)?;
let status = resp.status();
if status == reqwest::StatusCode::NOT_FOUND {
return Ok(false);
}
if status.is_client_error() {
return Err(RegistryError::Permanent(format!(
"registry returned {status} for recognise"
)));
}
if status.is_server_error() {
return Err(RegistryError::Transient(format!(
"registry returned {status} for recognise"
)));
}
if !status.is_success() {
return Err(RegistryError::Transient(format!(
"unexpected status {status} for recognise"
)));
}
let body: RecognitionResponse = resp
.json()
.await
.map_err(|e| RegistryError::Transient(format!("parse recognise response: {e}")))?;
Ok(body.recognized)
}
async fn health(&self) -> Result<(), RegistryError> {
let url = format!("{}/.well-known/did.json", self.base_url);
debug!(%url, "trust-registry health probe");
let resp = self
.http
.get(&url)
.send()
.await
.map_err(Self::classify_http_error)?;
let status = resp.status();
if status.is_success() || status.is_client_error() {
Ok(())
} else if status.is_server_error() {
Err(RegistryError::Transient(format!(
"registry returned {status}"
)))
} else {
Err(RegistryError::Transient(format!(
"unexpected status {status}"
)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::Json;
use axum::Router;
use axum::extract::State;
use axum::http::StatusCode;
use axum::routing::post;
use serde_json::Value as JsonValue;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::oneshot;
#[derive(Debug, Clone, serde::Deserialize)]
struct CapturedRequest {
entity_id: String,
authority_id: String,
action: String,
resource: String,
}
#[derive(Clone)]
struct ServerState {
captured: Arc<tokio::sync::Mutex<Option<CapturedRequest>>>,
response: Arc<tokio::sync::Mutex<TestResponse>>,
}
#[derive(Clone, Debug)]
enum TestResponse {
Recognized(bool),
NotFound,
BadRequest,
ServerError,
}
async fn recognition_handler(
State(state): State<ServerState>,
Json(body): Json<CapturedRequest>,
) -> (StatusCode, Json<JsonValue>) {
*state.captured.lock().await = Some(body);
match state.response.lock().await.clone() {
TestResponse::Recognized(b) => (
StatusCode::OK,
Json(serde_json::json!({
"entity_id": "did:webvh:peer.example",
"authority_id": "did:webvh:vtc.example",
"action": "recognise",
"resource": "trust-graph",
"recognized": b,
"context": {},
"record_type": "recognition",
"time_requested": "2026-05-13T10:00:00Z",
"time_evaluated": "2026-05-13T10:00:00Z",
"message": "mock"
})),
),
TestResponse::NotFound => (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"title": "not_found",
"type": "about:blank",
"code": 404
})),
),
TestResponse::BadRequest => (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"title": "bad_request",
"type": "about:blank",
"code": 400
})),
),
TestResponse::ServerError => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"title": "internal_error",
"type": "about:blank",
"code": 500
})),
),
}
}
async fn spawn_server(
initial: TestResponse,
) -> (
String,
Arc<tokio::sync::Mutex<Option<CapturedRequest>>>,
Arc<tokio::sync::Mutex<TestResponse>>,
oneshot::Sender<()>,
) {
let captured = Arc::new(tokio::sync::Mutex::new(None));
let response = Arc::new(tokio::sync::Mutex::new(initial));
let state = ServerState {
captured: captured.clone(),
response: response.clone(),
};
let app = Router::new()
.route("/recognition", post(recognition_handler))
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr: SocketAddr = listener.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async {
let _ = shutdown_rx.await;
})
.await
.unwrap();
});
(format!("http://{addr}"), captured, response, shutdown_tx)
}
fn config_for(base_url: &str) -> UpstreamConfig {
UpstreamConfig {
base_url: base_url.to_string(),
http_timeout: Duration::from_secs(2),
authority_did: Some("did:webvh:vtc.example".into()),
}
}
#[test]
fn config_trims_trailing_slash() {
let cfg = UpstreamConfig {
base_url: "https://registry.example.com/".into(),
http_timeout: Duration::from_secs(5),
authority_did: Some("did:webvh:vtc.example".into()),
};
let c = UpstreamRegistryClient::new(cfg).unwrap();
assert_eq!(c.base_url, "https://registry.example.com");
}
#[tokio::test]
async fn health_against_unreachable_host_is_unreachable() {
let cfg = UpstreamConfig {
base_url: "http://127.0.0.1:1".into(),
http_timeout: Duration::from_millis(200),
authority_did: None,
};
let c = UpstreamRegistryClient::new(cfg).unwrap();
let err = c.health().await.expect_err("should fail");
assert!(
err.is_retriable(),
"connection refused is retriable: {err:?}"
);
}
#[tokio::test]
async fn publish_member_returns_permanent_until_m3_4() {
let cfg = UpstreamConfig {
base_url: "http://localhost:9999".into(),
http_timeout: Duration::from_secs(1),
authority_did: None,
};
let c = UpstreamRegistryClient::new(cfg).unwrap();
let record = RegistryRecord::fresh_active("did:key:zMember");
let err = c
.publish_member(&record)
.await
.expect_err("M3.2 doesn't implement writes");
assert!(matches!(err, RegistryError::Permanent(_)));
assert!(!err.is_retriable());
}
#[tokio::test]
async fn recognise_sends_canonical_four_tuple() {
let (url, captured, _resp, shutdown) = spawn_server(TestResponse::Recognized(true)).await;
let c = UpstreamRegistryClient::new(config_for(&url)).unwrap();
let ok = c.recognise("did:webvh:peer.example").await.unwrap();
assert!(ok);
let body = captured.lock().await.clone().expect("server saw request");
assert_eq!(body.entity_id, "did:webvh:peer.example");
assert_eq!(body.authority_id, "did:webvh:vtc.example");
assert_eq!(body.action, "recognise");
assert_eq!(body.resource, "trust-graph");
let _ = shutdown.send(());
}
#[tokio::test]
async fn recognise_returns_false_when_response_recognized_is_false() {
let (url, _captured, _resp, shutdown) = spawn_server(TestResponse::Recognized(false)).await;
let c = UpstreamRegistryClient::new(config_for(&url)).unwrap();
let ok = c.recognise("did:webvh:peer.example").await.unwrap();
assert!(!ok);
let _ = shutdown.send(());
}
#[tokio::test]
async fn recognise_maps_404_to_clean_not_recognised() {
let (url, _captured, _resp, shutdown) = spawn_server(TestResponse::NotFound).await;
let c = UpstreamRegistryClient::new(config_for(&url)).unwrap();
let ok = c.recognise("did:webvh:stranger.example").await.unwrap();
assert!(!ok);
let _ = shutdown.send(());
}
#[tokio::test]
async fn recognise_maps_400_to_permanent() {
let (url, _captured, _resp, shutdown) = spawn_server(TestResponse::BadRequest).await;
let c = UpstreamRegistryClient::new(config_for(&url)).unwrap();
let err = c.recognise("did:webvh:peer.example").await.unwrap_err();
assert!(matches!(err, RegistryError::Permanent(_)));
assert!(!err.is_retriable());
let _ = shutdown.send(());
}
#[tokio::test]
async fn recognise_maps_500_to_transient() {
let (url, _captured, _resp, shutdown) = spawn_server(TestResponse::ServerError).await;
let c = UpstreamRegistryClient::new(config_for(&url)).unwrap();
let err = c.recognise("did:webvh:peer.example").await.unwrap_err();
assert!(matches!(err, RegistryError::Transient(_)));
assert!(err.is_retriable());
let _ = shutdown.send(());
}
#[tokio::test]
async fn recognise_connection_refused_is_unreachable() {
let cfg = UpstreamConfig {
base_url: "http://127.0.0.1:1".into(),
http_timeout: Duration::from_millis(300),
authority_did: Some("did:webvh:vtc.example".into()),
};
let c = UpstreamRegistryClient::new(cfg).unwrap();
let err = c.recognise("did:webvh:peer.example").await.unwrap_err();
assert!(
err.is_retriable(),
"connection refused must be retriable: {err:?}"
);
}
#[tokio::test]
async fn recognise_refuses_when_authority_did_missing() {
let cfg = UpstreamConfig {
base_url: "http://127.0.0.1:1".into(),
http_timeout: Duration::from_secs(1),
authority_did: None,
};
let c = UpstreamRegistryClient::new(cfg).unwrap();
let err = c.recognise("did:webvh:peer.example").await.unwrap_err();
assert!(matches!(err, RegistryError::Permanent(_)));
assert!(
err.to_string().contains("authority_did"),
"error should mention authority_did: {err}"
);
}
}