Skip to main content

ironcore_documents/v5/
key_id_header.rs

1use bytes::Bytes;
2use itertools::Itertools;
3use protobuf::Message;
4
5use crate::{Error, Result, vector_encryption_metadata::VectorEncryptionMetadata};
6use std::fmt::Display;
7
8// This file is for functions which are working with our key id header value.
9// This value has the following structure:
10// 4 Byte id. This value is a u32 encoded in big endian format.
11// 1 Byte where the first 4 bits are used for which type of edek the id points to (Standalone, Saas Shield, DCP).
12//   The next 4 bits are to denote which type of data follows it (vector metadata, IronCore Edoc, deterministic ciphertext)
13// 1 Byte of 0
14
15// EdekType numeric values. Note that in order to compare to these values you must bitmask
16// off the bottom 4 bits of the byte first.
17const SAAS_SHIELD_EDEK_TYPE_NUM: u8 = 0u8;
18const STANDALONE_EDEK_TYPE_NUM: u8 = 128u8;
19const DCP_EDEK_TYPE_NUM: u8 = 64u8;
20
21// PayloadType numeric values.Note that in order to compare to these values you must bitmask
22// off the top 4 bits of the byte first.
23const DETERMINISTIC_PAYLOAD_TYPE_NUM: u8 = 0u8;
24const VECTOR_METADATA_PAYLOAD_TYPE_NUM: u8 = 1u8;
25const STANDARD_EDEK_PAYLOAD_TYPE_NUM: u8 = 2u8;
26
27pub(crate) const KEY_ID_HEADER_LEN: usize = 6;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct KeyId(pub u32);
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
33pub enum PayloadType {
34    DeterministicField,
35    VectorMetadata,
36    StandardEdek,
37}
38
39impl Display for PayloadType {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            PayloadType::DeterministicField => write!(f, "Deterministic Field"),
43            PayloadType::VectorMetadata => write!(f, "Vector Metadata"),
44            PayloadType::StandardEdek => write!(f, "Standard EDEK"),
45        }
46    }
47}
48
49impl PayloadType {
50    pub(crate) fn to_numeric_value(self) -> u8 {
51        match self {
52            PayloadType::DeterministicField => DETERMINISTIC_PAYLOAD_TYPE_NUM,
53            PayloadType::VectorMetadata => VECTOR_METADATA_PAYLOAD_TYPE_NUM,
54            PayloadType::StandardEdek => STANDARD_EDEK_PAYLOAD_TYPE_NUM,
55        }
56    }
57
58    pub(crate) fn from_numeric_value(candidate: &u8) -> Result<PayloadType> {
59        let masked_candidate = candidate & 0x0F; // Mask off the top 4 bits.
60        match masked_candidate {
61            DETERMINISTIC_PAYLOAD_TYPE_NUM => Ok(PayloadType::DeterministicField),
62            VECTOR_METADATA_PAYLOAD_TYPE_NUM => Ok(PayloadType::VectorMetadata),
63            STANDARD_EDEK_PAYLOAD_TYPE_NUM => Ok(PayloadType::StandardEdek),
64            _ => Err(Error::PayloadTypeError(format!(
65                "Byte {masked_candidate} isn't a valid payload type."
66            ))),
67        }
68    }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
72pub enum EdekType {
73    Standalone,
74    SaasShield,
75    DataControlPlatform,
76}
77
78impl Display for EdekType {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        match self {
81            EdekType::Standalone => write!(f, "Standalone"),
82            EdekType::SaasShield => write!(f, "SaaS Shield"),
83            EdekType::DataControlPlatform => write!(f, "Data Control Platform"),
84        }
85    }
86}
87
88impl EdekType {
89    pub(crate) fn to_numeric_value(self) -> u8 {
90        match self {
91            EdekType::SaasShield => SAAS_SHIELD_EDEK_TYPE_NUM,
92            EdekType::Standalone => STANDALONE_EDEK_TYPE_NUM,
93            EdekType::DataControlPlatform => DCP_EDEK_TYPE_NUM,
94        }
95    }
96
97    pub(crate) fn from_numeric_value(candidate: &u8) -> Result<EdekType> {
98        let masked_candidate = candidate & 0xF0; // Mask off the bottom 4 bits.
99        match masked_candidate {
100            SAAS_SHIELD_EDEK_TYPE_NUM => Ok(EdekType::SaasShield),
101            STANDALONE_EDEK_TYPE_NUM => Ok(EdekType::Standalone),
102            DCP_EDEK_TYPE_NUM => Ok(EdekType::DataControlPlatform),
103            _ => Err(Error::EdekTypeError(format!(
104                "Byte {masked_candidate} isn't a valid edek type."
105            ))),
106        }
107    }
108}
109
110/// The key id header parsed into its pieces.
111#[derive(Debug, PartialEq)]
112pub struct KeyIdHeader {
113    pub key_id: KeyId,
114    pub edek_type: EdekType,
115    pub payload_type: PayloadType,
116}
117
118impl KeyIdHeader {
119    pub fn new(edek_type: EdekType, payload_type: PayloadType, key_id: KeyId) -> KeyIdHeader {
120        KeyIdHeader {
121            edek_type,
122            payload_type,
123            key_id,
124        }
125    }
126
127    /// Write this header onto the front of the document.
128    pub fn put_header_on_document<U: IntoIterator<Item = u8>>(&self, document: U) -> Bytes {
129        self.write_to_bytes().into_iter().chain(document).collect()
130    }
131
132    /// Write the header to bytes. This is done by writing the key_id to be 4 bytes, putting the edek and payload types into
133    /// the next byte and padding with a zero. See the comment at the top of this file for more information.
134    pub fn write_to_bytes(&self) -> Bytes {
135        let iter = u32::to_be_bytes(self.key_id.0).into_iter().chain([
136            self.edek_type.to_numeric_value() | self.payload_type.to_numeric_value(),
137            0u8,
138        ]);
139        Bytes::from_iter(iter)
140    }
141
142    /// This is not public because callers should use use decode_version_prefixed_value instead.
143    pub(crate) fn parse_from_bytes(b: [u8; 6]) -> Result<KeyIdHeader> {
144        let [one, two, three, four, five, six] = b;
145        if six == 0u8 {
146            let key_id = KeyId(u32::from_be_bytes([one, two, three, four]));
147            let edek_type = EdekType::from_numeric_value(&five)?;
148            let payload_type = PayloadType::from_numeric_value(&five)?;
149            Ok(KeyIdHeader {
150                edek_type,
151                payload_type,
152                key_id,
153            })
154        } else {
155            Err(Error::KeyIdHeaderMalformed(format!(
156                "The last byte of the header should be 0, but it was {six}"
157            )))
158        }
159    }
160}
161
162/// Create the key_id_header and vector metadata. The first value is the key_id header and the
163/// second is the vector metadata. These can be passed to encode_vector_metadata to create a single
164/// byte string.
165pub fn create_vector_metadata(
166    key_id_header: KeyIdHeader,
167    iv: Bytes,
168    auth_hash: Bytes,
169) -> (Bytes, VectorEncryptionMetadata) {
170    let vector_encryption_metadata = VectorEncryptionMetadata {
171        iv,
172        auth_hash,
173        ..Default::default()
174    };
175    (key_id_header.write_to_bytes(), vector_encryption_metadata)
176}
177
178/// Form the bytes that represent the vector metadata to the outside world.
179/// This is the protobuf with the key_id_header put onto the front.
180pub fn encode_vector_metadata(
181    key_id_header_bytes: Bytes,
182    vector_metadata: VectorEncryptionMetadata,
183) -> Bytes {
184    key_id_header_bytes
185        .into_iter()
186        .chain(
187            vector_metadata
188                .write_to_bytes()
189                .expect("Writing to in memory bytes failed"),
190        )
191        .collect_vec()
192        .into()
193}
194
195/// Decode a value which has the key_id_header put on the front by breaking it up.
196/// This returns the key id, edek type and the remaining bytes.
197pub fn decode_version_prefixed_value(mut value: Bytes) -> Result<(KeyIdHeader, Bytes)> {
198    let value_len = value.len();
199    if value_len >= KEY_ID_HEADER_LEN {
200        let rest = value.split_off(KEY_ID_HEADER_LEN);
201        match value[..] {
202            [one, two, three, four, five, six] => {
203                let key_id_header =
204                    KeyIdHeader::parse_from_bytes([one, two, three, four, five, six])?;
205                Ok((key_id_header, rest))
206            }
207            // This should not ever be able to happen since we sliced off 6 above
208            _ => Err(Error::KeyIdHeaderTooShort(value_len)),
209        }
210    } else {
211        Err(Error::KeyIdHeaderTooShort(value_len))
212    }
213}
214
215/// Get the bytes that can be used for a prefix search of key_id headers.
216pub fn get_prefix_bytes_for_search(key_id_header: KeyIdHeader) -> Bytes {
217    key_id_header.write_to_bytes()
218}
219
220#[cfg(test)]
221mod test {
222
223    use super::*;
224    #[test]
225    fn test_create_produces_saas_shield() {
226        let iv_bytes: Bytes = (1..12).collect_vec().into();
227        let auth_hash_bytes: Bytes = (1..16).collect_vec().into();
228        let (header, result) = create_vector_metadata(
229            KeyIdHeader::new(
230                EdekType::SaasShield,
231                PayloadType::DeterministicField,
232                KeyId(72000),
233            ),
234            iv_bytes.clone(),
235            auth_hash_bytes.clone(),
236        );
237        assert_eq!(
238            header.to_vec(),
239            vec![0, 1, 25, 64, SAAS_SHIELD_EDEK_TYPE_NUM, 0]
240        );
241        assert_eq!(result.iv, iv_bytes);
242        assert_eq!(result.auth_hash, auth_hash_bytes);
243    }
244
245    #[test]
246    fn test_create_produces_standalone() {
247        let iv_bytes: Bytes = (1..12).collect_vec().into();
248        let auth_hash_bytes: Bytes = (1..16).collect_vec().into();
249        let (header, result) = create_vector_metadata(
250            KeyIdHeader::new(
251                EdekType::Standalone,
252                PayloadType::DeterministicField,
253                KeyId(72000),
254            ),
255            iv_bytes.clone(),
256            auth_hash_bytes.clone(),
257        );
258        assert_eq!(
259            header.to_vec(),
260            vec![0, 1, 25, 64, STANDALONE_EDEK_TYPE_NUM, 0]
261        );
262        assert_eq!(result.iv, iv_bytes);
263        assert_eq!(result.auth_hash, auth_hash_bytes);
264    }
265
266    #[test]
267    fn test_encode_decode_roundtrip() {
268        let iv_bytes: Bytes = (1..12).collect_vec().into();
269        let auth_hash_bytes: Bytes = (1..16).collect_vec().into();
270        let key_id = KeyId(72000);
271        let (header, result) = create_vector_metadata(
272            KeyIdHeader::new(EdekType::Standalone, PayloadType::StandardEdek, key_id),
273            iv_bytes.clone(),
274            auth_hash_bytes.clone(),
275        );
276
277        let encode_result = encode_vector_metadata(header, result.clone());
278        let (final_key_id_header, final_vector_bytes) =
279            decode_version_prefixed_value(encode_result).unwrap();
280        assert_eq!(final_key_id_header.key_id, key_id);
281        assert_eq!(final_key_id_header.edek_type, EdekType::Standalone);
282        assert_eq!(final_key_id_header.payload_type, PayloadType::StandardEdek);
283        assert_eq!(final_vector_bytes, result.write_to_bytes().unwrap());
284    }
285
286    fn edek_type_roundtrip(e: EdekType) -> Result<EdekType> {
287        EdekType::from_numeric_value(&e.to_numeric_value())
288    }
289    #[test]
290    fn test_edek_type_to_and_from_roundtrip() {
291        let all_types = [
292            EdekType::Standalone,
293            EdekType::SaasShield,
294            EdekType::DataControlPlatform,
295        ];
296
297        // If you add to this match, add to the array above otherwise the test will pass but you won't be testing them all.
298        for e in all_types {
299            match e {
300                EdekType::Standalone => edek_type_roundtrip(EdekType::Standalone),
301                EdekType::SaasShield => edek_type_roundtrip(EdekType::SaasShield),
302                EdekType::DataControlPlatform => edek_type_roundtrip(EdekType::DataControlPlatform),
303            }
304            .unwrap();
305        }
306    }
307
308    fn payload_type_roundtrip(e: PayloadType) -> Result<PayloadType> {
309        PayloadType::from_numeric_value(&e.to_numeric_value())
310    }
311
312    #[test]
313    fn test_payload_type_to_and_from_roundtrip() {
314        let all_types = [
315            PayloadType::DeterministicField,
316            PayloadType::VectorMetadata,
317            PayloadType::StandardEdek,
318        ];
319
320        // If you add to this match, add to the array above otherwise the test will pass but you won't be testing them all.
321        for e in all_types {
322            match e {
323                PayloadType::DeterministicField => {
324                    payload_type_roundtrip(PayloadType::DeterministicField)
325                }
326                PayloadType::VectorMetadata => payload_type_roundtrip(PayloadType::VectorMetadata),
327                PayloadType::StandardEdek => payload_type_roundtrip(PayloadType::StandardEdek),
328            }
329            .unwrap();
330        }
331    }
332}