1use std::f64::consts::PI;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum PoolingStrategy {
11 Mean,
13 Max,
15 CLS,
17 Last,
19}
20
21#[derive(Debug, Clone)]
23pub struct EncodingConfig {
24 pub max_length: usize,
26 pub batch_size: usize,
28 pub pooling: PoolingStrategy,
30 pub normalize: bool,
32}
33
34impl Default for EncodingConfig {
35 fn default() -> Self {
36 Self {
37 max_length: 128,
38 batch_size: 32,
39 pooling: PoolingStrategy::Mean,
40 normalize: true,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct TokenizedText {
48 pub tokens: Vec<String>,
50 pub ids: Vec<u32>,
52 pub attention_mask: Vec<u8>,
54}
55
56#[derive(Debug, Clone)]
58pub struct EncodedBatch {
59 pub embeddings: Vec<Vec<f32>>,
61 pub token_counts: Vec<usize>,
63 pub batch_size: usize,
65}
66
67const EMBED_DIM: usize = 128;
69
70const HASH_PRIME: u32 = 7919;
72
73pub struct BatchEncoder {
75 config: EncodingConfig,
76 vocab: std::collections::HashMap<String, u32>,
78 next_id: u32,
80}
81
82impl BatchEncoder {
83 pub fn new(config: EncodingConfig) -> Self {
85 Self {
86 config,
87 vocab: std::collections::HashMap::new(),
88 next_id: 1, }
90 }
91
92 pub fn tokenize(&mut self, text: &str) -> TokenizedText {
95 let raw_tokens: Vec<String> = text.split_whitespace().map(|t| t.to_lowercase()).collect();
96
97 let truncated: Vec<String> = raw_tokens
98 .into_iter()
99 .take(self.config.max_length)
100 .collect();
101
102 let ids: Vec<u32> = truncated
103 .iter()
104 .map(|tok| {
105 if let Some(&id) = self.vocab.get(tok) {
106 id
107 } else {
108 let id = self.next_id;
109 self.vocab.insert(tok.clone(), id);
110 self.next_id = self.next_id.saturating_add(1);
111 id
112 }
113 })
114 .collect();
115
116 let attention_mask = vec![1u8; truncated.len()];
117
118 TokenizedText {
119 tokens: truncated,
120 ids,
121 attention_mask,
122 }
123 }
124
125 fn token_embedding(id: u32) -> Vec<f32> {
132 let mut emb = Vec::with_capacity(EMBED_DIM);
133 for d in 0..EMBED_DIM {
134 let phase = ((id.wrapping_mul(HASH_PRIME).wrapping_add(d as u32)) % 997) as f64 / 997.0
135 * 2.0
136 * PI;
137 let val = if d % 2 == 0 { phase.cos() } else { phase.sin() };
138 emb.push(val as f32);
139 }
140 emb
141 }
142
143 pub fn encode_single(&mut self, text: &str) -> Vec<f32> {
147 let tokenized = self.tokenize(text);
148
149 if tokenized.ids.is_empty() {
150 return vec![0.0f32; EMBED_DIM];
152 }
153
154 let token_embs: Vec<Vec<f32>> = tokenized
155 .ids
156 .iter()
157 .map(|&id| Self::token_embedding(id))
158 .collect();
159
160 let mut pooled = Self::pool(token_embs, &self.config.pooling.clone());
161
162 if self.config.normalize {
163 Self::normalize_l2(&mut pooled);
164 }
165
166 pooled
167 }
168
169 pub fn encode_batch(&mut self, texts: &[&str]) -> EncodedBatch {
171 let mut embeddings = Vec::with_capacity(texts.len());
172 let mut token_counts = Vec::with_capacity(texts.len());
173
174 for chunk in texts.chunks(self.config.batch_size) {
176 for &text in chunk {
177 let tokenized = self.tokenize(text);
178 let count = tokenized.ids.len();
179 token_counts.push(count);
180
181 if tokenized.ids.is_empty() {
182 embeddings.push(vec![0.0f32; EMBED_DIM]);
183 continue;
184 }
185
186 let token_embs: Vec<Vec<f32>> = tokenized
187 .ids
188 .iter()
189 .map(|&id| Self::token_embedding(id))
190 .collect();
191
192 let mut pooled = Self::pool(token_embs, &self.config.pooling.clone());
193
194 if self.config.normalize {
195 Self::normalize_l2(&mut pooled);
196 }
197
198 embeddings.push(pooled);
199 }
200 }
201
202 let batch_size = embeddings.len();
203 EncodedBatch {
204 embeddings,
205 token_counts,
206 batch_size,
207 }
208 }
209
210 pub fn pool(token_embeddings: Vec<Vec<f32>>, strategy: &PoolingStrategy) -> Vec<f32> {
212 if token_embeddings.is_empty() {
213 return vec![0.0f32; EMBED_DIM];
214 }
215
216 let dim = token_embeddings[0].len();
217 let n = token_embeddings.len();
218
219 match strategy {
220 PoolingStrategy::Mean => {
221 let mut result = vec![0.0f32; dim];
222 for emb in &token_embeddings {
223 for (r, &v) in result.iter_mut().zip(emb.iter()) {
224 *r += v;
225 }
226 }
227 for r in result.iter_mut() {
228 *r /= n as f32;
229 }
230 result
231 }
232 PoolingStrategy::Max => {
233 let mut result = vec![f32::NEG_INFINITY; dim];
234 for emb in &token_embeddings {
235 for (r, &v) in result.iter_mut().zip(emb.iter()) {
236 if v > *r {
237 *r = v;
238 }
239 }
240 }
241 result
242 }
243 PoolingStrategy::CLS => {
244 token_embeddings[0].clone()
246 }
247 PoolingStrategy::Last => {
248 token_embeddings[n - 1].clone()
250 }
251 }
252 }
253
254 pub fn normalize_l2(v: &mut [f32]) {
257 let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
258 if norm > 1e-10 {
259 for x in v.iter_mut() {
260 *x /= norm;
261 }
262 }
263 }
264
265 pub fn similarity(a: &[f32], b: &[f32]) -> f64 {
268 if a.len() != b.len() || a.is_empty() {
269 return 0.0;
270 }
271 let dot: f64 = a
272 .iter()
273 .zip(b.iter())
274 .map(|(&x, &y)| x as f64 * y as f64)
275 .sum();
276 let norm_a: f64 = a
277 .iter()
278 .map(|&x| (x as f64) * (x as f64))
279 .sum::<f64>()
280 .sqrt();
281 let norm_b: f64 = b
282 .iter()
283 .map(|&x| (x as f64) * (x as f64))
284 .sum::<f64>()
285 .sqrt();
286 if norm_a < 1e-10 || norm_b < 1e-10 {
287 return 0.0;
288 }
289 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
290 }
291
292 pub fn vocab_size(&self) -> usize {
294 self.vocab.len()
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 fn default_encoder() -> BatchEncoder {
303 BatchEncoder::new(EncodingConfig::default())
304 }
305
306 #[test]
309 fn test_tokenize_basic() {
310 let mut enc = default_encoder();
311 let t = enc.tokenize("hello world");
312 assert_eq!(t.tokens, vec!["hello", "world"]);
313 assert_eq!(t.ids.len(), 2);
314 assert_eq!(t.attention_mask, vec![1, 1]);
315 }
316
317 #[test]
318 fn test_tokenize_empty_string() {
319 let mut enc = default_encoder();
320 let t = enc.tokenize("");
321 assert!(t.tokens.is_empty());
322 assert!(t.ids.is_empty());
323 assert!(t.attention_mask.is_empty());
324 }
325
326 #[test]
327 fn test_tokenize_single_token() {
328 let mut enc = default_encoder();
329 let t = enc.tokenize("rust");
330 assert_eq!(t.tokens, vec!["rust"]);
331 assert_eq!(t.ids.len(), 1);
332 }
333
334 #[test]
335 fn test_tokenize_lowercases() {
336 let mut enc = default_encoder();
337 let t = enc.tokenize("Hello WORLD");
338 assert_eq!(t.tokens, vec!["hello", "world"]);
339 }
340
341 #[test]
342 fn test_tokenize_truncation() {
343 let config = EncodingConfig {
344 max_length: 3,
345 ..EncodingConfig::default()
346 };
347 let mut enc = BatchEncoder::new(config);
348 let t = enc.tokenize("a b c d e");
349 assert_eq!(t.tokens.len(), 3);
350 assert_eq!(t.ids.len(), 3);
351 }
352
353 #[test]
354 fn test_tokenize_max_length_exact() {
355 let config = EncodingConfig {
356 max_length: 2,
357 ..EncodingConfig::default()
358 };
359 let mut enc = BatchEncoder::new(config);
360 let t = enc.tokenize("x y");
361 assert_eq!(t.tokens.len(), 2);
362 }
363
364 #[test]
365 fn test_tokenize_consistent_ids() {
366 let mut enc = default_encoder();
367 let t1 = enc.tokenize("hello");
368 let t2 = enc.tokenize("hello");
369 assert_eq!(t1.ids, t2.ids);
370 }
371
372 #[test]
373 fn test_tokenize_different_words_different_ids() {
374 let mut enc = default_encoder();
375 let t1 = enc.tokenize("foo");
376 let t2 = enc.tokenize("bar");
377 assert_ne!(t1.ids[0], t2.ids[0]);
378 }
379
380 #[test]
383 fn test_encode_single_returns_128_dim() {
384 let mut enc = default_encoder();
385 let emb = enc.encode_single("hello world");
386 assert_eq!(emb.len(), EMBED_DIM);
387 }
388
389 #[test]
390 fn test_encode_single_deterministic() {
391 let mut enc1 = default_encoder();
392 let mut enc2 = default_encoder();
393 let e1 = enc1.encode_single("deterministic test");
394 let e2 = enc2.encode_single("deterministic test");
395 assert_eq!(e1, e2);
396 }
397
398 #[test]
399 fn test_encode_single_normalized_when_flag_set() {
400 let mut enc = default_encoder();
401 let emb = enc.encode_single("normalize me please");
402 let norm: f32 = emb.iter().map(|&x| x * x).sum::<f32>().sqrt();
403 assert!((norm - 1.0).abs() < 1e-5, "Expected unit norm, got {norm}");
404 }
405
406 #[test]
407 fn test_encode_single_no_normalize() {
408 let config = EncodingConfig {
409 normalize: false,
410 ..EncodingConfig::default()
411 };
412 let mut enc = BatchEncoder::new(config);
413 let emb = enc.encode_single("no norm");
414 let norm: f32 = emb.iter().map(|&x| x * x).sum::<f32>().sqrt();
415 assert!(norm >= 0.0);
417 }
418
419 #[test]
420 fn test_encode_single_empty_returns_zeros() {
421 let mut enc = default_encoder();
422 let emb = enc.encode_single("");
423 assert_eq!(emb.len(), EMBED_DIM);
424 assert!(emb.iter().all(|&x| x == 0.0));
425 }
426
427 #[test]
428 fn test_encode_single_different_texts_different_embeddings() {
429 let mut enc = default_encoder();
430 let e1 = enc.encode_single("apple banana cherry");
431 let e2 = enc.encode_single("dog cat fish");
432 assert_ne!(e1, e2);
434 }
435
436 #[test]
439 fn test_encode_batch_count() {
440 let mut enc = default_encoder();
441 let texts = ["one", "two", "three"];
442 let batch = enc.encode_batch(&texts);
443 assert_eq!(batch.batch_size, 3);
444 assert_eq!(batch.embeddings.len(), 3);
445 assert_eq!(batch.token_counts.len(), 3);
446 }
447
448 #[test]
449 fn test_encode_batch_each_128_dim() {
450 let mut enc = default_encoder();
451 let texts = ["alpha", "beta gamma", "delta epsilon zeta"];
452 let batch = enc.encode_batch(&texts);
453 for emb in &batch.embeddings {
454 assert_eq!(emb.len(), EMBED_DIM);
455 }
456 }
457
458 #[test]
459 fn test_encode_batch_token_counts_correct() {
460 let mut enc = BatchEncoder::new(EncodingConfig {
461 max_length: 10,
462 ..EncodingConfig::default()
463 });
464 let texts = ["a b c", "x", "one two three four"];
465 let batch = enc.encode_batch(&texts);
466 assert_eq!(batch.token_counts[0], 3);
467 assert_eq!(batch.token_counts[1], 1);
468 assert_eq!(batch.token_counts[2], 4);
469 }
470
471 #[test]
472 fn test_encode_batch_chunking() {
473 let config = EncodingConfig {
474 batch_size: 2,
475 ..EncodingConfig::default()
476 };
477 let mut enc = BatchEncoder::new(config);
478 let texts: Vec<&str> = (0..5).map(|_| "hello world").collect();
479 let batch = enc.encode_batch(&texts);
480 assert_eq!(batch.batch_size, 5);
481 }
482
483 #[test]
484 fn test_encode_batch_empty_texts() {
485 let mut enc = default_encoder();
486 let texts: Vec<&str> = vec![];
487 let batch = enc.encode_batch(&texts);
488 assert_eq!(batch.batch_size, 0);
489 }
490
491 #[test]
492 fn test_encode_batch_single_text() {
493 let mut enc = default_encoder();
494 let texts = ["only one"];
495 let batch = enc.encode_batch(&texts);
496 assert_eq!(batch.batch_size, 1);
497 }
498
499 fn sample_token_embeddings() -> Vec<Vec<f32>> {
502 vec![
503 vec![1.0, 0.0, 2.0, -1.0],
504 vec![0.0, 3.0, -1.0, 2.0],
505 vec![2.0, 1.0, 0.0, 0.5],
506 ]
507 }
508
509 #[test]
510 fn test_pool_mean() {
511 let embs = sample_token_embeddings();
512 let result = BatchEncoder::pool(embs, &PoolingStrategy::Mean);
513 let expected = [1.0, 4.0 / 3.0, 1.0 / 3.0, 0.5];
514 for (r, e) in result.iter().zip(expected.iter()) {
515 assert!((r - e).abs() < 1e-5, "{r} != {e}");
516 }
517 }
518
519 #[test]
520 fn test_pool_max() {
521 let embs = sample_token_embeddings();
522 let result = BatchEncoder::pool(embs, &PoolingStrategy::Max);
523 let expected = vec![2.0f32, 3.0, 2.0, 2.0];
524 assert_eq!(result, expected);
525 }
526
527 #[test]
528 fn test_pool_cls() {
529 let embs = sample_token_embeddings();
530 let result = BatchEncoder::pool(embs, &PoolingStrategy::CLS);
531 assert_eq!(result, vec![1.0, 0.0, 2.0, -1.0]);
532 }
533
534 #[test]
535 fn test_pool_last() {
536 let embs = sample_token_embeddings();
537 let result = BatchEncoder::pool(embs, &PoolingStrategy::Last);
538 assert_eq!(result, vec![2.0, 1.0, 0.0, 0.5]);
539 }
540
541 #[test]
542 fn test_pool_empty() {
543 let result = BatchEncoder::pool(vec![], &PoolingStrategy::Mean);
544 assert_eq!(result.len(), EMBED_DIM);
545 assert!(result.iter().all(|&x| x == 0.0));
546 }
547
548 #[test]
549 fn test_pool_single_token_mean() {
550 let embs = vec![vec![1.0, 2.0, 3.0]];
551 let result = BatchEncoder::pool(embs.clone(), &PoolingStrategy::Mean);
552 assert_eq!(result, embs[0]);
553 }
554
555 #[test]
556 fn test_pool_single_token_max() {
557 let embs = vec![vec![4.0, 5.0, 6.0]];
558 let result = BatchEncoder::pool(embs.clone(), &PoolingStrategy::Max);
559 assert_eq!(result, embs[0]);
560 }
561
562 #[test]
565 fn test_normalize_unit_norm() {
566 let mut v = vec![3.0f32, 4.0, 0.0];
567 BatchEncoder::normalize_l2(&mut v);
568 let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
569 assert!((norm - 1.0).abs() < 1e-6);
570 assert!((v[0] - 0.6).abs() < 1e-5);
571 assert!((v[1] - 0.8).abs() < 1e-5);
572 }
573
574 #[test]
575 fn test_normalize_zero_vector() {
576 let mut v = vec![0.0f32, 0.0, 0.0];
577 BatchEncoder::normalize_l2(&mut v);
578 assert!(v.iter().all(|&x| x == 0.0));
580 }
581
582 #[test]
583 fn test_normalize_already_unit() {
584 let mut v = vec![1.0f32, 0.0, 0.0];
585 BatchEncoder::normalize_l2(&mut v);
586 assert!((v[0] - 1.0).abs() < 1e-6);
587 }
588
589 #[test]
592 fn test_similarity_identical_vectors() {
593 let v = vec![1.0f32, 0.0, 0.0];
594 let sim = BatchEncoder::similarity(&v, &v);
595 assert!((sim - 1.0).abs() < 1e-6);
596 }
597
598 #[test]
599 fn test_similarity_orthogonal_vectors() {
600 let a = vec![1.0f32, 0.0, 0.0];
601 let b = vec![0.0f32, 1.0, 0.0];
602 let sim = BatchEncoder::similarity(&a, &b);
603 assert!(sim.abs() < 1e-6);
604 }
605
606 #[test]
607 fn test_similarity_opposite_vectors() {
608 let a = vec![1.0f32, 0.0];
609 let b = vec![-1.0f32, 0.0];
610 let sim = BatchEncoder::similarity(&a, &b);
611 assert!((sim - (-1.0)).abs() < 1e-6);
612 }
613
614 #[test]
615 fn test_similarity_zero_vector() {
616 let a = vec![1.0f32, 0.0];
617 let b = vec![0.0f32, 0.0];
618 let sim = BatchEncoder::similarity(&a, &b);
619 assert_eq!(sim, 0.0);
620 }
621
622 #[test]
623 fn test_similarity_mismatched_len() {
624 let a = vec![1.0f32, 0.0];
625 let b = vec![1.0f32, 0.0, 0.5];
626 let sim = BatchEncoder::similarity(&a, &b);
627 assert_eq!(sim, 0.0);
628 }
629
630 #[test]
631 fn test_similarity_empty_vectors() {
632 let sim = BatchEncoder::similarity(&[], &[]);
633 assert_eq!(sim, 0.0);
634 }
635
636 #[test]
637 fn test_similarity_bounded() {
638 let mut enc = default_encoder();
639 let e1 = enc.encode_single("semantic similarity test");
640 let e2 = enc.encode_single("another sentence here");
641 let sim = BatchEncoder::similarity(&e1, &e2);
642 assert!((-1.0..=1.0).contains(&sim));
643 }
644
645 #[test]
648 fn test_vocab_grows() {
649 let mut enc = default_encoder();
650 assert_eq!(enc.vocab_size(), 0);
651 enc.tokenize("alpha beta gamma");
652 assert_eq!(enc.vocab_size(), 3);
653 enc.tokenize("alpha delta"); assert_eq!(enc.vocab_size(), 4);
655 }
656
657 #[test]
658 fn test_encode_batch_matches_single() {
659 let mut enc = default_encoder();
660 let texts = ["hello world", "foo bar baz"];
661 let e_single_a = enc.encode_single(texts[0]);
662 let e_single_b = enc.encode_single(texts[1]);
663 let batch = enc.encode_batch(&texts);
664 assert_eq!(batch.embeddings[0], e_single_a);
665 assert_eq!(batch.embeddings[1], e_single_b);
666 }
667}