atm0s_sdn_virtual_socket/vnet/
udp_socket.rs

1use std::{
2    fmt::Debug,
3    net::{SocketAddr, SocketAddrV4},
4    ops::DerefMut,
5};
6
7use atm0s_sdn_identity::NodeId;
8use quinn::{udp::EcnCodepoint, AsyncUdpSocket};
9
10use crate::VirtualSocketPkt;
11
12use super::{async_queue::AsyncQueue, internal::VirtualNetInternal, VirtualNetError};
13
14pub struct VirtualUdpSocket {
15    local_port: u16,
16    internal: VirtualNetInternal,
17    queue: AsyncQueue<VirtualSocketPkt>,
18}
19
20impl VirtualUdpSocket {
21    pub(crate) fn new(internal: VirtualNetInternal, port: u16, buffer_size: usize) -> Result<Self, VirtualNetError> {
22        let (queue, local_port) = internal.register_socket(port, buffer_size)?;
23        Ok(Self { internal, queue, local_port })
24    }
25
26    pub fn local_port(&self) -> u16 {
27        self.local_port
28    }
29
30    pub fn send_to_node(&self, node: NodeId, port: u16, payload: &[u8], ecn: Option<u8>) -> Result<(), VirtualNetError> {
31        self.internal.send_to_node(self.local_port, node, port, payload, ecn)
32    }
33
34    pub fn send_to(&self, dest: SocketAddrV4, payload: &[u8], ecn: Option<u8>) -> Result<(), VirtualNetError> {
35        self.internal.send_to(self.local_port, dest, payload, ecn)
36    }
37
38    pub fn try_recv_from(&self) -> Option<VirtualSocketPkt> {
39        self.queue.try_pop()
40    }
41
42    pub async fn recv_from(&self) -> Option<VirtualSocketPkt> {
43        self.queue.recv().await
44    }
45}
46
47impl Debug for VirtualUdpSocket {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        f.debug_struct("VirtualUdpSocket").field("local_port", &self.local_port).finish()
50    }
51}
52
53impl AsyncUdpSocket for VirtualUdpSocket {
54    fn poll_send(&self, _state: &quinn::udp::UdpState, _cx: &mut std::task::Context, transmits: &[quinn::udp::Transmit]) -> std::task::Poll<Result<usize, std::io::Error>> {
55        for transmit in transmits {
56            let res = match transmit.destination {
57                SocketAddr::V4(addr) => self.internal.send_to(self.local_port, addr, &transmit.contents, transmit.ecn.map(|x| x as u8)),
58                _ => return std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Only IPv4 supported"))),
59            };
60            if res.is_err() {
61                break;
62            }
63        }
64        std::task::Poll::Ready(Ok(transmits.len()))
65    }
66
67    fn poll_recv(&self, cx: &mut std::task::Context, bufs: &mut [std::io::IoSliceMut<'_>], meta: &mut [quinn::udp::RecvMeta]) -> std::task::Poll<std::io::Result<usize>> {
68        match self.queue.poll_pop(cx) {
69            std::task::Poll::Pending => std::task::Poll::Pending,
70            std::task::Poll::Ready(Some(pkt)) => {
71                let len = pkt.payload.len();
72                if len <= bufs[0].len() {
73                    bufs[0].deref_mut()[0..len].copy_from_slice(&pkt.payload);
74                    meta[0] = quinn::udp::RecvMeta {
75                        addr: SocketAddr::V4(pkt.src),
76                        len,
77                        stride: len,
78                        ecn: pkt.ecn.map(|x| EcnCodepoint::from_bits(x).expect("Invalid ECN codepoint")),
79                        dst_ip: None,
80                    };
81                    std::task::Poll::Ready(Ok(1))
82                } else {
83                    log::warn!("Buffer too small for packet {} vs {}, dropping", len, bufs[0].len());
84                    std::task::Poll::Pending
85                }
86            }
87            std::task::Poll::Ready(None) => std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "Socket closed"))),
88        }
89    }
90
91    fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
92        Ok(SocketAddr::V4(SocketAddrV4::new(self.internal.local_node().into(), self.local_port)))
93    }
94}
95
96impl Drop for VirtualUdpSocket {
97    fn drop(&mut self) {
98        self.internal.unregister_socket(self.local_port);
99    }
100}