finalfusion/
embeddings.rs

1//! Word embeddings.
2
3use std::collections::HashMap;
4use std::io::{Read, Seek, Write};
5use std::iter::Enumerate;
6use std::mem;
7use std::slice;
8
9use ndarray::{s, Array1, Array2, ArrayViewMut1, ArrayViewMut2, Axis, CowArray, Ix1};
10use rand::{CryptoRng, RngCore, SeedableRng};
11use rand_chacha::ChaChaRng;
12use reductive::pq::TrainPq;
13
14use crate::chunks::io::{ChunkIdentifier, Header, ReadChunk, WriteChunk};
15use crate::chunks::metadata::Metadata;
16use crate::chunks::norms::NdNorms;
17use crate::chunks::storage::{
18    sealed::CloneFromMapping, NdArray, Quantize as QuantizeStorage, QuantizedArray, Storage,
19    StorageView, StorageViewWrap, StorageWrap,
20};
21use crate::chunks::vocab::{
22    BucketSubwordVocab, ExplicitSubwordVocab, FastTextSubwordVocab, SimpleVocab, SubwordVocab,
23    Vocab, VocabWrap, WordIndex,
24};
25use crate::error::{Error, Result};
26use crate::io::{ReadEmbeddings, WriteEmbeddings};
27use crate::subword::BucketIndexer;
28use crate::util::l2_normalize;
29use crate::vocab::FloretSubwordVocab;
30
31/// Word embeddings.
32///
33/// This data structure stores word embeddings (also known as *word vectors*)
34/// and provides some useful methods on the embeddings, such as similarity
35/// and analogy queries.
36#[derive(Clone, Debug)]
37pub struct Embeddings<V, S> {
38    metadata: Option<Metadata>,
39    storage: S,
40    vocab: V,
41    norms: Option<NdNorms>,
42}
43
44impl<V, S> Embeddings<V, S>
45where
46    V: Vocab,
47    S: Storage,
48{
49    /// Construct an embeddings from a vocabulary, storage, and norms.
50    ///
51    /// The embeddings for known words **must** be
52    /// normalized. However, this is not verified due to the high
53    /// computational cost.
54    pub fn new(metadata: Option<Metadata>, vocab: V, storage: S, norms: NdNorms) -> Self {
55        assert_eq!(
56            vocab.words_len(),
57            norms.len(),
58            "Vocab and norms do not have the same length"
59        );
60        Embeddings::new_with_maybe_norms(metadata, vocab, storage, Some(norms))
61    }
62
63    pub(crate) fn new_with_maybe_norms(
64        metadata: Option<Metadata>,
65        vocab: V,
66        storage: S,
67        norms: Option<NdNorms>,
68    ) -> Self {
69        assert_eq!(
70            vocab.vocab_len(),
71            storage.shape().0,
72            "Max vocab index must match number of rows in the embedding matrix."
73        );
74        Embeddings {
75            metadata,
76            storage,
77            vocab,
78            norms,
79        }
80    }
81}
82
83impl<V, S> Embeddings<V, S> {
84    /// Decompose embeddings in its vocabulary, storage, and
85    /// optionally norms.
86    pub fn into_parts(self) -> (Option<Metadata>, V, S, Option<NdNorms>) {
87        (self.metadata, self.vocab, self.storage, self.norms)
88    }
89
90    /// Get metadata.
91    pub fn metadata(&self) -> Option<&Metadata> {
92        self.metadata.as_ref()
93    }
94
95    /// Get metadata mutably.
96    pub fn metadata_mut(&mut self) -> Option<&mut Metadata> {
97        self.metadata.as_mut()
98    }
99
100    /// Get embedding norms.
101    pub fn norms(&self) -> Option<&NdNorms> {
102        self.norms.as_ref()
103    }
104
105    /// Set metadata.
106    ///
107    /// Returns the previously-stored metadata.
108    pub fn set_metadata(&mut self, mut metadata: Option<Metadata>) -> Option<Metadata> {
109        mem::swap(&mut self.metadata, &mut metadata);
110        metadata
111    }
112
113    /// Get the embedding storage.
114    pub fn storage(&self) -> &S {
115        &self.storage
116    }
117
118    /// Get the vocabulary.
119    pub fn vocab(&self) -> &V {
120        &self.vocab
121    }
122}
123
124#[allow(clippy::len_without_is_empty)]
125impl<V, S> Embeddings<V, S>
126where
127    V: Vocab,
128    S: Storage,
129{
130    /// Return the length (in vector components) of the word embeddings.
131    pub fn dims(&self) -> usize {
132        self.storage.shape().1
133    }
134
135    /// Get the embedding of a word.
136    pub fn embedding(&self, word: &str) -> Option<CowArray<f32, Ix1>> {
137        match self.vocab.idx(word)? {
138            WordIndex::Word(idx) => Some(self.storage.embedding(idx)),
139            WordIndex::Subword(indices) => {
140                let embeds = self.storage.embeddings(&indices);
141                let mut embed = embeds.sum_axis(Axis(0));
142                l2_normalize(embed.view_mut());
143
144                Some(CowArray::from(embed))
145            }
146        }
147    }
148
149    /// Realize the embedding of a word into the given vector.
150    ///
151    /// This variant of `embedding` realizes the embedding into the
152    /// given vector. This makes it possible to look up embeddings
153    /// without any additional allocations. This method returns
154    /// `false` and does not modify the vector if no embedding could
155    /// be found.
156    ///
157    /// Panics when then the vector does not have the same
158    /// dimensionality as the word embeddings.
159    pub fn embedding_into(&self, word: &str, mut target: ArrayViewMut1<f32>) -> bool {
160        assert_eq!(
161            target.len(),
162            self.dims(),
163            "Embeddings have {} dimensions, whereas target array has {}",
164            self.dims(),
165            target.len()
166        );
167
168        let index = if let Some(idx) = self.vocab.idx(word) {
169            idx
170        } else {
171            return false;
172        };
173
174        match index {
175            WordIndex::Word(idx) => target.assign(&self.storage.embedding(idx)),
176            WordIndex::Subword(indices) => {
177                target.fill(0.);
178
179                let embeds = self.storage.embeddings(&indices);
180
181                for embed in embeds.outer_iter() {
182                    target += &embed;
183                }
184
185                l2_normalize(target.view_mut());
186            }
187        }
188
189        true
190    }
191
192    /// Get a batch of embeddings.
193    ///
194    /// The embeddings of all `words` are computed and returned. This method also
195    /// return a `Vec` indicating for each word if an embedding could be found.
196    pub fn embedding_batch(&self, words: &[impl AsRef<str>]) -> (Array2<f32>, Vec<bool>) {
197        let mut embeddings = Array2::zeros((words.len(), self.dims()));
198        let found = self.embedding_batch_into(words, embeddings.view_mut());
199        (embeddings, found)
200    }
201
202    /// Get a batch of embeddings.
203    ///
204    /// The embeddings of all `words` are computed and written to `output`. A `Vec` is
205    /// returned that indicates for each word if an embedding could be found.
206    ///
207    /// This method panics when `output` does not have the correct shape.
208    pub fn embedding_batch_into(
209        &self,
210        words: &[impl AsRef<str>],
211        mut output: ArrayViewMut2<f32>,
212    ) -> Vec<bool> {
213        assert_eq!(
214            output.len_of(Axis(0)),
215            words.len(),
216            "Expected embedding matrix for batch size {}, got {}",
217            words.len(),
218            output.len_of(Axis(0))
219        );
220        assert_eq!(
221            output.len_of(Axis(1)),
222            self.dims(),
223            "Expected embedding matrix for embeddings with dimensionality {}, got {}",
224            self.dims(),
225            output.len_of(Axis(1))
226        );
227
228        let mut found = vec![false; words.len()];
229
230        let mut word_indices: HashMap<_, Vec<_>> = HashMap::new();
231        for (idx, word) in words.iter().enumerate() {
232            let indices = word_indices.entry(word.as_ref()).or_default();
233            indices.push(idx);
234        }
235
236        for (word, indices) in word_indices {
237            // Look up the embedding for the first occurence of the word and
238            // then copy to other occurences.
239            let idx_first = indices[0];
240            if self.embedding_into(word, output.index_axis_mut(Axis(0), idx_first)) {
241                found[idx_first] = true;
242
243                for &idx in indices.iter().skip(1) {
244                    let (first, mut cur) = output.multi_slice_mut((s![idx_first, ..], s![idx, ..]));
245                    cur.assign(&first);
246                    found[idx] = true;
247                }
248            }
249        }
250
251        found
252    }
253
254    /// Get the embedding and original norm of a word.
255    ///
256    /// Returns for a word:
257    ///
258    /// * The word embedding.
259    /// * The norm of the embedding before normalization to a unit vector.
260    ///
261    /// The original embedding can be reconstructed by multiplying all
262    /// embedding components by the original norm.
263    ///
264    /// If the model does not have associated norms, *1* will be
265    /// returned as the norm for vocabulary words.
266    pub fn embedding_with_norm(&self, word: &str) -> Option<EmbeddingWithNorm> {
267        match self.vocab.idx(word)? {
268            WordIndex::Word(idx) => Some(EmbeddingWithNorm {
269                embedding: self.storage.embedding(idx),
270                norm: self.norms().map(|n| n[idx]).unwrap_or(1.),
271            }),
272            WordIndex::Subword(indices) => {
273                let embeds = self.storage.embeddings(&indices);
274                let mut embed = embeds.sum_axis(Axis(0));
275
276                let norm = l2_normalize(embed.view_mut());
277
278                Some(EmbeddingWithNorm {
279                    embedding: CowArray::from(embed),
280                    norm,
281                })
282            }
283        }
284    }
285
286    /// Get an iterator over pairs of words and the corresponding embeddings.
287    pub fn iter(&self) -> Iter {
288        Iter {
289            storage: &self.storage,
290            inner: self.vocab.words().iter().enumerate(),
291        }
292    }
293
294    /// Get an iterator over triples of words, embeddings, and norms.
295    ///
296    /// Returns an iterator that returns triples of:
297    ///
298    /// * A word.
299    /// * Its word embedding.
300    /// * The original norm of the embedding before normalization to a unit vector.
301    ///
302    /// The original embedding can be reconstructed by multiplying all
303    /// embedding components by the original norm.
304    ///
305    /// If the model does not have associated norms, the norm is
306    /// always *1*.
307    pub fn iter_with_norms(&self) -> IterWithNorms {
308        IterWithNorms {
309            storage: &self.storage,
310            norms: self.norms(),
311            inner: self.vocab.words().iter().enumerate(),
312        }
313    }
314
315    /// Get the vocabulary size.
316    ///
317    /// The vocabulary size excludes subword units.
318    pub fn len(&self) -> usize {
319        self.vocab.words_len()
320    }
321}
322
323impl<I, S> Embeddings<SubwordVocab<I>, S>
324where
325    I: BucketIndexer,
326    S: Storage + CloneFromMapping,
327{
328    /// Convert to explicitly indexed subword Embeddings.
329    pub fn to_explicit(&self) -> Result<Embeddings<ExplicitSubwordVocab, S::Result>> {
330        to_explicit_impl(self.vocab(), self.storage(), self.norms())
331    }
332}
333
334impl<S> Embeddings<VocabWrap, S>
335where
336    S: Storage + CloneFromMapping,
337{
338    /// Try to convert to explicitly indexed subword embeddings.
339    ///
340    /// Conversion fails if the wrapped vocabulary is `SimpleVocab`, `FloretSubwordVocab` or
341    /// already an `ExplicitSubwordVocab`.
342    pub fn try_to_explicit(&self) -> Result<Embeddings<ExplicitSubwordVocab, S::Result>> {
343        match &self.vocab {
344            VocabWrap::BucketSubwordVocab(sw) => {
345                Ok(to_explicit_impl(sw, self.storage(), self.norms())?)
346            }
347            VocabWrap::FastTextSubwordVocab(sw) => {
348                Ok(to_explicit_impl(sw, self.storage(), self.norms())?)
349            }
350            VocabWrap::SimpleVocab(_) => {
351                Err(Error::conversion_error("SimpleVocab", "ExplicitVocab"))
352            }
353            VocabWrap::ExplicitSubwordVocab(_) => {
354                Err(Error::conversion_error("ExplicitVocab", "ExplicitVocab"))
355            }
356            VocabWrap::FloretSubwordVocab(_) => Err(Error::conversion_error(
357                "FloretSubwordVocab",
358                "ExplicitVocab",
359            )),
360        }
361    }
362}
363
364fn to_explicit_impl<I, S>(
365    vocab: &SubwordVocab<I>,
366    storage: &S,
367    norms: Option<&NdNorms>,
368) -> Result<Embeddings<ExplicitSubwordVocab, S::Result>>
369where
370    S: Storage + CloneFromMapping,
371    I: BucketIndexer,
372{
373    let (expl_voc, old_to_new) = vocab.to_explicit()?;
374    let mut mapping = (0..expl_voc.vocab_len()).collect::<Vec<_>>();
375    for (old, new) in old_to_new {
376        mapping[expl_voc.words_len() + new] = old as usize + expl_voc.words_len();
377    }
378
379    let new_storage = storage.clone_from_mapping(&mapping);
380    Ok(Embeddings::new_with_maybe_norms(
381        None,
382        expl_voc,
383        new_storage,
384        norms.cloned(),
385    ))
386}
387
388macro_rules! impl_embeddings_from(
389    ($vocab:ty, $storage:ty, $storage_wrap:ty) => {
390        impl From<Embeddings<$vocab, $storage>> for Embeddings<VocabWrap, $storage_wrap> {
391            fn from(from: Embeddings<$vocab, $storage>) -> Self {
392                let (metadata, vocab, storage, norms) = from.into_parts();
393                Embeddings {
394                    metadata,
395                    vocab: vocab.into(),
396                    storage: storage.into(),
397                    norms,
398                }
399            }
400        }
401    }
402);
403
404// Hmpf. With the blanket From<T> for T implementation, we need
405// specialization to generalize this.
406impl_embeddings_from!(SimpleVocab, NdArray, StorageWrap);
407impl_embeddings_from!(SimpleVocab, NdArray, StorageViewWrap);
408impl_embeddings_from!(SimpleVocab, QuantizedArray, StorageWrap);
409impl_embeddings_from!(BucketSubwordVocab, NdArray, StorageWrap);
410impl_embeddings_from!(BucketSubwordVocab, NdArray, StorageViewWrap);
411impl_embeddings_from!(BucketSubwordVocab, QuantizedArray, StorageWrap);
412impl_embeddings_from!(FastTextSubwordVocab, NdArray, StorageWrap);
413impl_embeddings_from!(FastTextSubwordVocab, NdArray, StorageViewWrap);
414impl_embeddings_from!(FastTextSubwordVocab, QuantizedArray, StorageWrap);
415impl_embeddings_from!(ExplicitSubwordVocab, NdArray, StorageWrap);
416impl_embeddings_from!(ExplicitSubwordVocab, NdArray, StorageViewWrap);
417impl_embeddings_from!(ExplicitSubwordVocab, QuantizedArray, StorageWrap);
418impl_embeddings_from!(FloretSubwordVocab, NdArray, StorageWrap);
419impl_embeddings_from!(FloretSubwordVocab, NdArray, StorageViewWrap);
420impl_embeddings_from!(FloretSubwordVocab, QuantizedArray, StorageWrap);
421impl_embeddings_from!(VocabWrap, QuantizedArray, StorageWrap);
422
423impl<'a, V, S> IntoIterator for &'a Embeddings<V, S>
424where
425    V: Vocab,
426    S: Storage,
427{
428    type Item = (&'a str, CowArray<'a, f32, Ix1>);
429    type IntoIter = Iter<'a>;
430
431    fn into_iter(self) -> Self::IntoIter {
432        self.iter()
433    }
434}
435
436#[cfg(feature = "memmap")]
437mod mmap {
438    use std::fs::File;
439    use std::io::BufReader;
440
441    use super::Embeddings;
442    use crate::chunks::io::MmapChunk;
443    use crate::chunks::io::{ChunkIdentifier, Header, ReadChunk};
444    use crate::chunks::metadata::Metadata;
445    use crate::chunks::norms::NdNorms;
446    #[cfg(target_endian = "little")]
447    use crate::chunks::storage::StorageViewWrap;
448    use crate::chunks::storage::{MmapArray, MmapQuantizedArray, StorageWrap};
449    use crate::chunks::vocab::{
450        BucketSubwordVocab, ExplicitSubwordVocab, FastTextSubwordVocab, SimpleVocab, VocabWrap,
451    };
452    use crate::error::{Error, Result};
453    use crate::io::MmapEmbeddings;
454    use crate::vocab::FloretSubwordVocab;
455
456    impl_embeddings_from!(SimpleVocab, MmapArray, StorageWrap);
457    #[cfg(target_endian = "little")]
458    impl_embeddings_from!(SimpleVocab, MmapArray, StorageViewWrap);
459    impl_embeddings_from!(SimpleVocab, MmapQuantizedArray, StorageWrap);
460    impl_embeddings_from!(BucketSubwordVocab, MmapArray, StorageWrap);
461    #[cfg(target_endian = "little")]
462    impl_embeddings_from!(BucketSubwordVocab, MmapArray, StorageViewWrap);
463    impl_embeddings_from!(BucketSubwordVocab, MmapQuantizedArray, StorageWrap);
464    impl_embeddings_from!(FloretSubwordVocab, MmapArray, StorageWrap);
465    #[cfg(target_endian = "little")]
466    impl_embeddings_from!(FloretSubwordVocab, MmapArray, StorageViewWrap);
467    impl_embeddings_from!(FloretSubwordVocab, MmapQuantizedArray, StorageWrap);
468    impl_embeddings_from!(FastTextSubwordVocab, MmapArray, StorageWrap);
469    #[cfg(target_endian = "little")]
470    impl_embeddings_from!(FastTextSubwordVocab, MmapArray, StorageViewWrap);
471    impl_embeddings_from!(FastTextSubwordVocab, MmapQuantizedArray, StorageWrap);
472    impl_embeddings_from!(ExplicitSubwordVocab, MmapArray, StorageWrap);
473    impl_embeddings_from!(ExplicitSubwordVocab, MmapQuantizedArray, StorageWrap);
474    #[cfg(target_endian = "little")]
475    impl_embeddings_from!(ExplicitSubwordVocab, MmapArray, StorageViewWrap);
476    #[cfg(feature = "memmap")]
477    impl_embeddings_from!(VocabWrap, MmapQuantizedArray, StorageWrap);
478
479    impl<V, S> MmapEmbeddings for Embeddings<V, S>
480    where
481        Self: Sized,
482        V: ReadChunk,
483        S: MmapChunk,
484    {
485        fn mmap_embeddings(read: &mut BufReader<File>) -> Result<Self> {
486            let header = Header::read_chunk(read)?;
487            let chunks = header.chunk_identifiers();
488            if chunks.is_empty() {
489                return Err(Error::Format(String::from(
490                    "Embedding file does not contain chunks",
491                )));
492            }
493
494            let metadata = if header.chunk_identifiers()[0] == ChunkIdentifier::Metadata {
495                Some(Metadata::read_chunk(read)?)
496            } else {
497                None
498            };
499
500            let vocab = V::read_chunk(read)?;
501            let storage = S::mmap_chunk(read)?;
502            let norms = NdNorms::read_chunk(read).ok();
503
504            Ok(Embeddings {
505                metadata,
506                storage,
507                vocab,
508                norms,
509            })
510        }
511    }
512}
513
514impl<V, S> ReadEmbeddings for Embeddings<V, S>
515where
516    V: ReadChunk,
517    S: ReadChunk,
518{
519    fn read_embeddings<R>(read: &mut R) -> Result<Self>
520    where
521        R: Read + Seek,
522    {
523        let header = Header::read_chunk(read)?;
524        let chunks = header.chunk_identifiers();
525        if chunks.is_empty() {
526            return Err(Error::Format(String::from(
527                "Embedding file does not contain chunks",
528            )));
529        }
530
531        let metadata = if header.chunk_identifiers()[0] == ChunkIdentifier::Metadata {
532            Some(Metadata::read_chunk(read)?)
533        } else {
534            None
535        };
536
537        let vocab = V::read_chunk(read)?;
538        let storage = S::read_chunk(read)?;
539        let norms = NdNorms::read_chunk(read).ok();
540
541        Ok(Embeddings {
542            metadata,
543            storage,
544            vocab,
545            norms,
546        })
547    }
548}
549
550impl<V, S> WriteEmbeddings for Embeddings<V, S>
551where
552    V: WriteChunk,
553    S: WriteChunk,
554{
555    fn write_embeddings<W>(&self, write: &mut W) -> Result<()>
556    where
557        W: Write + Seek,
558    {
559        let mut chunks = match self.metadata {
560            Some(ref metadata) => vec![metadata.chunk_identifier()],
561            None => vec![],
562        };
563
564        chunks.extend_from_slice(&[
565            self.vocab.chunk_identifier(),
566            self.storage.chunk_identifier(),
567        ]);
568
569        if let Some(ref norms) = self.norms {
570            chunks.push(norms.chunk_identifier());
571        }
572
573        Header::new(chunks).write_chunk(write)?;
574        if let Some(ref metadata) = self.metadata {
575            metadata.write_chunk(write)?;
576        }
577
578        self.vocab.write_chunk(write)?;
579        self.storage.write_chunk(write)?;
580
581        if let Some(norms) = self.norms() {
582            norms.write_chunk(write)?;
583        }
584
585        Ok(())
586    }
587
588    fn write_embeddings_len(&self, offset: u64) -> u64 {
589        let mut len = 0;
590
591        let mut chunks = match self.metadata {
592            Some(ref metadata) => vec![metadata.chunk_identifier()],
593            None => vec![],
594        };
595
596        chunks.extend_from_slice(&[
597            self.vocab.chunk_identifier(),
598            self.storage.chunk_identifier(),
599        ]);
600
601        if let Some(ref norms) = self.norms {
602            chunks.push(norms.chunk_identifier());
603        }
604
605        let header = Header::new(chunks);
606        len += header.chunk_len(offset + len);
607
608        if let Some(ref metadata) = self.metadata {
609            len += metadata.chunk_len(offset + len);
610        }
611
612        len += self.vocab.chunk_len(offset + len);
613        len += self.storage.chunk_len(offset + len);
614
615        if let Some(ref norms) = self.norms {
616            len += norms.chunk_len(offset + len);
617        }
618
619        len
620    }
621}
622
623/// Quantizable embedding matrix.
624pub trait Quantize<V> {
625    /// Quantize the embedding matrix.
626    ///
627    /// This method trains a quantizer for the embedding matrix and
628    /// then quantizes the matrix using this quantizer.
629    ///
630    /// The xorshift PRNG is used for picking the initial quantizer
631    /// centroids.
632    fn quantize<T>(
633        &self,
634        n_subquantizers: usize,
635        n_subquantizer_bits: u32,
636        n_iterations: usize,
637        n_attempts: usize,
638        normalize: bool,
639    ) -> Result<Embeddings<V, QuantizedArray>>
640    where
641        T: TrainPq<f32>,
642    {
643        self.quantize_using::<T, _>(
644            n_subquantizers,
645            n_subquantizer_bits,
646            n_iterations,
647            n_attempts,
648            normalize,
649            ChaChaRng::from_entropy(),
650        )
651    }
652
653    /// Quantize the embedding matrix using the provided RNG.
654    ///
655    /// This method trains a quantizer for the embedding matrix and
656    /// then quantizes the matrix using this quantizer.
657    fn quantize_using<T, R>(
658        &self,
659        n_subquantizers: usize,
660        n_subquantizer_bits: u32,
661        n_iterations: usize,
662        n_attempts: usize,
663        normalize: bool,
664        rng: R,
665    ) -> Result<Embeddings<V, QuantizedArray>>
666    where
667        T: TrainPq<f32>,
668        R: CryptoRng + RngCore + SeedableRng + Send;
669}
670
671impl<V, S> Quantize<V> for Embeddings<V, S>
672where
673    V: Vocab + Clone,
674    S: StorageView,
675{
676    fn quantize_using<T, R>(
677        &self,
678        n_subquantizers: usize,
679        n_subquantizer_bits: u32,
680        n_iterations: usize,
681        n_attempts: usize,
682        normalize: bool,
683        rng: R,
684    ) -> Result<Embeddings<V, QuantizedArray>>
685    where
686        T: TrainPq<f32>,
687        R: CryptoRng + RngCore + SeedableRng + Send,
688    {
689        let quantized_storage = self.storage().quantize_using::<T, R>(
690            n_subquantizers,
691            n_subquantizer_bits,
692            n_iterations,
693            n_attempts,
694            normalize,
695            rng,
696        )?;
697
698        Ok(Embeddings {
699            metadata: self.metadata().cloned(),
700            vocab: self.vocab.clone(),
701            storage: quantized_storage,
702            norms: self.norms().cloned(),
703        })
704    }
705}
706
707/// An embedding with its (pre-normalization) l2 norm.
708pub struct EmbeddingWithNorm<'a> {
709    pub embedding: CowArray<'a, f32, Ix1>,
710    pub norm: f32,
711}
712
713impl<'a> EmbeddingWithNorm<'a> {
714    // Compute the unnormalized embedding.
715    pub fn into_unnormalized(self) -> Array1<f32> {
716        let mut unnormalized = self.embedding.into_owned();
717        unnormalized *= self.norm;
718        unnormalized
719    }
720}
721
722/// Iterator over embeddings.
723pub struct Iter<'a> {
724    storage: &'a dyn Storage,
725    inner: Enumerate<slice::Iter<'a, String>>,
726}
727
728impl<'a> Iterator for Iter<'a> {
729    type Item = (&'a str, CowArray<'a, f32, Ix1>);
730
731    fn next(&mut self) -> Option<Self::Item> {
732        self.inner
733            .next()
734            .map(|(idx, word)| (word.as_str(), self.storage.embedding(idx)))
735    }
736}
737
738/// Iterator over embeddings.
739pub struct IterWithNorms<'a> {
740    storage: &'a dyn Storage,
741    norms: Option<&'a NdNorms>,
742    inner: Enumerate<slice::Iter<'a, String>>,
743}
744
745impl<'a> Iterator for IterWithNorms<'a> {
746    type Item = (&'a str, EmbeddingWithNorm<'a>);
747
748    fn next(&mut self) -> Option<Self::Item> {
749        self.inner.next().map(|(idx, word)| {
750            (
751                word.as_str(),
752                EmbeddingWithNorm {
753                    embedding: self.storage.embedding(idx),
754                    norm: self.norms.map(|n| n[idx]).unwrap_or(1.),
755                },
756            )
757        })
758    }
759}
760
761#[cfg(test)]
762mod tests {
763    use std::fs::File;
764    use std::io::{BufReader, Cursor, Seek, SeekFrom};
765
766    use ndarray::{array, Array1};
767    use toml::toml;
768
769    use super::Embeddings;
770    use crate::chunks::metadata::Metadata;
771    use crate::chunks::norms::NdNorms;
772    use crate::chunks::storage::{NdArray, Storage, StorageView};
773    use crate::chunks::vocab::{FastTextSubwordVocab, SimpleVocab, Vocab};
774    use crate::compat::fasttext::ReadFastText;
775    use crate::compat::word2vec::ReadWord2VecRaw;
776    use crate::io::{ReadEmbeddings, WriteEmbeddings};
777    use crate::prelude::StorageWrap;
778    use crate::storage::QuantizedArray;
779    use crate::subword::Indexer;
780    use crate::vocab::VocabWrap;
781
782    fn test_embeddings() -> Embeddings<SimpleVocab, NdArray> {
783        let mut reader = BufReader::new(File::open("testdata/similarity.bin").unwrap());
784        Embeddings::read_word2vec_binary_raw(&mut reader, false).unwrap()
785    }
786
787    fn test_embeddings_with_metadata() -> Embeddings<SimpleVocab, NdArray> {
788        let mut embeds = test_embeddings();
789        embeds.set_metadata(Some(test_metadata()));
790        embeds
791    }
792
793    fn test_embeddings_fasttext() -> Embeddings<FastTextSubwordVocab, NdArray> {
794        let mut reader = BufReader::new(File::open("testdata/fasttext.bin").unwrap());
795        Embeddings::read_fasttext(&mut reader).unwrap()
796    }
797
798    fn test_embeddings_quantized() -> Embeddings<SimpleVocab, QuantizedArray> {
799        let mut reader = BufReader::new(File::open("testdata/quantized.fifu").unwrap());
800        Embeddings::read_embeddings(&mut reader).unwrap()
801    }
802
803    fn test_metadata() -> Metadata {
804        Metadata::new(toml! {
805            [hyperparameters]
806            dims = 300
807            ns = 5
808
809            [description]
810            description = "Test model"
811            language = "de"
812        })
813    }
814
815    #[test]
816    fn embedding_into_equal_to_embedding() {
817        let mut reader = BufReader::new(File::open("testdata/fasttext.bin").unwrap());
818        let embeds = Embeddings::read_fasttext(&mut reader).unwrap();
819
820        // Known word
821        let mut target = Array1::zeros(embeds.dims());
822        assert!(embeds.embedding_into("ganz", target.view_mut()));
823        assert_eq!(target, embeds.embedding("ganz").unwrap());
824
825        // Unknown word
826        let mut target = Array1::zeros(embeds.dims());
827        assert!(embeds.embedding_into("iddqd", target.view_mut()));
828        assert_eq!(target, embeds.embedding("iddqd").unwrap());
829
830        // Unknown word, non-zero vector
831        assert!(embeds.embedding_into("idspispopd", target.view_mut()));
832        assert_eq!(target, embeds.embedding("idspispopd").unwrap());
833    }
834
835    #[test]
836    fn embedding_batch_returns_correct_embeddings() {
837        let embeds = test_embeddings();
838        let (lookups, found) = embeds.embedding_batch(&[
839            "Berlin",
840            "Bremen",
841            "Groningen",
842            "Bremen",
843            "Berlin",
844            "Amsterdam",
845        ]);
846
847        assert_eq!(lookups.row(0), embeds.embedding("Berlin").unwrap());
848        assert_eq!(lookups.row(1), embeds.embedding("Bremen").unwrap());
849        assert_eq!(lookups.row(3), embeds.embedding("Bremen").unwrap());
850        assert_eq!(lookups.row(4), embeds.embedding("Berlin").unwrap());
851        assert_eq!(found, &[true, true, false, true, true, false]);
852    }
853
854    #[test]
855    #[cfg(feature = "memmap")]
856    fn mmap() {
857        use crate::chunks::storage::MmapArray;
858        use crate::io::MmapEmbeddings;
859        let check_embeds = test_embeddings();
860        let mut reader = BufReader::new(File::open("testdata/similarity.fifu").unwrap());
861        let embeds: Embeddings<SimpleVocab, MmapArray> =
862            Embeddings::mmap_embeddings(&mut reader).unwrap();
863        assert_eq!(embeds.vocab(), check_embeds.vocab());
864
865        #[cfg(target_endian = "little")]
866        assert_eq!(embeds.storage().view(), check_embeds.storage().view());
867    }
868
869    #[test]
870    fn to_explicit() {
871        let mut reader = BufReader::new(File::open("testdata/fasttext.bin").unwrap());
872        let embeds = Embeddings::read_fasttext(&mut reader).unwrap();
873        let expl_embeds = embeds.to_explicit().unwrap();
874        let ganz = embeds.embedding("ganz").unwrap();
875        let ganz_expl = expl_embeds.embedding("ganz").unwrap();
876        assert_eq!(ganz, ganz_expl);
877
878        assert!(embeds
879            .vocab
880            .idx("anz")
881            .map(|i| i.subword().is_some())
882            .unwrap_or(false));
883        let anz_idx = embeds.vocab.indexer().index_ngram(&"anz".into())[0] as usize
884            + embeds.vocab.words_len();
885        let anz = embeds.storage.embedding(anz_idx);
886        let anz_idx = expl_embeds.vocab.indexer().index_ngram(&"anz".into())[0] as usize
887            + embeds.vocab.words_len();
888        let anz_expl = expl_embeds.storage.embedding(anz_idx);
889        assert_eq!(anz, anz_expl);
890    }
891
892    #[test]
893    fn norms() {
894        let vocab = SimpleVocab::new(vec!["norms".to_string(), "test".to_string()]);
895        let storage = NdArray::new(array![[1f32], [-1f32]]);
896        let norms = NdNorms::new(array![2f32, 3f32]);
897        let check = Embeddings::new(None, vocab, storage, norms);
898
899        let mut serialized = Cursor::new(Vec::new());
900        check.write_embeddings(&mut serialized).unwrap();
901        serialized.seek(SeekFrom::Start(0)).unwrap();
902
903        let embeddings: Embeddings<SimpleVocab, NdArray> =
904            Embeddings::read_embeddings(&mut serialized).unwrap();
905
906        assert!(check
907            .norms()
908            .unwrap()
909            .view()
910            .abs_diff_eq(&embeddings.norms().unwrap().view(), 1e-8),);
911    }
912
913    #[test]
914    fn write_read_simple_roundtrip() {
915        let check_embeds = test_embeddings();
916        let mut cursor = Cursor::new(Vec::new());
917        check_embeds.write_embeddings(&mut cursor).unwrap();
918        cursor.seek(SeekFrom::Start(0)).unwrap();
919        let embeds: Embeddings<SimpleVocab, NdArray> =
920            Embeddings::read_embeddings(&mut cursor).unwrap();
921        assert_eq!(embeds.storage().view(), check_embeds.storage().view());
922        assert_eq!(embeds.vocab(), check_embeds.vocab());
923        assert_eq!(
924            cursor.into_inner().len() as u64,
925            check_embeds.write_embeddings_len(0)
926        );
927    }
928
929    #[test]
930    fn write_read_simple_metadata_roundtrip() {
931        let check_embeds = test_embeddings_with_metadata();
932
933        let mut cursor = Cursor::new(Vec::new());
934        check_embeds.write_embeddings(&mut cursor).unwrap();
935        cursor.seek(SeekFrom::Start(0)).unwrap();
936        let embeds: Embeddings<SimpleVocab, NdArray> =
937            Embeddings::read_embeddings(&mut cursor).unwrap();
938        assert_eq!(embeds.storage().view(), check_embeds.storage().view());
939        assert_eq!(embeds.vocab(), check_embeds.vocab());
940        assert_eq!(
941            cursor.into_inner().len() as u64,
942            check_embeds.write_embeddings_len(0)
943        );
944    }
945
946    #[test]
947    fn embeddings_write_length_different_offsets() {
948        let embeddings: Vec<Embeddings<VocabWrap, StorageWrap>> = vec![
949            test_embeddings().into(),
950            test_embeddings_with_metadata().into(),
951            test_embeddings_fasttext().into(),
952            test_embeddings_quantized().into(),
953        ];
954
955        for check_embeddings in &embeddings {
956            for offset in 0..16u64 {
957                let mut cursor = Cursor::new(Vec::new());
958                cursor.seek(SeekFrom::Start(offset)).unwrap();
959                check_embeddings.write_embeddings(&mut cursor).unwrap();
960                assert_eq!(
961                    cursor.into_inner().len() as u64 - offset,
962                    check_embeddings.write_embeddings_len(offset)
963                );
964            }
965        }
966    }
967}