use std::fmt;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::UnboundedSender;
use crate::{Error, Result};
const AGE_BONUS_DIVISOR: f64 = 300.0;
const MAX_AGE_BONUS: f64 = 0.2;
const CONSECUTIVE_FAILURE_PENALTY: f64 = 0.1;
const MAX_CONSECUTIVE_PENALTY: f64 = 0.5;
const MAX_CONSECUTIVE_FAILURES: u64 = 5;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PeerId(pub SocketAddr);
impl fmt::Display for PeerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<SocketAddr> for PeerId {
fn from(addr: SocketAddr) -> Self {
Self(addr)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PeerState {
Connecting,
Connected,
Stale,
Disconnected,
}
#[derive(Debug, Clone)]
pub struct PeerInfo {
pub id: PeerId,
pub state: PeerState,
pub last_seen: Instant,
pub connected_at: Instant,
pub messages_received: u64,
pub messages_sent: u64,
pub message_failures: u64,
pub consecutive_failures: u64,
}
impl PeerInfo {
pub fn new(id: PeerId) -> Self {
let now = Instant::now();
Self {
id,
state: PeerState::Connecting,
last_seen: now,
connected_at: now,
messages_received: 0,
messages_sent: 0,
message_failures: 0,
consecutive_failures: 0,
}
}
pub fn mark_connected(&mut self) {
self.state = PeerState::Connected;
}
pub fn mark_stale(&mut self) {
self.state = PeerState::Stale;
}
pub fn mark_disconnected(&mut self) {
self.state = PeerState::Disconnected;
}
pub fn update_last_seen(&mut self) {
self.last_seen = Instant::now();
if self.state == PeerState::Stale {
self.state = PeerState::Connected;
}
}
pub fn is_stale(&self, timeout: Duration) -> bool {
self.last_seen.elapsed() > timeout
}
pub fn increment_received(&mut self) {
self.messages_received = self.messages_received.saturating_add(1);
self.update_last_seen();
}
pub fn increment_sent(&mut self) {
self.messages_sent = self.messages_sent.saturating_add(1);
self.consecutive_failures = 0;
}
pub fn record_failure(&mut self) {
self.message_failures = self.message_failures.saturating_add(1);
self.consecutive_failures = self.consecutive_failures.saturating_add(1);
}
pub fn health_score(&self) -> f64 {
let total_attempts = self.messages_sent + self.message_failures;
if total_attempts == 0 {
return 1.0;
}
let success_rate = self.messages_sent as f64 / total_attempts as f64;
let age_seconds = self.connected_at.elapsed().as_secs() as f64;
let age_bonus = (age_seconds / AGE_BONUS_DIVISOR).min(MAX_AGE_BONUS);
let consecutive_penalty = (self.consecutive_failures as f64 * CONSECUTIVE_FAILURE_PENALTY)
.min(MAX_CONSECUTIVE_PENALTY);
(success_rate + age_bonus - consecutive_penalty).clamp(0.0, 1.0)
}
pub fn should_disconnect(&self) -> bool {
self.consecutive_failures >= MAX_CONSECUTIVE_FAILURES
}
}
#[derive(Debug)]
pub struct Peer {
pub info: PeerInfo,
sender: UnboundedSender<Bytes>,
}
impl Peer {
pub fn new(id: PeerId, sender: UnboundedSender<Bytes>) -> Self {
Self {
info: PeerInfo::new(id),
sender,
}
}
pub fn send(&mut self, data: Bytes) -> Result<()> {
match self.sender.send(data) {
Ok(()) => {
self.info.increment_sent();
Ok(())
}
Err(err) => {
self.info.record_failure();
Err(Error::Channel(err.to_string()))
}
}
}
pub fn id(&self) -> PeerId {
self.info.id
}
pub fn state(&self) -> PeerState {
self.info.state
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn peer_info_saturating_counters() {
let addr = "127.0.0.1:8000".parse().unwrap();
let mut info = PeerInfo::new(PeerId(addr));
info.messages_received = u64::MAX;
info.increment_received();
assert_eq!(info.messages_received, u64::MAX);
info.messages_sent = u64::MAX;
info.increment_sent();
assert_eq!(info.messages_sent, u64::MAX);
}
#[test]
fn peer_send_success() {
let addr = "127.0.0.1:8000".parse().unwrap();
let peer_id = PeerId(addr);
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let mut peer = Peer::new(peer_id, tx);
let data = Bytes::from("test data");
assert_eq!(peer.info.messages_sent, 0);
let result = peer.send(data.clone());
assert!(result.is_ok());
assert_eq!(peer.info.messages_sent, 1);
let received = rx.try_recv();
assert!(received.is_ok());
assert_eq!(received.unwrap(), data);
}
#[test]
fn peer_send_failure_channel_closed() {
let addr = "127.0.0.1:8000".parse().unwrap();
let peer_id = PeerId(addr);
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
drop(rx);
let mut peer = Peer::new(peer_id, tx);
let data = Bytes::from("test data");
let result = peer.send(data);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), Error::Channel(_)));
}
#[test]
fn peer_multiple_sends() {
let addr = "127.0.0.1:8000".parse().unwrap();
let peer_id = PeerId(addr);
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let mut peer = Peer::new(peer_id, tx);
for i in 0..5 {
let data = Bytes::from(format!("message {i}"));
assert!(peer.send(data.clone()).is_ok());
assert_eq!(peer.info.messages_sent, i + 1);
let received = rx.try_recv().unwrap();
assert_eq!(received, Bytes::from(format!("message {i}")));
}
}
#[test]
fn peer_id_serialization() {
let addr = "127.0.0.1:8000".parse().unwrap();
let peer_id = PeerId(addr);
let serialized = serde_json::to_string(&peer_id).unwrap();
let deserialized: PeerId = serde_json::from_str(&serialized).unwrap();
assert_eq!(peer_id, deserialized);
}
}