1use crate::kv_router::indexer::compute_hash;
17use bytemuck::cast_slice;
18use derive_getters::{Dissolve, Getters};
19use rayon::prelude::*;
20
21pub type Token = u32;
22
23pub type BlockHash = u64;
25
26pub type SequenceHash = u64;
28
29#[derive(Debug, Clone, Dissolve, Default)]
30pub struct Tokens(Vec<Token>);
31
32impl AsRef<[Token]> for Tokens {
33 fn as_ref(&self) -> &[Token] {
34 &self.0
35 }
36}
37
38impl AsMut<[Token]> for Tokens {
39 fn as_mut(&mut self) -> &mut [Token] {
40 &mut self.0
41 }
42}
43
44impl std::ops::Deref for Tokens {
45 type Target = [Token];
46
47 fn deref(&self) -> &Self::Target {
48 &self.0
49 }
50}
51
52impl std::ops::DerefMut for Tokens {
53 fn deref_mut(&mut self) -> &mut Self::Target {
54 &mut self.0
55 }
56}
57
58impl std::borrow::Borrow<[Token]> for Tokens {
59 fn borrow(&self) -> &[Token] {
60 &self.0
61 }
62}
63
64impl IntoIterator for Tokens {
65 type Item = Token;
66
67 type IntoIter = std::vec::IntoIter<Token>;
68
69 fn into_iter(self) -> Self::IntoIter {
70 self.0.into_iter()
71 }
72}
73
74impl From<Vec<Token>> for Tokens {
75 fn from(tokens: Vec<Token>) -> Self {
76 Tokens(tokens)
77 }
78}
79
80impl From<&[Token]> for Tokens {
81 fn from(tokens: &[Token]) -> Self {
82 Tokens(tokens.to_vec())
83 }
84}
85
86impl From<Vec<i32>> for Tokens {
87 fn from(tokens: Vec<i32>) -> Self {
88 Tokens(tokens.into_iter().map(|t| t as u32).collect())
89 }
90}
91
92impl From<&[i32]> for Tokens {
93 fn from(tokens: &[i32]) -> Self {
94 Tokens(tokens.iter().map(|&t| t as u32).collect())
95 }
96}
97
98impl From<Tokens> for Vec<Token> {
99 fn from(tokens: Tokens) -> Self {
100 tokens.0
101 }
102}
103
104impl Tokens {
105 pub fn into_sequence(self, block_size: usize) -> TokenSequence {
106 TokenSequence::new(self, block_size)
107 }
108
109 pub fn compute_block_hash(tokens: &[Token], block_size: usize) -> Vec<BlockHash> {
110 tokens
111 .par_chunks_exact(block_size)
112 .map(|chunk| compute_hash(cast_slice(chunk)))
113 .collect()
114 }
115}
116
117pub struct PartialTokenBlock {
118 tokens: Tokens,
119 block_size: usize,
120 parent_sequence_hash: Option<SequenceHash>,
121}
122
123impl PartialTokenBlock {
124 pub fn push_token(&mut self, token: Token) -> Option<TokenBlock> {
127 self.tokens.0.push(token);
128 if self.tokens.0.len() == self.block_size {
129 let block = std::mem::take(&mut self.tokens);
130 let block_hash = compute_hash(cast_slice(&block));
131 let sequence_hash = compute_hash(bytemuck::cast_slice(&[
132 self.parent_sequence_hash.unwrap_or_default(),
133 block_hash,
134 ]));
135 Some(TokenBlock {
136 tokens: block,
137 sequence_hash,
138 block_hash,
139 parent_sequence_hash: self.parent_sequence_hash,
140 })
141 } else {
142 None
143 }
144 }
145
146 pub fn tokens(&self) -> &Tokens {
147 &self.tokens
148 }
149}
150
151impl std::ops::Deref for PartialTokenBlock {
152 type Target = Tokens;
153
154 fn deref(&self) -> &Self::Target {
155 &self.tokens
156 }
157}
158
159#[derive(Debug, Clone, Getters, Default)]
160pub struct TokenBlock {
161 tokens: Tokens,
162
163 #[getter(copy)]
164 block_hash: BlockHash,
165
166 #[getter(copy)]
167 sequence_hash: SequenceHash,
168
169 #[getter(copy)]
170 parent_sequence_hash: Option<SequenceHash>,
171}
172
173pub struct TokenSequence {
174 blocks: Vec<TokenBlock>,
175 current_block: PartialTokenBlock,
176}
177
178impl TokenSequence {
179 pub fn new(tokens: Tokens, block_size: usize) -> Self {
180 let (blocks, current_block) = Self::split_tokens(tokens, block_size);
181
182 Self {
183 blocks,
184 current_block,
185 }
186 }
187
188 pub fn push_token(&mut self, token: Token) -> Option<&TokenBlock> {
189 if let Some(block) = self.current_block.push_token(token) {
190 self.blocks.push(block);
191 self.blocks.last()
192 } else {
193 None
194 }
195 }
196
197 pub fn blocks(&self) -> &[TokenBlock] {
198 &self.blocks
199 }
200
201 pub fn current_block(&self) -> &PartialTokenBlock {
202 &self.current_block
203 }
204
205 pub fn into_parts(self) -> (Vec<TokenBlock>, PartialTokenBlock) {
206 (self.blocks, self.current_block)
207 }
208
209 pub fn split_tokens(tokens: Tokens, block_size: usize) -> (Vec<TokenBlock>, PartialTokenBlock) {
210 let mut blocks: Vec<TokenBlock> = tokens
212 .par_chunks_exact(block_size)
213 .map(|chunk| TokenBlock {
214 tokens: chunk.to_vec().into(),
215 sequence_hash: 0,
216 block_hash: compute_hash(cast_slice(chunk)),
217 parent_sequence_hash: None,
218 })
219 .collect();
220
221 blocks[0].sequence_hash = blocks[0].block_hash;
222
223 for i in 1..blocks.len() {
226 let previous_block = &blocks[i - 1];
227 let parent_sequence_hash = previous_block.sequence_hash;
228 let vals = &[parent_sequence_hash, blocks[i].block_hash];
229 blocks[i].sequence_hash = compute_hash(bytemuck::cast_slice(vals));
230 blocks[i].parent_sequence_hash = Some(parent_sequence_hash);
231 }
232
233 let remainder = tokens.chunks_exact(block_size).remainder();
234
235 let next_block = PartialTokenBlock {
236 tokens: remainder.into(),
237 block_size,
238 parent_sequence_hash: blocks.last().map(|b| b.sequence_hash),
239 };
240
241 (blocks, next_block)
242 }
243}
244
245impl PartialEq<Vec<Token>> for Tokens {
246 fn eq(&self, other: &Vec<Token>) -> bool {
247 self.0 == *other
248 }
249}
250
251impl PartialEq<Tokens> for Vec<Token> {
252 fn eq(&self, other: &Tokens) -> bool {
253 *self == other.0
254 }
255}
256
257impl PartialEq<[Token]> for Tokens {
258 fn eq(&self, other: &[Token]) -> bool {
259 self.0.as_slice() == other
260 }
261}
262
263impl PartialEq<Tokens> for &[Token] {
264 fn eq(&self, other: &Tokens) -> bool {
265 *self == other.0.as_slice()
266 }
267}
268
269impl PartialEq<Vec<Token>> for &Tokens {
270 fn eq(&self, other: &Vec<Token>) -> bool {
271 self.0 == *other
272 }
273}
274
275impl<'a> PartialEq<&'a Tokens> for Vec<Token> {
276 fn eq(&self, other: &&'a Tokens) -> bool {
277 *self == other.0
278 }
279}
280
281impl PartialEq<[Token]> for &Tokens {
282 fn eq(&self, other: &[Token]) -> bool {
283 self.0.as_slice() == other
284 }
285}
286
287impl<'a> PartialEq<&'a [Token]> for Tokens {
288 fn eq(&self, other: &&'a [Token]) -> bool {
289 self.0.as_slice() == *other
290 }
291}
292
293impl PartialEq for Tokens {
294 fn eq(&self, other: &Self) -> bool {
295 self.0 == other.0
296 }
297}
298
299impl Eq for Tokens {}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_tokens_slice_operations() {
307 let tokens = Tokens(vec![1, 2, 3, 4, 5]);
308
309 let slice: &[Token] = tokens.as_ref();
311 assert_eq!(slice, &[1, 2, 3, 4, 5]);
312
313 assert_eq!(tokens.len(), 5);
315 assert_eq!(tokens[0], 1);
316 assert_eq!(tokens[4], 5);
317
318 let sum: u32 = tokens.iter().sum();
320 assert_eq!(sum, 15);
321
322 let slice = &tokens[1..4];
324 assert_eq!(slice, &[2, 3, 4]);
325
326 let borrowed: &[Token] = std::borrow::Borrow::borrow(&tokens);
328 assert_eq!(borrowed, &[1, 2, 3, 4, 5]);
329
330 fn takes_slice(slice: &[Token]) -> usize {
332 slice.len()
333 }
334
335 assert_eq!(takes_slice(&tokens), 5);
336 }
337
338 #[test]
339 fn test_tokens_conversions() {
340 let vec = vec![1, 2, 3, 4, 5];
342 let tokens: Tokens = vec.clone().into();
343 assert_eq!(tokens.0, vec);
344
345 let tokens = Tokens(vec![6, 7, 8, 9, 10]);
347 let vec: Vec<Token> = tokens.into();
348 assert_eq!(vec, vec![6, 7, 8, 9, 10]);
349
350 let slice: &[Token] = &[11, 12, 13];
352 let tokens: Tokens = slice.into();
353 assert_eq!(tokens.0, vec![11, 12, 13]);
354
355 let i32_values = vec![100_i32, 200_i32, 300_i32];
357 let tokens: Tokens = i32_values.into();
358 assert_eq!(tokens.0, vec![100, 200, 300]);
359
360 let i32_slice: &[i32] = &[400_i32, 500_i32, 600_i32];
362 let tokens: Tokens = i32_slice.into();
363 assert_eq!(tokens.0, vec![400, 500, 600]);
364 }
365
366 #[test]
367 fn test_tokens_blocks() {
368 let tokens = Tokens(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
369 let sequence = TokenSequence::new(tokens, 4);
370
371 assert_eq!(sequence.blocks().len(), 2);
372 assert_eq!(sequence.current_block().len(), 2);
373
374 assert_eq!(sequence.blocks()[0].tokens(), vec![1, 2, 3, 4]);
375 assert_eq!(sequence.blocks()[0].block_hash(), 14643705804678351452);
376 assert_eq!(sequence.blocks()[0].sequence_hash(), 14643705804678351452);
377 println!("blocks[0]: {:?}", sequence.blocks()[0]);
378
379 assert_eq!(sequence.blocks()[1].tokens(), vec![5, 6, 7, 8]);
380 assert_eq!(sequence.blocks()[1].block_hash(), 16777012769546811212);
381 assert_eq!(sequence.blocks()[1].sequence_hash(), 4945711292740353085);
382 println!("blocks[1]: {:?}", sequence.blocks()[1]);
383
384 assert_eq!(sequence.current_block().tokens(), vec![9, 10]);
385
386 let mut sequence = sequence;
387
388 let new_block = sequence.push_token(11);
389 assert!(new_block.is_none());
390 assert_eq!(sequence.blocks().len(), 2);
391
392 let new_block = sequence.push_token(12);
393 assert!(new_block.is_some());
394 assert_eq!(sequence.blocks().len(), 3);
395 assert_eq!(sequence.current_block().tokens().len(), 0);
396 println!("blocks[2]: {:?}", sequence.blocks()[2]);
397
398 let (blocks, mut current_block) = sequence.into_parts();
399
400 let new_block = current_block.push_token(13);
401 assert!(new_block.is_none());
402 assert_eq!(current_block.tokens().len(), 1);
403
404 let new_block = current_block.push_token(14);
405 assert!(new_block.is_none());
406 assert_eq!(current_block.tokens().len(), 2);
407
408 let new_block = current_block.push_token(15);
409 assert!(new_block.is_none());
410 assert_eq!(current_block.tokens().len(), 3);
411
412 let new_block = current_block.push_token(16);
413 assert!(new_block.is_some());
414 assert_eq!(blocks.len(), 3);
415 assert_eq!(current_block.tokens().len(), 0);
416 }
417}