atm0s_sdn_virtual_socket/vnet/
udp_socket.rs1use std::{
2 fmt::Debug,
3 net::{SocketAddr, SocketAddrV4},
4 ops::DerefMut,
5};
6
7use atm0s_sdn_identity::NodeId;
8use quinn::{udp::EcnCodepoint, AsyncUdpSocket};
9
10use crate::VirtualSocketPkt;
11
12use super::{async_queue::AsyncQueue, internal::VirtualNetInternal, VirtualNetError};
13
14pub struct VirtualUdpSocket {
15 local_port: u16,
16 internal: VirtualNetInternal,
17 queue: AsyncQueue<VirtualSocketPkt>,
18}
19
20impl VirtualUdpSocket {
21 pub(crate) fn new(internal: VirtualNetInternal, port: u16, buffer_size: usize) -> Result<Self, VirtualNetError> {
22 let (queue, local_port) = internal.register_socket(port, buffer_size)?;
23 Ok(Self { internal, queue, local_port })
24 }
25
26 pub fn local_port(&self) -> u16 {
27 self.local_port
28 }
29
30 pub fn send_to_node(&self, node: NodeId, port: u16, payload: &[u8], ecn: Option<u8>) -> Result<(), VirtualNetError> {
31 self.internal.send_to_node(self.local_port, node, port, payload, ecn)
32 }
33
34 pub fn send_to(&self, dest: SocketAddrV4, payload: &[u8], ecn: Option<u8>) -> Result<(), VirtualNetError> {
35 self.internal.send_to(self.local_port, dest, payload, ecn)
36 }
37
38 pub fn try_recv_from(&self) -> Option<VirtualSocketPkt> {
39 self.queue.try_pop()
40 }
41
42 pub async fn recv_from(&self) -> Option<VirtualSocketPkt> {
43 self.queue.recv().await
44 }
45}
46
47impl Debug for VirtualUdpSocket {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("VirtualUdpSocket").field("local_port", &self.local_port).finish()
50 }
51}
52
53impl AsyncUdpSocket for VirtualUdpSocket {
54 fn poll_send(&self, _state: &quinn::udp::UdpState, _cx: &mut std::task::Context, transmits: &[quinn::udp::Transmit]) -> std::task::Poll<Result<usize, std::io::Error>> {
55 for transmit in transmits {
56 let res = match transmit.destination {
57 SocketAddr::V4(addr) => self.internal.send_to(self.local_port, addr, &transmit.contents, transmit.ecn.map(|x| x as u8)),
58 _ => return std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Only IPv4 supported"))),
59 };
60 if res.is_err() {
61 break;
62 }
63 }
64 std::task::Poll::Ready(Ok(transmits.len()))
65 }
66
67 fn poll_recv(&self, cx: &mut std::task::Context, bufs: &mut [std::io::IoSliceMut<'_>], meta: &mut [quinn::udp::RecvMeta]) -> std::task::Poll<std::io::Result<usize>> {
68 match self.queue.poll_pop(cx) {
69 std::task::Poll::Pending => std::task::Poll::Pending,
70 std::task::Poll::Ready(Some(pkt)) => {
71 let len = pkt.payload.len();
72 if len <= bufs[0].len() {
73 bufs[0].deref_mut()[0..len].copy_from_slice(&pkt.payload);
74 meta[0] = quinn::udp::RecvMeta {
75 addr: SocketAddr::V4(pkt.src),
76 len,
77 stride: len,
78 ecn: pkt.ecn.map(|x| EcnCodepoint::from_bits(x).expect("Invalid ECN codepoint")),
79 dst_ip: None,
80 };
81 std::task::Poll::Ready(Ok(1))
82 } else {
83 log::warn!("Buffer too small for packet {} vs {}, dropping", len, bufs[0].len());
84 std::task::Poll::Pending
85 }
86 }
87 std::task::Poll::Ready(None) => std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "Socket closed"))),
88 }
89 }
90
91 fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
92 Ok(SocketAddr::V4(SocketAddrV4::new(self.internal.local_node().into(), self.local_port)))
93 }
94}
95
96impl Drop for VirtualUdpSocket {
97 fn drop(&mut self) {
98 self.internal.unregister_socket(self.local_port);
99 }
100}