Skip to main content

pushwire_server/
transport.rs

1use std::time::{Duration, Instant};
2
3use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use thiserror::Error;
7use uuid::Uuid;
8
9/// Active transport path for a peer.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum TransportMode {
13    /// All traffic flows through the server-managed RPS socket.
14    Direct,
15    /// Peer-to-peer data channel is available.
16    P2p,
17    /// RPS relay is used as fallback for peer traffic.
18    Relay,
19}
20
21/// Why a mode change was initiated.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24pub enum ModeSwitchReason {
25    Requested,
26    PathDegraded,
27    PathRecovered,
28    OperatorOverride,
29}
30
31#[derive(Debug, Clone)]
32struct ModeState {
33    mode: TransportMode,
34    generation: u64,
35    updated_at: Instant,
36    last_failure: Option<Instant>,
37    health: PathHealth,
38    last_cursors: HashMap<String, u64>,
39}
40
41#[derive(Debug, Clone)]
42struct PathHealth {
43    consecutive_failures: u32,
44    backoff_until: Option<Instant>,
45}
46
47impl PathHealth {
48    fn new() -> Self {
49        Self {
50            consecutive_failures: 0,
51            backoff_until: None,
52        }
53    }
54}
55
56/// Thresholds for failover and retry behavior.
57#[derive(Debug, Clone, Copy)]
58pub struct FailoverPolicy {
59    /// Number of consecutive P2P failures before switching to relay.
60    pub failure_threshold: u32,
61    /// How long to wait before attempting to recover back to P2P.
62    pub retry_backoff: Duration,
63}
64
65impl Default for FailoverPolicy {
66    fn default() -> Self {
67        Self {
68            failure_threshold: 3,
69            retry_backoff: Duration::from_secs(10),
70        }
71    }
72}
73
74/// Signal emitted to clients when transport mode changes.
75#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
76pub struct ModeSwitchSignal {
77    pub peer_id: Uuid,
78    pub mode: TransportMode,
79    pub previous_mode: TransportMode,
80    pub reason: ModeSwitchReason,
81    pub generation: u64,
82}
83
84/// Packet that needs to be routed to a peer.
85#[derive(Debug, Clone)]
86pub struct TransportPacket {
87    pub channel: String,
88    pub cursor: Option<u64>,
89    pub payload: serde_json::Value,
90}
91
92impl TransportPacket {
93    pub fn new(channel: impl Into<String>, payload: serde_json::Value) -> Self {
94        Self {
95            channel: channel.into(),
96            cursor: None,
97            payload,
98        }
99    }
100
101    pub fn with_cursor(mut self, cursor: u64) -> Self {
102        self.cursor = Some(cursor);
103        self
104    }
105}
106
107/// How the packet was ultimately routed.
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub enum TransportRoute {
110    Direct,
111    P2p,
112    Relay,
113}
114
115/// Result of a single routing operation.
116#[derive(Debug, Clone)]
117pub struct RouteOutcome {
118    pub route: TransportRoute,
119    pub active_mode: TransportMode,
120    pub switch: Option<ModeSwitchSignal>,
121    pub delivered: bool,
122}
123
124/// Errors that can occur while routing.
125#[derive(Debug, Error, PartialEq, Eq)]
126pub enum TransportError {
127    #[error("path unavailable: {0}")]
128    PathUnavailable(&'static str),
129    #[error("dispatch failed: {0}")]
130    DispatchFailed(&'static str),
131    #[error("rate limited: {0}")]
132    RateLimited(&'static str),
133}
134
135/// Trait that abstracts the actual transport sinks.
136pub trait TransportDispatcher: Send + Sync {
137    fn send_direct(&self, peer: Uuid, packet: TransportPacket) -> Result<(), TransportError>;
138    fn send_p2p(&self, peer: Uuid, packet: TransportPacket) -> Result<(), TransportError>;
139    fn send_relay(&self, peer: Uuid, packet: TransportPacket) -> Result<(), TransportError>;
140}
141
142/// Keeps track of transport mode per peer and routes packets across the best available path.
143pub struct TransportManager {
144    modes: DashMap<Uuid, ModeState>,
145    policy: FailoverPolicy,
146}
147
148impl Default for TransportManager {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154impl TransportManager {
155    pub fn new() -> Self {
156        Self {
157            modes: DashMap::new(),
158            policy: FailoverPolicy::default(),
159        }
160    }
161
162    pub fn with_policy(policy: FailoverPolicy) -> Self {
163        Self {
164            modes: DashMap::new(),
165            policy,
166        }
167    }
168
169    /// Returns the active mode for the peer (defaults to Direct).
170    pub fn mode_for(&self, peer: Uuid) -> TransportMode {
171        self.modes
172            .get(&peer)
173            .map(|s| s.mode)
174            .unwrap_or(TransportMode::Direct)
175    }
176
177    /// Forces a mode change and returns a signal to broadcast, if the mode actually changed.
178    pub fn switch_mode(
179        &self,
180        peer: Uuid,
181        new_mode: TransportMode,
182        reason: ModeSwitchReason,
183    ) -> Option<ModeSwitchSignal> {
184        let mut state = self.ensure_state(peer);
185        if state.mode == new_mode {
186            return None;
187        }
188
189        let previous_mode = state.mode;
190        state.mode = new_mode;
191        state.generation += 1;
192        state.updated_at = Instant::now();
193
194        Some(ModeSwitchSignal {
195            peer_id: peer,
196            mode: new_mode,
197            previous_mode,
198            reason,
199            generation: state.generation,
200        })
201    }
202
203    /// Route a packet according to the current transport mode. If P2P is unavailable,
204    /// the packet is transparently retried via relay and a mode switch signal is returned.
205    pub fn route<D: TransportDispatcher>(
206        &self,
207        peer: Uuid,
208        packet: TransportPacket,
209        dispatcher: &D,
210    ) -> Result<RouteOutcome, TransportError> {
211        if self.is_duplicate(peer, &packet) {
212            let mode = self.mode_for(peer);
213            return Ok(RouteOutcome {
214                route: match mode {
215                    TransportMode::Direct => TransportRoute::Direct,
216                    TransportMode::P2p => TransportRoute::P2p,
217                    TransportMode::Relay => TransportRoute::Relay,
218                },
219                active_mode: mode,
220                switch: None,
221                delivered: false,
222            });
223        }
224
225        let mode = self.mode_for(peer);
226        match mode {
227            TransportMode::Direct => {
228                let record = packet.clone();
229                dispatcher.send_direct(peer, packet)?;
230                self.record_cursor(peer, &record);
231                Ok(RouteOutcome {
232                    route: TransportRoute::Direct,
233                    active_mode: TransportMode::Direct,
234                    switch: None,
235                    delivered: true,
236                })
237            }
238            TransportMode::Relay => {
239                let record = packet.clone();
240                dispatcher.send_relay(peer, packet)?;
241                self.record_cursor(peer, &record);
242                Ok(RouteOutcome {
243                    route: TransportRoute::Relay,
244                    active_mode: TransportMode::Relay,
245                    switch: None,
246                    delivered: true,
247                })
248            }
249            TransportMode::P2p => match dispatcher.send_p2p(peer, packet.clone()) {
250                Ok(_) => {
251                    self.reset_failures(peer);
252                    self.record_cursor(peer, &packet);
253                    Ok(RouteOutcome {
254                        route: TransportRoute::P2p,
255                        active_mode: TransportMode::P2p,
256                        switch: None,
257                        delivered: true,
258                    })
259                }
260                Err(TransportError::PathUnavailable(_)) => {
261                    let switch = self.mark_failure_and_maybe_downgrade(peer);
262                    let record = packet.clone();
263                    dispatcher.send_relay(peer, packet)?;
264                    self.record_cursor(peer, &record);
265                    Ok(RouteOutcome {
266                        route: TransportRoute::Relay,
267                        active_mode: TransportMode::Relay,
268                        switch,
269                        delivered: true,
270                    })
271                }
272                Err(err) => Err(err),
273            },
274        }
275    }
276
277    /// Attempt a background recovery back to P2P when backoff has elapsed.
278    pub fn maybe_retry_p2p(&self, peer: Uuid) -> Option<ModeSwitchSignal> {
279        let mut state = self.ensure_state(peer);
280        if state.mode != TransportMode::Relay {
281            return None;
282        }
283
284        let now = Instant::now();
285        let ready = state
286            .health
287            .backoff_until
288            .map(|until| now >= until)
289            .unwrap_or(false);
290
291        if !ready {
292            return None;
293        }
294
295        state.health.backoff_until = None;
296        state.health.consecutive_failures = 0;
297
298        let previous_mode = state.mode;
299        state.mode = TransportMode::P2p;
300        state.generation += 1;
301        state.updated_at = now;
302
303        Some(ModeSwitchSignal {
304            peer_id: peer,
305            mode: TransportMode::P2p,
306            previous_mode,
307            reason: ModeSwitchReason::PathRecovered,
308            generation: state.generation,
309        })
310    }
311
312    fn ensure_state(&self, peer: Uuid) -> dashmap::mapref::one::RefMut<'_, Uuid, ModeState> {
313        self.modes.entry(peer).or_insert_with(|| ModeState {
314            mode: TransportMode::Direct,
315            generation: 0,
316            updated_at: Instant::now(),
317            last_failure: None,
318            health: PathHealth::new(),
319            last_cursors: HashMap::new(),
320        })
321    }
322
323    fn mark_failure_and_maybe_downgrade(&self, peer: Uuid) -> Option<ModeSwitchSignal> {
324        let mut state = self.ensure_state(peer);
325        let now = Instant::now();
326        state.last_failure = Some(now);
327        state.health.consecutive_failures = state.health.consecutive_failures.saturating_add(1);
328
329        if state.health.consecutive_failures < self.policy.failure_threshold {
330            return None;
331        }
332
333        state.health.backoff_until = Some(now + self.policy.retry_backoff);
334        Some(ModeSwitchSignal {
335            peer_id: peer,
336            mode: TransportMode::Relay,
337            previous_mode: state.mode,
338            reason: ModeSwitchReason::PathDegraded,
339            generation: {
340                state.mode = TransportMode::Relay;
341                state.generation += 1;
342                state.generation
343            },
344        })
345    }
346
347    fn reset_failures(&self, peer: Uuid) {
348        if let Some(mut state) = self.modes.get_mut(&peer) {
349            state.health.consecutive_failures = 0;
350            state.health.backoff_until = None;
351        }
352    }
353
354    fn is_duplicate(&self, peer: Uuid, packet: &TransportPacket) -> bool {
355        let Some(cursor) = packet.cursor else {
356            return false;
357        };
358        let state = self.ensure_state(peer);
359        if let Some(last) = state.last_cursors.get(&packet.channel) {
360            return cursor <= *last;
361        }
362        false
363    }
364
365    fn record_cursor(&self, peer: Uuid, packet: &TransportPacket) {
366        let Some(cursor) = packet.cursor else {
367            return;
368        };
369        if let Some(mut state) = self.modes.get_mut(&peer) {
370            state.last_cursors.insert(packet.channel.clone(), cursor);
371        }
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use serde_json::json;
379    use std::sync::Mutex;
380
381    #[derive(Default)]
382    struct MockDispatcher {
383        fail_p2p: Mutex<bool>,
384        calls: Mutex<Vec<(TransportRoute, Uuid, TransportPacket)>>,
385    }
386
387    impl MockDispatcher {
388        fn with_p2p_failure() -> Self {
389            Self {
390                fail_p2p: Mutex::new(true),
391                calls: Mutex::new(Vec::new()),
392            }
393        }
394
395        fn allow_p2p(&self) {
396            *self.fail_p2p.lock().unwrap() = false;
397        }
398
399        fn calls(&self) -> Vec<(TransportRoute, Uuid, TransportPacket)> {
400            self.calls.lock().unwrap().clone()
401        }
402    }
403
404    impl TransportDispatcher for MockDispatcher {
405        fn send_direct(&self, peer: Uuid, packet: TransportPacket) -> Result<(), TransportError> {
406            self.calls
407                .lock()
408                .unwrap()
409                .push((TransportRoute::Direct, peer, packet));
410            Ok(())
411        }
412
413        fn send_p2p(&self, peer: Uuid, packet: TransportPacket) -> Result<(), TransportError> {
414            if *self.fail_p2p.lock().unwrap() {
415                return Err(TransportError::PathUnavailable("p2p unavailable"));
416            }
417
418            self.calls
419                .lock()
420                .unwrap()
421                .push((TransportRoute::P2p, peer, packet));
422            Ok(())
423        }
424
425        fn send_relay(&self, peer: Uuid, packet: TransportPacket) -> Result<(), TransportError> {
426            self.calls
427                .lock()
428                .unwrap()
429                .push((TransportRoute::Relay, peer, packet));
430            Ok(())
431        }
432    }
433
434    #[test]
435    fn tracks_modes_and_generations() {
436        let manager = TransportManager::new();
437        let peer = Uuid::new_v4();
438
439        assert_eq!(manager.mode_for(peer), TransportMode::Direct);
440
441        let switch = manager
442            .switch_mode(peer, TransportMode::P2p, ModeSwitchReason::Requested)
443            .unwrap();
444        assert_eq!(switch.previous_mode, TransportMode::Direct);
445        assert_eq!(switch.mode, TransportMode::P2p);
446        assert_eq!(switch.generation, 1);
447
448        assert!(
449            manager
450                .switch_mode(peer, TransportMode::P2p, ModeSwitchReason::Requested)
451                .is_none()
452        );
453
454        let switch = manager
455            .switch_mode(peer, TransportMode::Relay, ModeSwitchReason::PathDegraded)
456            .unwrap();
457        assert_eq!(switch.generation, 2);
458        assert_eq!(manager.mode_for(peer), TransportMode::Relay);
459    }
460
461    #[test]
462    fn routes_using_current_mode() {
463        let manager = TransportManager::new();
464        let dispatcher = MockDispatcher::default();
465        dispatcher.allow_p2p();
466
467        let peer = Uuid::new_v4();
468        let _ = manager.switch_mode(peer, TransportMode::P2p, ModeSwitchReason::Requested);
469        let packet = TransportPacket::new("presence", json!({ "message": "hi" }));
470
471        let outcome = manager.route(peer, packet, &dispatcher).unwrap();
472        assert_eq!(outcome.route, TransportRoute::P2p);
473        assert_eq!(outcome.active_mode, TransportMode::P2p);
474        assert!(outcome.delivered);
475        assert!(outcome.switch.is_none());
476
477        let calls = dispatcher.calls();
478        assert_eq!(calls.len(), 1);
479        assert_eq!(calls[0].0, TransportRoute::P2p);
480        assert_eq!(calls[0].2.channel, "presence");
481    }
482
483    #[test]
484    fn falls_back_to_relay_on_p2p_failure() {
485        let manager = TransportManager::with_policy(FailoverPolicy {
486            failure_threshold: 1,
487            retry_backoff: Duration::from_secs(1),
488        });
489        let dispatcher = MockDispatcher::with_p2p_failure();
490        let peer = Uuid::new_v4();
491        let _ = manager.switch_mode(peer, TransportMode::P2p, ModeSwitchReason::Requested);
492
493        let packet = TransportPacket::new("data", json!({ "seq": 1 })).with_cursor(42);
494        let outcome = manager.route(peer, packet, &dispatcher).unwrap();
495
496        assert_eq!(outcome.route, TransportRoute::Relay);
497        assert_eq!(outcome.active_mode, TransportMode::Relay);
498        assert!(outcome.delivered);
499        assert_eq!(
500            manager.mode_for(peer),
501            TransportMode::Relay,
502            "mode should be downgraded after failure"
503        );
504
505        let switch = outcome.switch.expect("switch signal missing");
506        assert_eq!(switch.reason, ModeSwitchReason::PathDegraded);
507        assert_eq!(switch.previous_mode, TransportMode::P2p);
508        assert_eq!(switch.mode, TransportMode::Relay);
509
510        let calls = dispatcher.calls();
511        assert_eq!(calls.len(), 1, "relay should get the retried packet");
512        assert_eq!(calls[0].0, TransportRoute::Relay);
513        assert_eq!(calls[0].2.cursor, Some(42));
514        assert_eq!(calls[0].2.channel, "data");
515    }
516
517    #[test]
518    fn respects_failure_threshold_and_backoff() {
519        let manager = TransportManager::with_policy(FailoverPolicy {
520            failure_threshold: 3,
521            retry_backoff: Duration::from_secs(5),
522        });
523        let dispatcher = MockDispatcher::with_p2p_failure();
524        let peer = Uuid::new_v4();
525        let _ = manager.switch_mode(peer, TransportMode::P2p, ModeSwitchReason::Requested);
526
527        for _ in 0..(manager.policy.failure_threshold - 1) {
528            let packet = TransportPacket::new("data", json!({ "seq": 1 }));
529            let outcome = manager.route(peer, packet, &dispatcher).unwrap();
530            assert!(outcome.switch.is_none(), "should not downgrade yet");
531        }
532
533        let packet = TransportPacket::new("data", json!({ "seq": 2 }));
534        let outcome = manager.route(peer, packet, &dispatcher).unwrap();
535        assert_eq!(outcome.active_mode, TransportMode::Relay);
536        assert!(outcome.delivered);
537        assert!(manager.mode_for(peer) == TransportMode::Relay);
538
539        // Backoff not elapsed -> no retry.
540        assert!(manager.maybe_retry_p2p(peer).is_none());
541    }
542
543    #[test]
544    fn retries_after_backoff() {
545        let manager = TransportManager::with_policy(FailoverPolicy {
546            failure_threshold: 1,
547            retry_backoff: Duration::from_millis(0),
548        });
549        let dispatcher = MockDispatcher::with_p2p_failure();
550        let peer = Uuid::new_v4();
551        let _ = manager.switch_mode(peer, TransportMode::P2p, ModeSwitchReason::Requested);
552
553        // Force downgrade.
554        for _ in 0..manager.policy.failure_threshold {
555            let packet = TransportPacket::new("data", json!({ "seq": 1 }));
556            let _ = manager.route(peer, packet.clone(), &dispatcher).unwrap();
557        }
558
559        // Manually expire backoff.
560        if let Some(mut state) = manager.modes.get_mut(&peer) {
561            state.health.backoff_until = Some(Instant::now());
562        }
563
564        let switch = manager.maybe_retry_p2p(peer).expect("should retry");
565        assert_eq!(switch.mode, TransportMode::P2p);
566        assert_eq!(manager.mode_for(peer), TransportMode::P2p);
567    }
568
569    #[test]
570    fn drops_duplicates_and_preserves_ordering() {
571        let manager = TransportManager::with_policy(FailoverPolicy {
572            failure_threshold: 1,
573            retry_backoff: Duration::from_secs(1),
574        });
575        let dispatcher = MockDispatcher::default();
576        dispatcher.allow_p2p();
577
578        let peer = Uuid::new_v4();
579        let _ = manager.switch_mode(peer, TransportMode::P2p, ModeSwitchReason::Requested);
580
581        let packet1 = TransportPacket::new("data", json!({ "seq": 1 })).with_cursor(10);
582        let outcome1 = manager.route(peer, packet1.clone(), &dispatcher).unwrap();
583        assert!(outcome1.delivered);
584
585        // Duplicate cursor should be dropped (no additional dispatch).
586        let outcome_dup = manager.route(peer, packet1, &dispatcher).unwrap();
587        assert!(!outcome_dup.delivered);
588
589        let packet2 = TransportPacket::new("data", json!({ "seq": 2 })).with_cursor(11);
590        let outcome2 = manager.route(peer, packet2, &dispatcher).unwrap();
591        assert!(outcome2.delivered);
592
593        let calls = dispatcher.calls();
594        // Only two actual dispatches (cursor 10 and 11).
595        assert_eq!(calls.len(), 2);
596        assert_eq!(calls[0].2.cursor, Some(10));
597        assert_eq!(calls[1].2.cursor, Some(11));
598    }
599}