use anyhow::Result;
use chacha20poly1305::ChaCha20Poly1305;
use chacha20poly1305::{aead::AeadInPlace, Nonce, Tag};
use memmap::MmapMut;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::io::Write;
use crate::chunks::PortalChunks;
use crate::errors::PortalError;
pub struct PortalFile {
pub mmap: MmapMut,
pub cipher: ChaCha20Poly1305,
state: StateMetadata,
pos: usize,
}
#[derive(Serialize, Deserialize, PartialEq, Default, Debug)]
pub struct StateMetadata {
pub nonce: Vec<u8>,
pub tag: Vec<u8>,
}
impl PortalFile {
pub fn init(mmap: MmapMut, cipher: ChaCha20Poly1305) -> PortalFile {
PortalFile {
mmap,
cipher,
pos: 0,
state: StateMetadata {
nonce: Vec::new(),
tag: Vec::new(),
},
}
}
pub fn encrypt(&mut self) -> Result<()> {
let mut rng = rand::thread_rng();
let rbytes = rng.gen::<[u8; 12]>();
let nonce = Nonce::from_slice(&rbytes); self.state.nonce.extend(nonce);
let tag = match self
.cipher
.encrypt_in_place_detached(nonce, b"", &mut self.mmap[..])
{
Ok(tag) => tag,
Err(_e) => return Err(PortalError::EncryptError.into()),
};
self.state.tag.extend(tag);
Ok(())
}
pub fn decrypt(&mut self) -> Result<()> {
if self.state.nonce.len() != std::mem::size_of::<Nonce>()
|| self.state.tag.len() != std::mem::size_of::<Tag>()
{
return Err(PortalError::DecryptError.into());
}
let nonce = Nonce::from_slice(&self.state.nonce);
let tag = Tag::from_slice(&self.state.tag);
match self
.cipher
.decrypt_in_place_detached(nonce, b"", &mut self.mmap[..], &tag)
{
Ok(_) => Ok(()),
Err(_e) => Err(PortalError::DecryptError.into()),
}
}
pub fn sync_file_state<W>(&mut self, writer: &mut W) -> Result<usize>
where
W: std::io::Write,
{
let data: Vec<u8> = bincode::serialize(&self.state)?;
writer.write_all(&data)?;
Ok(data.len())
}
pub fn download_file<R, F>(&mut self, mut reader: R, callback: F) -> Result<u64>
where
R: std::io::Read,
F: Fn(u64),
{
let remote_state: StateMetadata = bincode::deserialize_from(&mut reader)?;
self.state.nonce.extend(&remote_state.nonce);
self.state.tag.extend(&remote_state.tag);
loop {
let len = match reader.read(&mut self.mmap[self.pos..]) {
Ok(0) => {
return Ok(self.pos as u64);
}
Ok(len) => len,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e.into()),
};
self.pos += len;
callback(self.pos as u64);
}
}
pub fn get_chunks<'a>(
&'a self,
chunk_size: usize,
) -> impl std::iter::Iterator<Item = &'a [u8]> {
PortalChunks::init(&self.mmap[..], chunk_size)
}
pub fn write_given_chunk(&mut self, data: &[u8]) -> Result<u64> {
(&mut self.mmap[self.pos..]).write_all(&data)?;
self.pos += data.len();
Ok(data.len() as u64)
}
}
#[cfg(test)]
pub mod tests {
use crate::errors::PortalError;
use crate::{Direction, Portal};
use std::io::{Read, Write};
pub struct MockTcpStream {
pub data: Vec<u8>,
}
impl Read for MockTcpStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
let size: usize = std::cmp::min(self.data.len(), buf.len());
buf[..size].copy_from_slice(&self.data[..size]);
self.data.drain(0..size);
Ok(size)
}
}
impl Write for MockTcpStream {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
self.data.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> Result<(), std::io::Error> {
Ok(())
}
}
#[test]
fn test_failed_decryption() {
let dir = Direction::Receiver;
let pass = "test".to_string();
let (mut receiver, receiver_msg) = Portal::init(dir, "id".to_string(), pass, None);
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut sender, sender_msg) = Portal::init(dir, "id".to_string(), pass, None);
receiver.derive_key(sender_msg.as_slice()).unwrap();
sender.derive_key(receiver_msg.as_slice()).unwrap();
let mut file = sender.load_file("/etc/passwd").unwrap();
file.encrypt().unwrap();
let old_tag = file.state.tag.clone();
file.state.tag.truncate(0);
let result = file.decrypt();
assert!(result.is_err());
let _ = result.map_err(|e| match e.downcast_ref::<PortalError>() {
Some(PortalError::DecryptError) => anyhow::Ok(()),
_ => panic!("Unexpected error"),
});
file.state.tag = old_tag;
file.state.tag[0] += 1; let result = file.decrypt();
assert!(result.is_err());
let _ = result.map_err(|e| match e.downcast_ref::<PortalError>() {
Some(PortalError::DecryptError) => anyhow::Ok(()),
_ => panic!("Unexpected error"),
});
}
#[test]
fn test_sync_file_download_file() {
let dir = Direction::Receiver;
let pass = "test".to_string();
let (mut receiver, receiver_msg) = Portal::init(dir, "id".to_string(), pass, None);
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut sender, sender_msg) = Portal::init(dir, "id".to_string(), pass, None);
receiver.derive_key(sender_msg.as_slice()).unwrap();
sender.derive_key(receiver_msg.as_slice()).unwrap();
let mut file = sender.load_file("/etc/passwd").unwrap();
file.encrypt().unwrap();
let mut stream = MockTcpStream {
data: Vec::with_capacity(crate::CHUNK_SIZE),
};
file.sync_file_state(&mut stream).unwrap();
for data in file.get_chunks(crate::CHUNK_SIZE) {
stream.write(&data).unwrap();
}
let mut new_file = receiver
.create_file("/tmp/passwd", file.mmap[..].len() as u64)
.unwrap();
new_file
.download_file(&mut stream, |x| println!("{:?}", x))
.unwrap();
assert_eq!(&file.state.tag, &new_file.state.tag);
assert_eq!(&file.state.nonce, &new_file.state.nonce);
assert_eq!(&file.mmap[..], &new_file.mmap[..]);
new_file.decrypt().unwrap(); stream.flush().unwrap(); }
#[test]
fn test_encrypt_decrypt() {
let dir = Direction::Receiver;
let pass = "test".to_string();
let (mut receiver, receiver_msg) = Portal::init(dir, "id".to_string(), pass, None);
let dir = Direction::Sender;
let pass = "test".to_string();
let (mut sender, sender_msg) = Portal::init(dir, "id".to_string(), pass, None);
receiver.derive_key(sender_msg.as_slice()).unwrap();
sender.derive_key(receiver_msg.as_slice()).unwrap();
let mut file = sender.load_file("/etc/passwd").unwrap();
let file_before = String::from_utf8((&file.mmap[..]).to_vec());
file.encrypt().unwrap();
let file_encrypted = String::from_utf8((&file.mmap[..]).to_vec());
file.decrypt().unwrap();
let file_after = String::from_utf8((&file.mmap[..]).to_vec());
assert_ne!(file_before, file_encrypted);
assert_eq!(file_before, file_after);
}
}