1use bytes::{Buf, Bytes};
52
53use crate::codec::{read_b_varchar, read_us_varchar};
54use crate::error::ProtocolError;
55use crate::prelude::*;
56
57pub const COLUMN_FLAG_ENCRYPTED: u16 = 0x0800;
59
60pub const ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256: u8 = 2;
62
63pub const ENCRYPTION_TYPE_DETERMINISTIC: u8 = 1;
65
66pub const ENCRYPTION_TYPE_RANDOMIZED: u8 = 2;
68
69pub const NORMALIZATION_RULE_VERSION: u8 = 1;
71
72#[derive(Debug, Clone)]
77pub struct CekTableEntry {
78 pub database_id: u32,
80 pub cek_id: u32,
82 pub cek_version: u32,
84 pub cek_md_version: u64,
86 pub values: Vec<CekValue>,
88}
89
90#[derive(Debug, Clone)]
95pub struct CekValue {
96 pub encrypted_value: Bytes,
98 pub key_store_provider_name: String,
100 pub cmk_path: String,
102 pub encryption_algorithm: String,
104}
105
106#[derive(Debug, Clone)]
111pub struct CryptoMetadata {
112 pub cek_table_ordinal: u16,
114 pub algorithm_id: u8,
116 pub encryption_type: EncryptionTypeWire,
118 pub normalization_version: u8,
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum EncryptionTypeWire {
125 Deterministic,
127 Randomized,
129}
130
131impl EncryptionTypeWire {
132 #[must_use]
134 pub fn from_u8(value: u8) -> Option<Self> {
135 match value {
136 ENCRYPTION_TYPE_DETERMINISTIC => Some(Self::Deterministic),
137 ENCRYPTION_TYPE_RANDOMIZED => Some(Self::Randomized),
138 _ => None,
139 }
140 }
141
142 #[must_use]
144 pub fn to_u8(self) -> u8 {
145 match self {
146 Self::Deterministic => ENCRYPTION_TYPE_DETERMINISTIC,
147 Self::Randomized => ENCRYPTION_TYPE_RANDOMIZED,
148 }
149 }
150}
151
152#[derive(Debug, Clone, Default)]
154pub struct CekTable {
155 pub entries: Vec<CekTableEntry>,
157}
158
159impl CekTable {
160 #[must_use]
162 pub fn new() -> Self {
163 Self::default()
164 }
165
166 #[must_use]
168 pub fn get(&self, ordinal: u16) -> Option<&CekTableEntry> {
169 self.entries.get(ordinal as usize)
170 }
171
172 #[must_use]
174 pub fn is_empty(&self) -> bool {
175 self.entries.is_empty()
176 }
177
178 #[must_use]
180 pub fn len(&self) -> usize {
181 self.entries.len()
182 }
183
184 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
209 if src.remaining() < 2 {
210 return Err(ProtocolError::UnexpectedEof);
211 }
212
213 let cek_count = src.get_u16_le() as usize;
214
215 let mut entries = Vec::with_capacity(cek_count);
216
217 for _ in 0..cek_count {
218 let entry = CekTableEntry::decode(src)?;
219 entries.push(entry);
220 }
221
222 Ok(Self { entries })
223 }
224}
225
226impl CekTableEntry {
227 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
229 if src.remaining() < 21 {
231 return Err(ProtocolError::UnexpectedEof);
232 }
233
234 let database_id = src.get_u32_le();
235 let cek_id = src.get_u32_le();
236 let cek_version = src.get_u32_le();
237 let cek_md_version = src.get_u64_le();
238 let value_count = src.get_u8() as usize;
239
240 let mut values = Vec::with_capacity(value_count);
241
242 for _ in 0..value_count {
243 let value = CekValue::decode(src)?;
244 values.push(value);
245 }
246
247 Ok(Self {
248 database_id,
249 cek_id,
250 cek_version,
251 cek_md_version,
252 values,
253 })
254 }
255
256 #[must_use]
258 pub fn primary_value(&self) -> Option<&CekValue> {
259 self.values.first()
260 }
261}
262
263impl CekValue {
264 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
266 if src.remaining() < 2 {
268 return Err(ProtocolError::UnexpectedEof);
269 }
270
271 let encrypted_value_length = src.get_u16_le() as usize;
272
273 if src.remaining() < encrypted_value_length {
274 return Err(ProtocolError::UnexpectedEof);
275 }
276
277 let encrypted_value = src.copy_to_bytes(encrypted_value_length);
278
279 let key_store_provider_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
281
282 let cmk_path = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
284
285 let encryption_algorithm = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
287
288 Ok(Self {
289 encrypted_value,
290 key_store_provider_name,
291 cmk_path,
292 encryption_algorithm,
293 })
294 }
295}
296
297impl CryptoMetadata {
298 pub const SIZE: usize = 5; pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
303 if src.remaining() < Self::SIZE {
304 return Err(ProtocolError::UnexpectedEof);
305 }
306
307 let cek_table_ordinal = src.get_u16_le();
308 let algorithm_id = src.get_u8();
309 let encryption_type_byte = src.get_u8();
310 let normalization_version = src.get_u8();
311
312 let encryption_type = EncryptionTypeWire::from_u8(encryption_type_byte).ok_or(
313 ProtocolError::InvalidField {
314 field: "encryption_type",
315 value: encryption_type_byte as u32,
316 },
317 )?;
318
319 Ok(Self {
320 cek_table_ordinal,
321 algorithm_id,
322 encryption_type,
323 normalization_version,
324 })
325 }
326
327 #[must_use]
329 pub fn is_aead_aes_256(&self) -> bool {
330 self.algorithm_id == ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
331 }
332
333 #[must_use]
335 pub fn is_deterministic(&self) -> bool {
336 self.encryption_type == EncryptionTypeWire::Deterministic
337 }
338
339 #[must_use]
341 pub fn is_randomized(&self) -> bool {
342 self.encryption_type == EncryptionTypeWire::Randomized
343 }
344}
345
346#[derive(Debug, Clone, Default)]
351pub struct ColumnCryptoInfo {
352 pub crypto_metadata: Option<CryptoMetadata>,
354}
355
356impl ColumnCryptoInfo {
357 #[must_use]
359 pub fn unencrypted() -> Self {
360 Self {
361 crypto_metadata: None,
362 }
363 }
364
365 #[must_use]
367 pub fn encrypted(metadata: CryptoMetadata) -> Self {
368 Self {
369 crypto_metadata: Some(metadata),
370 }
371 }
372
373 #[must_use]
375 pub fn is_encrypted(&self) -> bool {
376 self.crypto_metadata.is_some()
377 }
378}
379
380#[must_use]
382pub fn is_column_encrypted(flags: u16) -> bool {
383 (flags & COLUMN_FLAG_ENCRYPTED) != 0
384}
385
386#[cfg(test)]
387#[allow(clippy::unwrap_used, clippy::expect_used)]
388mod tests {
389 use super::*;
390 use bytes::BytesMut;
391
392 #[test]
393 fn test_encryption_type_wire_roundtrip() {
394 assert_eq!(
395 EncryptionTypeWire::from_u8(1),
396 Some(EncryptionTypeWire::Deterministic)
397 );
398 assert_eq!(
399 EncryptionTypeWire::from_u8(2),
400 Some(EncryptionTypeWire::Randomized)
401 );
402 assert_eq!(EncryptionTypeWire::from_u8(0), None);
403 assert_eq!(EncryptionTypeWire::from_u8(99), None);
404
405 assert_eq!(EncryptionTypeWire::Deterministic.to_u8(), 1);
406 assert_eq!(EncryptionTypeWire::Randomized.to_u8(), 2);
407 }
408
409 #[test]
410 fn test_crypto_metadata_decode() {
411 let data = [
412 0x00, 0x00, 0x02, 0x01, 0x01, ];
417
418 let mut cursor: &[u8] = &data;
419 let metadata = CryptoMetadata::decode(&mut cursor).unwrap();
420
421 assert_eq!(metadata.cek_table_ordinal, 0);
422 assert_eq!(
423 metadata.algorithm_id,
424 ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
425 );
426 assert_eq!(metadata.encryption_type, EncryptionTypeWire::Deterministic);
427 assert_eq!(metadata.normalization_version, 1);
428 assert!(metadata.is_aead_aes_256());
429 assert!(metadata.is_deterministic());
430 assert!(!metadata.is_randomized());
431 }
432
433 #[test]
434 fn test_cek_value_decode() {
435 let mut data = BytesMut::new();
436
437 data.extend_from_slice(&[0x04, 0x00]);
439 data.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
441 data.extend_from_slice(&[0x04]); data.extend_from_slice(&[b'T', 0x00, b'E', 0x00, b'S', 0x00, b'T', 0x00]);
444 data.extend_from_slice(&[0x04, 0x00]); data.extend_from_slice(&[b'k', 0x00, b'e', 0x00, b'y', 0x00, b'1', 0x00]);
447 data.extend_from_slice(&[0x03]); data.extend_from_slice(&[b'R', 0x00, b'S', 0x00, b'A', 0x00]);
450
451 let mut cursor: &[u8] = &data;
452 let value = CekValue::decode(&mut cursor).unwrap();
453
454 assert_eq!(value.encrypted_value.as_ref(), &[0xDE, 0xAD, 0xBE, 0xEF]);
455 assert_eq!(value.key_store_provider_name, "TEST");
456 assert_eq!(value.cmk_path, "key1");
457 assert_eq!(value.encryption_algorithm, "RSA");
458 }
459
460 #[test]
461 fn test_cek_table_entry_decode() {
462 let mut data = BytesMut::new();
463
464 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
466 data.extend_from_slice(&[0x02, 0x00, 0x00, 0x00]);
468 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
470 data.extend_from_slice(&[0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
472 data.extend_from_slice(&[0x01]);
474
475 data.extend_from_slice(&[0x04, 0x00]); data.extend_from_slice(&[0x11, 0x22, 0x33, 0x44]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'K', 0x00, b'S', 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[b'P', 0x00]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'A', 0x00]); let mut cursor: &[u8] = &data;
486 let entry = CekTableEntry::decode(&mut cursor).expect("should decode entry");
487
488 assert_eq!(entry.database_id, 1);
489 assert_eq!(entry.cek_id, 2);
490 assert_eq!(entry.cek_version, 1);
491 assert_eq!(entry.cek_md_version, 100);
492 assert_eq!(entry.values.len(), 1);
493
494 let value = entry.primary_value().expect("should have primary value");
495 assert_eq!(value.encrypted_value.as_ref(), &[0x11, 0x22, 0x33, 0x44]);
496 }
497
498 #[test]
499 fn test_cek_table_decode() {
500 let mut data = BytesMut::new();
501
502 data.extend_from_slice(&[0x01, 0x00]);
504
505 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[0x02, 0x00]); data.extend_from_slice(&[0xAB, 0xCD]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'K', 0x00]);
517 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[b'P', 0x00]);
519 data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'A', 0x00]);
521
522 let mut cursor: &[u8] = &data;
523 let table = CekTable::decode(&mut cursor).expect("should decode table");
524
525 assert_eq!(table.len(), 1);
526 assert!(!table.is_empty());
527
528 let entry = table.get(0).expect("should have first entry");
529 assert_eq!(entry.database_id, 1);
530 }
531
532 #[test]
533 fn test_is_column_encrypted() {
534 assert!(!is_column_encrypted(0x0000));
535 assert!(!is_column_encrypted(0x0001)); assert!(is_column_encrypted(0x0800)); assert!(is_column_encrypted(0x0801)); }
539
540 #[test]
541 fn test_column_crypto_info() {
542 let unencrypted = ColumnCryptoInfo::unencrypted();
543 assert!(!unencrypted.is_encrypted());
544
545 let metadata = CryptoMetadata {
546 cek_table_ordinal: 0,
547 algorithm_id: 2,
548 encryption_type: EncryptionTypeWire::Randomized,
549 normalization_version: 1,
550 };
551 let encrypted = ColumnCryptoInfo::encrypted(metadata);
552 assert!(encrypted.is_encrypted());
553 }
554}