ironcore_documents/v5/
key_id_header.rs

1use bytes::Bytes;
2use itertools::Itertools;
3use protobuf::Message;
4
5use crate::{Error, vector_encryption_metadata::VectorEncryptionMetadata};
6use std::fmt::Display;
7
8// This file is for functions which are working with our key id header value.
9// This value has the following structure:
10// 4 Byte id. This value is a u32 encoded in big endian format.
11// 1 Byte where the first 4 bits are used for which type of edek the id points to (Standalone, Saas Shield, DCP).
12//   The next 4 bits are to denote which type of data follows it (vector metadata, IronCore Edoc, deterministic ciphertext)
13// 1 Byte of 0
14
15// EdekType numeric values. Note that in order to compare to these values you must bitmask
16// off the bottom 4 bits of the byte first.
17const SAAS_SHIELD_EDEK_TYPE_NUM: u8 = 0u8;
18const STANDALONE_EDEK_TYPE_NUM: u8 = 128u8;
19const DCP_EDEK_TYPE_NUM: u8 = 64u8;
20
21// PayloadType numeric values.Note that in order to compare to these values you must bitmask
22// off the top 4 bits of the byte first.
23const DETERMINISTIC_PAYLOAD_TYPE_NUM: u8 = 0u8;
24const VECTOR_METADATA_PAYLOAD_TYPE_NUM: u8 = 1u8;
25const STANDARD_EDEK_PAYLOAD_TYPE_NUM: u8 = 2u8;
26
27pub(crate) const KEY_ID_HEADER_LEN: usize = 6;
28
29type Result<A> = std::result::Result<A, super::Error>;
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct KeyId(pub u32);
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
34pub enum PayloadType {
35    DeterministicField,
36    VectorMetadata,
37    StandardEdek,
38}
39
40impl Display for PayloadType {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            PayloadType::DeterministicField => write!(f, "Deterministic Field"),
44            PayloadType::VectorMetadata => write!(f, "Vector Metadata"),
45            PayloadType::StandardEdek => write!(f, "Standard EDEK"),
46        }
47    }
48}
49
50impl PayloadType {
51    pub(crate) fn to_numeric_value(self) -> u8 {
52        match self {
53            PayloadType::DeterministicField => DETERMINISTIC_PAYLOAD_TYPE_NUM,
54            PayloadType::VectorMetadata => VECTOR_METADATA_PAYLOAD_TYPE_NUM,
55            PayloadType::StandardEdek => STANDARD_EDEK_PAYLOAD_TYPE_NUM,
56        }
57    }
58
59    pub(crate) fn from_numeric_value(candidate: &u8) -> Result<PayloadType> {
60        let masked_candidate = candidate & 0x0F; // Mask off the top 4 bits.
61        match masked_candidate {
62            DETERMINISTIC_PAYLOAD_TYPE_NUM => Ok(PayloadType::DeterministicField),
63            VECTOR_METADATA_PAYLOAD_TYPE_NUM => Ok(PayloadType::VectorMetadata),
64            STANDARD_EDEK_PAYLOAD_TYPE_NUM => Ok(PayloadType::StandardEdek),
65            _ => Err(Error::PayloadTypeError(format!(
66                "Byte {masked_candidate} isn't a valid payload type."
67            ))),
68        }
69    }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
73pub enum EdekType {
74    Standalone,
75    SaasShield,
76    DataControlPlatform,
77}
78
79impl Display for EdekType {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        match self {
82            EdekType::Standalone => write!(f, "Standalone"),
83            EdekType::SaasShield => write!(f, "SaaS Shield"),
84            EdekType::DataControlPlatform => write!(f, "Data Control Platform"),
85        }
86    }
87}
88
89impl EdekType {
90    pub(crate) fn to_numeric_value(self) -> u8 {
91        match self {
92            EdekType::SaasShield => SAAS_SHIELD_EDEK_TYPE_NUM,
93            EdekType::Standalone => STANDALONE_EDEK_TYPE_NUM,
94            EdekType::DataControlPlatform => DCP_EDEK_TYPE_NUM,
95        }
96    }
97
98    pub(crate) fn from_numeric_value(candidate: &u8) -> Result<EdekType> {
99        let masked_candidate = candidate & 0xF0; // Mask off the bottom 4 bits.
100        match masked_candidate {
101            SAAS_SHIELD_EDEK_TYPE_NUM => Ok(EdekType::SaasShield),
102            STANDALONE_EDEK_TYPE_NUM => Ok(EdekType::Standalone),
103            DCP_EDEK_TYPE_NUM => Ok(EdekType::DataControlPlatform),
104            _ => Err(Error::EdekTypeError(format!(
105                "Byte {masked_candidate} isn't a valid edek type."
106            ))),
107        }
108    }
109}
110
111/// The key id header parsed into its pieces.
112#[derive(Debug, PartialEq)]
113pub struct KeyIdHeader {
114    pub key_id: KeyId,
115    pub edek_type: EdekType,
116    pub payload_type: PayloadType,
117}
118
119impl KeyIdHeader {
120    pub fn new(edek_type: EdekType, payload_type: PayloadType, key_id: KeyId) -> KeyIdHeader {
121        KeyIdHeader {
122            edek_type,
123            payload_type,
124            key_id,
125        }
126    }
127
128    /// Write this header onto the front of the document.
129    pub fn put_header_on_document<U: IntoIterator<Item = u8>>(&self, document: U) -> Bytes {
130        self.write_to_bytes().into_iter().chain(document).collect()
131    }
132
133    /// Write the header to bytes. This is done by writing the key_id to be 4 bytes, putting the edek and payload types into
134    /// the next byte and padding with a zero. See the comment at the top of this file for more information.
135    pub fn write_to_bytes(&self) -> Bytes {
136        let iter = u32::to_be_bytes(self.key_id.0).into_iter().chain([
137            self.edek_type.to_numeric_value() | self.payload_type.to_numeric_value(),
138            0u8,
139        ]);
140        Bytes::from_iter(iter)
141    }
142
143    /// This is not public because callers should use use decode_version_prefixed_value instead.
144    pub(crate) fn parse_from_bytes(b: [u8; 6]) -> Result<KeyIdHeader> {
145        let [one, two, three, four, five, six] = b;
146        if six == 0u8 {
147            let key_id = KeyId(u32::from_be_bytes([one, two, three, four]));
148            let edek_type = EdekType::from_numeric_value(&five)?;
149            let payload_type = PayloadType::from_numeric_value(&five)?;
150            Ok(KeyIdHeader {
151                edek_type,
152                payload_type,
153                key_id,
154            })
155        } else {
156            Err(Error::KeyIdHeaderMalformed(format!(
157                "The last byte of the header should be 0, but it was {six}"
158            )))
159        }
160    }
161}
162
163/// Create the key_id_header and vector metadata. The first value is the key_id header and the
164/// second is the vector metadata. These can be passed to encode_vector_metadata to create a single
165/// byte string.
166pub fn create_vector_metadata(
167    key_id_header: KeyIdHeader,
168    iv: Bytes,
169    auth_hash: Bytes,
170) -> (Bytes, VectorEncryptionMetadata) {
171    let vector_encryption_metadata = VectorEncryptionMetadata {
172        iv,
173        auth_hash,
174        ..Default::default()
175    };
176    (key_id_header.write_to_bytes(), vector_encryption_metadata)
177}
178
179/// Form the bytes that represent the vector metadata to the outside world.
180/// This is the protobuf with the key_id_header put onto the front.
181pub fn encode_vector_metadata(
182    key_id_header_bytes: Bytes,
183    vector_metadata: VectorEncryptionMetadata,
184) -> Bytes {
185    key_id_header_bytes
186        .into_iter()
187        .chain(
188            vector_metadata
189                .write_to_bytes()
190                .expect("Writing to in memory bytes failed"),
191        )
192        .collect_vec()
193        .into()
194}
195
196/// Decode a value which has the key_id_header put on the front by breaking it up.
197/// This returns the key id, edek type and the remaining bytes.
198pub fn decode_version_prefixed_value(mut value: Bytes) -> Result<(KeyIdHeader, Bytes)> {
199    let value_len = value.len();
200    if value_len >= KEY_ID_HEADER_LEN {
201        let rest = value.split_off(KEY_ID_HEADER_LEN);
202        match value[..] {
203            [one, two, three, four, five, six] => {
204                let key_id_header =
205                    KeyIdHeader::parse_from_bytes([one, two, three, four, five, six])?;
206                Ok((key_id_header, rest))
207            }
208            // This should not ever be able to happen since we sliced off 6 above
209            _ => Err(Error::KeyIdHeaderTooShort(value_len)),
210        }
211    } else {
212        Err(Error::KeyIdHeaderTooShort(value_len))
213    }
214}
215
216/// Get the bytes that can be used for a prefix search of key_id headers.
217pub fn get_prefix_bytes_for_search(key_id_header: KeyIdHeader) -> Bytes {
218    key_id_header.write_to_bytes()
219}
220
221#[cfg(test)]
222mod test {
223
224    use super::*;
225    #[test]
226    fn test_create_produces_saas_shield() {
227        let iv_bytes: Bytes = (1..12).collect_vec().into();
228        let auth_hash_bytes: Bytes = (1..16).collect_vec().into();
229        let (header, result) = create_vector_metadata(
230            KeyIdHeader::new(
231                EdekType::SaasShield,
232                PayloadType::DeterministicField,
233                KeyId(72000),
234            ),
235            iv_bytes.clone(),
236            auth_hash_bytes.clone(),
237        );
238        assert_eq!(
239            header.to_vec(),
240            vec![0, 1, 25, 64, SAAS_SHIELD_EDEK_TYPE_NUM, 0]
241        );
242        assert_eq!(result.iv, iv_bytes);
243        assert_eq!(result.auth_hash, auth_hash_bytes);
244    }
245
246    #[test]
247    fn test_create_produces_standalone() {
248        let iv_bytes: Bytes = (1..12).collect_vec().into();
249        let auth_hash_bytes: Bytes = (1..16).collect_vec().into();
250        let (header, result) = create_vector_metadata(
251            KeyIdHeader::new(
252                EdekType::Standalone,
253                PayloadType::DeterministicField,
254                KeyId(72000),
255            ),
256            iv_bytes.clone(),
257            auth_hash_bytes.clone(),
258        );
259        assert_eq!(
260            header.to_vec(),
261            vec![0, 1, 25, 64, STANDALONE_EDEK_TYPE_NUM, 0]
262        );
263        assert_eq!(result.iv, iv_bytes);
264        assert_eq!(result.auth_hash, auth_hash_bytes);
265    }
266
267    #[test]
268    fn test_encode_decode_roundtrip() {
269        let iv_bytes: Bytes = (1..12).collect_vec().into();
270        let auth_hash_bytes: Bytes = (1..16).collect_vec().into();
271        let key_id = KeyId(72000);
272        let (header, result) = create_vector_metadata(
273            KeyIdHeader::new(EdekType::Standalone, PayloadType::StandardEdek, key_id),
274            iv_bytes.clone(),
275            auth_hash_bytes.clone(),
276        );
277
278        let encode_result = encode_vector_metadata(header, result.clone());
279        let (final_key_id_header, final_vector_bytes) =
280            decode_version_prefixed_value(encode_result).unwrap();
281        assert_eq!(final_key_id_header.key_id, key_id);
282        assert_eq!(final_key_id_header.edek_type, EdekType::Standalone);
283        assert_eq!(final_key_id_header.payload_type, PayloadType::StandardEdek);
284        assert_eq!(final_vector_bytes, result.write_to_bytes().unwrap());
285    }
286
287    fn edek_type_roundtrip(e: EdekType) -> Result<EdekType> {
288        EdekType::from_numeric_value(&e.to_numeric_value())
289    }
290    #[test]
291    fn test_edek_type_to_and_from_roundtrip() {
292        let all_types = [
293            EdekType::Standalone,
294            EdekType::SaasShield,
295            EdekType::DataControlPlatform,
296        ];
297
298        // If you add to this match, add to the array above otherwise the test will pass but you won't be testing them all.
299        for e in all_types {
300            match e {
301                EdekType::Standalone => edek_type_roundtrip(EdekType::Standalone),
302                EdekType::SaasShield => edek_type_roundtrip(EdekType::SaasShield),
303                EdekType::DataControlPlatform => edek_type_roundtrip(EdekType::DataControlPlatform),
304            }
305            .unwrap();
306        }
307    }
308
309    fn payload_type_roundtrip(e: PayloadType) -> Result<PayloadType> {
310        PayloadType::from_numeric_value(&e.to_numeric_value())
311    }
312
313    #[test]
314    fn test_payload_type_to_and_from_roundtrip() {
315        let all_types = [
316            PayloadType::DeterministicField,
317            PayloadType::VectorMetadata,
318            PayloadType::StandardEdek,
319        ];
320
321        // If you add to this match, add to the array above otherwise the test will pass but you won't be testing them all.
322        for e in all_types {
323            match e {
324                PayloadType::DeterministicField => {
325                    payload_type_roundtrip(PayloadType::DeterministicField)
326                }
327                PayloadType::VectorMetadata => payload_type_roundtrip(PayloadType::VectorMetadata),
328                PayloadType::StandardEdek => payload_type_roundtrip(PayloadType::StandardEdek),
329            }
330            .unwrap();
331        }
332    }
333}