narrowlink_network/
async_tools.rs1use std::{
2 io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use crate::{error::NetworkError, AsyncSocket, UniversalStream};
8use futures_util::{Future, Sink, SinkExt, Stream, StreamExt};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10
11pub struct AsyncToStream {
12 socket: Box<dyn AsyncSocket>,
13 buffer: Option<(usize, Vec<u8>)>,
14}
15
16impl AsyncToStream {
17 pub fn new(socket: impl AsyncSocket) -> Self {
18 Self {
19 socket: Box::new(socket),
20 buffer: None,
21 }
22 }
23}
24
25impl Stream for AsyncToStream {
26 type Item = Result<Vec<u8>, NetworkError>;
27
28 fn poll_next(
29 mut self: Pin<&mut Self>,
30 cx: &mut std::task::Context<'_>,
31 ) -> Poll<Option<Self::Item>> {
32 let mut buf = [0u8; 65536];
33 let mut buffer = ReadBuf::new(&mut buf);
34 match Pin::new(&mut self.socket).poll_read(cx, &mut buffer)? {
35 Poll::Ready(_) => {
36 if buffer.filled().is_empty() {
37 Poll::Ready(None)
38 } else {
39 Poll::Ready(Some(Ok(buffer.filled().to_vec())))
40 }
41 }
42 Poll::Pending => Poll::Pending,
43 }
44 }
45}
46
47impl Sink<Vec<u8>> for AsyncToStream {
48 type Error = NetworkError;
49
50 fn poll_ready(
51 mut self: Pin<&mut Self>,
52 cx: &mut std::task::Context<'_>,
53 ) -> Poll<Result<(), Self::Error>> {
54 if let Some((mut len, buffer)) = self.buffer.take() {
55 loop {
56 len = match Pin::new(&mut self.socket).poll_write(cx, &buffer)? {
57 Poll::Ready(written) => written,
58 Poll::Pending => return Poll::Pending,
59 };
60 if len == buffer.len() {
61 break;
62 }
63 }
64 }
65
66 Poll::Ready(Ok(()))
67 }
68
69 fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
70 self.buffer = Some((0, item));
71 Ok(())
72 }
73
74 fn poll_flush(
75 mut self: Pin<&mut Self>,
76 cx: &mut std::task::Context<'_>,
77 ) -> Poll<Result<(), Self::Error>> {
78 let _ = Pin::new(&mut self).poll_ready(cx)?;
79 Pin::new(&mut self.socket)
80 .poll_flush(cx)
81 .map_err(|e| e.into())
82 }
83
84 fn poll_close(
85 mut self: Pin<&mut Self>,
86 cx: &mut std::task::Context<'_>,
87 ) -> Poll<Result<(), Self::Error>> {
88 let _ = Pin::new(&mut self).poll_ready(cx)?;
89 Pin::new(&mut self.socket)
90 .poll_shutdown(cx)
91 .map_err(|e| e.into())
92 }
93}
94
95pub struct StreamToAsync {
96 stream: Box<dyn UniversalStream<Vec<u8>, NetworkError>>,
97 remaining_bytes: Option<Vec<u8>>,
98}
99impl StreamToAsync {
100 pub fn new(socket: impl UniversalStream<Vec<u8>, NetworkError>) -> Self {
101 Self {
102 stream: Box::new(socket),
103 remaining_bytes: None,
104 }
105 }
106}
107impl AsyncRead for StreamToAsync {
108 fn poll_read(
109 mut self: Pin<&mut Self>,
110 cx: &mut std::task::Context<'_>,
111 buf: &mut ReadBuf<'_>,
112 ) -> Poll<std::io::Result<()>> {
113 loop {
114 if let Some(mut remaining_buf) = self.remaining_bytes.take() {
115 if buf.remaining() < remaining_buf.len() {
116 self.remaining_bytes = Some(remaining_buf.split_off(buf.remaining()));
117 buf.put_slice(&remaining_buf);
118 } else {
119 buf.put_slice(&remaining_buf);
120 self.remaining_bytes = None;
121 }
122 return Poll::Ready(Ok(()));
123 }
124
125 match self.stream.poll_next_unpin(cx) {
126 Poll::Ready(Some(Ok(d))) => {
127 self.remaining_bytes = Some(d);
128 continue;
129 }
130 Poll::Ready(Some(Err(e))) => {
131 return Poll::Ready(Err(std::io::Error::new(
132 std::io::ErrorKind::Other,
133 e.to_string(),
134 )))
135 }
136 Poll::Ready(None) => return Poll::Ready(Ok(())),
137 Poll::Pending => return Poll::Pending,
138 };
139 }
140 }
141}
142
143impl AsyncWrite for StreamToAsync {
144 fn poll_write(
145 mut self: std::pin::Pin<&mut Self>,
146 cx: &mut Context<'_>,
147 buf: &[u8],
148 ) -> Poll<Result<usize, std::io::Error>> {
149 match Pin::new(&mut self.stream.send(buf.to_vec()))
150 .poll(cx)
151 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
152 {
153 Poll::Ready(_) => Poll::Ready(Ok(buf.len())),
154 Poll::Pending => Poll::Pending,
155 }
156 }
157
158 fn poll_flush(
159 mut self: std::pin::Pin<&mut Self>,
160 cx: &mut Context<'_>,
161 ) -> Poll<Result<(), io::Error>> {
162 self.stream
163 .poll_flush_unpin(cx)
164 .map_err(|_| io::Error::from(io::ErrorKind::UnexpectedEof))
165 }
166
167 fn poll_shutdown(
168 mut self: std::pin::Pin<&mut Self>,
169 cx: &mut Context<'_>,
170 ) -> Poll<Result<(), io::Error>> {
171 self.stream
172 .poll_close_unpin(cx)
173 .map_err(|_| io::Error::from(io::ErrorKind::UnexpectedEof))
174 }
175}