1use anyhow::{bail, Result};
8use rand::{Rng, SeedableRng};
9use rand::rngs::StdRng;
10use serde::{Deserialize, Serialize};
11use zeroize::Zeroize;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AdcpeConfig {
16 pub dimensions: usize,
18 #[serde(default)]
20 pub noise_scale: f64,
21}
22
23pub struct AdcpeEncryptor {
29 matrix: Vec<f64>,
31 matrix_inv: Vec<f64>,
33 dim: usize,
35 noise_scale: f64,
37 rng: StdRng,
39}
40
41impl Drop for AdcpeEncryptor {
42 fn drop(&mut self) {
43 self.matrix.zeroize();
44 self.matrix_inv.zeroize();
45 }
46}
47
48impl AdcpeEncryptor {
49 pub fn new(key: &[u8; 32], config: &AdcpeConfig) -> Result<Self> {
54 let dim = config.dimensions;
55 if dim == 0 {
56 bail!("Vector dimensions must be > 0");
57 }
58
59 let mut seed = [0u8; 32];
61 seed.copy_from_slice(key);
62 let mut rng = StdRng::from_seed(seed);
63
64 let mut matrix = vec![0.0f64; dim * dim];
66 for v in matrix.iter_mut() {
67 *v = rng.gen::<f64>() * 2.0 - 1.0;
68 }
69
70 gram_schmidt(&mut matrix, dim)?;
72
73 let matrix_inv = transpose(&matrix, dim);
75
76 let mut noise_seed = [0u8; 32];
78 for (i, b) in key.iter().enumerate() {
79 noise_seed[i] = b.wrapping_add(0x5A);
80 }
81 let noise_rng = StdRng::from_seed(noise_seed);
82
83 Ok(Self {
84 matrix,
85 matrix_inv,
86 dim,
87 noise_scale: config.noise_scale,
88 rng: noise_rng,
89 })
90 }
91
92 pub fn encrypt(&mut self, vector: &[f64]) -> Result<Vec<f64>> {
96 if vector.len() != self.dim {
97 bail!(
98 "Vector dimension mismatch: expected {}, got {}",
99 self.dim,
100 vector.len()
101 );
102 }
103
104 let mut result = mat_vec_mul(&self.matrix, vector, self.dim);
105
106 if self.noise_scale > 0.0 {
108 for v in result.iter_mut() {
109 *v += self.rng.gen::<f64>() * self.noise_scale;
110 }
111 }
112
113 Ok(result)
114 }
115
116 pub fn decrypt(&self, encrypted: &[f64]) -> Result<Vec<f64>> {
121 if encrypted.len() != self.dim {
122 bail!(
123 "Vector dimension mismatch: expected {}, got {}",
124 self.dim,
125 encrypted.len()
126 );
127 }
128
129 Ok(mat_vec_mul(&self.matrix_inv, encrypted, self.dim))
130 }
131
132 pub fn encrypt_batch(&mut self, vectors: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
134 vectors.iter().map(|v| self.encrypt(v)).collect()
135 }
136
137 pub fn decrypt_batch(&self, encrypted: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
139 encrypted.iter().map(|v| self.decrypt(v)).collect()
140 }
141
142 pub fn dimensions(&self) -> usize {
144 self.dim
145 }
146}
147
148pub fn encrypt_f32(encryptor: &mut AdcpeEncryptor, vector: &[f32]) -> Result<Vec<f32>> {
150 let f64_vec: Vec<f64> = vector.iter().map(|&v| v as f64).collect();
151 let encrypted = encryptor.encrypt(&f64_vec)?;
152 Ok(encrypted.iter().map(|&v| v as f32).collect())
153}
154
155pub fn decrypt_f32(encryptor: &AdcpeEncryptor, encrypted: &[f32]) -> Result<Vec<f32>> {
157 let f64_vec: Vec<f64> = encrypted.iter().map(|&v| v as f64).collect();
158 let decrypted = encryptor.decrypt(&f64_vec)?;
159 Ok(decrypted.iter().map(|&v| v as f32).collect())
160}
161
162fn mat_vec_mul(matrix: &[f64], vector: &[f64], dim: usize) -> Vec<f64> {
164 (0..dim)
165 .map(|i| {
166 let row_start = i * dim;
167 (0..dim).map(|j| matrix[row_start + j] * vector[j]).sum()
168 })
169 .collect()
170}
171
172fn transpose(matrix: &[f64], dim: usize) -> Vec<f64> {
174 let mut result = vec![0.0; dim * dim];
175 for i in 0..dim {
176 for j in 0..dim {
177 result[j * dim + i] = matrix[i * dim + j];
178 }
179 }
180 result
181}
182
183fn gram_schmidt(matrix: &mut [f64], dim: usize) -> Result<()> {
185 for i in 0..dim {
186 for j in 0..i {
188 let dot = dot_rows(matrix, i, j, dim);
189 let norm_sq = dot_rows(matrix, j, j, dim);
190 if norm_sq < 1e-10 {
191 bail!("Gram-Schmidt failed: degenerate matrix (row {} near-zero)", j);
192 }
193 let scale = dot / norm_sq;
194 for k in 0..dim {
195 let val = matrix[j * dim + k];
196 matrix[i * dim + k] -= scale * val;
197 }
198 }
199
200 let norm = dot_rows(matrix, i, i, dim).sqrt();
202 if norm < 1e-10 {
203 bail!("Gram-Schmidt failed: zero norm at row {}", i);
204 }
205 for k in 0..dim {
206 matrix[i * dim + k] /= norm;
207 }
208 }
209 Ok(())
210}
211
212fn dot_rows(matrix: &[f64], row_a: usize, row_b: usize, dim: usize) -> f64 {
214 let a_start = row_a * dim;
215 let b_start = row_b * dim;
216 let mut sum = 0.0;
217 for k in 0..dim {
218 sum += matrix[a_start + k] * matrix[b_start + k];
219 }
220 sum
221}
222
223pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
225 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
226 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
227 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
228 if norm_a < 1e-10 || norm_b < 1e-10 {
229 return 0.0;
230 }
231 dot / (norm_a * norm_b)
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 fn test_key() -> [u8; 32] {
239 [0xAB; 32]
240 }
241
242 fn test_config(dim: usize) -> AdcpeConfig {
243 AdcpeConfig {
244 dimensions: dim,
245 noise_scale: 0.0,
246 }
247 }
248
249 #[test]
250 fn test_encrypt_decrypt_roundtrip() {
251 let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(4)).unwrap();
252 let original = vec![1.0, 2.0, 3.0, 4.0];
253
254 let encrypted = enc.encrypt(&original).unwrap();
255 let decrypted = enc.decrypt(&encrypted).unwrap();
256
257 for (a, b) in original.iter().zip(decrypted.iter()) {
258 assert!((a - b).abs() < 1e-10, "Roundtrip failed: {} vs {}", a, b);
259 }
260 }
261
262 #[test]
263 fn test_cosine_similarity_preserved() {
264 let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(8)).unwrap();
265
266 let a = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
267 let b = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
268 let c = vec![1.0, 0.1, 1.0, 0.1, 1.0, 0.1, 1.0, 0.1];
269
270 let cos_ab_orig = cosine_similarity(&a, &b);
271 let cos_ac_orig = cosine_similarity(&a, &c);
272
273 let ea = enc.encrypt(&a).unwrap();
274 let eb = enc.encrypt(&b).unwrap();
275 let ec = enc.encrypt(&c).unwrap();
276
277 let cos_ab_enc = cosine_similarity(&ea, &eb);
278 let cos_ac_enc = cosine_similarity(&ea, &ec);
279
280 assert!(
281 (cos_ab_orig - cos_ab_enc).abs() < 1e-10,
282 "Cosine AB not preserved: {} vs {}",
283 cos_ab_orig, cos_ab_enc
284 );
285 assert!(
286 (cos_ac_orig - cos_ac_enc).abs() < 1e-10,
287 "Cosine AC not preserved: {} vs {}",
288 cos_ac_orig, cos_ac_enc
289 );
290 }
291
292 #[test]
293 fn test_encrypted_vectors_differ() {
294 let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(4)).unwrap();
295 let v = vec![1.0, 2.0, 3.0, 4.0];
296 let encrypted = enc.encrypt(&v).unwrap();
297
298 assert_ne!(v, encrypted);
300 }
301
302 #[test]
303 fn test_different_keys_produce_different_output() {
304 let config = test_config(4);
305 let v = vec![1.0, 2.0, 3.0, 4.0];
306
307 let mut enc1 = AdcpeEncryptor::new(&[0xAB; 32], &config).unwrap();
308 let mut enc2 = AdcpeEncryptor::new(&[0xCD; 32], &config).unwrap();
309
310 let e1 = enc1.encrypt(&v).unwrap();
311 let e2 = enc2.encrypt(&v).unwrap();
312
313 assert_ne!(e1, e2);
314 }
315
316 #[test]
317 fn test_dimension_mismatch_error() {
318 let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(4)).unwrap();
319 let wrong_dim = vec![1.0, 2.0, 3.0]; assert!(enc.encrypt(&wrong_dim).is_err());
322 }
323
324 #[test]
325 fn test_batch_encrypt_decrypt() {
326 let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(4)).unwrap();
327 let vectors = vec![
328 vec![1.0, 0.0, 0.0, 0.0],
329 vec![0.0, 1.0, 0.0, 0.0],
330 vec![0.0, 0.0, 1.0, 0.0],
331 ];
332
333 let encrypted = enc.encrypt_batch(&vectors).unwrap();
334 assert_eq!(encrypted.len(), 3);
335
336 let decrypted = enc.decrypt_batch(&encrypted).unwrap();
337 for (orig, dec) in vectors.iter().zip(decrypted.iter()) {
338 for (a, b) in orig.iter().zip(dec.iter()) {
339 assert!((a - b).abs() < 1e-10);
340 }
341 }
342 }
343
344 #[test]
345 fn test_f32_roundtrip() {
346 let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(4)).unwrap();
347 let original: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4];
348
349 let encrypted = encrypt_f32(&mut enc, &original).unwrap();
350 let decrypted = decrypt_f32(&enc, &encrypted).unwrap();
351
352 for (a, b) in original.iter().zip(decrypted.iter()) {
353 assert!((a - b).abs() < 1e-5, "f32 roundtrip: {} vs {}", a, b);
354 }
355 }
356
357 #[test]
358 fn test_noise_adds_distortion() {
359 let config = AdcpeConfig {
360 dimensions: 4,
361 noise_scale: 0.01,
362 };
363 let mut enc = AdcpeEncryptor::new(&test_key(), &config).unwrap();
364 let v = vec![1.0, 2.0, 3.0, 4.0];
365
366 let encrypted = enc.encrypt(&v).unwrap();
367 let decrypted = enc.decrypt(&encrypted).unwrap();
368
369 let max_err: f64 = v.iter().zip(decrypted.iter())
371 .map(|(a, b)| (a - b).abs())
372 .fold(0.0, f64::max);
373
374 assert!(max_err > 1e-12, "Expected some distortion from noise");
375 assert!(max_err < 1.0, "Distortion too large: {}", max_err);
376 }
377
378 #[test]
379 fn test_orthogonality() {
380 let enc = AdcpeEncryptor::new(&test_key(), &test_config(4)).unwrap();
382 let dim = enc.dim;
383
384 for i in 0..dim {
385 for j in 0..dim {
386 let dot = dot_rows(&enc.matrix, i, j, dim);
387 let expected = if i == j { 1.0 } else { 0.0 };
388 assert!(
389 (dot - expected).abs() < 1e-10,
390 "Not orthogonal at ({}, {}): {} vs {}",
391 i, j, dot, expected
392 );
393 }
394 }
395 }
396
397 #[test]
398 fn test_realistic_embedding_dimensions() {
399 let mut enc = AdcpeEncryptor::new(&test_key(), &test_config(128)).unwrap();
401
402 let mut rng = StdRng::seed_from_u64(42);
403 let a: Vec<f64> = (0..128).map(|_| rng.gen::<f64>() - 0.5).collect();
404 let b: Vec<f64> = (0..128).map(|_| rng.gen::<f64>() - 0.5).collect();
405
406 let cos_orig = cosine_similarity(&a, &b);
407
408 let ea = enc.encrypt(&a).unwrap();
409 let eb = enc.encrypt(&b).unwrap();
410
411 let cos_enc = cosine_similarity(&ea, &eb);
412
413 assert!(
414 (cos_orig - cos_enc).abs() < 1e-10,
415 "Cosine not preserved at dim=128: {} vs {}",
416 cos_orig, cos_enc
417 );
418 }
419}