dynamo_llm/
tokens.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::kv_router::indexer::compute_hash;
17use bytemuck::cast_slice;
18use derive_getters::{Dissolve, Getters};
19use rayon::prelude::*;
20
21pub type Token = u32;
22
23/// A hash of the only the tokens within a block computed from [compute_hash].
24pub type BlockHash = u64;
25
26/// A sequence aware hash that combines the previous block's sequence hash with the current block's hash.
27pub type SequenceHash = u64;
28
29#[derive(Debug, Clone, Dissolve, Default)]
30pub struct Tokens(Vec<Token>);
31
32impl AsRef<[Token]> for Tokens {
33    fn as_ref(&self) -> &[Token] {
34        &self.0
35    }
36}
37
38impl AsMut<[Token]> for Tokens {
39    fn as_mut(&mut self) -> &mut [Token] {
40        &mut self.0
41    }
42}
43
44impl std::ops::Deref for Tokens {
45    type Target = [Token];
46
47    fn deref(&self) -> &Self::Target {
48        &self.0
49    }
50}
51
52impl std::ops::DerefMut for Tokens {
53    fn deref_mut(&mut self) -> &mut Self::Target {
54        &mut self.0
55    }
56}
57
58impl std::borrow::Borrow<[Token]> for Tokens {
59    fn borrow(&self) -> &[Token] {
60        &self.0
61    }
62}
63
64impl IntoIterator for Tokens {
65    type Item = Token;
66
67    type IntoIter = std::vec::IntoIter<Token>;
68
69    fn into_iter(self) -> Self::IntoIter {
70        self.0.into_iter()
71    }
72}
73
74impl From<Vec<Token>> for Tokens {
75    fn from(tokens: Vec<Token>) -> Self {
76        Tokens(tokens)
77    }
78}
79
80impl From<&[Token]> for Tokens {
81    fn from(tokens: &[Token]) -> Self {
82        Tokens(tokens.to_vec())
83    }
84}
85
86impl From<Vec<i32>> for Tokens {
87    fn from(tokens: Vec<i32>) -> Self {
88        Tokens(tokens.into_iter().map(|t| t as u32).collect())
89    }
90}
91
92impl From<&[i32]> for Tokens {
93    fn from(tokens: &[i32]) -> Self {
94        Tokens(tokens.iter().map(|&t| t as u32).collect())
95    }
96}
97
98impl From<Tokens> for Vec<Token> {
99    fn from(tokens: Tokens) -> Self {
100        tokens.0
101    }
102}
103
104impl Tokens {
105    pub fn into_sequence(self, block_size: usize) -> TokenSequence {
106        TokenSequence::new(self, block_size)
107    }
108
109    pub fn compute_block_hash(tokens: &[Token], block_size: usize) -> Vec<BlockHash> {
110        tokens
111            .par_chunks_exact(block_size)
112            .map(|chunk| compute_hash(cast_slice(chunk)))
113            .collect()
114    }
115}
116
117pub struct PartialTokenBlock {
118    tokens: Tokens,
119    block_size: usize,
120    parent_sequence_hash: Option<SequenceHash>,
121}
122
123impl PartialTokenBlock {
124    /// Push a token onto the block, if the block is full, return a new [TokenBlock]
125    /// and reset the incomplete block
126    pub fn push_token(&mut self, token: Token) -> Option<TokenBlock> {
127        self.tokens.0.push(token);
128        if self.tokens.0.len() == self.block_size {
129            let block = std::mem::take(&mut self.tokens);
130            let block_hash = compute_hash(cast_slice(&block));
131            let sequence_hash = compute_hash(bytemuck::cast_slice(&[
132                self.parent_sequence_hash.unwrap_or_default(),
133                block_hash,
134            ]));
135            Some(TokenBlock {
136                tokens: block,
137                sequence_hash,
138                block_hash,
139                parent_sequence_hash: self.parent_sequence_hash,
140            })
141        } else {
142            None
143        }
144    }
145
146    pub fn tokens(&self) -> &Tokens {
147        &self.tokens
148    }
149}
150
151impl std::ops::Deref for PartialTokenBlock {
152    type Target = Tokens;
153
154    fn deref(&self) -> &Self::Target {
155        &self.tokens
156    }
157}
158
159#[derive(Debug, Clone, Getters, Default)]
160pub struct TokenBlock {
161    tokens: Tokens,
162
163    #[getter(copy)]
164    block_hash: BlockHash,
165
166    #[getter(copy)]
167    sequence_hash: SequenceHash,
168
169    #[getter(copy)]
170    parent_sequence_hash: Option<SequenceHash>,
171}
172
173pub struct TokenSequence {
174    blocks: Vec<TokenBlock>,
175    current_block: PartialTokenBlock,
176}
177
178impl TokenSequence {
179    pub fn new(tokens: Tokens, block_size: usize) -> Self {
180        let (blocks, current_block) = Self::split_tokens(tokens, block_size);
181
182        Self {
183            blocks,
184            current_block,
185        }
186    }
187
188    pub fn push_token(&mut self, token: Token) -> Option<&TokenBlock> {
189        if let Some(block) = self.current_block.push_token(token) {
190            self.blocks.push(block);
191            self.blocks.last()
192        } else {
193            None
194        }
195    }
196
197    pub fn blocks(&self) -> &[TokenBlock] {
198        &self.blocks
199    }
200
201    pub fn current_block(&self) -> &PartialTokenBlock {
202        &self.current_block
203    }
204
205    pub fn into_parts(self) -> (Vec<TokenBlock>, PartialTokenBlock) {
206        (self.blocks, self.current_block)
207    }
208
209    pub fn split_tokens(tokens: Tokens, block_size: usize) -> (Vec<TokenBlock>, PartialTokenBlock) {
210        // Use rayon's parallel iterator to process chunks in parallel
211        let mut blocks: Vec<TokenBlock> = tokens
212            .par_chunks_exact(block_size)
213            .map(|chunk| TokenBlock {
214                tokens: chunk.to_vec().into(),
215                sequence_hash: 0,
216                block_hash: compute_hash(cast_slice(chunk)),
217                parent_sequence_hash: None,
218            })
219            .collect();
220
221        blocks[0].sequence_hash = blocks[0].block_hash;
222
223        // compute the sequence hash for each block
224        // this is the sequence hash of the previous block with the current block's hash
225        for i in 1..blocks.len() {
226            let previous_block = &blocks[i - 1];
227            let parent_sequence_hash = previous_block.sequence_hash;
228            let vals = &[parent_sequence_hash, blocks[i].block_hash];
229            blocks[i].sequence_hash = compute_hash(bytemuck::cast_slice(vals));
230            blocks[i].parent_sequence_hash = Some(parent_sequence_hash);
231        }
232
233        let remainder = tokens.chunks_exact(block_size).remainder();
234
235        let next_block = PartialTokenBlock {
236            tokens: remainder.into(),
237            block_size,
238            parent_sequence_hash: blocks.last().map(|b| b.sequence_hash),
239        };
240
241        (blocks, next_block)
242    }
243}
244
245impl PartialEq<Vec<Token>> for Tokens {
246    fn eq(&self, other: &Vec<Token>) -> bool {
247        self.0 == *other
248    }
249}
250
251impl PartialEq<Tokens> for Vec<Token> {
252    fn eq(&self, other: &Tokens) -> bool {
253        *self == other.0
254    }
255}
256
257impl PartialEq<[Token]> for Tokens {
258    fn eq(&self, other: &[Token]) -> bool {
259        self.0.as_slice() == other
260    }
261}
262
263impl PartialEq<Tokens> for &[Token] {
264    fn eq(&self, other: &Tokens) -> bool {
265        *self == other.0.as_slice()
266    }
267}
268
269impl PartialEq<Vec<Token>> for &Tokens {
270    fn eq(&self, other: &Vec<Token>) -> bool {
271        self.0 == *other
272    }
273}
274
275impl<'a> PartialEq<&'a Tokens> for Vec<Token> {
276    fn eq(&self, other: &&'a Tokens) -> bool {
277        *self == other.0
278    }
279}
280
281impl PartialEq<[Token]> for &Tokens {
282    fn eq(&self, other: &[Token]) -> bool {
283        self.0.as_slice() == other
284    }
285}
286
287impl<'a> PartialEq<&'a [Token]> for Tokens {
288    fn eq(&self, other: &&'a [Token]) -> bool {
289        self.0.as_slice() == *other
290    }
291}
292
293impl PartialEq for Tokens {
294    fn eq(&self, other: &Self) -> bool {
295        self.0 == other.0
296    }
297}
298
299impl Eq for Tokens {}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn test_tokens_slice_operations() {
307        let tokens = Tokens(vec![1, 2, 3, 4, 5]);
308
309        // Test AsRef<[Token]>
310        let slice: &[Token] = tokens.as_ref();
311        assert_eq!(slice, &[1, 2, 3, 4, 5]);
312
313        // Test Deref
314        assert_eq!(tokens.len(), 5);
315        assert_eq!(tokens[0], 1);
316        assert_eq!(tokens[4], 5);
317
318        // Test iteration
319        let sum: u32 = tokens.iter().sum();
320        assert_eq!(sum, 15);
321
322        // Test slicing
323        let slice = &tokens[1..4];
324        assert_eq!(slice, &[2, 3, 4]);
325
326        // Test Borrow
327        let borrowed: &[Token] = std::borrow::Borrow::borrow(&tokens);
328        assert_eq!(borrowed, &[1, 2, 3, 4, 5]);
329
330        // Test with functions that accept &[Token]
331        fn takes_slice(slice: &[Token]) -> usize {
332            slice.len()
333        }
334
335        assert_eq!(takes_slice(&tokens), 5);
336    }
337
338    #[test]
339    fn test_tokens_conversions() {
340        // Test From<Vec<Token>> for Tokens
341        let vec = vec![1, 2, 3, 4, 5];
342        let tokens: Tokens = vec.clone().into();
343        assert_eq!(tokens.0, vec);
344
345        // Test Into<Vec<Token>> for Tokens
346        let tokens = Tokens(vec![6, 7, 8, 9, 10]);
347        let vec: Vec<Token> = tokens.into();
348        assert_eq!(vec, vec![6, 7, 8, 9, 10]);
349
350        // Test From<&[Token]> for Tokens
351        let slice: &[Token] = &[11, 12, 13];
352        let tokens: Tokens = slice.into();
353        assert_eq!(tokens.0, vec![11, 12, 13]);
354
355        // Test From<Vec<i32>> for Tokens
356        let i32_values = vec![100_i32, 200_i32, 300_i32];
357        let tokens: Tokens = i32_values.into();
358        assert_eq!(tokens.0, vec![100, 200, 300]);
359
360        // Test From<&[i32]> for Tokens
361        let i32_slice: &[i32] = &[400_i32, 500_i32, 600_i32];
362        let tokens: Tokens = i32_slice.into();
363        assert_eq!(tokens.0, vec![400, 500, 600]);
364    }
365
366    #[test]
367    fn test_tokens_blocks() {
368        let tokens = Tokens(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
369        let sequence = TokenSequence::new(tokens, 4);
370
371        assert_eq!(sequence.blocks().len(), 2);
372        assert_eq!(sequence.current_block().len(), 2);
373
374        assert_eq!(sequence.blocks()[0].tokens(), vec![1, 2, 3, 4]);
375        assert_eq!(sequence.blocks()[0].block_hash(), 14643705804678351452);
376        assert_eq!(sequence.blocks()[0].sequence_hash(), 14643705804678351452);
377        println!("blocks[0]: {:?}", sequence.blocks()[0]);
378
379        assert_eq!(sequence.blocks()[1].tokens(), vec![5, 6, 7, 8]);
380        assert_eq!(sequence.blocks()[1].block_hash(), 16777012769546811212);
381        assert_eq!(sequence.blocks()[1].sequence_hash(), 4945711292740353085);
382        println!("blocks[1]: {:?}", sequence.blocks()[1]);
383
384        assert_eq!(sequence.current_block().tokens(), vec![9, 10]);
385
386        let mut sequence = sequence;
387
388        let new_block = sequence.push_token(11);
389        assert!(new_block.is_none());
390        assert_eq!(sequence.blocks().len(), 2);
391
392        let new_block = sequence.push_token(12);
393        assert!(new_block.is_some());
394        assert_eq!(sequence.blocks().len(), 3);
395        assert_eq!(sequence.current_block().tokens().len(), 0);
396        println!("blocks[2]: {:?}", sequence.blocks()[2]);
397
398        let (blocks, mut current_block) = sequence.into_parts();
399
400        let new_block = current_block.push_token(13);
401        assert!(new_block.is_none());
402        assert_eq!(current_block.tokens().len(), 1);
403
404        let new_block = current_block.push_token(14);
405        assert!(new_block.is_none());
406        assert_eq!(current_block.tokens().len(), 2);
407
408        let new_block = current_block.push_token(15);
409        assert!(new_block.is_none());
410        assert_eq!(current_block.tokens().len(), 3);
411
412        let new_block = current_block.push_token(16);
413        assert!(new_block.is_some());
414        assert_eq!(blocks.len(), 3);
415        assert_eq!(current_block.tokens().len(), 0);
416    }
417}