use crate::nat_traversal_api::PeerId;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::net::SocketAddr;
use std::time::{Duration, SystemTime};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedPeer {
#[serde(with = "peer_id_serde")]
pub peer_id: PeerId,
pub addresses: Vec<SocketAddr>,
pub capabilities: PeerCapabilities,
pub first_seen: SystemTime,
pub last_seen: SystemTime,
pub last_attempt: Option<SystemTime>,
pub stats: ConnectionStats,
#[serde(default = "default_quality_score")]
pub quality_score: f64,
pub source: PeerSource,
#[serde(default)]
pub relay_paths: Vec<RelayPathHint>,
#[serde(default)]
pub token: Option<Vec<u8>>,
}
fn default_quality_score() -> f64 {
0.5
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PeerCapabilities {
pub supports_relay: bool,
pub supports_coordination: bool,
#[serde(default)]
pub protocols: HashSet<String>,
pub nat_type: Option<NatType>,
#[serde(default)]
pub external_addresses: Vec<SocketAddr>,
}
impl PeerCapabilities {
pub fn has_ipv4(&self) -> bool {
self.external_addresses.iter().any(|addr| addr.is_ipv4())
}
pub fn has_ipv6(&self) -> bool {
self.external_addresses.iter().any(|addr| addr.is_ipv6())
}
pub fn supports_dual_stack(&self) -> bool {
self.has_ipv4() && self.has_ipv6()
}
pub fn addresses_by_version(&self, ipv4: bool) -> Vec<SocketAddr> {
self.external_addresses
.iter()
.filter(|addr| addr.is_ipv4() == ipv4)
.copied()
.collect()
}
pub fn can_bridge(&self, source: &SocketAddr, target: &SocketAddr) -> bool {
let source_v4 = source.is_ipv4();
let target_v4 = target.is_ipv4();
if source_v4 == target_v4 {
return true;
}
self.supports_dual_stack()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NatType {
None,
FullCone,
AddressRestrictedCone,
PortRestrictedCone,
Symmetric,
Unknown,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ConnectionStats {
pub success_count: u32,
pub failure_count: u32,
pub avg_rtt_ms: u32,
pub min_rtt_ms: u32,
pub max_rtt_ms: u32,
pub bytes_relayed: u64,
pub coordinations_completed: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum PeerSource {
Seed,
Connection,
Relay,
Coordination,
Merge,
#[default]
Unknown,
}
#[derive(Debug, Clone)]
pub struct ConnectionOutcome {
pub success: bool,
pub rtt_ms: Option<u32>,
pub capabilities_discovered: Option<PeerCapabilities>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelayPathHint {
#[serde(with = "peer_id_serde")]
pub relay_endpoint_id: PeerId,
pub relay_locators: Vec<SocketAddr>,
pub observed_latency_ms: Option<u32>,
pub last_used: SystemTime,
}
impl CachedPeer {
pub fn new(peer_id: PeerId, addresses: Vec<SocketAddr>, source: PeerSource) -> Self {
let now = SystemTime::now();
Self {
peer_id,
addresses,
capabilities: PeerCapabilities::default(),
first_seen: now,
last_seen: now,
last_attempt: None,
stats: ConnectionStats::default(),
quality_score: 0.5, source,
relay_paths: Vec::new(),
token: None,
}
}
pub fn record_success(&mut self, rtt_ms: u32, caps: Option<PeerCapabilities>) {
self.last_seen = SystemTime::now();
self.last_attempt = Some(SystemTime::now());
self.stats.success_count = self.stats.success_count.saturating_add(1);
if self.stats.avg_rtt_ms == 0 {
self.stats.avg_rtt_ms = rtt_ms;
self.stats.min_rtt_ms = rtt_ms;
self.stats.max_rtt_ms = rtt_ms;
} else {
self.stats.avg_rtt_ms = (self.stats.avg_rtt_ms * 7 + rtt_ms) / 8;
self.stats.min_rtt_ms = self.stats.min_rtt_ms.min(rtt_ms);
self.stats.max_rtt_ms = self.stats.max_rtt_ms.max(rtt_ms);
}
if let Some(caps) = caps {
self.capabilities = caps;
}
}
pub fn record_failure(&mut self) {
self.last_attempt = Some(SystemTime::now());
self.stats.failure_count = self.stats.failure_count.saturating_add(1);
}
pub fn calculate_quality(&mut self, weights: &super::config::QualityWeights) {
let total_attempts = self.stats.success_count + self.stats.failure_count;
let success_rate = if total_attempts > 0 {
self.stats.success_count as f64 / total_attempts as f64
} else {
0.5 };
let rtt_score = if self.stats.avg_rtt_ms > 0 {
1.0 - (self.stats.avg_rtt_ms as f64 / 1000.0).min(1.0)
} else {
0.5 };
let age_secs = self
.last_seen
.duration_since(SystemTime::UNIX_EPOCH)
.ok()
.and_then(|last_seen_epoch| {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.ok()
.map(|now_epoch| {
now_epoch
.as_secs()
.saturating_sub(last_seen_epoch.as_secs())
})
})
.unwrap_or(0) as f64;
let freshness = (-age_secs * 0.693 / 86400.0).exp();
let mut cap_bonus: f64 = 0.0;
if self.capabilities.supports_relay {
cap_bonus += 0.25;
}
if self.capabilities.supports_coordination {
cap_bonus += 0.25;
}
if self.capabilities.supports_dual_stack() {
cap_bonus += 0.2; }
if matches!(
self.capabilities.nat_type,
Some(NatType::None) | Some(NatType::FullCone)
) {
cap_bonus += 0.3; }
let cap_score = cap_bonus.min(1.0);
self.quality_score = (success_rate * weights.success_rate
+ rtt_score * weights.rtt
+ freshness * weights.freshness
+ cap_score * weights.capabilities)
.clamp(0.0, 1.0);
}
pub fn is_stale(&self, threshold: Duration) -> bool {
self.last_seen
.elapsed()
.map(|age| age > threshold)
.unwrap_or(true)
}
pub fn success_rate(&self) -> f64 {
let total = self.stats.success_count + self.stats.failure_count;
if total == 0 {
0.5
} else {
self.stats.success_count as f64 / total as f64
}
}
pub fn merge_addresses(&mut self, other: &CachedPeer) {
for addr in &other.addresses {
if !self.addresses.contains(addr) {
self.addresses.push(*addr);
}
}
if self.addresses.len() > 10 {
self.addresses.truncate(10);
}
}
}
mod peer_id_serde {
use super::PeerId;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S>(peer_id: &PeerId, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
hex::encode(peer_id.0).serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<PeerId, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let bytes = hex::decode(&s).map_err(serde::de::Error::custom)?;
if bytes.len() != 32 {
return Err(serde::de::Error::custom("PeerId must be 32 bytes"));
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Ok(PeerId(arr))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cached_peer_new() {
let peer_id = PeerId([1u8; 32]);
let peer = CachedPeer::new(
peer_id,
vec!["127.0.0.1:9000".parse().unwrap()],
PeerSource::Seed,
);
assert_eq!(peer.peer_id, peer_id);
assert_eq!(peer.addresses.len(), 1);
assert_eq!(peer.source, PeerSource::Seed);
assert!((peer.quality_score - 0.5).abs() < f64::EPSILON);
}
#[test]
fn test_record_success() {
let mut peer = CachedPeer::new(
PeerId([1u8; 32]),
vec!["127.0.0.1:9000".parse().unwrap()],
PeerSource::Seed,
);
peer.record_success(100, None);
assert_eq!(peer.stats.success_count, 1);
assert_eq!(peer.stats.avg_rtt_ms, 100);
assert_eq!(peer.stats.min_rtt_ms, 100);
assert_eq!(peer.stats.max_rtt_ms, 100);
peer.record_success(200, None);
assert_eq!(peer.stats.success_count, 2);
assert_eq!(peer.stats.avg_rtt_ms, 112);
assert_eq!(peer.stats.min_rtt_ms, 100);
assert_eq!(peer.stats.max_rtt_ms, 200);
}
#[test]
fn test_record_failure() {
let mut peer = CachedPeer::new(
PeerId([1u8; 32]),
vec!["127.0.0.1:9000".parse().unwrap()],
PeerSource::Seed,
);
peer.record_failure();
assert_eq!(peer.stats.failure_count, 1);
assert!(peer.last_attempt.is_some());
}
#[test]
fn test_success_rate() {
let mut peer = CachedPeer::new(
PeerId([1u8; 32]),
vec!["127.0.0.1:9000".parse().unwrap()],
PeerSource::Seed,
);
assert!((peer.success_rate() - 0.5).abs() < f64::EPSILON);
peer.record_success(100, None);
assert!((peer.success_rate() - 1.0).abs() < f64::EPSILON);
peer.record_failure();
assert!((peer.success_rate() - 0.5).abs() < f64::EPSILON);
}
#[test]
fn test_quality_calculation() {
let weights = super::super::config::QualityWeights::default();
let mut peer = CachedPeer::new(
PeerId([1u8; 32]),
vec!["127.0.0.1:9000".parse().unwrap()],
PeerSource::Seed,
);
peer.calculate_quality(&weights);
assert!(peer.quality_score > 0.3 && peer.quality_score < 0.7);
for _ in 0..5 {
peer.record_success(50, None); }
peer.calculate_quality(&weights);
assert!(peer.quality_score > 0.6);
}
#[test]
fn test_peer_serialization() {
let peer = CachedPeer::new(
PeerId([0xab; 32]),
vec!["127.0.0.1:9000".parse().unwrap()],
PeerSource::Seed,
);
let json = serde_json::to_string(&peer).unwrap();
let deserialized: CachedPeer = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.peer_id, peer.peer_id);
assert_eq!(deserialized.addresses, peer.addresses);
assert_eq!(deserialized.source, peer.source);
}
#[test]
fn test_peer_capabilities_dual_stack() {
let mut caps = PeerCapabilities::default();
assert!(!caps.supports_dual_stack());
assert!(!caps.has_ipv4());
assert!(!caps.has_ipv6());
caps.external_addresses
.push("127.0.0.1:9000".parse().unwrap());
assert!(!caps.supports_dual_stack());
assert!(caps.has_ipv4());
assert!(!caps.has_ipv6());
caps.external_addresses.push("[::1]:9001".parse().unwrap());
assert!(caps.supports_dual_stack());
assert!(caps.has_ipv4());
assert!(caps.has_ipv6());
}
#[test]
fn test_peer_capabilities_ipv6_only() {
let mut caps = PeerCapabilities::default();
caps.external_addresses.push("[::1]:9000".parse().unwrap());
caps.external_addresses.push("[::1]:9001".parse().unwrap());
assert!(!caps.supports_dual_stack());
assert!(!caps.has_ipv4());
assert!(caps.has_ipv6());
}
#[test]
fn test_peer_capabilities_can_bridge() {
let mut caps = PeerCapabilities::default();
caps.external_addresses
.push("127.0.0.1:9000".parse().unwrap());
caps.external_addresses.push("[::1]:9001".parse().unwrap());
let v4_src: SocketAddr = "192.168.1.1:1000".parse().unwrap();
let v4_dst: SocketAddr = "192.168.1.2:2000".parse().unwrap();
let v6_src: SocketAddr = "[2001:db8::1]:1000".parse().unwrap();
let v6_dst: SocketAddr = "[2001:db8::2]:2000".parse().unwrap();
assert!(caps.can_bridge(&v4_src, &v4_dst));
assert!(caps.can_bridge(&v6_src, &v6_dst));
assert!(caps.can_bridge(&v4_src, &v6_dst));
assert!(caps.can_bridge(&v6_src, &v4_dst));
}
#[test]
fn test_peer_capabilities_cannot_bridge_ipv4_only() {
let mut caps = PeerCapabilities::default();
caps.external_addresses
.push("127.0.0.1:9000".parse().unwrap());
let v4_addr: SocketAddr = "192.168.1.1:1000".parse().unwrap();
let v6_addr: SocketAddr = "[2001:db8::1]:1000".parse().unwrap();
assert!(caps.can_bridge(&v4_addr, &v4_addr));
assert!(!caps.can_bridge(&v4_addr, &v6_addr));
assert!(!caps.can_bridge(&v6_addr, &v4_addr));
}
#[test]
fn test_addresses_by_version() {
let mut caps = PeerCapabilities::default();
caps.external_addresses
.push("127.0.0.1:9000".parse().unwrap());
caps.external_addresses
.push("10.0.0.1:9001".parse().unwrap());
caps.external_addresses.push("[::1]:9002".parse().unwrap());
let v4_addrs = caps.addresses_by_version(true);
assert_eq!(v4_addrs.len(), 2);
let v6_addrs = caps.addresses_by_version(false);
assert_eq!(v6_addrs.len(), 1);
}
}