1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
use std::collections::HashMap;
use std::io::{Read, Seek, Write};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use crate::error::{Error, Result};
mod subword;
pub use subword::{
BucketSubwordVocab, ExplicitSubwordVocab, FastTextSubwordVocab, NGramIndices, SubwordIndices,
SubwordVocab,
};
mod simple;
pub use simple::SimpleVocab;
mod wrappers;
pub use wrappers::VocabWrap;
#[allow(clippy::len_without_is_empty)]
pub trait Vocab {
fn idx(&self, word: &str) -> Option<WordIndex>;
fn words_len(&self) -> usize;
fn vocab_len(&self) -> usize;
fn words(&self) -> &[String];
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum WordIndex {
Word(usize),
Subword(Vec<usize>),
}
impl WordIndex {
pub fn word(&self) -> Option<usize> {
use WordIndex::*;
match self {
Word(idx) => Some(*idx),
Subword(_) => None,
}
}
pub fn subword(&self) -> Option<&[usize]> {
use WordIndex::*;
match self {
Word(_) => None,
Subword(indices) => Some(indices),
}
}
}
pub(crate) fn create_indices(words: &[String]) -> HashMap<String, usize> {
let mut indices = HashMap::new();
for (idx, word) in words.iter().enumerate() {
indices.insert(word.to_owned(), idx);
}
indices
}
pub(crate) fn read_vocab_items<R>(read: &mut R, len: usize) -> Result<Vec<String>>
where
R: Read + Seek,
{
let mut items = Vec::with_capacity(len);
for _ in 0..len {
let item_len =
read.read_u32::<LittleEndian>()
.map_err(|e| Error::io_error("Cannot read item length", e))? as usize;
let mut bytes = vec![0; item_len];
read.read_exact(&mut bytes)
.map_err(|e| Error::io_error("Cannot read item", e))?;
let item = String::from_utf8(bytes)
.map_err(|e| Error::Format(format!("Item contains invalid UTF-8: {}", e)))
.map_err(Error::from)?;
items.push(item);
}
Ok(items)
}
pub(crate) fn write_vocab_items<W>(write: &mut W, items: &[String]) -> Result<()>
where
W: Write + Seek,
{
for word in items {
write
.write_u32::<LittleEndian>(word.len() as u32)
.map_err(|e| Error::io_error("Cannot write token length", e))?;
write
.write_all(word.as_bytes())
.map_err(|e| Error::io_error("Cannot write token", e))?;
}
Ok(())
}
#[cfg(test)]
pub(crate) fn read_chunk_size(read: &mut impl Read) -> u64 {
read.read_u32::<LittleEndian>().unwrap();
read.read_u64::<LittleEndian>().unwrap()
}