phantom_protocol/transport/
virtual_socket.rs1use crate::transport::bandwidth_estimator;
7use crate::transport::{
8 fallback::{FallbackStateMachine, TransportMode},
9 legs::TransportLeg,
10 scheduler::Scheduler,
11 types::{LegType, SchedulerMode},
12};
13
14use bytes::Bytes;
15use std::collections::HashMap;
16use std::io;
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use tokio::sync::{mpsc, Mutex, RwLock};
20
21#[derive(Debug, Clone)]
23pub struct VirtualSocketConfig {
24 pub max_packet_size: u32,
26 pub send_buffer_size: u32,
28 pub recv_buffer_size: u32,
30 pub auto_fallback: bool,
32}
33
34impl Default for VirtualSocketConfig {
35 fn default() -> Self {
36 Self {
37 max_packet_size: 1400,
38 send_buffer_size: 1024,
39 recv_buffer_size: 1024,
40 auto_fallback: true,
41 }
42 }
43}
44
45pub struct VirtualSocket {
47 config: VirtualSocketConfig,
49 legs: RwLock<HashMap<LegType, Arc<dyn TransportLeg>>>,
51 scheduler: Arc<Scheduler>,
53 fallback: Arc<FallbackStateMachine>,
55 recv_tx: mpsc::Sender<Bytes>,
57 recv_rx: Mutex<mpsc::Receiver<Bytes>>,
58 estimators: Arc<Mutex<HashMap<LegType, bandwidth_estimator::BandwidthEstimator>>>,
62 closed: Arc<std::sync::atomic::AtomicBool>,
66}
67
68impl VirtualSocket {
69 pub fn new(
71 config: VirtualSocketConfig,
72 scheduler: Arc<Scheduler>,
73 fallback: Arc<FallbackStateMachine>,
74 ) -> Self {
75 let (recv_tx, recv_rx) = mpsc::channel(config.recv_buffer_size as usize);
76
77 Self {
78 config,
79 legs: RwLock::new(HashMap::new()),
80 scheduler,
81 fallback,
82 recv_tx,
83 recv_rx: Mutex::new(recv_rx),
84 estimators: Arc::new(Mutex::new(HashMap::new())),
85 closed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
86 }
87 }
88
89 pub fn with_defaults() -> Self {
91 let scheduler = Arc::new(Scheduler::new(SchedulerMode::LowLatency));
92 let fallback = Arc::new(FallbackStateMachine::with_defaults());
93 Self::new(VirtualSocketConfig::default(), scheduler, fallback)
94 }
95
96 pub async fn register_leg(&self, leg_type: LegType, leg: Arc<dyn TransportLeg>) {
98 self.legs.write().await.insert(leg_type, leg);
99 self.scheduler.register_path(leg_type);
100 }
101
102 pub async fn unregister_leg(&self, leg_type: LegType) -> Option<Arc<dyn TransportLeg>> {
104 let leg = self.legs.write().await.remove(&leg_type);
105 self.scheduler.set_path_available(leg_type, false);
106 leg
107 }
108
109 pub async fn get_leg(&self, leg_type: LegType) -> Option<Arc<dyn TransportLeg>> {
111 self.legs.read().await.get(&leg_type).cloned()
112 }
113
114 pub async fn send(&self, data: Bytes, is_priority: bool) -> io::Result<()> {
118 const MAX_FALLBACK_ATTEMPTS: u8 = 2;
120
121 for attempt in 0..MAX_FALLBACK_ATTEMPTS {
122 if self.is_closed() {
123 return Err(io::Error::new(io::ErrorKind::NotConnected, "Socket closed"));
124 }
125
126 let paths = self.scheduler.select_paths(is_priority);
128
129 if paths.is_empty() {
130 if attempt == 0 && self.config.auto_fallback {
132 self.fallback.check_and_fallback();
133 continue; }
135 return Err(io::Error::new(
136 io::ErrorKind::NotConnected,
137 "No available paths",
138 ));
139 }
140
141 let legs = self.legs.read().await;
142 let mut last_error = None;
143 let mut send_succeeded = false;
144
145 for leg_type in paths {
146 if let Some(leg) = legs.get(&leg_type) {
147 self.fallback.metrics().record_sent();
148
149 match leg.send(data.clone()).await {
150 Ok(()) => {
151 self.fallback.metrics().record_success();
152 self.scheduler.record_sent(leg_type, data.len() as u64);
153
154 self.scheduler.update_rtt(leg_type, leg.rtt_ms());
156
157 send_succeeded = true;
158 break;
159 }
160 Err(e) => {
161 self.fallback.metrics().record_failure();
162 last_error = Some(e);
163
164 if leg.loss_percent() > 50 {
166 self.scheduler.set_path_available(leg_type, false);
167 }
168 }
169 }
170 }
171 }
172
173 if send_succeeded {
174 return Ok(());
175 }
176
177 if attempt == 0 && self.config.auto_fallback && self.fallback.check_and_fallback() {
179 continue;
181 }
182
183 return Err(last_error.unwrap_or_else(|| io::Error::other("All paths failed")));
184 }
185
186 Err(io::Error::other("Max fallback attempts reached"))
187 }
188
189 pub async fn recv(&self) -> io::Result<Bytes> {
191 if self.is_closed() {
192 return Err(io::Error::new(io::ErrorKind::NotConnected, "Socket closed"));
193 }
194
195 let mut rx = self.recv_rx.lock().await;
196
197 rx.recv()
198 .await
199 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "Channel closed"))
200 }
201
202 pub async fn try_recv(&self) -> Option<Bytes> {
204 let mut rx = self.recv_rx.lock().await;
205 rx.try_recv().ok()
206 }
207
208 pub async fn start_recv_loop(&self, leg_type: LegType) -> io::Result<()> {
210 let leg = self
211 .legs
212 .read()
213 .await
214 .get(&leg_type)
215 .cloned()
216 .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Leg not found"))?;
217
218 let tx = self.recv_tx.clone();
219 let scheduler = self.scheduler.clone();
220 let estimators = self.estimators.clone();
221 let fallback = self.fallback.clone();
222 let closed = self.closed.clone();
225
226 tokio::spawn(async move {
227 loop {
228 if closed.load(std::sync::atomic::Ordering::Relaxed) {
229 break;
230 }
231
232 match leg.recv().await {
233 Ok(data) => {
234 if leg_type == LegType::Kcp {
236 fallback.upgrade();
237 }
238
239 scheduler.update_rtt(leg_type, leg.rtt_ms());
241
242 if let Ok(header) = crate::transport::types::PacketHeader::from_wire(&data)
253 {
254 if header
255 .flags
256 .contains(crate::transport::types::PacketFlags::ACK)
257 {
258 let mut ests: tokio::sync::MutexGuard<
259 '_,
260 HashMap<LegType, bandwidth_estimator::BandwidthEstimator>,
261 > = estimators.lock().await;
262 let est = ests.entry(leg_type).or_default();
263 let ack_delay_us = header.ack_delay as u64;
264 let sample = bandwidth_estimator::DeliverySample {
265 delivered_bytes: 0,
266 sent_at: Instant::now()
267 - Duration::from_millis(leg.rtt_ms() as u64),
268 acked_at: Instant::now(),
269 packet_bytes: data.len() as u64,
270 is_app_limited: false,
271 ack_delay_us,
272 };
273 est.on_ack(sample);
274 }
275 }
276
277 if tx.send(data).await.is_err() {
278 break; }
280 }
281 Err(e) => {
282 log::error!("Recv error on {:?}: {}", leg_type, e);
283 scheduler.set_path_available(leg_type, false);
284 break;
285 }
286 }
287 }
288 });
289
290 Ok(())
291 }
292
293 pub fn current_mode(&self) -> TransportMode {
295 self.fallback.current_mode()
296 }
297
298 pub async fn available_legs(&self) -> Vec<LegType> {
300 self.legs.read().await.keys().cloned().collect()
301 }
302
303 pub fn is_closed(&self) -> bool {
305 self.closed.load(std::sync::atomic::Ordering::Relaxed)
306 }
307
308 pub async fn close(&self) -> io::Result<()> {
310 self.closed
311 .store(true, std::sync::atomic::Ordering::Relaxed);
312
313 let legs = self.legs.write().await;
315 for (_, leg) in legs.iter() {
316 let _ = leg.close().await;
317 }
318
319 Ok(())
320 }
321
322 pub fn scheduler(&self) -> &Arc<Scheduler> {
324 &self.scheduler
325 }
326
327 pub fn fallback(&self) -> &Arc<FallbackStateMachine> {
329 &self.fallback
330 }
331
332 pub fn start_probe_loop(self: Arc<Self>) {
334 let socket = self.clone();
335 tokio::spawn(async move {
336 let mut interval = tokio::time::interval(Duration::from_secs(30));
337 loop {
338 interval.tick().await;
339
340 if socket.is_closed() {
341 break;
342 }
343
344 if socket.fallback.should_probe() {
345 let legs = socket.legs.read().await;
346 if let Some(leg) = legs.get(&LegType::Kcp) {
347 socket.fallback.record_probe();
348
349 let probe = Bytes::from(vec![0u8; 40]);
352 let _ = leg.send(probe).await;
353 log::debug!("Sent transport upgrade probe via KCP");
354 }
355 }
356 }
357 });
358 }
359}
360
361impl std::fmt::Debug for VirtualSocket {
362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 f.debug_struct("VirtualSocket")
364 .field("mode", &self.current_mode())
365 .field("closed", &self.is_closed())
366 .finish()
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[tokio::test]
375 async fn test_virtual_socket_creation() {
376 let socket = VirtualSocket::with_defaults();
377
378 assert!(!socket.is_closed());
379 assert_eq!(socket.current_mode(), TransportMode::Turbo);
380 assert!(socket.available_legs().await.is_empty());
381 }
382
383 #[tokio::test]
388 async fn close_signals_the_shared_flag() {
389 let socket = VirtualSocket::with_defaults();
390 assert!(!socket.is_closed());
391 socket.close().await.expect("close");
392 assert!(socket.is_closed());
393 }
394
395 #[test]
399 fn ack_header_decodes_via_canonical_codec() {
400 use crate::transport::types::{PacketFlags, PacketHeader, PhantomPacket, SessionId};
401 let mut header = PacketHeader::new(
402 SessionId::from_bytes([0x55; 32]),
403 3,
404 7,
405 PacketFlags::new(PacketFlags::ACK),
406 );
407 header.ack_delay = 1234;
408 let wire = PhantomPacket::new(header, Vec::new()).to_wire();
409
410 let parsed = PacketHeader::from_wire(&wire).expect("header parses");
411 assert!(parsed.flags.contains(PacketFlags::ACK));
412 assert_eq!(parsed.ack_delay, 1234);
413
414 assert_eq!(wire[38], 7);
419 }
420}