use std::cmp::{max, min};
use std::fs::File;
use std::io::{self, BufReader, Cursor, Error as IoError, Read, Write};
use std::sync::{Arc, Mutex};
use openssl::symm::{Cipher, Crypter, Mode as CrypterMode};
const TAG_LEN: usize = 16;
pub struct EncryptedFileReader {
file: File,
cipher: Cipher,
crypter: Crypter,
tag: Option<Cursor<Vec<u8>>>,
internal_buf: Vec<u8>,
}
impl EncryptedFileReader {
pub fn new(file: File, cipher: Cipher, key: &[u8], iv: &[u8]) -> Result<Self, io::Error> {
let crypter = Crypter::new(cipher, CrypterMode::Encrypt, key, Some(iv))?;
Ok(EncryptedFileReader {
file,
cipher,
crypter,
tag: None,
internal_buf: Vec::new(),
})
}
fn read_internal(&mut self, buf: &mut [u8]) -> usize {
if self.internal_buf.is_empty() || buf.is_empty() {
return 0;
}
let len = min(buf.len(), self.internal_buf.len());
{
let (out, _) = self.internal_buf.split_at(len);
let (buf, _) = buf.split_at_mut(len);
buf.copy_from_slice(out);
}
self.internal_buf.drain(..len);
len
}
fn read_file_encrypted(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
let block_size = self.cipher.block_size();
let mut data = vec![0u8; buf.len()];
let len = self.file.read(&mut data)?;
if len == 0 {
return Ok(0);
}
let mut encrypted = vec![0u8; len + block_size];
let len = self.crypter.update(&data[..len], &mut encrypted)?;
let out_len = min(buf.len(), len);
let (out, remaining) = encrypted.split_at(out_len);
let (buf, _) = buf.split_at_mut(out_len);
buf.copy_from_slice(out);
let (store, _) = remaining.split_at(len - out_len);
self.internal_buf.extend(store.iter());
Ok(out_len)
}
fn finalize_file(&mut self) -> Result<(), io::Error> {
let mut output = vec![0u8; self.cipher.block_size()];
let len = self.crypter.finalize(&mut output)?;
if len > 0 {
self.internal_buf.extend(output.iter().take(len));
}
let mut tag = vec![0u8; TAG_LEN];
self.crypter.get_tag(&mut tag)?;
self.tag = Some(Cursor::new(tag));
Ok(())
}
}
impl ExactLengthReader for EncryptedFileReader {
fn len(&self) -> Result<u64, io::Error> {
Ok(self.file.metadata()?.len() + TAG_LEN as u64)
}
}
impl Read for EncryptedFileReader {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
let len = self.read_internal(buf);
if len >= buf.len() {
return Ok(len);
}
let (_, buf) = buf.split_at_mut(len);
let mut total = len;
if let Some(ref mut tag) = self.tag {
return Ok(tag.read(buf)? + total);
}
let len = self.read_file_encrypted(buf)?;
total += len;
if len >= buf.len() {
return Ok(total);
}
let (_, buf) = buf.split_at_mut(len);
self.finalize_file()?;
Ok(self.read(buf)? + total)
}
}
unsafe impl Send for EncryptedFileReader {}
pub struct ProgressReader<R> {
inner: R,
len: u64,
progress: u64,
reporter: Option<Arc<Mutex<ProgressReporter>>>,
}
impl<R: Read> ProgressReader<R> {
pub fn new(inner: R) -> Result<Self, IoError>
where
R: ExactLengthReader,
{
Ok(Self {
len: inner.len()?,
inner,
progress: 0,
reporter: None,
})
}
pub fn from(inner: R, len: u64) -> Self {
Self {
inner,
len,
progress: 0,
reporter: None,
}
}
pub fn set_reporter(&mut self, reporter: Arc<Mutex<ProgressReporter>>) {
self.reporter = Some(reporter);
}
pub fn progress(&self) -> u64 {
self.progress
}
}
impl<R: Read> Read for ProgressReader<R> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
let len = self.inner.read(buf)?;
self.progress += len as u64;
if self.progress > self.len {
self.len = self.progress;
}
if let Some(reporter) = self.reporter.as_mut() {
let progress = self.progress;
let _ = reporter.lock().map(|mut r| r.progress(progress));
}
Ok(len)
}
}
impl<R: Read> ExactLengthReader for ProgressReader<R> {
fn len(&self) -> Result<u64, io::Error> {
Ok(self.len)
}
}
pub trait ProgressReporter: Send {
fn start(&mut self, total: u64);
fn progress(&mut self, progress: u64);
fn finish(&mut self);
}
pub trait ExactLengthReader {
fn len(&self) -> Result<u64, io::Error>;
fn is_empty(&self) -> Result<bool, io::Error> {
self.len().map(|l| l == 0)
}
}
impl<R: ExactLengthReader + Read> ExactLengthReader for BufReader<R> {
fn len(&self) -> Result<u64, io::Error> {
self.get_ref().len()
}
}
pub struct EncryptedFileWriter {
file: File,
cur: usize,
len: usize,
cipher: Cipher,
crypter: Crypter,
tag_buf: Vec<u8>,
verified: bool,
}
impl EncryptedFileWriter {
pub fn new(
file: File,
len: usize,
cipher: Cipher,
key: &[u8],
iv: &[u8],
) -> Result<Self, io::Error> {
let crypter = Crypter::new(cipher, CrypterMode::Decrypt, key, Some(iv))?;
Ok(EncryptedFileWriter {
file,
cur: 0,
len,
cipher,
crypter,
tag_buf: Vec::with_capacity(TAG_LEN),
verified: false,
})
}
pub fn has_tag(&self) -> bool {
self.tag_buf.len() >= TAG_LEN
}
pub fn verified(&self) -> bool {
self.verified
}
}
impl ExactLengthReader for EncryptedFileWriter {
fn len(&self) -> Result<u64, IoError> {
Ok(self.len as u64)
}
}
impl Write for EncryptedFileWriter {
fn write(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
if self.verified() || self.has_tag() {
return Ok(0);
}
let file_bytes = max(self.len - TAG_LEN - self.cur, 0);
let tag_bytes = TAG_LEN - self.tag_buf.len();
let (file_buf, tag_buf) = buf.split_at(min(file_bytes, buf.len()));
if !file_buf.is_empty() {
let block_size = self.cipher.block_size();
let mut decrypted = vec![0u8; file_bytes + block_size];
let len = self.crypter.update(file_buf, &mut decrypted)?;
self.file.write_all(&decrypted[..len])?;
}
if !tag_buf.is_empty() {
self.tag_buf.extend(tag_buf.iter().take(tag_bytes));
}
if self.has_tag() {
self.crypter.set_tag(&self.tag_buf)?;
let block_size = self.cipher.block_size();
let mut extra = vec![0u8; block_size];
let len = self.crypter.finalize(&mut extra)?;
self.file.write_all(&extra[..len])?;
self.verified = true;
}
let len = file_buf.len() + min(tag_buf.len(), TAG_LEN);
self.cur += len;
Ok(len)
}
fn flush(&mut self) -> Result<(), io::Error> {
self.file.flush()
}
}
pub struct ProgressWriter<W> {
inner: W,
len: u64,
progress: u64,
reporter: Option<Arc<Mutex<ProgressReporter>>>,
}
impl<W: Write> ProgressWriter<W> {
pub fn new(inner: W) -> Result<Self, IoError>
where
W: ExactLengthReader,
{
Ok(Self {
len: inner.len()?,
inner,
progress: 0,
reporter: None,
})
}
pub fn from(inner: W, len: u64) -> Self {
Self {
inner,
len,
progress: 0,
reporter: None,
}
}
pub fn set_reporter(&mut self, reporter: Arc<Mutex<ProgressReporter>>) {
self.reporter = Some(reporter);
}
pub fn progress(&self) -> u64 {
self.progress
}
pub fn unwrap(self) -> W {
self.inner
}
}
impl<W: Write> Write for ProgressWriter<W> {
fn write(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
let len = self.inner.write(buf)?;
self.progress += len as u64;
if self.progress > self.len {
self.len = self.progress;
}
if let Some(reporter) = self.reporter.as_mut() {
let progress = self.progress;
let _ = reporter.lock().map(|mut r| r.progress(progress));
}
Ok(len)
}
fn flush(&mut self) -> Result<(), IoError> {
self.inner.flush()
}
}
impl<W: Write> ExactLengthReader for ProgressWriter<W> {
fn len(&self) -> Result<u64, io::Error> {
Ok(self.len)
}
}