1use std::{
2    fmt, io,
3    ops::{Deref, DerefMut},
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use futures_core::Stream;
9use futures_sink::Sink;
10use tokio::{
11    io::{AsyncRead, ReadBuf},
12    net::{
13        tcp::{ReadHalf, WriteHalf},
14        TcpStream,
15    },
16};
17
18use crate::{
19    AsyncDestination, AsyncFrameDestination, AsyncProstReader, AsyncProstWriter, SyncDestination,
20};
21
22#[derive(Debug)]
24pub struct AsyncProstStream<S, R, W, D> {
25    stream: AsyncProstReader<InternalAsyncWriter<S, W, D>, R, D>,
26}
27
28#[doc(hidden)]
29pub struct InternalAsyncWriter<S, T, D>(AsyncProstWriter<S, T, D>);
30
31impl<S: fmt::Debug, T, D> fmt::Debug for InternalAsyncWriter<S, T, D> {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        self.get_ref().fmt(f)
34    }
35}
36
37impl<S, T, D> Deref for InternalAsyncWriter<S, T, D> {
38    type Target = AsyncProstWriter<S, T, D>;
39
40    fn deref(&self) -> &Self::Target {
41        &self.0
42    }
43}
44impl<S, T, D> DerefMut for InternalAsyncWriter<S, T, D> {
45    fn deref_mut(&mut self) -> &mut Self::Target {
46        &mut self.0
47    }
48}
49
50impl<S, R, W> Default for AsyncProstStream<S, R, W, SyncDestination>
51where
52    S: Default,
53{
54    fn default() -> Self {
55        Self::from(S::default())
56    }
57}
58
59impl<S, R, W> From<S> for AsyncProstStream<S, R, W, SyncDestination> {
60    fn from(stream: S) -> Self {
61        Self {
62            stream: AsyncProstReader::from(InternalAsyncWriter(AsyncProstWriter::from(stream))),
63        }
64    }
65}
66
67impl<S, R, W, D> AsyncProstStream<S, R, W, D> {
68    pub fn get_ref(&self) -> &S {
72        self.stream.get_ref().0.get_ref()
73    }
74
75    pub fn get_mut(&mut self) -> &mut S {
79        self.stream.get_mut().0.get_mut()
80    }
81
82    pub fn into_inner(self) -> S {
87        self.stream.into_inner().0.into_inner()
88    }
89}
90
91impl<S, R, W, D> AsyncProstStream<S, R, W, D> {
92    pub fn for_async(self) -> AsyncProstStream<S, R, W, AsyncDestination> {
94        let stream = self.into_inner();
95        AsyncProstStream {
96            stream: AsyncProstReader::from(InternalAsyncWriter(
97                AsyncProstWriter::from(stream).for_async(),
98            )),
99        }
100    }
101
102    pub fn for_async_framed(self) -> AsyncProstStream<S, R, W, AsyncFrameDestination> {
104        let stream = self.into_inner();
105        AsyncProstStream {
106            stream: AsyncProstReader::from(InternalAsyncWriter(
107                AsyncProstWriter::from(stream).for_async_framed(),
108            )),
109        }
110    }
111
112    pub fn for_sync(self) -> AsyncProstStream<S, R, W, SyncDestination> {
114        AsyncProstStream::from(self.into_inner())
115    }
116}
117
118impl<R, W, D> AsyncProstStream<TcpStream, R, W, D> {
119    pub fn tcp_split(
121        &mut self,
122    ) -> (
123        AsyncProstReader<ReadHalf, R, D>,
124        AsyncProstWriter<WriteHalf, W, D>,
125    ) {
126        let rbuff = self.stream.buffer.split();
128        let writer = &mut self.stream.get_mut().0;
130        let wbuff = writer.buffer.split_off(0);
132        let wsize = writer.written;
133        let (r, w) = writer.get_mut().split();
135        let mut reader = AsyncProstReader::from(r);
137        reader.buffer = rbuff;
138        let mut writer = AsyncProstWriter::from(w).make_for();
140        writer.buffer = wbuff;
141        writer.written = wsize;
142
143        (reader, writer)
144    }
145}
146
147impl<S, T, D> AsyncRead for InternalAsyncWriter<S, T, D>
148where
149    S: AsyncRead + Unpin,
150{
151    fn poll_read(
152        self: Pin<&mut Self>,
153        cx: &mut Context<'_>,
154        buf: &mut ReadBuf<'_>,
155    ) -> Poll<std::io::Result<()>> {
156        Pin::new(self.get_mut().get_mut()).poll_read(cx, buf)
157    }
158}
159
160impl<S, R, W, D> Stream for AsyncProstStream<S, R, W, D>
161where
162    S: Unpin,
163    AsyncProstReader<InternalAsyncWriter<S, W, D>, R, D>: Stream<Item = Result<R, io::Error>>,
164{
165    type Item = Result<R, io::Error>;
166
167    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
168        Pin::new(&mut self.stream).poll_next(cx)
169    }
170}
171
172impl<S, R, W, D> Sink<W> for AsyncProstStream<S, R, W, D>
173where
174    S: Unpin,
175    AsyncProstWriter<S, W, D>: Sink<W, Error = io::Error>,
176{
177    type Error = io::Error;
178
179    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
180        Pin::new(&mut **self.stream.get_mut()).poll_ready(cx)
181    }
182
183    fn start_send(mut self: Pin<&mut Self>, item: W) -> Result<(), Self::Error> {
184        Pin::new(&mut **self.stream.get_mut()).start_send(item)
185    }
186
187    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
188        Pin::new(&mut **self.stream.get_mut()).poll_flush(cx)
189    }
190
191    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
192        Pin::new(&mut **self.stream.get_mut()).poll_close(cx)
193    }
194}