1use crate::{AlloyMetadata, TenantId, VectorEncryptionKey, errors::AlloyError};
2use ironcore_documents::v5::key_id_header::KeyId;
3use itertools::Either;
4use protobuf::Message;
5use rand::{
6 SeedableRng,
7 rngs::{OsRng, adapter::ReseedingRng},
8};
9use rand_chacha::{ChaCha20Core, ChaCha20Rng};
10use rayon::iter::ParallelIterator;
11use rayon::iter::{IntoParallelIterator, ParallelExtend};
12use ring::hmac::{HMAC_SHA256, HMAC_SHA512, Key as HMACKey};
13use std::hash::Hash;
14use std::{
15 collections::HashMap,
16 sync::{Arc, Mutex, MutexGuard},
17};
18
19const BYTES_BEFORE_RESEEDING: u64 = 1024 * 1024;
21
22pub(crate) type OurReseedingRng = ReseedingRng<ChaCha20Core, OsRng>;
23
24#[derive(Debug, PartialEq)]
25pub(crate) struct AuthHash(pub(crate) [u8; 32]);
26
27pub fn take_lock<T>(m: &Mutex<T>) -> MutexGuard<T> {
46 m.lock().unwrap_or_else(|e| {
47 let error = format!("Error when acquiring lock: {e}");
48 panic!("{error}");
49 })
50}
51
52pub(crate) fn hash256<K: AsRef<[u8]>, T: AsRef<[u8]>>(key: K, payload: T) -> [u8; 32] {
53 let hmac_key = HMACKey::new(HMAC_SHA256, key.as_ref());
54 ring::hmac::sign(&hmac_key, payload.as_ref())
55 .as_ref()
56 .try_into()
58 .unwrap()
59}
60
61pub(crate) fn hash512<K: AsRef<[u8]>, T: AsRef<[u8]>>(key: K, payload: T) -> [u8; 64] {
62 let hmac_key = HMACKey::new(HMAC_SHA512, key.as_ref());
63 ring::hmac::sign(&hmac_key, payload.as_ref())
64 .as_ref()
65 .try_into()
67 .unwrap()
68}
69
70pub(crate) fn compute_auth_hash<'a, A: AsRef<[u8]>, B: Iterator<Item = &'a f32>>(
71 key: &VectorEncryptionKey,
72 approximation_factor: &f32,
73 iv: A,
74 encrypted_embedding: B,
75) -> AuthHash {
76 let hmac_key = HMACKey::new(HMAC_SHA256, &key.key.0);
77 let mut ctx = ring::hmac::Context::with_key(&hmac_key);
78 ctx.update(key.scaling_factor.0.to_be_bytes().as_ref());
79 ctx.update(approximation_factor.to_be_bytes().as_ref());
80 ctx.update(iv.as_ref());
81 for embedding in encrypted_embedding {
82 ctx.update(embedding.to_be_bytes().as_ref());
83 }
84 let signature: [u8; 32] = ctx
85 .sign()
86 .as_ref()
87 .try_into() .unwrap();
89 AuthHash(signature)
90}
91
92pub(crate) fn check_auth_hash<'a, A: AsRef<[u8]>, B: Iterator<Item = &'a f32>>(
93 key: &VectorEncryptionKey,
94 approximation_factor: &f32,
95 iv: A,
96 encrypted_embedding: B,
97 auth_hash: AuthHash,
98) -> bool {
99 compute_auth_hash(key, approximation_factor, iv, encrypted_embedding) == auth_hash
100}
101
102pub(crate) fn create_reseeding_rng() -> Arc<Mutex<OurReseedingRng>> {
103 Arc::new(Mutex::new(ReseedingRng::new(
104 ChaCha20Core::from_entropy(),
105 BYTES_BEFORE_RESEEDING,
106 OsRng,
107 )))
108}
109
110pub(crate) fn create_test_seeded_rng(seed: u64) -> Arc<Mutex<OurReseedingRng>> {
113 Arc::new(Mutex::new(ReseedingRng::new(
115 ChaCha20Core::seed_from_u64(seed),
116 0,
117 OsRng,
118 )))
119}
120
121pub(crate) fn create_rng_maybe_seeded(maybe_seed: Option<i32>) -> Arc<Mutex<OurReseedingRng>> {
122 maybe_seed
123 .map(|seed| create_test_seeded_rng(seed as u64))
125 .unwrap_or_else(create_reseeding_rng)
126}
127
128pub(crate) fn create_rng<K: AsRef<[u8]>, T: AsRef<[u8]>>(key: K, hash_payload: T) -> ChaCha20Rng {
129 ChaCha20Rng::from_seed(hash256(key, hash_payload))
130}
131
132pub(crate) struct BatchResult<K: Hash, U> {
133 pub successes: HashMap<K, U>,
134 pub failures: HashMap<K, AlloyError>,
135}
136
137#[macro_export]
142macro_rules! create_batch_result_struct {
143 ($struct_name:ident, $success_type:ident, $map_key_type:ident) => {
144 #[derive(Debug, Clone, uniffi::Record)]
145 pub struct $struct_name {
146 pub successes: std::collections::HashMap<$map_key_type, $success_type>,
147 pub failures: std::collections::HashMap<$map_key_type, $crate::errors::AlloyError>,
148 }
149
150 impl From<$crate::util::BatchResult<$map_key_type, $success_type>> for $struct_name {
151 fn from(value: $crate::util::BatchResult<$map_key_type, $success_type>) -> Self {
152 Self {
153 successes: value.successes,
154 failures: value.failures,
155 }
156 }
157 }
158 };
159}
160
161#[macro_export]
167macro_rules! create_batch_result_struct_using_newtype {
168 ($struct_name:ident, $success_type:ident, $map_key_type:ident, $newtype:ident) => {
169 #[derive(Debug, Clone, uniffi::Record)]
170 pub struct $struct_name {
171 pub successes: $newtype,
172 pub failures: std::collections::HashMap<$map_key_type, $crate::errors::AlloyError>,
173 }
174
175 impl From<$crate::util::BatchResult<$map_key_type, $success_type>> for $struct_name {
176 fn from(value: $crate::util::BatchResult<$map_key_type, $success_type>) -> Self {
177 Self {
178 successes: $newtype(value.successes),
179 failures: value.failures,
180 }
181 }
182 }
183 };
184}
185
186pub(crate) fn perform_batch_action<T, U, F, I, K>(collection: I, func: F) -> BatchResult<K, U>
189where
190 F: Fn(T) -> Result<U, AlloyError> + Sync,
191 I: IntoParallelIterator<Item = (K, T)>,
192 K: Hash + Eq + Send,
193 U: Send,
194 HashMap<K, U>: ParallelExtend<(K, U)>,
195 HashMap<K, AlloyError>: ParallelExtend<(K, AlloyError)>,
196{
197 let (successes, failures) = collection
198 .into_par_iter()
199 .map(|(key, value)| match func(value) {
200 Ok(x) => Ok((key, x)),
201 Err(x) => Err((key, x)),
202 })
203 .partition_map(|x| match x {
204 Ok(success) => Either::Left(success),
205 Err(failure) => Either::Right(failure),
206 });
207 BatchResult {
208 successes,
209 failures,
210 }
211}
212
213pub(crate) fn check_rotation_no_op(
215 encrypted_key_id: KeyId,
216 maybe_current_key: &Option<u32>,
217 new_tenant_id: &TenantId,
218 metadata: &AlloyMetadata,
219) -> bool {
220 maybe_current_key == &Some(encrypted_key_id.0) && new_tenant_id == &metadata.tenant_id
221}
222
223pub(crate) fn v4_proto_from_bytes<B: AsRef<[u8]>>(
224 b: B,
225) -> Result<ironcore_documents::icl_header_v4::V4DocumentHeader, AlloyError> {
226 Ok(Message::parse_from_bytes(b.as_ref())?)
227}
228
229#[cfg(test)]
230pub(crate) mod tests {
231 use super::*;
232 use crate::{vector::EncryptionKey, vector::ScalingFactor};
233 use base64::Engine;
234 use bytes::Bytes;
235 use itertools::Itertools;
236 use proptest::prelude::*;
237 use std::collections::HashSet;
238
239 #[test]
241 fn test_hash_empty_key() {
242 assert_eq!(
243 hash256(Bytes::default(), [1u8]).to_vec(),
244 vec![
245 61, 122, 251, 102, 49, 36, 236, 191, 44, 149, 63, 134, 61, 79, 200, 121, 110, 235,
246 45, 55, 43, 100, 170, 213, 134, 151, 236, 82, 100, 100, 156, 219
247 ]
248 );
249 }
250
251 proptest! {
252 #[test]
255 fn roundtrip(arb_msg: Vec<f32>, key: [u8; 32], iv: [u8; 12], scaling_factor: u32) {
256 let f32_scaling_factor = if scaling_factor == 0 {
257 1.
258 } else {
259 scaling_factor as f32
260 };
261 let key = VectorEncryptionKey {
262 scaling_factor: ScalingFactor(f32_scaling_factor),
263 key: EncryptionKey(key.to_vec()),
264 };
265
266 let first_hash = compute_auth_hash(&key, &1.2f32, iv, arb_msg.iter());
267 let roundtrip: Vec<f32> =
268 serde_json::from_str(serde_json::to_string(&arb_msg).unwrap().as_str()).unwrap();
269 let second_hash = compute_auth_hash(&key, &1.2f32, iv, roundtrip.iter());
270
271 proptest::prop_assert_eq!(first_hash, second_hash);
272 }
273
274 #[test]
275 fn ascii85_encoding_produces_consistent_prefix(arb_msg: [u8; 2], id in 1..u32::MAX) {
276
277 let prefix_func = |prefix_bytes:&[u8]| {
278 let prefix_padded = [prefix_bytes, &arb_msg[..]].concat();
285 let mut starting = ascii85::encode(prefix_padded.as_slice());
286 starting.pop();
288 starting.pop();
289 starting.pop();
292 starting.pop();
293 starting.pop();
294 starting
295 };
296
297 let encode_func = |whole:&[u8]| ascii85::encode(whole);
298 encoding_produces_consistent_prefix(id,&arb_msg[..], prefix_func, encode_func)?
299 }
300
301 #[test]
303 fn z85_encoding_produces_consistent_prefix(arb_msg: [u8; 4], id in 1..u32::MAX) {
304
305 let prefix_func = |prefix_bytes:&[u8]| {
306 let prefix_padded = [prefix_bytes, &[0,0]].concat();
311 let mut starting = z85::encode(prefix_padded.as_slice());
312 starting.pop();
315 starting.pop();
316 starting.pop();
317 starting
318 };
319
320 let encode_func = |whole:&[u8]| z85::encode(whole);
321 encoding_produces_consistent_prefix(id,&arb_msg[..], prefix_func, encode_func)?
322 }
323
324 #[test]
326 fn base85_encoding_produces_consistent_prefix(arb_msg: [u8; 4], id in 1..u32::MAX) {
327
328 let prefix_func = |prefix_bytes:&[u8]| {
329 let prefix_padded = [prefix_bytes, &[0, 0]].concat();
334 let mut starting = base85::encode(prefix_padded.as_slice());
335 starting.pop();
338 starting.pop();
339 starting.pop();
340 starting
341 };
342
343 let encode_func = |whole:&[u8]| base85::encode(whole);
344 encoding_produces_consistent_prefix(id,&arb_msg[..], prefix_func, encode_func)?
345 }
346 #[test]
348 fn base64_encoding_produces_consistent_prefix(arb_msg: [u8; 4], id in 1..u32::MAX) {
349
350 let prefix_func = |prefix_bytes:&[u8]| {
351 base64::engine::general_purpose::STANDARD_NO_PAD.encode(prefix_bytes)
352 };
353
354 let encode_func = |whole:&[u8]| base64::engine::general_purpose::STANDARD.encode(whole);
355 encoding_produces_consistent_prefix(id,&arb_msg[..], prefix_func, encode_func)?
356 }
357
358 #[test]
360 fn base64_url_encoding_produces_consistent_prefix(arb_msg: [u8; 4], id in 1..u32::MAX) {
361
362 let prefix_func = |prefix_bytes:&[u8]| {
363 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(prefix_bytes)
364 };
365
366 let encode_func = |whole:&[u8]| base64::engine::general_purpose::URL_SAFE.encode(whole);
367 encoding_produces_consistent_prefix(id,&arb_msg[..], prefix_func, encode_func)?
368 }
369 }
370
371 fn encoding_produces_consistent_prefix<F, G>(
373 id: u32,
374 arb_msg: &[u8],
375 encode_prefix: F,
376 encode_message: G,
377 ) -> Result<(), TestCaseError>
378 where
379 F: Fn(&[u8]) -> String,
380 G: Fn(&[u8]) -> String,
381 {
382 let second_byte_padding = 0u8;
383
384 let unique_padding_byte_vec = vec![128, 64, 32, 16, 8, 4, 2, 1, 0u8];
387 let unique_padding_count = unique_padding_byte_vec.len();
388 let mut prefixes: HashSet<String> = HashSet::new();
390 for first_byte_padding in unique_padding_byte_vec {
391 let prefix_bytes = [
393 &u32::to_be_bytes(id)[..],
394 &[first_byte_padding],
395 &[second_byte_padding],
396 ]
397 .concat();
398 let prefix = encode_prefix(prefix_bytes.as_slice());
400 let full_vec = prefix_bytes
401 .into_iter()
402 .chain(arb_msg.iter().copied())
403 .collect_vec();
404 let encoded_string = encode_message(full_vec.as_slice());
406 prop_assert!(encoded_string.starts_with(&prefix));
408 prefixes.insert(prefix);
409 }
410 prop_assert_eq!(prefixes.len(), unique_padding_count);
411 Ok(())
412 }
413}