use crate::dm::DmEnvelope;
use crate::identity::AgentId;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
use std::time::{Duration, Instant};
const RELAY_HEADER_SIGN_DOMAIN: &[u8] = b"x0x-relay-hdr-v1";
pub const DEFAULT_FAIL_THRESHOLD: u32 = 3;
pub const DEFAULT_FAIL_WINDOW: Duration = Duration::from_secs(60);
pub const DEFAULT_RELAY_FRESHNESS: Duration = Duration::from_secs(30);
pub const RELAY_CLOCK_SKEW_TOLERANCE_MS: u64 = 30_000;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RelayHeader {
pub version: u16,
pub dst_agent_id: [u8; 32],
pub sender_agent_id: [u8; 32],
pub sender_public_key: Vec<u8>,
pub originated_at_unix_ms: u64,
pub signature: Vec<u8>,
}
impl RelayHeader {
pub const VERSION: u16 = 1;
#[must_use]
pub fn signing_bytes(
version: u16,
dst_agent_id: &[u8; 32],
sender_agent_id: &[u8; 32],
sender_public_key: &[u8],
originated_at_unix_ms: u64,
) -> Vec<u8> {
let mut out = Vec::with_capacity(
RELAY_HEADER_SIGN_DOMAIN.len() + 2 + 32 + 32 + sender_public_key.len() + 8,
);
out.extend_from_slice(RELAY_HEADER_SIGN_DOMAIN);
out.extend_from_slice(&version.to_be_bytes());
out.extend_from_slice(dst_agent_id);
out.extend_from_slice(sender_agent_id);
out.extend_from_slice(sender_public_key);
out.extend_from_slice(&originated_at_unix_ms.to_be_bytes());
out
}
#[must_use]
pub fn own_signing_bytes(&self) -> Vec<u8> {
Self::signing_bytes(
self.version,
&self.dst_agent_id,
&self.sender_agent_id,
&self.sender_public_key,
self.originated_at_unix_ms,
)
}
#[must_use]
pub fn verify(&self) -> bool {
if self.version != Self::VERSION {
return false;
}
let public_key = match ant_quic::MlDsaPublicKey::from_bytes(&self.sender_public_key) {
Ok(pk) => pk,
Err(_) => return false,
};
let derived = AgentId::from_public_key(&public_key);
if derived.0 != self.sender_agent_id {
return false;
}
let signature = match ant_quic::crypto::raw_public_keys::pqc::MlDsaSignature::from_bytes(
&self.signature,
) {
Ok(sig) => sig,
Err(_) => return false,
};
ant_quic::crypto::raw_public_keys::pqc::verify_with_ml_dsa(
&public_key,
&self.own_signing_bytes(),
&signature,
)
.is_ok()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelayedDm {
pub header: RelayHeader,
pub inner: DmEnvelope,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RelayDisposition {
DeliverLocally,
Forward { dst_agent_id: [u8; 32] },
Refuse(RelayRefusal),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RelayRefusal {
BadSignature,
Stale,
PolicyDisabled,
}
#[derive(Debug, Clone, Copy)]
pub struct RelayPolicy {
pub enabled: bool,
pub fail_threshold: u32,
pub fail_window: Duration,
pub freshness: Duration,
}
impl Default for RelayPolicy {
fn default() -> Self {
Self {
enabled: false,
fail_threshold: DEFAULT_FAIL_THRESHOLD,
fail_window: DEFAULT_FAIL_WINDOW,
freshness: DEFAULT_RELAY_FRESHNESS,
}
}
}
impl RelayPolicy {
#[must_use]
pub fn enabled() -> Self {
Self {
enabled: true,
..Self::default()
}
}
#[must_use]
pub fn with_failure_trigger(mut self, threshold: u32, window: Duration) -> Self {
self.fail_threshold = threshold.max(1);
self.fail_window = window;
self
}
}
#[derive(Debug, Default)]
struct PeerRelayState {
recent_failures: Vec<Instant>,
in_relay_mode: bool,
}
#[derive(Debug, Default)]
pub struct RelayStats {
relay_sent: AtomicU64,
relay_received: AtomicU64,
relay_forwarded: AtomicU64,
relay_refused_bad_signature: AtomicU64,
relay_refused_stale: AtomicU64,
relay_refused_policy_disabled: AtomicU64,
direct_recovered_after_relay: AtomicU64,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct RelayStatsSnapshot {
pub relay_sent: u64,
pub relay_received: u64,
pub relay_forwarded: u64,
pub relay_refused_bad_signature: u64,
pub relay_refused_stale: u64,
pub relay_refused_policy_disabled: u64,
pub direct_recovered_after_relay: u64,
}
impl RelayStats {
#[must_use]
pub fn snapshot(&self) -> RelayStatsSnapshot {
RelayStatsSnapshot {
relay_sent: self.relay_sent.load(Ordering::Relaxed),
relay_received: self.relay_received.load(Ordering::Relaxed),
relay_forwarded: self.relay_forwarded.load(Ordering::Relaxed),
relay_refused_bad_signature: self.relay_refused_bad_signature.load(Ordering::Relaxed),
relay_refused_stale: self.relay_refused_stale.load(Ordering::Relaxed),
relay_refused_policy_disabled: self
.relay_refused_policy_disabled
.load(Ordering::Relaxed),
direct_recovered_after_relay: self.direct_recovered_after_relay.load(Ordering::Relaxed),
}
}
}
#[derive(Debug)]
pub struct PeerRelay {
policy: RelayPolicy,
stats: RelayStats,
per_peer: Mutex<HashMap<[u8; 32], PeerRelayState>>,
}
impl Default for PeerRelay {
fn default() -> Self {
Self::new()
}
}
impl PeerRelay {
#[must_use]
pub fn new() -> Self {
Self {
policy: RelayPolicy::default(),
stats: RelayStats::default(),
per_peer: Mutex::new(HashMap::new()),
}
}
#[must_use]
pub fn with_policy(policy: RelayPolicy) -> Self {
Self {
policy,
stats: RelayStats::default(),
per_peer: Mutex::new(HashMap::new()),
}
}
#[must_use]
pub fn policy(&self) -> &RelayPolicy {
&self.policy
}
#[must_use]
pub fn stats(&self) -> &RelayStats {
&self.stats
}
fn lock(&self) -> std::sync::MutexGuard<'_, HashMap<[u8; 32], PeerRelayState>> {
match self.per_peer.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
}
}
pub fn record_direct_failure(&self, peer: &AgentId) {
let now = Instant::now();
let window = self.policy.fail_window;
let mut guard = self.lock();
let entry = guard.entry(peer.0).or_default();
entry
.recent_failures
.retain(|t| now.saturating_duration_since(*t) < window);
entry.recent_failures.push(now);
}
pub fn record_direct_success(&self, peer: &AgentId) {
let mut guard = self.lock();
if let Some(entry) = guard.get_mut(&peer.0) {
entry.recent_failures.clear();
if entry.in_relay_mode {
entry.in_relay_mode = false;
drop(guard);
self.stats
.direct_recovered_after_relay
.fetch_add(1, Ordering::Relaxed);
}
}
}
#[must_use]
pub fn needs_relay(&self, peer: &AgentId) -> bool {
if !self.policy.enabled {
return false;
}
let now = Instant::now();
let window = self.policy.fail_window;
let threshold = self.policy.fail_threshold as usize;
let mut guard = self.lock();
let Some(entry) = guard.get_mut(&peer.0) else {
return false;
};
entry
.recent_failures
.retain(|t| now.saturating_duration_since(*t) < window);
let needs = entry.recent_failures.len() >= threshold;
if needs {
entry.in_relay_mode = true;
}
needs
}
#[must_use]
pub fn select_relay(
&self,
candidates: &[AgentId],
dst: &AgentId,
sender: &AgentId,
) -> Option<AgentId> {
candidates
.iter()
.find(|c| c.0 != dst.0 && c.0 != sender.0)
.copied()
}
pub fn build_relayed_dm<F>(
&self,
dst: &AgentId,
sender: &AgentId,
sender_public_key: Vec<u8>,
originated_at_unix_ms: u64,
inner: DmEnvelope,
sign: F,
) -> Result<RelayedDm, String>
where
F: FnOnce(&[u8]) -> Result<Vec<u8>, String>,
{
let signing_bytes = RelayHeader::signing_bytes(
RelayHeader::VERSION,
&dst.0,
&sender.0,
&sender_public_key,
originated_at_unix_ms,
);
let signature = sign(&signing_bytes)?;
let header = RelayHeader {
version: RelayHeader::VERSION,
dst_agent_id: dst.0,
sender_agent_id: sender.0,
sender_public_key,
originated_at_unix_ms,
signature,
};
self.stats.relay_sent.fetch_add(1, Ordering::Relaxed);
Ok(RelayedDm { header, inner })
}
#[must_use]
pub fn disposition_for(
&self,
relayed: &RelayedDm,
local_agent_id: &AgentId,
now_unix_ms: u64,
) -> RelayDisposition {
if !relayed.header.verify() {
self.stats
.relay_refused_bad_signature
.fetch_add(1, Ordering::Relaxed);
return RelayDisposition::Refuse(RelayRefusal::BadSignature);
}
if !self.policy.enabled {
self.stats
.relay_refused_policy_disabled
.fetch_add(1, Ordering::Relaxed);
return RelayDisposition::Refuse(RelayRefusal::PolicyDisabled);
}
let freshness_ms = self.policy.freshness.as_millis() as u64;
let originated = relayed.header.originated_at_unix_ms;
let from_future = originated > now_unix_ms.saturating_add(RELAY_CLOCK_SKEW_TOLERANCE_MS);
let too_old = now_unix_ms.saturating_sub(originated) > freshness_ms;
if from_future || too_old {
self.stats
.relay_refused_stale
.fetch_add(1, Ordering::Relaxed);
return RelayDisposition::Refuse(RelayRefusal::Stale);
}
if relayed.header.dst_agent_id == local_agent_id.0 {
self.stats.relay_received.fetch_add(1, Ordering::Relaxed);
RelayDisposition::DeliverLocally
} else {
self.stats.relay_forwarded.fetch_add(1, Ordering::Relaxed);
RelayDisposition::Forward {
dst_agent_id: relayed.header.dst_agent_id,
}
}
}
#[must_use]
pub fn tracked_peer_count(&self) -> usize {
self.lock().len()
}
pub fn forget_peer(&self, peer: &AgentId) {
self.lock().remove(&peer.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dm::{DmBody, DmPayload};
use crate::identity::AgentKeypair;
fn aid(seed: u8) -> AgentId {
AgentId([seed; 32])
}
fn dummy_inner() -> DmEnvelope {
DmEnvelope {
protocol_version: 1,
request_id: [7u8; 16],
sender_agent_id: [1u8; 32],
sender_machine_id: [2u8; 32],
recipient_agent_id: [3u8; 32],
created_at_unix_ms: 1_000,
expires_at_unix_ms: 60_000,
body: DmBody::Payload(DmPayload {
kem_ciphertext: vec![0u8; 8],
body_nonce: [0u8; 12],
body_ciphertext: vec![0u8; 8],
}),
signature: vec![0u8; 8],
}
}
#[test]
fn relay_disabled_by_default() {
let relay = PeerRelay::new();
assert!(!relay.policy().enabled);
let peer = aid(9);
for _ in 0..10 {
relay.record_direct_failure(&peer);
}
assert!(
!relay.needs_relay(&peer),
"disabled policy must never trigger relay regardless of failures"
);
}
#[test]
fn needs_relay_after_threshold_failures_within_window() {
let relay = PeerRelay::with_policy(RelayPolicy::enabled());
let peer = aid(1);
relay.record_direct_failure(&peer);
relay.record_direct_failure(&peer);
assert!(
!relay.needs_relay(&peer),
"2 failures < default threshold 3 — no relay yet"
);
relay.record_direct_failure(&peer);
assert!(
relay.needs_relay(&peer),
"3 failures == threshold — peer now needs a relay"
);
}
#[test]
fn direct_success_clears_failures_and_counts_recovery() {
let relay = PeerRelay::with_policy(RelayPolicy::enabled());
let peer = aid(2);
for _ in 0..3 {
relay.record_direct_failure(&peer);
}
assert!(relay.needs_relay(&peer), "peer entered relay mode");
relay.record_direct_success(&peer);
assert!(
!relay.needs_relay(&peer),
"direct success clears the failure history"
);
assert_eq!(
relay.stats().snapshot().direct_recovered_after_relay,
1,
"recovery from relay mode is counted once"
);
relay.record_direct_success(&peer);
assert_eq!(
relay.stats().snapshot().direct_recovered_after_relay,
1,
"recovery counter does not double-count"
);
}
#[test]
fn select_relay_skips_dst_and_sender() {
let relay = PeerRelay::new();
let sender = aid(1);
let dst = aid(2);
let r1 = aid(3);
let r2 = aid(4);
let candidates = vec![dst, sender, r1, r2];
assert_eq!(relay.select_relay(&candidates, &dst, &sender), Some(r1));
let only_endpoints = vec![dst, sender];
assert_eq!(
relay.select_relay(&only_endpoints, &dst, &sender),
None,
"no third party available — cannot relay"
);
}
#[test]
fn relay_header_sign_verify_roundtrip() {
let kp = AgentKeypair::generate().expect("keypair");
let sender = kp.agent_id();
let dst = aid(50);
let (pub_bytes, sec_bytes) = kp.to_bytes();
let originated = 1_700_000_000_000u64;
let signing_bytes = RelayHeader::signing_bytes(
RelayHeader::VERSION,
&dst.0,
&sender.0,
&pub_bytes,
originated,
);
let secret = ant_quic::MlDsaSecretKey::from_bytes(&sec_bytes).expect("secret");
let signature =
ant_quic::crypto::raw_public_keys::pqc::sign_with_ml_dsa(&secret, &signing_bytes)
.expect("sign");
let header = RelayHeader {
version: RelayHeader::VERSION,
dst_agent_id: dst.0,
sender_agent_id: sender.0,
sender_public_key: pub_bytes,
originated_at_unix_ms: originated,
signature: signature.as_bytes().to_vec(),
};
assert!(header.verify(), "a correctly signed header must verify");
}
#[test]
fn relay_header_verify_rejects_tampered_dst() {
let kp = AgentKeypair::generate().expect("keypair");
let sender = kp.agent_id();
let dst = aid(50);
let (pub_bytes, sec_bytes) = kp.to_bytes();
let originated = 1_700_000_000_000u64;
let signing_bytes = RelayHeader::signing_bytes(
RelayHeader::VERSION,
&dst.0,
&sender.0,
&pub_bytes,
originated,
);
let secret = ant_quic::MlDsaSecretKey::from_bytes(&sec_bytes).expect("secret");
let signature =
ant_quic::crypto::raw_public_keys::pqc::sign_with_ml_dsa(&secret, &signing_bytes)
.expect("sign");
let mut header = RelayHeader {
version: RelayHeader::VERSION,
dst_agent_id: dst.0,
sender_agent_id: sender.0,
sender_public_key: pub_bytes,
originated_at_unix_ms: originated,
signature: signature.as_bytes().to_vec(),
};
header.dst_agent_id = aid(99).0;
assert!(
!header.verify(),
"a tampered dst must break the header signature"
);
}
#[test]
fn relay_header_verify_rejects_forged_origin() {
let kp = AgentKeypair::generate().expect("keypair");
let (pub_bytes, sec_bytes) = kp.to_bytes();
let dst = aid(50);
let forged_sender = aid(123); let originated = 1_700_000_000_000u64;
let signing_bytes = RelayHeader::signing_bytes(
RelayHeader::VERSION,
&dst.0,
&forged_sender.0,
&pub_bytes,
originated,
);
let secret = ant_quic::MlDsaSecretKey::from_bytes(&sec_bytes).expect("secret");
let signature =
ant_quic::crypto::raw_public_keys::pqc::sign_with_ml_dsa(&secret, &signing_bytes)
.expect("sign");
let header = RelayHeader {
version: RelayHeader::VERSION,
dst_agent_id: dst.0,
sender_agent_id: forged_sender.0,
sender_public_key: pub_bytes,
originated_at_unix_ms: originated,
signature: signature.as_bytes().to_vec(),
};
assert!(
!header.verify(),
"sender_agent_id must derive from sender_public_key"
);
}
#[test]
fn build_relayed_dm_increments_relay_sent_and_produces_verifiable_header() {
let kp = AgentKeypair::generate().expect("keypair");
let sender = kp.agent_id();
let dst = aid(60);
let (pub_bytes, sec_bytes) = kp.to_bytes();
let secret = ant_quic::MlDsaSecretKey::from_bytes(&sec_bytes).expect("secret");
let relay = PeerRelay::with_policy(RelayPolicy::enabled());
let relayed = relay
.build_relayed_dm(
&dst,
&sender,
pub_bytes,
1_700_000_000_000,
dummy_inner(),
|bytes| {
ant_quic::crypto::raw_public_keys::pqc::sign_with_ml_dsa(&secret, bytes)
.map(|s| s.as_bytes().to_vec())
.map_err(|e| format!("{e:?}"))
},
)
.expect("build_relayed_dm");
assert!(
relayed.header.verify(),
"build_relayed_dm must produce a verifiable header"
);
assert_eq!(relay.stats().snapshot().relay_sent, 1);
}
#[test]
fn disposition_delivers_locally_when_we_are_the_dst() {
let kp = AgentKeypair::generate().expect("keypair");
let sender = kp.agent_id();
let (pub_bytes, sec_bytes) = kp.to_bytes();
let secret = ant_quic::MlDsaSecretKey::from_bytes(&sec_bytes).expect("secret");
let local = aid(70);
let relay = PeerRelay::with_policy(RelayPolicy::enabled());
let now_ms = 1_700_000_000_000u64;
let relayed = relay
.build_relayed_dm(&local, &sender, pub_bytes, now_ms, dummy_inner(), |bytes| {
ant_quic::crypto::raw_public_keys::pqc::sign_with_ml_dsa(&secret, bytes)
.map(|s| s.as_bytes().to_vec())
.map_err(|e| format!("{e:?}"))
})
.expect("build");
assert_eq!(
relay.disposition_for(&relayed, &local, now_ms + 100),
RelayDisposition::DeliverLocally
);
assert_eq!(relay.stats().snapshot().relay_received, 1);
}
#[test]
fn disposition_forwards_when_we_are_an_intermediate_relay() {
let kp = AgentKeypair::generate().expect("keypair");
let sender = kp.agent_id();
let (pub_bytes, sec_bytes) = kp.to_bytes();
let secret = ant_quic::MlDsaSecretKey::from_bytes(&sec_bytes).expect("secret");
let dst = aid(80);
let we_are_the_relay = aid(81);
let relay = PeerRelay::with_policy(RelayPolicy::enabled());
let now_ms = 1_700_000_000_000u64;
let relayed = relay
.build_relayed_dm(&dst, &sender, pub_bytes, now_ms, dummy_inner(), |bytes| {
ant_quic::crypto::raw_public_keys::pqc::sign_with_ml_dsa(&secret, bytes)
.map(|s| s.as_bytes().to_vec())
.map_err(|e| format!("{e:?}"))
})
.expect("build");
assert_eq!(
relay.disposition_for(&relayed, &we_are_the_relay, now_ms + 100),
RelayDisposition::Forward {
dst_agent_id: dst.0
}
);
assert_eq!(relay.stats().snapshot().relay_forwarded, 1);
}
#[test]
fn disposition_refuses_stale_relayed_dm() {
let kp = AgentKeypair::generate().expect("keypair");
let sender = kp.agent_id();
let (pub_bytes, sec_bytes) = kp.to_bytes();
let secret = ant_quic::MlDsaSecretKey::from_bytes(&sec_bytes).expect("secret");
let local = aid(90);
let relay = PeerRelay::with_policy(RelayPolicy::enabled());
let originated_ms = 1_700_000_000_000u64;
let relayed = relay
.build_relayed_dm(
&local,
&sender,
pub_bytes,
originated_ms,
dummy_inner(),
|bytes| {
ant_quic::crypto::raw_public_keys::pqc::sign_with_ml_dsa(&secret, bytes)
.map(|s| s.as_bytes().to_vec())
.map_err(|e| format!("{e:?}"))
},
)
.expect("build");
let now_ms = originated_ms + 31_000;
assert_eq!(
relay.disposition_for(&relayed, &local, now_ms),
RelayDisposition::Refuse(RelayRefusal::Stale)
);
assert_eq!(relay.stats().snapshot().relay_refused_stale, 1);
}
#[test]
fn disposition_refuses_far_future_relayed_dm() {
let kp = AgentKeypair::generate().expect("keypair");
let sender = kp.agent_id();
let (pub_bytes, sec_bytes) = kp.to_bytes();
let secret = ant_quic::MlDsaSecretKey::from_bytes(&sec_bytes).expect("secret");
let local = aid(91);
let relay = PeerRelay::with_policy(RelayPolicy::enabled());
let now_ms = 1_700_000_000_000u64;
let originated_ms = now_ms + RELAY_CLOCK_SKEW_TOLERANCE_MS + 1_000;
let relayed = relay
.build_relayed_dm(
&local,
&sender,
pub_bytes,
originated_ms,
dummy_inner(),
|bytes| {
ant_quic::crypto::raw_public_keys::pqc::sign_with_ml_dsa(&secret, bytes)
.map(|s| s.as_bytes().to_vec())
.map_err(|e| format!("{e:?}"))
},
)
.expect("build");
assert_eq!(
relay.disposition_for(&relayed, &local, now_ms),
RelayDisposition::Refuse(RelayRefusal::Stale)
);
assert_eq!(relay.stats().snapshot().relay_refused_stale, 1);
let fresh = relay
.build_relayed_dm(
&local,
&sender,
kp.to_bytes().0,
now_ms + 1_000,
dummy_inner(),
|bytes| {
ant_quic::crypto::raw_public_keys::pqc::sign_with_ml_dsa(&secret, bytes)
.map(|s| s.as_bytes().to_vec())
.map_err(|e| format!("{e:?}"))
},
)
.expect("build");
assert_eq!(
relay.disposition_for(&fresh, &local, now_ms),
RelayDisposition::DeliverLocally
);
}
#[test]
fn disposition_refuses_when_policy_disabled() {
let kp = AgentKeypair::generate().expect("keypair");
let sender = kp.agent_id();
let (pub_bytes, sec_bytes) = kp.to_bytes();
let secret = ant_quic::MlDsaSecretKey::from_bytes(&sec_bytes).expect("secret");
let local = aid(95);
let builder = PeerRelay::with_policy(RelayPolicy::enabled());
let now_ms = 1_700_000_000_000u64;
let relayed = builder
.build_relayed_dm(&local, &sender, pub_bytes, now_ms, dummy_inner(), |bytes| {
ant_quic::crypto::raw_public_keys::pqc::sign_with_ml_dsa(&secret, bytes)
.map(|s| s.as_bytes().to_vec())
.map_err(|e| format!("{e:?}"))
})
.expect("build");
let disabled = PeerRelay::new();
assert_eq!(
disabled.disposition_for(&relayed, &local, now_ms + 100),
RelayDisposition::Refuse(RelayRefusal::PolicyDisabled)
);
assert_eq!(disabled.stats().snapshot().relay_refused_policy_disabled, 1);
}
#[test]
fn forget_peer_drops_relay_state() {
let relay = PeerRelay::with_policy(RelayPolicy::enabled());
let peer = aid(1);
relay.record_direct_failure(&peer);
assert_eq!(relay.tracked_peer_count(), 1);
relay.forget_peer(&peer);
assert_eq!(relay.tracked_peer_count(), 0);
}
}