1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
use std::fs::File; use std::io::BufReader; use failure::{format_err, Error, ResultExt}; use rust2vec::prelude::*; #[derive(Clone, Copy, PartialEq, Eq)] pub enum EmbeddingFormat { FinalFusion, FinalFusionMmap, Word2Vec, Text, TextDims, } impl EmbeddingFormat { pub fn try_from(format: impl AsRef<str>) -> Result<Self, Error> { use EmbeddingFormat::*; match format.as_ref() { "finalfusion" => Ok(FinalFusion), "finalfusion_mmap" => Ok(FinalFusionMmap), "word2vec" => Ok(Word2Vec), "text" => Ok(Text), "textdims" => Ok(TextDims), unknown => Err(format_err!("Unknown embedding format: {}", unknown)), } } } pub fn read_embeddings_view( filename: &str, embedding_format: EmbeddingFormat, ) -> Result<Embeddings<VocabWrap, StorageViewWrap>, Error> { let f = File::open(filename).context("Cannot open embeddings file")?; let mut reader = BufReader::new(f); use EmbeddingFormat::*; let embeddings = match embedding_format { FinalFusion => ReadEmbeddings::read_embeddings(&mut reader), FinalFusionMmap => MmapEmbeddings::mmap_embeddings(&mut reader), Word2Vec => ReadWord2Vec::read_word2vec_binary(&mut reader, true).map(Embeddings::into), Text => ReadText::read_text(&mut reader, true).map(Embeddings::into), TextDims => ReadTextDims::read_text_dims(&mut reader, true).map(Embeddings::into), } .context("Cannot read embeddings")?; Ok(embeddings) }