1use std::future::Future;
4use std::io::Error;
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use env::tls::TlsServerConfig;
10
11use crate::env::{self, tasks, Result};
12use crate::{ocall, ResourceId};
13
14pub struct TcpListener {
16 res_id: ResourceId,
17}
18
19#[derive(Debug)]
21pub struct TcpConnector {
22 res: Result<ResourceId, env::OcallError>,
23}
24
25#[derive(Debug)]
27pub struct TcpStream {
28 res_id: ResourceId,
29}
30
31pub struct Acceptor<'a> {
33 listener: &'a TcpListener,
34}
35
36impl Future for Acceptor<'_> {
37 type Output = Result<(TcpStream, SocketAddr)>;
38
39 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40 use env::OcallError;
41 let waker_id = tasks::intern_waker(cx.waker().clone());
42 match ocall::tcp_accept(waker_id, self.listener.res_id.0) {
43 Ok((res_id, remote_addr)) => Poll::Ready(Ok((
44 TcpStream::new(ResourceId(res_id)),
45 remote_addr
46 .parse()
47 .expect("ocall::tcp_accept returned an invalid remote address"),
48 ))),
49 Err(OcallError::Pending) => Poll::Pending,
50 Err(err) => Poll::Ready(Err(err)),
51 }
52 }
53}
54
55impl TcpListener {
56 pub async fn bind(addr: &str) -> Result<Self> {
58 let todo = "prevent local interface probing and port occupation";
61 let res_id = ResourceId(ocall::tcp_listen(addr.into(), None)?);
62 Ok(Self { res_id })
63 }
64
65 pub async fn bind_tls(addr: &str, config: TlsServerConfig) -> Result<Self> {
67 let res_id = ResourceId(ocall::tcp_listen(addr.into(), Some(config))?);
68 Ok(Self { res_id })
69 }
70
71 pub fn accept(&self) -> Acceptor {
73 Acceptor { listener: self }
74 }
75}
76
77impl Future for TcpConnector {
78 type Output = Result<TcpStream>;
79
80 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
81 use env::OcallError;
82
83 let res_id = match &self.get_mut().res {
84 Ok(res_id) => res_id,
85 Err(err) => return Poll::Ready(Err(*err)),
86 };
87
88 match ocall::poll_res(env::tasks::intern_waker(ctx.waker().clone()), res_id.0) {
89 Ok(res_id) => Poll::Ready(Ok(TcpStream::new(ResourceId(res_id)))),
90 Err(OcallError::Pending) => Poll::Pending,
91 Err(err) => Poll::Ready(Err(err)),
92 }
93 }
94}
95
96impl TcpStream {
97 fn new(res_id: ResourceId) -> Self {
98 Self { res_id }
99 }
100
101 pub fn connect(host: &str, port: u16, enable_tls: bool) -> TcpConnector {
103 let res = if enable_tls {
104 ocall::tcp_connect_tls(host.into(), port, env::tls::TlsClientConfig::V0)
105 } else {
106 ocall::tcp_connect(host, port)
107 };
108 let res = res.map(|res_id| ResourceId(res_id));
109 TcpConnector { res }
110 }
111}
112
113#[cfg(feature = "hyper")]
114pub use impl_hyper::{AddrIncoming, AddrStream, HttpConnector};
115#[cfg(feature = "hyper")]
116mod impl_hyper {
117 use super::*;
118 use env::OcallError;
119 use hyper::client::connect::{Connected, Connection};
120 use hyper::server::accept::Accept;
121 use hyper::{service::Service, Uri};
122 use std::{io, task};
123
124 macro_rules! ready_ok {
125 ($poll: expr) => {
126 match $poll {
127 Poll::Ready(Ok(val)) => val,
128 Poll::Ready(Err(OcallError::EndOfFile)) => return Poll::Ready(None),
129 Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
130 Poll::Pending => return Poll::Pending,
131 }
132 };
133 }
134
135 impl Accept for TcpListener {
136 type Conn = TcpStream;
137
138 type Error = env::OcallError;
139
140 fn poll_accept(
141 self: Pin<&mut Self>,
142 cx: &mut Context<'_>,
143 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
144 let (conn, _addr) = ready_ok!(Pin::new(&mut self.accept()).poll(cx));
145 Poll::Ready(Some(Ok(conn)))
146 }
147 }
148
149 impl Connection for TcpStream {
150 fn connected(&self) -> Connected {
151 Connected::new()
152 }
153 }
154
155 impl TcpListener {
156 pub fn into_addr_incoming(self) -> AddrIncoming {
158 AddrIncoming { listener: self }
159 }
160 }
161
162 #[derive(Clone, Default, Debug)]
164 pub struct HttpConnector;
165
166 impl HttpConnector {
167 pub fn new() -> Self {
169 Self
170 }
171 }
172
173 impl Service<Uri> for HttpConnector {
174 type Response = TcpStream;
175 type Error = OcallError;
176 type Future = TcpConnector;
177
178 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
179 Poll::Ready(Ok(()))
180 }
181
182 fn call(&mut self, dst: Uri) -> Self::Future {
183 let is_https = dst.scheme_str() == Some("https");
184 let host = dst
185 .host()
186 .unwrap_or("")
187 .trim_matches(|c| c == '[' || c == ']');
188 let port = dst.port_u16().unwrap_or(if is_https { 443 } else { 80 });
189 TcpStream::connect(host, port, is_https)
190 }
191 }
192
193 pub struct AddrIncoming {
195 listener: TcpListener,
196 }
197
198 impl Accept for AddrIncoming {
199 type Conn = AddrStream;
200 type Error = env::OcallError;
201
202 fn poll_accept(
203 self: Pin<&mut Self>,
204 cx: &mut task::Context<'_>,
205 ) -> task::Poll<Option<Result<Self::Conn, Self::Error>>> {
206 let (stream, remote_addr) = ready_ok!(Pin::new(&mut self.listener.accept()).poll(cx));
207 Poll::Ready(Some(Ok(AddrStream {
208 stream,
209 remote_addr,
210 })))
211 }
212 }
213
214 #[pin_project::pin_project]
216 pub struct AddrStream {
217 #[pin]
218 stream: TcpStream,
219 remote_addr: SocketAddr,
220 }
221
222 impl AddrStream {
223 pub fn remote_addr(&self) -> SocketAddr {
225 self.remote_addr
226 }
227 }
228
229 #[cfg(feature = "tokio")]
230 const _: () = {
231 use tokio::io::{AsyncRead, AsyncWrite};
232 impl AsyncRead for AddrStream {
233 fn poll_read(
234 self: Pin<&mut Self>,
235 cx: &mut task::Context<'_>,
236 buf: &mut tokio::io::ReadBuf<'_>,
237 ) -> task::Poll<io::Result<()>> {
238 self.project().stream.poll_read(cx, buf)
239 }
240 }
241
242 impl AsyncWrite for AddrStream {
243 fn poll_write(
244 self: Pin<&mut Self>,
245 cx: &mut task::Context<'_>,
246 buf: &[u8],
247 ) -> task::Poll<io::Result<usize>> {
248 self.project().stream.poll_write(cx, buf)
249 }
250
251 fn poll_flush(
252 self: Pin<&mut Self>,
253 cx: &mut task::Context<'_>,
254 ) -> task::Poll<io::Result<()>> {
255 self.project().stream.poll_flush(cx)
256 }
257
258 fn poll_shutdown(
259 self: Pin<&mut Self>,
260 cx: &mut task::Context<'_>,
261 ) -> task::Poll<io::Result<()>> {
262 self.project().stream.poll_shutdown(cx)
263 }
264 }
265 };
266
267 const _: () = {
268 use futures::{AsyncRead, AsyncWrite};
269
270 impl AsyncRead for AddrStream {
271 fn poll_read(
272 self: Pin<&mut Self>,
273 cx: &mut Context<'_>,
274 buf: &mut [u8],
275 ) -> Poll<io::Result<usize>> {
276 self.project().stream.poll_read(cx, buf)
277 }
278 }
279
280 impl AsyncWrite for AddrStream {
281 fn poll_write(
282 self: Pin<&mut Self>,
283 cx: &mut Context<'_>,
284 buf: &[u8],
285 ) -> Poll<io::Result<usize>> {
286 self.project().stream.poll_write(cx, buf)
287 }
288
289 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
290 self.project().stream.poll_flush(cx)
291 }
292
293 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
294 self.project().stream.poll_close(cx)
295 }
296 }
297 };
298}
299
300#[cfg(feature = "tokio")]
301mod impl_tokio {
302 use super::*;
303 use tokio::io::{AsyncRead, AsyncWrite};
304
305 impl AsyncRead for TcpStream {
306 fn poll_read(
307 self: Pin<&mut Self>,
308 cx: &mut Context<'_>,
309 buf: &mut tokio::io::ReadBuf<'_>,
310 ) -> Poll<std::io::Result<()>> {
311 let result = {
312 let size = buf.remaining().min(512);
313 let buf = buf.initialize_unfilled_to(size);
314 let waker_id = tasks::intern_waker(cx.waker().clone());
315 ocall::poll_read(waker_id, self.res_id.0, buf)
316 };
317 use env::OcallError;
318 match result {
319 Ok(len) => {
320 let len = len as usize;
321 if len > buf.remaining() {
322 Poll::Ready(Err(Error::from_raw_os_error(
323 env::OcallError::InvalidEncoding as i32,
324 )))
325 } else {
326 buf.advance(len);
327 Poll::Ready(Ok(()))
328 }
329 }
330 Err(OcallError::Pending) => Poll::Pending,
331 Err(err) => Poll::Ready(Err(Error::from_raw_os_error(err as i32))),
332 }
333 }
334 }
335
336 impl AsyncWrite for TcpStream {
337 fn poll_write(
338 self: Pin<&mut Self>,
339 cx: &mut Context<'_>,
340 buf: &[u8],
341 ) -> Poll<Result<usize, Error>> {
342 let waker_id = tasks::intern_waker(cx.waker().clone());
343 into_poll(ocall::poll_write(waker_id, self.res_id.0, buf).map(|len| len as usize))
344 }
345
346 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
347 Poll::Ready(Ok(()))
348 }
349
350 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
351 let waker_id = tasks::intern_waker(cx.waker().clone());
352 into_poll(ocall::poll_shutdown(waker_id, self.res_id.0))
353 }
354 }
355}
356
357mod impl_futures_io {
358 use super::*;
359 use futures::io::{AsyncRead, AsyncWrite};
360
361 impl AsyncRead for TcpStream {
362 fn poll_read(
363 self: Pin<&mut Self>,
364 cx: &mut Context<'_>,
365 buf: &mut [u8],
366 ) -> Poll<std::io::Result<usize>> {
367 let waker_id = tasks::intern_waker(cx.waker().clone());
368 into_poll(ocall::poll_read(waker_id, self.res_id.0, buf).map(|len| len as usize))
369 }
370 }
371
372 impl AsyncWrite for TcpStream {
373 fn poll_write(
374 self: Pin<&mut Self>,
375 cx: &mut Context<'_>,
376 buf: &[u8],
377 ) -> Poll<std::io::Result<usize>> {
378 let waker_id = tasks::intern_waker(cx.waker().clone());
379 into_poll(ocall::poll_write(waker_id, self.res_id.0, buf).map(|len| len as usize))
380 }
381
382 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
383 Poll::Ready(Ok(()))
384 }
385
386 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
387 let waker_id = tasks::intern_waker(cx.waker().clone());
388 into_poll(ocall::poll_shutdown(waker_id, self.res_id.0))
389 }
390 }
391}
392
393fn into_poll<T>(res: Result<T, env::OcallError>) -> Poll<std::io::Result<T>> {
394 match res {
395 Ok(v) => Poll::Ready(Ok(v)),
396 Err(env::OcallError::Pending) => Poll::Pending,
397 Err(err) => Poll::Ready(Err(std::io::Error::from_raw_os_error(err as i32))),
398 }
399}