phantom_protocol/transport/
udp_transport.rs1#![allow(unsafe_code)]
21
22use super::buffer_pool::BufferPool;
23use super::pacer::Pacer;
24use crate::crypto::aes_session::AesSession;
25use crate::transport::bandwidth_estimator;
26use crate::transport::handshake::{ClientHello, HandshakeResponse, HandshakeServer};
27use std::net::SocketAddr;
28use std::sync::Arc;
29use tokio::io::{self, Result as IoResult};
30use tokio::net::UdpSocket;
31
32pub struct UdpTransport {
34 socket: Arc<UdpSocket>,
35 peer_addr: SocketAddr,
36 session: Arc<AesSession>,
37 buffer_pool: Arc<BufferPool>,
38}
39
40impl UdpTransport {
41 pub async fn bind(local_addr: &str) -> IoResult<Self> {
43 let socket = UdpSocket::bind(local_addr).await?;
44 socket.set_broadcast(false)?;
45
46 let peer_addr = "0.0.0.0:0"
47 .parse()
48 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
49
50 let session = AesSession::from_shared_secret(&[0u8; 32]).map_err(io::Error::other)?;
51
52 Ok(Self {
53 socket: Arc::new(socket),
54 peer_addr,
55 session: Arc::new(session),
56 buffer_pool: Arc::new(BufferPool::new(65536, 16, 256)),
57 })
58 }
59
60 pub async fn connect(&mut self, peer_addr: SocketAddr, session: AesSession) {
62 self.peer_addr = peer_addr;
63 self.session = Arc::new(session);
64 }
65
66 #[inline]
68 pub async fn send(&self, data: &[u8]) -> IoResult<usize> {
69 let encrypted = self.session.encrypt(&[], data).map_err(io::Error::other)?;
70 self.socket.send_to(&encrypted, self.peer_addr).await
71 }
72
73 #[inline]
75 pub async fn send_zero_copy(&self, data: &[u8]) -> IoResult<usize> {
76 let mut buf = Vec::with_capacity(data.len() + 16);
77 buf.extend_from_slice(data);
78 self.session
79 .encrypt_in_place(&[], &mut buf)
80 .map_err(io::Error::other)?;
81 self.socket.send_to(&buf, self.peer_addr).await
82 }
83
84 #[inline]
86 pub async fn recv(&self) -> IoResult<(Vec<u8>, SocketAddr)> {
87 let mut buf = self.buffer_pool.acquire();
88 buf.resize(65536, 0);
89
90 let (len, addr) = self.socket.recv_from(&mut buf).await?;
91
92 let decrypted = self
93 .session
94 .decrypt(&[], &buf[..len])
95 .map_err(io::Error::other)?;
96
97 Ok((decrypted, addr))
98 }
99
100 #[inline]
102 pub async fn send_batch(&self, packets: &[&[u8]]) -> IoResult<usize> {
103 let mut total = 0;
104 for packet in packets {
105 total += self.send(packet).await?;
106 }
107 Ok(total)
108 }
109
110 pub fn socket(&self) -> &Arc<UdpSocket> {
112 &self.socket
113 }
114
115 pub fn set_pacing_rate(&self, rate_bps: u64) -> IoResult<()> {
127 #[cfg(not(target_os = "linux"))]
130 let _ = rate_bps;
131 #[cfg(target_os = "linux")]
132 {
133 use std::os::unix::io::AsRawFd;
134 let rate_u32 = rate_bps.min(u32::MAX as u64) as u32;
137 let fd = self.socket.as_ref().as_raw_fd();
138 let ret = unsafe {
141 libc::setsockopt(
142 fd,
143 libc::SOL_SOCKET,
144 47, &rate_u32 as *const u32 as *const libc::c_void,
146 std::mem::size_of::<u32>() as libc::socklen_t,
147 )
148 };
149 if ret != 0 {
150 return Err(io::Error::last_os_error());
151 }
152 }
153 Ok(())
154 }
155
156 pub fn buffer_stats(&self) -> super::buffer_pool::PoolStats {
158 self.buffer_pool.stats()
159 }
160}
161
162pub struct UdpHandshakeListener {
164 socket: Arc<UdpSocket>,
165 buffer_pool: Arc<BufferPool>,
166}
167
168impl UdpHandshakeListener {
169 pub async fn bind(local_addr: &str) -> IoResult<Self> {
170 let socket = UdpSocket::bind(local_addr).await?;
171 socket.set_broadcast(false)?;
172
173 Ok(Self {
174 socket: Arc::new(socket),
175 buffer_pool: Arc::new(BufferPool::new(65536, 16, 256)),
176 })
177 }
178
179 pub async fn accept_handshake(&self, server: &HandshakeServer, difficulty: u8) -> IoResult<()> {
181 let mut buf = self.buffer_pool.acquire();
182 buf.resize(65536, 0);
183
184 loop {
185 let (len, addr) = self.socket.recv_from(&mut buf).await?;
186
187 if len < 1200 {
190 continue;
191 }
192
193 let client_hello = match borsh::from_slice::<ClientHello>(&buf[..len]) {
198 Ok(ch) => ch,
199 Err(_) => {
200 continue;
202 }
203 };
204
205 match server.process_client_hello(&client_hello, difficulty, addr.ip()) {
207 HandshakeResponse::Retry(retry_req) => {
208 if let Ok(encoded) = borsh::to_vec(&retry_req) {
210 let _ = self.socket.send_to(&encoded, addr).await;
211 }
212 }
213 HandshakeResponse::Success(server_hello, _session, _early_data) => {
214 if let Ok(encoded) = borsh::to_vec(&server_hello) {
216 let _ = self.socket.send_to(&encoded, addr).await;
217 }
218 }
220 HandshakeResponse::Reject(reject) => {
221 if let Ok(encoded) = borsh::to_vec(&reject) {
224 let _ = self.socket.send_to(&encoded, addr).await;
225 }
226 }
227 HandshakeResponse::Fail(_) => {
228 }
230 }
231
232 break;
234 }
235
236 Ok(())
237 }
238}
239
240pub struct PacedSender {
247 transport: Arc<UdpTransport>,
248 pacer: Arc<Pacer>,
249 estimator: Arc<parking_lot::Mutex<bandwidth_estimator::BandwidthEstimator>>,
250}
251
252impl PacedSender {
253 pub fn new(
255 transport: Arc<UdpTransport>,
256 pacer: Arc<Pacer>,
257 estimator: Arc<parking_lot::Mutex<bandwidth_estimator::BandwidthEstimator>>,
258 ) -> Self {
259 Self {
260 transport,
261 pacer,
262 estimator,
263 }
264 }
265
266 pub async fn send(&self, data: &[u8]) -> IoResult<usize> {
269 let bytes = data.len() as u64;
270
271 loop {
273 if self.pacer.try_consume(bytes) {
274 break;
275 }
276 let wait = self.pacer.time_until_available(bytes);
277 if wait.is_zero() {
278 break;
279 }
280 tokio::time::sleep(wait).await;
281 }
282
283 self.estimator.lock().on_send(bytes);
285
286 self.transport.send(data).await
287 }
288
289 pub async fn send_unpaced(&self, data: &[u8]) -> IoResult<usize> {
291 self.transport.send(data).await
292 }
293
294 pub fn on_ack(&self, sample: bandwidth_estimator::DeliverySample) {
296 let mut est = self.estimator.lock();
297 est.on_ack(sample);
298 let new_rate = est.pacing_rate();
299 self.pacer.set_rate(new_rate);
300 }
301
302 pub fn set_rate(&self, rate_bps: u64) {
304 self.pacer.set_rate(rate_bps);
305 }
306
307 pub fn rate(&self) -> u64 {
309 self.pacer.rate()
310 }
311}
312
313impl std::fmt::Debug for PacedSender {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.debug_struct("PacedSender")
316 .field("rate_bps", &self.pacer.rate())
317 .field("pacer_enabled", &self.pacer.is_enabled())
318 .finish()
319 }
320}
321
322pub struct FastSender {
324 socket: Arc<UdpSocket>,
325 session: Arc<AesSession>,
326 peer_addr: SocketAddr,
327}
328
329impl FastSender {
330 pub fn new(socket: Arc<UdpSocket>, session: Arc<AesSession>, peer_addr: SocketAddr) -> Self {
331 Self {
332 socket,
333 session,
334 peer_addr,
335 }
336 }
337
338 #[inline]
340 pub async fn send(&self, data: &[u8]) -> IoResult<usize> {
341 let mut buf = Vec::with_capacity(data.len() + 16);
342 buf.extend_from_slice(data);
343 self.session
344 .encrypt_in_place(&[], &mut buf)
345 .map_err(io::Error::other)?;
346 self.socket.send_to(&buf, self.peer_addr).await
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[tokio::test]
355 async fn test_udp_transport_create() {
356 let transport = UdpTransport::bind("127.0.0.1:0").await.unwrap();
357 assert_eq!(transport.buffer_stats().pool_size, 16);
358 }
359
360 #[tokio::test]
361 async fn test_paced_sender_creation() {
362 let transport = Arc::new(UdpTransport::bind("127.0.0.1:0").await.unwrap());
363 let pacer = Arc::new(Pacer::new(1_000_000)); let estimator = Arc::new(parking_lot::Mutex::new(
365 bandwidth_estimator::BandwidthEstimator::new(),
366 ));
367 let sender = PacedSender::new(transport, pacer, estimator);
368
369 assert_eq!(sender.rate(), 1_000_000);
370 sender.set_rate(2_000_000);
371 assert_eq!(sender.rate(), 2_000_000);
372 }
373}