use std::{
io,
thread,
str::FromStr,
net::SocketAddr,
sync::mpsc,
};
use futures::TryFutureExt;
use tokio::{
io::{split, copy, AsyncWriteExt, AsyncReadExt},
net::{TcpListener, TcpStream},
};
use ciph::salsa::{Psk, Acceptor, Connector, Randomness};
const PSK_B64: &str = include_str!("test.psk");
fn spawn_localhost_acceptor(psk: Psk, randomness: Randomness) -> SocketAddr {
let (addr_send, addr_recv) = mpsc::channel();
thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.build()
.unwrap();
let handle = rt.handle().clone();
let acceptor = Acceptor::new(psk, randomness);
let server_fut = async move {
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(&addr).await?;
let bound_addr = listener
.local_addr()
.unwrap();
addr_send
.send(bound_addr)
.unwrap();
loop {
let (stream, _) = listener
.accept()
.await?;
let establish = acceptor.accept(stream);
let service_fut = async move {
let stream = establish.await?;
let (mut reader, mut writer) = split(stream);
copy(&mut reader, &mut writer).await?;
Ok(()) as io::Result<()>
}
.unwrap_or_else(|err| eprintln!("Server: {:?}", err));
handle.spawn(service_fut);
}
}
.unwrap_or_else(|err: io::Error| eprintln!("Server: {:?}", err));
rt.block_on(server_fut);
});
addr_recv
.recv()
.unwrap()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn salsa_stream_cipher_works() {
const FILE: &'static [u8] = include_bytes!("../README.md");
let mut buffer: [u8; FILE.len()] = [0; FILE.len()];
let psk = Psk::from_str(PSK_B64).unwrap();
let addr = spawn_localhost_acceptor(psk.clone(), Randomness::Entropy);
let tcp_stream = TcpStream::connect(&addr)
.await
.unwrap();
let connector = Connector::new(psk, Randomness::Entropy);
let connect = connector.connect(tcp_stream);
let mut stream = connect
.await
.unwrap();
stream
.write_all(FILE)
.await
.unwrap();
stream
.flush()
.await
.unwrap();
stream
.read_exact(&mut buffer)
.await
.unwrap();
assert_eq!(&buffer[..], FILE);
}