use std::collections::HashMap;
use std::io::{Read, Seek, Write};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use crate::error::{Error, Result};
mod subword;
pub use subword::{
BucketSubwordVocab, ExplicitSubwordVocab, FastTextSubwordVocab, FloretSubwordVocab,
NGramIndices, SubwordIndices, SubwordVocab,
};
mod simple;
pub use simple::SimpleVocab;
mod wrappers;
pub use wrappers::VocabWrap;
#[allow(clippy::len_without_is_empty)]
pub trait Vocab {
fn idx(&self, word: &str) -> Option<WordIndex>;
fn words_len(&self) -> usize;
fn vocab_len(&self) -> usize;
fn words(&self) -> &[String];
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum WordIndex {
Word(usize),
Subword(Vec<usize>),
}
impl WordIndex {
pub fn word(&self) -> Option<usize> {
use WordIndex::*;
match self {
Word(idx) => Some(*idx),
Subword(_) => None,
}
}
pub fn subword(&self) -> Option<&[usize]> {
use WordIndex::*;
match self {
Word(_) => None,
Subword(indices) => Some(indices),
}
}
}
pub(crate) fn create_indices(words: &[String]) -> HashMap<String, usize> {
let mut indices = HashMap::new();
for (idx, word) in words.iter().enumerate() {
indices.insert(word.to_owned(), idx);
}
indices
}
pub(crate) fn read_string<R>(read: &mut R) -> Result<String>
where
R: Read,
{
let string_len =
read.read_u32::<LittleEndian>()
.map_err(|e| Error::read_error("Cannot read string length", e))? as usize;
let mut bytes = vec![0; string_len];
read.read_exact(&mut bytes)
.map_err(|e| Error::read_error("Cannot read item", e))?;
String::from_utf8(bytes)
.map_err(|e| Error::Format(format!("Item contains invalid UTF-8: {}", e)))
.map_err(Error::from)
}
pub(crate) fn read_vocab_items<R>(read: &mut R, len: usize) -> Result<Vec<String>>
where
R: Read,
{
let mut items = Vec::with_capacity(len);
for _ in 0..len {
let item = read_string(read)?;
items.push(item);
}
Ok(items)
}
pub(crate) fn write_string<W>(write: &mut W, s: &str) -> Result<()>
where
W: Write,
{
write
.write_u32::<LittleEndian>(s.len() as u32)
.map_err(|e| Error::write_error("Cannot write string length", e))?;
write
.write_all(s.as_bytes())
.map_err(|e| Error::write_error("Cannot write string", e))
}
pub(crate) fn write_vocab_items<W>(write: &mut W, items: &[String]) -> Result<()>
where
W: Write + Seek,
{
for word in items {
write_string(write, word)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::io::{Cursor, Read, Seek, SeekFrom};
use byteorder::{LittleEndian, ReadBytesExt};
use crate::chunks::io::WriteChunk;
use crate::vocab::VocabWrap;
fn read_chunk_size(read: &mut impl Read) -> u64 {
read.read_u32::<LittleEndian>().unwrap();
read.read_u64::<LittleEndian>().unwrap()
}
#[cfg(test)]
pub(crate) fn test_vocab_chunk_len(check_vocab: VocabWrap) {
for offset in 0..16u64 {
let mut cursor = Cursor::new(Vec::new());
cursor.seek(SeekFrom::Start(offset)).unwrap();
check_vocab.write_chunk(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(offset)).unwrap();
let chunk_size = read_chunk_size(&mut cursor);
assert_eq!(
cursor.read_to_end(&mut Vec::new()).unwrap(),
chunk_size as usize
);
assert_eq!(
cursor.into_inner().len() as u64 - offset,
check_vocab.chunk_len(offset)
);
}
}
}