Skip to main content

ironcore_alloy/vector/
mod.rs

1use self::crypto::{EncryptResult, shuffle, unshuffle};
2use crate::{
3    AlloyMetadata, DerivationPath, EncryptedBytes, Secret, SecretPath, TenantId,
4    create_batch_result_struct, create_batch_result_struct_using_newtype,
5    errors::AlloyError,
6    util::{self, AuthHash},
7};
8use bytes::Bytes;
9use ironcore_documents::{
10    impl_secret_debug,
11    v5::{
12        self,
13        key_id_header::{EdekType, KeyId, KeyIdHeader, PayloadType},
14    },
15    vector_encryption_metadata::VectorEncryptionMetadata,
16};
17use itertools::Itertools;
18use rand::CryptoRng;
19use serde::Serialize;
20use std::{
21    collections::HashMap,
22    sync::{Arc, Mutex},
23};
24use uniffi::custom_newtype;
25
26pub(crate) mod crypto;
27
28#[derive(Debug, Clone, Hash, PartialEq, Eq)]
29pub struct VectorId(pub String);
30custom_newtype!(VectorId, String);
31
32#[derive(Clone, Debug, uniffi::Record)]
33pub struct EncryptedVector {
34    pub encrypted_vector: Vec<f32>,
35    pub secret_path: SecretPath,
36    pub derivation_path: DerivationPath,
37    pub paired_icl_info: EncryptedBytes,
38}
39
40#[derive(Clone, Debug, uniffi::Record)]
41pub struct PlaintextVector {
42    pub plaintext_vector: Vec<f32>,
43    pub secret_path: SecretPath,
44    pub derivation_path: DerivationPath,
45}
46
47#[derive(Debug, Clone)]
48pub struct PlaintextVectors(pub HashMap<VectorId, PlaintextVector>);
49custom_newtype!(PlaintextVectors, HashMap<VectorId, PlaintextVector>);
50#[derive(Debug, Clone)]
51pub struct EncryptedVectors(pub HashMap<VectorId, EncryptedVector>);
52custom_newtype!(EncryptedVectors, HashMap<VectorId, EncryptedVector>);
53pub struct GenerateVectorQueryResult(pub HashMap<VectorId, Vec<EncryptedVector>>);
54custom_newtype!(GenerateVectorQueryResult, HashMap<VectorId, Vec<EncryptedVector>>);
55create_batch_result_struct!(VectorRotateResult, EncryptedVector, VectorId);
56
57create_batch_result_struct_using_newtype!(
58    VectorEncryptBatchResult,
59    EncryptedVector,
60    VectorId,
61    EncryptedVectors
62);
63create_batch_result_struct_using_newtype!(
64    VectorDecryptBatchResult,
65    PlaintextVector,
66    VectorId,
67    PlaintextVectors
68);
69
70/// Key used to for vector encryption.
71#[derive(Debug, Clone)]
72pub(crate) struct VectorEncryptionKey {
73    /// The amount to scale embedding values during encryption
74    pub scaling_factor: ScalingFactor,
75    /// The actual key used for encryption/decryption operations
76    pub key: EncryptionKey,
77}
78
79#[derive(Debug, Serialize, Clone, Copy)]
80pub(crate) struct ScalingFactor(pub f32); // Based on page 135 having a size 2^30
81
82#[derive(Clone)]
83pub(crate) struct EncryptionKey(pub Vec<u8>);
84impl_secret_debug!(EncryptionKey);
85
86impl VectorEncryptionKey {
87    /// A way to generate a key from the secret, tenant_id and derivation_path. This is done in the context of
88    /// a standalone secret where we don't have a TSP to call to for derivation.
89    pub(crate) fn derive_from_secret(
90        secret: &Secret,
91        tenant_id: &TenantId,
92        derivation_path: &DerivationPath,
93    ) -> Self {
94        let hash_result = util::hash512(
95            &secret.secret[..],
96            format!("{}-{}", tenant_id.0, derivation_path.0),
97        );
98        Self::unsafe_bytes_to_key(&hash_result[..])
99    }
100
101    /// This function *will* panic on you if the slice is not of size >= 35.
102    /// It will take the first 3 bytes and make it into a scaling factor and use the next 32 bytes
103    /// as the encryption key. Ensure you've checked the size before calling this.
104    pub(crate) fn unsafe_bytes_to_key(key_bytes: &[u8]) -> VectorEncryptionKey {
105        let (scaling_factor_bytes, rest) = key_bytes.split_at(3);
106        let (key_bytes, _) = rest.split_at(32);
107        // Put a 0 on the front so that it's the right number of bytes for `u32::from_be_bytes`
108        let scaling_byte_vec = std::iter::once(0)
109            .chain(scaling_factor_bytes.iter().cloned())
110            .collect_vec();
111        let scaling_factor_u32: u32 = u32::from_be_bytes(
112            scaling_byte_vec
113                .try_into()
114                .expect("The vector above is always size 4, so this shouldn't happen."),
115        );
116        VectorEncryptionKey {
117            scaling_factor: ScalingFactor(scaling_factor_u32 as f32),
118            key: EncryptionKey(key_bytes.to_vec()),
119        }
120    }
121}
122
123#[uniffi::export]
124#[async_trait::async_trait]
125pub trait VectorOps: Send + Sync {
126    /// Encrypt a vector embedding with the provided metadata. The provided embedding is assumed to be normalized
127    /// and its values will be shuffled as part of the encryption.
128    /// The same tenant ID must be provided in the metadata when decrypting the embedding.
129    async fn encrypt(
130        &self,
131        plaintext_vector: PlaintextVector,
132        metadata: &AlloyMetadata,
133    ) -> Result<EncryptedVector, AlloyError>;
134
135    /// Encrypt multiple vector embeddings with the provided metadata. The provided embeddings are assumed to be normalized
136    /// and their values will be shuffled as part of the encryption.
137    /// The same tenant ID must be provided in the metadata when decrypting the embeddings.
138    async fn encrypt_batch(
139        &self,
140        plaintext_vectors: PlaintextVectors,
141        metadata: &AlloyMetadata,
142    ) -> Result<VectorEncryptBatchResult, AlloyError>;
143
144    /// Decrypt a vector embedding that was encrypted with the provided metadata. The values of the embedding will
145    /// be unshuffled to their original positions during decryption.
146    async fn decrypt(
147        &self,
148        encrypted_vector: EncryptedVector,
149        metadata: &AlloyMetadata,
150    ) -> Result<PlaintextVector, AlloyError>;
151
152    /// Decrypt multiple vector embeddings that were encrypted with the provided metadata. The values of the embeddings
153    /// will be unshuffled to their original positions during decryption.
154    /// Note that because the metadata is shared between the vectors, they all must correspond to the
155    /// same tenant ID.
156    async fn decrypt_batch(
157        &self,
158        encrypted_vectors: EncryptedVectors,
159        metadata: &AlloyMetadata,
160    ) -> Result<VectorDecryptBatchResult, AlloyError>;
161
162    /// Encrypt each plaintext vector with any Current and InRotation keys for the provided secret path.
163    /// The resulting encrypted vectors should be used in tandem when querying the vector database.
164    async fn generate_query_vectors(
165        &self,
166        vectors_to_query: PlaintextVectors,
167        metadata: &AlloyMetadata,
168    ) -> Result<GenerateVectorQueryResult, AlloyError>;
169
170    /// Generate a prefix that could used to search a data store for documents encrypted using an identifier (KMS
171    /// config id for SaaS Shield, secret id for Standalone). These bytes should be encoded into
172    /// a format matching the encoding in the data store. z85/ascii85 users should first pass these bytes through
173    /// `encode_prefix_z85` or `base85_prefix_padding`. Make sure you've read the documentation of those functions to
174    /// avoid pitfalls when encoding across byte boundaries.
175    async fn get_in_rotation_prefix(
176        &self,
177        secret_path: SecretPath,
178        derivation_path: DerivationPath,
179        metadata: &AlloyMetadata,
180    ) -> Result<Vec<u8>, AlloyError>;
181
182    /// Rotates vectors from the in-rotation secret for their secret path to the current secret.
183    /// This can also be used to rotate data from one tenant ID to a new one, which most useful when a tenant is
184    /// internally migrated.
185    ///
186    /// WARNINGS:
187    ///     * this involves decrypting then encrypting vectors. Since the vectors are full of floating point numbers,
188    ///       this process is lossy, which will cause some drift over time. If you need perfectly preserved accuracy,
189    ///       store the source vector encrypted with `standard` next to the encrypted vector. `standard` decrypt
190    ///       that, `vector` encrypt it again, and replace the encrypted vector with the result.
191    ///     * only one metadata and new tenant ID argument means each call to this needs to have one tenant's vectors.
192    async fn rotate_vectors(
193        &self,
194        encrypted_vectors: EncryptedVectors,
195        metadata: &AlloyMetadata,
196        new_tenant_id: Option<TenantId>,
197    ) -> Result<VectorRotateResult, AlloyError>;
198}
199
200pub(crate) fn get_iv_and_auth_hash(b: &[u8]) -> Result<([u8; 12], AuthHash), AlloyError> {
201    let vector_proto: VectorEncryptionMetadata = protobuf::Message::parse_from_bytes(b)?;
202    let iv = vector_proto.iv;
203    let auth_hash = vector_proto.auth_hash;
204    Ok((
205        iv[..].try_into().map_err(|_| AlloyError::DecryptError {
206            msg: "Invalid IV".to_string(),
207        })?,
208        AuthHash(
209            auth_hash[..]
210                .try_into()
211                .map_err(|_| AlloyError::DecryptError {
212                    msg: "Invalid authentication hash".to_string(),
213                })?,
214        ),
215    ))
216}
217
218pub(crate) fn encrypt_internal<R: CryptoRng + Send + Sync>(
219    approximation_factor: f32,
220    key: &VectorEncryptionKey,
221    key_id: KeyId,
222    edek_type: EdekType,
223    plaintext_vector: PlaintextVector,
224    rng: Arc<Mutex<R>>,
225    use_scaling_factor: bool,
226) -> Result<EncryptedVector, AlloyError> {
227    let result = crypto::encrypt(
228        key,
229        approximation_factor,
230        shuffle(&key.key, plaintext_vector.plaintext_vector)
231            .into_iter()
232            .collect(),
233        rng,
234        use_scaling_factor,
235    )?;
236    let (header, vector_metadata) = v5::key_id_header::create_vector_metadata(
237        KeyIdHeader::new(edek_type, PayloadType::VectorMetadata, key_id),
238        result.iv.to_vec().into(),
239        result.auth_hash.0.to_vec().into(),
240    );
241    Ok(EncryptedVector {
242        encrypted_vector: result.ciphertext.to_vec(),
243        secret_path: plaintext_vector.secret_path,
244        derivation_path: plaintext_vector.derivation_path,
245        paired_icl_info: EncryptedBytes(
246            v5::key_id_header::encode_vector_metadata(header, vector_metadata).to_vec(),
247        ),
248    })
249}
250
251pub(crate) fn decrypt_internal(
252    approximation_factor: f32,
253    key: &VectorEncryptionKey,
254    encrypted_vector: EncryptedVector,
255    icl_metadata_bytes: Bytes,
256    use_scaling_factor: bool,
257) -> Result<PlaintextVector, AlloyError> {
258    let (iv, auth_hash) = get_iv_and_auth_hash(&icl_metadata_bytes)?;
259    Ok(crypto::decrypt(
260        key,
261        approximation_factor,
262        EncryptResult {
263            ciphertext: encrypted_vector.encrypted_vector.into(),
264            iv,
265            auth_hash,
266        },
267        use_scaling_factor,
268    )
269    .map(|r| unshuffle(&key.key, r))?)
270    .map(|dec| PlaintextVector {
271        plaintext_vector: dec,
272        secret_path: encrypted_vector.secret_path,
273        derivation_path: encrypted_vector.derivation_path,
274    })
275}
276
277pub(crate) fn get_approximation_factor(maybe_approx: Option<f32>) -> Result<f32, AlloyError> {
278    maybe_approx.ok_or_else(|| AlloyError::InvalidConfiguration {
279        msg: "`approximation_factor` was not set in the vector configuration.".to_string(),
280    })
281}