use std::{
io,
marker::Unpin,
pin::Pin,
sync::Arc,
task::{self, Poll},
};
use byte_string::ByteStr;
use bytes::Bytes;
use futures::ready;
use log::trace;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::{
config::ServerUserManager,
context::Context,
crypto::{CipherCategory, CipherKind},
};
use super::aead::{DecryptedReader as AeadDecryptedReader, EncryptedWriter as AeadEncryptedWriter};
#[cfg(feature = "aead-cipher-2022")]
use super::aead_2022::{DecryptedReader as Aead2022DecryptedReader, EncryptedWriter as Aead2022EncryptedWriter};
#[cfg(feature = "stream-cipher")]
use super::stream::{DecryptedReader as StreamDecryptedReader, EncryptedWriter as StreamEncryptedWriter};
#[derive(thiserror::Error, Debug)]
pub enum ProtocolError {
#[error(transparent)]
IoError(#[from] io::Error),
#[cfg(feature = "stream-cipher")]
#[error(transparent)]
StreamError(#[from] super::stream::ProtocolError),
#[error(transparent)]
AeadError(#[from] super::aead::ProtocolError),
#[cfg(feature = "aead-cipher-2022")]
#[error(transparent)]
Aead2022Error(#[from] super::aead_2022::ProtocolError),
}
pub type ProtocolResult<T> = Result<T, ProtocolError>;
impl From<ProtocolError> for io::Error {
fn from(e: ProtocolError) -> io::Error {
match e {
ProtocolError::IoError(err) => err,
#[cfg(feature = "stream-cipher")]
ProtocolError::StreamError(err) => err.into(),
ProtocolError::AeadError(err) => err.into(),
#[cfg(feature = "aead-cipher-2022")]
ProtocolError::Aead2022Error(err) => err.into(),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum StreamType {
Client,
Server,
}
#[allow(clippy::large_enum_variant)]
pub enum DecryptedReader {
None,
Aead(AeadDecryptedReader),
#[cfg(feature = "stream-cipher")]
Stream(StreamDecryptedReader),
#[cfg(feature = "aead-cipher-2022")]
Aead2022(Aead2022DecryptedReader),
}
impl DecryptedReader {
pub fn new(stream_ty: StreamType, method: CipherKind, key: &[u8]) -> DecryptedReader {
DecryptedReader::with_user_manager(stream_ty, method, key, None)
}
pub fn with_user_manager(
stream_ty: StreamType,
method: CipherKind,
key: &[u8],
user_manager: Option<Arc<ServerUserManager>>,
) -> DecryptedReader {
if cfg!(not(feature = "aead-cipher-2022")) {
let _ = stream_ty;
let _ = user_manager;
}
match method.category() {
#[cfg(feature = "stream-cipher")]
CipherCategory::Stream => DecryptedReader::Stream(StreamDecryptedReader::new(method, key)),
CipherCategory::Aead => DecryptedReader::Aead(AeadDecryptedReader::new(method, key)),
CipherCategory::None => DecryptedReader::None,
#[cfg(feature = "aead-cipher-2022")]
CipherCategory::Aead2022 => DecryptedReader::Aead2022(Aead2022DecryptedReader::with_user_manager(
stream_ty,
method,
key,
user_manager,
)),
}
}
#[inline]
pub fn poll_read_decrypted<S>(
&mut self,
cx: &mut task::Context<'_>,
context: &Context,
stream: &mut S,
buf: &mut ReadBuf<'_>,
) -> Poll<ProtocolResult<()>>
where
S: AsyncRead + Unpin + ?Sized,
{
match *self {
#[cfg(feature = "stream-cipher")]
DecryptedReader::Stream(ref mut reader) => {
reader.poll_read_decrypted(cx, context, stream, buf).map_err(Into::into)
}
DecryptedReader::Aead(ref mut reader) => {
reader.poll_read_decrypted(cx, context, stream, buf).map_err(Into::into)
}
DecryptedReader::None => Pin::new(stream).poll_read(cx, buf).map_err(Into::into),
#[cfg(feature = "aead-cipher-2022")]
DecryptedReader::Aead2022(ref mut reader) => {
reader.poll_read_decrypted(cx, context, stream, buf).map_err(Into::into)
}
}
}
pub fn nonce(&self) -> Option<&[u8]> {
match *self {
#[cfg(feature = "stream-cipher")]
DecryptedReader::Stream(ref reader) => reader.iv(),
DecryptedReader::Aead(ref reader) => reader.salt(),
DecryptedReader::None => None,
#[cfg(feature = "aead-cipher-2022")]
DecryptedReader::Aead2022(ref reader) => reader.salt(),
}
}
pub fn request_nonce(&self) -> Option<&[u8]> {
match *self {
#[cfg(feature = "stream-cipher")]
DecryptedReader::Stream(..) => None,
DecryptedReader::Aead(..) => None,
DecryptedReader::None => None,
#[cfg(feature = "aead-cipher-2022")]
DecryptedReader::Aead2022(ref reader) => reader.request_salt(),
}
}
pub fn user_key(&self) -> Option<&[u8]> {
match *self {
#[cfg(feature = "stream-cipher")]
DecryptedReader::Stream(..) => None,
DecryptedReader::Aead(..) => None,
DecryptedReader::None => None,
#[cfg(feature = "aead-cipher-2022")]
DecryptedReader::Aead2022(ref reader) => reader.user_key(),
}
}
pub fn handshaked(&self) -> bool {
match *self {
#[cfg(feature = "stream-cipher")]
DecryptedReader::Stream(ref reader) => reader.handshaked(),
DecryptedReader::Aead(ref reader) => reader.handshaked(),
DecryptedReader::None => true,
#[cfg(feature = "aead-cipher-2022")]
DecryptedReader::Aead2022(ref reader) => reader.handshaked(),
}
}
}
pub enum EncryptedWriter {
None,
Aead(AeadEncryptedWriter),
#[cfg(feature = "stream-cipher")]
Stream(StreamEncryptedWriter),
#[cfg(feature = "aead-cipher-2022")]
Aead2022(Aead2022EncryptedWriter),
}
impl EncryptedWriter {
pub fn new(stream_ty: StreamType, method: CipherKind, key: &[u8], nonce: &[u8]) -> EncryptedWriter {
if cfg!(not(feature = "aead-cipher-2022")) {
let _ = stream_ty;
}
match method.category() {
#[cfg(feature = "stream-cipher")]
CipherCategory::Stream => EncryptedWriter::Stream(StreamEncryptedWriter::new(method, key, nonce)),
CipherCategory::Aead => EncryptedWriter::Aead(AeadEncryptedWriter::new(method, key, nonce)),
CipherCategory::None => EncryptedWriter::None,
#[cfg(feature = "aead-cipher-2022")]
CipherCategory::Aead2022 => {
EncryptedWriter::Aead2022(Aead2022EncryptedWriter::new(stream_ty, method, key, nonce))
}
}
}
pub fn with_identity(
stream_ty: StreamType,
method: CipherKind,
key: &[u8],
nonce: &[u8],
identity_keys: &[Bytes],
) -> EncryptedWriter {
if cfg!(not(feature = "aead-cipher-2022")) {
let _ = stream_ty;
let _ = identity_keys;
}
match method.category() {
#[cfg(feature = "stream-cipher")]
CipherCategory::Stream => EncryptedWriter::Stream(StreamEncryptedWriter::new(method, key, nonce)),
CipherCategory::Aead => EncryptedWriter::Aead(AeadEncryptedWriter::new(method, key, nonce)),
CipherCategory::None => EncryptedWriter::None,
#[cfg(feature = "aead-cipher-2022")]
CipherCategory::Aead2022 => EncryptedWriter::Aead2022(Aead2022EncryptedWriter::with_identity(
stream_ty,
method,
key,
nonce,
identity_keys,
)),
}
}
#[inline]
pub fn poll_write_encrypted<S>(
&mut self,
cx: &mut task::Context<'_>,
stream: &mut S,
buf: &[u8],
) -> Poll<ProtocolResult<usize>>
where
S: AsyncWrite + Unpin + ?Sized,
{
match *self {
#[cfg(feature = "stream-cipher")]
EncryptedWriter::Stream(ref mut writer) => writer.poll_write_encrypted(cx, stream, buf).map_err(Into::into),
EncryptedWriter::Aead(ref mut writer) => writer.poll_write_encrypted(cx, stream, buf).map_err(Into::into),
EncryptedWriter::None => Pin::new(stream).poll_write(cx, buf).map_err(Into::into),
#[cfg(feature = "aead-cipher-2022")]
EncryptedWriter::Aead2022(ref mut writer) => {
writer.poll_write_encrypted(cx, stream, buf).map_err(Into::into)
}
}
}
pub fn nonce(&self) -> &[u8] {
match *self {
#[cfg(feature = "stream-cipher")]
EncryptedWriter::Stream(ref writer) => writer.iv(),
EncryptedWriter::Aead(ref writer) => writer.salt(),
EncryptedWriter::None => &[],
#[cfg(feature = "aead-cipher-2022")]
EncryptedWriter::Aead2022(ref writer) => writer.salt(),
}
}
pub fn set_request_nonce(&mut self, request_nonce: Bytes) {
match *self {
#[cfg(feature = "aead-cipher-2022")]
EncryptedWriter::Aead2022(ref mut writer) => writer.set_request_salt(request_nonce),
_ => {
let _ = request_nonce;
panic!("only AEAD-2022 cipher could send request salt");
}
}
}
pub fn reset_cipher_with_key(&mut self, key: &[u8]) {
match *self {
#[cfg(feature = "aead-cipher-2022")]
EncryptedWriter::Aead2022(ref mut writer) => writer.reset_cipher_with_key(key),
_ => {
let _ = key;
panic!("only AEAD-2022 cipher could authenticate with multiple users");
}
}
}
}
pub struct CryptoStream<S> {
stream: S,
dec: DecryptedReader,
enc: EncryptedWriter,
method: CipherKind,
has_handshaked: bool,
}
impl<S> CryptoStream<S> {
pub fn from_stream(
context: &Context,
stream: S,
stream_ty: StreamType,
method: CipherKind,
key: &[u8],
) -> CryptoStream<S> {
static EMPTY_IDENTITY: [Bytes; 0] = [];
CryptoStream::from_stream_with_identity(context, stream, stream_ty, method, key, &EMPTY_IDENTITY, None)
}
pub fn from_stream_with_identity(
context: &Context,
stream: S,
stream_ty: StreamType,
method: CipherKind,
key: &[u8],
identity_keys: &[Bytes],
user_manager: Option<Arc<ServerUserManager>>,
) -> CryptoStream<S> {
let category = method.category();
if category == CipherCategory::None {
return CryptoStream::<S>::new_none(stream, method);
}
let prev_len = match category {
#[cfg(feature = "stream-cipher")]
CipherCategory::Stream => method.iv_len(),
CipherCategory::Aead => method.salt_len(),
CipherCategory::None => 0,
#[cfg(feature = "aead-cipher-2022")]
CipherCategory::Aead2022 => method.salt_len(),
};
let iv = match category {
#[cfg(feature = "stream-cipher")]
CipherCategory::Stream => {
let mut local_iv = vec![0u8; prev_len];
context.generate_nonce(method, &mut local_iv, true);
trace!("generated Stream cipher IV {:?}", ByteStr::new(&local_iv));
local_iv
}
CipherCategory::Aead => {
let mut local_salt = vec![0u8; prev_len];
context.generate_nonce(method, &mut local_salt, true);
trace!("generated AEAD cipher salt {:?}", ByteStr::new(&local_salt));
local_salt
}
CipherCategory::None => Vec::new(),
#[cfg(feature = "aead-cipher-2022")]
CipherCategory::Aead2022 => {
let mut local_salt = vec![0u8; prev_len];
context.generate_nonce(method, &mut local_salt, false);
trace!("generated AEAD cipher salt {:?}", ByteStr::new(&local_salt));
local_salt
}
};
CryptoStream {
stream,
dec: DecryptedReader::with_user_manager(stream_ty, method, key, user_manager),
enc: EncryptedWriter::with_identity(stream_ty, method, key, &iv, identity_keys),
method,
has_handshaked: false,
}
}
fn new_none(stream: S, method: CipherKind) -> CryptoStream<S> {
CryptoStream {
stream,
dec: DecryptedReader::None,
enc: EncryptedWriter::None,
method,
has_handshaked: false,
}
}
pub fn get_ref(&self) -> &S {
&self.stream
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.stream
}
pub fn into_inner(self) -> S {
self.stream
}
#[inline]
pub fn received_nonce(&self) -> Option<&[u8]> {
self.dec.nonce()
}
#[inline]
pub fn sent_nonce(&self) -> &[u8] {
self.enc.nonce()
}
#[inline]
pub fn received_request_nonce(&self) -> Option<&[u8]> {
self.dec.request_nonce()
}
#[inline]
pub fn set_request_nonce(&mut self, request_nonce: &[u8]) {
self.enc.set_request_nonce(Bytes::copy_from_slice(request_nonce))
}
#[cfg(feature = "aead-cipher-2022")]
pub(crate) fn set_request_nonce_with_received(&mut self) -> bool {
match self.dec.nonce() {
None => false,
Some(nonce) => {
self.enc.set_request_nonce(Bytes::copy_from_slice(nonce));
true
}
}
}
#[cfg(feature = "aead-cipher-2022")]
pub(crate) fn current_data_chunk_remaining(&self) -> (u64, usize) {
if let DecryptedReader::Aead2022(ref dec) = self.dec {
dec.current_data_chunk_remaining()
} else {
panic!("only AEAD-2022 protocol has data chunk counter");
}
}
}
pub trait CryptoRead {
fn poll_read_decrypted(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
context: &Context,
buf: &mut ReadBuf<'_>,
) -> Poll<ProtocolResult<()>>;
}
pub trait CryptoWrite {
fn poll_write_encrypted(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<ProtocolResult<usize>>;
}
impl<S> CryptoStream<S> {
pub fn method(&self) -> CipherKind {
self.method
}
}
impl<S> CryptoRead for CryptoStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
fn poll_read_decrypted(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
context: &Context,
buf: &mut ReadBuf<'_>,
) -> Poll<ProtocolResult<()>> {
let CryptoStream {
ref mut dec,
ref mut enc,
ref mut stream,
ref mut has_handshaked,
..
} = *self;
ready!(dec.poll_read_decrypted(cx, context, stream, buf))?;
if !*has_handshaked && dec.handshaked() {
*has_handshaked = true;
if let Some(user_key) = dec.user_key() {
enc.reset_cipher_with_key(user_key);
}
}
Ok(()).into()
}
}
impl<S> CryptoWrite for CryptoStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
fn poll_write_encrypted(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<ProtocolResult<usize>> {
let CryptoStream {
ref mut enc,
ref mut stream,
..
} = *self;
enc.poll_write_encrypted(cx, stream, buf)
}
}
impl<S> CryptoStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
pub fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<ProtocolResult<()>> {
Pin::new(&mut self.stream).poll_flush(cx).map_err(Into::into)
}
#[inline]
pub fn poll_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<ProtocolResult<()>> {
Pin::new(&mut self.stream).poll_shutdown(cx).map_err(Into::into)
}
}