1use std::{
2 fmt, io,
3 ops::{Deref, DerefMut},
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use futures_core::Stream;
9use futures_sink::Sink;
10use tokio::{
11 io::{AsyncRead, ReadBuf},
12 net::{
13 tcp::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, WriteHalf},
14 TcpStream,
15 },
16};
17
18use crate::{
19 AsyncDestination, AsyncFrameDestination, AsyncProstReader, AsyncProstWriter, SyncDestination,
20};
21
22#[derive(Debug)]
24pub struct AsyncProstStream<S, R, W, D> {
25 stream: AsyncProstReader<InternalAsyncWriter<S, W, D>, R, D>,
26}
27
28#[doc(hidden)]
29pub struct InternalAsyncWriter<S, T, D>(AsyncProstWriter<S, T, D>);
30
31impl<S: fmt::Debug, T, D> fmt::Debug for InternalAsyncWriter<S, T, D> {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 self.get_ref().fmt(f)
34 }
35}
36
37impl<S, T, D> Deref for InternalAsyncWriter<S, T, D> {
38 type Target = AsyncProstWriter<S, T, D>;
39
40 fn deref(&self) -> &Self::Target {
41 &self.0
42 }
43}
44impl<S, T, D> DerefMut for InternalAsyncWriter<S, T, D> {
45 fn deref_mut(&mut self) -> &mut Self::Target {
46 &mut self.0
47 }
48}
49
50impl<S, R, W> Default for AsyncProstStream<S, R, W, SyncDestination>
51where
52 S: Default,
53{
54 fn default() -> Self {
55 Self::from(S::default())
56 }
57}
58
59impl<S, R, W> From<S> for AsyncProstStream<S, R, W, SyncDestination> {
60 fn from(stream: S) -> Self {
61 Self {
62 stream: AsyncProstReader::from(InternalAsyncWriter(AsyncProstWriter::from(stream))),
63 }
64 }
65}
66
67impl<S, R, W, D> AsyncProstStream<S, R, W, D> {
68 pub fn get_ref(&self) -> &S {
72 self.stream.get_ref().0.get_ref()
73 }
74
75 pub fn get_mut(&mut self) -> &mut S {
79 self.stream.get_mut().0.get_mut()
80 }
81
82 pub fn into_inner(self) -> S {
87 self.stream.into_inner().0.into_inner()
88 }
89}
90
91impl<S, R, W, D> AsyncProstStream<S, R, W, D> {
92 pub fn for_async(self) -> AsyncProstStream<S, R, W, AsyncDestination> {
94 let stream = self.into_inner();
95 AsyncProstStream {
96 stream: AsyncProstReader::from(InternalAsyncWriter(
97 AsyncProstWriter::from(stream).for_async(),
98 )),
99 }
100 }
101
102 pub fn for_async_framed(self) -> AsyncProstStream<S, R, W, AsyncFrameDestination> {
104 let stream = self.into_inner();
105 AsyncProstStream {
106 stream: AsyncProstReader::from(InternalAsyncWriter(
107 AsyncProstWriter::from(stream).for_async_framed(),
108 )),
109 }
110 }
111
112 pub fn for_sync(self) -> AsyncProstStream<S, R, W, SyncDestination> {
114 AsyncProstStream::from(self.into_inner())
115 }
116}
117
118impl<R, W, D> AsyncProstStream<TcpStream, R, W, D> {
119 pub fn tcp_split(
121 &mut self,
122 ) -> (
123 AsyncProstReader<ReadHalf, R, D>,
124 AsyncProstWriter<WriteHalf, W, D>,
125 ) {
126 let rbuff = self.stream.buffer.split();
128 let writer = &mut self.stream.get_mut().0;
130 let wbuff = writer.buffer.split_off(0);
132 let wsize = writer.written;
133 let (r, w) = writer.get_mut().split();
135 let mut reader = AsyncProstReader::from(r);
137 reader.buffer = rbuff;
138 let mut writer = AsyncProstWriter::from(w).make_for();
140 writer.buffer = wbuff;
141 writer.written = wsize;
142
143 (reader, writer)
144 }
145
146 pub fn tcp_into_split(
150 mut self,
151 ) -> (
152 AsyncProstReader<OwnedReadHalf, R, D>,
153 AsyncProstWriter<OwnedWriteHalf, W, D>,
154 ) {
155 let rbuff = self.stream.buffer.split();
157 let mut writer = self.stream.into_inner().0;
159 let wbuff = writer.buffer.split_off(0);
161 let wsize = writer.written;
162 let (r, w) = writer.into_inner().into_split();
164 let mut reader = AsyncProstReader::from(r);
166 reader.buffer = rbuff;
167 let mut writer = AsyncProstWriter::from(w).make_for();
169 writer.buffer = wbuff;
170 writer.written = wsize;
171
172 (reader, writer)
173 }
174}
175
176impl<S, T, D> AsyncRead for InternalAsyncWriter<S, T, D>
177where
178 S: AsyncRead + Unpin,
179{
180 fn poll_read(
181 self: Pin<&mut Self>,
182 cx: &mut Context<'_>,
183 buf: &mut ReadBuf<'_>,
184 ) -> Poll<std::io::Result<()>> {
185 Pin::new(self.get_mut().get_mut()).poll_read(cx, buf)
186 }
187}
188
189impl<S, R, W, D> Stream for AsyncProstStream<S, R, W, D>
190where
191 S: Unpin,
192 AsyncProstReader<InternalAsyncWriter<S, W, D>, R, D>: Stream<Item = Result<R, io::Error>>,
193{
194 type Item = Result<R, io::Error>;
195
196 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
197 Pin::new(&mut self.stream).poll_next(cx)
198 }
199}
200
201impl<S, R, W, D> Sink<W> for AsyncProstStream<S, R, W, D>
202where
203 S: Unpin,
204 AsyncProstWriter<S, W, D>: Sink<W, Error = io::Error>,
205{
206 type Error = io::Error;
207
208 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
209 Pin::new(&mut **self.stream.get_mut()).poll_ready(cx)
210 }
211
212 fn start_send(mut self: Pin<&mut Self>, item: W) -> Result<(), Self::Error> {
213 Pin::new(&mut **self.stream.get_mut()).start_send(item)
214 }
215
216 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
217 Pin::new(&mut **self.stream.get_mut()).poll_flush(cx)
218 }
219
220 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
221 Pin::new(&mut **self.stream.get_mut()).poll_close(cx)
222 }
223}