use std::cmp::min;
use std::io::{self, Read, Write};
use byteorder::{BigEndian, ByteOrder};
use bytes::BytesMut;
#[cfg(feature = "crypto-openssl")]
use openssl::symm;
#[cfg(feature = "crypto-ring")]
use ring::aead;
use super::{Crypt, CryptMode};
use crate::config::{self, TAG_LEN};
use crate::crypto::{hkdf::hkdf, rand_bytes};
use crate::pipe::{prelude::*, DEFAULT_BUF_SIZE};
pub const RS: u32 = config::ECE_RECORD_SIZE;
const KEY_LEN: usize = 16;
const NONCE_LEN: usize = 12;
pub const HEADER_LEN: u32 = 21;
const SALT_LEN: usize = 16;
const RS_LEN: usize = 4;
const KEY_INFO: &str = "Content-Encoding: aes128gcm\0";
const NONCE_INFO: &str = "Content-Encoding: nonce\0";
pub struct EceCrypt {
mode: CryptMode,
ikm: Vec<u8>,
key: Option<Vec<u8>>,
nonce: Option<Vec<u8>>,
salt: Option<Vec<u8>>,
seq: u32,
cur_in: usize,
cur: usize,
len: usize,
rs: u32,
}
impl EceCrypt {
pub fn new(mode: CryptMode, len: usize, ikm: Vec<u8>, salt: Option<Vec<u8>>) -> Self {
Self {
mode,
ikm,
key: None,
nonce: None,
salt,
seq: 0,
cur_in: 0,
cur: 0,
len,
rs: RS,
}
}
pub fn encrypt(len: usize, ikm: Vec<u8>, salt: Option<Vec<u8>>) -> Self {
let mut crypt = Self::new(
CryptMode::Encrypt,
len,
ikm,
salt.or_else(|| Some(generate_salt())),
);
crypt.derive_key_and_nonce();
crypt
}
pub fn decrypt(len: usize, ikm: Vec<u8>) -> Self {
Self::new(CryptMode::Decrypt, len, ikm, None)
}
#[inline(always)]
fn chunk_size(&self) -> u32 {
match self.mode {
CryptMode::Encrypt => self.rs - TAG_LEN as u32 - 1,
CryptMode::Decrypt => {
if self.has_header() {
self.rs
} else {
HEADER_LEN
}
}
}
}
fn pipe_encrypt(&mut self, input: Vec<u8>) -> (usize, Option<Vec<u8>>) {
if !self.has_header() {
let mut ciphertext = self.create_header();
let (read, chunk) = self.encrypt_chunk(input);
if let Some(chunk) = chunk {
ciphertext.extend_from_slice(&chunk)
}
self.increase_seq();
(read, Some(ciphertext))
} else {
let result = self.encrypt_chunk(input);
self.increase_seq();
result
}
}
fn pipe_decrypt(&mut self, input: &[u8]) -> (usize, Option<Vec<u8>>) {
if !self.has_header() {
self.parse_header(input);
return (input.len(), None);
}
let result = self.decrypt_chunk(input);
self.increase_seq();
result
}
#[inline(always)]
fn encrypt_chunk(&mut self, mut plaintext: Vec<u8>) -> (usize, Option<Vec<u8>>) {
let read = plaintext.len();
self.cur += read;
let nonce = self.generate_nonce(self.seq);
pad(&mut plaintext, self.rs as usize, self.is_last());
#[cfg(feature = "crypto-openssl")]
{
let mut tag = vec![0u8; TAG_LEN];
let mut ciphertext = symm::encrypt_aead(
symm::Cipher::aes_128_gcm(),
self.key
.as_ref()
.expect("failed to encrypt ECE chunk, missing crypto key"),
Some(&nonce),
&[],
&plaintext,
&mut tag,
)
.expect("failed to encrypt ECE chunk");
ciphertext.extend_from_slice(&tag);
(read, Some(ciphertext))
}
#[cfg(feature = "crypto-ring")]
{
let nonce = aead::Nonce::try_assume_unique_for_key(&nonce)
.expect("failed to encrypt ECE chunk, invalid nonce");
let aad = aead::Aad::empty();
let key = self
.key
.as_ref()
.expect("failed to encrypt ECE chunk, missing crypto key");
let unbound_key = aead::UnboundKey::new(&aead::AES_128_GCM, key).unwrap();
let key = aead::LessSafeKey::new(unbound_key);
key.seal_in_place_append_tag(nonce, aad, &mut plaintext)
.expect("failed to encrypt ECE chunk");
(read, Some(plaintext.to_vec()))
}
}
#[inline(always)]
fn decrypt_chunk(&mut self, ciphertext: &[u8]) -> (usize, Option<Vec<u8>>) {
let nonce = self.generate_nonce(self.seq);
#[cfg(feature = "crypto-openssl")]
{
let (payload, tag) = ciphertext.split_at(ciphertext.len() - TAG_LEN);
let mut plaintext = symm::decrypt_aead(
symm::Cipher::aes_128_gcm(),
self.key
.as_ref()
.expect("failed to decrypt ECE chunk, missing crypto key"),
Some(&nonce),
&[],
payload,
tag,
)
.expect("failed to decrypt ECE chunk");
unpad(&mut plaintext, self.is_last());
self.cur += plaintext.len();
(ciphertext.len(), Some(plaintext))
}
#[cfg(feature = "crypto-ring")]
{
let mut ciphertext = ciphertext.to_vec();
let nonce = aead::Nonce::try_assume_unique_for_key(&nonce)
.expect("failed to decrypt ECE chunk, invalid nonce");
let aad = aead::Aad::empty();
let key = self
.key
.as_ref()
.expect("failed to decrypt ECE chunk, missing crypto key");
let unbound_key = aead::UnboundKey::new(&aead::AES_128_GCM, key).unwrap();
let key = aead::LessSafeKey::new(unbound_key);
let mut plaintext = key
.open_in_place(nonce, aad, &mut ciphertext)
.expect("failed to decrypt ECE chunk")
.to_vec();
unpad(&mut plaintext, self.is_last());
self.cur += plaintext.len();
(ciphertext.len(), Some(plaintext))
}
}
#[inline(always)]
fn create_header(&self) -> Vec<u8> {
let mut header = Vec::with_capacity(HEADER_LEN as usize);
let salt = self
.salt
.as_ref()
.expect("failed to create ECE header, no crypto salt specified");
assert_eq!(salt.len(), SALT_LEN);
header.extend_from_slice(salt);
let mut rs = [0u8; 4];
BigEndian::write_u32(&mut rs, self.rs);
header.extend_from_slice(&rs);
header.push(0);
header
}
#[inline(always)]
fn parse_header(&mut self, header: &[u8]) {
assert_eq!(
header.len() as u32,
HEADER_LEN,
"failed to decrypt, ECE header is not 21 bytes long",
);
let (salt, header) = header.split_at(SALT_LEN);
let (rs, header) = header.split_at(RS_LEN);
self.salt = Some(salt.to_vec());
self.rs = BigEndian::read_u32(rs);
let (key_id_data, header) = header.split_at(1);
let key_id_len = key_id_data[0] as usize;
let _length = key_id_len + KEY_LEN + 5;
self.derive_key_and_nonce();
assert!(
header.is_empty(),
"failed to decrypt, not all ECE header bytes are used",
);
}
#[inline(always)]
fn derive_key_and_nonce(&mut self) {
self.key = Some(hkdf(
self.salt.as_ref().map(|s| s.as_slice()),
KEY_LEN,
&self.ikm,
Some(KEY_INFO.as_bytes()),
));
self.nonce = Some(hkdf(
self.salt.as_ref().map(|s| s.as_slice()),
NONCE_LEN,
&self.ikm,
Some(NONCE_INFO.as_bytes()),
));
}
#[inline(always)]
fn generate_nonce(&self, seq: u32) -> Vec<u8> {
let mut nonce = self
.nonce
.clone()
.expect("failed to generate nonce, no base nonce available");
let nonce_len = nonce.len();
let m = BigEndian::read_u32(&nonce[nonce_len - 4..nonce_len]);
let xor = m ^ seq;
BigEndian::write_u32(&mut nonce[nonce_len - 4..nonce_len], xor);
nonce
}
#[inline(always)]
fn has_header(&self) -> bool {
match self.mode {
CryptMode::Encrypt => self.cur > 0,
CryptMode::Decrypt => self.salt.is_some(),
}
}
#[inline(always)]
fn is_last(&self) -> bool {
self.is_last_with(0)
}
#[inline(always)]
fn is_last_with(&self, extra: usize) -> bool {
self.cur_in + extra >= self.len_in()
}
#[inline(always)]
fn increase_seq(&mut self) {
self.seq = self
.seq
.checked_add(1)
.expect("failed to crypt ECE payload, record sequence number exceeds limit");
}
}
impl Pipe for EceCrypt {
type Reader = EceReader;
type Writer = EceWriter;
fn pipe(&mut self, input: &[u8]) -> (usize, Option<Vec<u8>>) {
self.cur_in += input.len();
match self.mode {
CryptMode::Encrypt => self.pipe_encrypt(input.to_vec()),
CryptMode::Decrypt => self.pipe_decrypt(input),
}
}
}
impl Crypt for EceCrypt {}
impl PipeLen for EceCrypt {
fn len_in(&self) -> usize {
match self.mode {
CryptMode::Encrypt => self.len,
CryptMode::Decrypt => len_encrypted(self.len, self.rs as usize),
}
}
fn len_out(&self) -> usize {
match self.mode {
CryptMode::Encrypt => len_encrypted(self.len, self.rs as usize),
CryptMode::Decrypt => self.len,
}
}
}
pub struct EceReader {
crypt: EceCrypt,
inner: Box<dyn Read>,
buf_in: BytesMut,
buf_out: BytesMut,
}
pub struct EceWriter {
crypt: EceCrypt,
inner: Box<dyn Write>,
buf: BytesMut,
}
impl PipeRead<EceCrypt> for EceReader {
fn new(crypt: EceCrypt, inner: Box<dyn Read>) -> Self {
let chunk_size = crypt.chunk_size() as usize;
Self {
crypt,
inner,
buf_in: BytesMut::with_capacity(chunk_size),
buf_out: BytesMut::with_capacity(DEFAULT_BUF_SIZE),
}
}
}
impl PipeWrite<EceCrypt> for EceWriter {
fn new(crypt: EceCrypt, inner: Box<dyn Write>) -> Self {
let chunk_size = crypt.chunk_size() as usize;
Self {
crypt,
inner,
buf: BytesMut::with_capacity(chunk_size),
}
}
}
impl Read for EceReader {
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
let mut total = 0;
if !self.buf_out.is_empty() {
let write = min(self.buf_out.len(), buf.len());
total += write;
buf[..write].copy_from_slice(&self.buf_out.split_to(write));
if total >= buf.len() {
return Ok(total);
}
buf = &mut buf[write..];
}
let capacity = self.crypt.chunk_size() as usize - self.buf_in.len();
if capacity > 0 {
let mut inner_buf = vec![0u8; capacity];
let read = self.inner.read(&mut inner_buf)?;
self.buf_in.extend_from_slice(&inner_buf[..read]);
if read == 0 || (read != capacity && !self.crypt.is_last_with(read)) {
return Ok(total);
}
}
let (read, out) = self.crypt.crypt(&self.buf_in);
let _ = self.buf_in.split_to(read);
if let Some(out) = out {
let write = min(out.len(), buf.len());
total += write;
buf[..write].copy_from_slice(&out[..write]);
if write < out.len() {
self.buf_out.extend_from_slice(&out[write..]);
}
if write >= buf.len() {
return Ok(total);
}
buf = &mut buf[write..];
}
self.read(buf).map(|n| n + total)
}
}
impl Write for EceWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let chunk_size = self.crypt.chunk_size() as usize;
let capacity = chunk_size - self.buf.len();
let read = min(capacity, buf.len());
if capacity > 0 {
self.buf.extend_from_slice(&buf[..read]);
}
if self.buf.len() >= chunk_size {
let (read, data) = self.crypt.crypt(&self.buf.split_off(0));
assert_eq!(read, chunk_size, "ECE crypto did not transform full chunk");
if let Some(data) = data {
self.inner.write_all(&data)?;
}
}
if self.crypt.is_last_with(self.buf.len()) {
if let (_, Some(data)) = self.crypt.crypt(&self.buf.split_off(0)) {
self.inner.write_all(&data)?;
}
}
Ok(read)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl PipeLen for EceReader {
fn len_in(&self) -> usize {
self.crypt.len_in()
}
fn len_out(&self) -> usize {
self.crypt.len_out()
}
}
impl ReadLen for EceReader {}
impl PipeLen for EceWriter {
fn len_in(&self) -> usize {
self.crypt.len_in()
}
fn len_out(&self) -> usize {
self.crypt.len_out()
}
}
impl WriteLen for EceWriter {}
unsafe impl Send for EceReader {}
unsafe impl Send for EceWriter {}
fn pad(block: &mut Vec<u8>, rs: usize, last: bool) {
assert!(
block.len() + TAG_LEN < rs,
"failed to pad ECE ciphertext, data too large for record size"
);
if !last {
let mut pad = vec![0u8; rs - block.len() - TAG_LEN];
pad[0] = 1;
block.extend(pad);
} else {
block.push(2);
}
}
fn unpad(block: &mut Vec<u8>, last: bool) {
let pos = match block.iter().rposition(|&b| b != 0) {
Some(pos) => pos,
None => panic!("ciphertext is zero"),
};
let expected_delim = if last { 2 } else { 1 };
assert_eq!(block[pos], expected_delim, "ECE decrypt unpadding failure");
block.truncate(pos);
}
pub fn generate_salt() -> Vec<u8> {
let mut salt = vec![0u8; SALT_LEN];
rand_bytes(&mut salt).expect("failed to generate encryption salt");
salt
}
pub fn len_encrypted(len: usize, rs: usize) -> usize {
let chunk_meta = TAG_LEN + 1;
let chunk_data = rs - chunk_meta;
let header = HEADER_LEN as usize;
let chunks = (len as f64 / chunk_data as f64).ceil() as usize;
header + len + chunk_meta * chunks
}