enclave_runner/stream_router/
os.rs1use std::cmp;
8use std::collections::VecDeque;
9use std::future::Future;
10use std::io::{self, ErrorKind as IoErrorKind, Read, Result as IoResult};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use std::thread;
14
15use futures::FutureExt;
16use futures::lock::Mutex;
17use pin_project_lite::pin_project;
18use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19use tokio::sync::{mpsc as async_mpsc};
20
21use super::{AsyncListener, AsyncStream, StreamRouter};
22
23pub struct OsStreamRouter {
31 _private: ()
32}
33
34impl OsStreamRouter {
35 pub fn new() -> Box<dyn StreamRouter + Send + Sync> {
36 Box::new(OsStreamRouter {
37 _private: ()
38 })
39 }
40}
41
42fn set_opt_from_sockaddr(dst: Option<&mut String>, f: impl FnOnce() -> IoResult<std::net::SocketAddr>) {
43 if let Some(dst) = dst {
44 *dst = match f() {
45 Ok(addr) => addr.to_string(),
46 Err(_) => "error".into(),
47 }
48 }
49}
50
51impl StreamRouter for OsStreamRouter {
52 fn basic_streams(&self) -> Vec<Box<dyn AsyncStream>> {
53 vec![
54 Box::new(ReadOnly { inner: Stdin }),
55 Box::new(WriteOnly { inner: tokio::io::stdout() }),
56 Box::new(WriteOnly { inner: tokio::io::stderr() }),
57 ]
58 }
59
60 fn connect_stream<'future>(
61 &'future self,
62 addr: &'future str,
63 local_addr: Option<&'future mut String>,
64 peer_addr: Option<&'future mut String>,
65 ) -> std::pin::Pin<Box<dyn Future<Output = IoResult<Box<dyn AsyncStream>>> + Send +'future>> {
66 (async move {
67 let stream = tokio::net::TcpStream::connect(addr).await?;
68
69 set_opt_from_sockaddr(local_addr, || stream.local_addr());
70 set_opt_from_sockaddr(peer_addr, || stream.peer_addr());
71
72 Ok(Box::new(stream) as _)
73 }).boxed()
74 }
75
76 fn bind_stream<'future>(
77 &'future self,
78 addr: &'future str,
79 local_addr: Option<&'future mut String>,
80 ) -> std::pin::Pin<Box<dyn Future<Output = IoResult<Box<dyn AsyncListener>>> + Send + 'future>> {
81 (async move {
82 let socket = tokio::net::TcpListener::bind(addr).await?;
83
84 set_opt_from_sockaddr(local_addr, || socket.local_addr());
85
86 Ok(Box::new(socket) as _)
87 }).boxed()
88 }
89}
90
91impl AsyncListener for tokio::net::TcpListener {
92 fn poll_accept(
93 self: Pin<&mut Self>,
94 cx: &mut Context,
95 local_addr: Option<&mut String>,
96 peer_addr: Option<&mut String>,
97 ) -> Poll<tokio::io::Result<Box<dyn AsyncStream>>> {
98 tokio::net::TcpListener::poll_accept(&self, cx).map_ok(|(stream, stream_peer_addr)| {
99 set_opt_from_sockaddr(local_addr, || stream.local_addr());
100 set_opt_from_sockaddr(peer_addr, move || Ok(stream_peer_addr));
101
102 Box::new(stream) as _
103 })
104 }
105}
106
107pin_project! {
108 struct ReadOnly<R> {
109 #[pin]
110 inner: R
111 }
112}
113pin_project! {
114 struct WriteOnly<W> {
115 #[pin]
116 inner: W
117 }
118}
119
120macro_rules! forward {
121 (fn $n:ident(mut self: Pin<&mut Self> $(, $p:ident : $t:ty)*) -> $ret:ty) => {
122 fn $n(self: Pin<&mut Self> $(, $p: $t)*) -> $ret {
123 self.project().inner.$n($($p),*)
124 }
125 }
126}
127
128impl<R: std::marker::Unpin + AsyncRead> AsyncRead for ReadOnly<R> {
129 forward!(fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<tokio::io::Result<()>>);
130}
131
132impl<T> AsyncRead for WriteOnly<T> {
133 fn poll_read(self: Pin<&mut Self>, _cx: &mut Context, _buf: &mut ReadBuf) -> Poll<tokio::io::Result<()>> {
134 Poll::Ready(Err(IoErrorKind::BrokenPipe.into()))
135 }
136}
137
138impl<T> AsyncWrite for ReadOnly<T> {
139 fn poll_write(self: Pin<&mut Self>, _cx: &mut Context, _buf: &[u8]) -> Poll<tokio::io::Result<usize>> {
140 Poll::Ready(Err(IoErrorKind::BrokenPipe.into()))
141 }
142
143 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<tokio::io::Result<()>> {
144 Poll::Ready(Err(IoErrorKind::BrokenPipe.into()))
145 }
146
147 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<tokio::io::Result<()>> {
148 Poll::Ready(Err(IoErrorKind::BrokenPipe.into()))
149 }
150}
151
152impl<W: std::marker::Unpin + AsyncWrite> AsyncWrite for WriteOnly<W> {
153 forward!(fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<tokio::io::Result<usize>>);
154 forward!(fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>>);
155 forward!(fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>>);
156}
157
158struct Stdin;
159
160impl AsyncRead for Stdin {
161 fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<tokio::io::Result<()>> {
162 const BUF_SIZE: usize = 8192;
163
164 struct AsyncStdin {
165 rx: async_mpsc::Receiver<VecDeque<u8>>,
166 buf: VecDeque<u8>,
167 }
168
169 lazy_static::lazy_static! {
170 static ref STDIN: Mutex<AsyncStdin> = {
171 let (tx, rx) = async_mpsc::channel(8);
172 thread::spawn(move || {
173 let mut buf = [0u8; BUF_SIZE];
174 while let Ok(len) = io::stdin().read(&mut buf) {
175 if len == 0 {
176 continue
177 }
178
179 if tx.try_send(buf[..len].to_vec().into()).is_err() {
180 return
181 };
182 }
183 });
184 Mutex::new(AsyncStdin { rx, buf: VecDeque::new() })
185 };
186 }
187
188 match Pin::new(&mut STDIN.lock()).poll(cx) {
189 Poll::Ready(mut stdin) => {
190 if stdin.buf.is_empty() {
191 let pipeerr = tokio::io::Error::new(tokio::io::ErrorKind::BrokenPipe, "broken pipe");
192 stdin.buf = match Pin::new(&mut stdin.rx).poll_recv(cx) {
193 Poll::Ready(Some(vec)) => vec,
194 Poll::Ready(None) => return Poll::Ready(Err(pipeerr)),
195 _ => return Poll::Pending,
196 };
197 }
198 let inbuf = match stdin.buf.as_slices() {
199 (&[], inbuf) => inbuf,
200 (inbuf, _) => inbuf,
201 };
202 let len = cmp::min(buf.remaining(), inbuf.len());
203 buf.put_slice(&inbuf[..len]);
204 stdin.buf.drain(..len);
205 Poll::Ready(Ok(()))
206 }
207 Poll::Pending => Poll::Pending
208 }
209 }
210}