Skip to main content

enclave_runner/stream_router/
os.rs

1/* Copyright (c) Fortanix, Inc.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7use std::cmp;
8use std::collections::VecDeque;
9use std::future::Future;
10use std::io::{self, ErrorKind as IoErrorKind, Read, Result as IoResult};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use std::thread;
14
15use futures::FutureExt;
16use futures::lock::Mutex;
17use pin_project_lite::pin_project;
18use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19use tokio::sync::{mpsc as async_mpsc};
20
21use super::{AsyncListener, AsyncStream, StreamRouter};
22
23/// The standard stream router using the host OS.
24///
25/// The basic stream elements 0, 1, and 2 connect to stdin, stdout, and stderr, respectively.
26///
27/// Connect requests are translated into outbound TCP connections.
28///
29/// Bind requests are translated into incoming TCP listeners.
30pub struct OsStreamRouter {
31    _private: ()
32}
33
34impl OsStreamRouter {
35    pub fn new() -> Box<dyn StreamRouter + Send + Sync> {
36        Box::new(OsStreamRouter {
37            _private: ()
38        })
39    }
40}
41
42fn set_opt_from_sockaddr(dst: Option<&mut String>, f: impl FnOnce() -> IoResult<std::net::SocketAddr>) {
43    if let Some(dst) = dst {
44        *dst = match f() {
45            Ok(addr) => addr.to_string(),
46            Err(_) => "error".into(),
47        }
48    }
49}
50
51impl StreamRouter for OsStreamRouter {
52    fn basic_streams(&self) -> Vec<Box<dyn AsyncStream>> {
53        vec![
54            Box::new(ReadOnly { inner: Stdin }),
55            Box::new(WriteOnly { inner: tokio::io::stdout() }),
56            Box::new(WriteOnly { inner: tokio::io::stderr() }),
57        ]
58    }
59
60    fn connect_stream<'future>(
61        &'future self,
62        addr: &'future str,
63        local_addr: Option<&'future mut String>,
64        peer_addr: Option<&'future mut String>,
65    ) -> std::pin::Pin<Box<dyn Future<Output = IoResult<Box<dyn AsyncStream>>> + Send +'future>> {
66        (async move {
67            let stream = tokio::net::TcpStream::connect(addr).await?;
68
69            set_opt_from_sockaddr(local_addr, || stream.local_addr());
70            set_opt_from_sockaddr(peer_addr, || stream.peer_addr());
71            
72            Ok(Box::new(stream) as _)
73        }).boxed()
74    }
75
76    fn bind_stream<'future>(
77        &'future self,
78        addr: &'future str,
79        local_addr: Option<&'future mut String>,
80    ) -> std::pin::Pin<Box<dyn Future<Output = IoResult<Box<dyn AsyncListener>>> + Send + 'future>> {
81        (async move {
82            let socket = tokio::net::TcpListener::bind(addr).await?;
83
84            set_opt_from_sockaddr(local_addr, || socket.local_addr());
85
86            Ok(Box::new(socket) as _)
87        }).boxed()
88    }
89}
90
91impl AsyncListener for tokio::net::TcpListener {
92    fn poll_accept(
93        self: Pin<&mut Self>,
94        cx: &mut Context,
95        local_addr: Option<&mut String>,
96        peer_addr: Option<&mut String>,
97    ) -> Poll<tokio::io::Result<Box<dyn AsyncStream>>> {
98        tokio::net::TcpListener::poll_accept(&self, cx).map_ok(|(stream, stream_peer_addr)| {
99                set_opt_from_sockaddr(local_addr, || stream.local_addr());
100                set_opt_from_sockaddr(peer_addr, move || Ok(stream_peer_addr));
101
102                Box::new(stream) as _
103        })
104    }
105}
106
107pin_project! {
108    struct ReadOnly<R> {
109        #[pin]
110        inner: R
111    }
112}
113pin_project! {
114    struct WriteOnly<W> {
115        #[pin]
116        inner: W
117    }
118}
119
120macro_rules! forward {
121    (fn $n:ident(mut self: Pin<&mut Self> $(, $p:ident : $t:ty)*) -> $ret:ty) => {
122        fn $n(self: Pin<&mut Self> $(, $p: $t)*) -> $ret {
123            self.project().inner.$n($($p),*)
124        }
125    }
126}
127
128impl<R: std::marker::Unpin + AsyncRead> AsyncRead for ReadOnly<R> {
129    forward!(fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<tokio::io::Result<()>>);
130}
131
132impl<T> AsyncRead for WriteOnly<T> {
133    fn poll_read(self: Pin<&mut Self>, _cx: &mut Context, _buf: &mut ReadBuf) -> Poll<tokio::io::Result<()>> {
134        Poll::Ready(Err(IoErrorKind::BrokenPipe.into()))
135    }
136}
137
138impl<T> AsyncWrite for ReadOnly<T> {
139    fn poll_write(self: Pin<&mut Self>, _cx: &mut Context, _buf: &[u8]) -> Poll<tokio::io::Result<usize>> {
140        Poll::Ready(Err(IoErrorKind::BrokenPipe.into()))
141    }
142
143    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<tokio::io::Result<()>> {
144        Poll::Ready(Err(IoErrorKind::BrokenPipe.into()))
145    }
146
147    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<tokio::io::Result<()>> {
148        Poll::Ready(Err(IoErrorKind::BrokenPipe.into()))
149    }
150}
151
152impl<W: std::marker::Unpin + AsyncWrite> AsyncWrite for WriteOnly<W> {
153    forward!(fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<tokio::io::Result<usize>>);
154    forward!(fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>>);
155    forward!(fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>>);
156}
157
158struct Stdin;
159
160impl AsyncRead for Stdin {
161    fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<tokio::io::Result<()>> {
162        const BUF_SIZE: usize = 8192;
163
164        struct AsyncStdin {
165            rx: async_mpsc::Receiver<VecDeque<u8>>,
166            buf: VecDeque<u8>,
167        }
168
169        lazy_static::lazy_static! {
170            static ref STDIN: Mutex<AsyncStdin> = {
171                let (tx, rx) = async_mpsc::channel(8);
172                thread::spawn(move || {
173                    let mut buf = [0u8; BUF_SIZE];
174                    while let Ok(len) = io::stdin().read(&mut buf) {
175                        if len == 0 {
176                            continue
177                        }
178
179                        if tx.try_send(buf[..len].to_vec().into()).is_err() {
180                            return
181                        };
182                    }
183                });
184                Mutex::new(AsyncStdin { rx, buf: VecDeque::new() })
185            };
186        }
187
188        match Pin::new(&mut STDIN.lock()).poll(cx) {
189            Poll::Ready(mut stdin) => {
190                if stdin.buf.is_empty() {
191                    let pipeerr = tokio::io::Error::new(tokio::io::ErrorKind::BrokenPipe, "broken pipe");
192                    stdin.buf = match Pin::new(&mut stdin.rx).poll_recv(cx) {
193                        Poll::Ready(Some(vec)) => vec,
194                        Poll::Ready(None) => return Poll::Ready(Err(pipeerr)),
195                        _ => return Poll::Pending,
196                    };
197                }
198                let inbuf = match stdin.buf.as_slices() {
199                    (&[], inbuf) => inbuf,
200                    (inbuf, _) => inbuf,
201                };
202                let len = cmp::min(buf.remaining(), inbuf.len());
203                buf.put_slice(&inbuf[..len]);
204                stdin.buf.drain(..len);
205                Poll::Ready(Ok(()))
206            }
207            Poll::Pending => Poll::Pending
208        }
209    }
210}