Skip to main content

pushwire_server/
relay.rs

1use std::time::Instant;
2
3use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use uuid::Uuid;
6
7use crate::{TransportDispatcher, TransportError, TransportPacket};
8
9/// Relay bandwidth configuration.
10#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
11pub struct RelayBandwidth {
12    /// Allowed bytes per second per sender.
13    pub bytes_per_second: u64,
14    /// Maximum burst allowed immediately.
15    pub burst_bytes: u64,
16}
17
18impl RelayBandwidth {
19    pub fn unbounded() -> Self {
20        Self {
21            bytes_per_second: u64::MAX,
22            burst_bytes: u64::MAX / 2,
23        }
24    }
25}
26
27#[derive(Debug)]
28struct RelayState {
29    allowance: f64,
30    last_refill: Instant,
31    total_bytes: u64,
32}
33
34impl RelayState {
35    fn new(now: Instant, burst_bytes: u64) -> Self {
36        Self {
37            allowance: burst_bytes as f64,
38            last_refill: now,
39            total_bytes: 0,
40        }
41    }
42}
43
44/// Result of a relay attempt.
45#[derive(Debug, Clone)]
46pub struct RelayOutcome {
47    pub from: Uuid,
48    pub to: Uuid,
49    pub bytes: u64,
50}
51
52/// Performs server-side relay with bandwidth accounting and rate limiting.
53pub struct RelayController {
54    limits: RelayBandwidth,
55    peers: DashMap<Uuid, RelayState>,
56}
57
58impl RelayController {
59    pub fn new(limits: RelayBandwidth) -> Self {
60        Self {
61            limits,
62            peers: DashMap::new(),
63        }
64    }
65
66    /// Relay a packet from one peer to another with rate limiting applied to the sender.
67    pub fn relay<D: TransportDispatcher>(
68        &self,
69        from: Uuid,
70        to: Uuid,
71        packet: TransportPacket,
72        dispatcher: &D,
73    ) -> Result<RelayOutcome, TransportError> {
74        let size = estimate_packet_size(&packet)?;
75        let mut state = self.ensure_state(from);
76        self.consume_allowance(&mut state, size)?;
77
78        dispatcher.send_relay(to, packet)?;
79        state.total_bytes = state.total_bytes.saturating_add(size);
80
81        Ok(RelayOutcome {
82            from,
83            to,
84            bytes: size,
85        })
86    }
87
88    /// Returns the total number of bytes relayed for a peer (sender-scoped).
89    pub fn total_bytes(&self, peer: Uuid) -> u64 {
90        self.peers.get(&peer).map(|s| s.total_bytes).unwrap_or(0)
91    }
92
93    fn ensure_state(&self, peer: Uuid) -> dashmap::mapref::one::RefMut<'_, Uuid, RelayState> {
94        let burst = self.limits.burst_bytes;
95        self.peers
96            .entry(peer)
97            .or_insert_with(|| RelayState::new(Instant::now(), burst))
98    }
99
100    fn consume_allowance(
101        &self,
102        state: &mut dashmap::mapref::one::RefMut<'_, Uuid, RelayState>,
103        size: u64,
104    ) -> Result<(), TransportError> {
105        let now = Instant::now();
106        let elapsed = now.saturating_duration_since(state.last_refill);
107        let tokens_to_add = (elapsed.as_secs_f64() * self.limits.bytes_per_second as f64)
108            .min(self.limits.burst_bytes as f64);
109        state.allowance = (state.allowance + tokens_to_add).min(self.limits.burst_bytes as f64);
110        state.last_refill = now;
111
112        if state.allowance < size as f64 {
113            return Err(TransportError::RateLimited("relay bandwidth exceeded"));
114        }
115
116        state.allowance -= size as f64;
117        Ok(())
118    }
119}
120
121fn estimate_packet_size(packet: &TransportPacket) -> Result<u64, TransportError> {
122    let payload_len = serde_json::to_vec(&packet.payload)
123        .map_err(|_| TransportError::DispatchFailed("payload serialization failed"))?
124        .len() as u64;
125    let channel_len = packet.channel.len() as u64;
126    let cursor_len = if packet.cursor.is_some() { 8 } else { 0 };
127    Ok(payload_len + channel_len + cursor_len)
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::TransportRoute;
134    use serde_json::json;
135    use std::sync::Mutex;
136
137    #[derive(Default)]
138    struct MockDispatcher {
139        calls: Mutex<Vec<(TransportRoute, Uuid, TransportPacket)>>,
140    }
141
142    impl TransportDispatcher for MockDispatcher {
143        fn send_direct(&self, _peer: Uuid, _packet: TransportPacket) -> Result<(), TransportError> {
144            Ok(())
145        }
146
147        fn send_p2p(&self, _peer: Uuid, _packet: TransportPacket) -> Result<(), TransportError> {
148            Ok(())
149        }
150
151        fn send_relay(&self, peer: Uuid, packet: TransportPacket) -> Result<(), TransportError> {
152            self.calls
153                .lock()
154                .unwrap()
155                .push((TransportRoute::Relay, peer, packet));
156            Ok(())
157        }
158    }
159
160    #[test]
161    fn relays_and_accounts_bytes() {
162        let controller = RelayController::new(RelayBandwidth {
163            bytes_per_second: 10_000,
164            burst_bytes: 10_000,
165        });
166        let dispatcher = MockDispatcher::default();
167        let from = Uuid::new_v4();
168        let to = Uuid::new_v4();
169
170        let packet = TransportPacket::new("data", json!({ "hello": "world" }));
171        let outcome = controller.relay(from, to, packet, &dispatcher).unwrap();
172
173        assert_eq!(outcome.from, from);
174        assert_eq!(outcome.to, to);
175        assert!(outcome.bytes > 0);
176        assert_eq!(controller.total_bytes(from), outcome.bytes);
177    }
178
179    #[test]
180    fn rate_limits_when_budget_exceeded() {
181        let controller = RelayController::new(RelayBandwidth {
182            bytes_per_second: 1,
183            burst_bytes: 4,
184        });
185        let dispatcher = MockDispatcher::default();
186        let from = Uuid::new_v4();
187        let to = Uuid::new_v4();
188        let packet = TransportPacket::new("data", json!({ "blob": "12345" })); // >4 bytes
189
190        let err = controller
191            .relay(from, to, packet, &dispatcher)
192            .expect_err("should be rate limited");
193        assert_eq!(err, TransportError::RateLimited("relay bandwidth exceeded"));
194        assert_eq!(controller.total_bytes(from), 0);
195    }
196}