use std::{convert::From, io, iter, mem};
mod model;
mod serialization;
mod trainer;
mod word;
type Pair = (u32, u32);
#[derive(Debug)]
pub enum Error {
Io(std::io::Error),
JsonError(serde_json::Error),
BadVocabulary,
BadMerges(usize),
MergeTokenOutOfVocabulary(String),
UnkTokenOutOfVocabulary(String),
InvalidDropout,
}
impl From<io::Error> for Error {
fn from(error: io::Error) -> Self {
Error::Io(error)
}
}
impl From<serde_json::Error> for Error {
fn from(error: serde_json::Error) -> Self {
Error::JsonError(error)
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Error::Io(e) => write!(f, "IoError: {}", e),
Error::JsonError(e) => write!(f, "JsonError: {}", e),
Error::BadVocabulary => write!(f, "Bad vocabulary json file"),
Error::BadMerges(line) => write!(f, "Merges text file invalid at line {}", line),
Error::MergeTokenOutOfVocabulary(token) => {
write!(f, "Token `{}` out of vocabulary", token)
}
Error::UnkTokenOutOfVocabulary(token) => {
write!(f, "Unk token `{}` not found in the vocabulary", token)
}
Error::InvalidDropout => write!(f, "Dropout should be between 0 and 1"),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::Io(e) => Some(e),
Error::JsonError(e) => Some(e),
_ => None,
}
}
}
pub(crate) trait WithFirstLastIterator: Iterator + Sized {
fn with_first_and_last(self) -> FirstLastIterator<Self>;
}
impl<I> WithFirstLastIterator for I
where
I: Iterator,
{
fn with_first_and_last(self) -> FirstLastIterator<Self> {
FirstLastIterator {
first: true,
iter: self.peekable(),
}
}
}
pub(crate) struct FirstLastIterator<I>
where
I: Iterator,
{
first: bool,
iter: iter::Peekable<I>,
}
impl<I> Iterator for FirstLastIterator<I>
where
I: Iterator,
{
type Item = (bool, bool, I::Item);
fn next(&mut self) -> Option<Self::Item> {
let first = mem::replace(&mut self.first, false);
self.iter
.next()
.map(|e| (first, self.iter.peek().is_none(), e))
}
}
pub use model::*;
pub use trainer::*;
use word::*;