Skip to main content

tfserver/structures/
transport.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
5use tokio::net::TcpStream;
6use tokio_rustls::{client::TlsStream as ClientTlsStream, server::TlsStream as ServerTlsStream};
7
8/// Unified transport wrapper, for different types of streams
9pub struct Transport {
10    inner: Box<dyn AsyncReadWrite>,
11}
12
13/// Trait object to unify AsyncRead + AsyncWrite
14pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
15impl<T: AsyncRead + AsyncWrite + ?Sized + Send + Sync + Unpin + 'static> AsyncReadWrite for T {}
16
17impl Transport {
18    /// Wrap a plain TcpStream
19    pub fn plain(stream: TcpStream) -> Self {
20        Self {
21            inner: Box::new(stream),
22        }
23    }
24
25    /// Wrap a server-side TLS stream
26    pub fn tls_server(stream: ServerTlsStream<TcpStream>) -> Self {
27        Self {
28            inner: Box::new(stream),
29        }
30    }
31
32    /// Wrap a client-side TLS stream
33    pub fn tls_client(stream: ClientTlsStream<TcpStream>) -> Self {
34        Self {
35            inner: Box::new(stream),
36        }
37    }
38
39    /// Optionally expose inner (if needed)
40    pub fn inner(&mut self) -> &mut dyn AsyncReadWrite {
41        &mut *self.inner
42    }
43}
44
45impl AsyncRead for Transport {
46    fn poll_read(
47        mut self: Pin<&mut Self>,
48        cx: &mut Context<'_>,
49        buf: &mut ReadBuf<'_>,
50    ) -> Poll<io::Result<()>> {
51        Pin::new(&mut *self.inner).poll_read(cx, buf)
52    }
53}
54
55impl AsyncWrite for Transport {
56    fn poll_write(
57        mut self: Pin<&mut Self>,
58        cx: &mut Context<'_>,
59        buf: &[u8],
60    ) -> Poll<io::Result<usize>> {
61        Pin::new(&mut *self.inner).poll_write(cx, buf)
62    }
63
64    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
65        Pin::new(&mut *self.inner).poll_flush(cx)
66    }
67
68    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
69        Pin::new(&mut *self.inner).poll_shutdown(cx)
70    }
71}