Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
use crate::sync::map::FastMap;
use alloc::boxed::Box;
use alloc::format;
use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use core::cell::UnsafeCell;
use core::mem;
use core::net::SocketAddr;
use core::pin::Pin;
use core::task::{Context, Poll};
use futures::{AsyncRead, AsyncWrite, Future};

use crate::linux::io_uring::ffi::{SOCK_CLOEXEC, SOCK_NONBLOCK, SocketDomain, SocketType};
use crate::linux::io_uring::{
    Fd, IoUring, IoVec, MsgHdr, SocketAddrStorage, socket_addr_to_dual_stack,
};
use crate::linux::net::{NetworkError, Result, SocketBufferAllocation, get_buffer_pool};
use crate::linux::sys::{self};
use crate::sync::Waiter;
use crate::sync::lock::{Lock, Mutex};
use crate::{linux, net::udp};

pub struct Socket {
    ring: Arc<IoUring>,
    fd: Fd,
    local_addr: SocketAddr,
    // Track endpoints by peer address
    endpoints: Arc<FastMap<SocketAddr, Arc<EndpointState>>>,
}

// Shared state for an endpoint
struct EndpointState {
    peer_addr: SocketAddr,
    // Queue for received packets from this peer
    recv_queue: Mutex<Vec<Vec<u8>>>,
}

impl udp::Socket<linux::runtime::Runtime, linux::runtime::Share> for Socket {
    fn bind(addr: SocketAddr) -> impl Future<Output = Result<Self>>
    where
        Self: Sized,
    {
        async move {
            // Create io_uring instance
            let ring = Arc::new(IoUring::with_capacity(256).map_err(|e| {
                NetworkError::Internal(format!("Failed to create io_uring: {}", e))
            })?);

            // Create UDP socket
            let domain = match addr {
                SocketAddr::V4(_) => SocketDomain::Inet as i32,
                SocketAddr::V6(_) => SocketDomain::Inet6 as i32,
            };

            let fd = ring
                .socket(
                    domain,
                    SocketType::Dgram as i32 | SOCK_NONBLOCK | SOCK_CLOEXEC,
                    0,
                )
                .await
                .await
                .map_err(|e| NetworkError::Internal(format!("Failed to create socket: {}", e)))?;

            // Convert address
            let (sock_addr, addr_len) = socket_addr_to_dual_stack(addr);

            // Bind the socket
            unsafe {
                sys::bind(
                    *fd,
                    &sock_addr as *const _ as *const sys::SockAddr,
                    addr_len as u32,
                )
                .map_err(|e| NetworkError::AddressInUse)?;
            }

            let socket = Socket {
                ring: ring.clone(),
                fd,
                local_addr: addr,
                endpoints: Arc::new(FastMap::new()),
            };

            // Start background packet receiver task
            let bg_ring = ring.clone();
            let bg_fd = fd;
            let bg_endpoints = socket.endpoints.clone();

            crate::runtime::spawn(async move {
                loop {
                    // Allocate buffer for receiving
                    let mut buf = vec![0u8; 65536];
                    let mut addr_storage = SocketAddrStorage::new();
                    let mut addr_len = mem::size_of::<SocketAddrStorage>() as u32;

                    // Set up iovec
                    let mut iovec = IoVec {
                        base: buf.as_mut_ptr(),
                        len: buf.len(),
                    };

                    // Create message header
                    let mut msghdr = MsgHdr {
                        name: &mut addr_storage as *mut _ as *mut core::ffi::c_void,
                        namelen: addr_len,
                        iov: &mut iovec as *mut IoVec,
                        iovlen: 1,
                        control: core::ptr::null_mut(),
                        controllen: 0,
                        flags: 0,
                    };

                    // Receive a packet
                    match bg_ring.recvmsg(bg_fd, &mut msghdr).await.await {
                        Ok(n) => {
                            // Queue packet for the appropriate endpoint
                            if let Some(peer_addr) =
                                addr_storage.to_socket_addr(msghdr.namelen as usize)
                            {
                                bg_endpoints
                                    .get_ref(&peer_addr, |state| {
                                        let state_clone = state.clone();
                                        let packet = buf[..n].to_vec();
                                        crate::runtime::spawn(async move {
                                            state_clone
                                                .recv_queue
                                                .lock(Waiter::default())
                                                .await
                                                .push(packet);
                                        });
                                    })
                                    .await;
                            }
                        }
                        Err(_) => {
                            // Socket closed or error, exit task
                            break;
                        }
                    }
                }
            });

            Ok(socket)
        }
    }

