use boostr::inference::memory::CpuBlockAllocator;
use boostr::inference::scheduler::{
ScheduledBatch, SchedulerConfig, SchedulingPriority, SequenceRequest, SequenceScheduler,
SequenceState,
};
use std::collections::HashMap;
fn create_scheduler() -> SequenceScheduler<CpuBlockAllocator> {
let allocator = CpuBlockAllocator::new(100, 16);
let config = SchedulerConfig {
max_batch_size: 4,
max_batch_tokens: 1024,
max_seq_len: 512,
block_size: 16,
enable_preemption: true,
max_preempt_per_step: 2,
};
SequenceScheduler::new(allocator, config)
}
#[test]
fn test_scheduler_create() {
let scheduler = create_scheduler();
assert!(!scheduler.has_work());
assert_eq!(scheduler.stats().total_requests, 0);
}
#[test]
fn test_scheduler_add_request() {
let mut scheduler = create_scheduler();
let request = SequenceRequest::new(1, vec![1, 2, 3, 4, 5]);
scheduler.add_request(request).unwrap();
assert!(scheduler.has_work());
assert_eq!(scheduler.stats().total_requests, 1);
assert_eq!(scheduler.stats().waiting_count, 1);
}
#[test]
fn test_scheduler_add_duplicate_request() {
let mut scheduler = create_scheduler();
let request1 = SequenceRequest::new(1, vec![1, 2, 3]);
scheduler.add_request(request1).unwrap();
let request2 = SequenceRequest::new(1, vec![4, 5, 6]);
let result = scheduler.add_request(request2);
assert!(result.is_err());
}
#[test]
fn test_scheduler_schedule_prefill() {
let mut scheduler = create_scheduler();
let request = SequenceRequest::new(1, vec![1, 2, 3, 4, 5]);
scheduler.add_request(request).unwrap();
let batch = scheduler.schedule().unwrap().unwrap();
assert_eq!(batch.prefill_sequences, vec![1]);
assert!(batch.decode_sequences.is_empty());
assert!(batch.block_tables.contains_key(&1));
assert_eq!(scheduler.stats().waiting_count, 0);
assert_eq!(scheduler.stats().running_count, 1);
}
#[test]
fn test_scheduler_schedule_decode() {
let mut scheduler = create_scheduler();
let request = SequenceRequest::new(1, vec![1, 2, 3, 4, 5]);
scheduler.add_request(request).unwrap();
scheduler.schedule().unwrap();
scheduler.prefill_complete(1).unwrap();
let batch = scheduler.schedule().unwrap().unwrap();
assert!(batch.prefill_sequences.is_empty());
assert_eq!(batch.decode_sequences, vec![1]);
}
#[test]
fn test_scheduler_append_token() {
let mut scheduler = create_scheduler();
let request = SequenceRequest::new(1, vec![1, 2, 3]).with_max_tokens(5);
scheduler.add_request(request).unwrap();
scheduler.schedule().unwrap();
for i in 0..5 {
let finished = scheduler.append_token(1, 100 + i).unwrap();
if i < 4 {
assert!(!finished);
} else {
assert!(finished);
}
}
assert_eq!(scheduler.get_generated_tokens(1).unwrap().len(), 5);
assert_eq!(
scheduler.get_sequence_state(1),
Some(SequenceState::Finished)
);
}
#[test]
fn test_scheduler_finish_sequence() {
let mut scheduler = create_scheduler();
let request = SequenceRequest::new(1, vec![1, 2, 3]);
scheduler.add_request(request).unwrap();
scheduler.schedule().unwrap();
scheduler.finish_sequence(1).unwrap();
assert_eq!(
scheduler.get_sequence_state(1),
Some(SequenceState::Finished)
);
assert_eq!(scheduler.stats().finished_count, 1);
assert_eq!(scheduler.stats().running_count, 0);
}
#[test]
fn test_scheduler_abort_sequence() {
let mut scheduler = create_scheduler();
let request = SequenceRequest::new(1, vec![1, 2, 3]);
scheduler.add_request(request).unwrap();
scheduler.abort_sequence(1).unwrap();
assert_eq!(
scheduler.get_sequence_state(1),
Some(SequenceState::Finished)
);
assert_eq!(scheduler.stats().waiting_count, 0);
}
#[test]
fn test_scheduler_multiple_sequences() {
let mut scheduler = create_scheduler();
for i in 1..=3 {
let request = SequenceRequest::new(i, vec![1, 2, 3]);
scheduler.add_request(request).unwrap();
}
let batch = scheduler.schedule().unwrap().unwrap();
assert_eq!(batch.prefill_sequences.len(), 3);
assert_eq!(scheduler.stats().running_count, 3);
}
#[test]
fn test_scheduler_batch_size_limit() {
let mut scheduler = create_scheduler();
for i in 1..=10 {
let request = SequenceRequest::new(i, vec![1, 2, 3]);
scheduler.add_request(request).unwrap();
}
let batch = scheduler.schedule().unwrap().unwrap();
assert_eq!(batch.prefill_sequences.len(), 4);
assert_eq!(scheduler.stats().waiting_count, 6);
}
#[test]
fn test_scheduler_block_allocation() {
let mut scheduler = create_scheduler();
let request = SequenceRequest::new(1, vec![0; 17]);
scheduler.add_request(request).unwrap();
scheduler.schedule().unwrap();
let block_table = scheduler.get_block_table(1).unwrap();
assert_eq!(block_table.num_blocks(), 2);
}
#[test]
fn test_scheduler_no_work() {
let mut scheduler = create_scheduler();
assert!(scheduler.schedule().unwrap().is_none());
}
#[test]
fn test_scheduler_cleanup_finished() {
let mut scheduler = create_scheduler();
let request = SequenceRequest::new(1, vec![1, 2, 3]);
scheduler.add_request(request).unwrap();
scheduler.schedule().unwrap();
scheduler.finish_sequence(1).unwrap();
assert!(scheduler.has_sequence(1));
scheduler.cleanup_finished();
assert!(!scheduler.has_sequence(1));
}
#[test]
fn test_scheduler_stats() {
let mut scheduler = create_scheduler();
for i in 1..=3 {
let request = SequenceRequest::new(i, vec![1, 2, 3]).with_max_tokens(2);
scheduler.add_request(request).unwrap();
}
scheduler.schedule().unwrap();
for i in 1..=3 {
scheduler.append_token(i, 100).unwrap();
scheduler.append_token(i, 101).unwrap();
}
let stats = scheduler.stats();
assert_eq!(stats.total_requests, 3);
assert_eq!(stats.total_tokens_generated, 6);
assert_eq!(stats.finished_count, 3);
}
#[test]
fn test_sequence_request_builder() {
let request = SequenceRequest::new(1, vec![1, 2, 3])
.with_max_tokens(100)
.with_priority(SchedulingPriority::High);
assert_eq!(request.max_new_tokens, 100);
assert_eq!(request.priority, SchedulingPriority::High);
}
#[test]
fn test_append_token_returns_error_when_blocks_exhausted() {
let allocator = CpuBlockAllocator::new(1, 2);
let config = SchedulerConfig {
max_batch_size: 4,
max_batch_tokens: 1024,
max_seq_len: 512,
block_size: 2,
enable_preemption: false,
max_preempt_per_step: 0,
};
let mut scheduler = SequenceScheduler::new(allocator, config);
let req = SequenceRequest::new(1, vec![10]);
scheduler.add_request(req).unwrap();
scheduler.schedule().unwrap();
let done = scheduler.append_token(1, 20).unwrap();
assert!(!done);
let result = scheduler.append_token(1, 30);
assert!(result.is_err(), "must fail when no block available");
let bt = scheduler.get_block_table(1).unwrap();
assert!(
bt.num_tokens <= 2,
"num_tokens must not advance beyond available blocks: got {}",
bt.num_tokens
);
}
#[test]
fn test_abort_sequence_is_idempotent() {
let mut scheduler = create_scheduler();
let req = SequenceRequest::new(42, vec![1, 2, 3, 4]);
scheduler.add_request(req).unwrap();
scheduler.schedule().unwrap();
scheduler.abort_sequence(42).unwrap();
let stats_after_first = scheduler.stats();
scheduler.abort_sequence(42).unwrap();
let stats_after_second = scheduler.stats();
assert_eq!(
stats_after_first.finished_count, stats_after_second.finished_count,
"finished_count must not change on second abort"
);
assert_eq!(
stats_after_first.running_count, stats_after_second.running_count,
"running_count must not underflow on second abort"
);
}
#[test]
fn test_scheduled_batch_methods() {
let batch = ScheduledBatch {
prefill_sequences: vec![1, 2],
decode_sequences: vec![3, 4, 5],
block_tables: HashMap::new(),
preempted_sequences: vec![],
cached_token_counts: HashMap::new(),
};
assert!(!batch.is_empty());
assert_eq!(batch.len(), 5);
assert_eq!(batch.all_sequences(), vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_scheduler_with_prefix_cache() {
use boostr::inference::memory::CpuBlockAllocator;
use boostr::inference::prefix_cache::{PrefixCache, PrefixCacheConfig};
let block_size = 16;
let allocator = CpuBlockAllocator::new(64, block_size);
let config = SchedulerConfig {
max_batch_size: 4,
max_batch_tokens: 1024,
max_seq_len: 512,
block_size,
..Default::default()
};
let cache_config = PrefixCacheConfig {
enabled: true,
block_size,
max_cached_blocks: 32,
min_prefix_tokens: 16,
};
let cache = PrefixCache::new(allocator.clone(), cache_config);
let mut scheduler = SequenceScheduler::new(allocator, config).with_prefix_cache(cache);
let prompt: Vec<u32> = (0..64).collect();
let req1 = SequenceRequest::new(1, prompt.clone()).with_max_tokens(10);
scheduler.add_request(req1).unwrap();
let batch1 = scheduler.schedule().unwrap().unwrap();
assert!(batch1.prefill_sequences.contains(&1));
let cached1 = batch1.cached_token_counts.get(&1).copied().unwrap_or(0);
assert_eq!(cached1, 0, "First request should have no cache hits");
scheduler.prefill_complete(1).unwrap();
for i in 0..10 {
let done = scheduler.append_token(1, 1000 + i).unwrap();
if done {
break; }
}
scheduler.cleanup_finished();
let req2 = SequenceRequest::new(2, prompt.clone()).with_max_tokens(10);
scheduler.add_request(req2).unwrap();
let batch2 = scheduler.schedule().unwrap().unwrap();
assert!(batch2.prefill_sequences.contains(&2));
let cached2 = batch2.cached_token_counts.get(&2).copied().unwrap_or(0);
assert!(
cached2 > 0,
"Second request with same prefix should have cache hits, got {}",
cached2
);
}