impl AsMut<Self> for TokenDict{
fn as_mut(&mut self)->&mut Self{self}
}
impl AsRef<Self> for TokenDict{
fn as_ref(&self)->&Self{self}
}
impl Default for TokenDict{
fn default()->Self{
let x:[Vec<u8>;0]=[];
x.into_iter().collect()
}
}
impl DoubleEndedIterator for DictIntoIter{
fn next_back(&mut self)->Option<Self::Item>{
self.range.next_back().map(|n|if n<256{SINGLE_TOKENS[n].clone()}else{self.tokens[n-256].clone()})
}
fn nth_back(&mut self,n:usize)->Option<Self::Item>{
self.range.nth_back(n).map(|n|if n<256{SINGLE_TOKENS[n].clone()}else{self.tokens[n-256].clone()})
}
fn rfold<B,F:FnMut(B,Self::Item)->B>(self,mut init:B,mut f:F)->B{
let (range,tokens)=(self.range,self.tokens);
let (start,stop)=(range.start,range.end);
init=tokens[start.saturating_sub(256)..stop.saturating_sub(256)].iter().cloned().rfold(init,&mut f);
SINGLE_TOKENS[start.min(256)..stop.min(256)].iter().cloned().rfold(init,f)
}
}
impl ExactSizeIterator for DictIntoIter{
fn len(&self)->usize{self.range.len()}
}
impl Index<u32> for TokenDict{
fn index(&self,ix:u32)->&Self::Output{
let ix=ix as usize;
if ix<256{&SINGLE_BYTES[ix..ix+1]}else{&self.tokens[ix-256]}
}
type Output=[u8];
}
impl Index<usize> for TokenDict{
fn index(&self,ix:usize)->&Self::Output{
if ix<256{&SINGLE_TOKENS[ix]}else{&self.tokens[ix-256]}
}
type Output=Token;
}
impl IntoIterator for TokenDict{
fn into_iter(self)->Self::IntoIter{
DictIntoIter{range:0..self.len(),tokens:self.tokens}
}
type IntoIter=DictIntoIter;
type Item=Token;
}
impl Iterator for DictIntoIter{
fn count(self)->usize{self.range.count()}
fn fold<B,F:FnMut(B,Self::Item)->B>(self,mut init:B,mut f:F)->B{
let (range,tokens)=(self.range,self.tokens);
let (start,stop)=(range.start,range.end);
init=SINGLE_TOKENS[start.min(256)..stop.min(256)].iter().cloned().fold(init,&mut f);
tokens[start.saturating_sub(256)..stop.saturating_sub(256)].iter().cloned().fold(init,f)
}
fn last(mut self)->Option<Self::Item>{self.next_back()}
fn next(&mut self)->Option<Self::Item>{
self.range.next().map(|n|if n<256{SINGLE_TOKENS[n].clone()}else{self.tokens[n-256].clone()})
}
fn nth(&mut self,n:usize)->Option<Self::Item>{
self.range.nth(n).map(|n|if n<256{SINGLE_TOKENS[n].clone()}else{self.tokens[n-256].clone()})
}
fn size_hint(&self)->(usize,Option<usize>){self.range.size_hint()}
type Item=Token;
}
#[cfg(feature="serial")]
impl Serialize for TokenDict{
fn serialize<S:Serializer>(&self,serializer:S)->Result<S::Ok,S::Error>{
let data:Vec<Vec<u8>>=self.iter().map(|t|t.to_vec()).collect();
data.serialize(serializer)
}
}
impl TokenDict{
/// decodes the tokens into bytes
pub fn detokenize<I:IntoIterator>(&self,tokens:I)->Detokenization<I::IntoIter> where I::Item:Val<u32>{
Detokenization{inner:tokens.into_iter().fuse(),maxtokenlen:self.maxtokenlen,position:1,tokenid:0,tokens:self.tokens.clone()}
}
/// creates an iterator over tokens
pub fn detoken_iter<I:IntoIterator>(&self,tokens:I)->impl Iterator<Item=Token> where I::Item:Val<u32>{
let tokenizer=self.clone();
tokens.into_iter().map(move|id|tokenizer[id.val() as usize].clone())
}
/// decodes the tokens into chars, replacing invalid unicode with replacement character
pub fn detokenize_str<I:IntoIterator>(&self,tokens:I)->impl Iterator<Item=char> where I::Item:Val<u32>{
UTF8CharIter::from(self.detokenize(tokens)).map(|r|if let Ok(c)=r{c}else{char::REPLACEMENT_CHARACTER})
}
/// decodes the tokens into chars, replacing invalid unicode with replacement character
pub fn detokenize_string<I:IntoIterator>(&self,tokens:I)->String where I::Item:Val<u32>{self.detokenize_str(tokens).collect()}
/// accumulates frequencies of each token in the text as if they were tokenized by this tokenizer
pub fn frequencies<I:IntoIterator,O:Into<Option<Vec<usize>>>>(&self,data:I,freq:O)->Vec<usize> where I::Item:Val<u8>{
let mut freq=freq.into().unwrap_or_default();
if freq.len()<self.len(){freq.resize(self.len(),0)}
for t in self.tokenize(data){freq[t as usize]+=1}
freq
}
/// gets an id for the token if it is in the dictionary
pub fn get_id(&self,token:&[u8])->Option<u32>{self.ids[*token.get(0)? as usize].get(token.iter().copied().skip(1)).copied()}
/// returns an interator over the possible tokens generated by this tokenizer
pub fn iter(&self)->DictIter<'_>{
DictIter{range:0..self.len(),tokens:&self.tokens}
}
/// returns the number of possible token ids generated by this tokenizer
pub fn len(&self)->usize{self.tokens.len()+256}
/// finds token pairs and returns new tokens of them mapped to their frequencies with ids as if they were added to this dictionary. Tokens with ids within the current dictionary will have those ids
pub fn pairs<I:IntoIterator,O:Into<Option<HashMap<Token,usize>>>>(&self,data:I,freq:O)->HashMap<Token,usize> where I::Item:Val<u8>{
let mut freq=freq.into().unwrap_or_default();
let mut nexttokenid=freq.keys().map(|t|t.id()+1).chain([self.len() as u32]).max().unwrap();
let mut previous:Option<u32>=None;
let mut temp:Vec<u8>=Vec::new();
self.tokenize(data).for_each(|id|{
if let Some(previous)=previous{
temp.clear();
temp.extend(self[previous as usize].clone());
temp.extend(self[id as usize].clone());
if let Some(f)=freq.get_mut(temp.as_slice()){
*f+=1
}else{
let newid=if let Some(id)=self.get_id(&temp){id}else{post_inc!(nexttokenid)};
let token=Token::new(newid,Some(Arc::from(temp.as_slice())));
freq.insert(token,1);
}
}
previous=Some(id);
});
freq
}
/// adds the token to the dictionary
pub fn push<A:AsRef<[u8]>>(&mut self,token:A){self.extend(Some(token))}
/// converts the string to a token vec
pub fn string_to_tokens<S:?Sized+AsRef<str>>(&self,input:&S)->Vec<u32>{self.tokenize(input.as_ref().as_bytes()).collect()}
/// creates an iterator over tokens
pub fn token_iter<I:IntoIterator>(&self,bytes:I)->impl Iterator<Item=Token> where I::Item:Val<u8>{
let tokenizer=self.clone();
self.tokenize(bytes).map(move|id|tokenizer[id as usize].clone())
}
/// converts the bytes to tokens
pub fn tokenize<I:IntoIterator>(&self,bytes:I)->Tokenization<I::IntoIter> where I::Item:Val<u8>{
Tokenization{ids:self.ids.clone(),inner:bytes.into_iter().fuse(),state:VecDeque::with_capacity(self.maxtokenlen)}
}
/// converts the string to tokens
pub fn tokenize_str<'a,S:?Sized+AsRef<str>>(&self,input:&'a S)->Tokenization<SliceIter<'a,u8>>{self.tokenize(input.as_ref().as_bytes())}
/// converts the string to tokens
pub fn tokenize_string(&self,input:String)->Tokenization<VecIntoIter<u8>>{self.tokenize(Vec::from(input))}
/// converts the token vec to string
pub fn tokens_to_string<V:?Sized+AsRef<[u32]>>(&self,input:&V)->String{String::from_utf8_lossy(&self.detokenize(input.as_ref()).collect::<Vec<u8>>()).to_string()}
}
#[cfg(feature="serial")]
impl<'a> Deserialize<'a> for TokenDict{
fn deserialize<D:Deserializer<'a>>(deserializer:D)->Result<Self,D::Error>{
let data:Vec<Vec<u8>>=Deserialize::deserialize(deserializer)?;
Ok(data.into_iter().collect())
}
}
impl<'a> DoubleEndedIterator for DictIter<'a>{
fn next_back(&mut self)->Option<Self::Item>{
self.range.next_back().map(|n|if n<256{&SINGLE_TOKENS[n]}else{&self.tokens[n-256]})
}
fn nth_back(&mut self,n:usize)->Option<Self::Item>{
self.range.nth_back(n).map(|n|if n<256{&SINGLE_TOKENS[n]}else{&self.tokens[n-256]})
}
fn rfold<B,F:FnMut(B,Self::Item)->B>(self,mut init:B,mut f:F)->B{
let (range,tokens)=(self.range,self.tokens);
let (start,stop)=(range.start,range.end);
init=tokens[start.saturating_sub(256)..stop.saturating_sub(256)].iter().rfold(init,&mut f);
SINGLE_TOKENS[start.min(256)..stop.min(256)].iter().rfold(init,f)
}
}
impl<'a> ExactSizeIterator for DictIter<'a>{
fn len(&self)->usize{self.range.len()}
}
impl<'a> Iterator for DictIter<'a>{
fn count(self)->usize{self.range.count()}
fn fold<B,F:FnMut(B,Self::Item)->B>(self,mut init:B,mut f:F)->B{
let (range,tokens)=(self.range,self.tokens);
let (start,stop)=(range.start,range.end);
init=SINGLE_TOKENS[start.min(256)..stop.min(256)].iter().fold(init,&mut f);
tokens[start.saturating_sub(256)..stop.saturating_sub(256)].iter().fold(init,f)
}
fn last(mut self)->Option<Self::Item>{self.next_back()}
fn next(&mut self)->Option<Self::Item>{
self.range.next().map(|n|if n<256{&SINGLE_TOKENS[n]}else{&self.tokens[n-256]})
}
fn nth(&mut self,n:usize)->Option<Self::Item>{
self.range.nth(n).map(|n|if n<256{&SINGLE_TOKENS[n]}else{&self.tokens[n-256]})
}
fn size_hint(&self)->(usize,Option<usize>){self.range.size_hint()}
type Item=&'a Token;
}
impl<A:AsRef<[u8]>> Extend<A> for TokenDict{
fn extend<I:IntoIterator<Item=A>>(&mut self,iter:I){
let (ids,tokens)=(Arc::make_mut(&mut self.ids),Arc::make_mut(&mut self.tokens));
let maxtokenlen=&mut self.maxtokenlen;
iter.into_iter().filter(|a|a.as_ref().len()>1).for_each(|a|{
let id=u32::try_from(tokens.len()+256).unwrap();
let token:Arc<[u8]>=Arc::from(a.as_ref());
ids[token[0] as usize].insert(token.iter().copied().skip(1),id);
*maxtokenlen=(*maxtokenlen).max(token.len());
tokens.push(Token::new(id,Some(token)))
});
}
}
impl<A:AsRef<[u8]>> FromIterator<A> for TokenDict{
fn from_iter<I:IntoIterator<Item=A>>(iter:I)->Self{
let mut maxtokenlen=1;
let mut ids:[Trie<_,_>;256]=std::array::from_fn(|_|Trie::new());
let tokens:Vec<Token>=iter.into_iter().filter(|t|t.as_ref().len()>1).enumerate().map(|(n,t)|{
let id=u32::try_from(n+256).unwrap();
let token:Arc<[u8]>=Arc::from(t.as_ref());
ids[token[0] as usize].insert(token.iter().copied().skip(1),id);
maxtokenlen=maxtokenlen.max(token.len());
Token::new(id,Some(token))
}).collect();
let (ids,tokens)=(Arc::new(ids),Arc::new(tokens));
Self{ids,maxtokenlen,tokens}
}
}
impl<I:Iterator> Iterator for Detokenization<I> where I::Item:Val<u32>{
fn fold<B,F:FnMut(B,Self::Item)->B>(self,init:B,mut f:F)->B{
self.inner.map(Val::val).fold(init,|acc,tokenid|if tokenid<256{f(acc,tokenid as u8)}else{self.tokens[tokenid as usize].iter().fold(acc,|acc,&b|f(acc,b))})
}
fn next(&mut self)->Option<u8>{
let (inner,position)=(&mut self.inner,&mut self.position);
let tokenid=&mut self.tokenid;
let tokens=&self.tokens;
if let Some(b)=if *tokenid<256{(*position==0).then_some(*tokenid as u8)}else{tokens[*tokenid as usize-256].get(*position).map(|&b|b)}{
*position+=1;
b
}else{
*position=1;
*tokenid=inner.map(Val::val).next()?;
if *tokenid<256{*tokenid as u8}else{tokens[*tokenid as usize-256][0]}
}.into()
}
fn size_hint(&self)->(usize,Option<usize>){
let (lowertokens,uppertokens)=self.inner.size_hint();
let maxtoken=self.maxtokenlen;
(lowertokens,uppertokens.map(|h|h*maxtoken))
}
type Item=u8;
}
impl<I:Iterator> Iterator for Tokenization<I> where I::Item:Val<u8>{
fn next(&mut self)->Option<u32>{
let (inner,state)=(&mut self.inner,&mut self.state);
let ids=&self.ids;
state.extend(inner.map(Val::val).take(state.capacity()-state.len()));
if state.len()==0{return None}
let (tokenlen,&tokenid)=if let Some(t)=ids[state[0] as usize].find_longest_prefix_len(state.iter().copied().skip(1)).filter(|(tokenlen,_tokenid)|*tokenlen>0){t}else{return Some(state.pop_front().unwrap() as u32)};
state.drain(..tokenlen+1);
return Some(tokenid)
}
fn size_hint(&self)->(usize,Option<usize>){
let (lowerbytes,upperbytes)=self.inner.size_hint();
let maxtoken=self.state.capacity();
let statelen=self.state.len();
((lowerbytes+statelen).div_ceil(maxtoken),upperbytes.map(|b|b+statelen))
}
type Item=u32;
}
macro_rules! post_inc {
($e:expr) => {{
let old = $e;
$e += 1;
old
}};
}
#[cfg(test)]
mod tests{
#[test]
fn tokenizer_iter(){
let tokenizer:TokenDict=["aa","bb","cc"].into_iter().collect();
let t2:TokenDict=tokenizer.iter().collect();
assert_eq!(tokenizer.tokenize_str("ccaabb").collect::<Vec<_>>(),t2.tokenize_str("ccaabb").collect::<Vec<_>>());
}
#[test]
fn bytes_only(){
let teststring="oishsoghohhduihahdufghud";
let tokenizer=TokenDict::default();
let tokens:Vec<u32>=tokenizer.tokenize_str(teststring).collect();
let detokens:Vec<u8>=tokenizer.detokenize(tokens).collect();
assert_eq!(detokens.as_slice(),teststring.as_bytes());
}
#[test]
fn there_are_tokens_yay(){
let teststring="there are tokens! yay";
let tokenizer:TokenDict=["there","are","tokens","yay"].into_iter().collect();
let tokens:Vec<u32>=tokenizer.tokenize(teststring.bytes()).collect();
let detokens:Vec<u8>=tokenizer.detokenize(&tokens).collect();
assert_eq!(tokens.len(),8);
assert_eq!(detokens.as_slice(),teststring.as_bytes());
}
#[test]
fn test_default_token_dict_detokenize_empty() {
let dict = TokenDict::default();
let inp:Vec<u32>=vec![];
let out: Vec<u8> = dict.detokenize(inp).collect();
assert!(out.is_empty(), "Detokenizing an empty input should yield no bytes");
}
use super::*;
}
pub (crate) const SINGLE_TOKENS:&[Token;256]=&[Token::single(0),Token::single(1),Token::single(2),Token::single(3),Token::single(4),Token::single(5),Token::single(6),Token::single(7),Token::single(8),Token::single(9),Token::single(10),Token::single(11),Token::single(12),Token::single(13),Token::single(14),Token::single(15),Token::single(16),Token::single(17),Token::single(18),Token::single(19),Token::single(20),Token::single(21),Token::single(22),Token::single(23),Token::single(24),Token::single(25),Token::single(26),Token::single(27),Token::single(28),Token::single(29),Token::single(30),Token::single(31),Token::single(32),Token::single(33),Token::single(34),Token::single(35),Token::single(36),Token::single(37),Token::single(38),Token::single(39),Token::single(40),Token::single(41),Token::single(42),Token::single(43),Token::single(44),Token::single(45),Token::single(46),Token::single(47),Token::single(48),Token::single(49),Token::single(50),Token::single(51),Token::single(52),Token::single(53),Token::single(54),Token::single(55),Token::single(56),Token::single(57),Token::single(58),Token::single(59),Token::single(60),Token::single(61),Token::single(62),Token::single(63),Token::single(64),Token::single(65),Token::single(66),Token::single(67),Token::single(68),Token::single(69),Token::single(70),Token::single(71),Token::single(72),Token::single(73),Token::single(74),Token::single(75),Token::single(76),Token::single(77),Token::single(78),Token::single(79),Token::single(80),Token::single(81),Token::single(82),Token::single(83),Token::single(84),Token::single(85),Token::single(86),Token::single(87),Token::single(88),Token::single(89),Token::single(90),Token::single(91),Token::single(92),Token::single(93),Token::single(94),Token::single(95),Token::single(96),Token::single(97),Token::single(98),Token::single(99),Token::single(100),Token::single(101),Token::single(102),Token::single(103),Token::single(104),Token::single(105),Token::single(106),Token::single(107),Token::single(108),Token::single(109),Token::single(110),Token::single(111),Token::single(112),Token::single(113),Token::single(114),Token::single(115),Token::single(116),Token::single(117),Token::single(118),Token::single(119),Token::single(120),Token::single(121),Token::single(122),Token::single(123),Token::single(124),Token::single(125),Token::single(126),Token::single(127),Token::single(128),Token::single(129),Token::single(130),Token::single(131),Token::single(132),Token::single(133),Token::single(134),Token::single(135),Token::single(136),Token::single(137),Token::single(138),Token::single(139),Token::single(140),Token::single(141),Token::single(142),Token::single(143),Token::single(144),Token::single(145),Token::single(146),Token::single(147),Token::single(148),Token::single(149),Token::single(150),Token::single(151),Token::single(152),Token::single(153),Token::single(154),Token::single(155),Token::single(156),Token::single(157),Token::single(158),Token::single(159),Token::single(160),Token::single(161),Token::single(162),Token::single(163),Token::single(164),Token::single(165),Token::single(166),Token::single(167),Token::single(168),Token::single(169),Token::single(170),Token::single(171),Token::single(172),Token::single(173),Token::single(174),Token::single(175),Token::single(176),Token::single(177),Token::single(178),Token::single(179),Token::single(180),Token::single(181),Token::single(182),Token::single(183),Token::single(184),Token::single(185),Token::single(186),Token::single(187),Token::single(188),Token::single(189),Token::single(190),Token::single(191),Token::single(192),Token::single(193),Token::single(194),Token::single(195),Token::single(196),Token::single(197),Token::single(198),Token::single(199),Token::single(200),Token::single(201),Token::single(202),Token::single(203),Token::single(204),Token::single(205),Token::single(206),Token::single(207),Token::single(208),Token::single(209),Token::single(210),Token::single(211),Token::single(212),Token::single(213),Token::single(214),Token::single(215),Token::single(216),Token::single(217),Token::single(218),Token::single(219),Token::single(220),Token::single(221),Token::single(222),Token::single(223),Token::single(224),Token::single(225),Token::single(226),Token::single(227),Token::single(228),Token::single(229),Token::single(230),Token::single(231),Token::single(232),Token::single(233),Token::single(234),Token::single(235),Token::single(236),Token::single(237),Token::single(238),Token::single(239),Token::single(240),Token::single(241),Token::single(242),Token::single(243),Token::single(244),Token::single(245),Token::single(246),Token::single(247),Token::single(248),Token::single(249),Token::single(250),Token::single(251),Token::single(252),Token::single(253),Token::single(254),Token::single(255)];
#[derive(Clone,Debug)]
/// a simple dictionary based detokenizer
pub struct Detokenization<I:Iterator> where I::Item:Val<u32>{inner:Fuse<I>,maxtokenlen:usize,position:usize,tokenid:u32,tokens:Arc<Vec<Token>>}
#[derive(Clone,Debug)]
/// iterator for TokenDict
pub struct DictIter<'a>{range:Range<usize>,tokens:&'a [Token]}
#[derive(Clone,Debug)]
/// intoiterator for TokenDict
pub struct DictIntoIter{range:Range<usize>,tokens:Arc<Vec<Token>>}
#[derive(Clone,Debug)]
/// a simple dictionary based tokenizer dictionary where single bytes implicitly form the lower 256 ids. It's reference counted and cheap to clone
pub struct TokenDict{ids:Arc<[Trie<u8,u32>;256]>,maxtokenlen:usize,tokens:Arc<Vec<Token>>}
#[derive(Clone,Debug)]
/// a simple dictionary based tokenizer iterator
pub struct Tokenization<I:Iterator> where I::Item:Val<u8>{ids:Arc<[Trie<u8,u32>;256]>,inner:Fuse<I>,state:VecDeque<u8>}
use crate::{Token,UTF8CharIter,Val,token::SINGLE_BYTES};
use post_inc;
use ptrie::Trie;
#[cfg(feature="serial")]
use serde::{Deserialize,Deserializer,Serialize,Serializer};
use std::{
collections::{HashMap,VecDeque},iter::{Extend,Fuse},ops::{Index,Range},slice::Iter as SliceIter,sync::Arc,vec::IntoIter as VecIntoIter,
};