ss_light/crypto/
aead.rs

1use bytes::{BufMut, Bytes, BytesMut};
2
3use futures::ready;
4use ring::aead::{Aad, BoundKey, OpeningKey, SealingKey, UnboundKey, AES_256_GCM};
5
6use core::slice;
7use std::io::{self, ErrorKind};
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tracing::trace;
12
13use super::kind::CipherKind;
14use super::util;
15
16enum EncryptWriteState {
17    AssemblePacket,
18    Writing { pos: usize },
19}
20
21pub struct EncryptedWriter {
22    sealing_key: Option<SealingKey<util::NonceSequence>>, // for encrypt
23    buf: BytesMut,
24    state: EncryptWriteState,
25    kind: CipherKind,
26}
27
28impl EncryptedWriter {
29    pub fn new(kind: CipherKind, key: &[u8], salt: &[u8]) -> Self {
30        match kind {
31            CipherKind::AES_256_GCM => {
32                let mut buf = BytesMut::with_capacity(salt.len());
33                buf.put(salt);
34
35                // cacl sub_key
36                let sub_key = util::hkdf_sha1(key, salt);
37
38                let unbound =
39                    UnboundKey::new(&AES_256_GCM, &sub_key).expect("key.len != algorithm.key_len");
40                let sealing_key = SealingKey::new(unbound, util::NonceSequence::new());
41
42                Self {
43                    sealing_key: Some(sealing_key),
44                    buf,
45                    state: EncryptWriteState::AssemblePacket,
46                    kind,
47                }
48            }
49            _ => panic!("unsupport chipher kind"),
50        }
51    }
52
53    /// Write buf to stream, return num_bytes_written
54    ///
55    /// An AEAD encrypted TCP stream starts with a randomly generated salt to derive the per-session subkey, followed by any number of encrypted chunks.
56    /// Each chunk has the following structure:
57    ///
58    /// [encrypted payload length][length tag][encrypted payload][payload tag]
59    ///
60    /// More details in the [wiki](https://shadowsocks.org/en/wiki/AEAD-Ciphers.html)
61    pub fn poll_write<S>(
62        &mut self,
63        cx: &mut Context,
64        stream: &mut S,
65        mut buf: &[u8],
66    ) -> Poll<io::Result<usize>>
67    where
68        S: AsyncWrite + Unpin + ?Sized,
69    {
70        if buf.len() > self.kind.max_package_size() {
71            buf = &buf[..self.kind.max_package_size()]
72        }
73
74        loop {
75            match self.state {
76                EncryptWriteState::AssemblePacket => {
77                    // 1. append length
78                    let befor_len = self.buf.len(); // salt in buf at begain
79                    let length_size = 2;
80                    self.buf.reserve(length_size);
81                    self.buf.put_u16(buf.len() as u16);
82                    let view = &mut self.buf.as_mut()[befor_len..];
83                    debug_assert!(view.len() == length_size);
84                    let tag = self
85                        .sealing_key
86                        .as_mut()
87                        .unwrap()
88                        .seal_in_place_separate_tag(Aad::<[u8; 0]>::empty(), view)
89                        .expect("seal_in_place_separate_tag for length");
90                    self.buf.extend_from_slice(tag.as_ref());
91
92                    // 2. append data
93                    let befor_len = self.buf.len(); // length data at before
94                    self.buf.extend_from_slice(buf);
95                    let view = &mut self.buf.as_mut()[befor_len..];
96                    let tag = self
97                        .sealing_key
98                        .as_mut()
99                        .unwrap()
100                        .seal_in_place_separate_tag(Aad::<[u8; 0]>::empty(), view)
101                        .expect("seal_in_place_separate_tag for data");
102                    self.buf.extend_from_slice(tag.as_ref());
103
104                    // 3. write all
105                    self.state = EncryptWriteState::Writing { pos: 0 };
106                }
107                EncryptWriteState::Writing { ref mut pos } => {
108                    while *pos < self.buf.len() {
109                        let n = ready!(Pin::new(&mut *stream).poll_write(cx, &self.buf[*pos..]))?;
110                        *pos += n;
111                    }
112
113                    // reset
114                    self.state = EncryptWriteState::AssemblePacket;
115                    self.buf.clear();
116                    return Ok(buf.len()).into();
117                }
118            }
119        }
120    }
121}
122
123enum DecryptReadState {
124    WaitSalt,
125    ReadLength,
126    ReadData { length: usize },
127    BufferedData { pos: usize },
128}
129pub struct DecryptedReader {
130    opening_key: Option<OpeningKey<util::NonceSequence>>, // for decrypt
131    buf: BytesMut,
132    state: DecryptReadState,
133    kind: CipherKind,
134    salt: Option<Bytes>,
135    key: Bytes,
136}
137
138impl DecryptedReader {
139    pub fn new(kind: CipherKind, key: &[u8]) -> Self {
140        match kind {
141            CipherKind::AES_256_GCM => Self {
142                opening_key: None,
143                buf: BytesMut::new(),
144                state: DecryptReadState::WaitSalt,
145                kind,
146                salt: None,
147                key: Bytes::copy_from_slice(key),
148            },
149            _ => panic!("unsupport chipher kind"),
150        }
151    }
152
153    pub fn poll_read<S>(
154        &mut self,
155        cx: &mut Context,
156        stream: &mut S,
157        buf: &mut ReadBuf,
158    ) -> Poll<io::Result<()>>
159    where
160        S: AsyncRead + Unpin + ?Sized,
161    {
162        loop {
163            match self.state {
164                DecryptReadState::WaitSalt => {
165                    let salt_len = self.kind.salt_len();
166                    let n = ready!(self.poll_read_exact_or_zero(cx, stream, salt_len))?;
167                    if n == 0 {
168                        return Err(ErrorKind::UnexpectedEof.into()).into();
169                    }
170                    debug_assert!(self.buf.len() == salt_len);
171                    self.salt = Some(Bytes::copy_from_slice(&self.buf));
172
173                    // cacl sub_key
174                    let sub_key = util::hkdf_sha1(&self.key, &self.salt.as_ref().unwrap());
175                    trace!("peer sub_key is {:?}", sub_key);
176
177                    let unbound = UnboundKey::new(&AES_256_GCM, &sub_key)
178                        .expect("key.len != algorithm.key_len");
179                    let opening_key = OpeningKey::new(unbound, util::NonceSequence::new());
180
181                    self.buf.clear();
182                    self.state = DecryptReadState::ReadLength;
183                    self.buf.reserve(2 + self.kind.tag_len());
184                    self.opening_key = Some(opening_key);
185                }
186                DecryptReadState::ReadLength => {
187                    let usize =
188                        ready!(self.poll_read_exact_or_zero(cx, stream, 2 + self.kind.tag_len()))?;
189                    if usize == 0 {
190                        return Ok(()).into();
191                    } else {
192                        let result = self
193                            .opening_key
194                            .as_mut()
195                            .unwrap()
196                            .open_in_place(Aad::<[u8; 0]>::empty(), &mut self.buf)
197                            .map_err(|_| {
198                                io::Error::new(ErrorKind::Other, "ReadLength invalid tag-in")
199                            })?;
200                        let plen = u16::from_be_bytes([result[0], result[1]]) as usize;
201                        if plen > self.kind.max_package_size() {
202                            let  err = io::Error::new(
203                                ErrorKind::InvalidData,
204                                format!(
205                                    "buffer size too large ({:#x}), AEAD encryption protocol requires buffer to be smaller than 0x3FFF, the higher two bits must be set to zero",
206                                    plen
207                                ),
208                            );
209                            return Err(err).into();
210                        }
211                        self.buf.clear();
212                        self.state = DecryptReadState::ReadData { length: plen };
213                        self.buf.reserve(plen + self.kind.tag_len())
214                    }
215                }
216                DecryptReadState::ReadData { length } => {
217                    let data_len = length + self.kind.tag_len();
218                    let n = ready!(self.poll_read_exact_or_zero(cx, stream, data_len))?;
219                    if n == 0 {
220                        return Err(ErrorKind::UnexpectedEof.into()).into();
221                    }
222                    debug_assert_eq!(data_len, self.buf.len());
223
224                    let _ = self
225                        .opening_key
226                        .as_mut()
227                        .unwrap()
228                        .open_in_place(Aad::<[u8; 0]>::empty(), &mut self.buf)
229                        .map_err(|_| io::Error::new(ErrorKind::Other, "ReadData invalid tag-in"))?;
230
231                    // remove tag
232                    self.buf.truncate(length);
233                    self.state = DecryptReadState::BufferedData { pos: 0 };
234                }
235                DecryptReadState::BufferedData { ref mut pos } => {
236                    if *pos < self.buf.len() {
237                        let buffered = &self.buf[*pos..];
238                        let consumed = usize::min(buffered.len(), buf.remaining());
239                        buf.put_slice(&buffered[..consumed]);
240                        *pos += consumed;
241
242                        return Ok(()).into();
243                    }
244                    self.buf.clear();
245                    self.state = DecryptReadState::ReadLength;
246                    self.buf.reserve(2 + self.kind.tag_len());
247                }
248            }
249        }
250    }
251
252    fn poll_read_exact_or_zero<S>(
253        &mut self,
254        cx: &mut Context,
255        stream: &mut S,
256        size: usize,
257    ) -> Poll<io::Result<usize>>
258    where
259        S: AsyncRead + Unpin + ?Sized,
260    {
261        assert!(size != 0);
262        while self.buf.len() < size {
263            let remaing = size - self.buf.len();
264
265            let view = &mut self.buf.chunk_mut()[..remaing];
266            debug_assert_eq!(view.len(), remaing);
267            let mut read_buf = ReadBuf::uninit(unsafe {
268                slice::from_raw_parts_mut(view.as_mut_ptr() as *mut _, remaing)
269            });
270
271            ready!(Pin::new(&mut *stream).poll_read(cx, &mut read_buf))?;
272            let n = read_buf.filled().len();
273
274            unsafe { self.buf.advance_mut(n) }
275
276            if n == 0 {
277                if !self.buf.is_empty() {
278                    return Err(ErrorKind::UnexpectedEof.into()).into();
279                } else {
280                    return Ok(0).into();
281                }
282            }
283        }
284        Ok(size).into()
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use std::{
291        ops::DerefMut,
292        pin::Pin,
293        task::{Context, Poll},
294    };
295
296    use futures::{ready, Future};
297    use tokio::io::ReadBuf;
298
299    use crate::crypto::{aead::DecryptedReader, kind::CipherKind, util};
300
301    use super::EncryptedWriter;
302
303    #[tokio::test]
304    async fn test_reader_writer() {
305        let pwd = "123456";
306        let salt = &[0u8; 32];
307        let key = util::evp_bytes_to_key(pwd.as_bytes(), CipherKind::AES_256_GCM.key_len());
308
309        struct Fut {
310            r: DecryptedReader,
311            w: EncryptedWriter,
312            mock: Vec<u8>,
313        }
314
315        impl Future for Fut {
316            type Output = ();
317            fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
318                let content = "hello";
319
320                let p = self.deref_mut();
321
322                let w = &mut p.w;
323                let mock = &mut p.mock;
324
325                let n = ready!(w.poll_write(cx, mock, content.as_bytes())).unwrap();
326                assert_eq!(n, content.len());
327
328                let r = &mut p.r;
329                let mut bs = [0u8; 1024];
330                let mut buf = ReadBuf::new(&mut bs);
331                ready!(r.poll_read(cx, &mut mock.as_slice(), &mut buf)).unwrap();
332
333                assert_eq!(buf.filled(), content.as_bytes());
334
335                ().into()
336            }
337        }
338
339        Fut {
340            r: DecryptedReader::new(CipherKind::AES_256_GCM, &key),
341            w: EncryptedWriter::new(CipherKind::AES_256_GCM, &key, salt),
342            mock: Vec::<u8>::new(),
343        }
344        .await
345    }
346}