#[cfg(test)]
mod tests;
use openssl::cipher as ossl_cipher;
use openssl::cipher_ctx::CipherCtx;
use openssl::error::ErrorStack;
use std::str::FromStr;
use std::{cmp, fmt};
use thiserror::Error;
use crate::buffer::{Buffer, BufferError, BufferMut};
use crate::svec::SecureVec;
#[derive(Debug, Error)]
pub enum CipherError {
#[error("invalid key")]
InvalidKey,
#[error("invalid iv")]
InvalidIv,
#[error("invalid block-size")]
InvalidBlockSize,
#[error("the plaintext is not trustworthy")]
NotTrustworthy,
#[error(transparent)]
OpenSSL(#[from] ErrorStack),
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Cipher {
None,
Aes128Ctr,
Aes192Ctr,
Aes256Ctr,
Aes128Gcm,
Aes192Gcm,
Aes256Gcm,
}
impl Cipher {
pub fn block_size(&self) -> usize {
match self.to_openssl() {
None => 1,
Some(c) => c.block_size(),
}
}
pub fn key_len(&self) -> usize {
match self.to_openssl() {
None => 0,
Some(c) => c.key_length(),
}
}
pub fn iv_len(&self) -> usize {
match self.to_openssl() {
None => 0,
Some(c) => c.iv_length(),
}
}
pub fn tag_size(&self) -> u32 {
match self {
Cipher::None => 0,
Cipher::Aes128Ctr | Cipher::Aes192Ctr | Cipher::Aes256Ctr => 0,
Cipher::Aes128Gcm | Cipher::Aes192Gcm | Cipher::Aes256Gcm => 16,
}
}
pub(crate) fn get_from_buffer<T: Buffer>(buf: &mut T) -> Result<Cipher, BufferError> {
let b = buf.get_u32()?;
match b {
0 => Ok(Cipher::None),
1 => Ok(Cipher::Aes128Ctr),
2 => Ok(Cipher::Aes128Gcm),
3 => Ok(Cipher::Aes192Ctr),
4 => Ok(Cipher::Aes256Ctr),
5 => Ok(Cipher::Aes192Gcm),
6 => Ok(Cipher::Aes256Gcm),
_ => Err(BufferError::InvalidIndex("Cipher".to_string(), b)),
}
}
pub(crate) fn put_into_buffer<T: BufferMut>(&self, buf: &mut T) -> Result<(), BufferError> {
let b = match self {
Cipher::None => 0,
Cipher::Aes128Ctr => 1,
Cipher::Aes128Gcm => 2,
Cipher::Aes192Ctr => 3,
Cipher::Aes256Ctr => 4,
Cipher::Aes192Gcm => 5,
Cipher::Aes256Gcm => 6,
};
buf.put_u32(b)
}
fn to_openssl(self) -> Option<&'static ossl_cipher::CipherRef> {
match self {
Cipher::None => None,
Cipher::Aes128Ctr => Some(ossl_cipher::Cipher::aes_128_ctr()),
Cipher::Aes192Ctr => Some(ossl_cipher::Cipher::aes_192_ctr()),
Cipher::Aes256Ctr => Some(ossl_cipher::Cipher::aes_256_ctr()),
Cipher::Aes128Gcm => Some(ossl_cipher::Cipher::aes_128_gcm()),
Cipher::Aes192Gcm => Some(ossl_cipher::Cipher::aes_192_gcm()),
Cipher::Aes256Gcm => Some(ossl_cipher::Cipher::aes_256_gcm()),
}
}
}
impl fmt::Display for Cipher {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
let s = match self {
Cipher::None => "none",
Cipher::Aes128Ctr => "aes128-ctr",
Cipher::Aes192Ctr => "aes192-ctr",
Cipher::Aes256Ctr => "aes256-ctr",
Cipher::Aes128Gcm => "aes128-gcm",
Cipher::Aes192Gcm => "aes192-gcm",
Cipher::Aes256Gcm => "aes256-gcm",
};
fmt.write_str(s)
}
}
impl FromStr for Cipher {
type Err = ();
fn from_str(str: &str) -> Result<Self, ()> {
match str {
"none" => Ok(Cipher::None),
"aes128-ctr" => Ok(Cipher::Aes128Ctr),
"aes192-ctr" => Ok(Cipher::Aes192Ctr),
"aes256-ctr" => Ok(Cipher::Aes256Ctr),
"aes128-gcm" => Ok(Cipher::Aes128Gcm),
"aes192-gcm" => Ok(Cipher::Aes192Gcm),
"aes256-gcm" => Ok(Cipher::Aes256Gcm),
_ => Err(()),
}
}
}
#[derive(Debug)]
pub(super) struct CipherContext {
cipher: Cipher,
inp: SecureVec,
outp: SecureVec,
}
impl CipherContext {
pub(super) fn new(cipher: Cipher) -> CipherContext {
CipherContext {
cipher,
inp: vec![].into(),
outp: vec![].into(),
}
}
pub fn copy_from_slice(&mut self, buf_size: usize, buf: &[u8]) -> usize {
let len = cmp::min(buf_size, buf.len());
self.inp.resize(buf_size, 0);
self.inp[..len].copy_from_slice(&buf[..len]);
self.inp[len..].iter_mut().for_each(|n| *n = 0);
len
}
pub fn inp_mut(&mut self, buf_size: usize) -> &mut [u8] {
self.copy_from_slice(buf_size, &[]);
&mut self.inp
}
pub fn encrypt(&mut self, key: &[u8], iv: &[u8]) -> Result<&[u8], CipherError> {
match self.cipher {
Cipher::None => self.make_none(),
_ => self.encrypt_aad(None, key, iv).map(|_| ())?,
};
Ok(self.outp.as_slice())
}
fn encrypt_aad(
&mut self,
aad: Option<&[u8]>,
key: &[u8],
iv: &[u8],
) -> Result<usize, CipherError> {
let key = key
.get(..self.cipher.key_len())
.ok_or(CipherError::InvalidKey)?;
let iv = iv
.get(..self.cipher.iv_len())
.ok_or(CipherError::InvalidIv)?;
let ptext_len = self.inp.len();
let ctext_len = ptext_len;
if ptext_len == 0 {
return Ok(0);
}
if ptext_len % self.cipher.block_size() != 0 {
return Err(CipherError::InvalidBlockSize);
}
let mut ctx = CipherCtx::new()?;
ctx.encrypt_init(self.cipher.to_openssl(), Some(key), Some(iv))?;
ctx.set_padding(false);
if let Some(buf) = aad {
ctx.cipher_update(buf, None)?;
}
self.outp
.resize(ctext_len + self.cipher.tag_size() as usize, 0);
ctx.cipher_update(&self.inp[..ptext_len], Some(&mut self.outp[..ctext_len]))?;
if self.cipher.tag_size() > 0 {
ctx.cipher_final(&mut [])?;
ctx.tag(&mut self.outp[ctext_len..])?;
}
Ok(ctext_len)
}
pub fn decrypt(&mut self, key: &[u8], iv: &[u8]) -> Result<&[u8], CipherError> {
match self.cipher {
Cipher::None => self.make_none(),
_ => self.decrypt_aad(None, key, iv).map(|_| ())?,
}
Ok(self.outp.as_slice())
}
fn decrypt_aad(
&mut self,
aad: Option<&[u8]>,
key: &[u8],
iv: &[u8],
) -> Result<usize, CipherError> {
let key = key
.get(..self.cipher.key_len())
.ok_or(CipherError::InvalidKey)?;
let iv = iv
.get(..self.cipher.iv_len())
.ok_or(CipherError::InvalidIv)?;
let ctext_bytes = self
.inp
.len()
.saturating_sub(self.cipher.tag_size() as usize);
let ptext_bytes = ctext_bytes;
if ctext_bytes == 0 {
return Ok(0);
}
if ctext_bytes % self.cipher.block_size() != 0 {
return Err(CipherError::InvalidBlockSize);
}
let mut ctx = CipherCtx::new()?;
ctx.decrypt_init(self.cipher.to_openssl(), Some(key), Some(iv))?;
ctx.set_padding(false);
if let Some(buf) = aad {
ctx.cipher_update(buf, None)?;
}
self.outp.resize(ptext_bytes, 0);
ctx.cipher_update(
&self.inp[..ctext_bytes],
Some(&mut self.outp[..ptext_bytes]),
)?;
if self.cipher.tag_size() > 0 {
ctx.set_tag(&self.inp[ctext_bytes..])?;
ctx.cipher_final(&mut [])
.map_err(|_| CipherError::NotTrustworthy)?;
}
Ok(ctext_bytes)
}
fn make_none(&mut self) {
self.outp.clear();
self.outp.extend_from_slice(&self.inp);
}
}