use anyhow::Result;
use memmap::MmapOptions;
use serde::{Deserialize, Serialize};
use std::convert::TryInto;
use std::fs::File;
use std::fs::OpenOptions;
use hkdf::Hkdf;
use sha2::{Digest, Sha256};
use spake2::{Ed25519Group, Identity, Password, SPAKE2};
use chacha20poly1305::aead::NewAead;
use chacha20poly1305::{aead::AeadInPlace, ChaCha20Poly1305, Key, Nonce, Tag};
use rand::Rng;
mod chunks;
pub mod errors;
pub mod file;
use errors::PortalError::*;
use file::{PortalFile, StateMetadata};
pub const DEFAULT_PORT: u16 = 13265;
pub const CHUNK_SIZE: usize = 65536;
pub type PortalConfirmation = [u8; 33];
#[derive(Serialize, Deserialize, PartialEq, Debug)]
pub struct Portal {
id: String,
direction: Direction,
#[serde(skip)]
metadata: Metadata,
#[serde(skip)]
state: Option<SPAKE2<Ed25519Group>>,
#[serde(skip)]
key: Option<Vec<u8>>,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Default)]
pub struct Metadata {
filesize: u64,
filename: Option<Vec<u8>>,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub enum Direction {
Sender,
Receiver,
}
fn compare_key_derivations(a: &[u8], b: &[u8]) -> std::cmp::Ordering {
for (ai, bi) in a.iter().zip(b.iter()) {
match ai.cmp(&bi) {
std::cmp::Ordering::Equal => continue,
ord => return ord,
}
}
a.len().cmp(&b.len())
}
impl Portal {
pub fn init(
direction: Direction,
id: String,
password: String,
mut filename: Option<String>,
) -> (Portal, PortalConfirmation) {
let mut hasher = Sha256::new();
hasher.update(&id);
let id_bytes = hasher.finalize();
let id_hash = hex::encode(&id_bytes);
let (s1, outbound_msg) = SPAKE2::<Ed25519Group>::start_symmetric(
&Password::new(&password.as_bytes()),
&Identity::new(&id_bytes),
);
if let Some(file) = filename {
let f = std::path::Path::new(&file);
let f = f.file_name().unwrap().to_str().unwrap();
filename = Some(f.to_string());
}
let metadata = Metadata {
filesize: 0,
filename: filename.map_or(None, |v| Some(v.as_bytes().to_vec())),
};
(
Portal {
direction,
id: id_hash,
metadata,
state: Some(s1),
key: None,
},
outbound_msg.try_into().expect("Bad message format"),
)
}
pub fn parse(data: &[u8]) -> Result<Portal> {
Ok(bincode::deserialize(&data)?)
}
pub fn read_response_from<R>(reader: R) -> Result<Portal>
where
R: std::io::Read,
{
Ok(bincode::deserialize_from::<R, Portal>(reader)?)
}
pub fn read_confirmation_from<R>(mut reader: R) -> Result<PortalConfirmation>
where
R: std::io::Read,
{
let mut res: PortalConfirmation = [0u8; 33];
reader.read_exact(&mut res)?;
Ok(res)
}
pub fn read_metadata_from<R>(&mut self, mut reader: R) -> Result<()>
where
R: std::io::Read,
{
let key = self.key.as_ref().ok_or(NoPeer)?;
let cha_key = Key::from_slice(&key[..]);
let cipher = ChaCha20Poly1305::new(cha_key);
let state: StateMetadata = bincode::deserialize_from(&mut reader).or(Err(BadMsg))?;
let mut data: Vec<u8> = bincode::deserialize_from(&mut reader).or(Err(BadMsg))?;
if state.nonce.len() != std::mem::size_of::<Nonce>()
|| state.tag.len() != std::mem::size_of::<Tag>()
{
return Err(BadState.into());
}
let nonce = Nonce::from_slice(&state.nonce);
let tag = Tag::from_slice(&state.tag);
match cipher.decrypt_in_place_detached(&nonce, b"", &mut data, &tag) {
Ok(_) => {}
Err(_e) => return Err(DecryptError.into()),
}
let mdata: Metadata = bincode::deserialize(&data)?;
self.metadata = mdata;
Ok(())
}
pub fn write_metadata_to<W>(&mut self, mut writer: W) -> Result<usize>
where
W: std::io::Write,
{
let mut state = StateMetadata::default();
let mut rng = rand::thread_rng();
let rbytes = rng.gen::<[u8; 12]>();
let nonce = Nonce::from_slice(&rbytes);
state.nonce.extend(nonce);
let key = self.key.as_ref().ok_or(NoPeer)?;
let cha_key = Key::from_slice(&key[..]);
let cipher = ChaCha20Poly1305::new(cha_key);
let mut data: Vec<u8> = bincode::serialize(&self.metadata)?;
let tag = match cipher.encrypt_in_place_detached(nonce, b"", &mut data) {
Ok(tag) => tag,
Err(_e) => return Err(EncryptError.into()),
};
state.tag.extend(tag);
let mut finaldata = bincode::serialize(&state)?;
finaldata.extend_from_slice(&bincode::serialize(&data)?);
writer.write_all(&finaldata).or(Err(IOError))?;
Ok(data.len())
}
pub fn serialize(&self) -> Result<Vec<u8>> {
Ok(bincode::serialize(&self)?)
}
pub fn load_file<'a>(&'a self, f: &str) -> Result<PortalFile> {
let file = File::open(f)?;
let mmap = unsafe { MmapOptions::new().map_copy(&file)? };
let key = self.key.as_ref().ok_or(NoPeer)?;
let cha_key = Key::from_slice(&key[..]);
let cipher = ChaCha20Poly1305::new(cha_key);
Ok(PortalFile::init(mmap, cipher))
}
pub fn create_file<'a>(&'a self, f: &str, size: u64) -> Result<PortalFile> {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(&f)?;
file.set_len(size)?;
let key = self.key.as_ref().ok_or(NoPeer)?;
let mmap = unsafe { MmapOptions::new().map_mut(&file)? };
let cha_key = Key::from_slice(&key[..]);
let cipher = ChaCha20Poly1305::new(cha_key);
Ok(PortalFile::init(mmap, cipher))
}
pub fn derive_key(&mut self, msg_data: &[u8]) -> Result<()> {
let state = std::mem::replace(&mut self.state, None);
let state = state.ok_or(BadState)?;
self.key = match state.finish(msg_data) {
Ok(res) => Some(res),
Err(_) => {
return Err(BadMsg.into());
}
};
Ok(())
}
pub fn confirm_peer<R>(&mut self, mut client: R) -> Result<()>
where
R: std::io::Read + std::io::Write,
{
let key = self.key.as_ref().ok_or(NoPeer)?;
let sender_info = format!("{}-{}", self.id, "senderinfo");
let receiver_info = format!("{}-{}", self.id, "receiverinfo");
let h = Hkdf::<Sha256>::new(None, &key);
let mut peer_msg = [0u8; 42];
let mut sender_confirm = [0u8; 42];
let mut receiver_confirm = [0u8; 42];
h.expand(&sender_info.as_bytes(), &mut sender_confirm)
.unwrap();
h.expand(&receiver_info.as_bytes(), &mut receiver_confirm)
.unwrap();
match self.direction {
Direction::Sender => {
client.write_all(&sender_confirm)?;
client.read_exact(&mut peer_msg)?;
if compare_key_derivations(&peer_msg, &receiver_confirm)
!= std::cmp::Ordering::Equal
{
return Err(BadMsg.into());
}
}
Direction::Receiver => {
client.write_all(&receiver_confirm)?;
client.read_exact(&mut peer_msg)?;
if compare_key_derivations(&peer_msg, &sender_confirm) != std::cmp::Ordering::Equal
{
return Err(BadMsg.into());
}
}
}
Ok(())
}
pub fn get_file_size(&self) -> u64 {
self.metadata.filesize
}
pub fn set_file_size(&mut self, size: u64) {
self.metadata.filesize = size;
}
pub fn get_file_name<'a>(&'a self) -> Result<&'a str> {
match &self.metadata.filename {
Some(f) => Ok(std::str::from_utf8(f)?),
None => Err(NoneError.into()),
}
}
pub fn get_direction(&self) -> Direction {
self.direction.clone()
}
pub fn set_direction(&mut self, direction: Direction) {
self.direction = direction;
}
pub fn get_id(&self) -> &String {
&self.id
}
pub fn set_id(&mut self, id: String) {
self.id = id;
}
}
#[cfg(test)]
mod tests {
use crate::file::tests::MockTcpStream;
use crate::{errors::PortalError, Direction, Portal, StateMetadata};
use hkdf::Hkdf;
use rand::Rng;
use sha2::Sha256;
use std::io::Write;
#[test]
fn metadata_roundtrip() {
let fsize = 1337;
let fname = "filename".to_string();
let dir = Direction::Receiver;
let pass = "test".to_string();
let (mut receiver, receiver_msg) = Portal::init(dir, "id".to_string(), pass, None);
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut sender, sender_msg) =
Portal::init(dir, "id".to_string(), pass, Some(fname.clone()));
sender.set_file_size(fsize);
receiver.derive_key(sender_msg.as_slice()).unwrap();
sender.derive_key(receiver_msg.as_slice()).unwrap();
let mut stream = MockTcpStream {
data: Vec::with_capacity(crate::CHUNK_SIZE),
};
sender.write_metadata_to(&mut stream).unwrap();
receiver.read_metadata_from(&mut stream).unwrap();
assert_eq!(fsize, receiver.get_file_size());
assert_eq!(fname, receiver.get_file_name().unwrap());
assert_eq!(
sender.get_file_name().unwrap(),
receiver.get_file_name().unwrap()
);
assert_eq!(sender.get_file_size(), receiver.get_file_size());
}
#[test]
fn fail_decrypt_metadata() {
let fsize = 1337;
let fname = "filename".to_string();
let dir = Direction::Receiver;
let pass = "test".to_string();
let (mut receiver, receiver_msg) = Portal::init(dir, "id".to_string(), pass, None);
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut sender, sender_msg) =
Portal::init(dir, "id".to_string(), pass, Some(fname.clone()));
sender.set_file_size(fsize);
receiver.derive_key(sender_msg.as_slice()).unwrap();
sender.derive_key(receiver_msg.as_slice()).unwrap();
let mut stream = MockTcpStream {
data: Vec::with_capacity(crate::CHUNK_SIZE),
};
let mut garbage = bincode::serialize(&StateMetadata::default()).unwrap();
garbage.extend_from_slice(&bincode::serialize(&vec![0u8]).unwrap());
stream.write_all(&garbage).unwrap();
let res = receiver.read_metadata_from(&mut stream);
assert!(res.is_err());
let _ = res.map_err(|e| match e.downcast_ref::<PortalError>() {
Some(PortalError::BadState) => anyhow::Ok(()),
_ => panic!("Unexpected error"),
});
let state = StateMetadata {
nonce: rand::thread_rng().gen::<[u8; 12]>().to_vec(),
tag: rand::thread_rng().gen::<[u8; 16]>().to_vec(),
};
let mut garbage = bincode::serialize(&state).unwrap();
garbage.extend_from_slice(&bincode::serialize(&vec![0u8]).unwrap());
stream.write_all(&garbage).unwrap();
let res = receiver.read_metadata_from(&mut stream);
assert!(res.is_err());
let _ = res.map_err(|e| match e.downcast_ref::<PortalError>() {
Some(PortalError::DecryptError) => anyhow::Ok(()),
_ => panic!("Unexpected error"),
});
}
#[test]
fn key_derivation() {
let dir = Direction::Receiver;
let pass = "test".to_string();
let (mut receiver, receiver_msg) = Portal::init(dir, "id".to_string(), pass, None);
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut sender, sender_msg) = Portal::init(dir, "id".to_string(), pass, None);
receiver.derive_key(sender_msg.as_slice()).unwrap();
sender.derive_key(receiver_msg.as_slice()).unwrap();
assert_eq!(receiver.key, sender.key);
}
#[test]
fn key_confirmation() {
let mut receiver_side = MockTcpStream {
data: Vec::with_capacity(crate::CHUNK_SIZE),
};
let mut sender_side = MockTcpStream {
data: Vec::with_capacity(crate::CHUNK_SIZE),
};
let dir = Direction::Receiver;
let pass = "test".to_string();
let (mut receiver, receiver_msg) = Portal::init(dir, "id".to_string(), pass, None);
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut sender, sender_msg) = Portal::init(dir, "id".to_string(), pass, None);
receiver.derive_key(sender_msg.as_slice()).unwrap();
sender.derive_key(receiver_msg.as_slice()).unwrap();
let id = receiver.get_id();
let sender_info = format!("{}-{}", id, "senderinfo");
let receiver_info = format!("{}-{}", id, "receiverinfo");
let h = Hkdf::<Sha256>::new(None, &sender.key.as_ref().unwrap());
let mut sender_confirm = [0u8; 42];
let mut receiver_confirm = [0u8; 42];
h.expand(&sender_info.as_bytes(), &mut sender_confirm)
.unwrap();
h.expand(&receiver_info.as_bytes(), &mut receiver_confirm)
.unwrap();
receiver_side.write(&sender_confirm).unwrap();
sender_side.write(&receiver_confirm).unwrap();
receiver.confirm_peer(&mut receiver_side).unwrap();
sender.confirm_peer(&mut sender_side).unwrap();
}
#[test]
fn portal_load_file() {
let dir = Direction::Receiver;
let pass = "test".to_string();
let (_receiver, receiver_msg) = Portal::init(dir, "id".to_string(), pass, None);
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut sender, _sender_msg) = Portal::init(dir, "id".to_string(), pass, None);
sender.derive_key(receiver_msg.as_slice()).unwrap();
let _file = sender.load_file("/etc/passwd").unwrap();
}
#[test]
fn portalfile_chunks_iterator() {
let dir = Direction::Receiver;
let pass = "test".to_string();
let (_receiver, receiver_msg) = Portal::init(dir, "id".to_string(), pass, None);
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut sender, _sender_msg) = Portal::init(dir, "id".to_string(), pass, None);
sender.derive_key(receiver_msg.as_slice()).unwrap();
let file = sender.load_file("/etc/passwd").unwrap();
let chunk_size = 10;
for v in file.get_chunks(chunk_size) {
assert!(v.len() <= chunk_size);
}
let chunk_size = 1024;
for v in file.get_chunks(chunk_size) {
assert!(v.len() <= chunk_size);
}
}
#[test]
fn portal_createfile() {
let dir = Direction::Receiver;
let pass = "test".to_string();
let (mut receiver, receiver_msg) = Portal::init(dir, "id".to_string(), pass, None);
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut sender, sender_msg) = Portal::init(dir, "id".to_string(), pass, None);
sender.derive_key(receiver_msg.as_slice()).unwrap();
receiver.derive_key(sender_msg.as_slice()).unwrap();
let _file_dst = receiver.create_file("/tmp/passwd", 4096).unwrap();
}
#[test]
fn portal_write_chunk() {
let dir = Direction::Receiver;
let pass = "test".to_string();
let (mut receiver, receiver_msg) = Portal::init(dir, "id".to_string(), pass, None);
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut sender, sender_msg) = Portal::init(dir, "id".to_string(), pass, None);
sender.derive_key(receiver_msg.as_slice()).unwrap();
receiver.derive_key(sender_msg.as_slice()).unwrap();
let file_src = sender.load_file("/etc/passwd").unwrap();
let mut file_dst = receiver.create_file("/tmp/passwd", 4096).unwrap();
let chunk_size = 4096;
for v in file_src.get_chunks(chunk_size) {
assert!(v.len() <= chunk_size);
file_dst.write_given_chunk(&v).unwrap();
}
}
#[test]
#[should_panic]
fn portal_createfile_no_peer() {
let dir = Direction::Sender;
let pass = "test".to_string();
let (portal, _msg) = Portal::init(dir, "id".to_string(), pass, None);
let _file_dst = portal.create_file("/tmp/passwd", 4096).unwrap();
}
#[test]
#[should_panic]
fn portal_loadfile_no_peer() {
let dir = Direction::Sender;
let pass = "test".to_string();
let (portal, _msg) = Portal::init(dir, "id".to_string(), pass, None);
let _file_src = portal.load_file("/etc/passwd").unwrap();
}
#[test]
fn test_file_trim() {
let file = Some("/my/path/filename.txt".to_string());
let dir = Direction::Receiver;
let pass = "test".to_string();
let (receiver, _receiver_msg) = Portal::init(dir, "id".to_string(), pass, file);
let result = receiver.get_file_name().unwrap();
assert_eq!(result, "filename.txt");
}
#[test]
fn test_compressed_edwards_size() {
let edwards_point = <spake2::Ed25519Group as spake2::Group>::Element::default();
let compressed = edwards_point.compress();
let msg_size: usize = std::mem::size_of_val(&compressed) + 1;
assert_eq!(33, msg_size);
}
#[test]
fn test_getters_setters() {
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut portal, _msg) = Portal::init(dir, "id".to_string(), pass, None);
portal.set_id("newID".to_string());
assert_eq!("newID", portal.get_id());
portal.set_direction(Direction::Receiver);
assert_eq!(portal.get_direction(), Direction::Receiver);
portal.set_file_size(25);
assert_eq!(portal.get_file_size(), 25);
}
#[test]
fn test_serialize_deserialize() {
let dir = Direction::Sender;
let pass = "test".to_string();
let (portal, _msg) = Portal::init(dir, "id".to_string(), pass, None);
let ser = portal.serialize().unwrap();
let res = Portal::parse(&ser).unwrap();
assert_eq!(res.id, portal.id);
assert_eq!(res.direction, portal.direction);
assert_eq!(res.metadata.filename, portal.metadata.filename);
assert_eq!(res.metadata.filesize, portal.metadata.filesize);
assert_ne!(res.state, portal.state);
assert_eq!(res.state, None);
assert_eq!(res.key, None);
}
}