use bytes::{Buf, Bytes};
use crate::codec::{read_b_varchar, read_us_varchar};
use crate::error::ProtocolError;
use crate::prelude::*;
pub const COLUMN_FLAG_ENCRYPTED: u16 = 0x0800;
pub const ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256: u8 = 2;
pub const ENCRYPTION_TYPE_DETERMINISTIC: u8 = 1;
pub const ENCRYPTION_TYPE_RANDOMIZED: u8 = 2;
pub const NORMALIZATION_RULE_VERSION: u8 = 1;
#[derive(Debug, Clone)]
pub struct CekTableEntry {
pub database_id: u32,
pub cek_id: u32,
pub cek_version: u32,
pub cek_md_version: u64,
pub values: Vec<CekValue>,
}
#[derive(Debug, Clone)]
pub struct CekValue {
pub encrypted_value: Bytes,
pub key_store_provider_name: String,
pub cmk_path: String,
pub encryption_algorithm: String,
}
#[derive(Debug, Clone)]
pub struct CryptoMetadata {
pub cek_table_ordinal: u16,
pub base_user_type: u32,
pub base_col_type: u8,
pub base_type_info: crate::token::TypeInfo,
pub algorithm_id: u8,
pub encryption_type: EncryptionTypeWire,
pub normalization_version: u8,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum EncryptionTypeWire {
Deterministic,
Randomized,
}
impl EncryptionTypeWire {
#[must_use]
pub fn from_u8(value: u8) -> Option<Self> {
match value {
ENCRYPTION_TYPE_DETERMINISTIC => Some(Self::Deterministic),
ENCRYPTION_TYPE_RANDOMIZED => Some(Self::Randomized),
_ => None,
}
}
#[must_use]
pub fn to_u8(self) -> u8 {
match self {
Self::Deterministic => ENCRYPTION_TYPE_DETERMINISTIC,
Self::Randomized => ENCRYPTION_TYPE_RANDOMIZED,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CekTable {
pub entries: Vec<CekTableEntry>,
}
impl CekTable {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn get(&self, ordinal: u16) -> Option<&CekTableEntry> {
self.entries.get(ordinal as usize)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let cek_count = src.get_u16_le() as usize;
let mut entries = Vec::with_capacity(cek_count);
for _ in 0..cek_count {
let entry = CekTableEntry::decode(src)?;
entries.push(entry);
}
Ok(Self { entries })
}
}
impl CekTableEntry {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 21 {
return Err(ProtocolError::UnexpectedEof);
}
let database_id = src.get_u32_le();
let cek_id = src.get_u32_le();
let cek_version = src.get_u32_le();
let cek_md_version = src.get_u64_le();
let value_count = src.get_u8() as usize;
let mut values = Vec::with_capacity(value_count);
for _ in 0..value_count {
let value = CekValue::decode(src)?;
values.push(value);
}
Ok(Self {
database_id,
cek_id,
cek_version,
cek_md_version,
values,
})
}
#[must_use]
pub fn primary_value(&self) -> Option<&CekValue> {
self.values.first()
}
}
impl CekValue {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let encrypted_value_length = src.get_u16_le() as usize;
if src.remaining() < encrypted_value_length {
return Err(ProtocolError::UnexpectedEof);
}
let encrypted_value = src.copy_to_bytes(encrypted_value_length);
let key_store_provider_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
let cmk_path = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
let encryption_algorithm = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
Ok(Self {
encrypted_value,
key_store_provider_name,
cmk_path,
encryption_algorithm,
})
}
}
impl CryptoMetadata {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 7 {
return Err(ProtocolError::UnexpectedEof);
}
let cek_table_ordinal = src.get_u16_le();
let base_user_type = src.get_u32_le();
let base_col_type = src.get_u8();
let base_type_id =
crate::types::TypeId::from_u8(base_col_type).unwrap_or(crate::types::TypeId::Null);
let base_type_info = crate::token::decode_type_info(src, base_type_id, base_col_type)?;
if src.remaining() < 3 {
return Err(ProtocolError::UnexpectedEof);
}
let algorithm_id = src.get_u8();
let encryption_type_byte = src.get_u8();
let normalization_version = src.get_u8();
let encryption_type = EncryptionTypeWire::from_u8(encryption_type_byte).ok_or(
ProtocolError::InvalidField {
field: "encryption_type",
value: encryption_type_byte as u32,
},
)?;
Ok(Self {
cek_table_ordinal,
base_user_type,
base_col_type,
base_type_info,
algorithm_id,
encryption_type,
normalization_version,
})
}
#[must_use]
pub fn is_aead_aes_256(&self) -> bool {
self.algorithm_id == ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
}
#[must_use]
pub fn is_deterministic(&self) -> bool {
self.encryption_type == EncryptionTypeWire::Deterministic
}
#[must_use]
pub fn is_randomized(&self) -> bool {
self.encryption_type == EncryptionTypeWire::Randomized
}
#[must_use]
pub fn base_type_id(&self) -> crate::types::TypeId {
crate::types::TypeId::from_u8(self.base_col_type).unwrap_or(crate::types::TypeId::Null)
}
}
#[derive(Debug, Clone, Default)]
pub struct ColumnCryptoInfo {
pub crypto_metadata: Option<CryptoMetadata>,
}
impl ColumnCryptoInfo {
#[must_use]
pub fn unencrypted() -> Self {
Self {
crypto_metadata: None,
}
}
#[must_use]
pub fn encrypted(metadata: CryptoMetadata) -> Self {
Self {
crypto_metadata: Some(metadata),
}
}
#[must_use]
pub fn is_encrypted(&self) -> bool {
self.crypto_metadata.is_some()
}
}
#[must_use]
pub fn is_column_encrypted(flags: u16) -> bool {
(flags & COLUMN_FLAG_ENCRYPTED) != 0
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use bytes::BytesMut;
#[test]
fn test_encryption_type_wire_roundtrip() {
assert_eq!(
EncryptionTypeWire::from_u8(1),
Some(EncryptionTypeWire::Deterministic)
);
assert_eq!(
EncryptionTypeWire::from_u8(2),
Some(EncryptionTypeWire::Randomized)
);
assert_eq!(EncryptionTypeWire::from_u8(0), None);
assert_eq!(EncryptionTypeWire::from_u8(99), None);
assert_eq!(EncryptionTypeWire::Deterministic.to_u8(), 1);
assert_eq!(EncryptionTypeWire::Randomized.to_u8(), 2);
}
#[test]
fn test_crypto_metadata_decode() {
let data = [
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x26, 0x04, 0x02, 0x01, 0x01, ];
let mut cursor: &[u8] = &data;
let metadata = CryptoMetadata::decode(&mut cursor).unwrap();
assert_eq!(metadata.cek_table_ordinal, 0);
assert_eq!(metadata.base_user_type, 0);
assert_eq!(metadata.base_col_type, 0x26); assert_eq!(metadata.base_type_info.max_length, Some(4));
assert_eq!(
metadata.algorithm_id,
ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
);
assert_eq!(metadata.encryption_type, EncryptionTypeWire::Deterministic);
assert_eq!(metadata.normalization_version, 1);
assert!(metadata.is_aead_aes_256());
assert!(metadata.is_deterministic());
assert!(!metadata.is_randomized());
assert_eq!(metadata.base_type_id(), crate::types::TypeId::IntN);
}
#[test]
fn test_cek_value_decode() {
let mut data = BytesMut::new();
data.extend_from_slice(&[0x04, 0x00]);
data.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
data.extend_from_slice(&[0x04]); data.extend_from_slice(&[b'T', 0x00, b'E', 0x00, b'S', 0x00, b'T', 0x00]);
data.extend_from_slice(&[0x04, 0x00]); data.extend_from_slice(&[b'k', 0x00, b'e', 0x00, b'y', 0x00, b'1', 0x00]);
data.extend_from_slice(&[0x03]); data.extend_from_slice(&[b'R', 0x00, b'S', 0x00, b'A', 0x00]);
let mut cursor: &[u8] = &data;
let value = CekValue::decode(&mut cursor).unwrap();
assert_eq!(value.encrypted_value.as_ref(), &[0xDE, 0xAD, 0xBE, 0xEF]);
assert_eq!(value.key_store_provider_name, "TEST");
assert_eq!(value.cmk_path, "key1");
assert_eq!(value.encryption_algorithm, "RSA");
}
#[test]
fn test_cek_table_entry_decode() {
let mut data = BytesMut::new();
data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
data.extend_from_slice(&[0x02, 0x00, 0x00, 0x00]);
data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
data.extend_from_slice(&[0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
data.extend_from_slice(&[0x01]);
data.extend_from_slice(&[0x04, 0x00]); data.extend_from_slice(&[0x11, 0x22, 0x33, 0x44]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'K', 0x00, b'S', 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[b'P', 0x00]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'A', 0x00]);
let mut cursor: &[u8] = &data;
let entry = CekTableEntry::decode(&mut cursor).expect("should decode entry");
assert_eq!(entry.database_id, 1);
assert_eq!(entry.cek_id, 2);
assert_eq!(entry.cek_version, 1);
assert_eq!(entry.cek_md_version, 100);
assert_eq!(entry.values.len(), 1);
let value = entry.primary_value().expect("should have primary value");
assert_eq!(value.encrypted_value.as_ref(), &[0x11, 0x22, 0x33, 0x44]);
}
#[test]
fn test_cek_table_decode() {
let mut data = BytesMut::new();
data.extend_from_slice(&[0x01, 0x00]);
data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01]);
data.extend_from_slice(&[0x02, 0x00]); data.extend_from_slice(&[0xAB, 0xCD]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'K', 0x00]);
data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[b'P', 0x00]);
data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'A', 0x00]);
let mut cursor: &[u8] = &data;
let table = CekTable::decode(&mut cursor).expect("should decode table");
assert_eq!(table.len(), 1);
assert!(!table.is_empty());
let entry = table.get(0).expect("should have first entry");
assert_eq!(entry.database_id, 1);
}
#[test]
fn test_is_column_encrypted() {
assert!(!is_column_encrypted(0x0000));
assert!(!is_column_encrypted(0x0001)); assert!(is_column_encrypted(0x0800)); assert!(is_column_encrypted(0x0801)); }
#[test]
fn test_column_crypto_info() {
let unencrypted = ColumnCryptoInfo::unencrypted();
assert!(!unencrypted.is_encrypted());
let metadata = CryptoMetadata {
cek_table_ordinal: 0,
base_user_type: 0,
base_col_type: 0x26, base_type_info: crate::token::TypeInfo::default(),
algorithm_id: 2,
encryption_type: EncryptionTypeWire::Randomized,
normalization_version: 1,
};
let encrypted = ColumnCryptoInfo::encrypted(metadata);
assert!(encrypted.is_encrypted());
}
}