Skip to main content

curseofrust_net_foundation/
lib.rs

1//! Fundamental async socket backend based on `unisock`.
2
3#![warn(missing_docs)]
4
5use std::net::{SocketAddr, ToSocketAddrs};
6
7use unisock::*;
8
9mod util;
10
11#[allow(unused_imports)]
12use util::*;
13
14/// The protocol of the socket.
15///
16/// The default protocol is `Udp`.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
18#[non_exhaustive]
19pub enum Protocol {
20    /// TCP.
21    Tcp,
22    /// UDP using only **a single socket**.
23    #[default]
24    Udp,
25    /// WebSocket.
26    #[cfg(feature = "ws")]
27    WebSocket,
28}
29
30/// The main handler.
31#[derive(Debug)]
32pub struct Handle(HandleInner);
33
34#[derive(Debug)]
35#[non_exhaustive]
36enum HandleInner {
37    Tcp(unisock_smol::Tcp),
38    Udp(unisock_smol::UdpSingle),
39    #[cfg(feature = "ws")]
40    WebSocket(unisock_smol_tungstenite::WebSocket),
41}
42
43macro_rules! call {
44    ($this:expr, $thist:ident => $fun:ident($($i:expr),*$(,)?).await) => {
45        match $this {
46            $thist::Tcp(ref mut back) => back.$fun($($i),*).await,
47            $thist::Udp(ref mut back) => back.$fun($($i),*).await,
48            #[cfg(feature = "ws")]
49            $thist::WebSocket(ref mut back) => back.$fun($($i),*).await.map_err(err_ws2io),
50        }
51    };
52}
53
54impl Handle {
55    /// Connect to the address with the specified protocol.
56    pub fn bind<A>(addr: A, protocol: Protocol) -> Result<Self, std::io::Error>
57    where
58        A: ToSocketAddrs,
59    {
60        let mut err = None;
61        for addr in addr.to_socket_addrs()? {
62            match protocol {
63                Protocol::Tcp => match unisock_smol::Tcp::bind(addr) {
64                    Ok(back) => return Ok(Self(HandleInner::Tcp(back))),
65                    Err(e) => err = Some(e),
66                },
67                Protocol::Udp => match unisock_smol::UdpSingle::bind(addr) {
68                    Ok(back) => return Ok(Self(HandleInner::Udp(back))),
69                    Err(e) => err = Some(e),
70                },
71                #[cfg(feature = "ws")]
72                Protocol::WebSocket => match unisock_smol_tungstenite::WebSocket::bind(addr) {
73                    Ok(back) => return Ok(Self(HandleInner::WebSocket(back))),
74                    Err(e) => err = Some(err_ws2io(e)),
75                },
76            }
77        }
78
79        Err(err.unwrap_or_else(|| {
80            std::io::Error::new(std::io::ErrorKind::InvalidInput, "no valid address found")
81        }))
82    }
83
84    /// Returns the listener.
85    pub fn listen(&'_ self) -> Result<Listener<'_>, std::io::Error> {
86        match &self.0 {
87            HandleInner::Tcp(back) => back.listen().map(|l| Listener(ListenerInner::Tcp(l))),
88            HandleInner::Udp(back) => Ok(Listener(ListenerInner::Udp(back))),
89            #[cfg(feature = "ws")]
90            HandleInner::WebSocket(back) => back
91                .listen()
92                .map(|l| Listener(ListenerInner::WebSocket(l)))
93                .map_err(err_ws2io),
94        }
95    }
96
97    /// Connect to the address.
98    pub async fn connect<A>(&'_ self, addr: A) -> Result<Connection<'_>, std::io::Error>
99    where
100        A: ToSocketAddrs,
101    {
102        let mut err = None;
103        for addr in addr.to_socket_addrs()? {
104            match &self.0 {
105                HandleInner::Tcp(back) => match back.connect(addr).await {
106                    Ok(conn) => return Ok(Connection(ConnectionInner::Tcp(conn))),
107                    Err(e) => err = Some(e),
108                },
109                HandleInner::Udp(back) => match back.connect(addr).await {
110                    Ok(conn) => return Ok(Connection(ConnectionInner::Udp(conn))),
111                    Err(e) => err = Some(e),
112                },
113                #[cfg(feature = "ws")]
114                HandleInner::WebSocket(back) => match back.connect(addr).await {
115                    Ok(conn) => return Ok(Connection(ConnectionInner::WebSocket(Box::new(conn)))),
116                    Err(e) => err = Some(err_ws2io(e)),
117                },
118            }
119        }
120
121        Err(err.unwrap_or_else(|| {
122            std::io::Error::new(std::io::ErrorKind::InvalidInput, "no valid address found")
123        }))
124    }
125}
126
127/// The listener.
128#[derive(Debug)]
129pub struct Listener<'a>(ListenerInner<'a>);
130
131#[derive(Debug)]
132#[non_exhaustive]
133enum ListenerInner<'a> {
134    Tcp(unisock_smol::tcp::Listener),
135    Udp(&'a unisock_smol::UdpSingle),
136    #[cfg(feature = "ws")]
137    WebSocket(unisock_smol_tungstenite::Listener),
138}
139
140impl Listener<'_> {
141    /// Accept a connection.
142    pub async fn accept(&'_ self) -> Result<(Connection<'_>, SocketAddr), std::io::Error> {
143        match &self.0 {
144            ListenerInner::Tcp(back) => back
145                .accept()
146                .await
147                .map(|(c, a)| (Connection(ConnectionInner::Tcp(c)), a)),
148            ListenerInner::Udp(back) => back
149                .accept()
150                .await
151                .map(|(c, a)| (Connection(ConnectionInner::Udp(c)), a)),
152            #[cfg(feature = "ws")]
153            ListenerInner::WebSocket(back) => back
154                .accept()
155                .await
156                .map(|(c, a)| (Connection(ConnectionInner::WebSocket(Box::new(c))), a))
157                .map_err(err_ws2io),
158        }
159    }
160}
161
162/// The connection.
163#[derive(Debug)]
164pub struct Connection<'a>(ConnectionInner<'a>);
165
166#[derive(Debug)]
167#[non_exhaustive]
168enum ConnectionInner<'a> {
169    Tcp(unisock_smol::tcp::Connection),
170    Udp(unisock_smol::udp_single_sock::Connection<'a>),
171    #[cfg(feature = "ws")]
172    WebSocket(Box<unisock_smol_tungstenite::Connection>),
173}
174
175impl Connection<'_> {
176    /// Send data.
177    pub async fn send(&mut self, data: &[u8]) -> Result<usize, std::io::Error> {
178        call!(self.0, ConnectionInner => write(data).await)
179    }
180
181    /// Receive data.
182    pub async fn recv(&mut self, data: &mut [u8]) -> Result<usize, std::io::Error> {
183        call!(self.0, ConnectionInner => read(data).await)
184    }
185
186    /// Poll the connection for readability.
187    pub fn poll_readable(&self, cx: &mut std::task::Context<'_>) -> bool {
188        match &self.0 {
189            ConnectionInner::Tcp(back) => back.poll_readable(cx),
190            ConnectionInner::Udp(back) => back.poll_readable(cx),
191            #[cfg(feature = "ws")]
192            ConnectionInner::WebSocket(back) => back.poll_readable(cx),
193        }
194    }
195
196    /// Poll the connection for writability.
197    pub fn poll_writable(&self, cx: &mut std::task::Context<'_>) -> bool {
198        match &self.0 {
199            ConnectionInner::Tcp(back) => back.poll_writable(cx),
200            ConnectionInner::Udp(back) => back.poll_writable(cx),
201            #[cfg(feature = "ws")]
202            ConnectionInner::WebSocket(back) => back.poll_writable(cx),
203        }
204    }
205
206    /// Close the connection.
207    pub async fn close(self) -> Result<(), std::io::Error> {
208        match self.0 {
209            ConnectionInner::Tcp(back) => back.close().await,
210            ConnectionInner::Udp(back) => back.close().await,
211            #[cfg(feature = "ws")]
212            ConnectionInner::WebSocket(back) => back.close().await.map_err(err_ws2io),
213        }
214    }
215}