async_prost/
stream.rs

1use std::{
2    fmt, io,
3    ops::{Deref, DerefMut},
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use futures_core::Stream;
9use futures_sink::Sink;
10use tokio::{
11    io::{AsyncRead, ReadBuf},
12    net::{
13        tcp::{ReadHalf, WriteHalf},
14        TcpStream,
15    },
16};
17
18use crate::{
19    AsyncDestination, AsyncFrameDestination, AsyncProstReader, AsyncProstWriter, SyncDestination,
20};
21
22/// A wrapper around an async stream that receives and sends prost-encoded values
23#[derive(Debug)]
24pub struct AsyncProstStream<S, R, W, D> {
25    stream: AsyncProstReader<InternalAsyncWriter<S, W, D>, R, D>,
26}
27
28#[doc(hidden)]
29pub struct InternalAsyncWriter<S, T, D>(AsyncProstWriter<S, T, D>);
30
31impl<S: fmt::Debug, T, D> fmt::Debug for InternalAsyncWriter<S, T, D> {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        self.get_ref().fmt(f)
34    }
35}
36
37impl<S, T, D> Deref for InternalAsyncWriter<S, T, D> {
38    type Target = AsyncProstWriter<S, T, D>;
39
40    fn deref(&self) -> &Self::Target {
41        &self.0
42    }
43}
44impl<S, T, D> DerefMut for InternalAsyncWriter<S, T, D> {
45    fn deref_mut(&mut self) -> &mut Self::Target {
46        &mut self.0
47    }
48}
49
50impl<S, R, W> Default for AsyncProstStream<S, R, W, SyncDestination>
51where
52    S: Default,
53{
54    fn default() -> Self {
55        Self::from(S::default())
56    }
57}
58
59impl<S, R, W> From<S> for AsyncProstStream<S, R, W, SyncDestination> {
60    fn from(stream: S) -> Self {
61        Self {
62            stream: AsyncProstReader::from(InternalAsyncWriter(AsyncProstWriter::from(stream))),
63        }
64    }
65}
66
67impl<S, R, W, D> AsyncProstStream<S, R, W, D> {
68    /// Gets a reference to the underlying stream.
69    ///
70    /// It is inadvisable to directly read from or write to the underlying stream.
71    pub fn get_ref(&self) -> &S {
72        self.stream.get_ref().0.get_ref()
73    }
74
75    /// Gets a mutable reference to the underlying stream.
76    ///
77    /// It is inadvisable to directly read from or write to the underlying stream.
78    pub fn get_mut(&mut self) -> &mut S {
79        self.stream.get_mut().0.get_mut()
80    }
81
82    /// Unwraps this `AsyncProstStream`, returning the underlying stream.
83    ///
84    /// Note that any leftover serialized data that has not yet been sent, or received data that
85    /// has not yet been deserialized, is lost.
86    pub fn into_inner(self) -> S {
87        self.stream.into_inner().0.into_inner()
88    }
89}
90
91impl<S, R, W, D> AsyncProstStream<S, R, W, D> {
92    /// make this stream include the serialized data's size before each serialized value
93    pub fn for_async(self) -> AsyncProstStream<S, R, W, AsyncDestination> {
94        let stream = self.into_inner();
95        AsyncProstStream {
96            stream: AsyncProstReader::from(InternalAsyncWriter(
97                AsyncProstWriter::from(stream).for_async(),
98            )),
99        }
100    }
101
102    /// make this stream include the serialized data's size before each serialized value
103    pub fn for_async_framed(self) -> AsyncProstStream<S, R, W, AsyncFrameDestination> {
104        let stream = self.into_inner();
105        AsyncProstStream {
106            stream: AsyncProstReader::from(InternalAsyncWriter(
107                AsyncProstWriter::from(stream).for_async_framed(),
108            )),
109        }
110    }
111
112    /// Make this stream only send prost-encoded values
113    pub fn for_sync(self) -> AsyncProstStream<S, R, W, SyncDestination> {
114        AsyncProstStream::from(self.into_inner())
115    }
116}
117
118impl<R, W, D> AsyncProstStream<TcpStream, R, W, D> {
119    /// split a TCP-based stream into a read half and a write half
120    pub fn tcp_split(
121        &mut self,
122    ) -> (
123        AsyncProstReader<ReadHalf, R, D>,
124        AsyncProstWriter<WriteHalf, W, D>,
125    ) {
126        // first, steal the reader state so it isn't lost
127        let rbuff = self.stream.buffer.split();
128        // then fish out the writer
129        let writer = &mut self.stream.get_mut().0;
130        // and steal the writer state so it isn't lost
131        let wbuff = writer.buffer.split_off(0);
132        let wsize = writer.written;
133        // now split the stream
134        let (r, w) = writer.get_mut().split();
135        // then put the reader back together
136        let mut reader = AsyncProstReader::from(r);
137        reader.buffer = rbuff;
138        // and then writer
139        let mut writer = AsyncProstWriter::from(w).make_for();
140        writer.buffer = wbuff;
141        writer.written = wsize;
142
143        (reader, writer)
144    }
145}
146
147impl<S, T, D> AsyncRead for InternalAsyncWriter<S, T, D>
148where
149    S: AsyncRead + Unpin,
150{
151    fn poll_read(
152        self: Pin<&mut Self>,
153        cx: &mut Context<'_>,
154        buf: &mut ReadBuf<'_>,
155    ) -> Poll<std::io::Result<()>> {
156        Pin::new(self.get_mut().get_mut()).poll_read(cx, buf)
157    }
158}
159
160impl<S, R, W, D> Stream for AsyncProstStream<S, R, W, D>
161where
162    S: Unpin,
163    AsyncProstReader<InternalAsyncWriter<S, W, D>, R, D>: Stream<Item = Result<R, io::Error>>,
164{
165    type Item = Result<R, io::Error>;
166
167    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
168        Pin::new(&mut self.stream).poll_next(cx)
169    }
170}
171
172impl<S, R, W, D> Sink<W> for AsyncProstStream<S, R, W, D>
173where
174    S: Unpin,
175    AsyncProstWriter<S, W, D>: Sink<W, Error = io::Error>,
176{
177    type Error = io::Error;
178
179    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
180        Pin::new(&mut **self.stream.get_mut()).poll_ready(cx)
181    }
182
183    fn start_send(mut self: Pin<&mut Self>, item: W) -> Result<(), Self::Error> {
184        Pin::new(&mut **self.stream.get_mut()).start_send(item)
185    }
186
187    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
188        Pin::new(&mut **self.stream.get_mut()).poll_flush(cx)
189    }
190
191    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
192        Pin::new(&mut **self.stream.get_mut()).poll_close(cx)
193    }
194}