use std::{
marker::Unpin,
pin::Pin,
task::{Context, Poll},
io,
};
use futures::{pin_mut, ready};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use cipher::{SyncStreamCipher, SyncStreamCipherSeek};
use salsa20::Salsa20;
const BUF_SIZE: usize = 64;
#[derive(Debug)]
pub struct SalsaStream<S> {
stream: S,
read_cipher: Salsa20,
write_cipher: Salsa20,
}
impl<S> SalsaStream<S> {
pub fn new(stream: S, read_cipher: Salsa20, write_cipher: Salsa20) -> Self {
Self {
stream,
read_cipher,
write_cipher,
}
}
}
impl<S> AsyncRead for SalsaStream<S>
where S: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let starting_fill = buf.filled().len();
let stream = &mut self.stream;
pin_mut!(stream);
ready!(stream.poll_read(ctx, buf))?;
let (_, to_w) = buf.filled_mut().split_at_mut(starting_fill);
let written = to_w.len();
let cipherd = self.read_cipher
.try_apply_keystream(to_w)
.map(|()| written)
.map_err(|e| io::Error::new(io::ErrorKind::Other, Box::new(e)))?;
assert!(cipherd == written);
Poll::Ready(Ok(()))
}
}
impl<S> AsyncWrite for SalsaStream<S>
where S: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut buffer: [u8; BUF_SIZE] = [0; BUF_SIZE];
let mut chunks = buf.chunks(BUF_SIZE);
let mut total = 0;
while let Some(block) = chunks.next() {
let (mut out_buf, _) = buffer.split_at_mut(block.len());
out_buf.copy_from_slice(block);
let result = {
self.write_cipher
.try_apply_keystream(&mut out_buf)
.map_err(|e| io::Error::new(io::ErrorKind::Other, Box::new(e)))
};
if let Err(e) = result {
return Poll::Ready(Err(e));
}
let write_result = {
let stream = &mut self.stream;
pin_mut!(stream);
stream.poll_write(cx, &mut out_buf)
};
total += match write_result {
Poll::Ready(Ok(written)) if written < out_buf.len() => {
let delta = out_buf.len() - written;
let maybie_current_position = self.write_cipher
.try_current_pos::<usize>()
.map_err(|e| io::Error::new(io::ErrorKind::Other, Box::new(e)));
let current_position = match maybie_current_position {
Ok(pos) => pos,
Err(e) => return Poll::Ready(Err(e)),
};
let new_position = current_position - delta;
let result = self.write_cipher
.try_seek(new_position)
.map(|()| written + total)
.map_err(|e| io::Error::new(io::ErrorKind::Other, Box::new(e)));
return Poll::Ready(result);
},
Poll::Ready(Ok(written)) => written,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {
let maybie_current_position = self.write_cipher
.try_current_pos::<usize>()
.map_err(|e| io::Error::new(io::ErrorKind::Other, Box::new(e)));
let current_position = match maybie_current_position {
Ok(pos) => pos,
Err(e) => return Poll::Ready(Err(e)),
};
let new_position = current_position - out_buf.len();
let result = self.write_cipher
.try_seek(new_position)
.map(|()| total)
.map_err(|e| io::Error::new(io::ErrorKind::Other, Box::new(e)));
return match result {
Ok(total) if total > 0 => {
Poll::Ready(Ok(total))
},
Ok(_) => {
Poll::Pending
},
Err(e) => Poll::Ready(Err(e)),
};
},
};
}
Poll::Ready(Ok(total))
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
let stream = &mut self.stream;
pin_mut!(stream);
stream.poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
let stream = &mut self.stream;
pin_mut!(stream);
stream.poll_shutdown(cx)
}
}
#[cfg(test)]
mod test {
use std::io::Cursor;
use futures::{stream, StreamExt};
use tokio::io::{split, copy, duplex, AsyncReadExt, AsyncWriteExt};
use cipher::NewStreamCipher;
use salsa20::{Key, Nonce};
use base64::decode_config_slice;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaChaRng;
use super::*;
const B64KEY_1: &str = "KarnNMxbGDlnKgR+HaSxoU4LA7zQ1BlJB5qgg+BkJys=";
const B64KEY_2: &str = "UkkPhhlMVMkitLuZrMonKgtM7KewjPP5WzYQ4bE5lDM=";
const NONCE_1: [u8; 8] = (1001 as u64).to_le_bytes();
const NONCE_2: [u8; 8] = (2002 as u64).to_le_bytes();
fn to_key(b64: &str) -> [u8; 32] {
let mut key: [u8; 32] = [0; 32];
decode_config_slice(b64, base64::STANDARD, &mut key).unwrap();
key
}
fn salsa_setup() -> (Salsa20, Salsa20) {
let key = Key::clone_from_slice(&to_key(B64KEY_1));
let nonce = Nonce::clone_from_slice(&NONCE_1);
let salsa1 = Salsa20::new(&key, &nonce);
let key = Key::clone_from_slice(&to_key(B64KEY_2));
let nonce = Nonce::clone_from_slice(&NONCE_2);
let salsa2 = Salsa20::new(&key, &nonce);
(salsa1, salsa2)
}
fn pseudorandom_data(len: usize, seed: u64) -> Vec<u8> {
let mut rng = ChaChaRng::seed_from_u64(seed);
(0..len)
.map(|_| rng.gen::<u8>())
.collect()
}
#[test]
fn check_salsa_cipher() {
let (mut encrypt, _) = salsa_setup();
let (mut decrypt, _) = salsa_setup();
let check = pseudorandom_data(1024 * 1024, 59);
let mut message = check.clone();
encrypt.try_apply_keystream(&mut message).unwrap();
assert!(&message[..] != &check[..]);
decrypt.try_apply_keystream(&mut message).unwrap();
assert!(&message[..] == &check[..]);
}
async fn echo_servlet<S>(salsa_s: SalsaStream<S>)
where S: AsyncWrite + AsyncRead + Unpin,
{
let (mut reader, mut writer) = split(salsa_s);
copy(&mut reader, &mut writer)
.await
.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn salsa_stream_check() {
let (writer_e, writer_d) = salsa_setup();
let (reader_e, reader_d) = salsa_setup();
let (client_s, server_s) = duplex(64);
let sstream_c = SalsaStream::new(client_s, reader_d, writer_e);
let sstream_s = SalsaStream::new(server_s, writer_d, reader_e);
let length = 1024 * 1024;
let message = pseudorandom_data(length as usize, 110);
let write_buf = message.clone();
tokio::spawn(echo_servlet(sstream_s));
let (reader, mut writer) = split(sstream_c);
let write_all = async move {
writer
.write_all(write_buf.as_slice())
.await
.unwrap();
};
let read_all = async move {
let mut buffer = Vec::new();
reader
.take(length)
.read_to_end(&mut buffer)
.await
.unwrap();
buffer
};
tokio::spawn(write_all);
let echoed = tokio::spawn(read_all)
.await
.unwrap();
println!("Length: {}, Message: {:?}", message.len(), &message);
println!("Length: {}, Echoed : {:?}", echoed.len(), &echoed);
assert!(message.len() == echoed.len());
assert!(message == echoed);
}
async fn async_read_test(length: usize, seed: u64) {
eprintln!("Length: {}, Seed: {}", &length, &seed);
let message = pseudorandom_data(length, seed);
let mut dummy_s = message.clone();
let (read_c, write_c) = salsa_setup();
let (mut cipher, _) = salsa_setup();
cipher.apply_keystream(&mut dummy_s.as_mut());
let mut salsa_stream = SalsaStream::new(dummy_s.as_slice(), read_c, write_c);
let mut buffer = Vec::new();
salsa_stream
.read_to_end(&mut buffer)
.await
.unwrap();
assert!(buffer.len() == message.len());
assert!(message == buffer);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn salsa_stream_async_read() {
let read_tests = stream::unfold(0, |count| async move {
if count < 100 {
let size = count ^ 2 * 5;
let seed = size * size;
Some(((seed, size), size + 1))
} else {
None
}
});
read_tests
.for_each(|(size, seed)| async move {
async_read_test(size, seed as u64).await;
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn salsa_stream_async_write_cursor() {
let message = pseudorandom_data(65, 123);
let (read_c, write_c) = salsa_setup();
let (_, mut cipher) = salsa_setup();
let pipe_b = vec![0; 64].into_boxed_slice();
let pipe = Cursor::new(pipe_b);
let mut salsa_stream = SalsaStream::new(pipe, read_c, write_c);
let mut total = 0;
let mut output = Vec::new();
while total < message.len() {
let (_, left) = message
.as_slice()
.split_at(total);
let written = salsa_stream
.write(left)
.await
.unwrap();
let (to_copy, _) = salsa_stream.stream
.get_ref()
.split_at(written);
output.extend(to_copy);
salsa_stream.stream.set_position(0);
total += written;
}
assert!(output.len() == message.len());
cipher.apply_keystream(output.as_mut());
assert!(output == message);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn salsa_stream_async_write_duplex() {
let message = pseudorandom_data(65, 123);
let (read_c, write_c) = salsa_setup();
let (_, mut cipher) = salsa_setup();
let (mut client, server) = duplex(64);
let mut salsa_stream = SalsaStream::new(server, read_c, write_c);
let mut output = Vec::new();
let mut total = 0;
while total < message.len() {
let (_, left) = message
.as_slice()
.split_at(total);
let written = salsa_stream
.write(left)
.await
.unwrap();
let mut temp_buf = Vec::with_capacity(0);
temp_buf.resize(written, 0);
client
.read(temp_buf.as_mut_slice())
.await
.unwrap();
output.extend(temp_buf);
total += written;
}
assert!(output.len() == message.len());
cipher.apply_keystream(output.as_mut());
assert!(output == message);
}
}