1use bytes::Bytes;
2use itertools::Itertools;
3use protobuf::Message;
4
5use crate::{Error, vector_encryption_metadata::VectorEncryptionMetadata};
6use std::fmt::Display;
7
8const SAAS_SHIELD_EDEK_TYPE_NUM: u8 = 0u8;
18const STANDALONE_EDEK_TYPE_NUM: u8 = 128u8;
19const DCP_EDEK_TYPE_NUM: u8 = 64u8;
20
21const DETERMINISTIC_PAYLOAD_TYPE_NUM: u8 = 0u8;
24const VECTOR_METADATA_PAYLOAD_TYPE_NUM: u8 = 1u8;
25const STANDARD_EDEK_PAYLOAD_TYPE_NUM: u8 = 2u8;
26
27pub(crate) const KEY_ID_HEADER_LEN: usize = 6;
28
29type Result<A> = std::result::Result<A, super::Error>;
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct KeyId(pub u32);
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
34pub enum PayloadType {
35 DeterministicField,
36 VectorMetadata,
37 StandardEdek,
38}
39
40impl Display for PayloadType {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 PayloadType::DeterministicField => write!(f, "Deterministic Field"),
44 PayloadType::VectorMetadata => write!(f, "Vector Metadata"),
45 PayloadType::StandardEdek => write!(f, "Standard EDEK"),
46 }
47 }
48}
49
50impl PayloadType {
51 pub(crate) fn to_numeric_value(self) -> u8 {
52 match self {
53 PayloadType::DeterministicField => DETERMINISTIC_PAYLOAD_TYPE_NUM,
54 PayloadType::VectorMetadata => VECTOR_METADATA_PAYLOAD_TYPE_NUM,
55 PayloadType::StandardEdek => STANDARD_EDEK_PAYLOAD_TYPE_NUM,
56 }
57 }
58
59 pub(crate) fn from_numeric_value(candidate: &u8) -> Result<PayloadType> {
60 let masked_candidate = candidate & 0x0F; match masked_candidate {
62 DETERMINISTIC_PAYLOAD_TYPE_NUM => Ok(PayloadType::DeterministicField),
63 VECTOR_METADATA_PAYLOAD_TYPE_NUM => Ok(PayloadType::VectorMetadata),
64 STANDARD_EDEK_PAYLOAD_TYPE_NUM => Ok(PayloadType::StandardEdek),
65 _ => Err(Error::PayloadTypeError(format!(
66 "Byte {masked_candidate} isn't a valid payload type."
67 ))),
68 }
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
73pub enum EdekType {
74 Standalone,
75 SaasShield,
76 DataControlPlatform,
77}
78
79impl Display for EdekType {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 match self {
82 EdekType::Standalone => write!(f, "Standalone"),
83 EdekType::SaasShield => write!(f, "SaaS Shield"),
84 EdekType::DataControlPlatform => write!(f, "Data Control Platform"),
85 }
86 }
87}
88
89impl EdekType {
90 pub(crate) fn to_numeric_value(self) -> u8 {
91 match self {
92 EdekType::SaasShield => SAAS_SHIELD_EDEK_TYPE_NUM,
93 EdekType::Standalone => STANDALONE_EDEK_TYPE_NUM,
94 EdekType::DataControlPlatform => DCP_EDEK_TYPE_NUM,
95 }
96 }
97
98 pub(crate) fn from_numeric_value(candidate: &u8) -> Result<EdekType> {
99 let masked_candidate = candidate & 0xF0; match masked_candidate {
101 SAAS_SHIELD_EDEK_TYPE_NUM => Ok(EdekType::SaasShield),
102 STANDALONE_EDEK_TYPE_NUM => Ok(EdekType::Standalone),
103 DCP_EDEK_TYPE_NUM => Ok(EdekType::DataControlPlatform),
104 _ => Err(Error::EdekTypeError(format!(
105 "Byte {masked_candidate} isn't a valid edek type."
106 ))),
107 }
108 }
109}
110
111#[derive(Debug, PartialEq)]
113pub struct KeyIdHeader {
114 pub key_id: KeyId,
115 pub edek_type: EdekType,
116 pub payload_type: PayloadType,
117}
118
119impl KeyIdHeader {
120 pub fn new(edek_type: EdekType, payload_type: PayloadType, key_id: KeyId) -> KeyIdHeader {
121 KeyIdHeader {
122 edek_type,
123 payload_type,
124 key_id,
125 }
126 }
127
128 pub fn put_header_on_document<U: IntoIterator<Item = u8>>(&self, document: U) -> Bytes {
130 self.write_to_bytes().into_iter().chain(document).collect()
131 }
132
133 pub fn write_to_bytes(&self) -> Bytes {
136 let iter = u32::to_be_bytes(self.key_id.0).into_iter().chain([
137 self.edek_type.to_numeric_value() | self.payload_type.to_numeric_value(),
138 0u8,
139 ]);
140 Bytes::from_iter(iter)
141 }
142
143 pub(crate) fn parse_from_bytes(b: [u8; 6]) -> Result<KeyIdHeader> {
145 let [one, two, three, four, five, six] = b;
146 if six == 0u8 {
147 let key_id = KeyId(u32::from_be_bytes([one, two, three, four]));
148 let edek_type = EdekType::from_numeric_value(&five)?;
149 let payload_type = PayloadType::from_numeric_value(&five)?;
150 Ok(KeyIdHeader {
151 edek_type,
152 payload_type,
153 key_id,
154 })
155 } else {
156 Err(Error::KeyIdHeaderMalformed(format!(
157 "The last byte of the header should be 0, but it was {six}"
158 )))
159 }
160 }
161}
162
163pub fn create_vector_metadata(
167 key_id_header: KeyIdHeader,
168 iv: Bytes,
169 auth_hash: Bytes,
170) -> (Bytes, VectorEncryptionMetadata) {
171 let vector_encryption_metadata = VectorEncryptionMetadata {
172 iv,
173 auth_hash,
174 ..Default::default()
175 };
176 (key_id_header.write_to_bytes(), vector_encryption_metadata)
177}
178
179pub fn encode_vector_metadata(
182 key_id_header_bytes: Bytes,
183 vector_metadata: VectorEncryptionMetadata,
184) -> Bytes {
185 key_id_header_bytes
186 .into_iter()
187 .chain(
188 vector_metadata
189 .write_to_bytes()
190 .expect("Writing to in memory bytes failed"),
191 )
192 .collect_vec()
193 .into()
194}
195
196pub fn decode_version_prefixed_value(mut value: Bytes) -> Result<(KeyIdHeader, Bytes)> {
199 let value_len = value.len();
200 if value_len >= KEY_ID_HEADER_LEN {
201 let rest = value.split_off(KEY_ID_HEADER_LEN);
202 match value[..] {
203 [one, two, three, four, five, six] => {
204 let key_id_header =
205 KeyIdHeader::parse_from_bytes([one, two, three, four, five, six])?;
206 Ok((key_id_header, rest))
207 }
208 _ => Err(Error::KeyIdHeaderTooShort(value_len)),
210 }
211 } else {
212 Err(Error::KeyIdHeaderTooShort(value_len))
213 }
214}
215
216pub fn get_prefix_bytes_for_search(key_id_header: KeyIdHeader) -> Bytes {
218 key_id_header.write_to_bytes()
219}
220
221#[cfg(test)]
222mod test {
223
224 use super::*;
225 #[test]
226 fn test_create_produces_saas_shield() {
227 let iv_bytes: Bytes = (1..12).collect_vec().into();
228 let auth_hash_bytes: Bytes = (1..16).collect_vec().into();
229 let (header, result) = create_vector_metadata(
230 KeyIdHeader::new(
231 EdekType::SaasShield,
232 PayloadType::DeterministicField,
233 KeyId(72000),
234 ),
235 iv_bytes.clone(),
236 auth_hash_bytes.clone(),
237 );
238 assert_eq!(
239 header.to_vec(),
240 vec![0, 1, 25, 64, SAAS_SHIELD_EDEK_TYPE_NUM, 0]
241 );
242 assert_eq!(result.iv, iv_bytes);
243 assert_eq!(result.auth_hash, auth_hash_bytes);
244 }
245
246 #[test]
247 fn test_create_produces_standalone() {
248 let iv_bytes: Bytes = (1..12).collect_vec().into();
249 let auth_hash_bytes: Bytes = (1..16).collect_vec().into();
250 let (header, result) = create_vector_metadata(
251 KeyIdHeader::new(
252 EdekType::Standalone,
253 PayloadType::DeterministicField,
254 KeyId(72000),
255 ),
256 iv_bytes.clone(),
257 auth_hash_bytes.clone(),
258 );
259 assert_eq!(
260 header.to_vec(),
261 vec![0, 1, 25, 64, STANDALONE_EDEK_TYPE_NUM, 0]
262 );
263 assert_eq!(result.iv, iv_bytes);
264 assert_eq!(result.auth_hash, auth_hash_bytes);
265 }
266
267 #[test]
268 fn test_encode_decode_roundtrip() {
269 let iv_bytes: Bytes = (1..12).collect_vec().into();
270 let auth_hash_bytes: Bytes = (1..16).collect_vec().into();
271 let key_id = KeyId(72000);
272 let (header, result) = create_vector_metadata(
273 KeyIdHeader::new(EdekType::Standalone, PayloadType::StandardEdek, key_id),
274 iv_bytes.clone(),
275 auth_hash_bytes.clone(),
276 );
277
278 let encode_result = encode_vector_metadata(header, result.clone());
279 let (final_key_id_header, final_vector_bytes) =
280 decode_version_prefixed_value(encode_result).unwrap();
281 assert_eq!(final_key_id_header.key_id, key_id);
282 assert_eq!(final_key_id_header.edek_type, EdekType::Standalone);
283 assert_eq!(final_key_id_header.payload_type, PayloadType::StandardEdek);
284 assert_eq!(final_vector_bytes, result.write_to_bytes().unwrap());
285 }
286
287 fn edek_type_roundtrip(e: EdekType) -> Result<EdekType> {
288 EdekType::from_numeric_value(&e.to_numeric_value())
289 }
290 #[test]
291 fn test_edek_type_to_and_from_roundtrip() {
292 let all_types = [
293 EdekType::Standalone,
294 EdekType::SaasShield,
295 EdekType::DataControlPlatform,
296 ];
297
298 for e in all_types {
300 match e {
301 EdekType::Standalone => edek_type_roundtrip(EdekType::Standalone),
302 EdekType::SaasShield => edek_type_roundtrip(EdekType::SaasShield),
303 EdekType::DataControlPlatform => edek_type_roundtrip(EdekType::DataControlPlatform),
304 }
305 .unwrap();
306 }
307 }
308
309 fn payload_type_roundtrip(e: PayloadType) -> Result<PayloadType> {
310 PayloadType::from_numeric_value(&e.to_numeric_value())
311 }
312
313 #[test]
314 fn test_payload_type_to_and_from_roundtrip() {
315 let all_types = [
316 PayloadType::DeterministicField,
317 PayloadType::VectorMetadata,
318 PayloadType::StandardEdek,
319 ];
320
321 for e in all_types {
323 match e {
324 PayloadType::DeterministicField => {
325 payload_type_roundtrip(PayloadType::DeterministicField)
326 }
327 PayloadType::VectorMetadata => payload_type_roundtrip(PayloadType::VectorMetadata),
328 PayloadType::StandardEdek => payload_type_roundtrip(PayloadType::StandardEdek),
329 }
330 .unwrap();
331 }
332 }
333}