use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use super::batch::{ActiveBatch, FinishReason, SequenceState};
#[derive(Debug, Clone)]
pub struct IterationConfig {
pub max_tokens_per_step: usize,
pub allow_preemption: bool,
pub tokens_between_checks: usize,
pub collect_metrics: bool,
pub stop_tokens: Vec<u32>,
}
impl Default for IterationConfig {
fn default() -> Self {
Self {
max_tokens_per_step: 1,
allow_preemption: true,
tokens_between_checks: 1,
collect_metrics: true,
stop_tokens: vec![2], }
}
}
impl IterationConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_tokens_per_step(mut self, n: usize) -> Self {
self.max_tokens_per_step = n;
self
}
pub fn with_preemption(mut self, allow: bool) -> Self {
self.allow_preemption = allow;
self
}
pub fn with_stop_tokens(mut self, tokens: Vec<u32>) -> Self {
self.stop_tokens = tokens;
self
}
}
#[derive(Debug, Clone)]
pub struct IterationResult {
pub tokens_generated: usize,
pub sequences_finished: usize,
pub batch_complete: bool,
pub preemption_requested: bool,
pub duration: Duration,
pub outputs: Vec<Option<u32>>,
}
impl IterationResult {
pub fn empty() -> Self {
Self {
tokens_generated: 0,
sequences_finished: 0,
batch_complete: false,
preemption_requested: false,
duration: Duration::ZERO,
outputs: Vec::new(),
}
}
pub fn complete(duration: Duration) -> Self {
Self {
tokens_generated: 0,
sequences_finished: 0,
batch_complete: true,
preemption_requested: false,
duration,
outputs: Vec::new(),
}
}
pub fn preempted(duration: Duration) -> Self {
Self {
tokens_generated: 0,
sequences_finished: 0,
batch_complete: false,
preemption_requested: true,
duration,
outputs: Vec::new(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IterationStep {
PrepareInputs,
Forward,
Sample,
UpdateSequences,
CheckCompletion,
ProcessOutputs,
}
impl fmt::Display for IterationStep {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::PrepareInputs => write!(f, "prepare_inputs"),
Self::Forward => write!(f, "forward"),
Self::Sample => write!(f, "sample"),
Self::UpdateSequences => write!(f, "update_sequences"),
Self::CheckCompletion => write!(f, "check_completion"),
Self::ProcessOutputs => write!(f, "process_outputs"),
}
}
}
#[derive(Debug)]
pub struct TokenIterator {
config: IterationConfig,
iteration: u64,
tokens_generated: u64,
preemption_requested: bool,
started_at: Instant,
metrics: IterationMetrics,
}
impl TokenIterator {
pub fn new(config: IterationConfig) -> Self {
Self {
config,
iteration: 0,
tokens_generated: 0,
preemption_requested: false,
started_at: Instant::now(),
metrics: IterationMetrics::new(),
}
}
pub fn iteration(&self) -> u64 {
self.iteration
}
pub fn tokens_generated(&self) -> u64 {
self.tokens_generated
}
pub fn elapsed(&self) -> Duration {
self.started_at.elapsed()
}
pub fn request_preemption(&mut self) {
self.preemption_requested = true;
}
pub fn is_preemption_requested(&self) -> bool {
self.preemption_requested
}
pub fn metrics(&self) -> &IterationMetrics {
&self.metrics
}
pub fn step(&mut self, batch: &mut ActiveBatch) -> IterationResult {
let step_start = Instant::now();
self.iteration += 1;
if self.config.allow_preemption && self.preemption_requested {
self.metrics.record_preemption();
return IterationResult::preempted(step_start.elapsed());
}
if batch.all_finished() {
return IterationResult::complete(step_start.elapsed());
}
let mut tokens_generated = 0;
let mut sequences_finished = 0;
let mut outputs = Vec::new();
batch.for_each_group_mut(|group| {
for seq in &mut group.sequences {
if seq.state != SequenceState::Running && seq.state != SequenceState::Waiting {
outputs.push(None);
continue;
}
if seq.state == SequenceState::Waiting {
seq.state = SequenceState::Running;
}
let next_token = self.simulate_next_token(seq.tokens.last().copied());
if self.config.stop_tokens.contains(&next_token) {
seq.finish(FinishReason::Stop);
sequences_finished += 1;
outputs.push(Some(next_token));
continue;
}
if seq.is_at_max_length() {
seq.finish(FinishReason::Length);
sequences_finished += 1;
outputs.push(None);
continue;
}
seq.append_token(next_token, Some(-0.5)); tokens_generated += 1;
outputs.push(Some(next_token));
}
});
self.tokens_generated += tokens_generated as u64;
batch.record_tokens(tokens_generated as u64);
let duration = step_start.elapsed();
self.metrics.record_step(duration, tokens_generated);
IterationResult {
tokens_generated,
sequences_finished,
batch_complete: batch.all_finished(),
preemption_requested: false,
duration,
outputs,
}
}
fn simulate_next_token(&self, _last_token: Option<u32>) -> u32 {
((self.iteration % 1000) + 100) as u32
}
pub fn run_to_completion(&mut self, batch: &mut ActiveBatch) -> IterationResult {
let start = Instant::now();
let mut total_tokens = 0;
let mut total_finished = 0;
loop {
let result = self.step(batch);
total_tokens += result.tokens_generated;
total_finished += result.sequences_finished;
if result.batch_complete || result.preemption_requested {
return IterationResult {
tokens_generated: total_tokens,
sequences_finished: total_finished,
batch_complete: result.batch_complete,
preemption_requested: result.preemption_requested,
duration: start.elapsed(),
outputs: result.outputs,
};
}
}
}
pub fn run_for_tokens(
&mut self,
batch: &mut ActiveBatch,
max_tokens: usize,
) -> IterationResult {
let start = Instant::now();
let mut total_tokens = 0;
let mut total_finished = 0;
let mut last_outputs = Vec::new();
while total_tokens < max_tokens {
let result = self.step(batch);
total_tokens += result.tokens_generated;
total_finished += result.sequences_finished;
last_outputs = result.outputs;
if result.batch_complete || result.preemption_requested {
return IterationResult {
tokens_generated: total_tokens,
sequences_finished: total_finished,
batch_complete: result.batch_complete,
preemption_requested: result.preemption_requested,
duration: start.elapsed(),
outputs: last_outputs,
};
}
}
IterationResult {
tokens_generated: total_tokens,
sequences_finished: total_finished,
batch_complete: false,
preemption_requested: false,
duration: start.elapsed(),
outputs: last_outputs,
}
}
pub fn run_for_duration(
&mut self,
batch: &mut ActiveBatch,
max_duration: Duration,
) -> IterationResult {
let start = Instant::now();
let mut total_tokens = 0;
let mut total_finished = 0;
let mut last_outputs = Vec::new();
while start.elapsed() < max_duration {
let result = self.step(batch);
total_tokens += result.tokens_generated;
total_finished += result.sequences_finished;
last_outputs = result.outputs;
if result.batch_complete || result.preemption_requested {
return IterationResult {
tokens_generated: total_tokens,
sequences_finished: total_finished,
batch_complete: result.batch_complete,
preemption_requested: result.preemption_requested,
duration: start.elapsed(),
outputs: last_outputs,
};
}
}
IterationResult {
tokens_generated: total_tokens,
sequences_finished: total_finished,
batch_complete: false,
preemption_requested: false,
duration: start.elapsed(),
outputs: last_outputs,
}
}
pub fn reset(&mut self) {
self.iteration = 0;
self.tokens_generated = 0;
self.preemption_requested = false;
self.started_at = Instant::now();
}
}
#[derive(Debug)]
pub struct IterationMetrics {
iterations: AtomicU64,
tokens_generated: AtomicU64,
total_time_ns: AtomicU64,
preemptions: AtomicU64,
total_steps: AtomicU64,
}
impl IterationMetrics {
pub fn new() -> Self {
Self {
iterations: AtomicU64::new(0),
tokens_generated: AtomicU64::new(0),
total_time_ns: AtomicU64::new(0),
preemptions: AtomicU64::new(0),
total_steps: AtomicU64::new(0),
}
}
pub fn record_step(&self, duration: Duration, tokens: usize) {
self.iterations.fetch_add(1, Ordering::Relaxed);
self.tokens_generated
.fetch_add(tokens as u64, Ordering::Relaxed);
self.total_time_ns
.fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
self.total_steps.fetch_add(1, Ordering::Relaxed);
}
pub fn record_preemption(&self) {
self.preemptions.fetch_add(1, Ordering::Relaxed);
}
pub fn iterations(&self) -> u64 {
self.iterations.load(Ordering::Relaxed)
}
pub fn tokens_generated(&self) -> u64 {
self.tokens_generated.load(Ordering::Relaxed)
}
pub fn total_time(&self) -> Duration {
Duration::from_nanos(self.total_time_ns.load(Ordering::Relaxed))
}
pub fn preemptions(&self) -> u64 {
self.preemptions.load(Ordering::Relaxed)
}
pub fn tokens_per_second(&self) -> f64 {
let tokens = self.tokens_generated() as f64;
let secs = self.total_time().as_secs_f64();
if secs > 0.0 {
tokens / secs
} else {
0.0
}
}
pub fn avg_time_per_token(&self) -> Duration {
let tokens = self.tokens_generated();
if tokens > 0 {
self.total_time() / tokens as u32
} else {
Duration::ZERO
}
}
pub fn prometheus(&self) -> String {
let mut output = String::new();
output.push_str("# HELP infernum_batch_iterations_total Total batch iterations\n");
output.push_str("# TYPE infernum_batch_iterations_total counter\n");
output.push_str(&format!(
"infernum_batch_iterations_total {}\n",
self.iterations()
));
output.push_str("# HELP infernum_batch_tokens_generated_total Total tokens generated\n");
output.push_str("# TYPE infernum_batch_tokens_generated_total counter\n");
output.push_str(&format!(
"infernum_batch_tokens_generated_total {}\n",
self.tokens_generated()
));
output.push_str("# HELP infernum_batch_preemptions_total Total batch preemptions\n");
output.push_str("# TYPE infernum_batch_preemptions_total counter\n");
output.push_str(&format!(
"infernum_batch_preemptions_total {}\n",
self.preemptions()
));
output.push_str("# HELP infernum_batch_tokens_per_second Current tokens per second\n");
output.push_str("# TYPE infernum_batch_tokens_per_second gauge\n");
output.push_str(&format!(
"infernum_batch_tokens_per_second {:.2}\n",
self.tokens_per_second()
));
output
}
}
impl Default for IterationMetrics {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::super::batch::{BatchId, SamplingParams, Sequence, SequenceGroup, SequenceId};
use super::*;
fn create_test_batch() -> ActiveBatch {
let batch = ActiveBatch::new(BatchId::new(1), 32, 4096);
let seq = Sequence::new(SequenceId::new(1), vec![1, 2, 3], 10);
let group = SequenceGroup::new("req-1", seq, SamplingParams::default());
let _ = batch.try_add(group);
batch
}
#[test]
fn test_iteration_config_default() {
let config = IterationConfig::default();
assert_eq!(config.max_tokens_per_step, 1);
assert!(config.allow_preemption);
assert!(config.collect_metrics);
}
#[test]
fn test_iteration_config_builder() {
let config = IterationConfig::new()
.with_max_tokens_per_step(4)
.with_preemption(false)
.with_stop_tokens(vec![0, 2]);
assert_eq!(config.max_tokens_per_step, 4);
assert!(!config.allow_preemption);
assert_eq!(config.stop_tokens, vec![0, 2]);
}
#[test]
fn test_iteration_result_empty() {
let result = IterationResult::empty();
assert_eq!(result.tokens_generated, 0);
assert_eq!(result.sequences_finished, 0);
assert!(!result.batch_complete);
assert!(!result.preemption_requested);
}
#[test]
fn test_iteration_result_complete() {
let result = IterationResult::complete(Duration::from_millis(100));
assert!(result.batch_complete);
assert_eq!(result.duration.as_millis(), 100);
}
#[test]
fn test_iteration_result_preempted() {
let result = IterationResult::preempted(Duration::from_millis(50));
assert!(result.preemption_requested);
assert!(!result.batch_complete);
}
#[test]
fn test_iteration_step_display() {
assert_eq!(IterationStep::PrepareInputs.to_string(), "prepare_inputs");
assert_eq!(IterationStep::Forward.to_string(), "forward");
assert_eq!(IterationStep::Sample.to_string(), "sample");
}
#[test]
fn test_token_iterator_new() {
let config = IterationConfig::default();
let iterator = TokenIterator::new(config);
assert_eq!(iterator.iteration(), 0);
assert_eq!(iterator.tokens_generated(), 0);
assert!(!iterator.is_preemption_requested());
}
#[test]
fn test_token_iterator_step() {
let config = IterationConfig::default();
let mut iterator = TokenIterator::new(config);
let mut batch = create_test_batch();
let result = iterator.step(&mut batch);
assert_eq!(iterator.iteration(), 1);
assert!(result.tokens_generated > 0 || result.sequences_finished > 0);
}
#[test]
fn test_token_iterator_preemption() {
let config = IterationConfig::default();
let mut iterator = TokenIterator::new(config);
let mut batch = create_test_batch();
iterator.request_preemption();
assert!(iterator.is_preemption_requested());
let result = iterator.step(&mut batch);
assert!(result.preemption_requested);
}
#[test]
fn test_token_iterator_run_for_tokens() {
let config = IterationConfig::new().with_stop_tokens(vec![]); let mut iterator = TokenIterator::new(config);
let mut batch = create_test_batch();
let result = iterator.run_for_tokens(&mut batch, 5);
assert!(result.tokens_generated >= 5 || result.batch_complete);
}
#[test]
fn test_token_iterator_run_for_duration() {
let config = IterationConfig::new().with_stop_tokens(vec![]);
let mut iterator = TokenIterator::new(config);
let mut batch = create_test_batch();
let result = iterator.run_for_duration(&mut batch, Duration::from_millis(10));
assert!(result.tokens_generated > 0 || result.batch_complete);
}
#[test]
fn test_token_iterator_reset() {
let config = IterationConfig::default();
let mut iterator = TokenIterator::new(config);
let mut batch = create_test_batch();
let _ = iterator.step(&mut batch);
assert!(iterator.iteration() > 0);
iterator.reset();
assert_eq!(iterator.iteration(), 0);
assert_eq!(iterator.tokens_generated(), 0);
}
#[test]
fn test_iteration_metrics_new() {
let metrics = IterationMetrics::new();
assert_eq!(metrics.iterations(), 0);
assert_eq!(metrics.tokens_generated(), 0);
assert_eq!(metrics.preemptions(), 0);
}
#[test]
fn test_iteration_metrics_record() {
let metrics = IterationMetrics::new();
metrics.record_step(Duration::from_millis(10), 5);
metrics.record_step(Duration::from_millis(15), 3);
assert_eq!(metrics.iterations(), 2);
assert_eq!(metrics.tokens_generated(), 8);
assert_eq!(metrics.total_time().as_millis(), 25);
}
#[test]
fn test_iteration_metrics_preemption() {
let metrics = IterationMetrics::new();
metrics.record_preemption();
metrics.record_preemption();
assert_eq!(metrics.preemptions(), 2);
}
#[test]
fn test_iteration_metrics_tokens_per_second() {
let metrics = IterationMetrics::new();
metrics.record_step(Duration::from_secs(1), 100);
let tps = metrics.tokens_per_second();
assert!(tps > 90.0 && tps < 110.0);
}
#[test]
fn test_iteration_metrics_prometheus() {
let metrics = IterationMetrics::new();
metrics.record_step(Duration::from_millis(10), 5);
let output = metrics.prometheus();
assert!(output.contains("infernum_batch_iterations_total 1"));
assert!(output.contains("infernum_batch_tokens_generated_total 5"));
assert!(output.contains("infernum_batch_preemptions_total 0"));
}
#[test]
fn test_batch_completion() {
let mut batch = ActiveBatch::new(BatchId::new(1), 32, 4096);
let seq = Sequence::new(SequenceId::new(1), vec![1, 2, 3], 2); let group = SequenceGroup::new("req-1", seq, SamplingParams::default());
let _ = batch.try_add(group);
let config = IterationConfig::new().with_stop_tokens(vec![]); let mut iterator = TokenIterator::new(config);
let result = iterator.run_to_completion(&mut batch);
assert!(result.batch_complete);
}
}