use anyhow::Result;
use serde::{Serialize, Deserialize};
use std::fs::File;
use memmap::MmapOptions;
use std::fs::OpenOptions;
use spake2::{Ed25519Group, Identity, Password, SPAKE2,Group};
use sha2::{Sha256, Digest};
use chacha20poly1305::{ChaCha20Poly1305, Key};
use chacha20poly1305::aead::{NewAead};
pub mod errors;
mod file;
mod chunks;
use errors::PortalError;
use file::PortalFile;
use chunks::PortalChunks;
pub const DEFAULT_PORT: u16 = 13265;
pub const CHUNK_SIZE: usize = 65535;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
pub struct Portal{
id: String,
direction: Option<Direction>,
filename: Option<String>,
filesize: u64,
#[serde(skip)]
state: Option<SPAKE2<Ed25519Group>>,
#[serde(skip)]
key: Option<Vec<u8>>,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub enum Direction {
Sender,
Receiver,
}
impl Portal {
pub fn init(direction: Option<Direction>,
id: String,
password: String,
mut filename: Option<String>) -> (Portal,Vec<u8>) {
let mut hasher = Sha256::new();
hasher.update(&id);
let id_bytes = hasher.finalize();
let id_hash = hex::encode(&id_bytes);
let (s1, outbound_msg) = SPAKE2::<Ed25519Group>::start_symmetric(
&Password::new(&password.as_bytes()),
&Identity::new(&id_bytes));
if let Some(file) = filename {
let f = std::path::Path::new(&file);
let f = f.file_name().unwrap().to_str().unwrap();
filename = Some(f.to_string());
}
return (Portal {
direction: direction,
id: id_hash,
filename: filename,
filesize: 0,
state: Some(s1),
key: None,
}, outbound_msg);
}
pub fn read_response_from<R>(reader: R) -> Result<Portal>
where
R: std::io::Read {
Ok(bincode::deserialize_from::<R,Portal>(reader)?)
}
pub fn read_confirmation_from<R>(mut reader: R) -> Result<[u8;33]>
where
R: std::io::Read {
assert_eq!(33,Portal::get_peer_msg_size());
let mut res = [0u8;33];
reader.read(&mut res)?;
Ok(res)
}
pub fn parse(data: &Vec<u8>) -> Result<Portal> {
Ok(bincode::deserialize(&data)?)
}
pub fn serialize(&self) -> Result<Vec<u8>> {
Ok(bincode::serialize(&self)?)
}
pub fn get_file_size(&self) -> u64 {
self.filesize
}
pub fn set_file_size(&mut self, size: u64) {
self.filesize = size;
}
pub fn get_file_name<'a>(&'a self) -> Result<&'a str> {
match &self.filename {
Some(f) => Ok(f.as_str()),
None => Err(PortalError::NoneError.into()),
}
}
pub fn get_id(&self) -> &String {
&self.id
}
pub fn get_direction(&self) -> Option<Direction> {
self.direction.clone()
}
pub fn set_id(&mut self, id: String) {
self.id = id;
}
pub fn set_direction(&mut self, direction: Option<Direction>) {
self.direction = direction;
}
pub fn load_file<'a>(&'a self, f: &str) -> Result<PortalFile> {
let file = File::open(f)?;
let mmap = unsafe { MmapOptions::new().map_copy(&file)? };
let key = self.key.as_ref().ok_or_else(|| PortalError::NoPeer)?;
let cha_key = Key::from_slice(&key[..]);
let cipher = ChaCha20Poly1305::new(cha_key);
Ok(PortalFile::init(mmap,cipher))
}
pub fn create_file<'a>(&'a self, f: &str, size: u64) -> Result<PortalFile> {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(&f)?;
file.set_len(size)?;
let key = self.key.as_ref().ok_or_else(|| PortalError::NoPeer)?;
let mmap = unsafe {
MmapOptions::new().map_mut(&file)?
};
let cha_key = Key::from_slice(&key[..]);
let cipher = ChaCha20Poly1305::new(cha_key);
Ok(PortalFile::init(mmap,cipher))
}
pub fn get_chunks<'a>(&self, data: &'a PortalFile, chunk_size: usize) -> PortalChunks<'a,u8> {
PortalChunks::init(
&data.mmap[..], chunk_size,
)
}
pub fn confirm_peer(&mut self, msg_data: &[u8]) -> Result<()> {
let state = std::mem::replace(&mut self.state, None);
let state = state.ok_or_else(|| PortalError::BadState)?;
self.key = match state.finish(msg_data) {
Ok(res) => Some(res),
Err(_) => {return Err(PortalError::BadMsg.into());}
};
Ok(())
}
fn get_peer_msg_size() -> usize {
let edwards_point = <spake2::Ed25519Group as Group>::Element::default();
let compressed = edwards_point.compress();
std::mem::size_of_val(&compressed)+1
}
}
#[cfg(test)]
mod tests {
use super::{Portal,Direction};
#[test]
fn key_derivation() {
let dir = Some(Direction::Receiver);
let pass ="test".to_string();
let (mut receiver,receiver_msg) = Portal::init(dir,"id".to_string(),pass,None);
let dir = Some(Direction::Sender);
let pass ="test".to_string();
let (mut sender,sender_msg) = Portal::init(dir,"id".to_string(),pass,None);
receiver.confirm_peer(&sender_msg).unwrap();
sender.confirm_peer(&receiver_msg).unwrap();
assert_eq!(receiver.key,sender.key);
}
#[test]
fn portal_load_file() {
let dir = Some(Direction::Receiver);
let pass ="test".to_string();
let (_receiver,receiver_msg) = Portal::init(dir,"id".to_string(),pass,None);
let dir = Some(Direction::Sender);
let pass ="test".to_string();
let (mut sender,_sender_msg) = Portal::init(dir,"id".to_string(),pass,None);
sender.confirm_peer(&receiver_msg).unwrap();
let _file = sender.load_file("/etc/passwd").unwrap();
}
#[test]
fn portalfile_chunks_iterator() {
let dir = Some(Direction::Receiver);
let pass ="test".to_string();
let (_receiver,receiver_msg) = Portal::init(dir,"id".to_string(),pass,None);
let dir = Some(Direction::Sender);
let pass ="test".to_string();
let (mut sender,_sender_msg) = Portal::init(dir,"id".to_string(),pass,None);
sender.confirm_peer(&receiver_msg).unwrap();
let file = sender.load_file("/etc/passwd").unwrap();
let chunk_size = 10;
let chunks = sender.get_chunks(&file,chunk_size);
for v in chunks.into_iter() {
assert!(v.len() <= chunk_size);
}
let chunk_size = 1024;
let chunks = sender.get_chunks(&file,chunk_size);
for v in chunks.into_iter() {
assert!(v.len() <= chunk_size);
}
}
#[test]
fn portal_createfile() {
let dir = Some(Direction::Receiver);
let pass ="test".to_string();
let (mut receiver,receiver_msg) = Portal::init(dir,"id".to_string(),pass,None);
let dir = Some(Direction::Sender);
let pass ="test".to_string();
let (mut sender,sender_msg) = Portal::init(dir,"id".to_string(),pass,None);
sender.confirm_peer(&receiver_msg).unwrap();
receiver.confirm_peer(&sender_msg).unwrap();
let _file_dst = receiver.create_file("/tmp/passwd",4096).unwrap();
}
#[test]
fn portal_write_chunk() {
let dir = Some(Direction::Receiver);
let pass ="test".to_string();
let (mut receiver,receiver_msg) = Portal::init(dir,"id".to_string(),pass,None);
let dir = Some(Direction::Sender);
let pass ="test".to_string();
let (mut sender,sender_msg) = Portal::init(dir,"id".to_string(),pass,None);
sender.confirm_peer(&receiver_msg).unwrap();
receiver.confirm_peer(&sender_msg).unwrap();
let file_src = sender.load_file("/etc/passwd").unwrap();
let mut file_dst = receiver.create_file("/tmp/passwd",4096).unwrap();
let chunk_size = 4096;
let chunks = sender.get_chunks(&file_src,chunk_size);
for v in chunks.into_iter() {
assert!(v.len() <= chunk_size);
file_dst.write_given_chunk(&v).unwrap();
}
}
#[test]
#[should_panic]
fn portal_createfile_no_peer() {
let dir = Some(Direction::Sender);
let pass = "test".to_string();
let (portal,_msg) = Portal::init(dir,"id".to_string(),pass, None);
let _file_dst = portal.create_file("/tmp/passwd",4096).unwrap();
}
#[test]
#[should_panic]
fn portal_loadfile_no_peer() {
let dir = Some(Direction::Sender);
let pass = "test".to_string();
let (portal,_msg) = Portal::init(dir,"id".to_string(),pass, None);
let _file_src = portal.load_file("/etc/passwd").unwrap();
}
}