use std::net::SocketAddr;
use async_trait::async_trait;
use thiserror::Error;
use tracing::{debug, info, warn};
use crate::{MultiAddr, PeerId};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RelayCandidate {
pub peer_id: PeerId,
pub direct_address: MultiAddr,
}
impl RelayCandidate {
pub fn new(peer_id: PeerId, direct_address: MultiAddr) -> Self {
Self {
peer_id,
direct_address,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AcquiredRelay {
pub relayer: PeerId,
pub allocated_public_addr: SocketAddr,
}
#[derive(Debug, Clone, Error)]
pub enum RelaySessionEstablishError {
#[error("relay at client capacity: {0}")]
AtCapacity(String),
#[error("relay unreachable: {0}")]
Unreachable(String),
}
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum RelayAcquisitionError {
#[error("no candidate relays available")]
NoCandidates,
#[error("all candidate relays exhausted without success")]
AllCandidatesExhausted,
}
#[async_trait]
pub trait RelaySessionEstablisher: Send + Sync + 'static {
async fn establish(
&self,
relay_addr: SocketAddr,
) -> Result<SocketAddr, RelaySessionEstablishError>;
}
pub struct RelayAcquisition<E: RelaySessionEstablisher> {
establisher: E,
}
impl<E: RelaySessionEstablisher> RelayAcquisition<E> {
pub fn new(establisher: E) -> Self {
Self { establisher }
}
pub async fn acquire(
&self,
candidates: Vec<RelayCandidate>,
) -> Result<AcquiredRelay, RelayAcquisitionError> {
if candidates.is_empty() {
debug!("relay acquisition called with empty candidate list");
return Err(RelayAcquisitionError::NoCandidates);
}
let candidate_count = candidates.len();
debug!(
candidates = candidate_count,
"starting proactive relay acquisition walk"
);
for (index, candidate) in candidates.into_iter().enumerate() {
let Some(socket) = candidate.direct_address.dialable_socket_addr() else {
warn!(
relayer = ?candidate.peer_id,
address = %candidate.direct_address,
"candidate has no dialable socket address, skipping"
);
continue;
};
match self.establisher.establish(socket).await {
Ok(allocated) => {
info!(
relayer = ?candidate.peer_id,
allocated = %allocated,
index = index,
"acquired proactive relay session"
);
return Ok(AcquiredRelay {
relayer: candidate.peer_id,
allocated_public_addr: allocated,
});
}
Err(RelaySessionEstablishError::AtCapacity(reason)) => {
debug!(
relayer = ?candidate.peer_id,
reason = %reason,
index = index,
"candidate relay at capacity, walking to next"
);
}
Err(RelaySessionEstablishError::Unreachable(reason)) => {
debug!(
relayer = ?candidate.peer_id,
reason = %reason,
index = index,
"candidate relay unreachable, walking to next"
);
}
}
}
warn!(
candidates = candidate_count,
"all candidate relays exhausted without success"
);
Err(RelayAcquisitionError::AllCandidatesExhausted)
}
}
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use super::*;
fn peer_id(seed: u8) -> PeerId {
PeerId::from_bytes([seed; 32])
}
fn dialable_addr(port: u16) -> MultiAddr {
MultiAddr::from_ipv4(Ipv4Addr::new(192, 0, 2, 1), port)
}
fn candidate(seed: u8, port: u16) -> RelayCandidate {
RelayCandidate::new(peer_id(seed), dialable_addr(port))
}
struct ScriptedEstablisher {
outcomes: std::sync::Mutex<Vec<Result<SocketAddr, RelaySessionEstablishError>>>,
calls: Arc<AtomicUsize>,
}
impl ScriptedEstablisher {
fn new(outcomes: Vec<Result<SocketAddr, RelaySessionEstablishError>>) -> Self {
Self {
outcomes: std::sync::Mutex::new(outcomes),
calls: Arc::new(AtomicUsize::new(0)),
}
}
}
#[async_trait]
impl RelaySessionEstablisher for ScriptedEstablisher {
async fn establish(
&self,
_relay_addr: SocketAddr,
) -> Result<SocketAddr, RelaySessionEstablishError> {
self.calls.fetch_add(1, Ordering::SeqCst);
let mut guard = self.outcomes.lock().expect("mutex poisoned in test");
if guard.is_empty() {
panic!("scripted establisher ran out of outcomes");
}
guard.remove(0)
}
}
fn allocated(port: u16) -> SocketAddr {
SocketAddr::new(std::net::IpAddr::V4(Ipv4Addr::new(203, 0, 113, 7)), port)
}
#[tokio::test]
async fn empty_candidate_list_returns_no_candidates_error() {
let establisher = ScriptedEstablisher::new(Vec::new());
let coordinator = RelayAcquisition::new(establisher);
let result = coordinator.acquire(Vec::new()).await;
assert_eq!(result.unwrap_err(), RelayAcquisitionError::NoCandidates);
}
#[tokio::test]
async fn first_candidate_success_returns_immediately() {
let establisher = ScriptedEstablisher::new(vec![Ok(allocated(9000))]);
let calls = establisher.calls.clone();
let coordinator = RelayAcquisition::new(establisher);
let candidates = vec![candidate(1, 10000), candidate(2, 10001)];
let result = coordinator
.acquire(candidates)
.await
.expect("should succeed");
assert_eq!(result.relayer, peer_id(1));
assert_eq!(result.allocated_public_addr, allocated(9000));
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"should stop on first success, not walk further"
);
}
#[tokio::test]
async fn at_capacity_walks_to_next_candidate() {
let establisher = ScriptedEstablisher::new(vec![
Err(RelaySessionEstablishError::AtCapacity("full".to_string())),
Ok(allocated(9001)),
]);
let calls = establisher.calls.clone();
let coordinator = RelayAcquisition::new(establisher);
let candidates = vec![candidate(1, 10000), candidate(2, 10001)];
let result = coordinator
.acquire(candidates)
.await
.expect("should succeed");
assert_eq!(result.relayer, peer_id(2));
assert_eq!(result.allocated_public_addr, allocated(9001));
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn unreachable_walks_to_next_candidate() {
let establisher = ScriptedEstablisher::new(vec![
Err(RelaySessionEstablishError::Unreachable(
"timeout".to_string(),
)),
Ok(allocated(9002)),
]);
let calls = establisher.calls.clone();
let coordinator = RelayAcquisition::new(establisher);
let candidates = vec![candidate(1, 10000), candidate(2, 10001)];
let result = coordinator
.acquire(candidates)
.await
.expect("should succeed");
assert_eq!(result.relayer, peer_id(2));
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn all_at_capacity_returns_exhausted() {
let establisher = ScriptedEstablisher::new(vec![
Err(RelaySessionEstablishError::AtCapacity("full".to_string())),
Err(RelaySessionEstablishError::AtCapacity("full".to_string())),
Err(RelaySessionEstablishError::AtCapacity("full".to_string())),
]);
let calls = establisher.calls.clone();
let coordinator = RelayAcquisition::new(establisher);
let candidates = vec![
candidate(1, 10000),
candidate(2, 10001),
candidate(3, 10002),
];
let result = coordinator.acquire(candidates).await;
assert_eq!(
result.unwrap_err(),
RelayAcquisitionError::AllCandidatesExhausted
);
assert_eq!(
calls.load(Ordering::SeqCst),
3,
"should have tried every candidate exactly once"
);
}
#[tokio::test]
async fn all_unreachable_returns_exhausted() {
let establisher = ScriptedEstablisher::new(vec![
Err(RelaySessionEstablishError::Unreachable("a".to_string())),
Err(RelaySessionEstablishError::Unreachable("b".to_string())),
]);
let coordinator = RelayAcquisition::new(establisher);
let candidates = vec![candidate(1, 10000), candidate(2, 10001)];
let result = coordinator.acquire(candidates).await;
assert_eq!(
result.unwrap_err(),
RelayAcquisitionError::AllCandidatesExhausted
);
}
#[tokio::test]
async fn mixed_errors_then_success() {
let establisher = ScriptedEstablisher::new(vec![
Err(RelaySessionEstablishError::Unreachable("dead".to_string())),
Err(RelaySessionEstablishError::AtCapacity("full".to_string())),
Ok(allocated(9003)),
]);
let calls = establisher.calls.clone();
let coordinator = RelayAcquisition::new(establisher);
let candidates = vec![
candidate(1, 10000),
candidate(2, 10001),
candidate(3, 10002),
];
let result = coordinator
.acquire(candidates)
.await
.expect("should succeed");
assert_eq!(result.relayer, peer_id(3));
assert_eq!(calls.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn candidate_with_non_dialable_address_is_skipped() {
let tcp_addr = MultiAddr::tcp(SocketAddr::new(
std::net::IpAddr::V4(Ipv4Addr::new(192, 0, 2, 1)),
80,
));
let skipped = RelayCandidate::new(peer_id(1), tcp_addr);
let ok = candidate(2, 10001);
let establisher = ScriptedEstablisher::new(vec![Ok(allocated(9004))]);
let calls = establisher.calls.clone();
let coordinator = RelayAcquisition::new(establisher);
let result = coordinator
.acquire(vec![skipped, ok])
.await
.expect("should succeed");
assert_eq!(result.relayer, peer_id(2));
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"TCP candidate must be skipped without invoking the establisher"
);
}
#[tokio::test]
async fn all_candidates_non_dialable_returns_exhausted() {
let tcp_addr = MultiAddr::tcp(SocketAddr::new(
std::net::IpAddr::V4(Ipv4Addr::new(192, 0, 2, 1)),
80,
));
let c1 = RelayCandidate::new(peer_id(1), tcp_addr.clone());
let c2 = RelayCandidate::new(peer_id(2), tcp_addr);
let establisher = ScriptedEstablisher::new(Vec::new());
let calls = establisher.calls.clone();
let coordinator = RelayAcquisition::new(establisher);
let result = coordinator.acquire(vec![c1, c2]).await;
assert_eq!(
result.unwrap_err(),
RelayAcquisitionError::AllCandidatesExhausted
);
assert_eq!(
calls.load(Ordering::SeqCst),
0,
"establisher must not be called when no candidate has a dialable address"
);
}
}