Skip to main content

pink_sidevm/
net.rs

1//! Networking support.
2
3use std::future::Future;
4use std::io::Error;
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use env::tls::TlsServerConfig;
10
11use crate::env::{self, tasks, Result};
12use crate::{ocall, ResourceId};
13
14/// A TCP socket server, listening for connections.
15pub struct TcpListener {
16    res_id: ResourceId,
17}
18
19/// A resource pointing to a connecting TCP socket.
20#[derive(Debug)]
21pub struct TcpConnector {
22    res: Result<ResourceId, env::OcallError>,
23}
24
25/// A connected TCP socket.
26#[derive(Debug)]
27pub struct TcpStream {
28    res_id: ResourceId,
29}
30
31/// Future returned by `TcpListener::accept`.
32pub struct Acceptor<'a> {
33    listener: &'a TcpListener,
34}
35
36impl Future for Acceptor<'_> {
37    type Output = Result<(TcpStream, SocketAddr)>;
38
39    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40        use env::OcallError;
41        let waker_id = tasks::intern_waker(cx.waker().clone());
42        match ocall::tcp_accept(waker_id, self.listener.res_id.0) {
43            Ok((res_id, remote_addr)) => Poll::Ready(Ok((
44                TcpStream::new(ResourceId(res_id)),
45                remote_addr
46                    .parse()
47                    .expect("ocall::tcp_accept returned an invalid remote address"),
48            ))),
49            Err(OcallError::Pending) => Poll::Pending,
50            Err(err) => Poll::Ready(Err(err)),
51        }
52    }
53}
54
55impl TcpListener {
56    /// Bind and listen on the specified address for incoming TCP connections.
57    pub async fn bind(addr: &str) -> Result<Self> {
58        // Side notes: could be used to probe enabled interfaces and occupied ports. We may
59        // consider to introduce some manifest file to further limit the capability in the future
60        let todo = "prevent local interface probing and port occupation";
61        let res_id = ResourceId(ocall::tcp_listen(addr.into(), None)?);
62        Ok(Self { res_id })
63    }
64
65    /// Bind and listen on the specified address for incoming TLS-enabled TCP connections.
66    pub async fn bind_tls(addr: &str, config: TlsServerConfig) -> Result<Self> {
67        let res_id = ResourceId(ocall::tcp_listen(addr.into(), Some(config))?);
68        Ok(Self { res_id })
69    }
70
71    /// Accept a new incoming connection.
72    pub fn accept(&self) -> Acceptor {
73        Acceptor { listener: self }
74    }
75}
76
77impl Future for TcpConnector {
78    type Output = Result<TcpStream>;
79
80    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
81        use env::OcallError;
82
83        let res_id = match &self.get_mut().res {
84            Ok(res_id) => res_id,
85            Err(err) => return Poll::Ready(Err(*err)),
86        };
87
88        match ocall::poll_res(env::tasks::intern_waker(ctx.waker().clone()), res_id.0) {
89            Ok(res_id) => Poll::Ready(Ok(TcpStream::new(ResourceId(res_id)))),
90            Err(OcallError::Pending) => Poll::Pending,
91            Err(err) => Poll::Ready(Err(err)),
92        }
93    }
94}
95
96impl TcpStream {
97    fn new(res_id: ResourceId) -> Self {
98        Self { res_id }
99    }
100
101    /// Initiate a TCP connection to a remote host.
102    pub fn connect(host: &str, port: u16, enable_tls: bool) -> TcpConnector {
103        let res = if enable_tls {
104            ocall::tcp_connect_tls(host.into(), port, env::tls::TlsClientConfig::V0)
105        } else {
106            ocall::tcp_connect(host, port)
107        };
108        let res = res.map(|res_id| ResourceId(res_id));
109        TcpConnector { res }
110    }
111}
112
113#[cfg(feature = "hyper")]
114pub use impl_hyper::{AddrIncoming, AddrStream, HttpConnector};
115#[cfg(feature = "hyper")]
116mod impl_hyper {
117    use super::*;
118    use env::OcallError;
119    use hyper::client::connect::{Connected, Connection};
120    use hyper::server::accept::Accept;
121    use hyper::{service::Service, Uri};
122    use std::{io, task};
123
124    macro_rules! ready_ok {
125        ($poll: expr) => {
126            match $poll {
127                Poll::Ready(Ok(val)) => val,
128                Poll::Ready(Err(OcallError::EndOfFile)) => return Poll::Ready(None),
129                Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
130                Poll::Pending => return Poll::Pending,
131            }
132        };
133    }
134
135    impl Accept for TcpListener {
136        type Conn = TcpStream;
137
138        type Error = env::OcallError;
139
140        fn poll_accept(
141            self: Pin<&mut Self>,
142            cx: &mut Context<'_>,
143        ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
144            let (conn, _addr) = ready_ok!(Pin::new(&mut self.accept()).poll(cx));
145            Poll::Ready(Some(Ok(conn)))
146        }
147    }
148
149    impl Connection for TcpStream {
150        fn connected(&self) -> Connected {
151            Connected::new()
152        }
153    }
154
155    impl TcpListener {
156        /// Convert the listener into another one that outputs AddrStreams.
157        pub fn into_addr_incoming(self) -> AddrIncoming {
158            AddrIncoming { listener: self }
159        }
160    }
161
162    /// An HTTP/HTTPS Connector for hyper working under sidevm.
163    #[derive(Clone, Default, Debug)]
164    pub struct HttpConnector;
165
166    impl HttpConnector {
167        /// Create a new HttpConnector.
168        pub fn new() -> Self {
169            Self
170        }
171    }
172
173    impl Service<Uri> for HttpConnector {
174        type Response = TcpStream;
175        type Error = OcallError;
176        type Future = TcpConnector;
177
178        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
179            Poll::Ready(Ok(()))
180        }
181
182        fn call(&mut self, dst: Uri) -> Self::Future {
183            let is_https = dst.scheme_str() == Some("https");
184            let host = dst
185                .host()
186                .unwrap_or("")
187                .trim_matches(|c| c == '[' || c == ']');
188            let port = dst.port_u16().unwrap_or(if is_https { 443 } else { 80 });
189            TcpStream::connect(host, port, is_https)
190        }
191    }
192
193    /// A wrapper of  TcpListener that outputs AddrStreams.
194    pub struct AddrIncoming {
195        listener: TcpListener,
196    }
197
198    impl Accept for AddrIncoming {
199        type Conn = AddrStream;
200        type Error = env::OcallError;
201
202        fn poll_accept(
203            self: Pin<&mut Self>,
204            cx: &mut task::Context<'_>,
205        ) -> task::Poll<Option<Result<Self::Conn, Self::Error>>> {
206            let (stream, remote_addr) = ready_ok!(Pin::new(&mut self.listener.accept()).poll(cx));
207            Poll::Ready(Some(Ok(AddrStream {
208                stream,
209                remote_addr,
210            })))
211        }
212    }
213
214    /// A wrapper of TcpStream that keep remote address.
215    #[pin_project::pin_project]
216    pub struct AddrStream {
217        #[pin]
218        stream: TcpStream,
219        remote_addr: SocketAddr,
220    }
221
222    impl AddrStream {
223        /// Get the remote address of the connection.
224        pub fn remote_addr(&self) -> SocketAddr {
225            self.remote_addr
226        }
227    }
228
229    #[cfg(feature = "tokio")]
230    const _: () = {
231        use tokio::io::{AsyncRead, AsyncWrite};
232        impl AsyncRead for AddrStream {
233            fn poll_read(
234                self: Pin<&mut Self>,
235                cx: &mut task::Context<'_>,
236                buf: &mut tokio::io::ReadBuf<'_>,
237            ) -> task::Poll<io::Result<()>> {
238                self.project().stream.poll_read(cx, buf)
239            }
240        }
241
242        impl AsyncWrite for AddrStream {
243            fn poll_write(
244                self: Pin<&mut Self>,
245                cx: &mut task::Context<'_>,
246                buf: &[u8],
247            ) -> task::Poll<io::Result<usize>> {
248                self.project().stream.poll_write(cx, buf)
249            }
250
251            fn poll_flush(
252                self: Pin<&mut Self>,
253                cx: &mut task::Context<'_>,
254            ) -> task::Poll<io::Result<()>> {
255                self.project().stream.poll_flush(cx)
256            }
257
258            fn poll_shutdown(
259                self: Pin<&mut Self>,
260                cx: &mut task::Context<'_>,
261            ) -> task::Poll<io::Result<()>> {
262                self.project().stream.poll_shutdown(cx)
263            }
264        }
265    };
266
267    const _: () = {
268        use futures::{AsyncRead, AsyncWrite};
269
270        impl AsyncRead for AddrStream {
271            fn poll_read(
272                self: Pin<&mut Self>,
273                cx: &mut Context<'_>,
274                buf: &mut [u8],
275            ) -> Poll<io::Result<usize>> {
276                self.project().stream.poll_read(cx, buf)
277            }
278        }
279
280        impl AsyncWrite for AddrStream {
281            fn poll_write(
282                self: Pin<&mut Self>,
283                cx: &mut Context<'_>,
284                buf: &[u8],
285            ) -> Poll<io::Result<usize>> {
286                self.project().stream.poll_write(cx, buf)
287            }
288
289            fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
290                self.project().stream.poll_flush(cx)
291            }
292
293            fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
294                self.project().stream.poll_close(cx)
295            }
296        }
297    };
298}
299
300#[cfg(feature = "tokio")]
301mod impl_tokio {
302    use super::*;
303    use tokio::io::{AsyncRead, AsyncWrite};
304
305    impl AsyncRead for TcpStream {
306        fn poll_read(
307            self: Pin<&mut Self>,
308            cx: &mut Context<'_>,
309            buf: &mut tokio::io::ReadBuf<'_>,
310        ) -> Poll<std::io::Result<()>> {
311            let result = {
312                let size = buf.remaining().min(512);
313                let buf = buf.initialize_unfilled_to(size);
314                let waker_id = tasks::intern_waker(cx.waker().clone());
315                ocall::poll_read(waker_id, self.res_id.0, buf)
316            };
317            use env::OcallError;
318            match result {
319                Ok(len) => {
320                    let len = len as usize;
321                    if len > buf.remaining() {
322                        Poll::Ready(Err(Error::from_raw_os_error(
323                            env::OcallError::InvalidEncoding as i32,
324                        )))
325                    } else {
326                        buf.advance(len);
327                        Poll::Ready(Ok(()))
328                    }
329                }
330                Err(OcallError::Pending) => Poll::Pending,
331                Err(err) => Poll::Ready(Err(Error::from_raw_os_error(err as i32))),
332            }
333        }
334    }
335
336    impl AsyncWrite for TcpStream {
337        fn poll_write(
338            self: Pin<&mut Self>,
339            cx: &mut Context<'_>,
340            buf: &[u8],
341        ) -> Poll<Result<usize, Error>> {
342            let waker_id = tasks::intern_waker(cx.waker().clone());
343            into_poll(ocall::poll_write(waker_id, self.res_id.0, buf).map(|len| len as usize))
344        }
345
346        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
347            Poll::Ready(Ok(()))
348        }
349
350        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
351            let waker_id = tasks::intern_waker(cx.waker().clone());
352            into_poll(ocall::poll_shutdown(waker_id, self.res_id.0))
353        }
354    }
355}
356
357mod impl_futures_io {
358    use super::*;
359    use futures::io::{AsyncRead, AsyncWrite};
360
361    impl AsyncRead for TcpStream {
362        fn poll_read(
363            self: Pin<&mut Self>,
364            cx: &mut Context<'_>,
365            buf: &mut [u8],
366        ) -> Poll<std::io::Result<usize>> {
367            let waker_id = tasks::intern_waker(cx.waker().clone());
368            into_poll(ocall::poll_read(waker_id, self.res_id.0, buf).map(|len| len as usize))
369        }
370    }
371
372    impl AsyncWrite for TcpStream {
373        fn poll_write(
374            self: Pin<&mut Self>,
375            cx: &mut Context<'_>,
376            buf: &[u8],
377        ) -> Poll<std::io::Result<usize>> {
378            let waker_id = tasks::intern_waker(cx.waker().clone());
379            into_poll(ocall::poll_write(waker_id, self.res_id.0, buf).map(|len| len as usize))
380        }
381
382        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
383            Poll::Ready(Ok(()))
384        }
385
386        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
387            let waker_id = tasks::intern_waker(cx.waker().clone());
388            into_poll(ocall::poll_shutdown(waker_id, self.res_id.0))
389        }
390    }
391}
392
393fn into_poll<T>(res: Result<T, env::OcallError>) -> Poll<std::io::Result<T>> {
394    match res {
395        Ok(v) => Poll::Ready(Ok(v)),
396        Err(env::OcallError::Pending) => Poll::Pending,
397        Err(err) => Poll::Ready(Err(std::io::Error::from_raw_os_error(err as i32))),
398    }
399}