    fn accept(&self) -> impl Future<Output = Result<(Endpoint, SocketAddr)>> {
        async move {
            // For UDP, "accept" means receiving the first packet from a new peer
            // We'll use recvmsg to get the first packet and peer address

            loop {
                // Allocate buffer for receiving
                let mut buf = vec![0u8; 65536]; // Max UDP packet size
                let mut addr_storage = SocketAddrStorage::new();
                let mut addr_len = mem::size_of::<SocketAddrStorage>() as u32;

                // Set up iovec for the buffer
                let mut iovec = IoVec {
                    base: buf.as_mut_ptr(),
                    len: buf.len(),
                };

                // Create message header with address
                let mut msghdr = MsgHdr {
                    name: &mut addr_storage as *mut _ as *mut core::ffi::c_void,
                    namelen: addr_len,
                    iov: &mut iovec as *mut IoVec,
                    iovlen: 1,
                    control: core::ptr::null_mut(),
                    controllen: 0,
                    flags: 0,
                };

                // Receive a packet
                let n = self
                    .ring
                    .recvmsg(self.fd, &mut msghdr)
                    .await
                    .await
                    .map_err(|e| NetworkError::Internal(format!("recvmsg failed: {}", e)))?;

                // Extract peer address
                if let Some(peer_addr) = addr_storage.to_socket_addr(msghdr.namelen as usize) {
                    // Check if this is a new peer
                    if !self.endpoints.contains_key(&peer_addr).await {
                        // Create new endpoint state
                        let endpoint_state = Arc::new(EndpointState {
                            peer_addr,
                            recv_queue: Mutex::new(vec![buf[..n].to_vec()]),
                        });

                        // Add to tracked endpoints
                        self.endpoints
                            .insert(peer_addr, endpoint_state.clone())
                            .await;

                        // Create endpoint
                        let endpoint = Endpoint {
                            ring: self.ring.clone(),
                            fd: self.fd,
                            local_addr: self.local_addr,
                            peer_addr,
                            state: endpoint_state,
                            socket_endpoints: self.endpoints.clone(),
                        };

                        return Ok((endpoint, peer_addr));
                    }
                    // If it's an existing peer, store the packet and continue looking for new peers
                    else {
                        self.endpoints
                            .get_ref(&peer_addr, |state| {
                                let state_clone = state.clone();
                                let packet = buf[..n].to_vec();
                                crate::runtime::spawn(async move {
                                    state_clone
                                        .recv_queue
                                        .lock(Waiter::default())
                                        .await
                                        .push(packet);
                                });
                            })
                            .await;
                    }
                }
            }
        }
    }

    fn connect_to(&self, peer: SocketAddr) -> impl Future<Output = Result<Endpoint>> {
        async move {
            // Check if endpoint already exists
            let endpoint_state = if let Some(state) = self.endpoints.get(&peer).await {
                state
            } else {
                // Create new endpoint state
                let state = Arc::new(EndpointState {
                    peer_addr: peer,
                    recv_queue: Mutex::new(Vec::new()),
                });
                self.endpoints.insert(peer, state.clone()).await;
                state
            };

            // Create endpoint
            let endpoint = Endpoint {
                ring: self.ring.clone(),
                fd: self.fd,
                local_addr: self.local_addr,
                peer_addr: peer,
                state: endpoint_state,
                socket_endpoints: self.endpoints.clone(),
            };

            Ok(endpoint)
        }
    }

    fn local_addr(&self) -> Result<SocketAddr> {
        Ok(self.local_addr)
    }
}

pub struct Endpoint {
    ring: Arc<IoUring>,
    fd: Fd,
    local_addr: SocketAddr,
    peer_addr: SocketAddr,
    state: Arc<EndpointState>,
    socket_endpoints: Arc<FastMap<SocketAddr, Arc<EndpointState>>>,
}

// Make Endpoint Send + Sync
unsafe impl Send for Endpoint {}
unsafe impl Sync for Endpoint {}

impl AsyncRead for Endpoint {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<futures::io::Result<usize>> {
        let mut buf = buf.to_vec();
        let mut short_circuit = buf.clone();
        // First check if we have queued packets
        let this = unsafe { self.get_unchecked_mut() };

        // Create a future to check the recv queue
        let check_queue = async {
            let mut queue = this.state.recv_queue.lock(Waiter::default()).await;
            if let Some(packet) = queue.pop() {
                let len = packet.len().min(short_circuit.len());
                short_circuit[..len].copy_from_slice(&packet[..len]);
                return Ok(len);
            }
            Err(())
        };

        // Poll the queue check
        let mut queue_future = Box::pin(check_queue);
        match queue_future.as_mut().poll(cx) {
            Poll::Ready(Ok(n)) => return Poll::Ready(Ok(n)),
            Poll::Ready(Err(())) => {} // Queue empty, continue to recv
            Poll::Pending => return Poll::Pending,
        }

