1use crate::encoders::Encoder;
6use crate::error::{MokoshError, Result};
7use crate::types::{Real, Sdr, UInt};
8use std::collections::HashSet;
9
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub struct WordEmbeddingEncoderParams {
17 pub embedding_dim: usize,
19
20 pub size: UInt,
22
23 pub active_bits: UInt,
25
26 pub num_hyperplanes: usize,
29}
30
31impl Default for WordEmbeddingEncoderParams {
32 fn default() -> Self {
33 Self {
34 embedding_dim: 300, size: 2048,
36 active_bits: 41,
37 num_hyperplanes: 128,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
75#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
76pub struct WordEmbeddingEncoder {
77 embedding_dim: usize,
78 size: UInt,
79 active_bits: UInt,
80 num_hyperplanes: usize,
81 hyperplanes: Vec<Real>,
83 dimensions: Vec<UInt>,
84}
85
86impl WordEmbeddingEncoder {
87 pub fn new(params: WordEmbeddingEncoderParams) -> Result<Self> {
89 Self::with_seed(params, 42)
90 }
91
92 pub fn with_seed(params: WordEmbeddingEncoderParams, seed: u64) -> Result<Self> {
94 if params.embedding_dim == 0 {
95 return Err(MokoshError::InvalidParameter {
96 name: "embedding_dim",
97 message: "Must be > 0".to_string(),
98 });
99 }
100
101 if params.active_bits > params.size {
102 return Err(MokoshError::InvalidParameter {
103 name: "active_bits",
104 message: "Cannot exceed size".to_string(),
105 });
106 }
107
108 if params.num_hyperplanes == 0 {
109 return Err(MokoshError::InvalidParameter {
110 name: "num_hyperplanes",
111 message: "Must be > 0".to_string(),
112 });
113 }
114
115 let mut hyperplanes =
117 Vec::with_capacity(params.num_hyperplanes * params.embedding_dim);
118
119 let mut state = seed;
120 for _ in 0..(params.num_hyperplanes * params.embedding_dim) {
121 state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
123 let value = ((state >> 33) as Real / (u32::MAX as Real / 2.0)) - 1.0;
125 hyperplanes.push(value);
126 }
127
128 Ok(Self {
129 embedding_dim: params.embedding_dim,
130 size: params.size,
131 active_bits: params.active_bits,
132 num_hyperplanes: params.num_hyperplanes,
133 hyperplanes,
134 dimensions: vec![params.size],
135 })
136 }
137
138 pub fn embedding_dim(&self) -> usize {
140 self.embedding_dim
141 }
142
143 fn compute_lsh_hash(&self, embedding: &[Real]) -> u128 {
145 let mut hash: u128 = 0;
146
147 for hp_idx in 0..self.num_hyperplanes.min(128) {
148 let hp_start = hp_idx * self.embedding_dim;
149 let hyperplane = &self.hyperplanes[hp_start..hp_start + self.embedding_dim];
150
151 let dot: Real = embedding
153 .iter()
154 .zip(hyperplane.iter())
155 .map(|(&e, &h)| e * h)
156 .sum();
157
158 if dot >= 0.0 {
159 hash |= 1u128 << hp_idx;
160 }
161 }
162
163 hash
164 }
165}
166
167impl Encoder<Vec<Real>> for WordEmbeddingEncoder {
168 fn dimensions(&self) -> &[UInt] {
169 &self.dimensions
170 }
171
172 fn size(&self) -> usize {
173 self.size as usize
174 }
175
176 fn encode(&self, embedding: Vec<Real>, output: &mut Sdr) -> Result<()> {
177 if embedding.len() != self.embedding_dim {
178 return Err(MokoshError::InvalidParameter {
179 name: "embedding",
180 message: format!(
181 "Expected {} dimensions, got {}",
182 self.embedding_dim,
183 embedding.len()
184 ),
185 });
186 }
187
188 if output.dimensions() != self.dimensions.as_slice() {
189 return Err(MokoshError::DimensionMismatch {
190 expected: self.dimensions.clone(),
191 actual: output.dimensions().to_vec(),
192 });
193 }
194
195 let lsh_hash = self.compute_lsh_hash(&embedding);
196
197 let mut active_bits = HashSet::new();
199 let mut state = lsh_hash as u64;
200
201 while active_bits.len() < self.active_bits as usize {
202 state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
203 let bit = (state % self.size as u64) as UInt;
204 active_bits.insert(bit);
205 }
206
207 let mut sparse: Vec<UInt> = active_bits.into_iter().collect();
208 sparse.sort_unstable();
209 output.set_sparse_unchecked(sparse);
210
211 Ok(())
212 }
213}
214
215impl Encoder<&[Real]> for WordEmbeddingEncoder {
216 fn dimensions(&self) -> &[UInt] {
217 &self.dimensions
218 }
219
220 fn size(&self) -> usize {
221 self.size as usize
222 }
223
224 fn encode(&self, embedding: &[Real], output: &mut Sdr) -> Result<()> {
225 self.encode(embedding.to_vec(), output)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn test_create_encoder() {
235 let encoder = WordEmbeddingEncoder::new(WordEmbeddingEncoderParams {
236 embedding_dim: 100,
237 size: 500,
238 active_bits: 25,
239 num_hyperplanes: 64,
240 })
241 .unwrap();
242
243 assert_eq!(encoder.embedding_dim(), 100);
244 assert_eq!(Encoder::<Vec<Real>>::size(&encoder), 500);
245 }
246
247 #[test]
248 fn test_encode_embedding() {
249 let encoder = WordEmbeddingEncoder::new(WordEmbeddingEncoderParams {
250 embedding_dim: 10,
251 size: 200,
252 active_bits: 20,
253 num_hyperplanes: 32,
254 })
255 .unwrap();
256
257 let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5, -0.1, -0.2, -0.3, -0.4, -0.5];
258 let sdr = encoder.encode_to_sdr(embedding).unwrap();
259
260 assert_eq!(sdr.get_sum(), 20);
261 }
262
263 #[test]
264 fn test_similar_embeddings_overlap() {
265 let encoder = WordEmbeddingEncoder::new(WordEmbeddingEncoderParams {
266 embedding_dim: 8,
267 size: 500,
268 active_bits: 25,
269 num_hyperplanes: 64,
270 })
271 .unwrap();
272
273 let embed1 = vec![0.5, 0.3, 0.1, 0.8, 0.2, 0.4, 0.6, 0.1];
275 let embed2 = vec![0.5, 0.3, 0.1, 0.8, 0.2, 0.4, 0.6, 0.1]; let embed3 = vec![-0.5, -0.3, -0.1, -0.8, -0.2, -0.4, -0.6, -0.1];
279
280 let sdr1 = encoder.encode_to_sdr(embed1).unwrap();
281 let sdr2 = encoder.encode_to_sdr(embed2).unwrap();
282 let sdr3 = encoder.encode_to_sdr(embed3).unwrap();
283
284 assert_eq!(sdr1.get_overlap(&sdr2), 25);
286
287 let diff_overlap = sdr1.get_overlap(&sdr3);
289 assert!(diff_overlap < 25);
290 }
291
292 #[test]
293 fn test_deterministic() {
294 let encoder = WordEmbeddingEncoder::new(WordEmbeddingEncoderParams {
295 embedding_dim: 5,
296 size: 100,
297 active_bits: 10,
298 num_hyperplanes: 16,
299 })
300 .unwrap();
301
302 let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
303
304 let sdr1 = encoder.encode_to_sdr(embedding.clone()).unwrap();
305 let sdr2 = encoder.encode_to_sdr(embedding).unwrap();
306
307 assert_eq!(sdr1.get_sparse(), sdr2.get_sparse());
308 }
309
310 #[test]
311 fn test_wrong_dimension() {
312 let encoder = WordEmbeddingEncoder::new(WordEmbeddingEncoderParams {
313 embedding_dim: 10,
314 ..Default::default()
315 })
316 .unwrap();
317
318 let result = encoder.encode_to_sdr(vec![0.1, 0.2, 0.3]); assert!(result.is_err());
320 }
321
322 #[test]
323 fn test_with_seed() {
324 let encoder1 = WordEmbeddingEncoder::with_seed(
325 WordEmbeddingEncoderParams {
326 embedding_dim: 5,
327 size: 100,
328 active_bits: 10,
329 num_hyperplanes: 16,
330 },
331 123,
332 )
333 .unwrap();
334
335 let encoder2 = WordEmbeddingEncoder::with_seed(
336 WordEmbeddingEncoderParams {
337 embedding_dim: 5,
338 size: 100,
339 active_bits: 10,
340 num_hyperplanes: 16,
341 },
342 123,
343 )
344 .unwrap();
345
346 let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
347
348 let sdr1 = encoder1.encode_to_sdr(embedding.clone()).unwrap();
349 let sdr2 = encoder2.encode_to_sdr(embedding).unwrap();
350
351 assert_eq!(sdr1.get_sparse(), sdr2.get_sparse());
352 }
353}