use fst::raw::{Fst, Output};
use std::error::Error;
use std::fs::File;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::mem::replace;
use std::ops::Range;
use std::path::PathBuf;
#[cfg(feature = "huggingface")]
use tokenizers::tokenizer::{Model, Token as HfToken};
#[inline]
fn find_longest_prefix<D: AsRef<[u8]>>(fst: &Fst<D>, input: &[u8]) -> Option<(usize, u64)> {
let mut node = fst.root();
let mut out = Output::zero();
let mut last_match: Option<(usize, Output)> = None;
for (i, &b) in input.iter().enumerate() {
if let Some(trans_index) = node.find_input(b) {
let t = node.transition(trans_index);
node = fst.node(t.addr);
out = out.cat(t.out);
if node.is_final() {
last_match = Some((i + 1, out.cat(node.final_output())));
}
} else {
break;
}
}
last_match.map(|(i, o)| (i, o.value()))
}
fn char_offs(text: &str, last_known_char: usize, range: Range<usize>) -> usize {
text[range].chars().count() + last_known_char
}
pub trait TokenID: PartialEq + Clone {
fn zero() -> Self;
fn coerce(t: u64) -> Self;
fn restore(self) -> u64;
}
impl TokenID for u64 {
fn zero() -> Self {
0
}
#[inline(always)]
fn coerce(t: u64) -> Self {
t
}
#[inline(always)]
fn restore(self) -> u64 {
self
}
}
impl TokenID for i64 {
fn zero() -> Self {
0
}
#[inline(always)]
fn coerce(t: u64) -> Self {
t as i64
}
#[inline(always)]
fn restore(self) -> u64 {
self as u64
}
}
impl TokenID for i32 {
fn zero() -> Self {
0
}
#[inline(always)]
fn coerce(t: u64) -> Self {
t as i32
}
#[inline(always)]
fn restore(self) -> u64 {
self as u64
}
}
impl TokenID for f64 {
fn zero() -> Self {
0.0
}
#[inline(always)]
fn coerce(t: u64) -> Self {
t as f64
}
#[inline(always)]
fn restore(self) -> u64 {
self as u64
}
}
pub struct AlephAlphaTokenizer {
tokens: Vec<String>,
starters: Fst<Vec<u8>>,
followers: Fst<Vec<u8>>,
special_tokens: Vec<u64>,
unk_id: u32,
prefix: Option<u32>,
suffix: Option<u32>,
}
impl AlephAlphaTokenizer {
pub fn from_vocab(path: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
let vocab = File::open(path)?;
let tokens = BufReader::new(vocab)
.lines()
.collect::<Result<Vec<String>, std::io::Error>>()?;
let mut starter: Vec<(Vec<u8>, u64)> = Vec::new();
let mut follower: Vec<(Vec<u8>, u64)> = Vec::new();
let mut special_tokens = Vec::new();
let mut unk_id = None;
let mut prefix = None;
let mut suffix = None;
for (i, tok) in tokens.iter().enumerate() {
let token = tok.trim().as_bytes();
if token.starts_with(b"[") && token.ends_with(b"]") {
if token.starts_with(b"[unused") {
continue;
}
if token == b"[UNK]" {
unk_id = Some(i as u32);
} else if token == b"[CLS]" {
prefix = Some(i as u32);
} else if token == b"[SEP]" {
suffix = Some(i as u32);
}
special_tokens.push(i as u64);
}
if token.starts_with(b"##") {
follower.push((token[2..].to_vec(), i as u64));
} else {
starter.push((token.to_vec(), i as u64));
}
}
let unk_id = if let Some(u) = unk_id {
u
} else {
return Err(Box::new(std::env::VarError::NotPresent));
};
starter.sort_by(|(k, _), (j, _)| k.cmp(j));
follower.sort_by(|(k, _), (j, _)| k.cmp(j));
let starters = Fst::from_iter_map(starter)?;
let followers = Fst::from_iter_map(follower)?;
Ok(AlephAlphaTokenizer {
tokens,
starters,
followers,
special_tokens,
unk_id,
prefix,
suffix,
})
}
pub fn char_ranges<'i>(
text: &'i str,
ranges: impl Iterator<Item = &'i Range<usize>> + 'i,
) -> impl Iterator<Item = (Range<usize>, Range<usize>)> + 'i {
let (mut last_char, mut last_byte) = (0, 0);
ranges.map(move |r| {
let (s, e) = (r.start, r.end);
let cs = char_offs(text, last_char, last_byte..s);
last_char = char_offs(text, cs, s..e);
last_byte = e;
(r.clone(), cs..last_char)
})
}
#[inline]
fn add_prefix<T: TokenID>(&self, token_ids: &mut Vec<T>, token_ranges: &mut Vec<Range<usize>>) {
if let Some(id) = self.prefix {
token_ids.push(T::coerce(u64::from(id)));
token_ranges.push(0..0);
}
}
#[inline]
fn add_suffix<T: TokenID>(&self, token_ids: &mut Vec<T>, token_ranges: &mut Vec<Range<usize>>) {
if let Some(id) = self.suffix {
let pos = token_ranges.last().map_or(0, |range| range.end);
token_ids.push(T::coerce(u64::from(id)));
token_ranges.push(pos..pos);
}
}
fn tokenize_word<T: TokenID>(
&self,
text: &str,
range: Range<usize>,
token_ids: &mut Vec<T>,
token_ranges: &mut Vec<Range<usize>>,
) {
let (start, end) = (range.start, range.end);
let word_index = token_ids.len();
let mut last_index = start;
if let Some((len, id)) = find_longest_prefix(&self.starters, text[start..end].as_bytes()) {
last_index = start + len;
token_ids.push(T::coerce(id));
token_ranges.push(start..last_index);
while last_index < end {
if let Some((len, id)) =
find_longest_prefix(&self.followers, &text[last_index..end].as_bytes())
{
let next_index = last_index + len;
token_ids.push(T::coerce(id));
token_ranges.push(last_index..replace(&mut last_index, next_index));
} else {
break;
}
}
}
if last_index < end {
assert!(word_index <= token_ids.len());
token_ids.truncate(word_index);
token_ids.push(T::coerce(u64::from(self.unk_id)));
token_ranges.truncate(word_index);
token_ranges.push(range);
}
}
pub fn tokens_into<T: TokenID>(
&self,
text: &str,
token_ids: &mut Vec<T>,
token_ranges: &mut Vec<Range<usize>>,
words: Option<&mut Vec<Range<usize>>>,
) {
token_ids.clear();
token_ranges.clear();
let text_len = text.len();
let mut words = words;
if let Some(w) = words.as_mut() {
w.clear();
}
let mut last_offs = 0;
self.add_prefix(token_ids, token_ranges);
let mut last_token = token_ids.len();
while let Some(next_ws) = text[last_offs..].find(char::is_whitespace) {
if next_ws != 0 {
self.tokenize_word(
text,
last_offs..last_offs + next_ws,
token_ids,
token_ranges,
);
if let Some(w) = words.as_mut() {
w.push(last_token..replace(&mut last_token, token_ids.len()));
}
}
last_offs += next_ws + 1;
}
if last_offs < text_len {
self.tokenize_word(text, last_offs..text_len, token_ids, token_ranges);
}
self.add_suffix(token_ids, token_ranges);
}
#[inline]
pub fn text_of<T: TokenID>(&self, token_id: T) -> &str {
&self.tokens[token_id.restore() as usize]
}
pub fn texts_of<'t, T: TokenID>(&'t self, token_ids: &[T]) -> Vec<&'t str> {
token_ids
.iter()
.cloned()
.map(|id| self.text_of(id))
.collect()
}
#[inline]
pub fn is_special<T: TokenID>(&self, token_id: T) -> bool {
self.special_tokens.contains(&token_id.restore())
}
#[inline]
pub fn attention<T: TokenID, U: TokenID>(token_id: T) -> U {
if token_id == T::zero() {
U::zero()
} else {
U::coerce(1)
}
}
pub fn attentions_into<T: TokenID, U: TokenID>(token_ids: &[T], attns: &mut Vec<U>) {
attns.clear();
attns.extend(
token_ids
.iter()
.cloned()
.map(AlephAlphaTokenizer::attention),
);
}
pub fn save_vocab(&self, vocab_path: PathBuf) -> Result<PathBuf, Box<dyn Error + Send + Sync>> {
let vocab = File::create(&vocab_path)?;
let mut vocab_writer = BufWriter::new(vocab);
for token in &self.tokens {
writeln!(vocab_writer, "{}", token)?;
}
Ok(vocab_path)
}
}
#[cfg(feature = "huggingface")]
use std::{borrow::Cow, path::Path};
#[cfg(feature = "huggingface")]
impl Model for AlephAlphaTokenizer {
fn tokenize(
&self,
tokens: Vec<(String, (usize, usize))>,
) -> Result<Vec<HfToken>, Box<dyn Error + Send + Sync>> {
let mut result = Vec::with_capacity(tokens.len());
let mut last_byte = 0;
let mut last_char = 0;
for (index, (word_str, offsets)) in tokens.into_iter().enumerate() {
let word = index as u32;
let word_index = result.len();
let word_bytes = word_str.as_bytes();
let word_len = word_bytes.len();
let mut last_index = 0;
if let Some((start_index, id)) = find_longest_prefix(&self.starters, word_bytes) {
result.push(HfToken {
id: id as u32,
value: word_str[..start_index].to_string(),
offsets: (offsets.0, offsets.0 + start_index),
word,
});
last_index = start_index;
while last_index < word_len {
if let Some((len, id)) =
find_longest_prefix(&self.followers, &word_bytes[last_index..])
{
let start = offsets.0 + last_index;
result.push(HfToken {
id: id as u32,
value: "##".to_string() + &word_str[last_index..last_index + len],
offsets: (start, start + len),
word,
});
last_index += len;
} else {
break;
}
}
}
if last_index < word_len {
assert!(word_index <= result.len());
result.truncate(word_index);
result.push(HfToken {
id: self.unk_id,
value: "[UNK±".to_string(),
offsets: (offsets.0, offsets.1),
word,
});
}
}
Ok(result)
}
fn token_to_id(&self, token: &str) -> Option<u32> {
if token.starts_with("##") {
self.followers.get(&token[2..])
} else {
self.starters.get(token)
}
.map(|x| x.value() as u32)
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.tokens.get(id as usize).cloned()
}
fn get_vocab_size(&self) -> usize {
self.tokens.len()
}
fn save(
&self,
folder: &Path,
name: Option<&str>,
) -> Result<Vec<PathBuf>, Box<dyn Error + Send + Sync>> {
let vocab_name = name.map_or(Cow::Borrowed("vocab.txt"), |n| {
Cow::Borrowed(n) + "-vocab.txt"
});
let mut vocab_path = folder.to_path_buf();
vocab_path.push(&Path::new(vocab_name.as_ref()));
self.save_vocab(vocab_path).map(|p| vec![p])
}
}