use aes::cipher::{KeyIvInit, StreamCipher};
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
use byteorder::{BigEndian, ByteOrder};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use std::io::Write;
use thiserror::Error;
use time::{Duration, OffsetDateTime};
type HmacSha256 = Hmac<Sha256>;
type Aes256Ctr64BE = ctr::Ctr64BE<aes::Aes256>;
const UNIX_EPOCH: OffsetDateTime = time::OffsetDateTime::UNIX_EPOCH;
#[derive(Clone, Debug)]
pub struct Keys {
pub encryption_key: [u8; 32],
pub integrity_key: [u8; 32],
}
impl Keys {
pub fn new(encryption_key: &[u8], integrity_key: &[u8]) -> Result<Self, CryptoError> {
let encryption_key: [u8; 32] = encryption_key
.try_into()
.map_err(|_| CryptoError::InvalidKey)?;
let integrity_key: [u8; 32] = integrity_key
.try_into()
.map_err(|_| CryptoError::InvalidKey)?;
Ok(Self {
encryption_key,
integrity_key,
})
}
}
#[derive(Error, Debug)]
pub enum CryptoError {
#[error("invalid key")]
InvalidKey,
#[error("invalid signature")]
InvalidSign,
#[error("invalid init vector")]
InvalidInitVector,
#[error("data too short")]
DataTooShort,
#[error("payload size mismatch")]
PayloadSizeMismatch,
#[error("decode error: {0}")]
DecodeError(#[from] base64::DecodeError),
#[error("io error: {0}")]
IoError(#[from] std::io::Error),
}
pub struct Crypto {
pub keys: Keys,
}
impl Crypto {
pub fn new(keys: Keys) -> Self {
Self { keys }
}
pub const IV_BASE: usize = 0;
pub const IV_SIZE: usize = 16;
pub const IV_TIME_OFFSET: usize = 0;
pub const IV_TIME_SIZE: usize = 8;
pub const IV_SERVER_ID_OFFSET: usize = 8;
pub const IV_SERVER_ID_SIZE: usize = 8;
pub const SIGNATURE_SIZE: usize = 4;
pub const PAYLOAD_BASE: usize = Crypto::IV_BASE + Crypto::IV_SIZE;
pub const OVERHEAD_SIZE: usize = Crypto::IV_SIZE + Crypto::SIGNATURE_SIZE;
#[inline]
pub fn decode<T>(&self, data: T) -> Result<Vec<u8>, CryptoError>
where
T: AsRef<[u8]>,
{
URL_SAFE
.decode(data)
.map(|v| v.to_vec())
.map_err(|e| e.into())
}
#[inline]
pub fn encode<T>(&self, data: T) -> String
where
T: AsRef<[u8]>,
{
URL_SAFE.encode(data)
}
#[inline]
pub fn decrypt(&self, cipher_data: &[u8]) -> Result<Vec<u8>, CryptoError> {
if cipher_data.len() < Self::OVERHEAD_SIZE {
return Err(CryptoError::DataTooShort);
}
let mut data = cipher_data.to_vec();
let data_size = data.len();
self.xor_payload(&mut data)?;
let confirmation_signature = self.hmac_signature(&data)?;
let integrity_signature = self.read_i32(&data, data_size - Self::SIGNATURE_SIZE);
self.write_i32(
&mut data,
data_size - Self::SIGNATURE_SIZE,
confirmation_signature,
);
if confirmation_signature != integrity_signature {
return Err(CryptoError::InvalidSign);
}
Ok(data)
}
#[inline]
pub fn encrypt(&self, plain_data: &[u8]) -> Result<Vec<u8>, CryptoError> {
if plain_data.len() < Self::OVERHEAD_SIZE {
return Err(CryptoError::DataTooShort);
}
let mut data = plain_data.to_vec();
let data_size = data.len();
let signature = self.hmac_signature(&data)?;
self.write_i32(&mut data, data_size - Self::SIGNATURE_SIZE, signature);
self.xor_payload(&mut data)?;
Ok(data)
}
#[inline]
pub fn package<T>(&self, payload: T, iv: Option<&[u8]>) -> Result<String, CryptoError>
where
T: AsRef<[u8]>,
{
let mut out = Vec::new();
self.package_to(payload, iv, &mut out)?;
Ok(String::from_utf8(out).expect("base64 output is valid UTF-8"))
}
#[inline]
pub fn unpackage<T>(&self, data: T) -> Result<Vec<u8>, CryptoError>
where
T: AsRef<[u8]>,
{
let mut out = Vec::new();
self.unpackage_to(data, &mut out)?;
Ok(out)
}
#[inline]
pub fn package_to<T, W>(
&self,
payload: T,
iv: Option<&[u8]>,
out: &mut W,
) -> Result<(), CryptoError>
where
T: AsRef<[u8]>,
W: Write,
{
let payload = payload.as_ref();
let mut pkg = self.init_plain_data(payload.len(), iv)?;
self.set_payload(&mut pkg, payload)?;
let encrypted = self.encrypt(&pkg)?;
out.write_all(URL_SAFE.encode(&encrypted).as_bytes())?;
Ok(())
}
#[inline]
pub fn unpackage_to<T, W>(&self, data: T, out: &mut W) -> Result<(), CryptoError>
where
T: AsRef<[u8]>,
W: Write,
{
let decoded = self.decode(data)?;
let decrypted = self.decrypt(&decoded)?;
let payload = self.payload(&decrypted).ok_or(CryptoError::DataTooShort)?;
out.write_all(payload)?;
Ok(())
}
#[inline]
pub fn create_init_vector(&self, timestamp: OffsetDateTime, server_id: i64) -> Vec<u8> {
let timestamp = (timestamp.unix_timestamp_nanos() / 1_000) as i64; let mut iv = vec![0; Self::IV_SIZE];
self.write_i64(&mut iv, Self::IV_TIME_OFFSET, timestamp);
self.write_i64(&mut iv, Self::IV_SERVER_ID_OFFSET, server_id);
iv
}
#[inline]
pub fn timestamp(&self, data: &[u8]) -> Option<OffsetDateTime> {
if data.len() < Self::IV_SIZE {
return None;
}
let ts = self.read_i64(data, Self::IV_BASE + Self::IV_TIME_OFFSET);
Some(
UNIX_EPOCH
.checked_add(Duration::microseconds(ts))
.unwrap_or(UNIX_EPOCH),
)
}
#[inline]
pub fn server_id(&self, data: &[u8]) -> Option<i64> {
if data.len() < Self::IV_SIZE {
return None;
}
Some(self.read_i64(data, Self::IV_BASE + Self::IV_SERVER_ID_OFFSET))
}
#[inline]
pub fn payload<'a>(&self, data: &'a [u8]) -> Option<&'a [u8]> {
if data.len() < Self::OVERHEAD_SIZE {
return None;
}
Some(&data[Self::PAYLOAD_BASE..data.len() - Self::SIGNATURE_SIZE])
}
#[inline]
pub fn init_plain_data(
&self,
payload_size: usize,
iv: Option<&[u8]>,
) -> Result<Vec<u8>, CryptoError> {
let mut plain_data = vec![0; Self::OVERHEAD_SIZE + payload_size];
if let Some(iv) = iv {
plain_data[Self::IV_BASE..Self::IV_BASE + Self::IV_SIZE].copy_from_slice(iv);
} else {
let now = (OffsetDateTime::now_utc().unix_timestamp_nanos() / 1_000) as i64;
self.write_i64(&mut plain_data, Self::IV_TIME_OFFSET, now);
self.write_i64(
&mut plain_data,
Self::IV_SERVER_ID_OFFSET,
rand::random::<i64>(),
);
}
Ok(plain_data)
}
#[inline]
pub fn set_payload(&self, plain_data: &mut [u8], payload: &[u8]) -> Result<(), CryptoError> {
if payload.len() != plain_data.len() - Self::OVERHEAD_SIZE {
return Err(CryptoError::PayloadSizeMismatch);
}
plain_data[Self::PAYLOAD_BASE..Self::PAYLOAD_BASE + payload.len()].copy_from_slice(payload);
Ok(())
}
#[inline]
fn read_i32(&self, data: &[u8], offset: usize) -> i32 {
BigEndian::read_i32(&data[offset..offset + 4])
}
#[inline]
fn read_i64(&self, data: &[u8], offset: usize) -> i64 {
BigEndian::read_i64(&data[offset..offset + 8])
}
#[inline]
fn write_i32(&self, data: &mut [u8], offset: usize, value: i32) {
BigEndian::write_i32(&mut data[offset..offset + 4], value);
}
#[inline]
fn write_i64(&self, data: &mut [u8], offset: usize, value: i64) {
BigEndian::write_i64(&mut data[offset..offset + 8], value);
}
#[inline]
fn xor_payload(&self, data: &mut [u8]) -> Result<(), CryptoError> {
let iv: &[u8; 16] = &data[Self::IV_BASE..Self::IV_BASE + Self::IV_SIZE]
.try_into()
.map_err(|_| CryptoError::InvalidInitVector)?;
let mut cipher = Aes256Ctr64BE::new(&self.keys.encryption_key.into(), iv.into());
let data_size = data.len();
cipher.apply_keystream(&mut data[Self::PAYLOAD_BASE..data_size - Self::SIGNATURE_SIZE]);
Ok(())
}
#[inline]
fn hmac_signature(&self, data: &[u8]) -> Result<i32, CryptoError> {
let mut mac = HmacSha256::new_from_slice(&self.keys.integrity_key)
.map_err(|_| CryptoError::InvalidKey)?;
mac.update(&data[Self::PAYLOAD_BASE..data.len() - Self::SIGNATURE_SIZE]);
mac.update(&data[Self::IV_BASE..Self::IV_BASE + Self::IV_SIZE]);
let b = mac.finalize().into_bytes();
Ok(self.read_i32(&b, 0))
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::prelude::*;
static TEST_ENCRYPTION_KEY: &str = "sIxwz7yw62yrfoLGt12lIHKuYrK/S5kLuApI2BQe7Ac=";
static TEST_INTEGRITY_KEY: &str = "v3fsVcMBMMHYzRhi7SpM0sdqwzvAxM6KPTu9OtVod5I=";
fn create_keys() -> Keys {
Keys::new(
&BASE64_STANDARD.decode(TEST_ENCRYPTION_KEY).unwrap(),
&BASE64_STANDARD.decode(TEST_INTEGRITY_KEY).unwrap(),
)
.unwrap()
}
#[test]
fn test_decode() {
let crypto = Crypto::new(create_keys());
let encoded = "aGVsbG8sIHdvcmxk";
let decoded = crypto.decode(encoded).unwrap();
assert_eq!(decoded, b"hello, world");
}
#[test]
fn test_encode() {
let crypto = Crypto::new(create_keys());
let data = b"hello, world";
let encoded = crypto.encode(data);
assert_eq!(encoded, "aGVsbG8sIHdvcmxk");
}
#[test]
fn test_decrypt() {
let crypto = Crypto::new(create_keys());
let timestamp = OffsetDateTime::UNIX_EPOCH + Duration::seconds(1);
let iv = crypto.create_init_vector(timestamp, 123456789);
let payload = "https://example.com".as_bytes();
let mut plain_data = crypto.init_plain_data(payload.len(), Some(&iv)).unwrap();
crypto.set_payload(&mut plain_data, payload).unwrap();
let encrypted_data = crypto.encrypt(&plain_data).unwrap();
assert_eq!(crypto.timestamp(&iv), Some(timestamp));
assert_eq!(crypto.server_id(&iv), Some(123456789));
assert_eq!(
crypto.payload(&encrypted_data).unwrap().len(),
payload.len()
);
assert_ne!(crypto.payload(&encrypted_data), Some(payload));
let decrypted_data = crypto.decrypt(&encrypted_data).unwrap();
assert_eq!(crypto.timestamp(&decrypted_data), Some(timestamp));
assert_eq!(crypto.server_id(&decrypted_data), Some(123456789));
assert_eq!(crypto.payload(&decrypted_data), Some(payload));
let mut encrypted_data_invalid_sign = encrypted_data.clone();
crypto.write_i32(
&mut encrypted_data_invalid_sign,
encrypted_data.len() - Crypto::SIGNATURE_SIZE,
123456789,
);
assert!(matches!(
crypto.decrypt(&encrypted_data_invalid_sign),
Err(CryptoError::InvalidSign)
));
assert_ne!(crypto.payload(&encrypted_data_invalid_sign), Some(payload))
}
#[test]
fn test_create_init_vector() {
let crypto = Crypto::new(create_keys());
let timestamp = OffsetDateTime::UNIX_EPOCH + Duration::seconds(1);
let iv = crypto.create_init_vector(timestamp, 123456789);
assert_eq!(iv.len(), Crypto::IV_SIZE);
assert_eq!(crypto.read_i64(&iv, Crypto::IV_TIME_OFFSET), 1_000_000);
assert_eq!(crypto.read_i64(&iv, Crypto::IV_SERVER_ID_OFFSET), 123456789);
assert_eq!(crypto.timestamp(&iv), Some(timestamp));
assert_eq!(crypto.server_id(&iv), Some(123456789));
}
#[test]
fn test_init_plain_data() {
let crypto = Crypto::new(create_keys());
let payload = "https://example.com".as_bytes();
let mut plain_data = crypto.init_plain_data(payload.len(), None).unwrap();
crypto.set_payload(&mut plain_data, payload).unwrap();
assert_eq!(plain_data.len(), Crypto::OVERHEAD_SIZE + payload.len());
assert_eq!(crypto.payload(&plain_data), Some(payload));
}
#[test]
fn test_init_plain_data_empty_payload() {
let crypto = Crypto::new(create_keys());
let payload = "".as_bytes();
let mut plain_data = crypto.init_plain_data(0, None).unwrap();
crypto.set_payload(&mut plain_data, payload).unwrap();
assert_eq!(crypto.payload(&plain_data), Some(payload));
}
#[test]
fn test_package_unpackage() {
let crypto = Crypto::new(create_keys());
let payload = b"Hello, world!".as_slice();
let encoded = crypto.package(payload, None).unwrap();
assert_ne!(encoded, "");
let decoded = crypto.unpackage(&encoded).unwrap();
assert_eq!(decoded, payload);
}
#[test]
fn test_package_unpackage_with_iv() {
let crypto = Crypto::new(create_keys());
let timestamp = OffsetDateTime::UNIX_EPOCH + Duration::seconds(1);
let iv = crypto.create_init_vector(timestamp, 123456789);
let payload = b"https://example.com".as_slice();
let encoded = crypto.package(payload, Some(&iv)).unwrap();
let decoded = crypto.decode(encoded.as_bytes()).unwrap();
assert_eq!(crypto.timestamp(&decoded), Some(timestamp));
assert_eq!(crypto.server_id(&decoded), Some(123456789));
let recovered = crypto.unpackage(&encoded).unwrap();
assert_eq!(recovered, payload);
}
#[test]
fn test_package_unpackage_empty_payload() {
let crypto = Crypto::new(create_keys());
let payload = b"".as_slice();
let encoded = crypto.package(payload, None).unwrap();
let recovered = crypto.unpackage(&encoded).unwrap();
assert_eq!(recovered, payload);
}
#[test]
fn test_unpackage_tampered_signature() {
let crypto = Crypto::new(create_keys());
let encoded = crypto.package(b"Hello, world!", None).unwrap();
let mut bytes = crypto.decode(encoded.as_bytes()).unwrap();
let last = bytes.len() - Crypto::SIGNATURE_SIZE;
crypto.write_i32(&mut bytes, last, 123456789);
let tampered = crypto.encode(&bytes);
assert!(matches!(
crypto.unpackage(&tampered),
Err(CryptoError::InvalidSign)
));
}
#[test]
fn test_package_to_matches_package() {
let crypto = Crypto::new(create_keys());
let timestamp = OffsetDateTime::UNIX_EPOCH + Duration::seconds(1);
let iv = crypto.create_init_vector(timestamp, 123456789);
let payload = b"https://example.com".as_slice();
let encoded_alloc = crypto.package(payload, Some(&iv)).unwrap();
let mut buf = Vec::new();
crypto.package_to(payload, Some(&iv), &mut buf).unwrap();
assert_eq!(buf, encoded_alloc.as_bytes());
}
#[test]
fn test_package_to_unpackage_to_roundtrip() {
let crypto = Crypto::new(create_keys());
let payload = b"Hello, world!".as_slice();
let mut enc_buf = Vec::new();
crypto.package_to(payload, None, &mut enc_buf).unwrap();
let mut dec_buf = Vec::new();
crypto.unpackage_to(&enc_buf, &mut dec_buf).unwrap();
assert_eq!(dec_buf, payload);
}
#[test]
fn test_package_to_appends_and_preserves_existing() {
let crypto = Crypto::new(create_keys());
let payload = b"Hello".as_slice();
let mut buf = b"prefix".to_vec();
let prefix_len = buf.len();
crypto.package_to(payload, None, &mut buf).unwrap();
assert_eq!(&buf[..prefix_len], b"prefix");
assert!(buf.len() > prefix_len);
let mut dec_buf = Vec::new();
crypto
.unpackage_to(&buf[prefix_len..], &mut dec_buf)
.unwrap();
assert_eq!(dec_buf, payload);
}
#[test]
fn test_package_to_empty_payload() {
let crypto = Crypto::new(create_keys());
let mut enc_buf = Vec::new();
crypto.package_to(b"", None, &mut enc_buf).unwrap();
let mut dec_buf = Vec::new();
crypto.unpackage_to(&enc_buf, &mut dec_buf).unwrap();
assert_eq!(dec_buf, b"");
}
#[test]
fn test_unpackage_to_tampered_signature() {
let crypto = Crypto::new(create_keys());
let mut enc_buf = Vec::new();
crypto.package_to(b"Hello", None, &mut enc_buf).unwrap();
let mut raw = crypto.decode(&enc_buf).unwrap();
let last = raw.len() - Crypto::SIGNATURE_SIZE;
crypto.write_i32(&mut raw, last, 123456789);
let tampered = crypto.encode(&raw);
let mut dec_buf = Vec::new();
assert!(matches!(
crypto.unpackage_to(tampered.as_bytes(), &mut dec_buf),
Err(CryptoError::InvalidSign)
));
}
#[test]
fn test_package_to_with_non_vec_writer() {
let crypto = Crypto::new(create_keys());
let payload = b"Hello, world!".as_slice();
let mut writer = std::io::BufWriter::new(Vec::<u8>::new());
crypto.package_to(payload, None, &mut writer).unwrap();
let encoded = writer.into_inner().unwrap();
let mut dec_buf = Vec::new();
crypto.unpackage_to(&encoded, &mut dec_buf).unwrap();
assert_eq!(dec_buf, payload);
}
#[test]
fn test_unpackage_to_appends_to_existing_buffer() {
let crypto = Crypto::new(create_keys());
let payload = b"Hello".as_slice();
let encoded = crypto.package(payload, None).unwrap();
let mut buf = b"prefix".to_vec();
let prefix_len = buf.len();
crypto.unpackage_to(&encoded, &mut buf).unwrap();
assert_eq!(&buf[..prefix_len], b"prefix");
assert_eq!(&buf[prefix_len..], payload);
}
}