ironcore_alloy/
util.rs

1use crate::{AlloyMetadata, TenantId, VectorEncryptionKey, errors::AlloyError};
2use ironcore_documents::v5::key_id_header::KeyId;
3use itertools::Either;
4use protobuf::Message;
5use rand::{
6    SeedableRng,
7    rngs::{OsRng, adapter::ReseedingRng},
8};
9use rand_chacha::{ChaCha20Core, ChaCha20Rng};
10use rayon::iter::ParallelIterator;
11use rayon::iter::{IntoParallelIterator, ParallelExtend};
12use ring::hmac::{HMAC_SHA256, HMAC_SHA512, Key as HMACKey};
13use std::hash::Hash;
14use std::{
15    collections::HashMap,
16    sync::{Arc, Mutex, MutexGuard},
17};
18
19/// number of bytes that can be read from before it rngs are reseeded. 1 MiB
20const BYTES_BEFORE_RESEEDING: u64 = 1024 * 1024;
21
22pub(crate) type OurReseedingRng = ReseedingRng<ChaCha20Core, OsRng>;
23
24#[derive(Debug, PartialEq)]
25pub(crate) struct AuthHash(pub(crate) [u8; 32]);
26
27/// Acquire mutex in a blocking fashion. If the Mutex is or becomes poisoned, write out an error
28/// message and panic.
29///
30/// The lock is released when the returned MutexGuard falls out of scope.
31///
32/// # Usage:
33/// single statement (mut)
34/// `let result = take_lock(&t).deref_mut().call_method_on_t();`
35///
36/// multi-statement (mut)
37/// ```ignore
38/// let t = T {};
39/// let result = {
40///     let g = &mut *take_lock(&t);
41///     g.call_method_on_t()
42/// }; // lock released here
43/// ```
44///
45pub fn take_lock<T>(m: &Mutex<T>) -> MutexGuard<T> {
46    m.lock().unwrap_or_else(|e| {
47        let error = format!("Error when acquiring lock: {e}");
48        panic!("{error}");
49    })
50}
51
52pub(crate) fn hash256<K: AsRef<[u8]>, T: AsRef<[u8]>>(key: K, payload: T) -> [u8; 32] {
53    let hmac_key = HMACKey::new(HMAC_SHA256, key.as_ref());
54    ring::hmac::sign(&hmac_key, payload.as_ref())
55        .as_ref()
56        // this is safe because digest output len (SHA256_OUTPUT_LEN) == 32
57        .try_into()
58        .unwrap()
59}
60
61pub(crate) fn hash512<K: AsRef<[u8]>, T: AsRef<[u8]>>(key: K, payload: T) -> [u8; 64] {
62    let hmac_key = HMACKey::new(HMAC_SHA512, key.as_ref());
63    ring::hmac::sign(&hmac_key, payload.as_ref())
64        .as_ref()
65        // this is safe because digest output len (SHA256_OUTPUT_LEN) == 32
66        .try_into()
67        .unwrap()
68}
69
70pub(crate) fn compute_auth_hash<'a, A: AsRef<[u8]>, B: Iterator<Item = &'a f32>>(
71    key: &VectorEncryptionKey,
72    approximation_factor: &f32,
73    iv: A,
74    encrypted_embedding: B,
75) -> AuthHash {
76    let hmac_key = HMACKey::new(HMAC_SHA256, &key.key.0);
77    let mut ctx = ring::hmac::Context::with_key(&hmac_key);
78    ctx.update(key.scaling_factor.0.to_be_bytes().as_ref());
79    ctx.update(approximation_factor.to_be_bytes().as_ref());
80    ctx.update(iv.as_ref());
81    for embedding in encrypted_embedding {
82        ctx.update(embedding.to_be_bytes().as_ref());
83    }
84    let signature: [u8; 32] = ctx
85        .sign()
86        .as_ref()
87        .try_into() // this is safe because digest output len (SHA256_OUTPUT_LEN) == 32
88        .unwrap();
89    AuthHash(signature)
90}
91
92pub(crate) fn check_auth_hash<'a, A: AsRef<[u8]>, B: Iterator<Item = &'a f32>>(
93    key: &VectorEncryptionKey,
94    approximation_factor: &f32,
95    iv: A,
96    encrypted_embedding: B,
97    auth_hash: AuthHash,
98) -> bool {
99    compute_auth_hash(key, approximation_factor, iv, encrypted_embedding) == auth_hash
100}
101
102pub(crate) fn create_reseeding_rng() -> Arc<Mutex<OurReseedingRng>> {
103    Arc::new(Mutex::new(ReseedingRng::new(
104        ChaCha20Core::from_entropy(),
105        BYTES_BEFORE_RESEEDING,
106        OsRng,
107    )))
108}
109
110/// Creates a seeded RNG that won't actually ever reseed to use in test functions from the FFI and in the case
111/// that users are creating a client for testing
112pub(crate) fn create_test_seeded_rng(seed: u64) -> Arc<Mutex<OurReseedingRng>> {
113    //Note that this will never actually reseed because the threshold is 0.
114    Arc::new(Mutex::new(ReseedingRng::new(
115        ChaCha20Core::seed_from_u64(seed),
116        0,
117        OsRng,
118    )))
119}
120
121pub(crate) fn create_rng_maybe_seeded(maybe_seed: Option<i32>) -> Arc<Mutex<OurReseedingRng>> {
122    maybe_seed
123        //We don't care that the negative numbers turn into giant numbers for the seed we just need a static value.
124        .map(|seed| create_test_seeded_rng(seed as u64))
125        .unwrap_or_else(create_reseeding_rng)
126}
127
128pub(crate) fn create_rng<K: AsRef<[u8]>, T: AsRef<[u8]>>(key: K, hash_payload: T) -> ChaCha20Rng {
129    ChaCha20Rng::from_seed(hash256(key, hash_payload))
130}
131
132pub(crate) struct BatchResult<K: Hash, U> {
133    pub successes: HashMap<K, U>,
134    pub failures: HashMap<K, AlloyError>,
135}
136
137/// Creates a batch result struct named after the first parameter.
138/// Uses the second parameter as the success type.
139/// Uses the third parameter as the key type for the HashMaps.
140/// The type will be a uniffi Record and it will have a From impl for BatchResult.
141#[macro_export]
142macro_rules! create_batch_result_struct {
143    ($struct_name:ident, $success_type:ident, $map_key_type:ident) => {
144        #[derive(Debug, Clone, uniffi::Record)]
145        pub struct $struct_name {
146            pub successes: std::collections::HashMap<$map_key_type, $success_type>,
147            pub failures: std::collections::HashMap<$map_key_type, $crate::errors::AlloyError>,
148        }
149
150        impl From<$crate::util::BatchResult<$map_key_type, $success_type>> for $struct_name {
151            fn from(value: $crate::util::BatchResult<$map_key_type, $success_type>) -> Self {
152                Self {
153                    successes: value.successes,
154                    failures: value.failures,
155                }
156            }
157        }
158    };
159}
160
161/// Creates a batch result struct named after the first parameter.
162/// The second parameter is the success type.
163/// Uses the third parameter as the key type for the failure HashMap.
164/// The fourth parameter is a newtype conaining a Map<KeyType, SuccessType>.
165/// The type will be a uniffi Record and it will have a From impl for BatchResult.
166#[macro_export]
167macro_rules! create_batch_result_struct_using_newtype {
168    ($struct_name:ident, $success_type:ident, $map_key_type:ident, $newtype:ident) => {
169        #[derive(Debug, Clone, uniffi::Record)]
170        pub struct $struct_name {
171            pub successes: $newtype,
172            pub failures: std::collections::HashMap<$map_key_type, $crate::errors::AlloyError>,
173        }
174
175        impl From<$crate::util::BatchResult<$map_key_type, $success_type>> for $struct_name {
176            fn from(value: $crate::util::BatchResult<$map_key_type, $success_type>) -> Self {
177                Self {
178                    successes: $newtype(value.successes),
179                    failures: value.failures,
180                }
181            }
182        }
183    };
184}
185
186/// Applies the function `func` to all the values of `collection`, then partitions them into
187/// success and failure hashmaps.
188pub(crate) fn perform_batch_action<T, U, F, I, K>(collection: I, func: F) -> BatchResult<K, U>
189where
190    F: Fn(T) -> Result<U, AlloyError> + Sync,
191    I: IntoParallelIterator<Item = (K, T)>,
192    K: Hash + Eq + Send,
193    U: Send,
194    HashMap<K, U>: ParallelExtend<(K, U)>,
195    HashMap<K, AlloyError>: ParallelExtend<(K, AlloyError)>,
196{
197    let (successes, failures) = collection
198        .into_par_iter()
199        .map(|(key, value)| match func(value) {
200            Ok(x) => Ok((key, x)),
201            Err(x) => Err((key, x)),
202        })
203        .partition_map(|x| match x {
204            Ok(success) => Either::Left(success),
205            Err(failure) => Either::Right(failure),
206        });
207    BatchResult {
208        successes,
209        failures,
210    }
211}
212
213/// Returns `true` if the key IDs and tenant IDs are identical, otherwise `false`.
214pub(crate) fn check_rotation_no_op(
215    encrypted_key_id: KeyId,
216    maybe_current_key: &Option<u32>,
217    new_tenant_id: &TenantId,
218    metadata: &AlloyMetadata,
219) -> bool {
220    maybe_current_key == &Some(encrypted_key_id.0) && new_tenant_id == &metadata.tenant_id
221}
222
223pub(crate) fn v4_proto_from_bytes<B: AsRef<[u8]>>(
224    b: B,
225) -> Result<ironcore_documents::icl_header_v4::V4DocumentHeader, AlloyError> {
226    Ok(Message::parse_from_bytes(b.as_ref())?)
227}
228
229#[cfg(test)]
230pub(crate) mod tests {
231    use super::*;
232    use crate::{vector::EncryptionKey, vector::ScalingFactor};
233    use base64::Engine;
234    use bytes::Bytes;
235    use itertools::Itertools;
236    use proptest::prelude::*;
237    use std::collections::HashSet;
238
239    // Test to show that an empty key produces a reasonable signature.
240    #[test]
241    fn test_hash_empty_key() {
242        assert_eq!(
243            hash256(Bytes::default(), [1u8]).to_vec(),
244            vec![
245                61, 122, 251, 102, 49, 36, 236, 191, 44, 149, 63, 134, 61, 79, 200, 121, 110, 235,
246                45, 55, 43, 100, 170, 213, 134, 151, 236, 82, 100, 100, 156, 219
247            ]
248        );
249    }
250
251    proptest! {
252        // This is to test that values written out and read by serde_json work the same as
253        // just generic f32 values.
254        #[test]
255        fn roundtrip(arb_msg: Vec<f32>, key: [u8; 32], iv: [u8; 12], scaling_factor: u32) {
256            let f32_scaling_factor = if scaling_factor == 0 {
257                1.
258            } else {
259                scaling_factor as f32
260            };
261            let key = VectorEncryptionKey {
262                scaling_factor: ScalingFactor(f32_scaling_factor),
263                key: EncryptionKey(key.to_vec()),
264            };
265
266            let first_hash = compute_auth_hash(&key, &1.2f32, iv, arb_msg.iter());
267            let roundtrip: Vec<f32> =
268                serde_json::from_str(serde_json::to_string(&arb_msg).unwrap().as_str()).unwrap();
269            let second_hash = compute_auth_hash(&key, &1.2f32, iv, roundtrip.iter());
270
271            proptest::prop_assert_eq!(first_hash, second_hash);
272        }
273
274        #[test]
275        fn ascii85_encoding_produces_consistent_prefix(arb_msg: [u8; 2], id in 1..u32::MAX) {
276
277            let prefix_func = |prefix_bytes:&[u8]| {
278                // When we calculate the prefix, we need to pad out to the full 4 byte width
279                // so that the algorithm produces the correct character in the 2nd position.
280                // This is due to the way 85 bit encodings change the last character if you don't
281                // have a full 4 byte input.
282                // Note that this test is different because it's using 2 random bytes as the padding
283                // here showing that it's not relevant what the padding was, just that it's present.
284                let prefix_padded = [prefix_bytes, &arb_msg[..]].concat();
285                let mut starting = ascii85::encode(prefix_padded.as_slice());
286                // drop the last 2 chars as they're the padding `~>` in the case of ascii85
287                starting.pop();
288                starting.pop();
289                // drop the last 3 characters of the produced string
290                // as they could be affected by the random bytes
291                starting.pop();
292                starting.pop();
293                starting.pop();
294                starting
295            };
296
297            let encode_func = |whole:&[u8]| ascii85::encode(whole);
298            encoding_produces_consistent_prefix(id,&arb_msg[..], prefix_func, encode_func)?
299        }
300
301        // Generating 4 bytes for arb_msg since I want to test padding cases without a full chunk
302        #[test]
303        fn z85_encoding_produces_consistent_prefix(arb_msg: [u8; 4], id in 1..u32::MAX) {
304
305            let prefix_func = |prefix_bytes:&[u8]| {
306                // When we calculate the prefix, we need to pad out to the full 4 byte width
307                // so that the algorithm produces the correct character in the 2nd position.
308                // This is due to the way 85 bit encodings change the last character if you don't
309                // have a full 4 byte input.
310                let prefix_padded = [prefix_bytes, &[0,0]].concat();
311                let mut starting = z85::encode(prefix_padded.as_slice());
312                // drop the last 3 characters of the produced string
313                // as they could be affected by the random bytes
314                starting.pop();
315                starting.pop();
316                starting.pop();
317                starting
318            };
319
320            let encode_func = |whole:&[u8]| z85::encode(whole);
321            encoding_produces_consistent_prefix(id,&arb_msg[..], prefix_func, encode_func)?
322        }
323
324        // Generating 4 bytes for arb_msg since I want to test padding cases without a full chunk
325        #[test]
326        fn base85_encoding_produces_consistent_prefix(arb_msg: [u8; 4], id in 1..u32::MAX) {
327
328            let prefix_func = |prefix_bytes:&[u8]| {
329                // When we calculate the prefix, we need to pad out to the full 4 byte width
330                // so that the algorithm produces the correct character in the 2nd position.
331                // This is due to the way 85 bit encodings change the last character if you don't
332                // have a full 4 byte input.
333                let prefix_padded = [prefix_bytes, &[0, 0]].concat();
334                let mut starting = base85::encode(prefix_padded.as_slice());
335                // drop the last 3 characters of the produced string
336                // as they could be affected by the random bytes
337                starting.pop();
338                starting.pop();
339                starting.pop();
340                starting
341            };
342
343            let encode_func = |whole:&[u8]| base85::encode(whole);
344            encoding_produces_consistent_prefix(id,&arb_msg[..], prefix_func, encode_func)?
345        }
346        // Generating 4 bytes for arb_msg since I want to test padding cases without a full chunk
347        #[test]
348        fn base64_encoding_produces_consistent_prefix(arb_msg: [u8; 4], id in 1..u32::MAX) {
349
350            let prefix_func = |prefix_bytes:&[u8]| {
351                base64::engine::general_purpose::STANDARD_NO_PAD.encode(prefix_bytes)
352            };
353
354            let encode_func = |whole:&[u8]| base64::engine::general_purpose::STANDARD.encode(whole);
355            encoding_produces_consistent_prefix(id,&arb_msg[..], prefix_func, encode_func)?
356        }
357
358        // Generating 4 bytes for arb_msg since I want to test padding cases without a full chunk
359        #[test]
360        fn base64_url_encoding_produces_consistent_prefix(arb_msg: [u8; 4], id in 1..u32::MAX) {
361
362            let prefix_func = |prefix_bytes:&[u8]| {
363                base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(prefix_bytes)
364            };
365
366            let encode_func = |whole:&[u8]| base64::engine::general_purpose::URL_SAFE.encode(whole);
367            encoding_produces_consistent_prefix(id,&arb_msg[..], prefix_func, encode_func)?
368        }
369    }
370
371    // Note that this does _not_ work if you allow an id of 0
372    fn encoding_produces_consistent_prefix<F, G>(
373        id: u32,
374        arb_msg: &[u8],
375        encode_prefix: F,
376        encode_message: G,
377    ) -> Result<(), TestCaseError>
378    where
379        F: Fn(&[u8]) -> String,
380        G: Fn(&[u8]) -> String,
381    {
382        let second_byte_padding = 0u8;
383
384        // We want to check that if each of the bits of the first padding byte were set
385        // that the resulting
386        let unique_padding_byte_vec = vec![128, 64, 32, 16, 8, 4, 2, 1, 0u8];
387        let unique_padding_count = unique_padding_byte_vec.len();
388        // A place to collect all the prefixes we calculate using the above vec.
389        let mut prefixes: HashSet<String> = HashSet::new();
390        for first_byte_padding in unique_padding_byte_vec {
391            // These are the bytes we'll put on the encrypted data.
392            let prefix_bytes = [
393                &u32::to_be_bytes(id)[..],
394                &[first_byte_padding],
395                &[second_byte_padding],
396            ]
397            .concat();
398            // This is the string representation of the prefix bytes.
399            let prefix = encode_prefix(prefix_bytes.as_slice());
400            let full_vec = prefix_bytes
401                .into_iter()
402                .chain(arb_msg.iter().copied())
403                .collect_vec();
404            // String encoded prefix + arbitrary bytes.
405            let encoded_string = encode_message(full_vec.as_slice());
406            // Our calculated prefix should always be on the front of the encoded message.
407            prop_assert!(encoded_string.starts_with(&prefix));
408            prefixes.insert(prefix);
409        }
410        prop_assert_eq!(prefixes.len(), unique_padding_count);
411        Ok(())
412    }
413}