1use std::time::Instant;
2
3use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use uuid::Uuid;
6
7use crate::{TransportDispatcher, TransportError, TransportPacket};
8
9#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
11pub struct RelayBandwidth {
12 pub bytes_per_second: u64,
14 pub burst_bytes: u64,
16}
17
18impl RelayBandwidth {
19 pub fn unbounded() -> Self {
20 Self {
21 bytes_per_second: u64::MAX,
22 burst_bytes: u64::MAX / 2,
23 }
24 }
25}
26
27#[derive(Debug)]
28struct RelayState {
29 allowance: f64,
30 last_refill: Instant,
31 total_bytes: u64,
32}
33
34impl RelayState {
35 fn new(now: Instant, burst_bytes: u64) -> Self {
36 Self {
37 allowance: burst_bytes as f64,
38 last_refill: now,
39 total_bytes: 0,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct RelayOutcome {
47 pub from: Uuid,
48 pub to: Uuid,
49 pub bytes: u64,
50}
51
52pub struct RelayController {
54 limits: RelayBandwidth,
55 peers: DashMap<Uuid, RelayState>,
56}
57
58impl RelayController {
59 pub fn new(limits: RelayBandwidth) -> Self {
60 Self {
61 limits,
62 peers: DashMap::new(),
63 }
64 }
65
66 pub fn relay<D: TransportDispatcher>(
68 &self,
69 from: Uuid,
70 to: Uuid,
71 packet: TransportPacket,
72 dispatcher: &D,
73 ) -> Result<RelayOutcome, TransportError> {
74 let size = estimate_packet_size(&packet)?;
75 let mut state = self.ensure_state(from);
76 self.consume_allowance(&mut state, size)?;
77
78 dispatcher.send_relay(to, packet)?;
79 state.total_bytes = state.total_bytes.saturating_add(size);
80
81 Ok(RelayOutcome {
82 from,
83 to,
84 bytes: size,
85 })
86 }
87
88 pub fn total_bytes(&self, peer: Uuid) -> u64 {
90 self.peers.get(&peer).map(|s| s.total_bytes).unwrap_or(0)
91 }
92
93 fn ensure_state(&self, peer: Uuid) -> dashmap::mapref::one::RefMut<'_, Uuid, RelayState> {
94 let burst = self.limits.burst_bytes;
95 self.peers
96 .entry(peer)
97 .or_insert_with(|| RelayState::new(Instant::now(), burst))
98 }
99
100 fn consume_allowance(
101 &self,
102 state: &mut dashmap::mapref::one::RefMut<'_, Uuid, RelayState>,
103 size: u64,
104 ) -> Result<(), TransportError> {
105 let now = Instant::now();
106 let elapsed = now.saturating_duration_since(state.last_refill);
107 let tokens_to_add = (elapsed.as_secs_f64() * self.limits.bytes_per_second as f64)
108 .min(self.limits.burst_bytes as f64);
109 state.allowance = (state.allowance + tokens_to_add).min(self.limits.burst_bytes as f64);
110 state.last_refill = now;
111
112 if state.allowance < size as f64 {
113 return Err(TransportError::RateLimited("relay bandwidth exceeded"));
114 }
115
116 state.allowance -= size as f64;
117 Ok(())
118 }
119}
120
121fn estimate_packet_size(packet: &TransportPacket) -> Result<u64, TransportError> {
122 let payload_len = serde_json::to_vec(&packet.payload)
123 .map_err(|_| TransportError::DispatchFailed("payload serialization failed"))?
124 .len() as u64;
125 let channel_len = packet.channel.len() as u64;
126 let cursor_len = if packet.cursor.is_some() { 8 } else { 0 };
127 Ok(payload_len + channel_len + cursor_len)
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use crate::TransportRoute;
134 use serde_json::json;
135 use std::sync::Mutex;
136
137 #[derive(Default)]
138 struct MockDispatcher {
139 calls: Mutex<Vec<(TransportRoute, Uuid, TransportPacket)>>,
140 }
141
142 impl TransportDispatcher for MockDispatcher {
143 fn send_direct(&self, _peer: Uuid, _packet: TransportPacket) -> Result<(), TransportError> {
144 Ok(())
145 }
146
147 fn send_p2p(&self, _peer: Uuid, _packet: TransportPacket) -> Result<(), TransportError> {
148 Ok(())
149 }
150
151 fn send_relay(&self, peer: Uuid, packet: TransportPacket) -> Result<(), TransportError> {
152 self.calls
153 .lock()
154 .unwrap()
155 .push((TransportRoute::Relay, peer, packet));
156 Ok(())
157 }
158 }
159
160 #[test]
161 fn relays_and_accounts_bytes() {
162 let controller = RelayController::new(RelayBandwidth {
163 bytes_per_second: 10_000,
164 burst_bytes: 10_000,
165 });
166 let dispatcher = MockDispatcher::default();
167 let from = Uuid::new_v4();
168 let to = Uuid::new_v4();
169
170 let packet = TransportPacket::new("data", json!({ "hello": "world" }));
171 let outcome = controller.relay(from, to, packet, &dispatcher).unwrap();
172
173 assert_eq!(outcome.from, from);
174 assert_eq!(outcome.to, to);
175 assert!(outcome.bytes > 0);
176 assert_eq!(controller.total_bytes(from), outcome.bytes);
177 }
178
179 #[test]
180 fn rate_limits_when_budget_exceeded() {
181 let controller = RelayController::new(RelayBandwidth {
182 bytes_per_second: 1,
183 burst_bytes: 4,
184 });
185 let dispatcher = MockDispatcher::default();
186 let from = Uuid::new_v4();
187 let to = Uuid::new_v4();
188 let packet = TransportPacket::new("data", json!({ "blob": "12345" })); let err = controller
191 .relay(from, to, packet, &dispatcher)
192 .expect_err("should be rate limited");
193 assert_eq!(err, TransportError::RateLimited("relay bandwidth exceeded"));
194 assert_eq!(controller.total_bytes(from), 0);
195 }
196}