use std::{
io::{self, Cursor, ErrorKind, Read},
marker::Unpin,
pin::Pin,
slice,
sync::Arc,
task::{self, Poll},
time::SystemTime,
u16,
};
use aes::{
cipher::{BlockDecrypt, BlockEncrypt, KeyInit},
Aes128,
Aes256,
Block,
};
use byte_string::ByteStr;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::ready;
use log::{error, trace};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use super::{crypto_io::StreamType, proxy_stream::protocol::v2::SERVER_STREAM_TIMESTAMP_MAX_DIFF};
use crate::{
config::{method_support_eih, ServerUserManager},
context::Context,
crypto::{v2::tcp::TcpCipher, CipherKind},
};
#[inline]
fn get_now_timestamp() -> u64 {
match SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) {
Ok(n) => n.as_secs(),
Err(_) => panic!("SystemTime::now() is before UNIX Epoch!"),
}
}
pub const MAX_PACKET_SIZE: usize = 0xFFFF;
const AEAD2022_EIH_SUBKEY_CONTEXT: &str = "shadowsocks 2022 identity subkey";
#[derive(thiserror::Error, Debug)]
pub enum ProtocolError {
#[error(transparent)]
IoError(#[from] io::Error),
#[error("header too short, expecting {0} bytes, but found {1} bytes")]
HeaderTooShort(usize, usize),
#[error("missing extended identity header")]
MissingExtendedIdentityHeader,
#[error("invalid client user identity {:?}", ByteStr::new(.0))]
InvalidClientUser(Bytes),
#[error("decrypt header chunk failed")]
DecryptHeaderChunkError,
#[error("decrypt data failed")]
DecryptDataError,
#[error("decrypt length failed")]
DecryptLengthError,
#[error("invalid stream type, expecting {0:#x}, but found {1:#x}")]
InvalidStreamType(u8, u8),
#[error("invalid timestamp {0} - now {1} = {}", *.0 as i64 - *.1 as i64)]
InvalidTimestamp(u64, u64),
}
pub type ProtocolResult<T> = Result<T, ProtocolError>;
impl From<ProtocolError> for io::Error {
fn from(e: ProtocolError) -> io::Error {
match e {
ProtocolError::IoError(err) => err,
_ => io::Error::new(ErrorKind::Other, e),
}
}
}
enum DecryptReadState {
ReadHeader { key: Bytes },
ReadLength,
ReadData { length: usize },
BufferedData { pos: usize },
}
pub struct DecryptedReader {
stream_ty: StreamType,
state: DecryptReadState,
cipher: Option<TcpCipher>,
buffer: BytesMut,
method: CipherKind,
salt: Option<Bytes>,
request_salt: Option<Bytes>,
data_chunk_count: u64,
user_manager: Option<Arc<ServerUserManager>>,
user_key: Option<Bytes>,
has_handshaked: bool,
}
impl DecryptedReader {
pub fn new(stream_ty: StreamType, method: CipherKind, key: &[u8]) -> DecryptedReader {
DecryptedReader::with_user_manager(stream_ty, method, key, None)
}
pub fn with_user_manager(
stream_ty: StreamType,
method: CipherKind,
key: &[u8],
user_manager: Option<Arc<ServerUserManager>>,
) -> DecryptedReader {
if method.salt_len() > 0 {
DecryptedReader {
stream_ty,
state: DecryptReadState::ReadHeader {
key: Bytes::copy_from_slice(key),
},
cipher: None,
buffer: BytesMut::new(),
method,
salt: None,
request_salt: None,
data_chunk_count: 0,
user_manager,
user_key: None,
has_handshaked: false,
}
} else {
DecryptedReader {
stream_ty,
state: DecryptReadState::ReadHeader {
key: Bytes::new(), },
cipher: Some(TcpCipher::new(method, key, &[])),
buffer: BytesMut::new(),
method,
salt: None,
request_salt: None,
data_chunk_count: 0,
user_manager,
user_key: None,
has_handshaked: false,
}
}
}
pub fn salt(&self) -> Option<&[u8]> {
self.salt.as_deref()
}
pub fn request_salt(&self) -> Option<&[u8]> {
self.request_salt.as_deref().filter(|&n| !n.is_empty())
}
pub fn poll_read_decrypted<S>(
&mut self,
cx: &mut task::Context<'_>,
context: &Context,
stream: &mut S,
buf: &mut ReadBuf<'_>,
) -> Poll<ProtocolResult<()>>
where
S: AsyncRead + Unpin + ?Sized,
{
loop {
match self.state {
DecryptReadState::ReadHeader { ref key } => {
let key = unsafe { &*(key.as_ref() as *const _) };
match ready!(self.poll_read_header(cx, context, stream, key))? {
None => {
return Ok(()).into();
}
Some(length) => {
self.buffer.clear();
self.state = DecryptReadState::ReadData { length };
self.buffer.reserve(length + self.method.tag_len());
self.has_handshaked = true;
}
}
}
DecryptReadState::ReadLength => match ready!(self.poll_read_length(cx, stream))? {
None => {
return Ok(()).into();
}
Some(length) => {
self.buffer.clear();
self.state = DecryptReadState::ReadData { length };
self.buffer.reserve(length + self.method.tag_len());
}
},
DecryptReadState::ReadData { length } => {
ready!(self.poll_read_data(cx, stream, length))?;
self.state = DecryptReadState::BufferedData { pos: 0 };
self.data_chunk_count = self.data_chunk_count.wrapping_add(1);
}
DecryptReadState::BufferedData { ref mut pos } => {
if *pos < self.buffer.len() {
let buffered = &self.buffer[*pos..];
let consumed = usize::min(buffered.len(), buf.remaining());
buf.put_slice(&buffered[..consumed]);
*pos += consumed;
return Ok(()).into();
}
self.buffer.clear();
self.state = DecryptReadState::ReadLength;
self.buffer.reserve(2 + self.method.tag_len());
}
}
}
}
fn poll_read_header<S>(
&mut self,
cx: &mut task::Context<'_>,
context: &Context,
stream: &mut S,
key: &[u8],
) -> Poll<ProtocolResult<Option<usize>>>
where
S: AsyncRead + Unpin + ?Sized,
{
let salt_len = self.method.salt_len();
let request_salt_len = match self.stream_ty {
StreamType::Client => salt_len,
StreamType::Server => 0,
};
let require_eih =
self.stream_ty == StreamType::Server && method_support_eih(self.method) && self.user_manager.is_some();
let eih_len = if require_eih { 16 } else { 0 };
let header_len = salt_len + eih_len + 1 + 8 + request_salt_len + 2 + self.method.tag_len();
if self.buffer.len() < header_len {
self.buffer.resize(header_len, 0);
}
let mut read_buf = ReadBuf::new(&mut self.buffer[..header_len]);
ready!(Pin::new(stream).poll_read(cx, &mut read_buf))?;
let header_buf = read_buf.filled_mut();
if header_buf.is_empty() {
return Ok(None).into();
} else if header_buf.len() != header_len {
return Err(ProtocolError::HeaderTooShort(header_len, header_buf.len())).into();
}
let (salt, mut header_chunk) = header_buf.split_at_mut(salt_len);
trace!("got AEAD salt {:?}", ByteStr::new(salt));
let mut cipher = if require_eih {
if let Some(ref user_manager) = self.user_manager {
if header_chunk.len() < 16 {
error!("expecting EIH, but header chunk len: {}", header_chunk.len());
return Err(ProtocolError::MissingExtendedIdentityHeader).into();
}
let (eih, remain_header_chunk) = header_chunk.split_at_mut(16);
header_chunk = remain_header_chunk;
let key_material = [key, salt].concat();
let identity_sub_key = blake3::derive_key(AEAD2022_EIH_SUBKEY_CONTEXT, &key_material);
let mut user_hash = Block::from([0u8; 16]);
match self.method {
CipherKind::AEAD2022_BLAKE3_AES_128_GCM => {
let cipher = Aes128::new_from_slice(&identity_sub_key[0..16]).expect("AES-128");
cipher.decrypt_block_b2b(Block::from_slice(eih), &mut user_hash);
}
CipherKind::AEAD2022_BLAKE3_AES_256_GCM => {
let cipher = Aes256::new_from_slice(&identity_sub_key[0..32]).expect("AES-256");
cipher.decrypt_block_b2b(Block::from_slice(eih), &mut user_hash);
}
_ => unreachable!("{} doesn't support EIH", self.method),
}
let user_hash = user_hash.as_slice();
trace!(
"server EIH {:?}, hash: {:?}",
ByteStr::new(eih),
ByteStr::new(user_hash)
);
match user_manager.get_user_by_hash(user_hash) {
None => {
return Err(ProtocolError::InvalidClientUser(Bytes::copy_from_slice(user_hash))).into();
}
Some(user) => {
trace!("{:?} chosen by EIH", user);
self.user_key = Some(Bytes::copy_from_slice(user.key()));
TcpCipher::new(self.method, user.key(), salt)
}
}
} else {
unreachable!("user_manager must not be None")
}
} else {
TcpCipher::new(self.method, key, salt)
};
if !cipher.decrypt_packet(header_chunk) {
return Err(ProtocolError::DecryptHeaderChunkError).into();
}
let mut header_reader = Cursor::new(header_chunk);
let stream_ty = header_reader.get_u8();
let expected_stream_ty = match self.stream_ty {
StreamType::Client => 1, StreamType::Server => 0,
};
if stream_ty != expected_stream_ty {
return Err(ProtocolError::InvalidStreamType(expected_stream_ty, stream_ty)).into();
}
let timestamp = header_reader.get_u64();
let now = get_now_timestamp();
if now.abs_diff(timestamp) > SERVER_STREAM_TIMESTAMP_MAX_DIFF {
return Err(ProtocolError::InvalidTimestamp(timestamp, now)).into();
}
if request_salt_len > 0 {
let mut request_salt = BytesMut::with_capacity(salt_len);
request_salt.resize(salt_len, 0);
header_reader.read_exact(&mut request_salt)?;
self.request_salt = Some(request_salt.freeze());
}
let data_length = header_reader.get_u16();
trace!(
"got AEAD header stream_type: {}, timestamp: {}, length: {}, request_salt: {:?}",
stream_ty,
timestamp,
data_length,
self.request_salt.as_deref().map(ByteStr::new)
);
if self.stream_ty == StreamType::Server {
context.check_nonce_replay(self.method, salt)?;
}
self.salt = Some(Bytes::copy_from_slice(salt));
self.cipher = Some(cipher);
Ok(Some(data_length as usize)).into()
}
fn poll_read_length<S>(&mut self, cx: &mut task::Context<'_>, stream: &mut S) -> Poll<io::Result<Option<usize>>>
where
S: AsyncRead + Unpin + ?Sized,
{
let length_len = 2 + self.method.tag_len();
let n = ready!(self.poll_read_exact(cx, stream, length_len))?;
if n == 0 {
return Ok(None).into();
}
let cipher = self.cipher.as_mut().expect("cipher is None");
let m = &mut self.buffer[..length_len];
let length = DecryptedReader::decrypt_length(cipher, m)?;
Ok(Some(length)).into()
}
fn poll_read_data<S>(&mut self, cx: &mut task::Context<'_>, stream: &mut S, size: usize) -> Poll<ProtocolResult<()>>
where
S: AsyncRead + Unpin + ?Sized,
{
let data_len = size + self.method.tag_len();
let n = ready!(self.poll_read_exact(cx, stream, data_len))?;
if n == 0 {
return Err(io::Error::from(ErrorKind::UnexpectedEof).into()).into();
}
let cipher = self.cipher.as_mut().expect("cipher is None");
let m = &mut self.buffer[..data_len];
if !cipher.decrypt_packet(m) {
return Err(ProtocolError::DecryptDataError).into();
}
self.buffer.truncate(size);
Ok(()).into()
}
fn poll_read_exact<S>(&mut self, cx: &mut task::Context<'_>, stream: &mut S, size: usize) -> Poll<io::Result<usize>>
where
S: AsyncRead + Unpin + ?Sized,
{
assert!(size != 0);
while self.buffer.len() < size {
let remaining = size - self.buffer.len();
let buffer = &mut self.buffer.chunk_mut()[..remaining];
let mut read_buf =
ReadBuf::uninit(unsafe { slice::from_raw_parts_mut(buffer.as_mut_ptr() as *mut _, remaining) });
ready!(Pin::new(&mut *stream).poll_read(cx, &mut read_buf))?;
let n = read_buf.filled().len();
if n == 0 {
if !self.buffer.is_empty() {
return Err(ErrorKind::UnexpectedEof.into()).into();
} else {
return Ok(0).into();
}
}
unsafe {
self.buffer.advance_mut(n);
}
}
Ok(size).into()
}
fn decrypt_length(cipher: &mut TcpCipher, m: &mut [u8]) -> ProtocolResult<usize> {
let plen = {
if !cipher.decrypt_packet(m) {
return Err(ProtocolError::DecryptLengthError);
}
u16::from_be_bytes([m[0], m[1]]) as usize
};
Ok(plen)
}
pub fn current_data_chunk_remaining(&self) -> (u64, usize) {
match self.state {
DecryptReadState::BufferedData { pos } => (self.data_chunk_count, self.buffer.len() - pos),
_ => (self.data_chunk_count, 0),
}
}
pub fn user_key(&self) -> Option<&[u8]> {
self.user_key.as_deref()
}
pub fn handshaked(&self) -> bool {
self.has_handshaked
}
}
enum EncryptWriteState {
AssembleHeader,
AssemblePacket,
Writing { pos: usize },
}
pub struct EncryptedWriter {
stream_ty: StreamType,
cipher: TcpCipher,
method: CipherKind,
buffer: BytesMut,
state: EncryptWriteState,
salt: Bytes,
request_salt: Option<Bytes>,
}
impl EncryptedWriter {
pub fn new(stream_ty: StreamType, method: CipherKind, key: &[u8], nonce: &[u8]) -> EncryptedWriter {
static EMPTY_IDENTITY: [Bytes; 0] = [];
EncryptedWriter::with_identity(stream_ty, method, key, nonce, &EMPTY_IDENTITY)
}
pub fn with_identity(
stream_ty: StreamType,
method: CipherKind,
key: &[u8],
nonce: &[u8],
identity_keys: &[Bytes],
) -> EncryptedWriter {
let mut buffer = BytesMut::with_capacity(nonce.len() + identity_keys.len() * 16);
buffer.put(nonce);
#[inline]
fn make_eih(method: CipherKind, sub_key: &[u8], ipsk: &[u8], buffer: &mut BytesMut) {
let ipsk_hash = blake3::hash(ipsk);
let ipsk_plain_text = &ipsk_hash.as_bytes()[0..16];
match method {
CipherKind::AEAD2022_BLAKE3_AES_128_GCM => {
let enc_key = &sub_key[0..16];
let cipher = Aes128::new_from_slice(enc_key).expect("AES-128");
let ipsk_plain_text = Block::from_slice(ipsk_plain_text);
let mut block = Block::from([0u8; 16]);
cipher.encrypt_block_b2b(ipsk_plain_text, &mut block);
trace!(
"client EIH {:?}, hash: {:?}",
ByteStr::new(block.as_slice()),
ByteStr::new(ipsk_plain_text)
);
buffer.put(block.as_slice());
}
CipherKind::AEAD2022_BLAKE3_AES_256_GCM => {
let enc_key = &sub_key[0..32];
let cipher = Aes256::new_from_slice(enc_key).expect("AES-256");
let ipsk_plain_text = Block::from_slice(ipsk_plain_text);
let mut block = Block::from([0u8; 16]);
cipher.encrypt_block_b2b(ipsk_plain_text, &mut block);
trace!(
"client EIH {:?}, hash: {:?}",
ByteStr::new(block.as_slice()),
ByteStr::new(ipsk_plain_text)
);
buffer.put(block.as_slice());
}
_ => unreachable!("{} doesn't support EIH", method),
}
}
if stream_ty == StreamType::Client && method_support_eih(method) {
let mut sub_key: Option<[u8; blake3::OUT_LEN]> = None;
for ipsk in identity_keys {
if let Some(ref sub_key) = sub_key {
make_eih(method, sub_key, ipsk, &mut buffer);
}
let key_material = [ipsk, nonce].concat();
sub_key = Some(blake3::derive_key(AEAD2022_EIH_SUBKEY_CONTEXT, &key_material));
}
if let Some(ref sub_key) = sub_key {
make_eih(method, sub_key, key, &mut buffer);
}
}
EncryptedWriter {
stream_ty,
cipher: TcpCipher::new(method, key, nonce),
method,
buffer,
state: EncryptWriteState::AssembleHeader,
salt: Bytes::copy_from_slice(nonce),
request_salt: None,
}
}
pub fn salt(&self) -> &[u8] {
self.salt.as_ref()
}
pub fn set_request_salt(&mut self, request_salt: Bytes) {
debug_assert!(self.stream_ty == StreamType::Server);
self.request_salt = Some(request_salt);
}
pub fn reset_cipher_with_key(&mut self, key: &[u8]) {
self.cipher = TcpCipher::new(self.method, key, &self.salt);
}
pub fn poll_write_encrypted<S>(
&mut self,
cx: &mut task::Context<'_>,
stream: &mut S,
mut buf: &[u8],
) -> Poll<io::Result<usize>>
where
S: AsyncWrite + Unpin + ?Sized,
{
if buf.len() > MAX_PACKET_SIZE {
buf = &buf[..MAX_PACKET_SIZE];
}
loop {
match self.state {
EncryptWriteState::AssembleHeader => {
let request_salt_len = match self.request_salt {
None => 0,
Some(ref salt) => salt.len(),
};
let header_len = 1 + 8 + request_salt_len + 2 + self.cipher.tag_len();
self.buffer.reserve(header_len);
let mbuf = &mut self.buffer.chunk_mut()[..header_len];
let mbuf = unsafe { slice::from_raw_parts_mut(mbuf.as_mut_ptr(), mbuf.len()) };
let stream_ty = match self.stream_ty {
StreamType::Client => 0,
StreamType::Server => 1,
};
self.buffer.put_u8(stream_ty);
self.buffer.put_u64(get_now_timestamp());
if let Some(ref salt) = self.request_salt {
self.buffer.put_slice(salt);
}
self.buffer.put_u16(buf.len() as u16);
self.cipher.encrypt_packet(mbuf);
unsafe { self.buffer.advance_mut(self.cipher.tag_len()) };
let data_size = buf.len() + self.cipher.tag_len();
self.buffer.reserve(data_size);
let mbuf = &mut self.buffer.chunk_mut()[..data_size];
let mbuf = unsafe { slice::from_raw_parts_mut(mbuf.as_mut_ptr(), mbuf.len()) };
self.buffer.put_slice(buf);
self.cipher.encrypt_packet(mbuf);
unsafe { self.buffer.advance_mut(self.cipher.tag_len()) };
self.state = EncryptWriteState::Writing { pos: 0 };
}
EncryptWriteState::AssemblePacket => {
let length_size = 2 + self.cipher.tag_len();
self.buffer.reserve(length_size);
let mbuf = &mut self.buffer.chunk_mut()[..length_size];
let mbuf = unsafe { slice::from_raw_parts_mut(mbuf.as_mut_ptr(), mbuf.len()) };
self.buffer.put_u16(buf.len() as u16);
self.cipher.encrypt_packet(mbuf);
unsafe { self.buffer.advance_mut(self.cipher.tag_len()) };
let data_size = buf.len() + self.cipher.tag_len();
self.buffer.reserve(data_size);
let mbuf = &mut self.buffer.chunk_mut()[..data_size];
let mbuf = unsafe { slice::from_raw_parts_mut(mbuf.as_mut_ptr(), mbuf.len()) };
self.buffer.put_slice(buf);
self.cipher.encrypt_packet(mbuf);
unsafe { self.buffer.advance_mut(self.cipher.tag_len()) };
self.state = EncryptWriteState::Writing { pos: 0 };
}
EncryptWriteState::Writing { ref mut pos } => {
while *pos < self.buffer.len() {
let n = ready!(Pin::new(&mut *stream).poll_write(cx, &self.buffer[*pos..]))?;
if n == 0 {
return Err(ErrorKind::UnexpectedEof.into()).into();
}
*pos += n;
}
self.state = EncryptWriteState::AssemblePacket;
self.buffer.clear();
return Ok(buf.len()).into();
}
}
}
}
}