Skip to main content

phantom_protocol/transport/
virtual_socket.rs

1//! Phantom Transport - Virtual Socket
2//!
3//! Unified socket abstraction over multiple transport legs.
4//! Routes packets through the scheduler, handles fallback.
5
6use crate::transport::bandwidth_estimator;
7use crate::transport::{
8    fallback::{FallbackStateMachine, TransportMode},
9    legs::TransportLeg,
10    scheduler::Scheduler,
11    types::{LegType, SchedulerMode},
12};
13
14use bytes::Bytes;
15use std::collections::HashMap;
16use std::io;
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use tokio::sync::{mpsc, Mutex, RwLock};
20
21/// Virtual socket configuration
22#[derive(Debug, Clone)]
23pub struct VirtualSocketConfig {
24    /// Maximum packet size (MTU)
25    pub max_packet_size: u32,
26    /// Send buffer size
27    pub send_buffer_size: u32,
28    /// Receive buffer size  
29    pub recv_buffer_size: u32,
30    /// Enable automatic fallback
31    pub auto_fallback: bool,
32}
33
34impl Default for VirtualSocketConfig {
35    fn default() -> Self {
36        Self {
37            max_packet_size: 1400,
38            send_buffer_size: 1024,
39            recv_buffer_size: 1024,
40            auto_fallback: true,
41        }
42    }
43}
44
45/// Virtual socket - unified interface over multiple transport legs
46pub struct VirtualSocket {
47    /// Configuration
48    config: VirtualSocketConfig,
49    /// Transport legs
50    legs: RwLock<HashMap<LegType, Arc<dyn TransportLeg>>>,
51    /// Multi-path scheduler
52    scheduler: Arc<Scheduler>,
53    /// Fallback state machine
54    fallback: Arc<FallbackStateMachine>,
55    /// Receive channel
56    recv_tx: mpsc::Sender<Bytes>,
57    recv_rx: Mutex<mpsc::Receiver<Bytes>>,
58    /// Per-leg bandwidth estimators — each network path has independent BBR state.
59    /// Sharing a single estimator across LTE and Wi-Fi is mathematically incorrect
60    /// since RTT, BDP, and loss patterns differ significantly per path.
61    estimators: Arc<Mutex<HashMap<LegType, bandwidth_estimator::BandwidthEstimator>>>,
62    /// Whether socket is closed. `Arc` so the per-leg recv tasks share the
63    /// SAME flag as `close()` — cloning the bool's value (LEGS-004) gave each
64    /// task a private copy that `close()` could never signal.
65    closed: Arc<std::sync::atomic::AtomicBool>,
66}
67
68impl VirtualSocket {
69    /// Create a new virtual socket
70    pub fn new(
71        config: VirtualSocketConfig,
72        scheduler: Arc<Scheduler>,
73        fallback: Arc<FallbackStateMachine>,
74    ) -> Self {
75        let (recv_tx, recv_rx) = mpsc::channel(config.recv_buffer_size as usize);
76
77        Self {
78            config,
79            legs: RwLock::new(HashMap::new()),
80            scheduler,
81            fallback,
82            recv_tx,
83            recv_rx: Mutex::new(recv_rx),
84            estimators: Arc::new(Mutex::new(HashMap::new())),
85            closed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
86        }
87    }
88
89    /// Create with default configuration
90    pub fn with_defaults() -> Self {
91        let scheduler = Arc::new(Scheduler::new(SchedulerMode::LowLatency));
92        let fallback = Arc::new(FallbackStateMachine::with_defaults());
93        Self::new(VirtualSocketConfig::default(), scheduler, fallback)
94    }
95
96    /// Register a transport leg
97    pub async fn register_leg(&self, leg_type: LegType, leg: Arc<dyn TransportLeg>) {
98        self.legs.write().await.insert(leg_type, leg);
99        self.scheduler.register_path(leg_type);
100    }
101
102    /// Unregister a transport leg
103    pub async fn unregister_leg(&self, leg_type: LegType) -> Option<Arc<dyn TransportLeg>> {
104        let leg = self.legs.write().await.remove(&leg_type);
105        self.scheduler.set_path_available(leg_type, false);
106        leg
107    }
108
109    /// Get a transport leg
110    pub async fn get_leg(&self, leg_type: LegType) -> Option<Arc<dyn TransportLeg>> {
111        self.legs.read().await.get(&leg_type).cloned()
112    }
113
114    /// Send data through the virtual socket
115    ///
116    /// The scheduler selects the optimal path(s).
117    pub async fn send(&self, data: Bytes, is_priority: bool) -> io::Result<()> {
118        // Allow one fallback retry
119        const MAX_FALLBACK_ATTEMPTS: u8 = 2;
120
121        for attempt in 0..MAX_FALLBACK_ATTEMPTS {
122            if self.is_closed() {
123                return Err(io::Error::new(io::ErrorKind::NotConnected, "Socket closed"));
124            }
125
126            // Select paths via scheduler
127            let paths = self.scheduler.select_paths(is_priority);
128
129            if paths.is_empty() {
130                // Check for fallback on first attempt
131                if attempt == 0 && self.config.auto_fallback {
132                    self.fallback.check_and_fallback();
133                    continue; // Retry with new mode
134                }
135                return Err(io::Error::new(
136                    io::ErrorKind::NotConnected,
137                    "No available paths",
138                ));
139            }
140
141            let legs = self.legs.read().await;
142            let mut last_error = None;
143            let mut send_succeeded = false;
144
145            for leg_type in paths {
146                if let Some(leg) = legs.get(&leg_type) {
147                    self.fallback.metrics().record_sent();
148
149                    match leg.send(data.clone()).await {
150                        Ok(()) => {
151                            self.fallback.metrics().record_success();
152                            self.scheduler.record_sent(leg_type, data.len() as u64);
153
154                            // Update RTT from leg
155                            self.scheduler.update_rtt(leg_type, leg.rtt_ms());
156
157                            send_succeeded = true;
158                            break;
159                        }
160                        Err(e) => {
161                            self.fallback.metrics().record_failure();
162                            last_error = Some(e);
163
164                            // Mark path as potentially unavailable
165                            if leg.loss_percent() > 50 {
166                                self.scheduler.set_path_available(leg_type, false);
167                            }
168                        }
169                    }
170                }
171            }
172
173            if send_succeeded {
174                return Ok(());
175            }
176
177            // All paths failed on this attempt, try fallback (only on first attempt)
178            if attempt == 0 && self.config.auto_fallback && self.fallback.check_and_fallback() {
179                // Will retry in next loop iteration with new mode
180                continue;
181            }
182
183            return Err(last_error.unwrap_or_else(|| io::Error::other("All paths failed")));
184        }
185
186        Err(io::Error::other("Max fallback attempts reached"))
187    }
188
189    /// Receive data from the virtual socket
190    pub async fn recv(&self) -> io::Result<Bytes> {
191        if self.is_closed() {
192            return Err(io::Error::new(io::ErrorKind::NotConnected, "Socket closed"));
193        }
194
195        let mut rx = self.recv_rx.lock().await;
196
197        rx.recv()
198            .await
199            .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "Channel closed"))
200    }
201
202    /// Try to receive without blocking
203    pub async fn try_recv(&self) -> Option<Bytes> {
204        let mut rx = self.recv_rx.lock().await;
205        rx.try_recv().ok()
206    }
207
208    /// Start background receive loop for a leg
209    pub async fn start_recv_loop(&self, leg_type: LegType) -> io::Result<()> {
210        let leg = self
211            .legs
212            .read()
213            .await
214            .get(&leg_type)
215            .cloned()
216            .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Leg not found"))?;
217
218        let tx = self.recv_tx.clone();
219        let scheduler = self.scheduler.clone();
220        let estimators = self.estimators.clone();
221        let fallback = self.fallback.clone();
222        // Share the SAME close flag with the task (LEGS-004): clone the `Arc`,
223        // not the bool's value, so `close()` actually stops this recv loop.
224        let closed = self.closed.clone();
225
226        tokio::spawn(async move {
227            loop {
228                if closed.load(std::sync::atomic::Ordering::Relaxed) {
229                    break;
230                }
231
232                match leg.recv().await {
233                    Ok(data) => {
234                        // If we receive ANY data on KCP leg, try to upgrade from degraded modes
235                        if leg_type == LegType::Kcp {
236                            fallback.upgrade();
237                        }
238
239                        // Update RTT
240                        scheduler.update_rtt(leg_type, leg.rtt_ms());
241
242                        // Detect ACKs for BBR feedback; update the leg-specific estimator.
243                        // Each leg maintains independent BBR state so that different network
244                        // paths (LTE, Wi-Fi, TCP) don't corrupt each other's bandwidth estimate.
245                        // LEGS-005: parse the canonical 45-byte big-endian header
246                        // via `PacketHeader::from_wire` instead of magic byte
247                        // offsets / little-endian reads. The previous code read
248                        // `data[38]` (the sequence LSB) as the flags byte and
249                        // `data[39..41]` LE (the big-endian flags field) as
250                        // ack_delay — both wrong. `from_wire` reads the header off
251                        // the front and ignores the payload.
252                        if let Ok(header) = crate::transport::types::PacketHeader::from_wire(&data)
253                        {
254                            if header
255                                .flags
256                                .contains(crate::transport::types::PacketFlags::ACK)
257                            {
258                                let mut ests: tokio::sync::MutexGuard<
259                                    '_,
260                                    HashMap<LegType, bandwidth_estimator::BandwidthEstimator>,
261                                > = estimators.lock().await;
262                                let est = ests.entry(leg_type).or_default();
263                                let ack_delay_us = header.ack_delay as u64;
264                                let sample = bandwidth_estimator::DeliverySample {
265                                    delivered_bytes: 0,
266                                    sent_at: Instant::now()
267                                        - Duration::from_millis(leg.rtt_ms() as u64),
268                                    acked_at: Instant::now(),
269                                    packet_bytes: data.len() as u64,
270                                    is_app_limited: false,
271                                    ack_delay_us,
272                                };
273                                est.on_ack(sample);
274                            }
275                        }
276
277                        if tx.send(data).await.is_err() {
278                            break; // Receiver dropped
279                        }
280                    }
281                    Err(e) => {
282                        log::error!("Recv error on {:?}: {}", leg_type, e);
283                        scheduler.set_path_available(leg_type, false);
284                        break;
285                    }
286                }
287            }
288        });
289
290        Ok(())
291    }
292
293    /// Get current transport mode
294    pub fn current_mode(&self) -> TransportMode {
295        self.fallback.current_mode()
296    }
297
298    /// Get available leg types
299    pub async fn available_legs(&self) -> Vec<LegType> {
300        self.legs.read().await.keys().cloned().collect()
301    }
302
303    /// Check if socket is closed
304    pub fn is_closed(&self) -> bool {
305        self.closed.load(std::sync::atomic::Ordering::Relaxed)
306    }
307
308    /// Close the virtual socket
309    pub async fn close(&self) -> io::Result<()> {
310        self.closed
311            .store(true, std::sync::atomic::Ordering::Relaxed);
312
313        // Close all legs
314        let legs = self.legs.write().await;
315        for (_, leg) in legs.iter() {
316            let _ = leg.close().await;
317        }
318
319        Ok(())
320    }
321
322    /// Get scheduler reference
323    pub fn scheduler(&self) -> &Arc<Scheduler> {
324        &self.scheduler
325    }
326
327    /// Get fallback state machine reference
328    pub fn fallback(&self) -> &Arc<FallbackStateMachine> {
329        &self.fallback
330    }
331
332    /// Start the background probe loop to allow transport healing
333    pub fn start_probe_loop(self: Arc<Self>) {
334        let socket = self.clone();
335        tokio::spawn(async move {
336            let mut interval = tokio::time::interval(Duration::from_secs(30));
337            loop {
338                interval.tick().await;
339
340                if socket.is_closed() {
341                    break;
342                }
343
344                if socket.fallback.should_probe() {
345                    let legs = socket.legs.read().await;
346                    if let Some(leg) = legs.get(&LegType::Kcp) {
347                        socket.fallback.record_probe();
348
349                        // Send a dummy probe packet (40 bytes of zeros)
350                        // Peer will receive it, and its recv_loop will trigger their upgrade
351                        let probe = Bytes::from(vec![0u8; 40]);
352                        let _ = leg.send(probe).await;
353                        log::debug!("Sent transport upgrade probe via KCP");
354                    }
355                }
356            }
357        });
358    }
359}
360
361impl std::fmt::Debug for VirtualSocket {
362    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363        f.debug_struct("VirtualSocket")
364            .field("mode", &self.current_mode())
365            .field("closed", &self.is_closed())
366            .finish()
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[tokio::test]
375    async fn test_virtual_socket_creation() {
376        let socket = VirtualSocket::with_defaults();
377
378        assert!(!socket.is_closed());
379        assert_eq!(socket.current_mode(), TransportMode::Turbo);
380        assert!(socket.available_legs().await.is_empty());
381    }
382
383    /// LEGS-004: `close()` flips the shared flag the per-leg recv tasks observe.
384    /// (The recv loop captures `self.closed.clone()` — the SAME `Arc` — so this
385    /// store is visible to it; the old code cloned the bool's value and the
386    /// task never saw `close()`.)
387    #[tokio::test]
388    async fn close_signals_the_shared_flag() {
389        let socket = VirtualSocket::with_defaults();
390        assert!(!socket.is_closed());
391        socket.close().await.expect("close");
392        assert!(socket.is_closed());
393    }
394
395    /// LEGS-005: the recv loop's BBR ACK detection decodes the header through
396    /// the canonical big-endian `PacketHeader::from_wire`, not magic byte
397    /// offsets. This pins the contract that decode relies on.
398    #[test]
399    fn ack_header_decodes_via_canonical_codec() {
400        use crate::transport::types::{PacketFlags, PacketHeader, PhantomPacket, SessionId};
401        let mut header = PacketHeader::new(
402            SessionId::from_bytes([0x55; 32]),
403            3,
404            7,
405            PacketFlags::new(PacketFlags::ACK),
406        );
407        header.ack_delay = 1234;
408        let wire = PhantomPacket::new(header, Vec::new()).to_wire();
409
410        let parsed = PacketHeader::from_wire(&wire).expect("header parses");
411        assert!(parsed.flags.contains(PacketFlags::ACK));
412        assert_eq!(parsed.ack_delay, 1234);
413
414        // Regression guard: byte 38 is the LSB of the big-endian u32 `sequence`
415        // (= 7), which the old code misread as the "flags byte". The flags field
416        // actually lives at bytes 39..41 (big-endian) — exactly what `from_wire`
417        // now reads.
418        assert_eq!(wire[38], 7);
419    }
420}