use crate::common::protocols::MoveBlock;
use derive_getters::Getters;
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{TokenBlockSequence, Tokens};
use rand::random;
use validator::Validate;
fn create_sequence_cache(
tokens: &TokenBlockSequence,
block_size: usize,
enable_prefix_caching: bool,
) -> (Vec<UniqueBlock>, Vec<u64>) {
let mut unique_blocks = Vec::with_capacity(tokens.blocks().len() + 1);
let mut block_hashes = Vec::with_capacity(tokens.blocks().len());
for block in tokens.blocks() {
block_hashes.push(block.block_hash());
unique_blocks.push({
if enable_prefix_caching {
UniqueBlock::FullBlock(block.sequence_hash())
} else {
UniqueBlock::FullBlock(random::<u64>())
}
});
}
if !tokens.total_tokens().is_multiple_of(block_size) {
unique_blocks.push(UniqueBlock::default());
}
(unique_blocks, block_hashes)
}
#[derive(Debug, Getters, Validate)]
pub struct ActiveSequence {
unique_blocks: Vec<UniqueBlock>,
block_hashes: Vec<u64>,
tokens: TokenBlockSequence,
#[getter(copy)]
#[validate(range(min = 2))]
block_size: usize,
#[getter(copy)]
max_output_tokens: usize,
#[getter(copy)]
generated_tokens: usize,
#[getter(copy)]
num_input_tokens: usize,
#[getter(copy)]
num_allocated_tokens: usize,
#[getter(copy)]
enable_prefix_caching: bool,
#[getter(copy)]
emit_token_ids: bool,
}
impl ActiveSequence {
pub fn new(
tokens: Vec<u32>,
max_output_tokens: usize,
block_size: Option<usize>,
enable_prefix_caching: bool,
emit_token_ids: bool,
) -> Self {
let block_size = block_size.unwrap_or(64);
let num_input_tokens = tokens.len();
let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337));
let (unique_blocks, block_hashes) =
create_sequence_cache(&tokens, block_size, enable_prefix_caching);
let seq = Self {
unique_blocks,
block_hashes,
tokens,
block_size,
max_output_tokens,
generated_tokens: 0,
num_input_tokens,
num_allocated_tokens: 0,
enable_prefix_caching,
emit_token_ids,
};
seq.validate().expect("invalid ActiveSequence");
seq
}
pub fn extra_tokens(&self) -> u32 {
(self.len() % self.block_size) as u32
}
pub fn len(&self) -> usize {
self.tokens.total_tokens()
}
pub fn is_empty(&self) -> bool {
self.tokens.total_tokens() == 0
}
pub fn prepare_allocation(&self, cumulative_tokens: usize) -> Option<MoveBlock> {
let prev_blocks = self
.num_allocated_tokens
.div_ceil(self.block_size)
.min(self.unique_blocks.len());
let target_blocks = cumulative_tokens
.div_ceil(self.block_size)
.min(self.unique_blocks.len());
if target_blocks <= prev_blocks {
return None;
}
let range = prev_blocks..target_blocks;
let blocks = self.unique_blocks[range.clone()].to_vec();
let hash_start = prev_blocks.min(self.block_hashes.len());
let hash_end = target_blocks.min(self.block_hashes.len());
let hashes = self.block_hashes[hash_start..hash_end].to_vec();
let token_ids = if self.emit_token_ids && hash_start < hash_end {
Some(
self.tokens.blocks()[hash_start..hash_end]
.iter()
.map(|b| b.tokens().to_vec())
.collect(),
)
} else {
None
};
let parent = if prev_blocks > 0 {
Some(self.unique_blocks[prev_blocks - 1].clone())
} else {
None
};
Some(MoveBlock::Use(blocks, hashes, token_ids, parent))
}
pub fn commit_allocation(&mut self, cumulative_tokens: usize) {
self.num_allocated_tokens = cumulative_tokens;
}
pub fn allocate_blocks_for_chunk(&mut self, cumulative_tokens: usize) -> Option<MoveBlock> {
let signal = self.prepare_allocation(cumulative_tokens);
self.commit_allocation(cumulative_tokens);
signal
}
pub fn take_creation_signal(&mut self) -> Option<MoveBlock> {
self.allocate_blocks_for_chunk(self.len())
}
pub fn new_with_signal(
tokens: Vec<u32>,
max_output_tokens: usize,
block_size: Option<usize>,
enable_prefix_caching: bool,
) -> (Self, Option<MoveBlock>) {
let mut sequence = Self::new(
tokens,
max_output_tokens,
block_size,
enable_prefix_caching,
false,
);
let signal = sequence.take_creation_signal();
(sequence, signal)
}
pub fn push(&mut self, token: u32) -> Option<Vec<MoveBlock>> {
self.tokens.append(token).expect("Token push failed.");
self.generated_tokens += 1;
if self.len() % self.block_size != 1 {
return None;
}
let mut signals = Vec::new();
if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() {
let last_complete = self.tokens.last_complete_block().unwrap();
let last_seq_hash = if self.enable_prefix_caching {
last_complete.sequence_hash()
} else {
random::<u64>()
};
let last_block_hash = last_complete.block_hash();
let promote_token_ids = if self.emit_token_ids {
Some(last_complete.tokens().to_vec())
} else {
None
};
self.block_hashes.push(last_block_hash);
self.unique_blocks.pop();
let second_to_last_hash = self.unique_blocks.last().map(|block| match block {
UniqueBlock::FullBlock(hash) => *hash,
UniqueBlock::PartialBlock(_) => panic!("Cannot have a partial block as parent"),
});
self.unique_blocks
.push(UniqueBlock::FullBlock(last_seq_hash));
signals.push(MoveBlock::Promote(
uuid,
last_seq_hash,
second_to_last_hash,
last_block_hash,
promote_token_ids,
));
}
let new_partial_block = UniqueBlock::default();
self.unique_blocks.push(new_partial_block.clone());
signals.push(MoveBlock::Use(vec![new_partial_block], vec![], None, None));
Some(signals)
}
pub fn generate(&mut self) -> Vec<MoveBlock> {
assert!(
self.generated_tokens < self.max_output_tokens,
"Cannot generate more tokens: reached max_output_tokens limit"
);
let token = random::<u32>();
let mut signals = Vec::new();
if let Some(move_blocks) = self.push(token) {
signals.extend(move_blocks);
}
if self.generated_tokens != self.max_output_tokens {
return signals;
}
signals.extend(self.free_signal_for_tokens(self.len()));
signals
}
fn free_signal_for_tokens(&self, active_tokens: usize) -> Vec<MoveBlock> {
let active_blocks = active_tokens
.div_ceil(self.block_size)
.min(self.unique_blocks.len());
self.unique_blocks[..active_blocks]
.iter()
.rev()
.map(|block| match block {
UniqueBlock::PartialBlock(uuid) => {
MoveBlock::Destroy(vec![UniqueBlock::PartialBlock(*uuid)])
}
UniqueBlock::FullBlock(hash) => {
MoveBlock::Deref(vec![UniqueBlock::FullBlock(*hash)])
}
})
.collect()
}
pub fn free_signal(&self) -> Vec<MoveBlock> {
self.free_signal_for_tokens(self.num_allocated_tokens)
}
pub fn reset_with_signal(&mut self) -> Vec<MoveBlock> {
let free_signal = self.free_signal();
self.num_allocated_tokens = 0;
free_signal
}
pub fn pop(&mut self) {
self.tokens.pop();
self.generated_tokens = self.generated_tokens.saturating_sub(1);
if self.tokens.total_tokens().is_multiple_of(self.block_size) {
self.unique_blocks.pop();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn block_hashes_from_tokens(seq: &ActiveSequence) -> Vec<u64> {
seq.tokens
.blocks()
.iter()
.map(|block| block.block_hash())
.collect()
}
fn assert_cached_hashes_match_promoted_blocks(seq: &ActiveSequence) {
let num_full_unique_blocks = seq
.unique_blocks()
.iter()
.filter(|block| matches!(block, UniqueBlock::FullBlock(_)))
.count();
assert_eq!(
seq.block_hashes().as_slice(),
&block_hashes_from_tokens(seq)[..num_full_unique_blocks],
"cached block hashes should match the promoted full blocks"
);
}
fn assert_use_signal(
signal: &MoveBlock,
expected_blocks: &[UniqueBlock],
expected_hashes: &[u64],
) {
match signal {
MoveBlock::Use(blocks, hashes, ..) => {
assert_eq!(blocks, expected_blocks);
assert_eq!(hashes, expected_hashes);
}
_ => panic!("Expected MoveBlock::Use"),
}
}
fn assert_single_partial_use(signal: &MoveBlock) {
match signal {
MoveBlock::Use(blocks, hashes, ..) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
assert!(hashes.is_empty());
}
_ => panic!("Expected MoveBlock::Use with a single partial block"),
}
}
fn assert_promote_parent(signal: &MoveBlock, expected_parent: Option<u64>) {
match signal {
MoveBlock::Promote(_, _, parent_hash, _hash, ..) => {
assert_eq!(*parent_hash, expected_parent);
}
_ => panic!("Expected MoveBlock::Promote"),
}
}
fn assert_destroy_partial(signal: &MoveBlock) {
match signal {
MoveBlock::Destroy(blocks) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected MoveBlock::Destroy for partial block"),
}
}
fn assert_deref_full(signal: &MoveBlock) {
match signal {
MoveBlock::Deref(blocks) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::FullBlock(_)));
}
_ => panic!("Expected MoveBlock::Deref for full block"),
}
}
#[test]
fn test_new_with_signal_creates_initial_partial_block() {
let initial_tokens: Vec<u32> = (0..15).collect();
let (seq, signal) = ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
assert_eq!(seq.num_input_tokens(), 15);
assert_eq!(seq.len(), 15);
assert_single_partial_use(signal.as_ref().expect("Expected initial Use signal"));
}
#[test]
fn test_push_across_block_boundary_promotes_and_allocates_partial() {
let initial_tokens: Vec<u32> = (0..15).collect();
let (mut seq, _) = ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
let signal_15 = seq.push(15);
assert!(
signal_15.is_none(),
"Completing a block should not trigger signals"
);
let signal_16 = seq.push(16).expect("Expected boundary crossing signals");
assert_eq!(signal_16.len(), 2);
assert_promote_parent(&signal_16[0], None);
assert_single_partial_use(&signal_16[1]);
assert_eq!(
seq.unique_blocks().len(),
2,
"sequence should have one full block and one partial block"
);
assert_eq!(
seq.len() % seq.block_size(),
1,
"sequence should have one token in the new partial block"
);
}
#[test]
fn test_equivalent_histories_preserve_full_block_identity() {
let initial_tokens: Vec<u32> = (0..15).collect();
let (mut seq1, _) = ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
seq1.push(15);
seq1.push(16);
let extended_tokens: Vec<u32> = (0..16).collect();
let (mut seq2, _) = ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), true);
seq2.push(16);
seq2.pop();
seq2.push(16);
assert_eq!(seq1.unique_blocks()[0], seq2.unique_blocks()[0]);
assert_ne!(seq1.unique_blocks()[1], seq2.unique_blocks()[1]);
}
#[test]
fn test_promote_uses_previous_full_block_as_parent() {
let initial_tokens: Vec<u32> = (0..15).collect();
let (mut seq, _) = ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
seq.push(15);
seq.push(16);
seq.push(17);
seq.pop();
seq.pop();
seq.push(16);
let extended_tokens: Vec<u32> = (0..16).collect();
let (mut seq_equiv, _) =
ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), true);
seq_equiv.push(16);
seq_equiv.pop();
seq_equiv.push(16);
for token in 17..33 {
seq.push(token);
seq_equiv.push(token);
}
assert_eq!(
&seq.unique_blocks()[0..2],
&seq_equiv.unique_blocks()[0..2],
"first two full blocks should remain identical"
);
for token in 33..48 {
seq.push(token);
}
let signal = seq
.push(48)
.expect("Expected promote when opening next partial");
let UniqueBlock::FullBlock(expected_hash) = seq.unique_blocks()[1] else {
panic!("unique_blocks[1] should be a full block");
};
assert_promote_parent(&signal[0], Some(expected_hash));
assert_single_partial_use(&signal[1]);
}
#[test]
fn test_reset_with_signal_frees_blocks_and_resets_allocation() {
let initial_tokens: Vec<u32> = (0..15).collect();
let (mut seq, _) = ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
seq.push(15);
seq.push(16);
seq.commit_allocation(seq.len());
let free_signals = seq.reset_with_signal();
assert!(!free_signals.is_empty());
assert_eq!(seq.num_allocated_tokens(), 0);
assert_eq!(seq.generated_tokens(), 2);
}
#[test]
fn test_active_sequence_generate_signals() {
let initial_tokens: Vec<u32> = (0..14).collect();
let (mut seq, signal) = ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), true);
assert_single_partial_use(signal.as_ref().expect("Expected initial Use signal"));
seq.generate();
let signals_first = seq.generate();
assert_eq!(signals_first.len(), 0);
let signals_second = seq.generate();
assert_eq!(signals_second.len(), 2);
assert_promote_parent(&signals_second[0], None);
assert_single_partial_use(&signals_second[1]);
let signals_third = seq.generate();
assert_eq!(signals_third.len(), 0);
let signals_last = seq.generate();
assert_eq!(signals_last.len(), 2);
assert_destroy_partial(&signals_last[0]);
assert_deref_full(&signals_last[1]);
}
#[test]
fn test_prepare_allocation_slices_full_and_partial_blocks() {
let tokens: Vec<u32> = (0..10).collect();
let seq = ActiveSequence::new(tokens, 4, Some(4), true, false);
let first = seq.prepare_allocation(4).unwrap();
assert_use_signal(
&first,
&seq.unique_blocks()[0..1],
&seq.block_hashes()[0..1],
);
let second = seq.prepare_allocation(8).unwrap();
assert_use_signal(
&second,
&seq.unique_blocks()[0..2],
&seq.block_hashes()[0..2],
);
let third = seq.prepare_allocation(10).unwrap();
assert_use_signal(
&third,
&seq.unique_blocks()[0..3],
&seq.block_hashes()[0..2],
);
}
#[test]
fn test_prepare_allocation_is_stable_until_commit() {
let tokens: Vec<u32> = (0..10).collect();
let mut seq = ActiveSequence::new(tokens, 4, Some(4), true, false);
let first = seq.prepare_allocation(4).unwrap();
let second = seq.prepare_allocation(4).unwrap();
assert_eq!(first, second);
seq.commit_allocation(4);
let next = seq.prepare_allocation(8).unwrap();
assert_use_signal(&next, &seq.unique_blocks()[1..2], &seq.block_hashes()[1..2]);
}
#[test]
fn test_block_hash_cache_stays_in_sync_after_promote_and_pop() {
let initial_tokens: Vec<u32> = (0..15).collect();
let (mut seq, _) = ActiveSequence::new_with_signal(initial_tokens, 4, Some(16), true);
assert_cached_hashes_match_promoted_blocks(&seq);
seq.push(15);
assert_cached_hashes_match_promoted_blocks(&seq);
let promote_signals = seq.push(16).unwrap();
assert_eq!(promote_signals.len(), 2);
assert_cached_hashes_match_promoted_blocks(&seq);
seq.pop();
assert_cached_hashes_match_promoted_blocks(&seq);
}
}