pub use tweetnacly::*;
pub use expry::*;
use bytes::{BytesMut, Buf};
use tokio_util::codec::{Decoder, Framed, FramedRead};
use tokio::io::{AsyncWriteExt as _, AsyncWrite, AsyncRead};
use tokio_stream::Stream;
pub const MAX_NACL_RECEIVE_BUFFER: u64 = 65536;
pub const MAX_NACL_SEND_BUFFER: usize = 32768;
pub struct NaClCodec {
recv_cache: CryptoBoxCache,
recv_nonce: Nonce,
server: bool,
}
impl NaClCodec {
pub fn new_server(sk: SecretBoxKey, public_key_of_session: &PublicBoxKey) -> Self {
Self {
recv_cache: crypto_box_prepare(&sk, public_key_of_session),
recv_nonce: Nonce{data: [0u8; tweetnacly::bindings::crypto_box_curve25519xsalsa20poly1305_NONCEBYTES as usize]},
server: true,
}
}
pub fn new_client(sk: SecretBoxKey, public_key_of_session: &PublicBoxKey) -> Self {
Self {
recv_cache: crypto_box_prepare(&sk, public_key_of_session),
recv_nonce: Nonce{data: [255u8; tweetnacly::bindings::crypto_box_curve25519xsalsa20poly1305_NONCEBYTES as usize]},
server: false,
}
}
}
impl tokio_util::codec::Decoder for NaClCodec {
type Item = BytesMut;
type Error = std::io::Error;
fn decode(
&mut self,
src: &mut BytesMut
) -> Result<Option<Self::Item>, Self::Error> {
let mut reader = RawReader::with(src);
if let Ok(frame_size) = reader.read_var_u64() {
let remaining = reader.len();
if remaining >= frame_size as usize {
src.advance(src.len() - remaining); let mut data = src.split_to(frame_size as usize);
if data.len() < tweetnacly::bindings::crypto_box_MACBYTES as usize {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "NaCl packet did not contain enough bytes for MAC"));
}
let mut raw_data = [0u8; tweetnacly::bindings::crypto_box_MACBYTES as usize];
raw_data.copy_from_slice(&data[0..tweetnacly::bindings::crypto_box_MACBYTES as usize]);
let tag = AuthenticationTag{data: raw_data};
data.advance(tweetnacly::bindings::crypto_box_MACBYTES as usize);
crypto_box_open_in_place(&mut data, &tag, &self.recv_nonce, &self.recv_cache).map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "crypto error"))?;
if self.server {
increase_nonce(&mut self.recv_nonce.data);
} else {
decrease_nonce(&mut self.recv_nonce.data);
}
assert_eq!(frame_size as usize, tweetnacly::bindings::crypto_box_MACBYTES as usize + data.len());
return Ok(Some(data));
} else {
if frame_size > MAX_NACL_RECEIVE_BUFFER {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "max frame size exceeded"));
}
if src.capacity() < MAX_NACL_RECEIVE_BUFFER as usize{
src.reserve(MAX_NACL_RECEIVE_BUFFER as usize - src.len());
}
src.reserve(frame_size as usize - remaining);
}
}
Ok(None)
}
}
pub struct FramedReadWrapper<T,U>
where
T: AsyncRead,
U: Decoder,
{
framed: FramedRead<T,U>,
current_read: Option<BytesMut>,
}
impl<T,U> FramedReadWrapper<T,U>
where
T: AsyncRead,
U: Decoder,
{
pub fn new(framed: FramedRead<T,U>) -> Self {
Self {
framed,
current_read: None,
}
}
pub fn get_mut(&mut self) -> &mut T {
self.framed.get_mut()
}
}
impl<T,U> AsyncRead for FramedReadWrapper<T,U>
where
T: AsyncRead,
U: Decoder<Item=BytesMut,Error=std::io::Error>,
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let mut retval = std::task::Poll::Pending;
let this = unsafe { self.get_unchecked_mut() };
let current_read = &mut this.current_read;
if let Some(current_data) = current_read {
let len = std::cmp::min(current_data.len(),buf.remaining());
buf.put_slice(¤t_data[0..len]);
if len < current_data.len() {
current_data.advance(len);
return std::task::Poll::Ready(Ok(()));
}
retval = std::task::Poll::Ready(Ok(()));
*current_read = None;
}
while let std::task::Poll::Ready(Some(result)) = Stream::poll_next(unsafe { std::pin::Pin::new_unchecked(&mut this.framed) }, cx) {
match result {
Ok(mut current_data) => {
let len = std::cmp::min(current_data.len(),buf.remaining());
buf.put_slice(¤t_data[0..len]);
retval = std::task::Poll::Ready(Ok(()));
if len < current_data.len() {
current_data.advance(len);
this.current_read = Some(current_data);
return retval;
}
},
Err(err) => {
return std::task::Poll::Ready(Err(err));
},
}
}
retval
}
}
pub struct FramedWrapper<T,U>
where
T: AsyncRead + AsyncWrite,
U: Decoder,
{
framed: Framed<T,U>,
current_read: Option<BytesMut>,
}
impl<T,U> FramedWrapper<T,U>
where
T: AsyncRead + AsyncWrite,
U: Decoder,
{
pub fn new(framed: Framed<T,U>) -> Self {
Self {
framed,
current_read: None,
}
}
pub fn get_mut(&mut self) -> &mut T {
self.framed.get_mut()
}
}
impl<T,U> AsyncRead for FramedWrapper<T,U>
where
T: AsyncRead + AsyncWrite,
U: Decoder<Item=BytesMut,Error=std::io::Error>,
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let mut retval = std::task::Poll::Pending;
let this = unsafe { self.get_unchecked_mut() };
let current_read = &mut this.current_read;
if let Some(current_data) = current_read {
let len = std::cmp::min(current_data.len(),buf.remaining());
buf.put_slice(¤t_data[0..len]);
if len < current_data.len() {
current_data.advance(len);
return std::task::Poll::Ready(Ok(()));
}
retval = std::task::Poll::Ready(Ok(()));
*current_read = None;
}
while let std::task::Poll::Ready(Some(result)) = Stream::poll_next(unsafe { std::pin::Pin::new_unchecked(&mut this.framed) }, cx) {
match result {
Ok(mut current_data) => {
let len = std::cmp::min(current_data.len(),buf.remaining());
buf.put_slice(¤t_data[0..len]);
retval = std::task::Poll::Ready(Ok(()));
if len < current_data.len() {
current_data.advance(len);
this.current_read = Some(current_data);
return retval;
}
},
Err(err) => {
return std::task::Poll::Ready(Err(err));
},
}
}
retval
}
}
impl<T,U> AsyncWrite for FramedWrapper<T,U>
where
T: AsyncRead + AsyncWrite,
U: Decoder<Item=BytesMut>,
{
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let pin_framed = unsafe {
std::pin::Pin::new_unchecked(std::pin::Pin::into_inner_unchecked(self).framed.get_mut())
};
AsyncWrite::poll_write(pin_framed, cx, buf)
}
fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
let pin_framed = unsafe {
std::pin::Pin::new_unchecked(std::pin::Pin::into_inner_unchecked(self).framed.get_mut())
};
AsyncWrite::poll_flush(pin_framed, cx)
}
fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
let pin_framed = unsafe {
std::pin::Pin::new_unchecked(std::pin::Pin::into_inner_unchecked(self).framed.get_mut())
};
AsyncWrite::poll_shutdown(pin_framed, cx)
}
}
pub fn nonce_for_client() -> Nonce {
Nonce{data: [0u8; tweetnacly::bindings::crypto_box_curve25519xsalsa20poly1305_NONCEBYTES as usize]}
}
pub fn nonce_for_server() -> Nonce {
Nonce{data: [255u8; tweetnacly::bindings::crypto_box_curve25519xsalsa20poly1305_NONCEBYTES as usize]}
}
pub fn decrease_nonce<const N: usize>(numbers: &mut [u8; N]) {
for v in numbers.as_mut_slice() {
*v = (core::num::Wrapping(*v) - core::num::Wrapping(1u8)).0;
if *v != u8::MAX {
return;
}
}
}
pub fn increase_nonce<const N: usize>(numbers: &mut [u8; N]) {
for v in numbers.as_mut_slice() {
*v = (core::num::Wrapping(*v) + core::num::Wrapping(1u8)).0;
if *v != 0 {
return;
}
}
}
pub async fn send_frame<Out: AsyncWrite + std::marker::Unpin>(data: &mut [u8], send_nonce: &mut Nonce, send_cache: &CryptoBoxCache, server: bool, stream: &mut Out) -> Result<(),std::io::Error> {
for data in data.chunks_mut(MAX_NACL_SEND_BUFFER) {
let tag = crypto_box_in_place(data, send_nonce, send_cache);
let mut header = [0u8; 9 + tweetnacly::bindings::crypto_box_MACBYTES as usize];
let mut header_writer = RawWriter::with(&mut header);
header_writer.write_var_u64((tag.data.len() + data.len()) as u64).unwrap();
header_writer.write_bytes(&tag.data).unwrap();
let header = header_writer.build();
stream.write_all(header).await?;
stream.write_all(data).await?;
if server {
decrease_nonce(&mut send_nonce.data);
} else {
increase_nonce(&mut send_nonce.data);
}
}
Ok(())
}
pub struct WireCodec {
max: usize,
}
impl WireCodec {
pub fn new(max: usize) -> Self { Self { max, } }
}
impl tokio_util::codec::Decoder for WireCodec {
type Item = BytesMut;
type Error = std::io::Error;
fn decode(
&mut self,
src: &mut BytesMut
) -> Result<Option<Self::Item>, Self::Error> {
let mut reader = RawReader::with(src);
if let Ok(frame_size) = reader.read_var_u64() {
let remaining = reader.len();
if remaining >= frame_size as usize {
src.advance(src.len() - remaining); return Ok(Some(src.split_to(frame_size as usize)));
}
if frame_size > self.max as u64 {
return Err(std::io::Error::new(std::io::ErrorKind::Other, format!("max wire frame size exceeded: {}", frame_size)));
}
if src.capacity() < 4096 {
src.reserve(4096 - src.len());
}
src.reserve(frame_size as usize - remaining);
}
Ok(None)
}
}