use anyhow::Result;
use serde::{Serialize, Deserialize};
use std::fs::File;
use memmap::MmapOptions;
use std::fs::OpenOptions;
use spake2::{Ed25519Group, Identity, Password, SPAKE2};
use sha2::{Sha256, Digest};
use hkdf::Hkdf;
use chacha20poly1305::{ChaCha20Poly1305, Key};
use chacha20poly1305::aead::{NewAead};
pub mod errors;
pub mod file;
mod chunks;
use errors::PortalError;
use file::PortalFile;
pub const DEFAULT_PORT: u16 = 13265;
pub const CHUNK_SIZE: usize = 65535;
pub type PortalConfirmation = [u8;33];
#[derive(Serialize, Deserialize, PartialEq, Debug)]
pub struct Portal{
id: String,
direction: Direction,
filename: Option<String>,
filesize: u64,
#[serde(skip)]
state: Option<SPAKE2<Ed25519Group>>,
#[serde(skip)]
key: Option<Vec<u8>>,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub enum Direction {
Sender,
Receiver,
}
fn compare_key_derivations(a: &[u8], b: &[u8]) -> std::cmp::Ordering {
for (ai, bi) in a.iter().zip(b.iter()) {
match ai.cmp(&bi) {
std::cmp::Ordering::Equal => continue,
ord => return ord
}
}
a.len().cmp(&b.len())
}
impl Portal {
pub fn init(direction: Direction,
id: String,
password: String,
mut filename: Option<String>) -> (Portal,Vec<u8>) {
let mut hasher = Sha256::new();
hasher.update(&id);
let id_bytes = hasher.finalize();
let id_hash = hex::encode(&id_bytes);
let (s1, outbound_msg) = SPAKE2::<Ed25519Group>::start_symmetric(
&Password::new(&password.as_bytes()),
&Identity::new(&id_bytes));
if let Some(file) = filename {
let f = std::path::Path::new(&file);
let f = f.file_name().unwrap().to_str().unwrap();
filename = Some(f.to_string());
}
(Portal {
direction,
id: id_hash,
filename,
filesize: 0,
state: Some(s1),
key: None,
}, outbound_msg)
}
pub fn parse(data: &[u8]) -> Result<Portal> {
Ok(bincode::deserialize(&data)?)
}
pub fn read_response_from<R>(reader: R) -> Result<Portal>
where
R: std::io::Read {
Ok(bincode::deserialize_from::<R,Portal>(reader)?)
}
pub fn read_confirmation_from<R>(mut reader: R) -> Result<PortalConfirmation>
where
R: std::io::Read {
let mut res: PortalConfirmation= [0u8;33];
reader.read_exact(&mut res)?;
Ok(res)
}
pub fn serialize(&self) -> Result<Vec<u8>> {
Ok(bincode::serialize(&self)?)
}
pub fn load_file<'a>(&'a self, f: &str) -> Result<PortalFile> {
let file = File::open(f)?;
let mmap = unsafe { MmapOptions::new().map_copy(&file)? };
let key = self.key.as_ref().ok_or_else(|| PortalError::NoPeer)?;
let cha_key = Key::from_slice(&key[..]);
let cipher = ChaCha20Poly1305::new(cha_key);
Ok(PortalFile::init(mmap,cipher))
}
pub fn create_file<'a>(&'a self, f: &str, size: u64) -> Result<PortalFile> {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(&f)?;
file.set_len(size)?;
let key = self.key.as_ref().ok_or_else(|| PortalError::NoPeer)?;
let mmap = unsafe {
MmapOptions::new().map_mut(&file)?
};
let cha_key = Key::from_slice(&key[..]);
let cipher = ChaCha20Poly1305::new(cha_key);
Ok(PortalFile::init(mmap,cipher))
}
pub fn derive_key(&mut self, msg_data: &[u8]) -> Result<()> {
let state = std::mem::replace(&mut self.state, None);
let state = state.ok_or_else(|| PortalError::BadState)?;
self.key = match state.finish(msg_data) {
Ok(res) => Some(res),
Err(_) => {return Err(PortalError::BadMsg.into());}
};
Ok(())
}
pub fn confirm_peer<R>(&mut self, mut client: R) -> Result<()>
where
R: std::io::Read + std::io::Write {
let key = self.key.as_ref().ok_or_else(|| PortalError::NoPeer)?;
let sender_info = format!("{}-{}",self.id,"senderinfo");
let receiver_info = format!("{}-{}",self.id,"receiverinfo");
let h = Hkdf::<Sha256>::new(None,&key);
let mut peer_msg = [0u8;42];
let mut sender_confirm = [0u8; 42];
let mut receiver_confirm = [0u8; 42];
h.expand(&sender_info.as_bytes(), &mut sender_confirm).unwrap();
h.expand(&receiver_info.as_bytes(), &mut receiver_confirm).unwrap();
match self.direction {
Direction::Sender => {
client.write_all(&sender_confirm)?;
client.read_exact(&mut peer_msg)?;
if compare_key_derivations(&peer_msg,&receiver_confirm) == std::cmp::Ordering::Equal {
return Ok(());
}
Err(PortalError::BadMsg.into())
}
Direction::Receiver => {
client.write_all(&receiver_confirm)?;
client.read_exact(&mut peer_msg)?;
if compare_key_derivations(&peer_msg,&sender_confirm) == std::cmp::Ordering::Equal {
return Ok(());
}
Err(PortalError::BadMsg.into())
}
}
}
pub fn get_file_size(&self) -> u64 {
self.filesize
}
pub fn set_file_size(&mut self, size: u64) {
self.filesize = size;
}
pub fn get_file_name<'a>(&'a self) -> Result<&'a str> {
match &self.filename {
Some(f) => Ok(f.as_str()),
None => Err(PortalError::NoneError.into()),
}
}
pub fn get_direction(&self) -> Direction {
self.direction.clone()
}
pub fn set_direction(&mut self, direction: Direction) {
self.direction = direction;
}
pub fn get_id(&self) -> &String {
&self.id
}
pub fn set_id(&mut self, id: String) {
self.id = id;
}
}
#[cfg(test)]
mod tests {
use super::{Portal,Direction};
use sha2::Sha256;
use hkdf::Hkdf;
use std::io::Write;
use crate::file::tests::MockTcpStream;
#[test]
fn key_derivation() {
let dir = Direction::Receiver;
let pass ="test".to_string();
let (mut receiver,receiver_msg) = Portal::init(dir,"id".to_string(),pass,None);
let dir = Direction::Sender;
let pass ="test".to_string();
let (mut sender,sender_msg) = Portal::init(dir,"id".to_string(),pass,None);
receiver.derive_key(sender_msg.as_slice()).unwrap();
sender.derive_key(receiver_msg.as_slice()).unwrap();
assert_eq!(receiver.key,sender.key);
}
#[test]
fn key_confirmation() {
let mut receiver_side = MockTcpStream {
data: Vec::with_capacity(crate::CHUNK_SIZE),
};
let mut sender_side = MockTcpStream {
data: Vec::with_capacity(crate::CHUNK_SIZE),
};
let dir = Direction::Receiver;
let pass ="test".to_string();
let (mut receiver,receiver_msg) = Portal::init(dir,"id".to_string(),pass,None);
let dir = Direction::Sender;
let pass ="test".to_string();
let (mut sender,sender_msg) = Portal::init(dir,"id".to_string(),pass,None);
receiver.derive_key(sender_msg.as_slice()).unwrap();
sender.derive_key(receiver_msg.as_slice()).unwrap();
let id = receiver.get_id();
let sender_info = format!("{}-{}",id,"senderinfo");
let receiver_info = format!("{}-{}",id,"receiverinfo");
let h = Hkdf::<Sha256>::new(None,&sender.key.as_ref().unwrap());
let mut sender_confirm = [0u8; 42];
let mut receiver_confirm = [0u8; 42];
h.expand(&sender_info.as_bytes(), &mut sender_confirm).unwrap();
h.expand(&receiver_info.as_bytes(), &mut receiver_confirm).unwrap();
receiver_side.write(&sender_confirm).unwrap();
sender_side.write(&receiver_confirm).unwrap();
receiver.confirm_peer(&mut receiver_side).unwrap();
sender.confirm_peer(&mut sender_side).unwrap();
}
#[test]
fn portal_load_file() {
let dir = Direction::Receiver;
let pass ="test".to_string();
let (_receiver,receiver_msg) = Portal::init(dir,"id".to_string(),pass,None);
let dir = Direction::Sender;
let pass ="test".to_string();
let (mut sender,_sender_msg) = Portal::init(dir,"id".to_string(),pass,None);
sender.derive_key(receiver_msg.as_slice()).unwrap();
let _file = sender.load_file("/etc/passwd").unwrap();
}
#[test]
fn portalfile_chunks_iterator() {
let dir = Direction::Receiver;
let pass ="test".to_string();
let (_receiver,receiver_msg) = Portal::init(dir,"id".to_string(),pass,None);
let dir = Direction::Sender;
let pass ="test".to_string();
let (mut sender,_sender_msg) = Portal::init(dir,"id".to_string(),pass,None);
sender.derive_key(receiver_msg.as_slice()).unwrap();
let file = sender.load_file("/etc/passwd").unwrap();
let chunk_size = 10;
for v in file.get_chunks(chunk_size) {
assert!(v.len() <= chunk_size);
}
let chunk_size = 1024;
for v in file.get_chunks(chunk_size) {
assert!(v.len() <= chunk_size);
}
}
#[test]
fn portal_createfile() {
let dir = Direction::Receiver;
let pass ="test".to_string();
let (mut receiver,receiver_msg) = Portal::init(dir,"id".to_string(),pass,None);
let dir = Direction::Sender;
let pass ="test".to_string();
let (mut sender,sender_msg) = Portal::init(dir,"id".to_string(),pass,None);
sender.derive_key(receiver_msg.as_slice()).unwrap();
receiver.derive_key(sender_msg.as_slice()).unwrap();
let _file_dst = receiver.create_file("/tmp/passwd",4096).unwrap();
}
#[test]
fn portal_write_chunk() {
let dir = Direction::Receiver;
let pass ="test".to_string();
let (mut receiver,receiver_msg) = Portal::init(dir,"id".to_string(),pass,None);
let dir = Direction::Sender;
let pass ="test".to_string();
let (mut sender,sender_msg) = Portal::init(dir,"id".to_string(),pass,None);
sender.derive_key(receiver_msg.as_slice()).unwrap();
receiver.derive_key(sender_msg.as_slice()).unwrap();
let file_src = sender.load_file("/etc/passwd").unwrap();
let mut file_dst = receiver.create_file("/tmp/passwd",4096).unwrap();
let chunk_size = 4096;
for v in file_src.get_chunks(chunk_size) {
assert!(v.len() <= chunk_size);
file_dst.write_given_chunk(&v).unwrap();
}
}
#[test]
#[should_panic]
fn portal_createfile_no_peer() {
let dir = Direction::Sender;
let pass = "test".to_string();
let (portal,_msg) = Portal::init(dir,"id".to_string(),pass, None);
let _file_dst = portal.create_file("/tmp/passwd",4096).unwrap();
}
#[test]
#[should_panic]
fn portal_loadfile_no_peer() {
let dir = Direction::Sender;
let pass = "test".to_string();
let (portal,_msg) = Portal::init(dir,"id".to_string(),pass, None);
let _file_src = portal.load_file("/etc/passwd").unwrap();
}
#[test]
fn test_file_trim() {
let file = Some("/my/path/filename.txt".to_string());
let dir = Direction::Receiver;
let pass ="test".to_string();
let (receiver,_receiver_msg) = Portal::init(dir,"id".to_string(),pass,file);
let result = receiver.get_file_name().unwrap();
assert_eq!(result, "filename.txt");
}
#[test]
fn test_compressed_edwards_size() {
let edwards_point = <spake2::Ed25519Group as spake2::Group>::Element::default();
let compressed = edwards_point.compress();
let msg_size: usize = std::mem::size_of_val(&compressed)+1;
assert_eq!(33,msg_size);
}
#[test]
fn test_getters_setters() {
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut portal,_msg) = Portal::init(dir,"id".to_string(),pass, None);
portal.set_id("newID".to_string());
assert_eq!("newID",portal.get_id());
portal.set_direction(Direction::Receiver);
assert_eq!(portal.get_direction(),Direction::Receiver);
portal.set_file_size(25);
assert_eq!(portal.get_file_size(),25);
}
#[test]
fn test_serialize_deserialize() {
let dir = Direction::Sender;
let pass = "test".to_string();
let (portal,_msg) = Portal::init(dir,"id".to_string(),pass, None);
let ser = portal.serialize().unwrap();
let res = Portal::parse(&ser).unwrap();
assert_eq!(res.id, portal.id);
assert_eq!(res.direction, portal.direction);
assert_eq!(res.filename, portal.filename);
assert_eq!(res.filesize, portal.filesize);
assert_ne!(res.state, portal.state);
assert_eq!(res.state, None);
assert_eq!(res.key, None);
}
}