use crate::cryptotensors::CryptoTensorsError;
use ring::rand::SecureRandom;
use ring::{aead, rand};
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EncryptionAlgorithm {
Aes128Gcm,
Aes256Gcm,
ChaCha20Poly1305,
}
impl FromStr for EncryptionAlgorithm {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
let normalized = s.replace('-', "").to_lowercase();
match normalized.as_str() {
"aes128gcm" => Ok(EncryptionAlgorithm::Aes128Gcm),
"aes256gcm" => Ok(EncryptionAlgorithm::Aes256Gcm),
"chacha20poly1305" => Ok(EncryptionAlgorithm::ChaCha20Poly1305),
_ => Err(()),
}
}
}
impl fmt::Display for EncryptionAlgorithm {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
EncryptionAlgorithm::Aes128Gcm => "aes128gcm",
EncryptionAlgorithm::Aes256Gcm => "aes256gcm",
EncryptionAlgorithm::ChaCha20Poly1305 => "chacha20poly1305",
};
write!(f, "{}", s)
}
}
impl EncryptionAlgorithm {
pub fn get_aead_algo(&self) -> &'static aead::Algorithm {
match self {
EncryptionAlgorithm::Aes128Gcm => &aead::AES_128_GCM,
EncryptionAlgorithm::Aes256Gcm => &aead::AES_256_GCM,
EncryptionAlgorithm::ChaCha20Poly1305 => &aead::CHACHA20_POLY1305,
}
}
pub fn key_len(&self) -> usize {
match self {
EncryptionAlgorithm::Aes128Gcm => 16, EncryptionAlgorithm::Aes256Gcm => 32, EncryptionAlgorithm::ChaCha20Poly1305 => 32, }
}
pub fn tag_len(&self) -> usize {
match self {
EncryptionAlgorithm::Aes128Gcm => 16,
EncryptionAlgorithm::Aes256Gcm => 16,
EncryptionAlgorithm::ChaCha20Poly1305 => 16,
}
}
pub fn create_tag(&self, tag_bytes: &[u8]) -> Result<aead::Tag, String> {
let expected_len = self.tag_len();
if tag_bytes.len() != expected_len {
return Err(format!(
"Invalid tag length: expected {} bytes, got {} bytes",
expected_len,
tag_bytes.len()
));
}
let mut tag = [0u8; 16]; tag.copy_from_slice(tag_bytes);
Ok(aead::Tag::from(tag))
}
}
#[derive(Clone)]
pub struct PreparedKeyContext {
pub algo: EncryptionAlgorithm,
pub key: Arc<aead::LessSafeKey>,
}
impl fmt::Debug for PreparedKeyContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PreparedKeyContext")
.field("algo", &self.algo)
.finish() }
}
pub fn prepare_key_context(
key: &[u8],
algo_name: &str,
) -> Result<PreparedKeyContext, CryptoTensorsError> {
let algo = algo_name
.parse::<EncryptionAlgorithm>()
.map_err(|_| CryptoTensorsError::InvalidAlgorithm(algo_name.to_string()))?;
if key.is_empty() {
return Err(CryptoTensorsError::InvalidKeyLength {
expected: algo.key_len(),
actual: 0,
});
}
if key.len() != algo.key_len() {
return Err(CryptoTensorsError::InvalidKeyLength {
expected: algo.key_len(),
actual: key.len(),
});
}
let aead_algo = algo.get_aead_algo();
let unbound_key = aead::UnboundKey::new(aead_algo, key)
.map_err(|e| CryptoTensorsError::KeyCreation(e.to_string()))?;
let less_safe_key = aead::LessSafeKey::new(unbound_key);
Ok(PreparedKeyContext {
algo,
key: Arc::new(less_safe_key),
})
}
pub fn encrypt_data(
in_out: &mut [u8],
ctx: &PreparedKeyContext,
) -> Result<(Vec<u8>, Vec<u8>), CryptoTensorsError> {
if in_out.is_empty() {
return Ok((Vec::new(), Vec::new()));
}
let aead_algo = ctx.algo.get_aead_algo();
let mut nonce_bytes = vec![0u8; aead_algo.nonce_len()];
let rng = rand::SystemRandom::new();
rng.fill(&mut nonce_bytes)
.map_err(|e| CryptoTensorsError::RandomGeneration(e.to_string()))?;
let nonce = aead::Nonce::assume_unique_for_key(nonce_bytes.clone().try_into().unwrap());
let tag = ctx
.key
.seal_in_place_separate_tag(nonce, aead::Aad::empty(), in_out)
.map_err(|e| CryptoTensorsError::Encryption(e.to_string()))?;
Ok((nonce_bytes, tag.as_ref().to_vec()))
}
pub fn encrypt_data_with_iv(
in_out: &mut [u8],
ctx: &PreparedKeyContext,
iv: &[u8],
) -> Result<Vec<u8>, CryptoTensorsError> {
let aead_algo = ctx.algo.get_aead_algo();
if iv.len() != aead_algo.nonce_len() {
return Err(CryptoTensorsError::InvalidIvLength {
expected: aead_algo.nonce_len(),
actual: iv.len(),
});
}
let mut nonce_bytes = vec![0u8; aead_algo.nonce_len()];
nonce_bytes.copy_from_slice(iv);
let nonce = aead::Nonce::assume_unique_for_key(nonce_bytes.try_into().unwrap());
let tag = ctx
.key
.seal_in_place_separate_tag(nonce, aead::Aad::empty(), in_out)
.map_err(|e| CryptoTensorsError::Encryption(e.to_string()))?;
Ok(tag.as_ref().to_vec())
}
pub fn decrypt_data(
in_out: &mut [u8],
ctx: &PreparedKeyContext,
iv: &[u8],
tag: &[u8],
) -> Result<(), CryptoTensorsError> {
if in_out.is_empty() && iv.is_empty() && tag.is_empty() {
return Ok(());
}
let aead_algo = ctx.algo.get_aead_algo();
if iv.is_empty() {
return Err(CryptoTensorsError::InvalidIvLength {
expected: aead_algo.nonce_len(),
actual: 0,
});
}
if tag.is_empty() {
return Err(CryptoTensorsError::InvalidTagLength {
expected: ctx.algo.tag_len(),
actual: 0,
});
}
let nonce = aead::Nonce::try_assume_unique_for_key(iv).map_err(|_e| {
CryptoTensorsError::InvalidIvLength {
expected: aead_algo.nonce_len(),
actual: iv.len(),
}
})?;
let tag = ctx
.algo
.create_tag(tag)
.map_err(|_e| CryptoTensorsError::InvalidTagLength {
expected: ctx.algo.tag_len(),
actual: tag.len(),
})?;
ctx.key
.open_in_place_separate_tag(nonce, aead::Aad::empty(), tag, in_out, 0..)
.map_err(|e| CryptoTensorsError::Decryption(e.to_string()))?;
Ok(())
}
pub fn derive_chunk_iv(base_iv: &[u8], chunk_index: usize) -> Result<Vec<u8>, CryptoTensorsError> {
if base_iv.len() < 12 {
return Err(CryptoTensorsError::InvalidIvLength {
expected: 12,
actual: base_iv.len(),
});
}
let mut iv = base_iv[0..12].to_vec();
let mut counter_bytes = [0u8; 4];
counter_bytes.copy_from_slice(&iv[8..12]);
let counter = u32::from_be_bytes(counter_bytes);
let chunk_idx_u32 = u32::try_from(chunk_index)
.map_err(|_| CryptoTensorsError::Encryption("Chunk index exceeds u32 max".to_string()))?;
let new_counter = counter.checked_add(chunk_idx_u32).ok_or_else(|| {
CryptoTensorsError::Encryption("Chunk index overflowed IV counter".to_string())
})?;
iv[8..12].copy_from_slice(&new_counter.to_be_bytes());
Ok(iv)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_derive_chunk_iv() {
let base_iv = vec![0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0];
let iv_0 = derive_chunk_iv(&base_iv, 0).unwrap();
assert_eq!(iv_0, vec![0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0]);
let iv_1 = derive_chunk_iv(&base_iv, 1).unwrap();
assert_eq!(iv_1, vec![0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 1]);
let iv_256 = derive_chunk_iv(&base_iv, 256).unwrap();
assert_eq!(iv_256, vec![0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 1, 0]);
let base_iv_non_zero = vec![0, 1, 2, 3, 4, 5, 6, 7, 255, 255, 255, 255];
let res = derive_chunk_iv(&base_iv_non_zero, 1);
assert!(res.is_err()); }
#[test]
fn test_derive_chunk_iv_invalid_length() {
let short_iv = vec![0; 11];
let res = derive_chunk_iv(&short_iv, 0);
assert!(res.is_err());
match res.unwrap_err() {
CryptoTensorsError::InvalidIvLength { expected, actual } => {
assert_eq!(expected, 12);
assert_eq!(actual, 11);
}
_ => panic!("Expected InvalidIvLength error"),
}
}
}