1use bytes::{Buf, Bytes};
52
53use crate::codec::{read_b_varchar, read_us_varchar};
54use crate::error::ProtocolError;
55use crate::prelude::*;
56
57pub const COLUMN_FLAG_ENCRYPTED: u16 = 0x0800;
59
60pub const ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256: u8 = 2;
62
63pub const ENCRYPTION_TYPE_DETERMINISTIC: u8 = 1;
65
66pub const ENCRYPTION_TYPE_RANDOMIZED: u8 = 2;
68
69pub const NORMALIZATION_RULE_VERSION: u8 = 1;
71
72#[derive(Debug, Clone)]
77pub struct CekTableEntry {
78 pub database_id: u32,
80 pub cek_id: u32,
82 pub cek_version: u32,
84 pub cek_md_version: u64,
86 pub values: Vec<CekValue>,
88}
89
90#[derive(Debug, Clone)]
95pub struct CekValue {
96 pub encrypted_value: Bytes,
98 pub key_store_provider_name: String,
100 pub cmk_path: String,
102 pub encryption_algorithm: String,
104}
105
106#[derive(Debug, Clone)]
118pub struct CryptoMetadata {
119 pub cek_table_ordinal: u16,
121 pub base_user_type: u32,
123 pub base_col_type: u8,
125 pub base_type_info: crate::token::TypeInfo,
127 pub algorithm_id: u8,
129 pub encryption_type: EncryptionTypeWire,
131 pub normalization_version: u8,
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
137#[non_exhaustive]
138pub enum EncryptionTypeWire {
139 Deterministic,
141 Randomized,
143}
144
145impl EncryptionTypeWire {
146 #[must_use]
148 pub fn from_u8(value: u8) -> Option<Self> {
149 match value {
150 ENCRYPTION_TYPE_DETERMINISTIC => Some(Self::Deterministic),
151 ENCRYPTION_TYPE_RANDOMIZED => Some(Self::Randomized),
152 _ => None,
153 }
154 }
155
156 #[must_use]
158 pub fn to_u8(self) -> u8 {
159 match self {
160 Self::Deterministic => ENCRYPTION_TYPE_DETERMINISTIC,
161 Self::Randomized => ENCRYPTION_TYPE_RANDOMIZED,
162 }
163 }
164}
165
166#[derive(Debug, Clone, Default)]
168pub struct CekTable {
169 pub entries: Vec<CekTableEntry>,
171}
172
173impl CekTable {
174 #[must_use]
176 pub fn new() -> Self {
177 Self::default()
178 }
179
180 #[must_use]
182 pub fn get(&self, ordinal: u16) -> Option<&CekTableEntry> {
183 self.entries.get(ordinal as usize)
184 }
185
186 #[must_use]
188 pub fn is_empty(&self) -> bool {
189 self.entries.is_empty()
190 }
191
192 #[must_use]
194 pub fn len(&self) -> usize {
195 self.entries.len()
196 }
197
198 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
223 if src.remaining() < 2 {
224 return Err(ProtocolError::UnexpectedEof);
225 }
226
227 let cek_count = src.get_u16_le() as usize;
228
229 let mut entries = Vec::with_capacity(cek_count);
230
231 for _ in 0..cek_count {
232 let entry = CekTableEntry::decode(src)?;
233 entries.push(entry);
234 }
235
236 Ok(Self { entries })
237 }
238}
239
240impl CekTableEntry {
241 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
243 if src.remaining() < 21 {
245 return Err(ProtocolError::UnexpectedEof);
246 }
247
248 let database_id = src.get_u32_le();
249 let cek_id = src.get_u32_le();
250 let cek_version = src.get_u32_le();
251 let cek_md_version = src.get_u64_le();
252 let value_count = src.get_u8() as usize;
253
254 let mut values = Vec::with_capacity(value_count);
255
256 for _ in 0..value_count {
257 let value = CekValue::decode(src)?;
258 values.push(value);
259 }
260
261 Ok(Self {
262 database_id,
263 cek_id,
264 cek_version,
265 cek_md_version,
266 values,
267 })
268 }
269
270 #[must_use]
272 pub fn primary_value(&self) -> Option<&CekValue> {
273 self.values.first()
274 }
275}
276
277impl CekValue {
278 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
280 if src.remaining() < 2 {
282 return Err(ProtocolError::UnexpectedEof);
283 }
284
285 let encrypted_value_length = src.get_u16_le() as usize;
286
287 if src.remaining() < encrypted_value_length {
288 return Err(ProtocolError::UnexpectedEof);
289 }
290
291 let encrypted_value = src.copy_to_bytes(encrypted_value_length);
292
293 let key_store_provider_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
295
296 let cmk_path = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
298
299 let encryption_algorithm = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
301
302 Ok(Self {
303 encrypted_value,
304 key_store_provider_name,
305 cmk_path,
306 encryption_algorithm,
307 })
308 }
309}
310
311impl CryptoMetadata {
312 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
325 if src.remaining() < 7 {
327 return Err(ProtocolError::UnexpectedEof);
328 }
329
330 let cek_table_ordinal = src.get_u16_le();
331 let base_user_type = src.get_u32_le();
332 let base_col_type = src.get_u8();
333
334 let base_type_id = crate::types::TypeId::from_u8(base_col_type)
338 .ok_or(ProtocolError::InvalidDataType(base_col_type))?;
339 let base_type_info = crate::token::decode_type_info(src, base_type_id, base_col_type)?;
340
341 if src.remaining() < 3 {
343 return Err(ProtocolError::UnexpectedEof);
344 }
345
346 let algorithm_id = src.get_u8();
347 let encryption_type_byte = src.get_u8();
348 let normalization_version = src.get_u8();
349
350 let encryption_type = EncryptionTypeWire::from_u8(encryption_type_byte).ok_or(
351 ProtocolError::InvalidField {
352 field: "encryption_type",
353 value: encryption_type_byte as u32,
354 },
355 )?;
356
357 Ok(Self {
358 cek_table_ordinal,
359 base_user_type,
360 base_col_type,
361 base_type_info,
362 algorithm_id,
363 encryption_type,
364 normalization_version,
365 })
366 }
367
368 #[must_use]
370 pub fn is_aead_aes_256(&self) -> bool {
371 self.algorithm_id == ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
372 }
373
374 #[must_use]
376 pub fn is_deterministic(&self) -> bool {
377 self.encryption_type == EncryptionTypeWire::Deterministic
378 }
379
380 #[must_use]
382 pub fn is_randomized(&self) -> bool {
383 self.encryption_type == EncryptionTypeWire::Randomized
384 }
385
386 #[must_use]
388 pub fn base_type_id(&self) -> crate::types::TypeId {
389 crate::types::TypeId::from_u8(self.base_col_type).unwrap_or(crate::types::TypeId::Null)
390 }
391}
392
393#[derive(Debug, Clone, Default)]
398pub struct ColumnCryptoInfo {
399 pub crypto_metadata: Option<CryptoMetadata>,
401}
402
403impl ColumnCryptoInfo {
404 #[must_use]
406 pub fn unencrypted() -> Self {
407 Self {
408 crypto_metadata: None,
409 }
410 }
411
412 #[must_use]
414 pub fn encrypted(metadata: CryptoMetadata) -> Self {
415 Self {
416 crypto_metadata: Some(metadata),
417 }
418 }
419
420 #[must_use]
422 pub fn is_encrypted(&self) -> bool {
423 self.crypto_metadata.is_some()
424 }
425}
426
427#[must_use]
429pub fn is_column_encrypted(flags: u16) -> bool {
430 (flags & COLUMN_FLAG_ENCRYPTED) != 0
431}
432
433#[cfg(test)]
434#[allow(clippy::unwrap_used, clippy::expect_used)]
435mod tests {
436 use super::*;
437 use bytes::BytesMut;
438
439 #[test]
440 fn test_encryption_type_wire_roundtrip() {
441 assert_eq!(
442 EncryptionTypeWire::from_u8(1),
443 Some(EncryptionTypeWire::Deterministic)
444 );
445 assert_eq!(
446 EncryptionTypeWire::from_u8(2),
447 Some(EncryptionTypeWire::Randomized)
448 );
449 assert_eq!(EncryptionTypeWire::from_u8(0), None);
450 assert_eq!(EncryptionTypeWire::from_u8(99), None);
451
452 assert_eq!(EncryptionTypeWire::Deterministic.to_u8(), 1);
453 assert_eq!(EncryptionTypeWire::Randomized.to_u8(), 2);
454 }
455
456 #[test]
457 fn test_crypto_metadata_decode() {
458 let data = [
459 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x26, 0x04, 0x02, 0x01, 0x01, ];
467
468 let mut cursor: &[u8] = &data;
469 let metadata = CryptoMetadata::decode(&mut cursor).unwrap();
470
471 assert_eq!(metadata.cek_table_ordinal, 0);
472 assert_eq!(metadata.base_user_type, 0);
473 assert_eq!(metadata.base_col_type, 0x26); assert_eq!(metadata.base_type_info.max_length, Some(4));
475 assert_eq!(
476 metadata.algorithm_id,
477 ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
478 );
479 assert_eq!(metadata.encryption_type, EncryptionTypeWire::Deterministic);
480 assert_eq!(metadata.normalization_version, 1);
481 assert!(metadata.is_aead_aes_256());
482 assert!(metadata.is_deterministic());
483 assert!(!metadata.is_randomized());
484
485 assert_eq!(metadata.base_type_id(), crate::types::TypeId::IntN);
487 }
488
489 #[test]
490 fn test_cek_value_decode() {
491 let mut data = BytesMut::new();
492
493 data.extend_from_slice(&[0x04, 0x00]);
495 data.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
497 data.extend_from_slice(&[0x04]); data.extend_from_slice(&[b'T', 0x00, b'E', 0x00, b'S', 0x00, b'T', 0x00]);
500 data.extend_from_slice(&[0x04, 0x00]); data.extend_from_slice(&[b'k', 0x00, b'e', 0x00, b'y', 0x00, b'1', 0x00]);
503 data.extend_from_slice(&[0x03]); data.extend_from_slice(&[b'R', 0x00, b'S', 0x00, b'A', 0x00]);
506
507 let mut cursor: &[u8] = &data;
508 let value = CekValue::decode(&mut cursor).unwrap();
509
510 assert_eq!(value.encrypted_value.as_ref(), &[0xDE, 0xAD, 0xBE, 0xEF]);
511 assert_eq!(value.key_store_provider_name, "TEST");
512 assert_eq!(value.cmk_path, "key1");
513 assert_eq!(value.encryption_algorithm, "RSA");
514 }
515
516 #[test]
517 fn test_cek_table_entry_decode() {
518 let mut data = BytesMut::new();
519
520 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
522 data.extend_from_slice(&[0x02, 0x00, 0x00, 0x00]);
524 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
526 data.extend_from_slice(&[0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
528 data.extend_from_slice(&[0x01]);
530
531 data.extend_from_slice(&[0x04, 0x00]); data.extend_from_slice(&[0x11, 0x22, 0x33, 0x44]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'K', 0x00, b'S', 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[b'P', 0x00]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'A', 0x00]); let mut cursor: &[u8] = &data;
542 let entry = CekTableEntry::decode(&mut cursor).expect("should decode entry");
543
544 assert_eq!(entry.database_id, 1);
545 assert_eq!(entry.cek_id, 2);
546 assert_eq!(entry.cek_version, 1);
547 assert_eq!(entry.cek_md_version, 100);
548 assert_eq!(entry.values.len(), 1);
549
550 let value = entry.primary_value().expect("should have primary value");
551 assert_eq!(value.encrypted_value.as_ref(), &[0x11, 0x22, 0x33, 0x44]);
552 }
553
554 #[test]
555 fn test_cek_table_decode() {
556 let mut data = BytesMut::new();
557
558 data.extend_from_slice(&[0x01, 0x00]);
560
561 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[0x02, 0x00]); data.extend_from_slice(&[0xAB, 0xCD]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'K', 0x00]);
573 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[b'P', 0x00]);
575 data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'A', 0x00]);
577
578 let mut cursor: &[u8] = &data;
579 let table = CekTable::decode(&mut cursor).expect("should decode table");
580
581 assert_eq!(table.len(), 1);
582 assert!(!table.is_empty());
583
584 let entry = table.get(0).expect("should have first entry");
585 assert_eq!(entry.database_id, 1);
586 }
587
588 #[test]
589 fn test_is_column_encrypted() {
590 assert!(!is_column_encrypted(0x0000));
591 assert!(!is_column_encrypted(0x0001)); assert!(is_column_encrypted(0x0800)); assert!(is_column_encrypted(0x0801)); }
595
596 #[test]
597 fn test_column_crypto_info() {
598 let unencrypted = ColumnCryptoInfo::unencrypted();
599 assert!(!unencrypted.is_encrypted());
600
601 let metadata = CryptoMetadata {
602 cek_table_ordinal: 0,
603 base_user_type: 0,
604 base_col_type: 0x26, base_type_info: crate::token::TypeInfo::default(),
606 algorithm_id: 2,
607 encryption_type: EncryptionTypeWire::Randomized,
608 normalization_version: 1,
609 };
610 let encrypted = ColumnCryptoInfo::encrypted(metadata);
611 assert!(encrypted.is_encrypted());
612 }
613}