        // If no queued packets, receive new packets
        let ring = this.ring.clone();
        let fd = this.fd;
        let peer_addr = this.peer_addr;
        let state = this.state.clone();
        let socket_endpoints = this.socket_endpoints.clone();

        let recv_future = async move {
            loop {
                // Allocate buffer for receiving
                let mut recv_buf = vec![0u8; 65536];
                let mut addr_storage = SocketAddrStorage::new();
                let mut addr_len = mem::size_of::<SocketAddrStorage>() as u32;

                // Set up iovec
                let mut iovec = IoVec {
                    base: recv_buf.as_mut_ptr(),
                    len: recv_buf.len(),
                };

                // Create message header
                let mut msghdr = MsgHdr {
                    name: &mut addr_storage as *mut _ as *mut core::ffi::c_void,
                    namelen: addr_len,
                    iov: &mut iovec as *mut IoVec,
                    iovlen: 1,
                    control: core::ptr::null_mut(),
                    controllen: 0,
                    flags: 0,
                };

                // Receive a packet
                let n = ring.recvmsg(fd, &mut msghdr).await.await.map_err(|e| {
                    futures::io::Error::new(
                        futures::io::ErrorKind::Other,
                        format!("recvmsg error: {}", e),
                    )
                })?;

                // Check if it's from our peer
                if let Some(from_addr) = addr_storage.to_socket_addr(msghdr.namelen as usize) {
                    if from_addr == peer_addr {
                        // This packet is for us
                        let len = n.min(buf.len());
                        buf[..len].copy_from_slice(&recv_buf[..len]);
                        return Ok(len);
                    } else {
                        // Packet from different peer, queue it
                        socket_endpoints
                            .get_ref(&from_addr, |other_state| {
                                let state_clone = other_state.clone();
                                let packet = recv_buf[..n].to_vec();
                                crate::runtime::spawn(async move {
                                    state_clone
                                        .recv_queue
                                        .lock(Waiter::default())
                                        .await
                                        .push(packet);
                                });
                            })
                            .await;
                        // Continue looking for our packets
                    }
                }
            }
        };

        let mut recv_pin = Box::pin(recv_future);
        recv_pin.as_mut().poll(cx)
    }
}

impl AsyncWrite for Endpoint {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<futures::io::Result<usize>> {
        let this = unsafe { self.get_unchecked_mut() };

        // Create sendmsg future
        let ring = this.ring.clone();
        let fd = this.fd;
        let peer_addr = this.peer_addr;
        let data = buf.to_vec();

        let send_future = async move {
            // Convert peer address
            let (mut addr_storage, addr_len) = socket_addr_to_dual_stack(peer_addr);

            // Set up iovec
            let mut iovec = IoVec {
                base: data.as_ptr() as *mut u8,
                len: data.len(),
            };

            // Create message header
            let msghdr = MsgHdr {
                name: &mut addr_storage as *mut _ as *mut core::ffi::c_void,
                namelen: addr_len as u32,
                iov: &mut iovec as *mut IoVec,
                iovlen: 1,
                control: core::ptr::null_mut(),
                controllen: 0,
                flags: 0,
            };

            // Send the packet
            ring.sendmsg(fd, &msghdr, 0).await.await.map_err(|e| {
                futures::io::Error::new(
                    futures::io::ErrorKind::Other,
                    format!("sendmsg error: {}", e),
                )
            })
        };

        let mut send_pin = Box::pin(send_future);
        send_pin.as_mut().poll(cx)
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures::io::Result<()>> {
        // UDP doesn't need flushing
        Poll::Ready(Ok(()))
    }

    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures::io::Result<()>> {
        // Remove this endpoint from the socket's tracking
        let this = unsafe { self.get_unchecked_mut() };
        let peer_addr = this.peer_addr;
        let socket_endpoints = this.socket_endpoints.clone();

        let cleanup_future = async move {
            socket_endpoints.remove(&peer_addr).await;
            futures::io::Result::Ok(())
        };

        let mut cleanup_pin = Box::pin(cleanup_future);
        match cleanup_pin.as_mut().poll(_cx) {
            Poll::Ready(_) => Poll::Ready(Ok(())),
            Poll::Pending => Poll::Pending,
        }
    }
}

impl udp::Endpoint<linux::runtime::Runtime, linux::runtime::Share> for Endpoint {
    fn local_addr(&self) -> Result<SocketAddr> {
        Ok(self.local_addr)
    }

    fn peer_addr(&self) -> Result<SocketAddr> {
        Ok(self.peer_addr)
    }
}