1use std::collections::HashMap;
4use std::io::{Read, Seek, Write};
5use std::iter::Enumerate;
6use std::mem;
7use std::slice;
8
9use ndarray::{s, Array1, Array2, ArrayViewMut1, ArrayViewMut2, Axis, CowArray, Ix1};
10use rand::{CryptoRng, RngCore, SeedableRng};
11use rand_chacha::ChaChaRng;
12use reductive::pq::TrainPq;
13
14use crate::chunks::io::{ChunkIdentifier, Header, ReadChunk, WriteChunk};
15use crate::chunks::metadata::Metadata;
16use crate::chunks::norms::NdNorms;
17use crate::chunks::storage::{
18 sealed::CloneFromMapping, NdArray, Quantize as QuantizeStorage, QuantizedArray, Storage,
19 StorageView, StorageViewWrap, StorageWrap,
20};
21use crate::chunks::vocab::{
22 BucketSubwordVocab, ExplicitSubwordVocab, FastTextSubwordVocab, SimpleVocab, SubwordVocab,
23 Vocab, VocabWrap, WordIndex,
24};
25use crate::error::{Error, Result};
26use crate::io::{ReadEmbeddings, WriteEmbeddings};
27use crate::subword::BucketIndexer;
28use crate::util::l2_normalize;
29use crate::vocab::FloretSubwordVocab;
30
31#[derive(Clone, Debug)]
37pub struct Embeddings<V, S> {
38 metadata: Option<Metadata>,
39 storage: S,
40 vocab: V,
41 norms: Option<NdNorms>,
42}
43
44impl<V, S> Embeddings<V, S>
45where
46 V: Vocab,
47 S: Storage,
48{
49 pub fn new(metadata: Option<Metadata>, vocab: V, storage: S, norms: NdNorms) -> Self {
55 assert_eq!(
56 vocab.words_len(),
57 norms.len(),
58 "Vocab and norms do not have the same length"
59 );
60 Embeddings::new_with_maybe_norms(metadata, vocab, storage, Some(norms))
61 }
62
63 pub(crate) fn new_with_maybe_norms(
64 metadata: Option<Metadata>,
65 vocab: V,
66 storage: S,
67 norms: Option<NdNorms>,
68 ) -> Self {
69 assert_eq!(
70 vocab.vocab_len(),
71 storage.shape().0,
72 "Max vocab index must match number of rows in the embedding matrix."
73 );
74 Embeddings {
75 metadata,
76 storage,
77 vocab,
78 norms,
79 }
80 }
81}
82
83impl<V, S> Embeddings<V, S> {
84 pub fn into_parts(self) -> (Option<Metadata>, V, S, Option<NdNorms>) {
87 (self.metadata, self.vocab, self.storage, self.norms)
88 }
89
90 pub fn metadata(&self) -> Option<&Metadata> {
92 self.metadata.as_ref()
93 }
94
95 pub fn metadata_mut(&mut self) -> Option<&mut Metadata> {
97 self.metadata.as_mut()
98 }
99
100 pub fn norms(&self) -> Option<&NdNorms> {
102 self.norms.as_ref()
103 }
104
105 pub fn set_metadata(&mut self, mut metadata: Option<Metadata>) -> Option<Metadata> {
109 mem::swap(&mut self.metadata, &mut metadata);
110 metadata
111 }
112
113 pub fn storage(&self) -> &S {
115 &self.storage
116 }
117
118 pub fn vocab(&self) -> &V {
120 &self.vocab
121 }
122}
123
124#[allow(clippy::len_without_is_empty)]
125impl<V, S> Embeddings<V, S>
126where
127 V: Vocab,
128 S: Storage,
129{
130 pub fn dims(&self) -> usize {
132 self.storage.shape().1
133 }
134
135 pub fn embedding(&self, word: &str) -> Option<CowArray<f32, Ix1>> {
137 match self.vocab.idx(word)? {
138 WordIndex::Word(idx) => Some(self.storage.embedding(idx)),
139 WordIndex::Subword(indices) => {
140 let embeds = self.storage.embeddings(&indices);
141 let mut embed = embeds.sum_axis(Axis(0));
142 l2_normalize(embed.view_mut());
143
144 Some(CowArray::from(embed))
145 }
146 }
147 }
148
149 pub fn embedding_into(&self, word: &str, mut target: ArrayViewMut1<f32>) -> bool {
160 assert_eq!(
161 target.len(),
162 self.dims(),
163 "Embeddings have {} dimensions, whereas target array has {}",
164 self.dims(),
165 target.len()
166 );
167
168 let index = if let Some(idx) = self.vocab.idx(word) {
169 idx
170 } else {
171 return false;
172 };
173
174 match index {
175 WordIndex::Word(idx) => target.assign(&self.storage.embedding(idx)),
176 WordIndex::Subword(indices) => {
177 target.fill(0.);
178
179 let embeds = self.storage.embeddings(&indices);
180
181 for embed in embeds.outer_iter() {
182 target += &embed;
183 }
184
185 l2_normalize(target.view_mut());
186 }
187 }
188
189 true
190 }
191
192 pub fn embedding_batch(&self, words: &[impl AsRef<str>]) -> (Array2<f32>, Vec<bool>) {
197 let mut embeddings = Array2::zeros((words.len(), self.dims()));
198 let found = self.embedding_batch_into(words, embeddings.view_mut());
199 (embeddings, found)
200 }
201
202 pub fn embedding_batch_into(
209 &self,
210 words: &[impl AsRef<str>],
211 mut output: ArrayViewMut2<f32>,
212 ) -> Vec<bool> {
213 assert_eq!(
214 output.len_of(Axis(0)),
215 words.len(),
216 "Expected embedding matrix for batch size {}, got {}",
217 words.len(),
218 output.len_of(Axis(0))
219 );
220 assert_eq!(
221 output.len_of(Axis(1)),
222 self.dims(),
223 "Expected embedding matrix for embeddings with dimensionality {}, got {}",
224 self.dims(),
225 output.len_of(Axis(1))
226 );
227
228 let mut found = vec![false; words.len()];
229
230 let mut word_indices: HashMap<_, Vec<_>> = HashMap::new();
231 for (idx, word) in words.iter().enumerate() {
232 let indices = word_indices.entry(word.as_ref()).or_default();
233 indices.push(idx);
234 }
235
236 for (word, indices) in word_indices {
237 let idx_first = indices[0];
240 if self.embedding_into(word, output.index_axis_mut(Axis(0), idx_first)) {
241 found[idx_first] = true;
242
243 for &idx in indices.iter().skip(1) {
244 let (first, mut cur) = output.multi_slice_mut((s![idx_first, ..], s![idx, ..]));
245 cur.assign(&first);
246 found[idx] = true;
247 }
248 }
249 }
250
251 found
252 }
253
254 pub fn embedding_with_norm(&self, word: &str) -> Option<EmbeddingWithNorm> {
267 match self.vocab.idx(word)? {
268 WordIndex::Word(idx) => Some(EmbeddingWithNorm {
269 embedding: self.storage.embedding(idx),
270 norm: self.norms().map(|n| n[idx]).unwrap_or(1.),
271 }),
272 WordIndex::Subword(indices) => {
273 let embeds = self.storage.embeddings(&indices);
274 let mut embed = embeds.sum_axis(Axis(0));
275
276 let norm = l2_normalize(embed.view_mut());
277
278 Some(EmbeddingWithNorm {
279 embedding: CowArray::from(embed),
280 norm,
281 })
282 }
283 }
284 }
285
286 pub fn iter(&self) -> Iter {
288 Iter {
289 storage: &self.storage,
290 inner: self.vocab.words().iter().enumerate(),
291 }
292 }
293
294 pub fn iter_with_norms(&self) -> IterWithNorms {
308 IterWithNorms {
309 storage: &self.storage,
310 norms: self.norms(),
311 inner: self.vocab.words().iter().enumerate(),
312 }
313 }
314
315 pub fn len(&self) -> usize {
319 self.vocab.words_len()
320 }
321}
322
323impl<I, S> Embeddings<SubwordVocab<I>, S>
324where
325 I: BucketIndexer,
326 S: Storage + CloneFromMapping,
327{
328 pub fn to_explicit(&self) -> Result<Embeddings<ExplicitSubwordVocab, S::Result>> {
330 to_explicit_impl(self.vocab(), self.storage(), self.norms())
331 }
332}
333
334impl<S> Embeddings<VocabWrap, S>
335where
336 S: Storage + CloneFromMapping,
337{
338 pub fn try_to_explicit(&self) -> Result<Embeddings<ExplicitSubwordVocab, S::Result>> {
343 match &self.vocab {
344 VocabWrap::BucketSubwordVocab(sw) => {
345 Ok(to_explicit_impl(sw, self.storage(), self.norms())?)
346 }
347 VocabWrap::FastTextSubwordVocab(sw) => {
348 Ok(to_explicit_impl(sw, self.storage(), self.norms())?)
349 }
350 VocabWrap::SimpleVocab(_) => {
351 Err(Error::conversion_error("SimpleVocab", "ExplicitVocab"))
352 }
353 VocabWrap::ExplicitSubwordVocab(_) => {
354 Err(Error::conversion_error("ExplicitVocab", "ExplicitVocab"))
355 }
356 VocabWrap::FloretSubwordVocab(_) => Err(Error::conversion_error(
357 "FloretSubwordVocab",
358 "ExplicitVocab",
359 )),
360 }
361 }
362}
363
364fn to_explicit_impl<I, S>(
365 vocab: &SubwordVocab<I>,
366 storage: &S,
367 norms: Option<&NdNorms>,
368) -> Result<Embeddings<ExplicitSubwordVocab, S::Result>>
369where
370 S: Storage + CloneFromMapping,
371 I: BucketIndexer,
372{
373 let (expl_voc, old_to_new) = vocab.to_explicit()?;
374 let mut mapping = (0..expl_voc.vocab_len()).collect::<Vec<_>>();
375 for (old, new) in old_to_new {
376 mapping[expl_voc.words_len() + new] = old as usize + expl_voc.words_len();
377 }
378
379 let new_storage = storage.clone_from_mapping(&mapping);
380 Ok(Embeddings::new_with_maybe_norms(
381 None,
382 expl_voc,
383 new_storage,
384 norms.cloned(),
385 ))
386}
387
388macro_rules! impl_embeddings_from(
389 ($vocab:ty, $storage:ty, $storage_wrap:ty) => {
390 impl From<Embeddings<$vocab, $storage>> for Embeddings<VocabWrap, $storage_wrap> {
391 fn from(from: Embeddings<$vocab, $storage>) -> Self {
392 let (metadata, vocab, storage, norms) = from.into_parts();
393 Embeddings {
394 metadata,
395 vocab: vocab.into(),
396 storage: storage.into(),
397 norms,
398 }
399 }
400 }
401 }
402);
403
404impl_embeddings_from!(SimpleVocab, NdArray, StorageWrap);
407impl_embeddings_from!(SimpleVocab, NdArray, StorageViewWrap);
408impl_embeddings_from!(SimpleVocab, QuantizedArray, StorageWrap);
409impl_embeddings_from!(BucketSubwordVocab, NdArray, StorageWrap);
410impl_embeddings_from!(BucketSubwordVocab, NdArray, StorageViewWrap);
411impl_embeddings_from!(BucketSubwordVocab, QuantizedArray, StorageWrap);
412impl_embeddings_from!(FastTextSubwordVocab, NdArray, StorageWrap);
413impl_embeddings_from!(FastTextSubwordVocab, NdArray, StorageViewWrap);
414impl_embeddings_from!(FastTextSubwordVocab, QuantizedArray, StorageWrap);
415impl_embeddings_from!(ExplicitSubwordVocab, NdArray, StorageWrap);
416impl_embeddings_from!(ExplicitSubwordVocab, NdArray, StorageViewWrap);
417impl_embeddings_from!(ExplicitSubwordVocab, QuantizedArray, StorageWrap);
418impl_embeddings_from!(FloretSubwordVocab, NdArray, StorageWrap);
419impl_embeddings_from!(FloretSubwordVocab, NdArray, StorageViewWrap);
420impl_embeddings_from!(FloretSubwordVocab, QuantizedArray, StorageWrap);
421impl_embeddings_from!(VocabWrap, QuantizedArray, StorageWrap);
422
423impl<'a, V, S> IntoIterator for &'a Embeddings<V, S>
424where
425 V: Vocab,
426 S: Storage,
427{
428 type Item = (&'a str, CowArray<'a, f32, Ix1>);
429 type IntoIter = Iter<'a>;
430
431 fn into_iter(self) -> Self::IntoIter {
432 self.iter()
433 }
434}
435
436#[cfg(feature = "memmap")]
437mod mmap {
438 use std::fs::File;
439 use std::io::BufReader;
440
441 use super::Embeddings;
442 use crate::chunks::io::MmapChunk;
443 use crate::chunks::io::{ChunkIdentifier, Header, ReadChunk};
444 use crate::chunks::metadata::Metadata;
445 use crate::chunks::norms::NdNorms;
446 #[cfg(target_endian = "little")]
447 use crate::chunks::storage::StorageViewWrap;
448 use crate::chunks::storage::{MmapArray, MmapQuantizedArray, StorageWrap};
449 use crate::chunks::vocab::{
450 BucketSubwordVocab, ExplicitSubwordVocab, FastTextSubwordVocab, SimpleVocab, VocabWrap,
451 };
452 use crate::error::{Error, Result};
453 use crate::io::MmapEmbeddings;
454 use crate::vocab::FloretSubwordVocab;
455
456 impl_embeddings_from!(SimpleVocab, MmapArray, StorageWrap);
457 #[cfg(target_endian = "little")]
458 impl_embeddings_from!(SimpleVocab, MmapArray, StorageViewWrap);
459 impl_embeddings_from!(SimpleVocab, MmapQuantizedArray, StorageWrap);
460 impl_embeddings_from!(BucketSubwordVocab, MmapArray, StorageWrap);
461 #[cfg(target_endian = "little")]
462 impl_embeddings_from!(BucketSubwordVocab, MmapArray, StorageViewWrap);
463 impl_embeddings_from!(BucketSubwordVocab, MmapQuantizedArray, StorageWrap);
464 impl_embeddings_from!(FloretSubwordVocab, MmapArray, StorageWrap);
465 #[cfg(target_endian = "little")]
466 impl_embeddings_from!(FloretSubwordVocab, MmapArray, StorageViewWrap);
467 impl_embeddings_from!(FloretSubwordVocab, MmapQuantizedArray, StorageWrap);
468 impl_embeddings_from!(FastTextSubwordVocab, MmapArray, StorageWrap);
469 #[cfg(target_endian = "little")]
470 impl_embeddings_from!(FastTextSubwordVocab, MmapArray, StorageViewWrap);
471 impl_embeddings_from!(FastTextSubwordVocab, MmapQuantizedArray, StorageWrap);
472 impl_embeddings_from!(ExplicitSubwordVocab, MmapArray, StorageWrap);
473 impl_embeddings_from!(ExplicitSubwordVocab, MmapQuantizedArray, StorageWrap);
474 #[cfg(target_endian = "little")]
475 impl_embeddings_from!(ExplicitSubwordVocab, MmapArray, StorageViewWrap);
476 #[cfg(feature = "memmap")]
477 impl_embeddings_from!(VocabWrap, MmapQuantizedArray, StorageWrap);
478
479 impl<V, S> MmapEmbeddings for Embeddings<V, S>
480 where
481 Self: Sized,
482 V: ReadChunk,
483 S: MmapChunk,
484 {
485 fn mmap_embeddings(read: &mut BufReader<File>) -> Result<Self> {
486 let header = Header::read_chunk(read)?;
487 let chunks = header.chunk_identifiers();
488 if chunks.is_empty() {
489 return Err(Error::Format(String::from(
490 "Embedding file does not contain chunks",
491 )));
492 }
493
494 let metadata = if header.chunk_identifiers()[0] == ChunkIdentifier::Metadata {
495 Some(Metadata::read_chunk(read)?)
496 } else {
497 None
498 };
499
500 let vocab = V::read_chunk(read)?;
501 let storage = S::mmap_chunk(read)?;
502 let norms = NdNorms::read_chunk(read).ok();
503
504 Ok(Embeddings {
505 metadata,
506 storage,
507 vocab,
508 norms,
509 })
510 }
511 }
512}
513
514impl<V, S> ReadEmbeddings for Embeddings<V, S>
515where
516 V: ReadChunk,
517 S: ReadChunk,
518{
519 fn read_embeddings<R>(read: &mut R) -> Result<Self>
520 where
521 R: Read + Seek,
522 {
523 let header = Header::read_chunk(read)?;
524 let chunks = header.chunk_identifiers();
525 if chunks.is_empty() {
526 return Err(Error::Format(String::from(
527 "Embedding file does not contain chunks",
528 )));
529 }
530
531 let metadata = if header.chunk_identifiers()[0] == ChunkIdentifier::Metadata {
532 Some(Metadata::read_chunk(read)?)
533 } else {
534 None
535 };
536
537 let vocab = V::read_chunk(read)?;
538 let storage = S::read_chunk(read)?;
539 let norms = NdNorms::read_chunk(read).ok();
540
541 Ok(Embeddings {
542 metadata,
543 storage,
544 vocab,
545 norms,
546 })
547 }
548}
549
550impl<V, S> WriteEmbeddings for Embeddings<V, S>
551where
552 V: WriteChunk,
553 S: WriteChunk,
554{
555 fn write_embeddings<W>(&self, write: &mut W) -> Result<()>
556 where
557 W: Write + Seek,
558 {
559 let mut chunks = match self.metadata {
560 Some(ref metadata) => vec![metadata.chunk_identifier()],
561 None => vec![],
562 };
563
564 chunks.extend_from_slice(&[
565 self.vocab.chunk_identifier(),
566 self.storage.chunk_identifier(),
567 ]);
568
569 if let Some(ref norms) = self.norms {
570 chunks.push(norms.chunk_identifier());
571 }
572
573 Header::new(chunks).write_chunk(write)?;
574 if let Some(ref metadata) = self.metadata {
575 metadata.write_chunk(write)?;
576 }
577
578 self.vocab.write_chunk(write)?;
579 self.storage.write_chunk(write)?;
580
581 if let Some(norms) = self.norms() {
582 norms.write_chunk(write)?;
583 }
584
585 Ok(())
586 }
587
588 fn write_embeddings_len(&self, offset: u64) -> u64 {
589 let mut len = 0;
590
591 let mut chunks = match self.metadata {
592 Some(ref metadata) => vec![metadata.chunk_identifier()],
593 None => vec![],
594 };
595
596 chunks.extend_from_slice(&[
597 self.vocab.chunk_identifier(),
598 self.storage.chunk_identifier(),
599 ]);
600
601 if let Some(ref norms) = self.norms {
602 chunks.push(norms.chunk_identifier());
603 }
604
605 let header = Header::new(chunks);
606 len += header.chunk_len(offset + len);
607
608 if let Some(ref metadata) = self.metadata {
609 len += metadata.chunk_len(offset + len);
610 }
611
612 len += self.vocab.chunk_len(offset + len);
613 len += self.storage.chunk_len(offset + len);
614
615 if let Some(ref norms) = self.norms {
616 len += norms.chunk_len(offset + len);
617 }
618
619 len
620 }
621}
622
623pub trait Quantize<V> {
625 fn quantize<T>(
633 &self,
634 n_subquantizers: usize,
635 n_subquantizer_bits: u32,
636 n_iterations: usize,
637 n_attempts: usize,
638 normalize: bool,
639 ) -> Result<Embeddings<V, QuantizedArray>>
640 where
641 T: TrainPq<f32>,
642 {
643 self.quantize_using::<T, _>(
644 n_subquantizers,
645 n_subquantizer_bits,
646 n_iterations,
647 n_attempts,
648 normalize,
649 ChaChaRng::from_entropy(),
650 )
651 }
652
653 fn quantize_using<T, R>(
658 &self,
659 n_subquantizers: usize,
660 n_subquantizer_bits: u32,
661 n_iterations: usize,
662 n_attempts: usize,
663 normalize: bool,
664 rng: R,
665 ) -> Result<Embeddings<V, QuantizedArray>>
666 where
667 T: TrainPq<f32>,
668 R: CryptoRng + RngCore + SeedableRng + Send;
669}
670
671impl<V, S> Quantize<V> for Embeddings<V, S>
672where
673 V: Vocab + Clone,
674 S: StorageView,
675{
676 fn quantize_using<T, R>(
677 &self,
678 n_subquantizers: usize,
679 n_subquantizer_bits: u32,
680 n_iterations: usize,
681 n_attempts: usize,
682 normalize: bool,
683 rng: R,
684 ) -> Result<Embeddings<V, QuantizedArray>>
685 where
686 T: TrainPq<f32>,
687 R: CryptoRng + RngCore + SeedableRng + Send,
688 {
689 let quantized_storage = self.storage().quantize_using::<T, R>(
690 n_subquantizers,
691 n_subquantizer_bits,
692 n_iterations,
693 n_attempts,
694 normalize,
695 rng,
696 )?;
697
698 Ok(Embeddings {
699 metadata: self.metadata().cloned(),
700 vocab: self.vocab.clone(),
701 storage: quantized_storage,
702 norms: self.norms().cloned(),
703 })
704 }
705}
706
707pub struct EmbeddingWithNorm<'a> {
709 pub embedding: CowArray<'a, f32, Ix1>,
710 pub norm: f32,
711}
712
713impl<'a> EmbeddingWithNorm<'a> {
714 pub fn into_unnormalized(self) -> Array1<f32> {
716 let mut unnormalized = self.embedding.into_owned();
717 unnormalized *= self.norm;
718 unnormalized
719 }
720}
721
722pub struct Iter<'a> {
724 storage: &'a dyn Storage,
725 inner: Enumerate<slice::Iter<'a, String>>,
726}
727
728impl<'a> Iterator for Iter<'a> {
729 type Item = (&'a str, CowArray<'a, f32, Ix1>);
730
731 fn next(&mut self) -> Option<Self::Item> {
732 self.inner
733 .next()
734 .map(|(idx, word)| (word.as_str(), self.storage.embedding(idx)))
735 }
736}
737
738pub struct IterWithNorms<'a> {
740 storage: &'a dyn Storage,
741 norms: Option<&'a NdNorms>,
742 inner: Enumerate<slice::Iter<'a, String>>,
743}
744
745impl<'a> Iterator for IterWithNorms<'a> {
746 type Item = (&'a str, EmbeddingWithNorm<'a>);
747
748 fn next(&mut self) -> Option<Self::Item> {
749 self.inner.next().map(|(idx, word)| {
750 (
751 word.as_str(),
752 EmbeddingWithNorm {
753 embedding: self.storage.embedding(idx),
754 norm: self.norms.map(|n| n[idx]).unwrap_or(1.),
755 },
756 )
757 })
758 }
759}
760
761#[cfg(test)]
762mod tests {
763 use std::fs::File;
764 use std::io::{BufReader, Cursor, Seek, SeekFrom};
765
766 use ndarray::{array, Array1};
767 use toml::toml;
768
769 use super::Embeddings;
770 use crate::chunks::metadata::Metadata;
771 use crate::chunks::norms::NdNorms;
772 use crate::chunks::storage::{NdArray, Storage, StorageView};
773 use crate::chunks::vocab::{FastTextSubwordVocab, SimpleVocab, Vocab};
774 use crate::compat::fasttext::ReadFastText;
775 use crate::compat::word2vec::ReadWord2VecRaw;
776 use crate::io::{ReadEmbeddings, WriteEmbeddings};
777 use crate::prelude::StorageWrap;
778 use crate::storage::QuantizedArray;
779 use crate::subword::Indexer;
780 use crate::vocab::VocabWrap;
781
782 fn test_embeddings() -> Embeddings<SimpleVocab, NdArray> {
783 let mut reader = BufReader::new(File::open("testdata/similarity.bin").unwrap());
784 Embeddings::read_word2vec_binary_raw(&mut reader, false).unwrap()
785 }
786
787 fn test_embeddings_with_metadata() -> Embeddings<SimpleVocab, NdArray> {
788 let mut embeds = test_embeddings();
789 embeds.set_metadata(Some(test_metadata()));
790 embeds
791 }
792
793 fn test_embeddings_fasttext() -> Embeddings<FastTextSubwordVocab, NdArray> {
794 let mut reader = BufReader::new(File::open("testdata/fasttext.bin").unwrap());
795 Embeddings::read_fasttext(&mut reader).unwrap()
796 }
797
798 fn test_embeddings_quantized() -> Embeddings<SimpleVocab, QuantizedArray> {
799 let mut reader = BufReader::new(File::open("testdata/quantized.fifu").unwrap());
800 Embeddings::read_embeddings(&mut reader).unwrap()
801 }
802
803 fn test_metadata() -> Metadata {
804 Metadata::new(toml! {
805 [hyperparameters]
806 dims = 300
807 ns = 5
808
809 [description]
810 description = "Test model"
811 language = "de"
812 })
813 }
814
815 #[test]
816 fn embedding_into_equal_to_embedding() {
817 let mut reader = BufReader::new(File::open("testdata/fasttext.bin").unwrap());
818 let embeds = Embeddings::read_fasttext(&mut reader).unwrap();
819
820 let mut target = Array1::zeros(embeds.dims());
822 assert!(embeds.embedding_into("ganz", target.view_mut()));
823 assert_eq!(target, embeds.embedding("ganz").unwrap());
824
825 let mut target = Array1::zeros(embeds.dims());
827 assert!(embeds.embedding_into("iddqd", target.view_mut()));
828 assert_eq!(target, embeds.embedding("iddqd").unwrap());
829
830 assert!(embeds.embedding_into("idspispopd", target.view_mut()));
832 assert_eq!(target, embeds.embedding("idspispopd").unwrap());
833 }
834
835 #[test]
836 fn embedding_batch_returns_correct_embeddings() {
837 let embeds = test_embeddings();
838 let (lookups, found) = embeds.embedding_batch(&[
839 "Berlin",
840 "Bremen",
841 "Groningen",
842 "Bremen",
843 "Berlin",
844 "Amsterdam",
845 ]);
846
847 assert_eq!(lookups.row(0), embeds.embedding("Berlin").unwrap());
848 assert_eq!(lookups.row(1), embeds.embedding("Bremen").unwrap());
849 assert_eq!(lookups.row(3), embeds.embedding("Bremen").unwrap());
850 assert_eq!(lookups.row(4), embeds.embedding("Berlin").unwrap());
851 assert_eq!(found, &[true, true, false, true, true, false]);
852 }
853
854 #[test]
855 #[cfg(feature = "memmap")]
856 fn mmap() {
857 use crate::chunks::storage::MmapArray;
858 use crate::io::MmapEmbeddings;
859 let check_embeds = test_embeddings();
860 let mut reader = BufReader::new(File::open("testdata/similarity.fifu").unwrap());
861 let embeds: Embeddings<SimpleVocab, MmapArray> =
862 Embeddings::mmap_embeddings(&mut reader).unwrap();
863 assert_eq!(embeds.vocab(), check_embeds.vocab());
864
865 #[cfg(target_endian = "little")]
866 assert_eq!(embeds.storage().view(), check_embeds.storage().view());
867 }
868
869 #[test]
870 fn to_explicit() {
871 let mut reader = BufReader::new(File::open("testdata/fasttext.bin").unwrap());
872 let embeds = Embeddings::read_fasttext(&mut reader).unwrap();
873 let expl_embeds = embeds.to_explicit().unwrap();
874 let ganz = embeds.embedding("ganz").unwrap();
875 let ganz_expl = expl_embeds.embedding("ganz").unwrap();
876 assert_eq!(ganz, ganz_expl);
877
878 assert!(embeds
879 .vocab
880 .idx("anz")
881 .map(|i| i.subword().is_some())
882 .unwrap_or(false));
883 let anz_idx = embeds.vocab.indexer().index_ngram(&"anz".into())[0] as usize
884 + embeds.vocab.words_len();
885 let anz = embeds.storage.embedding(anz_idx);
886 let anz_idx = expl_embeds.vocab.indexer().index_ngram(&"anz".into())[0] as usize
887 + embeds.vocab.words_len();
888 let anz_expl = expl_embeds.storage.embedding(anz_idx);
889 assert_eq!(anz, anz_expl);
890 }
891
892 #[test]
893 fn norms() {
894 let vocab = SimpleVocab::new(vec!["norms".to_string(), "test".to_string()]);
895 let storage = NdArray::new(array![[1f32], [-1f32]]);
896 let norms = NdNorms::new(array![2f32, 3f32]);
897 let check = Embeddings::new(None, vocab, storage, norms);
898
899 let mut serialized = Cursor::new(Vec::new());
900 check.write_embeddings(&mut serialized).unwrap();
901 serialized.seek(SeekFrom::Start(0)).unwrap();
902
903 let embeddings: Embeddings<SimpleVocab, NdArray> =
904 Embeddings::read_embeddings(&mut serialized).unwrap();
905
906 assert!(check
907 .norms()
908 .unwrap()
909 .view()
910 .abs_diff_eq(&embeddings.norms().unwrap().view(), 1e-8),);
911 }
912
913 #[test]
914 fn write_read_simple_roundtrip() {
915 let check_embeds = test_embeddings();
916 let mut cursor = Cursor::new(Vec::new());
917 check_embeds.write_embeddings(&mut cursor).unwrap();
918 cursor.seek(SeekFrom::Start(0)).unwrap();
919 let embeds: Embeddings<SimpleVocab, NdArray> =
920 Embeddings::read_embeddings(&mut cursor).unwrap();
921 assert_eq!(embeds.storage().view(), check_embeds.storage().view());
922 assert_eq!(embeds.vocab(), check_embeds.vocab());
923 assert_eq!(
924 cursor.into_inner().len() as u64,
925 check_embeds.write_embeddings_len(0)
926 );
927 }
928
929 #[test]
930 fn write_read_simple_metadata_roundtrip() {
931 let check_embeds = test_embeddings_with_metadata();
932
933 let mut cursor = Cursor::new(Vec::new());
934 check_embeds.write_embeddings(&mut cursor).unwrap();
935 cursor.seek(SeekFrom::Start(0)).unwrap();
936 let embeds: Embeddings<SimpleVocab, NdArray> =
937 Embeddings::read_embeddings(&mut cursor).unwrap();
938 assert_eq!(embeds.storage().view(), check_embeds.storage().view());
939 assert_eq!(embeds.vocab(), check_embeds.vocab());
940 assert_eq!(
941 cursor.into_inner().len() as u64,
942 check_embeds.write_embeddings_len(0)
943 );
944 }
945
946 #[test]
947 fn embeddings_write_length_different_offsets() {
948 let embeddings: Vec<Embeddings<VocabWrap, StorageWrap>> = vec![
949 test_embeddings().into(),
950 test_embeddings_with_metadata().into(),
951 test_embeddings_fasttext().into(),
952 test_embeddings_quantized().into(),
953 ];
954
955 for check_embeddings in &embeddings {
956 for offset in 0..16u64 {
957 let mut cursor = Cursor::new(Vec::new());
958 cursor.seek(SeekFrom::Start(offset)).unwrap();
959 check_embeddings.write_embeddings(&mut cursor).unwrap();
960 assert_eq!(
961 cursor.into_inner().len() as u64 - offset,
962 check_embeddings.write_embeddings_len(offset)
963 );
964 }
965 }
966 }
967}