mokosh/encoders/
word_embedding.rs

1//! Word Embedding Encoder implementation.
2//!
3//! Converts dense word embedding vectors (word2vec, GloVe, etc.) into SDRs.
4
5use crate::encoders::Encoder;
6use crate::error::{MokoshError, Result};
7use crate::types::{Real, Sdr, UInt};
8use std::collections::HashSet;
9
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13/// Parameters for creating a Word Embedding Encoder.
14#[derive(Debug, Clone)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub struct WordEmbeddingEncoderParams {
17    /// Dimension of input embedding vectors.
18    pub embedding_dim: usize,
19
20    /// Total number of bits in output SDR.
21    pub size: UInt,
22
23    /// Number of active bits in output SDR.
24    pub active_bits: UInt,
25
26    /// Number of random hyperplanes for LSH.
27    /// More hyperplanes = more precision but less overlap for similar embeddings.
28    pub num_hyperplanes: usize,
29}
30
31impl Default for WordEmbeddingEncoderParams {
32    fn default() -> Self {
33        Self {
34            embedding_dim: 300, // Common for word2vec/GloVe
35            size: 2048,
36            active_bits: 41,
37            num_hyperplanes: 128,
38        }
39    }
40}
41
42/// Encodes dense word embeddings into SDR representations.
43///
44/// Uses locality-sensitive hashing (LSH) with random hyperplanes
45/// to convert continuous vectors into sparse binary representations
46/// while preserving cosine similarity.
47///
48/// # Example
49///
50/// ```rust
51/// use mokosh::encoders::{WordEmbeddingEncoder, WordEmbeddingEncoderParams, Encoder};
52///
53/// let encoder = WordEmbeddingEncoder::new(WordEmbeddingEncoderParams {
54///     embedding_dim: 4,
55///     size: 100,
56///     active_bits: 10,
57///     num_hyperplanes: 32,
58/// }).unwrap();
59///
60/// // Two similar embeddings
61/// let embed1 = vec![0.5, 0.3, 0.1, 0.8];
62/// let embed2 = vec![0.6, 0.35, 0.15, 0.75];
63///
64/// // A different embedding
65/// let embed3 = vec![-0.5, -0.3, 0.9, -0.1];
66///
67/// let sdr1 = encoder.encode_to_sdr(embed1).unwrap();
68/// let sdr2 = encoder.encode_to_sdr(embed2).unwrap();
69/// let sdr3 = encoder.encode_to_sdr(embed3).unwrap();
70///
71/// // Similar embeddings should have more overlap
72/// assert!(sdr1.get_overlap(&sdr2) > sdr1.get_overlap(&sdr3));
73/// ```
74#[derive(Debug, Clone)]
75#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
76pub struct WordEmbeddingEncoder {
77    embedding_dim: usize,
78    size: UInt,
79    active_bits: UInt,
80    num_hyperplanes: usize,
81    /// Random hyperplanes for LSH (flattened: num_hyperplanes x embedding_dim).
82    hyperplanes: Vec<Real>,
83    dimensions: Vec<UInt>,
84}
85
86impl WordEmbeddingEncoder {
87    /// Creates a new Word Embedding Encoder.
88    pub fn new(params: WordEmbeddingEncoderParams) -> Result<Self> {
89        Self::with_seed(params, 42)
90    }
91
92    /// Creates a new Word Embedding Encoder with a specific seed.
93    pub fn with_seed(params: WordEmbeddingEncoderParams, seed: u64) -> Result<Self> {
94        if params.embedding_dim == 0 {
95            return Err(MokoshError::InvalidParameter {
96                name: "embedding_dim",
97                message: "Must be > 0".to_string(),
98            });
99        }
100
101        if params.active_bits > params.size {
102            return Err(MokoshError::InvalidParameter {
103                name: "active_bits",
104                message: "Cannot exceed size".to_string(),
105            });
106        }
107
108        if params.num_hyperplanes == 0 {
109            return Err(MokoshError::InvalidParameter {
110                name: "num_hyperplanes",
111                message: "Must be > 0".to_string(),
112            });
113        }
114
115        // Generate random hyperplanes using a simple LCG for reproducibility
116        let mut hyperplanes =
117            Vec::with_capacity(params.num_hyperplanes * params.embedding_dim);
118
119        let mut state = seed;
120        for _ in 0..(params.num_hyperplanes * params.embedding_dim) {
121            // Simple LCG: state = (a * state + c) mod m
122            state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
123            // Convert to [-1, 1] range
124            let value = ((state >> 33) as Real / (u32::MAX as Real / 2.0)) - 1.0;
125            hyperplanes.push(value);
126        }
127
128        Ok(Self {
129            embedding_dim: params.embedding_dim,
130            size: params.size,
131            active_bits: params.active_bits,
132            num_hyperplanes: params.num_hyperplanes,
133            hyperplanes,
134            dimensions: vec![params.size],
135        })
136    }
137
138    /// Returns the embedding dimension.
139    pub fn embedding_dim(&self) -> usize {
140        self.embedding_dim
141    }
142
143    /// Computes the LSH hash for an embedding.
144    fn compute_lsh_hash(&self, embedding: &[Real]) -> u128 {
145        let mut hash: u128 = 0;
146
147        for hp_idx in 0..self.num_hyperplanes.min(128) {
148            let hp_start = hp_idx * self.embedding_dim;
149            let hyperplane = &self.hyperplanes[hp_start..hp_start + self.embedding_dim];
150
151            // Compute dot product
152            let dot: Real = embedding
153                .iter()
154                .zip(hyperplane.iter())
155                .map(|(&e, &h)| e * h)
156                .sum();
157
158            if dot >= 0.0 {
159                hash |= 1u128 << hp_idx;
160            }
161        }
162
163        hash
164    }
165}
166
167impl Encoder<Vec<Real>> for WordEmbeddingEncoder {
168    fn dimensions(&self) -> &[UInt] {
169        &self.dimensions
170    }
171
172    fn size(&self) -> usize {
173        self.size as usize
174    }
175
176    fn encode(&self, embedding: Vec<Real>, output: &mut Sdr) -> Result<()> {
177        if embedding.len() != self.embedding_dim {
178            return Err(MokoshError::InvalidParameter {
179                name: "embedding",
180                message: format!(
181                    "Expected {} dimensions, got {}",
182                    self.embedding_dim,
183                    embedding.len()
184                ),
185            });
186        }
187
188        if output.dimensions() != self.dimensions.as_slice() {
189            return Err(MokoshError::DimensionMismatch {
190                expected: self.dimensions.clone(),
191                actual: output.dimensions().to_vec(),
192            });
193        }
194
195        let lsh_hash = self.compute_lsh_hash(&embedding);
196
197        // Use the LSH hash to deterministically select active bits
198        let mut active_bits = HashSet::new();
199        let mut state = lsh_hash as u64;
200
201        while active_bits.len() < self.active_bits as usize {
202            state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
203            let bit = (state % self.size as u64) as UInt;
204            active_bits.insert(bit);
205        }
206
207        let mut sparse: Vec<UInt> = active_bits.into_iter().collect();
208        sparse.sort_unstable();
209        output.set_sparse_unchecked(sparse);
210
211        Ok(())
212    }
213}
214
215impl Encoder<&[Real]> for WordEmbeddingEncoder {
216    fn dimensions(&self) -> &[UInt] {
217        &self.dimensions
218    }
219
220    fn size(&self) -> usize {
221        self.size as usize
222    }
223
224    fn encode(&self, embedding: &[Real], output: &mut Sdr) -> Result<()> {
225        self.encode(embedding.to_vec(), output)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_create_encoder() {
235        let encoder = WordEmbeddingEncoder::new(WordEmbeddingEncoderParams {
236            embedding_dim: 100,
237            size: 500,
238            active_bits: 25,
239            num_hyperplanes: 64,
240        })
241        .unwrap();
242
243        assert_eq!(encoder.embedding_dim(), 100);
244        assert_eq!(Encoder::<Vec<Real>>::size(&encoder), 500);
245    }
246
247    #[test]
248    fn test_encode_embedding() {
249        let encoder = WordEmbeddingEncoder::new(WordEmbeddingEncoderParams {
250            embedding_dim: 10,
251            size: 200,
252            active_bits: 20,
253            num_hyperplanes: 32,
254        })
255        .unwrap();
256
257        let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5, -0.1, -0.2, -0.3, -0.4, -0.5];
258        let sdr = encoder.encode_to_sdr(embedding).unwrap();
259
260        assert_eq!(sdr.get_sum(), 20);
261    }
262
263    #[test]
264    fn test_similar_embeddings_overlap() {
265        let encoder = WordEmbeddingEncoder::new(WordEmbeddingEncoderParams {
266            embedding_dim: 8,
267            size: 500,
268            active_bits: 25,
269            num_hyperplanes: 64,
270        })
271        .unwrap();
272
273        // Very similar embeddings
274        let embed1 = vec![0.5, 0.3, 0.1, 0.8, 0.2, 0.4, 0.6, 0.1];
275        let embed2 = vec![0.5, 0.3, 0.1, 0.8, 0.2, 0.4, 0.6, 0.1]; // Identical
276
277        // Different embedding
278        let embed3 = vec![-0.5, -0.3, -0.1, -0.8, -0.2, -0.4, -0.6, -0.1];
279
280        let sdr1 = encoder.encode_to_sdr(embed1).unwrap();
281        let sdr2 = encoder.encode_to_sdr(embed2).unwrap();
282        let sdr3 = encoder.encode_to_sdr(embed3).unwrap();
283
284        // Identical should have full overlap
285        assert_eq!(sdr1.get_overlap(&sdr2), 25);
286
287        // Opposite should have less overlap (or different bits entirely)
288        let diff_overlap = sdr1.get_overlap(&sdr3);
289        assert!(diff_overlap < 25);
290    }
291
292    #[test]
293    fn test_deterministic() {
294        let encoder = WordEmbeddingEncoder::new(WordEmbeddingEncoderParams {
295            embedding_dim: 5,
296            size: 100,
297            active_bits: 10,
298            num_hyperplanes: 16,
299        })
300        .unwrap();
301
302        let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
303
304        let sdr1 = encoder.encode_to_sdr(embedding.clone()).unwrap();
305        let sdr2 = encoder.encode_to_sdr(embedding).unwrap();
306
307        assert_eq!(sdr1.get_sparse(), sdr2.get_sparse());
308    }
309
310    #[test]
311    fn test_wrong_dimension() {
312        let encoder = WordEmbeddingEncoder::new(WordEmbeddingEncoderParams {
313            embedding_dim: 10,
314            ..Default::default()
315        })
316        .unwrap();
317
318        let result = encoder.encode_to_sdr(vec![0.1, 0.2, 0.3]); // Wrong size
319        assert!(result.is_err());
320    }
321
322    #[test]
323    fn test_with_seed() {
324        let encoder1 = WordEmbeddingEncoder::with_seed(
325            WordEmbeddingEncoderParams {
326                embedding_dim: 5,
327                size: 100,
328                active_bits: 10,
329                num_hyperplanes: 16,
330            },
331            123,
332        )
333        .unwrap();
334
335        let encoder2 = WordEmbeddingEncoder::with_seed(
336            WordEmbeddingEncoderParams {
337                embedding_dim: 5,
338                size: 100,
339                active_bits: 10,
340                num_hyperplanes: 16,
341            },
342            123,
343        )
344        .unwrap();
345
346        let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
347
348        let sdr1 = encoder1.encode_to_sdr(embedding.clone()).unwrap();
349        let sdr2 = encoder2.encode_to_sdr(embedding).unwrap();
350
351        assert_eq!(sdr1.get_sparse(), sdr2.get_sparse());
352    }
353}