Skip to main content

oxirs_embed/
embedding_compressor.rs

1//! Random projection for embedding dimensionality reduction.
2//!
3//! Implements Achlioptas (2003) sparse random projection for efficient
4//! dimensionality reduction of embedding vectors while approximately
5//! preserving pairwise distances.
6
7/// Configuration for random projection compression.
8#[derive(Debug, Clone)]
9pub struct CompressionConfig {
10    /// Original embedding dimension.
11    pub input_dim: usize,
12    /// Target (compressed) dimension.
13    pub output_dim: usize,
14    /// Random seed for reproducible projections.
15    pub seed: u64,
16}
17
18/// Random projection matrix compressor using Achlioptas (2003) sparse projection.
19///
20/// Each entry of the projection matrix is +sqrt(3), 0, or -sqrt(3)
21/// with probabilities 1/6, 2/3, 1/6 respectively.
22pub struct EmbeddingCompressor {
23    config: CompressionConfig,
24    /// Projection matrix [output_dim x input_dim].
25    projection: Vec<Vec<f32>>,
26}
27
28/// Simple LCG random number generator for seeded, reproducible projections.
29struct LcgRng {
30    state: u64,
31}
32
33impl LcgRng {
34    fn new(seed: u64) -> Self {
35        Self {
36            state: seed.wrapping_add(1),
37        }
38    }
39
40    /// Advance and return next u64 in [0, u64::MAX].
41    fn next_u64(&mut self) -> u64 {
42        // LCG parameters from Knuth TAOCP Vol 2
43        self.state = self
44            .state
45            .wrapping_mul(6_364_136_223_846_793_005)
46            .wrapping_add(1_442_695_040_888_963_407);
47        self.state
48    }
49
50    /// Return a value in [0, 5] uniformly.
51    fn next_sixths(&mut self) -> u64 {
52        self.next_u64() % 6
53    }
54}
55
56impl EmbeddingCompressor {
57    /// Build compressor with a sparse random projection matrix (Achlioptas, 2003).
58    ///
59    /// Each matrix entry is:
60    /// - +sqrt(3) with probability 1/6
61    /// - 0        with probability 2/3
62    /// - -sqrt(3) with probability 1/6
63    pub fn new(config: CompressionConfig) -> Self {
64        let scale = (3.0_f32).sqrt();
65        let mut rng = LcgRng::new(config.seed);
66        let mut projection = Vec::with_capacity(config.output_dim);
67
68        for _ in 0..config.output_dim {
69            let mut row = Vec::with_capacity(config.input_dim);
70            for _ in 0..config.input_dim {
71                let val = match rng.next_sixths() {
72                    0 => scale,   // prob 1/6
73                    5 => -scale,  // prob 1/6
74                    _ => 0.0_f32, // prob 4/6 = 2/3
75                };
76                row.push(val);
77            }
78            projection.push(row);
79        }
80
81        Self { config, projection }
82    }
83
84    /// Compress a single embedding vector.
85    ///
86    /// Returns an error if the input length does not match `input_dim`.
87    pub fn compress(&self, embedding: &[f32]) -> Result<Vec<f32>, String> {
88        if embedding.len() != self.config.input_dim {
89            return Err(format!(
90                "Expected embedding of length {}, got {}",
91                self.config.input_dim,
92                embedding.len()
93            ));
94        }
95
96        let scale = 1.0_f32 / (self.config.output_dim as f32).sqrt();
97        let compressed = self
98            .projection
99            .iter()
100            .map(|row| {
101                let dot: f32 = row.iter().zip(embedding.iter()).map(|(r, e)| r * e).sum();
102                dot * scale
103            })
104            .collect();
105
106        Ok(compressed)
107    }
108
109    /// Compress a batch of embeddings.
110    ///
111    /// Returns an error if any embedding has incorrect length.
112    pub fn compress_batch(&self, embeddings: &[Vec<f32>]) -> Result<Vec<Vec<f32>>, String> {
113        embeddings.iter().map(|e| self.compress(e)).collect()
114    }
115
116    /// Approximate the similarity preservation ratio between two vectors.
117    ///
118    /// Computes cosine similarity in both original and compressed spaces and
119    /// returns the ratio (compressed / original). Should be close to 1.0 for
120    /// high-dimensional inputs (Johnson-Lindenstrauss lemma).
121    pub fn similarity_preservation_ratio(&self, a: &[f32], b: &[f32]) -> Result<f32, String> {
122        let original_sim = cosine_similarity(a, b)?;
123        let a_comp = self.compress(a)?;
124        let b_comp = self.compress(b)?;
125        let compressed_sim = cosine_similarity(&a_comp, &b_comp)?;
126
127        // Avoid division by zero; if original similarity is ~0, return compressed_sim
128        if original_sim.abs() < 1e-9 {
129            return Ok(compressed_sim.abs());
130        }
131
132        Ok(compressed_sim / original_sim)
133    }
134
135    /// Return a reference to the compression configuration.
136    pub fn config(&self) -> &CompressionConfig {
137        &self.config
138    }
139
140    /// Return the compression ratio (input_dim / output_dim).
141    pub fn compression_ratio(&self) -> f32 {
142        self.config.input_dim as f32 / self.config.output_dim as f32
143    }
144}
145
146/// Compute cosine similarity between two vectors.
147fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, String> {
148    if a.len() != b.len() {
149        return Err(format!(
150            "Vector length mismatch: {} vs {}",
151            a.len(),
152            b.len()
153        ));
154    }
155    if a.is_empty() {
156        return Err("Cannot compute cosine similarity of empty vectors".to_string());
157    }
158
159    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
160    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
161    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
162
163    if norm_a < 1e-9 || norm_b < 1e-9 {
164        return Ok(0.0);
165    }
166
167    Ok(dot / (norm_a * norm_b))
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    fn make_config(input_dim: usize, output_dim: usize, seed: u64) -> CompressionConfig {
175        CompressionConfig {
176            input_dim,
177            output_dim,
178            seed,
179        }
180    }
181
182    fn make_vec(dim: usize, val: f32) -> Vec<f32> {
183        vec![val; dim]
184    }
185
186    fn unit_vec(dim: usize, idx: usize) -> Vec<f32> {
187        let mut v = vec![0.0_f32; dim];
188        v[idx] = 1.0;
189        v
190    }
191
192    // --- Construction ---
193
194    #[test]
195    fn test_new_creates_correct_projection_dims() {
196        let cfg = make_config(128, 32, 42);
197        let c = EmbeddingCompressor::new(cfg);
198        assert_eq!(c.projection.len(), 32);
199        for row in &c.projection {
200            assert_eq!(row.len(), 128);
201        }
202    }
203
204    #[test]
205    fn test_new_entries_are_valid_achlioptas_values() {
206        let scale = (3.0_f32).sqrt();
207        let cfg = make_config(64, 16, 7);
208        let c = EmbeddingCompressor::new(cfg);
209        for row in &c.projection {
210            for &v in row {
211                assert!(
212                    (v - scale).abs() < 1e-6 || v.abs() < 1e-6 || (v + scale).abs() < 1e-6,
213                    "Unexpected value: {v}"
214                );
215            }
216        }
217    }
218
219    #[test]
220    fn test_seed_reproducibility() {
221        let cfg1 = make_config(64, 16, 99);
222        let cfg2 = make_config(64, 16, 99);
223        let c1 = EmbeddingCompressor::new(cfg1);
224        let c2 = EmbeddingCompressor::new(cfg2);
225        assert_eq!(c1.projection, c2.projection);
226    }
227
228    #[test]
229    fn test_different_seeds_produce_different_matrices() {
230        let c1 = EmbeddingCompressor::new(make_config(64, 16, 1));
231        let c2 = EmbeddingCompressor::new(make_config(64, 16, 2));
232        // With high probability the matrices differ
233        assert_ne!(c1.projection, c2.projection);
234    }
235
236    // --- compress ---
237
238    #[test]
239    fn test_compress_output_length_equals_output_dim() {
240        let cfg = make_config(128, 32, 0);
241        let c = EmbeddingCompressor::new(cfg);
242        let v = make_vec(128, 1.0);
243        let out = c.compress(&v).expect("compress should succeed");
244        assert_eq!(out.len(), 32);
245    }
246
247    #[test]
248    fn test_compress_wrong_input_length_returns_error() {
249        let cfg = make_config(128, 32, 0);
250        let c = EmbeddingCompressor::new(cfg);
251        let v = make_vec(64, 1.0);
252        let result = c.compress(&v);
253        assert!(result.is_err());
254    }
255
256    #[test]
257    fn test_compress_zero_vector() {
258        let cfg = make_config(64, 16, 5);
259        let c = EmbeddingCompressor::new(cfg);
260        let v = make_vec(64, 0.0);
261        let out = c.compress(&v).expect("compress should succeed");
262        for &x in &out {
263            assert!((x).abs() < 1e-9, "Expected zero vector, got {x}");
264        }
265    }
266
267    #[test]
268    fn test_compress_single_dim_input() {
269        let cfg = make_config(1, 1, 0);
270        let c = EmbeddingCompressor::new(cfg);
271        let v = vec![2.0_f32];
272        let out = c.compress(&v).expect("compress should succeed");
273        assert_eq!(out.len(), 1);
274    }
275
276    #[test]
277    fn test_compress_exact_output_dimension() {
278        for (input_dim, output_dim) in [(256, 64), (512, 128), (100, 50)] {
279            let cfg = make_config(input_dim, output_dim, 42);
280            let c = EmbeddingCompressor::new(cfg);
281            let v = make_vec(input_dim, 1.0);
282            let out = c.compress(&v).expect("compress ok");
283            assert_eq!(out.len(), output_dim);
284        }
285    }
286
287    #[test]
288    fn test_compress_is_deterministic() {
289        let cfg = make_config(64, 16, 13);
290        let c = EmbeddingCompressor::new(cfg);
291        let v: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
292        let out1 = c.compress(&v).expect("ok");
293        let out2 = c.compress(&v).expect("ok");
294        assert_eq!(out1, out2);
295    }
296
297    #[test]
298    fn test_compress_linearity_scalar_multiple() {
299        let cfg = make_config(32, 8, 17);
300        let c = EmbeddingCompressor::new(cfg);
301        let v: Vec<f32> = (0..32).map(|i| i as f32).collect();
302        let out1 = c.compress(&v).expect("ok");
303        let v2: Vec<f32> = v.iter().map(|&x| x * 2.0).collect();
304        let out2 = c.compress(&v2).expect("ok");
305        for (a, b) in out1.iter().zip(out2.iter()) {
306            assert!((b - 2.0 * a).abs() < 1e-5, "Linearity failed: {a} vs {b}");
307        }
308    }
309
310    #[test]
311    fn test_compress_unit_vector() {
312        let cfg = make_config(32, 8, 11);
313        let c = EmbeddingCompressor::new(cfg);
314        let v = unit_vec(32, 0);
315        let out = c.compress(&v).expect("ok");
316        assert_eq!(out.len(), 8);
317    }
318
319    // --- compress_batch ---
320
321    #[test]
322    fn test_compress_batch_correct_count() {
323        let cfg = make_config(64, 16, 0);
324        let c = EmbeddingCompressor::new(cfg);
325        let batch: Vec<Vec<f32>> = (0..5).map(|_| make_vec(64, 1.0)).collect();
326        let result = c.compress_batch(&batch).expect("ok");
327        assert_eq!(result.len(), 5);
328    }
329
330    #[test]
331    fn test_compress_batch_each_output_length() {
332        let cfg = make_config(64, 16, 0);
333        let c = EmbeddingCompressor::new(cfg);
334        let batch: Vec<Vec<f32>> = (0..3).map(|_| make_vec(64, 1.0)).collect();
335        let result = c.compress_batch(&batch).expect("ok");
336        for out in &result {
337            assert_eq!(out.len(), 16);
338        }
339    }
340
341    #[test]
342    fn test_compress_batch_empty() {
343        let cfg = make_config(64, 16, 0);
344        let c = EmbeddingCompressor::new(cfg);
345        let result = c.compress_batch(&[]).expect("ok");
346        assert!(result.is_empty());
347    }
348
349    #[test]
350    fn test_compress_batch_error_on_wrong_size() {
351        let cfg = make_config(64, 16, 0);
352        let c = EmbeddingCompressor::new(cfg);
353        let batch = vec![make_vec(64, 1.0), make_vec(32, 1.0)];
354        let result = c.compress_batch(&batch);
355        assert!(result.is_err());
356    }
357
358    #[test]
359    fn test_compress_batch_single_element() {
360        let cfg = make_config(64, 16, 0);
361        let c = EmbeddingCompressor::new(cfg);
362        let batch = vec![make_vec(64, 0.5)];
363        let result = c.compress_batch(&batch).expect("ok");
364        assert_eq!(result.len(), 1);
365        assert_eq!(result[0].len(), 16);
366    }
367
368    #[test]
369    fn test_compress_batch_matches_individual_compress() {
370        let cfg = make_config(32, 8, 55);
371        let c = EmbeddingCompressor::new(cfg);
372        let v1: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
373        let v2: Vec<f32> = (0..32).map(|i| (i as f32 * 0.1).sin()).collect();
374        let individual1 = c.compress(&v1).expect("ok");
375        let individual2 = c.compress(&v2).expect("ok");
376        let batch = c.compress_batch(&[v1, v2]).expect("ok");
377        assert_eq!(batch[0], individual1);
378        assert_eq!(batch[1], individual2);
379    }
380
381    // --- compression_ratio ---
382
383    #[test]
384    fn test_compression_ratio_basic() {
385        let cfg = make_config(128, 32, 0);
386        let c = EmbeddingCompressor::new(cfg);
387        let ratio = c.compression_ratio();
388        assert!((ratio - 4.0).abs() < 1e-6, "Expected 4.0, got {ratio}");
389    }
390
391    #[test]
392    fn test_compression_ratio_no_compression() {
393        let cfg = make_config(64, 64, 0);
394        let c = EmbeddingCompressor::new(cfg);
395        assert!((c.compression_ratio() - 1.0).abs() < 1e-6);
396    }
397
398    #[test]
399    fn test_compression_ratio_high() {
400        let cfg = make_config(512, 8, 0);
401        let c = EmbeddingCompressor::new(cfg);
402        assert!((c.compression_ratio() - 64.0).abs() < 1e-6);
403    }
404
405    // --- config ---
406
407    #[test]
408    fn test_config_returns_correct_input_dim() {
409        let cfg = make_config(100, 25, 42);
410        let c = EmbeddingCompressor::new(cfg);
411        assert_eq!(c.config().input_dim, 100);
412    }
413
414    #[test]
415    fn test_config_returns_correct_output_dim() {
416        let cfg = make_config(100, 25, 42);
417        let c = EmbeddingCompressor::new(cfg);
418        assert_eq!(c.config().output_dim, 25);
419    }
420
421    #[test]
422    fn test_config_returns_correct_seed() {
423        let cfg = make_config(100, 25, 42);
424        let c = EmbeddingCompressor::new(cfg);
425        assert_eq!(c.config().seed, 42);
426    }
427
428    // --- similarity_preservation_ratio ---
429
430    #[test]
431    fn test_similarity_preservation_ratio_in_range() {
432        let cfg = make_config(128, 32, 42);
433        let c = EmbeddingCompressor::new(cfg);
434        let a: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1).sin()).collect();
435        let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.2).cos()).collect();
436        let ratio = c.similarity_preservation_ratio(&a, &b).expect("ok");
437        // The ratio can be large when original similarity is near zero;
438        // just verify it is finite.
439        assert!(ratio.is_finite(), "Ratio should be finite: {ratio}");
440    }
441
442    #[test]
443    fn test_similarity_preservation_parallel_vectors() {
444        let cfg = make_config(64, 16, 7);
445        let c = EmbeddingCompressor::new(cfg);
446        let a = make_vec(64, 1.0);
447        let b = make_vec(64, 2.0); // parallel to a
448                                   // Both original and compressed cosine sim should be 1.0
449        let ratio = c.similarity_preservation_ratio(&a, &b).expect("ok");
450        // ratio should be approximately 1.0 (1.0 / 1.0)
451        assert!((ratio - 1.0).abs() < 0.5, "Expected ~1.0, got {ratio}");
452    }
453
454    #[test]
455    fn test_similarity_preservation_wrong_length() {
456        let cfg = make_config(64, 16, 7);
457        let c = EmbeddingCompressor::new(cfg);
458        let a = make_vec(64, 1.0);
459        let b = make_vec(32, 1.0); // wrong length
460        let result = c.similarity_preservation_ratio(&a, &b);
461        assert!(result.is_err());
462    }
463
464    #[test]
465    fn test_similarity_preservation_zero_vector() {
466        let cfg = make_config(32, 8, 3);
467        let c = EmbeddingCompressor::new(cfg);
468        let a = make_vec(32, 0.0); // zero vector
469        let b = make_vec(32, 1.0);
470        // Should not panic; original cosine similarity is 0
471        let result = c.similarity_preservation_ratio(&a, &b);
472        assert!(result.is_ok());
473    }
474
475    #[test]
476    fn test_similarity_preservation_identical_vectors() {
477        let cfg = make_config(64, 16, 9);
478        let c = EmbeddingCompressor::new(cfg);
479        let a: Vec<f32> = (0..64).map(|i| i as f32).collect();
480        let result = c.similarity_preservation_ratio(&a, &a);
481        assert!(result.is_ok());
482        // Ratio should be close to 1.0 (cosine similarity of identical vectors in both spaces)
483        let ratio = result.expect("ok");
484        assert!((0.0..=2.0).contains(&ratio), "ratio={ratio}");
485    }
486
487    // --- Edge cases ---
488
489    #[test]
490    #[allow(clippy::approx_constant)]
491    fn test_minimum_dimensions() {
492        let cfg = make_config(1, 1, 0);
493        let c = EmbeddingCompressor::new(cfg);
494        let out = c.compress(&[3.14]).expect("ok");
495        assert_eq!(out.len(), 1);
496    }
497
498    #[test]
499    fn test_large_dimension() {
500        let cfg = make_config(1024, 128, 42);
501        let c = EmbeddingCompressor::new(cfg);
502        let v = make_vec(1024, 0.5);
503        let out = c.compress(&v).expect("ok");
504        assert_eq!(out.len(), 128);
505    }
506
507    #[test]
508    fn test_seed_zero() {
509        let cfg = make_config(32, 8, 0);
510        let c = EmbeddingCompressor::new(cfg);
511        let v = make_vec(32, 1.0);
512        let out = c.compress(&v).expect("ok");
513        assert_eq!(out.len(), 8);
514    }
515
516    #[test]
517    fn test_projection_sparsity() {
518        // Achlioptas matrix should be ~2/3 zeros
519        let cfg = make_config(300, 100, 12345);
520        let c = EmbeddingCompressor::new(cfg);
521        let total: usize = c.projection.len() * c.projection[0].len();
522        let zeros: usize = c
523            .projection
524            .iter()
525            .flat_map(|row| row.iter())
526            .filter(|&&v| v.abs() < 1e-9)
527            .count();
528        let zero_fraction = zeros as f64 / total as f64;
529        // Should be roughly 2/3 zeros (allow ±15% slack)
530        assert!(
531            zero_fraction > 0.50 && zero_fraction < 0.80,
532            "Expected ~2/3 zeros, got {zero_fraction:.3}"
533        );
534    }
535
536    #[test]
537    fn test_batch_size_large() {
538        let cfg = make_config(64, 16, 42);
539        let c = EmbeddingCompressor::new(cfg);
540        let batch: Vec<Vec<f32>> = (0..100).map(|_| make_vec(64, 0.5)).collect();
541        let result = c.compress_batch(&batch).expect("ok");
542        assert_eq!(result.len(), 100);
543    }
544
545    #[test]
546    fn test_different_seeds_compress_differently() {
547        let v = make_vec(64, 1.0);
548        let c1 = EmbeddingCompressor::new(make_config(64, 16, 1));
549        let c2 = EmbeddingCompressor::new(make_config(64, 16, 2));
550        let out1 = c1.compress(&v).expect("ok");
551        let out2 = c2.compress(&v).expect("ok");
552        // With very high probability, outputs differ
553        assert_ne!(out1, out2);
554    }
555
556    #[test]
557    fn test_config_clone() {
558        let cfg = make_config(64, 16, 99);
559        let c = EmbeddingCompressor::new(cfg.clone());
560        assert_eq!(c.config().input_dim, cfg.input_dim);
561        assert_eq!(c.config().output_dim, cfg.output_dim);
562        assert_eq!(c.config().seed, cfg.seed);
563    }
564
565    #[test]
566    fn test_debug_format_config() {
567        let cfg = make_config(64, 16, 42);
568        let debug_str = format!("{cfg:?}");
569        assert!(debug_str.contains("64"));
570        assert!(debug_str.contains("16"));
571    }
572}