1#[derive(Debug, Clone)]
9pub struct CompressionConfig {
10 pub input_dim: usize,
12 pub output_dim: usize,
14 pub seed: u64,
16}
17
18pub struct EmbeddingCompressor {
23 config: CompressionConfig,
24 projection: Vec<Vec<f32>>,
26}
27
28struct LcgRng {
30 state: u64,
31}
32
33impl LcgRng {
34 fn new(seed: u64) -> Self {
35 Self {
36 state: seed.wrapping_add(1),
37 }
38 }
39
40 fn next_u64(&mut self) -> u64 {
42 self.state = self
44 .state
45 .wrapping_mul(6_364_136_223_846_793_005)
46 .wrapping_add(1_442_695_040_888_963_407);
47 self.state
48 }
49
50 fn next_sixths(&mut self) -> u64 {
52 self.next_u64() % 6
53 }
54}
55
56impl EmbeddingCompressor {
57 pub fn new(config: CompressionConfig) -> Self {
64 let scale = (3.0_f32).sqrt();
65 let mut rng = LcgRng::new(config.seed);
66 let mut projection = Vec::with_capacity(config.output_dim);
67
68 for _ in 0..config.output_dim {
69 let mut row = Vec::with_capacity(config.input_dim);
70 for _ in 0..config.input_dim {
71 let val = match rng.next_sixths() {
72 0 => scale, 5 => -scale, _ => 0.0_f32, };
76 row.push(val);
77 }
78 projection.push(row);
79 }
80
81 Self { config, projection }
82 }
83
84 pub fn compress(&self, embedding: &[f32]) -> Result<Vec<f32>, String> {
88 if embedding.len() != self.config.input_dim {
89 return Err(format!(
90 "Expected embedding of length {}, got {}",
91 self.config.input_dim,
92 embedding.len()
93 ));
94 }
95
96 let scale = 1.0_f32 / (self.config.output_dim as f32).sqrt();
97 let compressed = self
98 .projection
99 .iter()
100 .map(|row| {
101 let dot: f32 = row.iter().zip(embedding.iter()).map(|(r, e)| r * e).sum();
102 dot * scale
103 })
104 .collect();
105
106 Ok(compressed)
107 }
108
109 pub fn compress_batch(&self, embeddings: &[Vec<f32>]) -> Result<Vec<Vec<f32>>, String> {
113 embeddings.iter().map(|e| self.compress(e)).collect()
114 }
115
116 pub fn similarity_preservation_ratio(&self, a: &[f32], b: &[f32]) -> Result<f32, String> {
122 let original_sim = cosine_similarity(a, b)?;
123 let a_comp = self.compress(a)?;
124 let b_comp = self.compress(b)?;
125 let compressed_sim = cosine_similarity(&a_comp, &b_comp)?;
126
127 if original_sim.abs() < 1e-9 {
129 return Ok(compressed_sim.abs());
130 }
131
132 Ok(compressed_sim / original_sim)
133 }
134
135 pub fn config(&self) -> &CompressionConfig {
137 &self.config
138 }
139
140 pub fn compression_ratio(&self) -> f32 {
142 self.config.input_dim as f32 / self.config.output_dim as f32
143 }
144}
145
146fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, String> {
148 if a.len() != b.len() {
149 return Err(format!(
150 "Vector length mismatch: {} vs {}",
151 a.len(),
152 b.len()
153 ));
154 }
155 if a.is_empty() {
156 return Err("Cannot compute cosine similarity of empty vectors".to_string());
157 }
158
159 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
160 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
161 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
162
163 if norm_a < 1e-9 || norm_b < 1e-9 {
164 return Ok(0.0);
165 }
166
167 Ok(dot / (norm_a * norm_b))
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 fn make_config(input_dim: usize, output_dim: usize, seed: u64) -> CompressionConfig {
175 CompressionConfig {
176 input_dim,
177 output_dim,
178 seed,
179 }
180 }
181
182 fn make_vec(dim: usize, val: f32) -> Vec<f32> {
183 vec![val; dim]
184 }
185
186 fn unit_vec(dim: usize, idx: usize) -> Vec<f32> {
187 let mut v = vec![0.0_f32; dim];
188 v[idx] = 1.0;
189 v
190 }
191
192 #[test]
195 fn test_new_creates_correct_projection_dims() {
196 let cfg = make_config(128, 32, 42);
197 let c = EmbeddingCompressor::new(cfg);
198 assert_eq!(c.projection.len(), 32);
199 for row in &c.projection {
200 assert_eq!(row.len(), 128);
201 }
202 }
203
204 #[test]
205 fn test_new_entries_are_valid_achlioptas_values() {
206 let scale = (3.0_f32).sqrt();
207 let cfg = make_config(64, 16, 7);
208 let c = EmbeddingCompressor::new(cfg);
209 for row in &c.projection {
210 for &v in row {
211 assert!(
212 (v - scale).abs() < 1e-6 || v.abs() < 1e-6 || (v + scale).abs() < 1e-6,
213 "Unexpected value: {v}"
214 );
215 }
216 }
217 }
218
219 #[test]
220 fn test_seed_reproducibility() {
221 let cfg1 = make_config(64, 16, 99);
222 let cfg2 = make_config(64, 16, 99);
223 let c1 = EmbeddingCompressor::new(cfg1);
224 let c2 = EmbeddingCompressor::new(cfg2);
225 assert_eq!(c1.projection, c2.projection);
226 }
227
228 #[test]
229 fn test_different_seeds_produce_different_matrices() {
230 let c1 = EmbeddingCompressor::new(make_config(64, 16, 1));
231 let c2 = EmbeddingCompressor::new(make_config(64, 16, 2));
232 assert_ne!(c1.projection, c2.projection);
234 }
235
236 #[test]
239 fn test_compress_output_length_equals_output_dim() {
240 let cfg = make_config(128, 32, 0);
241 let c = EmbeddingCompressor::new(cfg);
242 let v = make_vec(128, 1.0);
243 let out = c.compress(&v).expect("compress should succeed");
244 assert_eq!(out.len(), 32);
245 }
246
247 #[test]
248 fn test_compress_wrong_input_length_returns_error() {
249 let cfg = make_config(128, 32, 0);
250 let c = EmbeddingCompressor::new(cfg);
251 let v = make_vec(64, 1.0);
252 let result = c.compress(&v);
253 assert!(result.is_err());
254 }
255
256 #[test]
257 fn test_compress_zero_vector() {
258 let cfg = make_config(64, 16, 5);
259 let c = EmbeddingCompressor::new(cfg);
260 let v = make_vec(64, 0.0);
261 let out = c.compress(&v).expect("compress should succeed");
262 for &x in &out {
263 assert!((x).abs() < 1e-9, "Expected zero vector, got {x}");
264 }
265 }
266
267 #[test]
268 fn test_compress_single_dim_input() {
269 let cfg = make_config(1, 1, 0);
270 let c = EmbeddingCompressor::new(cfg);
271 let v = vec![2.0_f32];
272 let out = c.compress(&v).expect("compress should succeed");
273 assert_eq!(out.len(), 1);
274 }
275
276 #[test]
277 fn test_compress_exact_output_dimension() {
278 for (input_dim, output_dim) in [(256, 64), (512, 128), (100, 50)] {
279 let cfg = make_config(input_dim, output_dim, 42);
280 let c = EmbeddingCompressor::new(cfg);
281 let v = make_vec(input_dim, 1.0);
282 let out = c.compress(&v).expect("compress ok");
283 assert_eq!(out.len(), output_dim);
284 }
285 }
286
287 #[test]
288 fn test_compress_is_deterministic() {
289 let cfg = make_config(64, 16, 13);
290 let c = EmbeddingCompressor::new(cfg);
291 let v: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
292 let out1 = c.compress(&v).expect("ok");
293 let out2 = c.compress(&v).expect("ok");
294 assert_eq!(out1, out2);
295 }
296
297 #[test]
298 fn test_compress_linearity_scalar_multiple() {
299 let cfg = make_config(32, 8, 17);
300 let c = EmbeddingCompressor::new(cfg);
301 let v: Vec<f32> = (0..32).map(|i| i as f32).collect();
302 let out1 = c.compress(&v).expect("ok");
303 let v2: Vec<f32> = v.iter().map(|&x| x * 2.0).collect();
304 let out2 = c.compress(&v2).expect("ok");
305 for (a, b) in out1.iter().zip(out2.iter()) {
306 assert!((b - 2.0 * a).abs() < 1e-5, "Linearity failed: {a} vs {b}");
307 }
308 }
309
310 #[test]
311 fn test_compress_unit_vector() {
312 let cfg = make_config(32, 8, 11);
313 let c = EmbeddingCompressor::new(cfg);
314 let v = unit_vec(32, 0);
315 let out = c.compress(&v).expect("ok");
316 assert_eq!(out.len(), 8);
317 }
318
319 #[test]
322 fn test_compress_batch_correct_count() {
323 let cfg = make_config(64, 16, 0);
324 let c = EmbeddingCompressor::new(cfg);
325 let batch: Vec<Vec<f32>> = (0..5).map(|_| make_vec(64, 1.0)).collect();
326 let result = c.compress_batch(&batch).expect("ok");
327 assert_eq!(result.len(), 5);
328 }
329
330 #[test]
331 fn test_compress_batch_each_output_length() {
332 let cfg = make_config(64, 16, 0);
333 let c = EmbeddingCompressor::new(cfg);
334 let batch: Vec<Vec<f32>> = (0..3).map(|_| make_vec(64, 1.0)).collect();
335 let result = c.compress_batch(&batch).expect("ok");
336 for out in &result {
337 assert_eq!(out.len(), 16);
338 }
339 }
340
341 #[test]
342 fn test_compress_batch_empty() {
343 let cfg = make_config(64, 16, 0);
344 let c = EmbeddingCompressor::new(cfg);
345 let result = c.compress_batch(&[]).expect("ok");
346 assert!(result.is_empty());
347 }
348
349 #[test]
350 fn test_compress_batch_error_on_wrong_size() {
351 let cfg = make_config(64, 16, 0);
352 let c = EmbeddingCompressor::new(cfg);
353 let batch = vec![make_vec(64, 1.0), make_vec(32, 1.0)];
354 let result = c.compress_batch(&batch);
355 assert!(result.is_err());
356 }
357
358 #[test]
359 fn test_compress_batch_single_element() {
360 let cfg = make_config(64, 16, 0);
361 let c = EmbeddingCompressor::new(cfg);
362 let batch = vec![make_vec(64, 0.5)];
363 let result = c.compress_batch(&batch).expect("ok");
364 assert_eq!(result.len(), 1);
365 assert_eq!(result[0].len(), 16);
366 }
367
368 #[test]
369 fn test_compress_batch_matches_individual_compress() {
370 let cfg = make_config(32, 8, 55);
371 let c = EmbeddingCompressor::new(cfg);
372 let v1: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
373 let v2: Vec<f32> = (0..32).map(|i| (i as f32 * 0.1).sin()).collect();
374 let individual1 = c.compress(&v1).expect("ok");
375 let individual2 = c.compress(&v2).expect("ok");
376 let batch = c.compress_batch(&[v1, v2]).expect("ok");
377 assert_eq!(batch[0], individual1);
378 assert_eq!(batch[1], individual2);
379 }
380
381 #[test]
384 fn test_compression_ratio_basic() {
385 let cfg = make_config(128, 32, 0);
386 let c = EmbeddingCompressor::new(cfg);
387 let ratio = c.compression_ratio();
388 assert!((ratio - 4.0).abs() < 1e-6, "Expected 4.0, got {ratio}");
389 }
390
391 #[test]
392 fn test_compression_ratio_no_compression() {
393 let cfg = make_config(64, 64, 0);
394 let c = EmbeddingCompressor::new(cfg);
395 assert!((c.compression_ratio() - 1.0).abs() < 1e-6);
396 }
397
398 #[test]
399 fn test_compression_ratio_high() {
400 let cfg = make_config(512, 8, 0);
401 let c = EmbeddingCompressor::new(cfg);
402 assert!((c.compression_ratio() - 64.0).abs() < 1e-6);
403 }
404
405 #[test]
408 fn test_config_returns_correct_input_dim() {
409 let cfg = make_config(100, 25, 42);
410 let c = EmbeddingCompressor::new(cfg);
411 assert_eq!(c.config().input_dim, 100);
412 }
413
414 #[test]
415 fn test_config_returns_correct_output_dim() {
416 let cfg = make_config(100, 25, 42);
417 let c = EmbeddingCompressor::new(cfg);
418 assert_eq!(c.config().output_dim, 25);
419 }
420
421 #[test]
422 fn test_config_returns_correct_seed() {
423 let cfg = make_config(100, 25, 42);
424 let c = EmbeddingCompressor::new(cfg);
425 assert_eq!(c.config().seed, 42);
426 }
427
428 #[test]
431 fn test_similarity_preservation_ratio_in_range() {
432 let cfg = make_config(128, 32, 42);
433 let c = EmbeddingCompressor::new(cfg);
434 let a: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1).sin()).collect();
435 let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.2).cos()).collect();
436 let ratio = c.similarity_preservation_ratio(&a, &b).expect("ok");
437 assert!(ratio.is_finite(), "Ratio should be finite: {ratio}");
440 }
441
442 #[test]
443 fn test_similarity_preservation_parallel_vectors() {
444 let cfg = make_config(64, 16, 7);
445 let c = EmbeddingCompressor::new(cfg);
446 let a = make_vec(64, 1.0);
447 let b = make_vec(64, 2.0); let ratio = c.similarity_preservation_ratio(&a, &b).expect("ok");
450 assert!((ratio - 1.0).abs() < 0.5, "Expected ~1.0, got {ratio}");
452 }
453
454 #[test]
455 fn test_similarity_preservation_wrong_length() {
456 let cfg = make_config(64, 16, 7);
457 let c = EmbeddingCompressor::new(cfg);
458 let a = make_vec(64, 1.0);
459 let b = make_vec(32, 1.0); let result = c.similarity_preservation_ratio(&a, &b);
461 assert!(result.is_err());
462 }
463
464 #[test]
465 fn test_similarity_preservation_zero_vector() {
466 let cfg = make_config(32, 8, 3);
467 let c = EmbeddingCompressor::new(cfg);
468 let a = make_vec(32, 0.0); let b = make_vec(32, 1.0);
470 let result = c.similarity_preservation_ratio(&a, &b);
472 assert!(result.is_ok());
473 }
474
475 #[test]
476 fn test_similarity_preservation_identical_vectors() {
477 let cfg = make_config(64, 16, 9);
478 let c = EmbeddingCompressor::new(cfg);
479 let a: Vec<f32> = (0..64).map(|i| i as f32).collect();
480 let result = c.similarity_preservation_ratio(&a, &a);
481 assert!(result.is_ok());
482 let ratio = result.expect("ok");
484 assert!((0.0..=2.0).contains(&ratio), "ratio={ratio}");
485 }
486
487 #[test]
490 #[allow(clippy::approx_constant)]
491 fn test_minimum_dimensions() {
492 let cfg = make_config(1, 1, 0);
493 let c = EmbeddingCompressor::new(cfg);
494 let out = c.compress(&[3.14]).expect("ok");
495 assert_eq!(out.len(), 1);
496 }
497
498 #[test]
499 fn test_large_dimension() {
500 let cfg = make_config(1024, 128, 42);
501 let c = EmbeddingCompressor::new(cfg);
502 let v = make_vec(1024, 0.5);
503 let out = c.compress(&v).expect("ok");
504 assert_eq!(out.len(), 128);
505 }
506
507 #[test]
508 fn test_seed_zero() {
509 let cfg = make_config(32, 8, 0);
510 let c = EmbeddingCompressor::new(cfg);
511 let v = make_vec(32, 1.0);
512 let out = c.compress(&v).expect("ok");
513 assert_eq!(out.len(), 8);
514 }
515
516 #[test]
517 fn test_projection_sparsity() {
518 let cfg = make_config(300, 100, 12345);
520 let c = EmbeddingCompressor::new(cfg);
521 let total: usize = c.projection.len() * c.projection[0].len();
522 let zeros: usize = c
523 .projection
524 .iter()
525 .flat_map(|row| row.iter())
526 .filter(|&&v| v.abs() < 1e-9)
527 .count();
528 let zero_fraction = zeros as f64 / total as f64;
529 assert!(
531 zero_fraction > 0.50 && zero_fraction < 0.80,
532 "Expected ~2/3 zeros, got {zero_fraction:.3}"
533 );
534 }
535
536 #[test]
537 fn test_batch_size_large() {
538 let cfg = make_config(64, 16, 42);
539 let c = EmbeddingCompressor::new(cfg);
540 let batch: Vec<Vec<f32>> = (0..100).map(|_| make_vec(64, 0.5)).collect();
541 let result = c.compress_batch(&batch).expect("ok");
542 assert_eq!(result.len(), 100);
543 }
544
545 #[test]
546 fn test_different_seeds_compress_differently() {
547 let v = make_vec(64, 1.0);
548 let c1 = EmbeddingCompressor::new(make_config(64, 16, 1));
549 let c2 = EmbeddingCompressor::new(make_config(64, 16, 2));
550 let out1 = c1.compress(&v).expect("ok");
551 let out2 = c2.compress(&v).expect("ok");
552 assert_ne!(out1, out2);
554 }
555
556 #[test]
557 fn test_config_clone() {
558 let cfg = make_config(64, 16, 99);
559 let c = EmbeddingCompressor::new(cfg.clone());
560 assert_eq!(c.config().input_dim, cfg.input_dim);
561 assert_eq!(c.config().output_dim, cfg.output_dim);
562 assert_eq!(c.config().seed, cfg.seed);
563 }
564
565 #[test]
566 fn test_debug_format_config() {
567 let cfg = make_config(64, 16, 42);
568 let debug_str = format!("{cfg:?}");
569 assert!(debug_str.contains("64"));
570 assert!(debug_str.contains("16"));
571 }
572}