Skip to main content

token_dict/
dict.rs

1impl AsMut<Self> for TokenDict{
2	fn as_mut(&mut self)->&mut Self{self}
3}
4impl AsRef<Self> for TokenDict{
5	fn as_ref(&self)->&Self{self}
6}
7impl Default for TokenDict{
8	fn default()->Self{
9		let x:[Vec<u8>;0]=[];
10		x.into_iter().collect()
11	}
12}
13impl DoubleEndedIterator for DictIntoIter{
14	fn next_back(&mut self)->Option<Self::Item>{
15		self.range.next_back().map(|n|if n<256{SINGLE_TOKENS[n].clone()}else{self.tokens[n-256].clone()})
16	}
17	fn nth_back(&mut self,n:usize)->Option<Self::Item>{
18		self.range.nth_back(n).map(|n|if n<256{SINGLE_TOKENS[n].clone()}else{self.tokens[n-256].clone()})
19	}
20	fn rfold<B,F:FnMut(B,Self::Item)->B>(self,mut init:B,mut f:F)->B{
21		let (range,tokens)=(self.range,self.tokens);
22		let (start,stop)=(range.start,range.end);
23
24		init=tokens[start.saturating_sub(256)..stop.saturating_sub(256)].iter().cloned().rfold(init,&mut f);
25		SINGLE_TOKENS[start.min(256)..stop.min(256)].iter().cloned().rfold(init,f)
26	}
27}
28impl ExactSizeIterator for DictIntoIter{
29	fn len(&self)->usize{self.range.len()}
30}
31impl Index<u32> for TokenDict{
32	fn index(&self,ix:u32)->&Self::Output{
33		let ix=ix as usize;
34		if ix<256{&SINGLE_BYTES[ix..ix+1]}else{&self.tokens[ix-256]}
35	}
36	type Output=[u8];
37}
38impl Index<usize> for TokenDict{
39	fn index(&self,ix:usize)->&Self::Output{
40		if ix<256{&SINGLE_TOKENS[ix]}else{&self.tokens[ix-256]}
41	}
42	type Output=Token;
43}
44impl IntoIterator for TokenDict{
45	fn into_iter(self)->Self::IntoIter{
46		DictIntoIter{range:0..self.len(),tokens:self.tokens}
47	}
48	type IntoIter=DictIntoIter;
49	type Item=Token;
50}
51impl Iterator for DictIntoIter{
52	fn count(self)->usize{self.range.count()}
53	fn fold<B,F:FnMut(B,Self::Item)->B>(self,mut init:B,mut f:F)->B{
54		let (range,tokens)=(self.range,self.tokens);
55		let (start,stop)=(range.start,range.end);
56
57		init=SINGLE_TOKENS[start.min(256)..stop.min(256)].iter().cloned().fold(init,&mut f);
58		tokens[start.saturating_sub(256)..stop.saturating_sub(256)].iter().cloned().fold(init,f)
59	}
60	fn last(mut self)->Option<Self::Item>{self.next_back()}
61	fn next(&mut self)->Option<Self::Item>{
62		self.range.next().map(|n|if n<256{SINGLE_TOKENS[n].clone()}else{self.tokens[n-256].clone()})
63	}
64	fn nth(&mut self,n:usize)->Option<Self::Item>{
65		self.range.nth(n).map(|n|if n<256{SINGLE_TOKENS[n].clone()}else{self.tokens[n-256].clone()})
66	}
67	fn size_hint(&self)->(usize,Option<usize>){self.range.size_hint()}
68	type Item=Token;
69}
70#[cfg(feature="serial")]
71impl Serialize for TokenDict{
72	fn serialize<S:Serializer>(&self,serializer:S)->Result<S::Ok,S::Error>{
73		let data:Vec<Vec<u8>>=self.iter().map(|t|t.to_vec()).collect();
74		data.serialize(serializer)
75	}
76}
77impl TokenDict{
78	/// decodes the tokens into bytes
79	pub fn detokenize<I:IntoIterator>(&self,tokens:I)->Detokenization<I::IntoIter> where I::Item:Val<u32>{
80		Detokenization{inner:tokens.into_iter().fuse(),maxtokenlen:self.maxtokenlen,position:1,tokenid:0,tokens:self.tokens.clone()}
81	}
82	/// creates an iterator over tokens
83	pub fn detoken_iter<I:IntoIterator>(&self,tokens:I)->impl Iterator<Item=Token> where I::Item:Val<u32>{
84		let tokenizer=self.clone();
85		tokens.into_iter().map(move|id|tokenizer[id.val() as usize].clone())
86	}
87	/// decodes the tokens into chars, replacing invalid unicode with replacement character
88	pub fn detokenize_str<I:IntoIterator>(&self,tokens:I)->impl Iterator<Item=char> where I::Item:Val<u32>{
89		UTF8CharIter::from(self.detokenize(tokens)).map(|r|if let Ok(c)=r{c}else{char::REPLACEMENT_CHARACTER})
90	}
91	/// decodes the tokens into chars, replacing invalid unicode with replacement character
92	pub fn detokenize_string<I:IntoIterator>(&self,tokens:I)->String where I::Item:Val<u32>{self.detokenize_str(tokens).collect()}
93	/// accumulates frequencies of each token in the text as if they were tokenized by this tokenizer
94	pub fn frequencies<I:IntoIterator,O:Into<Option<Vec<usize>>>>(&self,data:I,freq:O)->Vec<usize> where I::Item:Val<u8>{
95		let mut freq=freq.into().unwrap_or_default();
96		if freq.len()<self.len(){freq.resize(self.len(),0)}
97
98		for t in self.tokenize(data){freq[t as usize]+=1}
99		freq
100	}
101	/// gets an id for the token if it is in the dictionary
102	pub fn get_id(&self,token:&[u8])->Option<u32>{self.ids[*token.get(0)? as usize].get(token.iter().copied().skip(1)).copied()}
103	/// returns an interator over the possible tokens generated by this tokenizer
104	pub fn iter(&self)->DictIter<'_>{
105		DictIter{range:0..self.len(),tokens:&self.tokens}
106	}
107	/// returns the number of possible token ids generated by this tokenizer
108	pub fn len(&self)->usize{self.tokens.len()+256}
109	/// finds token pairs and returns new tokens of them mapped to their frequencies with ids as if they were added to this dictionary. Tokens with ids within the current dictionary will have those ids
110	pub fn pairs<I:IntoIterator,O:Into<Option<HashMap<Token,usize>>>>(&self,data:I,freq:O)->HashMap<Token,usize> where I::Item:Val<u8>{
111		let mut freq=freq.into().unwrap_or_default();
112		let mut nexttokenid=freq.keys().map(|t|t.id()+1).chain([self.len() as u32]).max().unwrap();
113		let mut previous:Option<u32>=None;
114		let mut temp:Vec<u8>=Vec::new();
115
116		self.tokenize(data).for_each(|id|{
117			if let Some(previous)=previous{
118				temp.clear();
119				temp.extend(self[previous as usize].clone());
120				temp.extend(self[id as usize].clone());
121
122				if let Some(f)=freq.get_mut(temp.as_slice()){
123					*f+=1
124				}else{
125					let newid=if let Some(id)=self.get_id(&temp){id}else{post_inc!(nexttokenid)};
126					let token=Token::new(newid,Some(Arc::from(temp.as_slice())));
127
128					freq.insert(token,1);
129				}
130			}
131			previous=Some(id);
132		});
133		freq
134	}
135	/// adds the token to the dictionary
136	pub fn push<A:AsRef<[u8]>>(&mut self,token:A){self.extend(Some(token))}
137	/// converts the string to a token vec
138	pub fn string_to_tokens<S:?Sized+AsRef<str>>(&self,input:&S)->Vec<u32>{self.tokenize(input.as_ref().as_bytes()).collect()}
139	/// creates an iterator over tokens
140	pub fn token_iter<I:IntoIterator>(&self,bytes:I)->impl Iterator<Item=Token> where I::Item:Val<u8>{
141		let tokenizer=self.clone();
142		self.tokenize(bytes).map(move|id|tokenizer[id as usize].clone())
143	}
144	/// converts the bytes to tokens
145	pub fn tokenize<I:IntoIterator>(&self,bytes:I)->Tokenization<I::IntoIter> where I::Item:Val<u8>{
146		Tokenization{ids:self.ids.clone(),inner:bytes.into_iter().fuse(),state:VecDeque::with_capacity(self.maxtokenlen)}
147	}
148	/// converts the string to tokens
149	pub fn tokenize_str<'a,S:?Sized+AsRef<str>>(&self,input:&'a S)->Tokenization<SliceIter<'a,u8>>{self.tokenize(input.as_ref().as_bytes())}
150	/// converts the string to tokens
151	pub fn tokenize_string(&self,input:String)->Tokenization<VecIntoIter<u8>>{self.tokenize(Vec::from(input))}
152	/// converts the token vec to string
153	pub fn tokens_to_string<V:?Sized+AsRef<[u32]>>(&self,input:&V)->String{String::from_utf8_lossy(&self.detokenize(input.as_ref()).collect::<Vec<u8>>()).to_string()}
154}
155#[cfg(feature="serial")]
156impl<'a> Deserialize<'a> for TokenDict{
157	fn deserialize<D:Deserializer<'a>>(deserializer:D)->Result<Self,D::Error>{
158		let data:Vec<Vec<u8>>=Deserialize::deserialize(deserializer)?;
159		Ok(data.into_iter().collect())
160	}
161}
162impl<'a> DoubleEndedIterator for DictIter<'a>{
163	fn next_back(&mut self)->Option<Self::Item>{
164		self.range.next_back().map(|n|if n<256{&SINGLE_TOKENS[n]}else{&self.tokens[n-256]})
165	}
166	fn nth_back(&mut self,n:usize)->Option<Self::Item>{
167		self.range.nth_back(n).map(|n|if n<256{&SINGLE_TOKENS[n]}else{&self.tokens[n-256]})
168	}
169	fn rfold<B,F:FnMut(B,Self::Item)->B>(self,mut init:B,mut f:F)->B{
170		let (range,tokens)=(self.range,self.tokens);
171		let (start,stop)=(range.start,range.end);
172
173		init=tokens[start.saturating_sub(256)..stop.saturating_sub(256)].iter().rfold(init,&mut f);
174		SINGLE_TOKENS[start.min(256)..stop.min(256)].iter().rfold(init,f)
175	}
176}
177impl<'a> ExactSizeIterator for DictIter<'a>{
178	fn len(&self)->usize{self.range.len()}
179}
180impl<'a> Iterator for DictIter<'a>{
181	fn count(self)->usize{self.range.count()}
182	fn fold<B,F:FnMut(B,Self::Item)->B>(self,mut init:B,mut f:F)->B{
183		let (range,tokens)=(self.range,self.tokens);
184		let (start,stop)=(range.start,range.end);
185
186		init=SINGLE_TOKENS[start.min(256)..stop.min(256)].iter().fold(init,&mut f);
187		tokens[start.saturating_sub(256)..stop.saturating_sub(256)].iter().fold(init,f)
188	}
189	fn last(mut self)->Option<Self::Item>{self.next_back()}
190	fn next(&mut self)->Option<Self::Item>{
191		self.range.next().map(|n|if n<256{&SINGLE_TOKENS[n]}else{&self.tokens[n-256]})
192	}
193	fn nth(&mut self,n:usize)->Option<Self::Item>{
194		self.range.nth(n).map(|n|if n<256{&SINGLE_TOKENS[n]}else{&self.tokens[n-256]})
195	}
196	fn size_hint(&self)->(usize,Option<usize>){self.range.size_hint()}
197	type Item=&'a Token;
198}
199impl<A:AsRef<[u8]>> Extend<A> for TokenDict{
200	fn extend<I:IntoIterator<Item=A>>(&mut self,iter:I){
201		let (ids,tokens)=(Arc::make_mut(&mut self.ids),Arc::make_mut(&mut self.tokens));
202		let maxtokenlen=&mut self.maxtokenlen;
203
204		iter.into_iter().filter(|a|a.as_ref().len()>1).for_each(|a|{
205			let id=u32::try_from(tokens.len()+256).unwrap();
206			let token:Arc<[u8]>=Arc::from(a.as_ref());
207
208			ids[token[0] as usize].insert(token.iter().copied().skip(1),id);
209			*maxtokenlen=(*maxtokenlen).max(token.len());
210			tokens.push(Token::new(id,Some(token)))
211		});
212	}
213}
214impl<A:AsRef<[u8]>> FromIterator<A> for TokenDict{
215	fn from_iter<I:IntoIterator<Item=A>>(iter:I)->Self{
216		let mut maxtokenlen=1;
217		let mut ids:[Trie<_,_>;256]=std::array::from_fn(|_|Trie::new());
218		let tokens:Vec<Token>=iter.into_iter().filter(|t|t.as_ref().len()>1).enumerate().map(|(n,t)|{
219			let id=u32::try_from(n+256).unwrap();
220			let token:Arc<[u8]>=Arc::from(t.as_ref());
221
222			ids[token[0] as usize].insert(token.iter().copied().skip(1),id);
223			maxtokenlen=maxtokenlen.max(token.len());
224
225			Token::new(id,Some(token))
226		}).collect();
227
228		let (ids,tokens)=(Arc::new(ids),Arc::new(tokens));
229		Self{ids,maxtokenlen,tokens}
230	}
231}
232impl<I:Iterator> Iterator for Detokenization<I> where I::Item:Val<u32>{
233	fn fold<B,F:FnMut(B,Self::Item)->B>(self,init:B,mut f:F)->B{
234		self.inner.map(Val::val).fold(init,|acc,tokenid|if tokenid<256{f(acc,tokenid as u8)}else{self.tokens[tokenid as usize].iter().fold(acc,|acc,&b|f(acc,b))})
235	}
236	fn next(&mut self)->Option<u8>{
237		let (inner,position)=(&mut self.inner,&mut self.position);
238		let tokenid=&mut self.tokenid;
239		let tokens=&self.tokens;
240
241		if let Some(b)=if *tokenid<256{(*position==0).then_some(*tokenid as u8)}else{tokens[*tokenid as usize-256].get(*position).map(|&b|b)}{
242			*position+=1;
243			b
244		}else{
245			*position=1;
246			*tokenid=inner.map(Val::val).next()?;
247			if *tokenid<256{*tokenid as u8}else{tokens[*tokenid as usize-256][0]}
248		}.into()
249	}
250	fn size_hint(&self)->(usize,Option<usize>){
251		let (lowertokens,uppertokens)=self.inner.size_hint();
252		let maxtoken=self.maxtokenlen;
253
254		(lowertokens,uppertokens.map(|h|h*maxtoken))
255	}
256	type Item=u8;
257}
258impl<I:Iterator> Iterator for Tokenization<I> where I::Item:Val<u8>{
259	fn next(&mut self)->Option<u32>{
260		let (inner,state)=(&mut self.inner,&mut self.state);
261		let ids=&self.ids;
262		state.extend(inner.map(Val::val).take(state.capacity()-state.len()));
263		if state.len()==0{return None}
264
265		let (tokenlen,&tokenid)=if let Some(t)=ids[state[0] as usize].find_longest_prefix_len(state.iter().copied().skip(1)).filter(|(tokenlen,_tokenid)|*tokenlen>0){t}else{return Some(state.pop_front().unwrap() as u32)};
266		state.drain(..tokenlen+1);
267		return Some(tokenid)
268	}
269	fn size_hint(&self)->(usize,Option<usize>){
270		let (lowerbytes,upperbytes)=self.inner.size_hint();
271		let maxtoken=self.state.capacity();
272		let statelen=self.state.len();
273
274		((lowerbytes+statelen).div_ceil(maxtoken),upperbytes.map(|b|b+statelen))
275	}
276	type Item=u32;
277}
278
279macro_rules! post_inc {
280	($e:expr) => {{
281		let old = $e;
282		$e += 1;
283		old
284	}};
285}
286
287#[cfg(test)]
288mod tests{
289	#[test]
290	fn tokenizer_iter(){
291		let tokenizer:TokenDict=["aa","bb","cc"].into_iter().collect();
292		let t2:TokenDict=tokenizer.iter().collect();
293		assert_eq!(tokenizer.tokenize_str("ccaabb").collect::<Vec<_>>(),t2.tokenize_str("ccaabb").collect::<Vec<_>>());
294	}
295	#[test]
296	fn bytes_only(){
297		let teststring="oishsoghohhduihahdufghud";
298		let tokenizer=TokenDict::default();
299		let tokens:Vec<u32>=tokenizer.tokenize_str(teststring).collect();
300		let detokens:Vec<u8>=tokenizer.detokenize(tokens).collect();
301
302		assert_eq!(detokens.as_slice(),teststring.as_bytes());
303	}
304	#[test]
305	fn there_are_tokens_yay(){
306		let teststring="there are tokens! yay";
307		let tokenizer:TokenDict=["there","are","tokens","yay"].into_iter().collect();
308		let tokens:Vec<u32>=tokenizer.tokenize(teststring.bytes()).collect();
309		let detokens:Vec<u8>=tokenizer.detokenize(&tokens).collect();
310
311		assert_eq!(tokens.len(),8);
312		assert_eq!(detokens.as_slice(),teststring.as_bytes());
313	}
314	#[test]
315	fn test_default_token_dict_detokenize_empty() {
316		let dict = TokenDict::default();
317		let inp:Vec<u32>=vec![];
318		let out: Vec<u8> = dict.detokenize(inp).collect();
319		assert!(out.is_empty(), "Detokenizing an empty input should yield no bytes");
320	}
321	use super::*;
322}
323
324pub (crate) const SINGLE_TOKENS:&[Token;256]=&[Token::single(0),Token::single(1),Token::single(2),Token::single(3),Token::single(4),Token::single(5),Token::single(6),Token::single(7),Token::single(8),Token::single(9),Token::single(10),Token::single(11),Token::single(12),Token::single(13),Token::single(14),Token::single(15),Token::single(16),Token::single(17),Token::single(18),Token::single(19),Token::single(20),Token::single(21),Token::single(22),Token::single(23),Token::single(24),Token::single(25),Token::single(26),Token::single(27),Token::single(28),Token::single(29),Token::single(30),Token::single(31),Token::single(32),Token::single(33),Token::single(34),Token::single(35),Token::single(36),Token::single(37),Token::single(38),Token::single(39),Token::single(40),Token::single(41),Token::single(42),Token::single(43),Token::single(44),Token::single(45),Token::single(46),Token::single(47),Token::single(48),Token::single(49),Token::single(50),Token::single(51),Token::single(52),Token::single(53),Token::single(54),Token::single(55),Token::single(56),Token::single(57),Token::single(58),Token::single(59),Token::single(60),Token::single(61),Token::single(62),Token::single(63),Token::single(64),Token::single(65),Token::single(66),Token::single(67),Token::single(68),Token::single(69),Token::single(70),Token::single(71),Token::single(72),Token::single(73),Token::single(74),Token::single(75),Token::single(76),Token::single(77),Token::single(78),Token::single(79),Token::single(80),Token::single(81),Token::single(82),Token::single(83),Token::single(84),Token::single(85),Token::single(86),Token::single(87),Token::single(88),Token::single(89),Token::single(90),Token::single(91),Token::single(92),Token::single(93),Token::single(94),Token::single(95),Token::single(96),Token::single(97),Token::single(98),Token::single(99),Token::single(100),Token::single(101),Token::single(102),Token::single(103),Token::single(104),Token::single(105),Token::single(106),Token::single(107),Token::single(108),Token::single(109),Token::single(110),Token::single(111),Token::single(112),Token::single(113),Token::single(114),Token::single(115),Token::single(116),Token::single(117),Token::single(118),Token::single(119),Token::single(120),Token::single(121),Token::single(122),Token::single(123),Token::single(124),Token::single(125),Token::single(126),Token::single(127),Token::single(128),Token::single(129),Token::single(130),Token::single(131),Token::single(132),Token::single(133),Token::single(134),Token::single(135),Token::single(136),Token::single(137),Token::single(138),Token::single(139),Token::single(140),Token::single(141),Token::single(142),Token::single(143),Token::single(144),Token::single(145),Token::single(146),Token::single(147),Token::single(148),Token::single(149),Token::single(150),Token::single(151),Token::single(152),Token::single(153),Token::single(154),Token::single(155),Token::single(156),Token::single(157),Token::single(158),Token::single(159),Token::single(160),Token::single(161),Token::single(162),Token::single(163),Token::single(164),Token::single(165),Token::single(166),Token::single(167),Token::single(168),Token::single(169),Token::single(170),Token::single(171),Token::single(172),Token::single(173),Token::single(174),Token::single(175),Token::single(176),Token::single(177),Token::single(178),Token::single(179),Token::single(180),Token::single(181),Token::single(182),Token::single(183),Token::single(184),Token::single(185),Token::single(186),Token::single(187),Token::single(188),Token::single(189),Token::single(190),Token::single(191),Token::single(192),Token::single(193),Token::single(194),Token::single(195),Token::single(196),Token::single(197),Token::single(198),Token::single(199),Token::single(200),Token::single(201),Token::single(202),Token::single(203),Token::single(204),Token::single(205),Token::single(206),Token::single(207),Token::single(208),Token::single(209),Token::single(210),Token::single(211),Token::single(212),Token::single(213),Token::single(214),Token::single(215),Token::single(216),Token::single(217),Token::single(218),Token::single(219),Token::single(220),Token::single(221),Token::single(222),Token::single(223),Token::single(224),Token::single(225),Token::single(226),Token::single(227),Token::single(228),Token::single(229),Token::single(230),Token::single(231),Token::single(232),Token::single(233),Token::single(234),Token::single(235),Token::single(236),Token::single(237),Token::single(238),Token::single(239),Token::single(240),Token::single(241),Token::single(242),Token::single(243),Token::single(244),Token::single(245),Token::single(246),Token::single(247),Token::single(248),Token::single(249),Token::single(250),Token::single(251),Token::single(252),Token::single(253),Token::single(254),Token::single(255)];
325#[derive(Clone,Debug)]
326/// a simple dictionary based detokenizer
327pub struct Detokenization<I:Iterator> where I::Item:Val<u32>{inner:Fuse<I>,maxtokenlen:usize,position:usize,tokenid:u32,tokens:Arc<Vec<Token>>}
328#[derive(Clone,Debug)]
329/// iterator for TokenDict
330pub struct DictIter<'a>{range:Range<usize>,tokens:&'a [Token]}
331#[derive(Clone,Debug)]
332/// intoiterator for TokenDict
333pub struct DictIntoIter{range:Range<usize>,tokens:Arc<Vec<Token>>}
334#[derive(Clone,Debug)]
335/// a simple dictionary based tokenizer dictionary where single bytes implicitly form the lower 256 ids. It's reference counted and cheap to clone
336pub struct TokenDict{ids:Arc<[Trie<u8,u32>;256]>,maxtokenlen:usize,tokens:Arc<Vec<Token>>}
337#[derive(Clone,Debug)]
338/// a simple dictionary based tokenizer iterator
339pub struct Tokenization<I:Iterator> where I::Item:Val<u8>{ids:Arc<[Trie<u8,u32>;256]>,inner:Fuse<I>,state:VecDeque<u8>}
340
341use crate::{Token,UTF8CharIter,Val,token::SINGLE_BYTES};
342use post_inc;
343use ptrie::Trie;
344#[cfg(feature="serial")]
345use serde::{Deserialize,Deserializer,Serialize,Serializer};
346use std::{
347	collections::{HashMap,VecDeque},iter::{Extend,Fuse},ops::{Index,Range},slice::Iter as SliceIter,sync::Arc,vec::IntoIter as VecIntoIter,
348};