Skip to main content

cloakpipe_vector/
adcpe.rs

1//! ADCPE: Approximately Distance-Preserving Cryptographic Encryption
2//!
3//! Applies a secret orthogonal transformation to embedding vectors.
4//! Orthogonal transformations preserve inner products (and thus cosine
5//! similarity), making encrypted vectors usable for similarity search.
6
7use anyhow::{bail, Result};
8use rand::{Rng, SeedableRng};
9use rand::rngs::StdRng;
10use serde::{Deserialize, Serialize};
11use zeroize::Zeroize;
12
13/// Configuration for ADCPE encryption.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AdcpeConfig {
16    /// Dimensionality of the embedding vectors.
17    pub dimensions: usize,
18    /// Optional noise scale (0.0 = exact distance preservation, >0 adds noise).
19    #[serde(default)]
20    pub noise_scale: f64,
21}
22
23/// ADCPE vector encryptor.
24///
25/// Holds a secret orthogonal matrix derived from the encryption key.
26/// The matrix is generated via Gram-Schmidt orthogonalization of a
27/// seeded pseudo-random matrix.
28pub struct AdcpeEncryptor {
29    /// The orthogonal transformation matrix (row-major, dim x dim).
30    matrix: Vec<f64>,
31    /// The inverse (transpose for orthogonal) matrix for decryption.
32    matrix_inv: Vec<f64>,
33    /// Dimensionality.
34    dim: usize,
35    /// Noise scale for differential privacy.
36    noise_scale: f64,
37    /// RNG for noise generation.
38    rng: StdRng,
39}
40
41impl Drop for AdcpeEncryptor {
42    fn drop(&mut self) {
43        self.matrix.zeroize();
44        self.matrix_inv.zeroize();
45    }
46}
47
48impl AdcpeEncryptor {
49    /// Create a new ADCPE encryptor from a 32-byte key.
50    ///
51    /// The key is used to seed a PRNG that generates the random matrix,
52    /// which is then orthogonalized via Gram-Schmidt.
53    pub fn new(key: &[u8; 32], config: &AdcpeConfig) -> Result<Self> {
54        let dim = config.dimensions;
55        if dim == 0 {
56            bail!("Vector dimensions must be > 0");
57        }
58
59        // Seed RNG from key
60        let mut seed = [0u8; 32];
61        seed.copy_from_slice(key);
62        let mut rng = StdRng::from_seed(seed);
63
64        // Generate random matrix
65        let mut matrix = vec![0.0f64; dim * dim];
66        for v in matrix.iter_mut() {
67            *v = rng.gen::<f64>() * 2.0 - 1.0;
68        }
69
70        // Gram-Schmidt orthogonalization
71        gram_schmidt(&mut matrix, dim)?;
72
73        // Transpose = inverse for orthogonal matrices
74        let matrix_inv = transpose(&matrix, dim);
75
76        // Fresh RNG for noise (different seed)
77        let mut noise_seed = [0u8; 32];
78        for (i, b) in key.iter().enumerate() {
79            noise_seed[i] = b.wrapping_add(0x5A);
80        }
81        let noise_rng = StdRng::from_seed(noise_seed);
82
83        Ok(Self {
84            matrix,
85            matrix_inv,
86            dim,
87            noise_scale: config.noise_scale,
88            rng: noise_rng,
89        })
90    }
91
92    /// Encrypt a single embedding vector.
93    ///
94    /// Returns the transformed vector with the same dimensionality.
95    pub fn encrypt(&mut self, vector: &[f64]) -> Result<Vec<f64>> {
96        if vector.len() != self.dim {
97            bail!(
98                "Vector dimension mismatch: expected {}, got {}",
99                self.dim,
100                vector.len()
101            );
102        }
103
104        let mut result = mat_vec_mul(&self.matrix, vector, self.dim);
105
106        // Add optional noise
107        if self.noise_scale > 0.0 {
108            for v in result.iter_mut() {
109                *v += self.rng.gen::<f64>() * self.noise_scale;
110            }
111        }
112
113        Ok(result)
114    }
115
116    /// Decrypt a single embedding vector (inverse transformation).
117    ///
118    /// Note: if noise was added during encryption, decryption will not
119    /// recover the exact original vector.
120    pub fn decrypt(&self, encrypted: &[f64]) -> Result<Vec<f64>> {
121        if encrypted.len() != self.dim {
122            bail!(
123                "Vector dimension mismatch: expected {}, got {}",
124                self.dim,
125                encrypted.len()
126            );
127        }
128
129        Ok(mat_vec_mul(&self.matrix_inv, encrypted, self.dim))
130    }
131
132    /// Encrypt a batch of vectors.
133    pub fn encrypt_batch(&mut self, vectors: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
134        vectors.iter().map(|v| self.encrypt(v)).collect()
135    }
136
137    /// Decrypt a batch of vectors.
138    pub fn decrypt_batch(&self, encrypted: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
139        encrypted.iter().map(|v| self.decrypt(v)).collect()
140    }
141
142    /// Get the dimensionality.
143    pub fn dimensions(&self) -> usize {
144        self.dim
145    }
146}
147
148/// Encrypt f32 vectors (common for embedding APIs).
149pub fn encrypt_f32(encryptor: &mut AdcpeEncryptor, vector: &[f32]) -> Result<Vec<f32>> {
150    let f64_vec: Vec<f64> = vector.iter().map(|&v| v as f64).collect();
151    let encrypted = encryptor.encrypt(&f64_vec)?;
152    Ok(encrypted.iter().map(|&v| v as f32).collect())
153}
154
155/// Decrypt f32 vectors.
156pub fn decrypt_f32(encryptor: &AdcpeEncryptor, encrypted: &[f32]) -> Result<Vec<f32>> {
157    let f64_vec: Vec<f64> = encrypted.iter().map(|&v| v as f64).collect();
158    let decrypted = encryptor.decrypt(&f64_vec)?;
159    Ok(decrypted.iter().map(|&v| v as f32).collect())
160}
161
162/// Matrix-vector multiplication (row-major matrix).
163fn mat_vec_mul(matrix: &[f64], vector: &[f64], dim: usize) -> Vec<f64> {
164    (0..dim)
165        .map(|i| {
166            let row_start = i * dim;
167            (0..dim).map(|j| matrix[row_start + j] * vector[j]).sum()
168        })
169        .collect()
170}
171
172/// Transpose a square matrix (row-major).
173fn transpose(matrix: &[f64], dim: usize) -> Vec<f64> {
174    let mut result = vec![0.0; dim * dim];
175    for i in 0..dim {
176        for j in 0..dim {
177            result[j * dim + i] = matrix[i * dim + j];
178        }
179    }
180    result
181}
182
183/// Gram-Schmidt orthogonalization (in-place, row-major).
184fn gram_schmidt(matrix: &mut [f64], dim: usize) -> Result<()> {
185    for i in 0..dim {
186        // Subtract projections onto previous rows
187        for j in 0..i {
188            let dot = dot_rows(matrix, i, j, dim);
189            let norm_sq = dot_rows(matrix, j, j, dim);
190            if norm_sq < 1e-10 {
191                bail!("Gram-Schmidt failed: degenerate matrix (row {} near-zero)", j);
192            }
193            let scale = dot / norm_sq;
194            for k in 0..dim {
195                let val = matrix[j * dim + k];
196                matrix[i * dim + k] -= scale * val;
197            }
198        }
199
200        // Normalize
201        let norm = dot_rows(matrix, i, i, dim).sqrt();
202        if norm < 1e-10 {
203            bail!("Gram-Schmidt failed: zero norm at row {}", i);
204        }
205        for k in 0..dim {
206            matrix[i * dim + k] /= norm;
207        }
208    }
209    Ok(())
210}
211
212/// Dot product of two rows in a row-major matrix.
213fn dot_rows(matrix: &[f64], row_a: usize, row_b: usize, dim: usize) -> f64 {
214    let a_start = row_a * dim;
215    let b_start = row_b * dim;
216    let mut sum = 0.0;
217    for k in 0..dim {
218        sum += matrix[a_start + k] * matrix[b_start + k];
219    }
220    sum
221}
222
223/// Compute cosine similarity between two vectors.
224pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
225    let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
226    let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
227    let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
228    if norm_a < 1e-10 || norm_b < 1e-10 {
229        return 0.0;
230    }
231    dot / (norm_a * norm_b)
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    fn test_key() -> [u8; 32] {
239        [0xAB; 32]
240    }
241
242    fn test_config(dim: usize) -> AdcpeConfig {
243        AdcpeConfig {
244            dimensions: dim,
245            noise_scale: 0.0,
246        }
247    }
248
249    #[test]
250    fn test_encrypt_decrypt_roundtrip() {
251        let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(4)).unwrap();
252        let original = vec![1.0, 2.0, 3.0, 4.0];
253
254        let encrypted = enc.encrypt(&original).unwrap();
255        let decrypted = enc.decrypt(&encrypted).unwrap();
256
257        for (a, b) in original.iter().zip(decrypted.iter()) {
258            assert!((a - b).abs() < 1e-10, "Roundtrip failed: {} vs {}", a, b);
259        }
260    }
261
262    #[test]
263    fn test_cosine_similarity_preserved() {
264        let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(8)).unwrap();
265
266        let a = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
267        let b = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
268        let c = vec![1.0, 0.1, 1.0, 0.1, 1.0, 0.1, 1.0, 0.1];
269
270        let cos_ab_orig = cosine_similarity(&a, &b);
271        let cos_ac_orig = cosine_similarity(&a, &c);
272
273        let ea = enc.encrypt(&a).unwrap();
274        let eb = enc.encrypt(&b).unwrap();
275        let ec = enc.encrypt(&c).unwrap();
276
277        let cos_ab_enc = cosine_similarity(&ea, &eb);
278        let cos_ac_enc = cosine_similarity(&ea, &ec);
279
280        assert!(
281            (cos_ab_orig - cos_ab_enc).abs() < 1e-10,
282            "Cosine AB not preserved: {} vs {}",
283            cos_ab_orig, cos_ab_enc
284        );
285        assert!(
286            (cos_ac_orig - cos_ac_enc).abs() < 1e-10,
287            "Cosine AC not preserved: {} vs {}",
288            cos_ac_orig, cos_ac_enc
289        );
290    }
291
292    #[test]
293    fn test_encrypted_vectors_differ() {
294        let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(4)).unwrap();
295        let v = vec![1.0, 2.0, 3.0, 4.0];
296        let encrypted = enc.encrypt(&v).unwrap();
297
298        // Encrypted should NOT equal original
299        assert_ne!(v, encrypted);
300    }
301
302    #[test]
303    fn test_different_keys_produce_different_output() {
304        let config = test_config(4);
305        let v = vec![1.0, 2.0, 3.0, 4.0];
306
307        let mut enc1 = AdcpeEncryptor::new(&[0xAB; 32], &config).unwrap();
308        let mut enc2 = AdcpeEncryptor::new(&[0xCD; 32], &config).unwrap();
309
310        let e1 = enc1.encrypt(&v).unwrap();
311        let e2 = enc2.encrypt(&v).unwrap();
312
313        assert_ne!(e1, e2);
314    }
315
316    #[test]
317    fn test_dimension_mismatch_error() {
318        let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(4)).unwrap();
319        let wrong_dim = vec![1.0, 2.0, 3.0]; // 3 instead of 4
320
321        assert!(enc.encrypt(&wrong_dim).is_err());
322    }
323
324    #[test]
325    fn test_batch_encrypt_decrypt() {
326        let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(4)).unwrap();
327        let vectors = vec![
328            vec![1.0, 0.0, 0.0, 0.0],
329            vec![0.0, 1.0, 0.0, 0.0],
330            vec![0.0, 0.0, 1.0, 0.0],
331        ];
332
333        let encrypted = enc.encrypt_batch(&vectors).unwrap();
334        assert_eq!(encrypted.len(), 3);
335
336        let decrypted = enc.decrypt_batch(&encrypted).unwrap();
337        for (orig, dec) in vectors.iter().zip(decrypted.iter()) {
338            for (a, b) in orig.iter().zip(dec.iter()) {
339                assert!((a - b).abs() < 1e-10);
340            }
341        }
342    }
343
344    #[test]
345    fn test_f32_roundtrip() {
346        let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(4)).unwrap();
347        let original: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4];
348
349        let encrypted = encrypt_f32(&mut enc, &original).unwrap();
350        let decrypted = decrypt_f32(&enc, &encrypted).unwrap();
351
352        for (a, b) in original.iter().zip(decrypted.iter()) {
353            assert!((a - b).abs() < 1e-5, "f32 roundtrip: {} vs {}", a, b);
354        }
355    }
356
357    #[test]
358    fn test_noise_adds_distortion() {
359        let config = AdcpeConfig {
360            dimensions: 4,
361            noise_scale: 0.01,
362        };
363        let mut enc = AdcpeEncryptor::new(&test_key(), &config).unwrap();
364        let v = vec![1.0, 2.0, 3.0, 4.0];
365
366        let encrypted = enc.encrypt(&v).unwrap();
367        let decrypted = enc.decrypt(&encrypted).unwrap();
368
369        // With noise, roundtrip won't be exact
370        let max_err: f64 = v.iter().zip(decrypted.iter())
371            .map(|(a, b)| (a - b).abs())
372            .fold(0.0, f64::max);
373
374        assert!(max_err > 1e-12, "Expected some distortion from noise");
375        assert!(max_err < 1.0, "Distortion too large: {}", max_err);
376    }
377
378    #[test]
379    fn test_orthogonality() {
380        // Verify the matrix is orthogonal: Q * Q^T = I
381        let enc = AdcpeEncryptor::new(&test_key(), &test_config(4)).unwrap();
382        let dim = enc.dim;
383
384        for i in 0..dim {
385            for j in 0..dim {
386                let dot = dot_rows(&enc.matrix, i, j, dim);
387                let expected = if i == j { 1.0 } else { 0.0 };
388                assert!(
389                    (dot - expected).abs() < 1e-10,
390                    "Not orthogonal at ({}, {}): {} vs {}",
391                    i, j, dot, expected
392                );
393            }
394        }
395    }
396
397    #[test]
398    fn test_realistic_embedding_dimensions() {
399        // Test with realistic dimensions (128, simulating a small model)
400        let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(128)).unwrap();
401
402        let mut rng = StdRng::seed_from_u64(42);
403        let a: Vec<f64> = (0..128).map(|_| rng.gen::<f64>() - 0.5).collect();
404        let b: Vec<f64> = (0..128).map(|_| rng.gen::<f64>() - 0.5).collect();
405
406        let cos_orig = cosine_similarity(&a, &b);
407
408        let ea = enc.encrypt(&a).unwrap();
409        let eb = enc.encrypt(&b).unwrap();
410
411        let cos_enc = cosine_similarity(&ea, &eb);
412
413        assert!(
414            (cos_orig - cos_enc).abs() < 1e-10,
415            "Cosine not preserved at dim=128: {} vs {}",
416            cos_orig, cos_enc
417        );
418    }
419}