1use bytes::Bytes;
2use itertools::Itertools;
3use protobuf::Message;
4
5use crate::{Error, Result, vector_encryption_metadata::VectorEncryptionMetadata};
6use std::fmt::Display;
7
8const SAAS_SHIELD_EDEK_TYPE_NUM: u8 = 0u8;
18const STANDALONE_EDEK_TYPE_NUM: u8 = 128u8;
19const DCP_EDEK_TYPE_NUM: u8 = 64u8;
20
21const DETERMINISTIC_PAYLOAD_TYPE_NUM: u8 = 0u8;
24const VECTOR_METADATA_PAYLOAD_TYPE_NUM: u8 = 1u8;
25const STANDARD_EDEK_PAYLOAD_TYPE_NUM: u8 = 2u8;
26
27pub(crate) const KEY_ID_HEADER_LEN: usize = 6;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct KeyId(pub u32);
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
33pub enum PayloadType {
34 DeterministicField,
35 VectorMetadata,
36 StandardEdek,
37}
38
39impl Display for PayloadType {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 PayloadType::DeterministicField => write!(f, "Deterministic Field"),
43 PayloadType::VectorMetadata => write!(f, "Vector Metadata"),
44 PayloadType::StandardEdek => write!(f, "Standard EDEK"),
45 }
46 }
47}
48
49impl PayloadType {
50 pub(crate) fn to_numeric_value(self) -> u8 {
51 match self {
52 PayloadType::DeterministicField => DETERMINISTIC_PAYLOAD_TYPE_NUM,
53 PayloadType::VectorMetadata => VECTOR_METADATA_PAYLOAD_TYPE_NUM,
54 PayloadType::StandardEdek => STANDARD_EDEK_PAYLOAD_TYPE_NUM,
55 }
56 }
57
58 pub(crate) fn from_numeric_value(candidate: &u8) -> Result<PayloadType> {
59 let masked_candidate = candidate & 0x0F; match masked_candidate {
61 DETERMINISTIC_PAYLOAD_TYPE_NUM => Ok(PayloadType::DeterministicField),
62 VECTOR_METADATA_PAYLOAD_TYPE_NUM => Ok(PayloadType::VectorMetadata),
63 STANDARD_EDEK_PAYLOAD_TYPE_NUM => Ok(PayloadType::StandardEdek),
64 _ => Err(Error::PayloadTypeError(format!(
65 "Byte {masked_candidate} isn't a valid payload type."
66 ))),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
72pub enum EdekType {
73 Standalone,
74 SaasShield,
75 DataControlPlatform,
76}
77
78impl Display for EdekType {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 EdekType::Standalone => write!(f, "Standalone"),
82 EdekType::SaasShield => write!(f, "SaaS Shield"),
83 EdekType::DataControlPlatform => write!(f, "Data Control Platform"),
84 }
85 }
86}
87
88impl EdekType {
89 pub(crate) fn to_numeric_value(self) -> u8 {
90 match self {
91 EdekType::SaasShield => SAAS_SHIELD_EDEK_TYPE_NUM,
92 EdekType::Standalone => STANDALONE_EDEK_TYPE_NUM,
93 EdekType::DataControlPlatform => DCP_EDEK_TYPE_NUM,
94 }
95 }
96
97 pub(crate) fn from_numeric_value(candidate: &u8) -> Result<EdekType> {
98 let masked_candidate = candidate & 0xF0; match masked_candidate {
100 SAAS_SHIELD_EDEK_TYPE_NUM => Ok(EdekType::SaasShield),
101 STANDALONE_EDEK_TYPE_NUM => Ok(EdekType::Standalone),
102 DCP_EDEK_TYPE_NUM => Ok(EdekType::DataControlPlatform),
103 _ => Err(Error::EdekTypeError(format!(
104 "Byte {masked_candidate} isn't a valid edek type."
105 ))),
106 }
107 }
108}
109
110#[derive(Debug, PartialEq)]
112pub struct KeyIdHeader {
113 pub key_id: KeyId,
114 pub edek_type: EdekType,
115 pub payload_type: PayloadType,
116}
117
118impl KeyIdHeader {
119 pub fn new(edek_type: EdekType, payload_type: PayloadType, key_id: KeyId) -> KeyIdHeader {
120 KeyIdHeader {
121 edek_type,
122 payload_type,
123 key_id,
124 }
125 }
126
127 pub fn put_header_on_document<U: IntoIterator<Item = u8>>(&self, document: U) -> Bytes {
129 self.write_to_bytes().into_iter().chain(document).collect()
130 }
131
132 pub fn write_to_bytes(&self) -> Bytes {
135 let iter = u32::to_be_bytes(self.key_id.0).into_iter().chain([
136 self.edek_type.to_numeric_value() | self.payload_type.to_numeric_value(),
137 0u8,
138 ]);
139 Bytes::from_iter(iter)
140 }
141
142 pub(crate) fn parse_from_bytes(b: [u8; 6]) -> Result<KeyIdHeader> {
144 let [one, two, three, four, five, six] = b;
145 if six == 0u8 {
146 let key_id = KeyId(u32::from_be_bytes([one, two, three, four]));
147 let edek_type = EdekType::from_numeric_value(&five)?;
148 let payload_type = PayloadType::from_numeric_value(&five)?;
149 Ok(KeyIdHeader {
150 edek_type,
151 payload_type,
152 key_id,
153 })
154 } else {
155 Err(Error::KeyIdHeaderMalformed(format!(
156 "The last byte of the header should be 0, but it was {six}"
157 )))
158 }
159 }
160}
161
162pub fn create_vector_metadata(
166 key_id_header: KeyIdHeader,
167 iv: Bytes,
168 auth_hash: Bytes,
169) -> (Bytes, VectorEncryptionMetadata) {
170 let vector_encryption_metadata = VectorEncryptionMetadata {
171 iv,
172 auth_hash,
173 ..Default::default()
174 };
175 (key_id_header.write_to_bytes(), vector_encryption_metadata)
176}
177
178pub fn encode_vector_metadata(
181 key_id_header_bytes: Bytes,
182 vector_metadata: VectorEncryptionMetadata,
183) -> Bytes {
184 key_id_header_bytes
185 .into_iter()
186 .chain(
187 vector_metadata
188 .write_to_bytes()
189 .expect("Writing to in memory bytes failed"),
190 )
191 .collect_vec()
192 .into()
193}
194
195pub fn decode_version_prefixed_value(mut value: Bytes) -> Result<(KeyIdHeader, Bytes)> {
198 let value_len = value.len();
199 if value_len >= KEY_ID_HEADER_LEN {
200 let rest = value.split_off(KEY_ID_HEADER_LEN);
201 match value[..] {
202 [one, two, three, four, five, six] => {
203 let key_id_header =
204 KeyIdHeader::parse_from_bytes([one, two, three, four, five, six])?;
205 Ok((key_id_header, rest))
206 }
207 _ => Err(Error::KeyIdHeaderTooShort(value_len)),
209 }
210 } else {
211 Err(Error::KeyIdHeaderTooShort(value_len))
212 }
213}
214
215pub fn get_prefix_bytes_for_search(key_id_header: KeyIdHeader) -> Bytes {
217 key_id_header.write_to_bytes()
218}
219
220#[cfg(test)]
221mod test {
222
223 use super::*;
224 #[test]
225 fn test_create_produces_saas_shield() {
226 let iv_bytes: Bytes = (1..12).collect_vec().into();
227 let auth_hash_bytes: Bytes = (1..16).collect_vec().into();
228 let (header, result) = create_vector_metadata(
229 KeyIdHeader::new(
230 EdekType::SaasShield,
231 PayloadType::DeterministicField,
232 KeyId(72000),
233 ),
234 iv_bytes.clone(),
235 auth_hash_bytes.clone(),
236 );
237 assert_eq!(
238 header.to_vec(),
239 vec![0, 1, 25, 64, SAAS_SHIELD_EDEK_TYPE_NUM, 0]
240 );
241 assert_eq!(result.iv, iv_bytes);
242 assert_eq!(result.auth_hash, auth_hash_bytes);
243 }
244
245 #[test]
246 fn test_create_produces_standalone() {
247 let iv_bytes: Bytes = (1..12).collect_vec().into();
248 let auth_hash_bytes: Bytes = (1..16).collect_vec().into();
249 let (header, result) = create_vector_metadata(
250 KeyIdHeader::new(
251 EdekType::Standalone,
252 PayloadType::DeterministicField,
253 KeyId(72000),
254 ),
255 iv_bytes.clone(),
256 auth_hash_bytes.clone(),
257 );
258 assert_eq!(
259 header.to_vec(),
260 vec![0, 1, 25, 64, STANDALONE_EDEK_TYPE_NUM, 0]
261 );
262 assert_eq!(result.iv, iv_bytes);
263 assert_eq!(result.auth_hash, auth_hash_bytes);
264 }
265
266 #[test]
267 fn test_encode_decode_roundtrip() {
268 let iv_bytes: Bytes = (1..12).collect_vec().into();
269 let auth_hash_bytes: Bytes = (1..16).collect_vec().into();
270 let key_id = KeyId(72000);
271 let (header, result) = create_vector_metadata(
272 KeyIdHeader::new(EdekType::Standalone, PayloadType::StandardEdek, key_id),
273 iv_bytes.clone(),
274 auth_hash_bytes.clone(),
275 );
276
277 let encode_result = encode_vector_metadata(header, result.clone());
278 let (final_key_id_header, final_vector_bytes) =
279 decode_version_prefixed_value(encode_result).unwrap();
280 assert_eq!(final_key_id_header.key_id, key_id);
281 assert_eq!(final_key_id_header.edek_type, EdekType::Standalone);
282 assert_eq!(final_key_id_header.payload_type, PayloadType::StandardEdek);
283 assert_eq!(final_vector_bytes, result.write_to_bytes().unwrap());
284 }
285
286 fn edek_type_roundtrip(e: EdekType) -> Result<EdekType> {
287 EdekType::from_numeric_value(&e.to_numeric_value())
288 }
289 #[test]
290 fn test_edek_type_to_and_from_roundtrip() {
291 let all_types = [
292 EdekType::Standalone,
293 EdekType::SaasShield,
294 EdekType::DataControlPlatform,
295 ];
296
297 for e in all_types {
299 match e {
300 EdekType::Standalone => edek_type_roundtrip(EdekType::Standalone),
301 EdekType::SaasShield => edek_type_roundtrip(EdekType::SaasShield),
302 EdekType::DataControlPlatform => edek_type_roundtrip(EdekType::DataControlPlatform),
303 }
304 .unwrap();
305 }
306 }
307
308 fn payload_type_roundtrip(e: PayloadType) -> Result<PayloadType> {
309 PayloadType::from_numeric_value(&e.to_numeric_value())
310 }
311
312 #[test]
313 fn test_payload_type_to_and_from_roundtrip() {
314 let all_types = [
315 PayloadType::DeterministicField,
316 PayloadType::VectorMetadata,
317 PayloadType::StandardEdek,
318 ];
319
320 for e in all_types {
322 match e {
323 PayloadType::DeterministicField => {
324 payload_type_roundtrip(PayloadType::DeterministicField)
325 }
326 PayloadType::VectorMetadata => payload_type_roundtrip(PayloadType::VectorMetadata),
327 PayloadType::StandardEdek => payload_type_roundtrip(PayloadType::StandardEdek),
328 }
329 .unwrap();
330 }
331 }
332}