1use self::crypto::{EncryptResult, shuffle, unshuffle};
2use crate::{
3 AlloyMetadata, DerivationPath, EncryptedBytes, Secret, SecretPath, TenantId,
4 create_batch_result_struct, create_batch_result_struct_using_newtype,
5 errors::AlloyError,
6 util::{self, AuthHash},
7};
8use bytes::Bytes;
9use ironcore_documents::{
10 impl_secret_debug,
11 v5::{
12 self,
13 key_id_header::{EdekType, KeyId, KeyIdHeader, PayloadType},
14 },
15 vector_encryption_metadata::VectorEncryptionMetadata,
16};
17use itertools::Itertools;
18use rand::CryptoRng;
19use serde::Serialize;
20use std::{
21 collections::HashMap,
22 sync::{Arc, Mutex},
23};
24use uniffi::custom_newtype;
25
26pub(crate) mod crypto;
27
28#[derive(Debug, Clone, Hash, PartialEq, Eq)]
29pub struct VectorId(pub String);
30custom_newtype!(VectorId, String);
31
32#[derive(Clone, Debug, uniffi::Record)]
33pub struct EncryptedVector {
34 pub encrypted_vector: Vec<f32>,
35 pub secret_path: SecretPath,
36 pub derivation_path: DerivationPath,
37 pub paired_icl_info: EncryptedBytes,
38}
39
40#[derive(Clone, Debug, uniffi::Record)]
41pub struct PlaintextVector {
42 pub plaintext_vector: Vec<f32>,
43 pub secret_path: SecretPath,
44 pub derivation_path: DerivationPath,
45}
46
47#[derive(Debug, Clone)]
48pub struct PlaintextVectors(pub HashMap<VectorId, PlaintextVector>);
49custom_newtype!(PlaintextVectors, HashMap<VectorId, PlaintextVector>);
50#[derive(Debug, Clone)]
51pub struct EncryptedVectors(pub HashMap<VectorId, EncryptedVector>);
52custom_newtype!(EncryptedVectors, HashMap<VectorId, EncryptedVector>);
53pub struct GenerateVectorQueryResult(pub HashMap<VectorId, Vec<EncryptedVector>>);
54custom_newtype!(GenerateVectorQueryResult, HashMap<VectorId, Vec<EncryptedVector>>);
55create_batch_result_struct!(VectorRotateResult, EncryptedVector, VectorId);
56
57create_batch_result_struct_using_newtype!(
58 VectorEncryptBatchResult,
59 EncryptedVector,
60 VectorId,
61 EncryptedVectors
62);
63create_batch_result_struct_using_newtype!(
64 VectorDecryptBatchResult,
65 PlaintextVector,
66 VectorId,
67 PlaintextVectors
68);
69
70#[derive(Debug, Clone)]
72pub(crate) struct VectorEncryptionKey {
73 pub scaling_factor: ScalingFactor,
75 pub key: EncryptionKey,
77}
78
79#[derive(Debug, Serialize, Clone, Copy)]
80pub(crate) struct ScalingFactor(pub f32); #[derive(Clone)]
83pub(crate) struct EncryptionKey(pub Vec<u8>);
84impl_secret_debug!(EncryptionKey);
85
86impl VectorEncryptionKey {
87 pub(crate) fn derive_from_secret(
90 secret: &Secret,
91 tenant_id: &TenantId,
92 derivation_path: &DerivationPath,
93 ) -> Self {
94 let hash_result = util::hash512(
95 &secret.secret[..],
96 format!("{}-{}", tenant_id.0, derivation_path.0),
97 );
98 Self::unsafe_bytes_to_key(&hash_result[..])
99 }
100
101 pub(crate) fn unsafe_bytes_to_key(key_bytes: &[u8]) -> VectorEncryptionKey {
105 let (scaling_factor_bytes, rest) = key_bytes.split_at(3);
106 let (key_bytes, _) = rest.split_at(32);
107 let scaling_byte_vec = std::iter::once(0)
109 .chain(scaling_factor_bytes.iter().cloned())
110 .collect_vec();
111 let scaling_factor_u32: u32 = u32::from_be_bytes(
112 scaling_byte_vec
113 .try_into()
114 .expect("The vector above is always size 4, so this shouldn't happen."),
115 );
116 VectorEncryptionKey {
117 scaling_factor: ScalingFactor(scaling_factor_u32 as f32),
118 key: EncryptionKey(key_bytes.to_vec()),
119 }
120 }
121}
122
123#[uniffi::export]
124#[async_trait::async_trait]
125pub trait VectorOps: Send + Sync {
126 async fn encrypt(
130 &self,
131 plaintext_vector: PlaintextVector,
132 metadata: &AlloyMetadata,
133 ) -> Result<EncryptedVector, AlloyError>;
134
135 async fn encrypt_batch(
139 &self,
140 plaintext_vectors: PlaintextVectors,
141 metadata: &AlloyMetadata,
142 ) -> Result<VectorEncryptBatchResult, AlloyError>;
143
144 async fn decrypt(
147 &self,
148 encrypted_vector: EncryptedVector,
149 metadata: &AlloyMetadata,
150 ) -> Result<PlaintextVector, AlloyError>;
151
152 async fn decrypt_batch(
157 &self,
158 encrypted_vectors: EncryptedVectors,
159 metadata: &AlloyMetadata,
160 ) -> Result<VectorDecryptBatchResult, AlloyError>;
161
162 async fn generate_query_vectors(
165 &self,
166 vectors_to_query: PlaintextVectors,
167 metadata: &AlloyMetadata,
168 ) -> Result<GenerateVectorQueryResult, AlloyError>;
169
170 async fn get_in_rotation_prefix(
176 &self,
177 secret_path: SecretPath,
178 derivation_path: DerivationPath,
179 metadata: &AlloyMetadata,
180 ) -> Result<Vec<u8>, AlloyError>;
181
182 async fn rotate_vectors(
193 &self,
194 encrypted_vectors: EncryptedVectors,
195 metadata: &AlloyMetadata,
196 new_tenant_id: Option<TenantId>,
197 ) -> Result<VectorRotateResult, AlloyError>;
198}
199
200pub(crate) fn get_iv_and_auth_hash(b: &[u8]) -> Result<([u8; 12], AuthHash), AlloyError> {
201 let vector_proto: VectorEncryptionMetadata = protobuf::Message::parse_from_bytes(b)?;
202 let iv = vector_proto.iv;
203 let auth_hash = vector_proto.auth_hash;
204 Ok((
205 iv[..].try_into().map_err(|_| AlloyError::DecryptError {
206 msg: "Invalid IV".to_string(),
207 })?,
208 AuthHash(
209 auth_hash[..]
210 .try_into()
211 .map_err(|_| AlloyError::DecryptError {
212 msg: "Invalid authentication hash".to_string(),
213 })?,
214 ),
215 ))
216}
217
218pub(crate) fn encrypt_internal<R: CryptoRng + Send + Sync>(
219 approximation_factor: f32,
220 key: &VectorEncryptionKey,
221 key_id: KeyId,
222 edek_type: EdekType,
223 plaintext_vector: PlaintextVector,
224 rng: Arc<Mutex<R>>,
225 use_scaling_factor: bool,
226) -> Result<EncryptedVector, AlloyError> {
227 let result = crypto::encrypt(
228 key,
229 approximation_factor,
230 shuffle(&key.key, plaintext_vector.plaintext_vector)
231 .into_iter()
232 .collect(),
233 rng,
234 use_scaling_factor,
235 )?;
236 let (header, vector_metadata) = v5::key_id_header::create_vector_metadata(
237 KeyIdHeader::new(edek_type, PayloadType::VectorMetadata, key_id),
238 result.iv.to_vec().into(),
239 result.auth_hash.0.to_vec().into(),
240 );
241 Ok(EncryptedVector {
242 encrypted_vector: result.ciphertext.to_vec(),
243 secret_path: plaintext_vector.secret_path,
244 derivation_path: plaintext_vector.derivation_path,
245 paired_icl_info: EncryptedBytes(
246 v5::key_id_header::encode_vector_metadata(header, vector_metadata).to_vec(),
247 ),
248 })
249}
250
251pub(crate) fn decrypt_internal(
252 approximation_factor: f32,
253 key: &VectorEncryptionKey,
254 encrypted_vector: EncryptedVector,
255 icl_metadata_bytes: Bytes,
256 use_scaling_factor: bool,
257) -> Result<PlaintextVector, AlloyError> {
258 let (iv, auth_hash) = get_iv_and_auth_hash(&icl_metadata_bytes)?;
259 Ok(crypto::decrypt(
260 key,
261 approximation_factor,
262 EncryptResult {
263 ciphertext: encrypted_vector.encrypted_vector.into(),
264 iv,
265 auth_hash,
266 },
267 use_scaling_factor,
268 )
269 .map(|r| unshuffle(&key.key, r))?)
270 .map(|dec| PlaintextVector {
271 plaintext_vector: dec,
272 secret_path: encrypted_vector.secret_path,
273 derivation_path: encrypted_vector.derivation_path,
274 })
275}
276
277pub(crate) fn get_approximation_factor(maybe_approx: Option<f32>) -> Result<f32, AlloyError> {
278 maybe_approx.ok_or_else(|| AlloyError::InvalidConfiguration {
279 msg: "`approximation_factor` was not set in the vector configuration.".to_string(),
280 })
281}