1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use awak::io::{AsyncRead, AsyncWrite};
8
9use crate::cipher::Cipher;
10use crate::util::eof;
11
12pub async fn copy_bidirectional<A, B>(
13 a: &mut A,
14 b: &mut B,
15 cipher: Arc<Mutex<Cipher>>,
16) -> io::Result<(u64, u64)>
17where
18 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
19 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
20{
21 CopyBidirectional {
22 a,
23 b,
24 a_to_b: TransferState::Running,
25 b_to_a: TransferState::Running,
26 decrypter: CopyBuffer::new(cipher.clone()),
27 encrypter: CopyBuffer::new(cipher),
28 }
29 .await
30}
31
32impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
33where
34 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
35 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
36{
37 type Output = io::Result<(u64, u64)>;
38
39 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40 let a_to_b = self.transfer_one_direction(cx, Direction::Encrypt)?;
41 let b_to_a = self.transfer_one_direction(cx, Direction::Decrypt)?;
42
43 let a_to_b = ready!(a_to_b);
46 let b_to_a = ready!(b_to_a);
47
48 Poll::Ready(Ok((a_to_b, b_to_a)))
49 }
50}
51
52struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
53 a: &'a mut A,
54 b: &'a mut B,
55 a_to_b: TransferState,
56 b_to_a: TransferState,
57 decrypter: CopyBuffer,
58 encrypter: CopyBuffer,
59}
60
61enum TransferState {
62 Running,
63 ShuttingDown(u64),
64 Done(u64),
65}
66#[derive(Debug)]
67enum Direction {
68 Encrypt,
69 Decrypt,
70}
71
72impl<'a, A, B> CopyBidirectional<'a, A, B>
73where
74 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
75 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
76{
77 fn transfer_one_direction(
78 &mut self,
79 cx: &mut Context<'_>,
80 direction: Direction,
81 ) -> Poll<io::Result<u64>> {
82 let state = match direction {
83 Direction::Encrypt => &mut self.a_to_b,
84 Direction::Decrypt => &mut self.b_to_a,
85 };
86 loop {
87 match state {
88 TransferState::Running => {
89 let count = match direction {
90 Direction::Encrypt => {
91 ready!(self.encrypter.poll_encrypt(cx, self.a, self.b))?
92 }
93 Direction::Decrypt => {
94 ready!(self.decrypter.poll_decrypt(cx, self.b, self.a))?
95 }
96 };
97
98 *state = TransferState::ShuttingDown(count);
99 }
100 TransferState::ShuttingDown(count) => {
101 match direction {
102 Direction::Encrypt => ready!(Pin::new(&mut self.b).poll_close(cx))?,
103 Direction::Decrypt => ready!(Pin::new(&mut self.a).poll_close(cx))?,
104 };
105
106 *state = TransferState::Done(*count);
107 }
108 TransferState::Done(count) => return Poll::Ready(Ok(*count)),
109 }
110 }
111 }
112}
113
114struct CopyBuffer {
115 cipher: Arc<Mutex<Cipher>>,
116 read_done: bool,
117 pos: usize,
118 cap: usize,
119 amt: u64,
120 buf: Box<[u8]>,
121 need_flush: bool,
122}
123
124impl CopyBuffer {
125 fn new(cipher: Arc<Mutex<Cipher>>) -> CopyBuffer {
126 CopyBuffer {
127 cipher,
128 read_done: false,
129 amt: 0,
130 pos: 0,
131 cap: 0,
132 buf: Box::new([0; 1024 * 2]),
133 need_flush: false,
134 }
135 }
136
137 fn poll_decrypt<'a, R, W>(
138 &mut self,
139 cx: &mut Context,
140 mut reader: &'a mut R,
141 writer: &'a mut W,
142 ) -> Poll<io::Result<u64>>
143 where
144 R: AsyncRead + Unpin + ?Sized,
145 W: AsyncWrite + Unpin + ?Sized,
146 {
147 {
148 let mut cipher = self.cipher.lock().unwrap();
149 if cipher.dec.is_none() {
150 while self.pos < cipher.iv.len() {
151 let n =
152 ready!(Pin::new(&mut reader).poll_read(cx, &mut cipher.iv[self.pos..]))?;
153 self.pos += n;
154 if n == 0 {
155 return Err(eof()).into();
156 }
157 }
158 self.pos = 0;
159 cipher.init_decrypt();
160 }
161 }
162 self.poll_copy(cx, reader, writer, Direction::Decrypt)
163 }
164
165 fn poll_encrypt<'a, R, W>(
166 &mut self,
167 cx: &mut Context,
168 reader: &'a mut R,
169 writer: &'a mut W,
170 ) -> Poll<io::Result<u64>>
171 where
172 R: AsyncRead + Unpin + ?Sized,
173 W: AsyncWrite + Unpin + ?Sized,
174 {
175 {
176 let mut cipher = self.cipher.lock().unwrap();
177 if cipher.enc.is_none() {
178 cipher.init_encrypt();
179 self.pos = 0;
180 let n = cipher.iv.len();
181 self.cap = n;
182 self.buf[..n].copy_from_slice(&cipher.iv);
183 }
184 }
185 self.poll_copy(cx, reader, writer, Direction::Encrypt)
186 }
187
188 fn poll_copy<'a, R, W>(
189 &mut self,
190 cx: &mut Context,
191 mut reader: &'a mut R,
192 mut writer: &'a mut W,
193 direction: Direction,
194 ) -> Poll<io::Result<u64>>
195 where
196 R: AsyncRead + Unpin + ?Sized,
197 W: AsyncWrite + Unpin + ?Sized,
198 {
199 loop {
200 if self.pos == self.cap && !self.read_done {
203 let n = match Pin::new(&mut reader).poll_read(cx, &mut self.buf) {
204 Poll::Ready(Ok(n)) => n,
205 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
206 Poll::Pending => {
207 if self.need_flush {
210 ready!(Pin::new(&mut writer).poll_flush(cx))?;
211 self.need_flush = false;
212 }
213 return Poll::Pending;
214 }
215 };
216 if n == 0 {
217 self.read_done = true;
218 } else {
219 let mut cipher = self.cipher.lock().unwrap();
220 match direction {
221 Direction::Decrypt => {
222 cipher.decrypt(&mut self.buf[..n]);
223 }
224 Direction::Encrypt => {
225 cipher.encrypt(&mut self.buf[..n]);
226 }
227 }
228 self.pos = 0;
229 self.cap = n;
230 }
231 }
232
233 while self.pos < self.cap {
235 let i =
236 ready!(Pin::new(&mut writer).poll_write(cx, &self.buf[self.pos..self.cap]))?;
237 if i == 0 {
238 return Poll::Ready(Err(io::Error::new(
239 io::ErrorKind::WriteZero,
240 "write zero byte into writer",
241 )));
242 } else {
243 self.pos += i;
244 self.amt += i as u64;
245 self.need_flush = true;
246 }
247 }
248
249 if self.pos == self.cap && self.read_done {
252 ready!(Pin::new(&mut writer).poll_flush(cx))?;
253 return Poll::Ready(Ok(self.amt));
254 }
255 }
256 }
257}