1use bytes::{Buf, Bytes};
52
53#[cfg(not(feature = "std"))]
54use alloc::string::String;
55#[cfg(not(feature = "std"))]
56use alloc::vec::Vec;
57
58use crate::codec::{read_b_varchar, read_us_varchar};
59use crate::error::ProtocolError;
60
61pub const COLUMN_FLAG_ENCRYPTED: u16 = 0x0800;
63
64pub const ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256: u8 = 2;
66
67pub const ENCRYPTION_TYPE_DETERMINISTIC: u8 = 1;
69
70pub const ENCRYPTION_TYPE_RANDOMIZED: u8 = 2;
72
73pub const NORMALIZATION_RULE_VERSION: u8 = 1;
75
76#[derive(Debug, Clone)]
81pub struct CekTableEntry {
82 pub database_id: u32,
84 pub cek_id: u32,
86 pub cek_version: u32,
88 pub cek_md_version: u64,
90 pub values: Vec<CekValue>,
92}
93
94#[derive(Debug, Clone)]
99pub struct CekValue {
100 pub encrypted_value: Bytes,
102 pub key_store_provider_name: String,
104 pub cmk_path: String,
106 pub encryption_algorithm: String,
108}
109
110#[derive(Debug, Clone)]
115pub struct CryptoMetadata {
116 pub cek_table_ordinal: u16,
118 pub algorithm_id: u8,
120 pub encryption_type: EncryptionTypeWire,
122 pub normalization_version: u8,
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum EncryptionTypeWire {
129 Deterministic,
131 Randomized,
133}
134
135impl EncryptionTypeWire {
136 #[must_use]
138 pub fn from_u8(value: u8) -> Option<Self> {
139 match value {
140 ENCRYPTION_TYPE_DETERMINISTIC => Some(Self::Deterministic),
141 ENCRYPTION_TYPE_RANDOMIZED => Some(Self::Randomized),
142 _ => None,
143 }
144 }
145
146 #[must_use]
148 pub fn to_u8(self) -> u8 {
149 match self {
150 Self::Deterministic => ENCRYPTION_TYPE_DETERMINISTIC,
151 Self::Randomized => ENCRYPTION_TYPE_RANDOMIZED,
152 }
153 }
154}
155
156#[derive(Debug, Clone, Default)]
158pub struct CekTable {
159 pub entries: Vec<CekTableEntry>,
161}
162
163impl CekTable {
164 #[must_use]
166 pub fn new() -> Self {
167 Self::default()
168 }
169
170 #[must_use]
172 pub fn get(&self, ordinal: u16) -> Option<&CekTableEntry> {
173 self.entries.get(ordinal as usize)
174 }
175
176 #[must_use]
178 pub fn is_empty(&self) -> bool {
179 self.entries.is_empty()
180 }
181
182 #[must_use]
184 pub fn len(&self) -> usize {
185 self.entries.len()
186 }
187
188 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
213 if src.remaining() < 2 {
214 return Err(ProtocolError::UnexpectedEof);
215 }
216
217 let cek_count = src.get_u16_le() as usize;
218
219 let mut entries = Vec::with_capacity(cek_count);
220
221 for _ in 0..cek_count {
222 let entry = CekTableEntry::decode(src)?;
223 entries.push(entry);
224 }
225
226 Ok(Self { entries })
227 }
228}
229
230impl CekTableEntry {
231 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
233 if src.remaining() < 21 {
235 return Err(ProtocolError::UnexpectedEof);
236 }
237
238 let database_id = src.get_u32_le();
239 let cek_id = src.get_u32_le();
240 let cek_version = src.get_u32_le();
241 let cek_md_version = src.get_u64_le();
242 let value_count = src.get_u8() as usize;
243
244 let mut values = Vec::with_capacity(value_count);
245
246 for _ in 0..value_count {
247 let value = CekValue::decode(src)?;
248 values.push(value);
249 }
250
251 Ok(Self {
252 database_id,
253 cek_id,
254 cek_version,
255 cek_md_version,
256 values,
257 })
258 }
259
260 #[must_use]
262 pub fn primary_value(&self) -> Option<&CekValue> {
263 self.values.first()
264 }
265}
266
267impl CekValue {
268 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
270 if src.remaining() < 2 {
272 return Err(ProtocolError::UnexpectedEof);
273 }
274
275 let encrypted_value_length = src.get_u16_le() as usize;
276
277 if src.remaining() < encrypted_value_length {
278 return Err(ProtocolError::UnexpectedEof);
279 }
280
281 let encrypted_value = src.copy_to_bytes(encrypted_value_length);
282
283 let key_store_provider_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
285
286 let cmk_path = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
288
289 let encryption_algorithm = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
291
292 Ok(Self {
293 encrypted_value,
294 key_store_provider_name,
295 cmk_path,
296 encryption_algorithm,
297 })
298 }
299}
300
301impl CryptoMetadata {
302 pub const SIZE: usize = 5; pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
307 if src.remaining() < Self::SIZE {
308 return Err(ProtocolError::UnexpectedEof);
309 }
310
311 let cek_table_ordinal = src.get_u16_le();
312 let algorithm_id = src.get_u8();
313 let encryption_type_byte = src.get_u8();
314 let normalization_version = src.get_u8();
315
316 let encryption_type = EncryptionTypeWire::from_u8(encryption_type_byte).ok_or(
317 ProtocolError::InvalidField {
318 field: "encryption_type",
319 value: encryption_type_byte as u32,
320 },
321 )?;
322
323 Ok(Self {
324 cek_table_ordinal,
325 algorithm_id,
326 encryption_type,
327 normalization_version,
328 })
329 }
330
331 #[must_use]
333 pub fn is_aead_aes_256(&self) -> bool {
334 self.algorithm_id == ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
335 }
336
337 #[must_use]
339 pub fn is_deterministic(&self) -> bool {
340 self.encryption_type == EncryptionTypeWire::Deterministic
341 }
342
343 #[must_use]
345 pub fn is_randomized(&self) -> bool {
346 self.encryption_type == EncryptionTypeWire::Randomized
347 }
348}
349
350#[derive(Debug, Clone, Default)]
355pub struct ColumnCryptoInfo {
356 pub crypto_metadata: Option<CryptoMetadata>,
358}
359
360impl ColumnCryptoInfo {
361 #[must_use]
363 pub fn unencrypted() -> Self {
364 Self {
365 crypto_metadata: None,
366 }
367 }
368
369 #[must_use]
371 pub fn encrypted(metadata: CryptoMetadata) -> Self {
372 Self {
373 crypto_metadata: Some(metadata),
374 }
375 }
376
377 #[must_use]
379 pub fn is_encrypted(&self) -> bool {
380 self.crypto_metadata.is_some()
381 }
382}
383
384#[must_use]
386pub fn is_column_encrypted(flags: u16) -> bool {
387 (flags & COLUMN_FLAG_ENCRYPTED) != 0
388}
389
390#[cfg(test)]
391#[allow(clippy::unwrap_used, clippy::expect_used)]
392mod tests {
393 use super::*;
394 use bytes::BytesMut;
395
396 #[test]
397 fn test_encryption_type_wire_roundtrip() {
398 assert_eq!(
399 EncryptionTypeWire::from_u8(1),
400 Some(EncryptionTypeWire::Deterministic)
401 );
402 assert_eq!(
403 EncryptionTypeWire::from_u8(2),
404 Some(EncryptionTypeWire::Randomized)
405 );
406 assert_eq!(EncryptionTypeWire::from_u8(0), None);
407 assert_eq!(EncryptionTypeWire::from_u8(99), None);
408
409 assert_eq!(EncryptionTypeWire::Deterministic.to_u8(), 1);
410 assert_eq!(EncryptionTypeWire::Randomized.to_u8(), 2);
411 }
412
413 #[test]
414 fn test_crypto_metadata_decode() {
415 let data = [
416 0x00, 0x00, 0x02, 0x01, 0x01, ];
421
422 let mut cursor: &[u8] = &data;
423 let metadata = CryptoMetadata::decode(&mut cursor).unwrap();
424
425 assert_eq!(metadata.cek_table_ordinal, 0);
426 assert_eq!(
427 metadata.algorithm_id,
428 ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
429 );
430 assert_eq!(metadata.encryption_type, EncryptionTypeWire::Deterministic);
431 assert_eq!(metadata.normalization_version, 1);
432 assert!(metadata.is_aead_aes_256());
433 assert!(metadata.is_deterministic());
434 assert!(!metadata.is_randomized());
435 }
436
437 #[test]
438 fn test_cek_value_decode() {
439 let mut data = BytesMut::new();
440
441 data.extend_from_slice(&[0x04, 0x00]);
443 data.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
445 data.extend_from_slice(&[0x04]); data.extend_from_slice(&[b'T', 0x00, b'E', 0x00, b'S', 0x00, b'T', 0x00]);
448 data.extend_from_slice(&[0x04, 0x00]); data.extend_from_slice(&[b'k', 0x00, b'e', 0x00, b'y', 0x00, b'1', 0x00]);
451 data.extend_from_slice(&[0x03]); data.extend_from_slice(&[b'R', 0x00, b'S', 0x00, b'A', 0x00]);
454
455 let mut cursor: &[u8] = &data;
456 let value = CekValue::decode(&mut cursor).unwrap();
457
458 assert_eq!(value.encrypted_value.as_ref(), &[0xDE, 0xAD, 0xBE, 0xEF]);
459 assert_eq!(value.key_store_provider_name, "TEST");
460 assert_eq!(value.cmk_path, "key1");
461 assert_eq!(value.encryption_algorithm, "RSA");
462 }
463
464 #[test]
465 fn test_cek_table_entry_decode() {
466 let mut data = BytesMut::new();
467
468 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
470 data.extend_from_slice(&[0x02, 0x00, 0x00, 0x00]);
472 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
474 data.extend_from_slice(&[0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
476 data.extend_from_slice(&[0x01]);
478
479 data.extend_from_slice(&[0x04, 0x00]); data.extend_from_slice(&[0x11, 0x22, 0x33, 0x44]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'K', 0x00, b'S', 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[b'P', 0x00]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'A', 0x00]); let mut cursor: &[u8] = &data;
490 let entry = CekTableEntry::decode(&mut cursor).expect("should decode entry");
491
492 assert_eq!(entry.database_id, 1);
493 assert_eq!(entry.cek_id, 2);
494 assert_eq!(entry.cek_version, 1);
495 assert_eq!(entry.cek_md_version, 100);
496 assert_eq!(entry.values.len(), 1);
497
498 let value = entry.primary_value().expect("should have primary value");
499 assert_eq!(value.encrypted_value.as_ref(), &[0x11, 0x22, 0x33, 0x44]);
500 }
501
502 #[test]
503 fn test_cek_table_decode() {
504 let mut data = BytesMut::new();
505
506 data.extend_from_slice(&[0x01, 0x00]);
508
509 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[0x02, 0x00]); data.extend_from_slice(&[0xAB, 0xCD]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'K', 0x00]);
521 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[b'P', 0x00]);
523 data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'A', 0x00]);
525
526 let mut cursor: &[u8] = &data;
527 let table = CekTable::decode(&mut cursor).expect("should decode table");
528
529 assert_eq!(table.len(), 1);
530 assert!(!table.is_empty());
531
532 let entry = table.get(0).expect("should have first entry");
533 assert_eq!(entry.database_id, 1);
534 }
535
536 #[test]
537 fn test_is_column_encrypted() {
538 assert!(!is_column_encrypted(0x0000));
539 assert!(!is_column_encrypted(0x0001)); assert!(is_column_encrypted(0x0800)); assert!(is_column_encrypted(0x0801)); }
543
544 #[test]
545 fn test_column_crypto_info() {
546 let unencrypted = ColumnCryptoInfo::unencrypted();
547 assert!(!unencrypted.is_encrypted());
548
549 let metadata = CryptoMetadata {
550 cek_table_ordinal: 0,
551 algorithm_id: 2,
552 encryption_type: EncryptionTypeWire::Randomized,
553 normalization_version: 1,
554 };
555 let encrypted = ColumnCryptoInfo::encrypted(metadata);
556 assert!(encrypted.is_encrypted());
557 }
558}