use std::time::Instant;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::{TransportDispatcher, TransportError, TransportPacket};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct RelayBandwidth {
pub bytes_per_second: u64,
pub burst_bytes: u64,
}
impl RelayBandwidth {
pub fn unbounded() -> Self {
Self {
bytes_per_second: u64::MAX,
burst_bytes: u64::MAX / 2,
}
}
}
#[derive(Debug)]
struct RelayState {
allowance: f64,
last_refill: Instant,
total_bytes: u64,
}
impl RelayState {
fn new(now: Instant, burst_bytes: u64) -> Self {
Self {
allowance: burst_bytes as f64,
last_refill: now,
total_bytes: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct RelayOutcome {
pub from: Uuid,
pub to: Uuid,
pub bytes: u64,
}
pub struct RelayController {
limits: RelayBandwidth,
peers: DashMap<Uuid, RelayState>,
}
impl RelayController {
pub fn new(limits: RelayBandwidth) -> Self {
Self {
limits,
peers: DashMap::new(),
}
}
pub fn relay<D: TransportDispatcher>(
&self,
from: Uuid,
to: Uuid,
packet: TransportPacket,
dispatcher: &D,
) -> Result<RelayOutcome, TransportError> {
let size = estimate_packet_size(&packet)?;
let mut state = self.ensure_state(from);
self.consume_allowance(&mut state, size)?;
dispatcher.send_relay(to, packet)?;
state.total_bytes = state.total_bytes.saturating_add(size);
Ok(RelayOutcome {
from,
to,
bytes: size,
})
}
pub fn total_bytes(&self, peer: Uuid) -> u64 {
self.peers.get(&peer).map(|s| s.total_bytes).unwrap_or(0)
}
fn ensure_state(&self, peer: Uuid) -> dashmap::mapref::one::RefMut<'_, Uuid, RelayState> {
let burst = self.limits.burst_bytes;
self.peers
.entry(peer)
.or_insert_with(|| RelayState::new(Instant::now(), burst))
}
fn consume_allowance(
&self,
state: &mut dashmap::mapref::one::RefMut<'_, Uuid, RelayState>,
size: u64,
) -> Result<(), TransportError> {
let now = Instant::now();
let elapsed = now.saturating_duration_since(state.last_refill);
let tokens_to_add = (elapsed.as_secs_f64() * self.limits.bytes_per_second as f64)
.min(self.limits.burst_bytes as f64);
state.allowance = (state.allowance + tokens_to_add).min(self.limits.burst_bytes as f64);
state.last_refill = now;
if state.allowance < size as f64 {
return Err(TransportError::RateLimited("relay bandwidth exceeded"));
}
state.allowance -= size as f64;
Ok(())
}
}
fn estimate_packet_size(packet: &TransportPacket) -> Result<u64, TransportError> {
let payload_len = serde_json::to_vec(&packet.payload)
.map_err(|_| TransportError::DispatchFailed("payload serialization failed"))?
.len() as u64;
let channel_len = packet.channel.len() as u64;
let cursor_len = if packet.cursor.is_some() { 8 } else { 0 };
Ok(payload_len + channel_len + cursor_len)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TransportRoute;
use serde_json::json;
use std::sync::Mutex;
#[derive(Default)]
struct MockDispatcher {
calls: Mutex<Vec<(TransportRoute, Uuid, TransportPacket)>>,
}
impl TransportDispatcher for MockDispatcher {
fn send_direct(&self, _peer: Uuid, _packet: TransportPacket) -> Result<(), TransportError> {
Ok(())
}
fn send_p2p(&self, _peer: Uuid, _packet: TransportPacket) -> Result<(), TransportError> {
Ok(())
}
fn send_relay(&self, peer: Uuid, packet: TransportPacket) -> Result<(), TransportError> {
self.calls
.lock()
.unwrap()
.push((TransportRoute::Relay, peer, packet));
Ok(())
}
}
#[test]
fn relays_and_accounts_bytes() {
let controller = RelayController::new(RelayBandwidth {
bytes_per_second: 10_000,
burst_bytes: 10_000,
});
let dispatcher = MockDispatcher::default();
let from = Uuid::new_v4();
let to = Uuid::new_v4();
let packet = TransportPacket::new("data", json!({ "hello": "world" }));
let outcome = controller.relay(from, to, packet, &dispatcher).unwrap();
assert_eq!(outcome.from, from);
assert_eq!(outcome.to, to);
assert!(outcome.bytes > 0);
assert_eq!(controller.total_bytes(from), outcome.bytes);
}
#[test]
fn rate_limits_when_budget_exceeded() {
let controller = RelayController::new(RelayBandwidth {
bytes_per_second: 1,
burst_bytes: 4,
});
let dispatcher = MockDispatcher::default();
let from = Uuid::new_v4();
let to = Uuid::new_v4();
let packet = TransportPacket::new("data", json!({ "blob": "12345" }));
let err = controller
.relay(from, to, packet, &dispatcher)
.expect_err("should be rate limited");
assert_eq!(err, TransportError::RateLimited("relay bandwidth exceeded"));
assert_eq!(controller.total_bytes(from), 0);
}
}