1use bytes::{Buf, Bytes};
52
53use crate::codec::{read_b_varchar, read_us_varchar};
54use crate::error::ProtocolError;
55use crate::prelude::*;
56
57pub const COLUMN_FLAG_ENCRYPTED: u16 = 0x0800;
59
60pub const ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256: u8 = 2;
62
63pub const ENCRYPTION_TYPE_DETERMINISTIC: u8 = 1;
65
66pub const ENCRYPTION_TYPE_RANDOMIZED: u8 = 2;
68
69pub const NORMALIZATION_RULE_VERSION: u8 = 1;
71
72#[derive(Debug, Clone)]
77pub struct CekTableEntry {
78 pub database_id: u32,
80 pub cek_id: u32,
82 pub cek_version: u32,
84 pub cek_md_version: u64,
86 pub values: Vec<CekValue>,
88}
89
90#[derive(Debug, Clone)]
95pub struct CekValue {
96 pub encrypted_value: Bytes,
98 pub key_store_provider_name: String,
100 pub cmk_path: String,
102 pub encryption_algorithm: String,
104}
105
106#[derive(Debug, Clone)]
118pub struct CryptoMetadata {
119 pub cek_table_ordinal: u16,
121 pub base_user_type: u32,
123 pub base_col_type: u8,
125 pub base_type_info: crate::token::TypeInfo,
127 pub algorithm_id: u8,
129 pub encryption_type: EncryptionTypeWire,
131 pub normalization_version: u8,
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
137#[non_exhaustive]
138pub enum EncryptionTypeWire {
139 Deterministic,
141 Randomized,
143}
144
145impl EncryptionTypeWire {
146 #[must_use]
148 pub fn from_u8(value: u8) -> Option<Self> {
149 match value {
150 ENCRYPTION_TYPE_DETERMINISTIC => Some(Self::Deterministic),
151 ENCRYPTION_TYPE_RANDOMIZED => Some(Self::Randomized),
152 _ => None,
153 }
154 }
155
156 #[must_use]
158 pub fn to_u8(self) -> u8 {
159 match self {
160 Self::Deterministic => ENCRYPTION_TYPE_DETERMINISTIC,
161 Self::Randomized => ENCRYPTION_TYPE_RANDOMIZED,
162 }
163 }
164}
165
166#[derive(Debug, Clone, Default)]
168pub struct CekTable {
169 pub entries: Vec<CekTableEntry>,
171}
172
173impl CekTable {
174 #[must_use]
176 pub fn new() -> Self {
177 Self::default()
178 }
179
180 #[must_use]
182 pub fn get(&self, ordinal: u16) -> Option<&CekTableEntry> {
183 self.entries.get(ordinal as usize)
184 }
185
186 #[must_use]
188 pub fn is_empty(&self) -> bool {
189 self.entries.is_empty()
190 }
191
192 #[must_use]
194 pub fn len(&self) -> usize {
195 self.entries.len()
196 }
197
198 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
223 if src.remaining() < 2 {
224 return Err(ProtocolError::UnexpectedEof);
225 }
226
227 let cek_count = src.get_u16_le() as usize;
228
229 let mut entries = Vec::with_capacity(cek_count);
230
231 for _ in 0..cek_count {
232 let entry = CekTableEntry::decode(src)?;
233 entries.push(entry);
234 }
235
236 Ok(Self { entries })
237 }
238}
239
240impl CekTableEntry {
241 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
243 if src.remaining() < 21 {
245 return Err(ProtocolError::UnexpectedEof);
246 }
247
248 let database_id = src.get_u32_le();
249 let cek_id = src.get_u32_le();
250 let cek_version = src.get_u32_le();
251 let cek_md_version = src.get_u64_le();
252 let value_count = src.get_u8() as usize;
253
254 let mut values = Vec::with_capacity(value_count);
255
256 for _ in 0..value_count {
257 let value = CekValue::decode(src)?;
258 values.push(value);
259 }
260
261 Ok(Self {
262 database_id,
263 cek_id,
264 cek_version,
265 cek_md_version,
266 values,
267 })
268 }
269
270 #[must_use]
272 pub fn primary_value(&self) -> Option<&CekValue> {
273 self.values.first()
274 }
275}
276
277impl CekValue {
278 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
280 if src.remaining() < 2 {
282 return Err(ProtocolError::UnexpectedEof);
283 }
284
285 let encrypted_value_length = src.get_u16_le() as usize;
286
287 if src.remaining() < encrypted_value_length {
288 return Err(ProtocolError::UnexpectedEof);
289 }
290
291 let encrypted_value = src.copy_to_bytes(encrypted_value_length);
292
293 let key_store_provider_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
295
296 let cmk_path = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
298
299 let encryption_algorithm = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
301
302 Ok(Self {
303 encrypted_value,
304 key_store_provider_name,
305 cmk_path,
306 encryption_algorithm,
307 })
308 }
309}
310
311impl CryptoMetadata {
312 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
325 if src.remaining() < 7 {
327 return Err(ProtocolError::UnexpectedEof);
328 }
329
330 let cek_table_ordinal = src.get_u16_le();
331 let base_user_type = src.get_u32_le();
332 let base_col_type = src.get_u8();
333
334 let base_type_id =
336 crate::types::TypeId::from_u8(base_col_type).unwrap_or(crate::types::TypeId::Null);
337 let base_type_info = crate::token::decode_type_info(src, base_type_id, base_col_type)?;
338
339 if src.remaining() < 3 {
341 return Err(ProtocolError::UnexpectedEof);
342 }
343
344 let algorithm_id = src.get_u8();
345 let encryption_type_byte = src.get_u8();
346 let normalization_version = src.get_u8();
347
348 let encryption_type = EncryptionTypeWire::from_u8(encryption_type_byte).ok_or(
349 ProtocolError::InvalidField {
350 field: "encryption_type",
351 value: encryption_type_byte as u32,
352 },
353 )?;
354
355 Ok(Self {
356 cek_table_ordinal,
357 base_user_type,
358 base_col_type,
359 base_type_info,
360 algorithm_id,
361 encryption_type,
362 normalization_version,
363 })
364 }
365
366 #[must_use]
368 pub fn is_aead_aes_256(&self) -> bool {
369 self.algorithm_id == ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
370 }
371
372 #[must_use]
374 pub fn is_deterministic(&self) -> bool {
375 self.encryption_type == EncryptionTypeWire::Deterministic
376 }
377
378 #[must_use]
380 pub fn is_randomized(&self) -> bool {
381 self.encryption_type == EncryptionTypeWire::Randomized
382 }
383
384 #[must_use]
386 pub fn base_type_id(&self) -> crate::types::TypeId {
387 crate::types::TypeId::from_u8(self.base_col_type).unwrap_or(crate::types::TypeId::Null)
388 }
389}
390
391#[derive(Debug, Clone, Default)]
396pub struct ColumnCryptoInfo {
397 pub crypto_metadata: Option<CryptoMetadata>,
399}
400
401impl ColumnCryptoInfo {
402 #[must_use]
404 pub fn unencrypted() -> Self {
405 Self {
406 crypto_metadata: None,
407 }
408 }
409
410 #[must_use]
412 pub fn encrypted(metadata: CryptoMetadata) -> Self {
413 Self {
414 crypto_metadata: Some(metadata),
415 }
416 }
417
418 #[must_use]
420 pub fn is_encrypted(&self) -> bool {
421 self.crypto_metadata.is_some()
422 }
423}
424
425#[must_use]
427pub fn is_column_encrypted(flags: u16) -> bool {
428 (flags & COLUMN_FLAG_ENCRYPTED) != 0
429}
430
431#[cfg(test)]
432#[allow(clippy::unwrap_used, clippy::expect_used)]
433mod tests {
434 use super::*;
435 use bytes::BytesMut;
436
437 #[test]
438 fn test_encryption_type_wire_roundtrip() {
439 assert_eq!(
440 EncryptionTypeWire::from_u8(1),
441 Some(EncryptionTypeWire::Deterministic)
442 );
443 assert_eq!(
444 EncryptionTypeWire::from_u8(2),
445 Some(EncryptionTypeWire::Randomized)
446 );
447 assert_eq!(EncryptionTypeWire::from_u8(0), None);
448 assert_eq!(EncryptionTypeWire::from_u8(99), None);
449
450 assert_eq!(EncryptionTypeWire::Deterministic.to_u8(), 1);
451 assert_eq!(EncryptionTypeWire::Randomized.to_u8(), 2);
452 }
453
454 #[test]
455 fn test_crypto_metadata_decode() {
456 let data = [
457 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x26, 0x04, 0x02, 0x01, 0x01, ];
465
466 let mut cursor: &[u8] = &data;
467 let metadata = CryptoMetadata::decode(&mut cursor).unwrap();
468
469 assert_eq!(metadata.cek_table_ordinal, 0);
470 assert_eq!(metadata.base_user_type, 0);
471 assert_eq!(metadata.base_col_type, 0x26); assert_eq!(metadata.base_type_info.max_length, Some(4));
473 assert_eq!(
474 metadata.algorithm_id,
475 ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
476 );
477 assert_eq!(metadata.encryption_type, EncryptionTypeWire::Deterministic);
478 assert_eq!(metadata.normalization_version, 1);
479 assert!(metadata.is_aead_aes_256());
480 assert!(metadata.is_deterministic());
481 assert!(!metadata.is_randomized());
482
483 assert_eq!(metadata.base_type_id(), crate::types::TypeId::IntN);
485 }
486
487 #[test]
488 fn test_cek_value_decode() {
489 let mut data = BytesMut::new();
490
491 data.extend_from_slice(&[0x04, 0x00]);
493 data.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
495 data.extend_from_slice(&[0x04]); data.extend_from_slice(&[b'T', 0x00, b'E', 0x00, b'S', 0x00, b'T', 0x00]);
498 data.extend_from_slice(&[0x04, 0x00]); data.extend_from_slice(&[b'k', 0x00, b'e', 0x00, b'y', 0x00, b'1', 0x00]);
501 data.extend_from_slice(&[0x03]); data.extend_from_slice(&[b'R', 0x00, b'S', 0x00, b'A', 0x00]);
504
505 let mut cursor: &[u8] = &data;
506 let value = CekValue::decode(&mut cursor).unwrap();
507
508 assert_eq!(value.encrypted_value.as_ref(), &[0xDE, 0xAD, 0xBE, 0xEF]);
509 assert_eq!(value.key_store_provider_name, "TEST");
510 assert_eq!(value.cmk_path, "key1");
511 assert_eq!(value.encryption_algorithm, "RSA");
512 }
513
514 #[test]
515 fn test_cek_table_entry_decode() {
516 let mut data = BytesMut::new();
517
518 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
520 data.extend_from_slice(&[0x02, 0x00, 0x00, 0x00]);
522 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
524 data.extend_from_slice(&[0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
526 data.extend_from_slice(&[0x01]);
528
529 data.extend_from_slice(&[0x04, 0x00]); data.extend_from_slice(&[0x11, 0x22, 0x33, 0x44]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'K', 0x00, b'S', 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[b'P', 0x00]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'A', 0x00]); let mut cursor: &[u8] = &data;
540 let entry = CekTableEntry::decode(&mut cursor).expect("should decode entry");
541
542 assert_eq!(entry.database_id, 1);
543 assert_eq!(entry.cek_id, 2);
544 assert_eq!(entry.cek_version, 1);
545 assert_eq!(entry.cek_md_version, 100);
546 assert_eq!(entry.values.len(), 1);
547
548 let value = entry.primary_value().expect("should have primary value");
549 assert_eq!(value.encrypted_value.as_ref(), &[0x11, 0x22, 0x33, 0x44]);
550 }
551
552 #[test]
553 fn test_cek_table_decode() {
554 let mut data = BytesMut::new();
555
556 data.extend_from_slice(&[0x01, 0x00]);
558
559 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[0x02, 0x00]); data.extend_from_slice(&[0xAB, 0xCD]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'K', 0x00]);
571 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[b'P', 0x00]);
573 data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'A', 0x00]);
575
576 let mut cursor: &[u8] = &data;
577 let table = CekTable::decode(&mut cursor).expect("should decode table");
578
579 assert_eq!(table.len(), 1);
580 assert!(!table.is_empty());
581
582 let entry = table.get(0).expect("should have first entry");
583 assert_eq!(entry.database_id, 1);
584 }
585
586 #[test]
587 fn test_is_column_encrypted() {
588 assert!(!is_column_encrypted(0x0000));
589 assert!(!is_column_encrypted(0x0001)); assert!(is_column_encrypted(0x0800)); assert!(is_column_encrypted(0x0801)); }
593
594 #[test]
595 fn test_column_crypto_info() {
596 let unencrypted = ColumnCryptoInfo::unencrypted();
597 assert!(!unencrypted.is_encrypted());
598
599 let metadata = CryptoMetadata {
600 cek_table_ordinal: 0,
601 base_user_type: 0,
602 base_col_type: 0x26, base_type_info: crate::token::TypeInfo::default(),
604 algorithm_id: 2,
605 encryption_type: EncryptionTypeWire::Randomized,
606 normalization_version: 1,
607 };
608 let encrypted = ColumnCryptoInfo::encrypted(metadata);
609 assert!(encrypted.is_encrypted());
610 }
611}