narrowlink_network/
async_tools.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use crate::{error::NetworkError, AsyncSocket, UniversalStream};
8use futures_util::{Future, Sink, SinkExt, Stream, StreamExt};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10
11pub struct AsyncToStream {
12    socket: Box<dyn AsyncSocket>,
13    buffer: Option<(usize, Vec<u8>)>,
14}
15
16impl AsyncToStream {
17    pub fn new(socket: impl AsyncSocket) -> Self {
18        Self {
19            socket: Box::new(socket),
20            buffer: None,
21        }
22    }
23}
24
25impl Stream for AsyncToStream {
26    type Item = Result<Vec<u8>, NetworkError>;
27
28    fn poll_next(
29        mut self: Pin<&mut Self>,
30        cx: &mut std::task::Context<'_>,
31    ) -> Poll<Option<Self::Item>> {
32        let mut buf = [0u8; 65536];
33        let mut buffer = ReadBuf::new(&mut buf);
34        match Pin::new(&mut self.socket).poll_read(cx, &mut buffer)? {
35            Poll::Ready(_) => {
36                if buffer.filled().is_empty() {
37                    Poll::Ready(None)
38                } else {
39                    Poll::Ready(Some(Ok(buffer.filled().to_vec())))
40                }
41            }
42            Poll::Pending => Poll::Pending,
43        }
44    }
45}
46
47impl Sink<Vec<u8>> for AsyncToStream {
48    type Error = NetworkError;
49
50    fn poll_ready(
51        mut self: Pin<&mut Self>,
52        cx: &mut std::task::Context<'_>,
53    ) -> Poll<Result<(), Self::Error>> {
54        if let Some((mut len, buffer)) = self.buffer.take() {
55            loop {
56                len = match Pin::new(&mut self.socket).poll_write(cx, &buffer)? {
57                    Poll::Ready(written) => written,
58                    Poll::Pending => return Poll::Pending,
59                };
60                if len == buffer.len() {
61                    break;
62                }
63            }
64        }
65
66        Poll::Ready(Ok(()))
67    }
68
69    fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
70        self.buffer = Some((0, item));
71        Ok(())
72    }
73
74    fn poll_flush(
75        mut self: Pin<&mut Self>,
76        cx: &mut std::task::Context<'_>,
77    ) -> Poll<Result<(), Self::Error>> {
78        let _ = Pin::new(&mut self).poll_ready(cx)?;
79        Pin::new(&mut self.socket)
80            .poll_flush(cx)
81            .map_err(|e| e.into())
82    }
83
84    fn poll_close(
85        mut self: Pin<&mut Self>,
86        cx: &mut std::task::Context<'_>,
87    ) -> Poll<Result<(), Self::Error>> {
88        let _ = Pin::new(&mut self).poll_ready(cx)?;
89        Pin::new(&mut self.socket)
90            .poll_shutdown(cx)
91            .map_err(|e| e.into())
92    }
93}
94
95pub struct StreamToAsync {
96    stream: Box<dyn UniversalStream<Vec<u8>, NetworkError>>,
97    remaining_bytes: Option<Vec<u8>>,
98}
99impl StreamToAsync {
100    pub fn new(socket: impl UniversalStream<Vec<u8>, NetworkError>) -> Self {
101        Self {
102            stream: Box::new(socket),
103            remaining_bytes: None,
104        }
105    }
106}
107impl AsyncRead for StreamToAsync {
108    fn poll_read(
109        mut self: Pin<&mut Self>,
110        cx: &mut std::task::Context<'_>,
111        buf: &mut ReadBuf<'_>,
112    ) -> Poll<std::io::Result<()>> {
113        loop {
114            if let Some(mut remaining_buf) = self.remaining_bytes.take() {
115                if buf.remaining() < remaining_buf.len() {
116                    self.remaining_bytes = Some(remaining_buf.split_off(buf.remaining()));
117                    buf.put_slice(&remaining_buf);
118                } else {
119                    buf.put_slice(&remaining_buf);
120                    self.remaining_bytes = None;
121                }
122                return Poll::Ready(Ok(()));
123            }
124
125            match self.stream.poll_next_unpin(cx) {
126                Poll::Ready(Some(Ok(d))) => {
127                    self.remaining_bytes = Some(d);
128                    continue;
129                }
130                Poll::Ready(Some(Err(e))) => {
131                    return Poll::Ready(Err(std::io::Error::new(
132                        std::io::ErrorKind::Other,
133                        e.to_string(),
134                    )))
135                }
136                Poll::Ready(None) => return Poll::Ready(Ok(())),
137                Poll::Pending => return Poll::Pending,
138            };
139        }
140    }
141}
142
143impl AsyncWrite for StreamToAsync {
144    fn poll_write(
145        mut self: std::pin::Pin<&mut Self>,
146        cx: &mut Context<'_>,
147        buf: &[u8],
148    ) -> Poll<Result<usize, std::io::Error>> {
149        match Pin::new(&mut self.stream.send(buf.to_vec()))
150            .poll(cx)
151            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
152        {
153            Poll::Ready(_) => Poll::Ready(Ok(buf.len())),
154            Poll::Pending => Poll::Pending,
155        }
156    }
157
158    fn poll_flush(
159        mut self: std::pin::Pin<&mut Self>,
160        cx: &mut Context<'_>,
161    ) -> Poll<Result<(), io::Error>> {
162        self.stream
163            .poll_flush_unpin(cx)
164            .map_err(|_| io::Error::from(io::ErrorKind::UnexpectedEof))
165    }
166
167    fn poll_shutdown(
168        mut self: std::pin::Pin<&mut Self>,
169        cx: &mut Context<'_>,
170    ) -> Poll<Result<(), io::Error>> {
171        self.stream
172            .poll_close_unpin(cx)
173            .map_err(|_| io::Error::from(io::ErrorKind::UnexpectedEof))
174    }
175}