#![deny(missing_docs)]
#![deny(missing_doc_code_examples)]
#[macro_use]
extern crate log;
use std::cell::RefCell;
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
use std::ops::{Deref, DerefMut};
use std::sync::{Arc, Mutex};
use chacha20::stream_cipher::{NewStreamCipher, SyncStreamCipher};
use chacha20::{ChaCha20, Key, Nonce};
use crossbeam_channel::{unbounded, Receiver, Sender};
use failure::{bail, Error};
use rand::rngs::OsRng;
use rand::RngCore;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
const MAGIC: u32 = 0x69421997;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone, Serialize, Deserialize)]
enum ChannelMessage<T> {
Message(T),
RawData(Vec<u8>),
RawDataStart(usize),
}
#[derive(Clone)]
enum ChannelSenderInner<T> {
Local(Sender<ChannelMessage<T>>),
Remote(Arc<Mutex<TcpStream>>),
RemoteEnc(Arc<Mutex<(TcpStream, ChaCha20)>>),
}
#[derive(Clone)]
pub struct ChannelSender<T> {
inner: ChannelSenderInner<T>,
}
enum ChannelReceiverInner<T> {
Local(Receiver<ChannelMessage<T>>),
Remote(RefCell<TcpStream>),
RemoteEnc(RefCell<(TcpStream, ChaCha20)>),
}
pub struct ChannelReceiver<T> {
inner: ChannelReceiverInner<T>,
}
impl<T> ChannelSender<T>
where
T: 'static + Send + Sync + Serialize,
{
pub fn send(&self, data: T) -> Result<()> {
match &self.inner {
ChannelSenderInner::Local(sender) => sender
.send(ChannelMessage::Message(data))
.map_err(|e| e.into()),
ChannelSenderInner::Remote(sender) => {
let mut sender = sender.lock().unwrap();
let stream = sender.deref_mut();
ChannelSender::<T>::send_remote_raw(stream, ChannelMessage::Message(data))
}
ChannelSenderInner::RemoteEnc(stream) => {
let mut stream = stream.lock().unwrap();
let (stream, enc) = stream.deref_mut();
ChannelSender::<T>::send_remote_raw_enc(stream, enc, ChannelMessage::Message(data))
}
}
}
pub fn send_raw(&self, data: &[u8]) -> Result<()> {
match &self.inner {
ChannelSenderInner::Local(sender) => {
Ok(sender.send(ChannelMessage::RawData(data.into()))?)
}
ChannelSenderInner::Remote(sender) => {
let mut sender = sender.lock().expect("Cannot lock ChannelSender");
let stream = sender.deref_mut();
ChannelSender::<T>::send_remote_raw(
stream,
ChannelMessage::RawDataStart(data.len()),
)?;
Ok(stream.write_all(&data)?)
}
ChannelSenderInner::RemoteEnc(stream) => {
let mut stream = stream.lock().unwrap();
let (stream, enc) = stream.deref_mut();
ChannelSender::<T>::send_remote_raw_enc(
stream,
enc,
ChannelMessage::RawDataStart(data.len()),
)?;
let data = ChannelSender::<T>::encrypt_buffer(data.into(), enc)?;
Ok(stream.write_all(&data)?)
}
}
}
fn send_remote_raw(stream: &mut TcpStream, data: ChannelMessage<T>) -> Result<()> {
Ok(bincode::serialize_into(stream, &data)?)
}
fn send_remote_raw_enc(
stream: &mut TcpStream,
encryptor: &mut ChaCha20,
data: ChannelMessage<T>,
) -> Result<()> {
let data = bincode::serialize(&data)?;
let data = ChannelSender::<T>::encrypt_buffer(data, encryptor)?;
stream.write_all(&data)?;
Ok(())
}
fn encrypt_buffer(mut data: Vec<u8>, encryptor: &mut ChaCha20) -> Result<Vec<u8>> {
let mut res = Vec::from((data.len() as u32).to_le_bytes());
res.append(&mut data);
encryptor.apply_keystream(&mut res);
Ok(res)
}
pub fn change_type<T2>(self) -> ChannelSender<T2> {
match self.inner {
ChannelSenderInner::Remote(r) => ChannelSender {
inner: ChannelSenderInner::Remote(r),
},
ChannelSenderInner::RemoteEnc(r) => ChannelSender {
inner: ChannelSenderInner::RemoteEnc(r),
},
ChannelSenderInner::Local(_) => panic!("Cannot change ChannelSender::Local type"),
}
}
}
impl<T> ChannelReceiver<T>
where
T: 'static + DeserializeOwned,
{
pub fn recv(&self) -> Result<T> {
let message = match &self.inner {
ChannelReceiverInner::Local(receiver) => receiver.recv()?,
ChannelReceiverInner::Remote(receiver) => ChannelReceiver::recv_remote_raw(receiver)?,
ChannelReceiverInner::RemoteEnc(receiver) => {
let mut receiver = receiver.borrow_mut();
let (receiver, decryptor) = receiver.deref_mut();
ChannelReceiver::recv_remote_raw_enc(receiver, decryptor)?
}
};
match message {
ChannelMessage::Message(mex) => Ok(mex),
_ => panic!("Expected ChannelMessage::Message"),
}
}
pub fn recv_raw(&self) -> Result<Vec<u8>> {
match &self.inner {
ChannelReceiverInner::Local(receiver) => match receiver.recv()? {
ChannelMessage::RawData(data) => Ok(data),
_ => panic!("Expected ChannelMessage::RawData"),
},
ChannelReceiverInner::Remote(receiver) => {
match ChannelReceiver::<T>::recv_remote_raw(receiver)? {
ChannelMessage::RawDataStart(len) => {
let mut receiver = receiver.borrow_mut();
let mut buf = vec![0u8; len];
receiver.read_exact(&mut buf)?;
Ok(buf)
}
_ => panic!("Expected ChannelMessage::RawDataStart"),
}
}
ChannelReceiverInner::RemoteEnc(receiver) => {
let mut receiver = receiver.borrow_mut();
let (receiver, decryptor) = receiver.deref_mut();
match ChannelReceiver::<T>::recv_remote_raw_enc(receiver, decryptor)? {
ChannelMessage::RawDataStart(_) => {
let buf = ChannelReceiver::<T>::decrypt_buffer(receiver, decryptor)?;
Ok(buf)
}
_ => panic!("Expected ChannelMessage::RawDataStart"),
}
}
}
}
fn recv_remote_raw(receiver: &RefCell<TcpStream>) -> Result<ChannelMessage<T>> {
let mut receiver = receiver.borrow_mut();
Ok(bincode::deserialize_from(receiver.deref_mut())?)
}
fn recv_remote_raw_enc(
receiver: &mut TcpStream,
decryptor: &mut ChaCha20,
) -> Result<ChannelMessage<T>> {
let buf = ChannelReceiver::<T>::decrypt_buffer(receiver, decryptor)?;
Ok(bincode::deserialize(&buf)?)
}
fn decrypt_buffer(receiver: &mut TcpStream, decryptor: &mut ChaCha20) -> Result<Vec<u8>> {
let mut len = [0u8; 4];
receiver.read_exact(&mut len)?;
decryptor.apply_keystream(&mut len);
let len = u32::from_le_bytes(len) as usize;
let mut buf = vec![0u8; len];
receiver.read_exact(&mut buf)?;
decryptor.apply_keystream(&mut buf);
Ok(buf)
}
pub fn change_type<T2>(self) -> ChannelReceiver<T2> {
match self.inner {
ChannelReceiverInner::Local(_) => panic!("Cannot change ChannelReceiver::Local type"),
ChannelReceiverInner::Remote(r) => ChannelReceiver {
inner: ChannelReceiverInner::Remote(r),
},
ChannelReceiverInner::RemoteEnc(r) => ChannelReceiver {
inner: ChannelReceiverInner::RemoteEnc(r),
},
}
}
}
pub fn new_local_channel<T>() -> (ChannelSender<T>, ChannelReceiver<T>) {
let (tx, rx) = unbounded();
(
ChannelSender {
inner: ChannelSenderInner::Local(tx),
},
ChannelReceiver {
inner: ChannelReceiverInner::Local(rx),
},
)
}
pub struct ChannelServer<S, R> {
listener: TcpListener,
enc_key: Option<[u8; 32]>,
_sender: PhantomData<*const S>,
_receiver: PhantomData<*const R>,
}
impl<S, R> ChannelServer<S, R> {
pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<ChannelServer<S, R>> {
Ok(ChannelServer {
listener: TcpListener::bind(addr)?,
enc_key: None,
_sender: Default::default(),
_receiver: Default::default(),
})
}
pub fn bind_with_enc<A: ToSocketAddrs>(
addr: A,
enc_key: [u8; 32],
) -> Result<ChannelServer<S, R>> {
Ok(ChannelServer {
listener: TcpListener::bind(addr)?,
enc_key: Some(enc_key),
_sender: Default::default(),
_receiver: Default::default(),
})
}
}
impl<S, R> Deref for ChannelServer<S, R> {
type Target = TcpListener;
fn deref(&self) -> &Self::Target {
&self.listener
}
}
impl<S, R> Iterator for ChannelServer<S, R> {
type Item = (ChannelSender<S>, ChannelReceiver<R>, SocketAddr);
fn next(&mut self) -> Option<Self::Item> {
loop {
let next = self
.listener
.incoming()
.next()
.expect("TcpListener::incoming returned None");
if let Ok(mut sender) = next {
let peer_addr = sender.peer_addr().expect("Peer has no remote address");
let receiver = sender.try_clone().expect("Failed to clone the stream");
if let Some(enc_key) = &self.enc_key {
let key = Key::from_slice(enc_key);
let (enc_nonce, dec_nonce) = match nonce_handshake(&mut sender) {
Ok(x) => x,
Err(e) => {
warn!("Nonce handshake failed with {}: {:?}", peer_addr, e);
continue;
}
};
let enc_nonce = Nonce::from_slice(&enc_nonce);
let mut enc = ChaCha20::new(&key, &enc_nonce);
let dec_nonce = Nonce::from_slice(&dec_nonce);
let mut dec = ChaCha20::new(&key, &dec_nonce);
if let Err(e) = check_encryption_key(&mut sender, &mut enc, &mut dec) {
warn!("Magic handshake failed with {}: {:?}", peer_addr, e);
continue;
}
return Some((
ChannelSender {
inner: ChannelSenderInner::RemoteEnc(Arc::new(Mutex::new((
sender, enc,
)))),
},
ChannelReceiver {
inner: ChannelReceiverInner::RemoteEnc(RefCell::new((receiver, dec))),
},
peer_addr,
));
} else {
if let Err(e) = check_no_encryption(&mut sender) {
warn!("Magic handshake failed with {}: {:?}", peer_addr, e);
continue;
}
return Some((
ChannelSender {
inner: ChannelSenderInner::Remote(Arc::new(Mutex::new(sender))),
},
ChannelReceiver {
inner: ChannelReceiverInner::Remote(RefCell::new(receiver)),
},
peer_addr,
));
}
}
}
}
}
pub fn connect_channel<A: ToSocketAddrs, S, R>(
addr: A,
) -> Result<(ChannelSender<S>, ChannelReceiver<R>)> {
let mut stream = TcpStream::connect(addr)?;
let stream2 = stream.try_clone()?;
check_no_encryption(&mut stream)?;
Ok((
ChannelSender {
inner: ChannelSenderInner::Remote(Arc::new(Mutex::new(stream))),
},
ChannelReceiver {
inner: ChannelReceiverInner::Remote(RefCell::new(stream2)),
},
))
}
pub fn connect_channel_with_enc<A: ToSocketAddrs, S, R>(
addr: A,
enc_key: &[u8; 32],
) -> Result<(ChannelSender<S>, ChannelReceiver<R>)> {
let mut stream = TcpStream::connect(addr)?;
let stream2 = stream.try_clone()?;
let (enc_nonce, dec_nonce) = nonce_handshake(&mut stream)?;
let key = Key::from_slice(enc_key);
let mut enc = ChaCha20::new(&key, &Nonce::from_slice(&enc_nonce));
let mut dec = ChaCha20::new(&key, &Nonce::from_slice(&dec_nonce));
check_encryption_key(&mut stream, &mut enc, &mut dec)?;
Ok((
ChannelSender {
inner: ChannelSenderInner::RemoteEnc(Arc::new(Mutex::new((stream, enc)))),
},
ChannelReceiver {
inner: ChannelReceiverInner::RemoteEnc(RefCell::new((stream2, dec))),
},
))
}
fn nonce_handshake(s: &mut TcpStream) -> Result<([u8; 12], [u8; 12])> {
let mut enc_nonce = [0u8; 12];
OsRng.fill_bytes(&mut enc_nonce);
s.write_all(&enc_nonce)?;
s.flush()?;
let mut dec_nonce = [0u8; 12];
s.read_exact(&mut dec_nonce)?;
Ok((enc_nonce, dec_nonce))
}
fn check_encryption_key(
stream: &mut TcpStream,
enc: &mut ChaCha20,
dec: &mut ChaCha20,
) -> Result<()> {
let mut magic = MAGIC.to_le_bytes();
enc.apply_keystream(&mut magic);
stream.write_all(&magic)?;
stream.flush()?;
stream.read_exact(&mut magic)?;
dec.apply_keystream(&mut magic);
let magic = u32::from_le_bytes(magic);
if magic != MAGIC {
bail!("Wrong encryption key");
}
Ok(())
}
fn check_no_encryption(stream: &mut TcpStream) -> Result<()> {
let key = b"task-maker's the best thing ever";
let nonce = b"task-maker!!";
let mut enc = ChaCha20::new(Key::from_slice(key), Nonce::from_slice(nonce));
let mut dec = ChaCha20::new(Key::from_slice(key), Nonce::from_slice(nonce));
check_encryption_key(stream, &mut enc, &mut dec)
}
#[cfg(test)]
mod tests {
use rand::Rng;
use super::*;
#[test]
fn test_remote_channels_enc_wrong_key() {
let port = rand::thread_rng().gen_range(10000u16, 20000u16);
let enc_key = [42u8; 32];
let mut server: ChannelServer<(), ()> =
ChannelServer::bind_with_enc(("127.0.0.1", port), enc_key).unwrap();
let client_thread = std::thread::spawn(move || {
let wrong_enc_key = [69u8; 32];
assert!(
connect_channel_with_enc::<_, (), ()>(("127.0.0.1", port), &wrong_enc_key).is_err()
);
connect_channel_with_enc::<_, (), ()>(("127.0.0.1", port), &enc_key).unwrap();
});
server.next().unwrap();
client_thread.join().unwrap();
}
#[test]
fn test_remote_channels_enc_no_key() {
let port = rand::thread_rng().gen_range(10000u16, 20000u16);
let enc_key = [42u8; 32];
let mut server: ChannelServer<(), ()> =
ChannelServer::bind_with_enc(("127.0.0.1", port), enc_key).unwrap();
let client_thread = std::thread::spawn(move || {
assert!(connect_channel::<_, (), ()>(("127.0.0.1", port)).is_err());
connect_channel_with_enc::<_, (), ()>(("127.0.0.1", port), &enc_key).unwrap();
});
server.next().unwrap();
client_thread.join().unwrap();
}
#[test]
fn test_remote_channels_receiver_stops() {
let port = rand::thread_rng().gen_range(10000u16, 20000u16);
let mut server: ChannelServer<(), ()> = ChannelServer::bind(("127.0.0.1", port)).unwrap();
let client_thread = std::thread::spawn(move || {
let (sender, _) = connect_channel::<_, (), ()>(("127.0.0.1", port)).unwrap();
sender.send(()).unwrap();
});
let (_, receiver, _) = server.next().unwrap();
client_thread.join().unwrap();
receiver.recv().unwrap();
assert!(receiver.recv().is_err());
}
}