use std::{convert::TryFrom, io};
use tink_core::{utils::wrap_err, EncryptingWrite, TinkError};
pub trait SegmentEncrypter {
fn encrypt_segment(&self, segment: &[u8], nonce: &[u8]) -> Result<Vec<u8>, TinkError>;
}
pub struct Writer {
w: Box<dyn io::Write>,
segment_encrypter: Box<dyn SegmentEncrypter>,
encrypted_segment_cnt: u64,
first_ciphertext_segment_offset: usize,
nonce_size: usize,
nonce_prefix: Vec<u8>,
plaintext: Vec<u8>,
plaintext_pos: usize,
closed: bool,
}
pub struct WriterParams {
pub w: Box<dyn io::Write>,
pub segment_encrypter: Box<dyn SegmentEncrypter>,
pub nonce_size: usize,
pub nonce_prefix: Vec<u8>,
pub plaintext_segment_size: usize,
pub first_ciphertext_segment_offset: usize,
}
impl Writer {
pub fn new(params: WriterParams) -> Result<Writer, TinkError> {
if params.nonce_size - params.nonce_prefix.len() < 5 {
return Err("nonce size too short".into());
}
let ct_size = params.plaintext_segment_size + params.nonce_size;
match ct_size.checked_sub(params.first_ciphertext_segment_offset) {
None => {
return Err(
"first ciphertext segment offset bigger than ciphertext segment size".into(),
)
}
Some(sz) if sz <= params.nonce_size => {
return Err("first ciphertext segment not large enough for full nonce".into())
}
_ => {}
}
Ok(Writer {
w: params.w,
segment_encrypter: params.segment_encrypter,
encrypted_segment_cnt: 0,
first_ciphertext_segment_offset: params.first_ciphertext_segment_offset,
nonce_size: params.nonce_size,
nonce_prefix: params.nonce_prefix,
plaintext: vec![0; params.plaintext_segment_size],
plaintext_pos: 0,
closed: false,
})
}
}
impl io::Write for Writer {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.closed {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"write on closed writer",
));
}
let mut pos = 0; loop {
let mut pt_lim = self.plaintext.len();
if self.encrypted_segment_cnt == 0 {
pt_lim -= self.first_ciphertext_segment_offset
}
let n = std::cmp::min(pt_lim - self.plaintext_pos, buf.len() - pos);
self.plaintext[self.plaintext_pos..self.plaintext_pos + n]
.copy_from_slice(&buf[pos..pos + n]);
self.plaintext_pos += n;
pos += n;
if pos == buf.len() {
break;
}
if self.plaintext_pos != pt_lim {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"internal error: pos={} != pt_lim={}",
self.plaintext_pos, pt_lim
),
));
}
let nonce = generate_segment_nonce(
self.nonce_size,
&self.nonce_prefix,
self.encrypted_segment_cnt,
false,
)?;
let ciphertext = self
.segment_encrypter
.encrypt_segment(&self.plaintext[..pt_lim], &nonce)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{:?}", e)))?;
self.w.write_all(&ciphertext)?;
self.plaintext_pos = 0;
self.encrypted_segment_cnt += 1;
}
Ok(pos)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl EncryptingWrite for Writer {
fn close(&mut self) -> Result<(), TinkError> {
if self.closed {
return Ok(());
}
let nonce = generate_segment_nonce(
self.nonce_size,
&self.nonce_prefix,
self.encrypted_segment_cnt,
true,
)
.map_err(|e| wrap_err("internal error", e))?;
let ciphertext = self
.segment_encrypter
.encrypt_segment(&self.plaintext[..self.plaintext_pos], &nonce)?;
self.w
.write_all(&ciphertext)
.map_err(|e| wrap_err("write failure", e))?;
self.plaintext_pos = 0;
self.encrypted_segment_cnt += 1;
self.closed = true;
Ok(())
}
}
impl Drop for Writer {
fn drop(&mut self) {
let _ = self.close();
}
}
pub trait SegmentDecrypter {
fn decrypt_segment(&self, segment: &[u8], nonce: &[u8]) -> Result<Vec<u8>, TinkError>;
}
pub struct Reader {
r: Box<dyn io::Read>,
segment_decrypter: Box<dyn SegmentDecrypter>,
decrypted_segment_cnt: u64,
first_ciphertext_segment_offset: usize,
nonce_size: usize,
nonce_prefix: Vec<u8>,
plaintext: Vec<u8>,
plaintext_pos: usize,
ciphertext: Vec<u8>,
ciphertext_pos: usize,
}
pub struct ReaderParams {
pub r: Box<dyn io::Read>,
pub segment_decrypter: Box<dyn SegmentDecrypter>,
pub nonce_size: usize,
pub nonce_prefix: Vec<u8>,
pub ciphertext_segment_size: usize,
pub first_ciphertext_segment_offset: usize,
}
impl Reader {
pub fn new(params: ReaderParams) -> Result<Reader, TinkError> {
if params.nonce_size - params.nonce_prefix.len() < 5 {
return Err("nonce size too short".into());
}
match params
.ciphertext_segment_size
.checked_sub(params.first_ciphertext_segment_offset)
{
None => {
return Err(
"first ciphertext segment offset bigger than ciphertext segment size".into(),
)
}
Some(sz) if sz <= params.nonce_size => {
return Err("first ciphertext segment not large enough for full nonce".into())
}
_ => {}
}
Ok(Reader {
r: params.r,
segment_decrypter: params.segment_decrypter,
decrypted_segment_cnt: 0,
first_ciphertext_segment_offset: params.first_ciphertext_segment_offset,
nonce_size: params.nonce_size,
nonce_prefix: params.nonce_prefix,
plaintext: vec![],
plaintext_pos: 0,
ciphertext: vec![0; params.ciphertext_segment_size + 1],
ciphertext_pos: 0,
})
}
}
trait ReadFullExt {
fn read_full(&mut self, buf: &mut [u8]) -> std::io::Result<usize>;
}
impl ReadFullExt for dyn std::io::Read {
fn read_full(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
let mut count = 0;
while !buf.is_empty() {
match self.read(buf) {
Ok(0) => break,
Ok(n) => {
count += n;
let tmp = buf;
buf = &mut tmp[n..];
}
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(count)
}
}
impl io::Read for Reader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.plaintext_pos < self.plaintext.len() {
let n = std::cmp::min(buf.len(), self.plaintext.len() - self.plaintext_pos);
buf[..n].copy_from_slice(&self.plaintext[self.plaintext_pos..(self.plaintext_pos + n)]);
self.plaintext_pos += n;
return Ok(n);
}
self.plaintext_pos = 0;
let mut ct_lim = self.ciphertext.len();
if self.decrypted_segment_cnt == 0 {
ct_lim -= self.first_ciphertext_segment_offset;
}
let n = self
.r
.read_full(&mut self.ciphertext[self.ciphertext_pos..ct_lim])?;
if n == 0 {
return Ok(0);
}
let last_segment;
let segment;
if n != (ct_lim - self.ciphertext_pos) {
last_segment = true;
segment = self.ciphertext_pos + n;
} else {
last_segment = false;
if (self.ciphertext_pos + n) < 1 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"ciphertext segment too short",
));
}
segment = self.ciphertext_pos + n - 1;
}
let nonce = generate_segment_nonce(
self.nonce_size,
&self.nonce_prefix,
self.decrypted_segment_cnt,
last_segment,
)?;
self.plaintext = self
.segment_decrypter
.decrypt_segment(&self.ciphertext[..segment], &nonce)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{:?}", e)))?;
if !last_segment {
let remainder_offset = segment;
self.ciphertext[0] = self.ciphertext[remainder_offset];
self.ciphertext_pos = 1;
}
self.decrypted_segment_cnt += 1;
let n = std::cmp::min(buf.len(), self.plaintext.len());
buf[..n].copy_from_slice(&self.plaintext[..n]);
self.plaintext_pos = n;
Ok(n)
}
}
fn generate_segment_nonce(
size: usize,
prefix: &[u8],
segment_num: u64,
last: bool,
) -> io::Result<Vec<u8>> {
let segment_num = match u32::try_from(segment_num) {
Ok(v) => v,
Err(_) => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"too many segments",
))
}
};
let mut nonce = vec![0; size];
nonce[..prefix.len()].copy_from_slice(prefix);
let mut offset = prefix.len();
nonce[offset..offset + 4].copy_from_slice(&segment_num.to_be_bytes()[..]);
offset += 4;
if last {
nonce[offset] = 1;
}
Ok(nonce)
}