shadowsocks/
copy.rs

1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use awak::io::{AsyncRead, AsyncWrite};
8
9use crate::cipher::Cipher;
10use crate::util::eof;
11
12pub async fn copy_bidirectional<A, B>(
13    a: &mut A,
14    b: &mut B,
15    cipher: Arc<Mutex<Cipher>>,
16) -> io::Result<(u64, u64)>
17where
18    A: AsyncRead + AsyncWrite + Unpin + ?Sized,
19    B: AsyncRead + AsyncWrite + Unpin + ?Sized,
20{
21    CopyBidirectional {
22        a,
23        b,
24        a_to_b: TransferState::Running,
25        b_to_a: TransferState::Running,
26        decrypter: CopyBuffer::new(cipher.clone()),
27        encrypter: CopyBuffer::new(cipher),
28    }
29    .await
30}
31
32impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
33where
34    A: AsyncRead + AsyncWrite + Unpin + ?Sized,
35    B: AsyncRead + AsyncWrite + Unpin + ?Sized,
36{
37    type Output = io::Result<(u64, u64)>;
38
39    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40        let a_to_b = self.transfer_one_direction(cx, Direction::Encrypt)?;
41        let b_to_a = self.transfer_one_direction(cx, Direction::Decrypt)?;
42
43        // It is not a problem if ready! returns early because transfer_one_direction for the
44        // other direction will keep returning TransferState::Done(count) in future calls to poll
45        let a_to_b = ready!(a_to_b);
46        let b_to_a = ready!(b_to_a);
47
48        Poll::Ready(Ok((a_to_b, b_to_a)))
49    }
50}
51
52struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
53    a: &'a mut A,
54    b: &'a mut B,
55    a_to_b: TransferState,
56    b_to_a: TransferState,
57    decrypter: CopyBuffer,
58    encrypter: CopyBuffer,
59}
60
61enum TransferState {
62    Running,
63    ShuttingDown(u64),
64    Done(u64),
65}
66#[derive(Debug)]
67enum Direction {
68    Encrypt,
69    Decrypt,
70}
71
72impl<'a, A, B> CopyBidirectional<'a, A, B>
73where
74    A: AsyncRead + AsyncWrite + Unpin + ?Sized,
75    B: AsyncRead + AsyncWrite + Unpin + ?Sized,
76{
77    fn transfer_one_direction(
78        &mut self,
79        cx: &mut Context<'_>,
80        direction: Direction,
81    ) -> Poll<io::Result<u64>> {
82        let state = match direction {
83            Direction::Encrypt => &mut self.a_to_b,
84            Direction::Decrypt => &mut self.b_to_a,
85        };
86        loop {
87            match state {
88                TransferState::Running => {
89                    let count = match direction {
90                        Direction::Encrypt => {
91                            ready!(self.encrypter.poll_encrypt(cx, self.a, self.b))?
92                        }
93                        Direction::Decrypt => {
94                            ready!(self.decrypter.poll_decrypt(cx, self.b, self.a))?
95                        }
96                    };
97
98                    *state = TransferState::ShuttingDown(count);
99                }
100                TransferState::ShuttingDown(count) => {
101                    match direction {
102                        Direction::Encrypt => ready!(Pin::new(&mut self.b).poll_close(cx))?,
103                        Direction::Decrypt => ready!(Pin::new(&mut self.a).poll_close(cx))?,
104                    };
105
106                    *state = TransferState::Done(*count);
107                }
108                TransferState::Done(count) => return Poll::Ready(Ok(*count)),
109            }
110        }
111    }
112}
113
114struct CopyBuffer {
115    cipher: Arc<Mutex<Cipher>>,
116    read_done: bool,
117    pos: usize,
118    cap: usize,
119    amt: u64,
120    buf: Box<[u8]>,
121    need_flush: bool,
122}
123
124impl CopyBuffer {
125    fn new(cipher: Arc<Mutex<Cipher>>) -> CopyBuffer {
126        CopyBuffer {
127            cipher,
128            read_done: false,
129            amt: 0,
130            pos: 0,
131            cap: 0,
132            buf: Box::new([0; 1024 * 2]),
133            need_flush: false,
134        }
135    }
136
137    fn poll_decrypt<'a, R, W>(
138        &mut self,
139        cx: &mut Context,
140        mut reader: &'a mut R,
141        writer: &'a mut W,
142    ) -> Poll<io::Result<u64>>
143    where
144        R: AsyncRead + Unpin + ?Sized,
145        W: AsyncWrite + Unpin + ?Sized,
146    {
147        {
148            let mut cipher = self.cipher.lock().unwrap();
149            if cipher.dec.is_none() {
150                while self.pos < cipher.iv.len() {
151                    let n =
152                        ready!(Pin::new(&mut reader).poll_read(cx, &mut cipher.iv[self.pos..]))?;
153                    self.pos += n;
154                    if n == 0 {
155                        return Err(eof()).into();
156                    }
157                }
158                self.pos = 0;
159                cipher.init_decrypt();
160            }
161        }
162        self.poll_copy(cx, reader, writer, Direction::Decrypt)
163    }
164
165    fn poll_encrypt<'a, R, W>(
166        &mut self,
167        cx: &mut Context,
168        reader: &'a mut R,
169        writer: &'a mut W,
170    ) -> Poll<io::Result<u64>>
171    where
172        R: AsyncRead + Unpin + ?Sized,
173        W: AsyncWrite + Unpin + ?Sized,
174    {
175        {
176            let mut cipher = self.cipher.lock().unwrap();
177            if cipher.enc.is_none() {
178                cipher.init_encrypt();
179                self.pos = 0;
180                let n = cipher.iv.len();
181                self.cap = n;
182                self.buf[..n].copy_from_slice(&cipher.iv);
183            }
184        }
185        self.poll_copy(cx, reader, writer, Direction::Encrypt)
186    }
187
188    fn poll_copy<'a, R, W>(
189        &mut self,
190        cx: &mut Context,
191        mut reader: &'a mut R,
192        mut writer: &'a mut W,
193        direction: Direction,
194    ) -> Poll<io::Result<u64>>
195    where
196        R: AsyncRead + Unpin + ?Sized,
197        W: AsyncWrite + Unpin + ?Sized,
198    {
199        loop {
200            // If our buffer is empty, then we need to read some data to
201            // continue.
202            if self.pos == self.cap && !self.read_done {
203                let n = match Pin::new(&mut reader).poll_read(cx, &mut self.buf) {
204                    Poll::Ready(Ok(n)) => n,
205                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
206                    Poll::Pending => {
207                        // Try flushing when the reader has no progress to avoid deadlock
208                        // when the reader depends on buffered writer.
209                        if self.need_flush {
210                            ready!(Pin::new(&mut writer).poll_flush(cx))?;
211                            self.need_flush = false;
212                        }
213                        return Poll::Pending;
214                    }
215                };
216                if n == 0 {
217                    self.read_done = true;
218                } else {
219                    let mut cipher = self.cipher.lock().unwrap();
220                    match direction {
221                        Direction::Decrypt => {
222                            cipher.decrypt(&mut self.buf[..n]);
223                        }
224                        Direction::Encrypt => {
225                            cipher.encrypt(&mut self.buf[..n]);
226                        }
227                    }
228                    self.pos = 0;
229                    self.cap = n;
230                }
231            }
232
233            // If our buffer has some data, let's write it out!
234            while self.pos < self.cap {
235                let i =
236                    ready!(Pin::new(&mut writer).poll_write(cx, &self.buf[self.pos..self.cap]))?;
237                if i == 0 {
238                    return Poll::Ready(Err(io::Error::new(
239                        io::ErrorKind::WriteZero,
240                        "write zero byte into writer",
241                    )));
242                } else {
243                    self.pos += i;
244                    self.amt += i as u64;
245                    self.need_flush = true;
246                }
247            }
248
249            // If we've written all the data and we've seen EOF, flush out the
250            // data and finish the transfer.
251            if self.pos == self.cap && self.read_done {
252                ready!(Pin::new(&mut writer).poll_flush(cx))?;
253                return Poll::Ready(Ok(self.amt));
254            }
255        }
256    }
257}