oxillama-runtime 0.1.0

Inference engine — KV cache, sampling, tokenizer bridge
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
//! Continuous batching scheduler.
//!
//! Manages multiple in-flight inference requests, scheduling them into
//! batched forward passes for efficient GPU/CPU utilization.
//!
//! The scheduler maintains a pool of active sequences, each with its own
//! KV cache state, and decides which sequences to process in each iteration.
//!
//! ## Scheduling Algorithm
//!
//! 1. **Prefill priority**: New sequences in prefill phase get priority
//!    (they block until the prompt is fully processed).
//! 2. **Decode round-robin**: Active sequences in decode phase are
//!    processed in round-robin order.
//! 3. **Eviction**: When memory pressure is high, idle or long-running
//!    sequences can be preempted.

use std::collections::HashMap;

/// Unique identifier for an inference sequence (request).
pub type SeqId = u64;

/// State of a sequence in the scheduler.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SeqState {
    /// Waiting to be processed (queued).
    Waiting,
    /// In prefill phase (processing prompt tokens).
    Prefilling,
    /// In decode phase (generating tokens one at a time).
    Decoding,
    /// Finished generation (hit EOS, max tokens, or user stop).
    Finished,
    /// Preempted (KV cache evicted to make room, can be restarted).
    Preempted,
}

/// A single inference sequence managed by the scheduler.
#[derive(Debug)]
pub struct Sequence {
    /// Unique sequence ID.
    pub id: SeqId,
    /// Current state.
    pub state: SeqState,
    /// Prompt token IDs.
    pub prompt_tokens: Vec<u32>,
    /// Generated output token IDs.
    pub output_tokens: Vec<u32>,
    /// Number of prompt tokens already processed (for resuming prefill).
    pub prompt_pos: usize,
    /// Maximum total tokens (prompt + output).
    pub max_tokens: usize,
    /// Whether the sequence has been stopped (by user or EOS).
    pub stopped: bool,
    /// Priority (lower = higher priority). Default: arrival order.
    pub priority: u64,
}

impl Sequence {
    /// Create a new waiting sequence.
    pub fn new(id: SeqId, prompt_tokens: Vec<u32>, max_tokens: usize) -> Self {
        Self {
            id,
            state: SeqState::Waiting,
            prompt_tokens,
            output_tokens: Vec::new(),
            prompt_pos: 0,
            max_tokens,
            stopped: false,
            priority: id, // FIFO by default
        }
    }

    /// Total tokens in this sequence (prompt + generated).
    pub fn total_tokens(&self) -> usize {
        self.prompt_tokens.len() + self.output_tokens.len()
    }

    /// Whether this sequence has reached its generation limit.
    pub fn at_limit(&self) -> bool {
        self.total_tokens() >= self.max_tokens
    }

    /// All tokens in order (prompt + output).
    pub fn all_tokens(&self) -> Vec<u32> {
        let mut tokens = self.prompt_tokens.clone();
        tokens.extend_from_slice(&self.output_tokens);
        tokens
    }
}

/// Configuration for the batch scheduler.
#[derive(Debug, Clone)]
pub struct SchedulerConfig {
    /// Maximum number of concurrent sequences.
    pub max_sequences: usize,
    /// Maximum batch size for a single forward pass (number of tokens).
    pub max_batch_tokens: usize,
    /// Maximum number of sequences in a single decode batch.
    pub max_batch_sequences: usize,
    /// Maximum tokens to prefill in a single forward pass.
    pub max_prefill_tokens: usize,
}

impl Default for SchedulerConfig {
    fn default() -> Self {
        Self {
            max_sequences: 32,
            max_batch_tokens: 512,
            max_batch_sequences: 8,
            max_prefill_tokens: 256,
        }
    }
}

/// A batch of work to be executed in a single forward pass.
#[derive(Debug)]
pub struct ScheduledBatch {
    /// Sequence IDs in this batch.
    pub seq_ids: Vec<SeqId>,
    /// Token IDs to process for each sequence.
    pub tokens: Vec<Vec<u32>>,
    /// Whether each sequence is in prefill (true) or decode (false).
    pub is_prefill: Vec<bool>,
}

impl ScheduledBatch {
    /// Total number of tokens in this batch.
    pub fn total_tokens(&self) -> usize {
        self.tokens.iter().map(|t| t.len()).sum()
    }

    /// Whether this batch is empty.
    pub fn is_empty(&self) -> bool {
        self.seq_ids.is_empty()
    }
}

/// Continuous batching scheduler.
///
/// Manages a pool of sequences and produces batches for the inference engine.
pub struct Scheduler {
    /// Scheduler configuration.
    config: SchedulerConfig,
    /// Active sequences indexed by ID.
    sequences: HashMap<SeqId, Sequence>,
    /// Next sequence ID to assign.
    next_id: SeqId,
    /// Waiting queue (sequence IDs in arrival order).
    waiting_queue: Vec<SeqId>,
    /// Active sequences (prefilling or decoding).
    active_ids: Vec<SeqId>,
}

impl Scheduler {
    /// Create a new scheduler with the given configuration.
    pub fn new(config: SchedulerConfig) -> Self {
        Self {
            config,
            sequences: HashMap::new(),
            next_id: 1,
            waiting_queue: Vec::new(),
            active_ids: Vec::new(),
        }
    }

    /// Add a new inference request to the scheduler.
    ///
    /// Returns the assigned sequence ID. The sequence starts in `Waiting` state
    /// and will be promoted to `Prefilling` when capacity is available.
    pub fn add_request(&mut self, prompt_tokens: Vec<u32>, max_tokens: usize) -> SeqId {
        let id = self.next_id;
        self.next_id += 1;

        let seq = Sequence::new(id, prompt_tokens, max_tokens);
        self.sequences.insert(id, seq);
        self.waiting_queue.push(id);
        id
    }

    /// Cancel/remove a sequence.
    pub fn remove_sequence(&mut self, id: SeqId) {
        self.sequences.remove(&id);
        self.waiting_queue.retain(|&x| x != id);
        self.active_ids.retain(|&x| x != id);
    }

    /// Mark a sequence as finished.
    pub fn finish_sequence(&mut self, id: SeqId) {
        if let Some(seq) = self.sequences.get_mut(&id) {
            seq.state = SeqState::Finished;
            seq.stopped = true;
        }
        self.active_ids.retain(|&x| x != id);
    }

    /// Record a generated token for a sequence.
    pub fn append_token(&mut self, id: SeqId, token: u32) {
        if let Some(seq) = self.sequences.get_mut(&id) {
            seq.output_tokens.push(token);
        }
    }

    /// Get a reference to a sequence by ID.
    pub fn get_sequence(&self, id: SeqId) -> Option<&Sequence> {
        self.sequences.get(&id)
    }

    /// Number of active (non-finished, non-waiting) sequences.
    pub fn active_count(&self) -> usize {
        self.active_ids.len()
    }

    /// Number of waiting sequences.
    pub fn waiting_count(&self) -> usize {
        self.waiting_queue.len()
    }

    /// Total number of sequences (all states).
    pub fn total_count(&self) -> usize {
        self.sequences.len()
    }

    /// Whether there is work to do (waiting or active sequences).
    pub fn has_work(&self) -> bool {
        !self.waiting_queue.is_empty() || !self.active_ids.is_empty()
    }

    /// Schedule the next batch of work.
    ///
    /// Returns a batch describing which sequences to process and what tokens
    /// to feed them. Returns an empty batch if there's nothing to do.
    pub fn schedule(&mut self) -> ScheduledBatch {
        let mut batch = ScheduledBatch {
            seq_ids: Vec::new(),
            tokens: Vec::new(),
            is_prefill: Vec::new(),
        };

        // Phase 1: Promote waiting sequences to prefilling (if capacity allows)
        while !self.waiting_queue.is_empty()
            && self.active_ids.len() < self.config.max_sequences
            && batch.seq_ids.len() < self.config.max_batch_sequences
        {
            let id = self.waiting_queue.remove(0);
            if let Some(seq) = self.sequences.get_mut(&id) {
                seq.state = SeqState::Prefilling;
                self.active_ids.push(id);

                // Schedule prefill tokens (up to max_prefill_tokens)
                let remaining = seq.prompt_tokens.len() - seq.prompt_pos;
                let chunk = remaining.min(self.config.max_prefill_tokens);
                let prefill_tokens =
                    seq.prompt_tokens[seq.prompt_pos..seq.prompt_pos + chunk].to_vec();
                seq.prompt_pos += chunk;

                // If all prompt tokens scheduled, transition to decoding
                if seq.prompt_pos >= seq.prompt_tokens.len() {
                    seq.state = SeqState::Decoding;
                }

                batch.seq_ids.push(id);
                batch.tokens.push(prefill_tokens);
                batch.is_prefill.push(true);
            }
        }

        // Phase 2: Continue prefill for partially-processed sequences
        // (only those not already scheduled in Phase 1)
        let active_snapshot: Vec<SeqId> = self.active_ids.clone();
        for &id in &active_snapshot {
            if batch.total_tokens() >= self.config.max_batch_tokens {
                break;
            }
            // Skip sequences already scheduled by Phase 1
            if batch.seq_ids.contains(&id) {
                continue;
            }
            if let Some(seq) = self.sequences.get_mut(&id) {
                if seq.state == SeqState::Prefilling {
                    let remaining = seq.prompt_tokens.len() - seq.prompt_pos;
                    let budget = self.config.max_batch_tokens - batch.total_tokens();
                    let chunk = remaining.min(self.config.max_prefill_tokens).min(budget);
                    if chunk > 0 {
                        let prefill_tokens =
                            seq.prompt_tokens[seq.prompt_pos..seq.prompt_pos + chunk].to_vec();
                        seq.prompt_pos += chunk;

                        if seq.prompt_pos >= seq.prompt_tokens.len() {
                            seq.state = SeqState::Decoding;
                        }

                        batch.seq_ids.push(id);
                        batch.tokens.push(prefill_tokens);
                        batch.is_prefill.push(true);
                    }
                }
            }
        }

        // Phase 3: Schedule decode tokens for active decoding sequences
        for &id in &active_snapshot {
            if batch.seq_ids.len() >= self.config.max_batch_sequences {
                break;
            }
            if batch.total_tokens() >= self.config.max_batch_tokens {
                break;
            }
            if let Some(seq) = self.sequences.get(&id) {
                if seq.state == SeqState::Decoding
                    && !seq.stopped
                    && !seq.at_limit()
                    && !batch.seq_ids.contains(&id)
                {
                    // Decode: just the last generated token (or last prompt token if no output)
                    let last_token = seq
                        .output_tokens
                        .last()
                        .copied()
                        .unwrap_or_else(|| *seq.prompt_tokens.last().unwrap_or(&0));
                    batch.seq_ids.push(id);
                    batch.tokens.push(vec![last_token]);
                    batch.is_prefill.push(false);
                }
            }
        }

        // Clean up finished/at-limit sequences from active list
        self.active_ids.retain(|&id| {
            self.sequences
                .get(&id)
                .is_some_and(|s| !s.stopped && !s.at_limit() && s.state != SeqState::Finished)
        });

        batch
    }

    /// Get all finished sequences and remove them from the scheduler.
    pub fn drain_finished(&mut self) -> Vec<Sequence> {
        let finished_ids: Vec<SeqId> = self
            .sequences
            .iter()
            .filter(|(_, s)| s.state == SeqState::Finished || s.stopped || s.at_limit())
            .map(|(&id, _)| id)
            .collect();

        let mut finished = Vec::new();
        for id in finished_ids {
            if let Some(seq) = self.sequences.remove(&id) {
                finished.push(seq);
            }
            self.active_ids.retain(|&x| x != id);
            self.waiting_queue.retain(|&x| x != id);
        }
        finished
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_add_and_schedule_single() {
        let mut scheduler = Scheduler::new(SchedulerConfig::default());
        let id = scheduler.add_request(vec![1, 2, 3], 10);
        assert_eq!(scheduler.waiting_count(), 1);
        assert_eq!(scheduler.active_count(), 0);

        let batch = scheduler.schedule();
        assert_eq!(batch.seq_ids.len(), 1);
        assert_eq!(batch.seq_ids[0], id);
        assert_eq!(batch.tokens[0], vec![1, 2, 3]);
        assert!(batch.is_prefill[0]);

        // After scheduling, seq should be active
        assert_eq!(scheduler.waiting_count(), 0);
        assert_eq!(scheduler.active_count(), 1);
    }

    #[test]
    fn test_decode_after_prefill() {
        let mut scheduler = Scheduler::new(SchedulerConfig::default());
        let id = scheduler.add_request(vec![1, 2, 3], 10);

        // First schedule: prefill
        let batch = scheduler.schedule();
        assert!(batch.is_prefill[0]);

        // Simulate: append a generated token
        scheduler.append_token(id, 4);

        // Second schedule: decode
        let batch = scheduler.schedule();
        assert_eq!(batch.seq_ids.len(), 1);
        assert!(!batch.is_prefill[0]);
        assert_eq!(batch.tokens[0], vec![4]); // last generated token
    }

    #[test]
    fn test_finish_sequence() {
        let mut scheduler = Scheduler::new(SchedulerConfig::default());
        let id = scheduler.add_request(vec![1], 5);
        scheduler.schedule(); // promote to active

        scheduler.finish_sequence(id);
        let batch = scheduler.schedule();
        assert!(batch.is_empty());
        assert_eq!(scheduler.active_count(), 0);
    }

    #[test]
    fn test_multiple_sequences() {
        let config = SchedulerConfig {
            max_batch_sequences: 4,
            ..SchedulerConfig::default()
        };
        let mut scheduler = Scheduler::new(config);

        let id1 = scheduler.add_request(vec![1, 2], 10);
        let id2 = scheduler.add_request(vec![3, 4], 10);
        let id3 = scheduler.add_request(vec![5, 6], 10);

        let batch = scheduler.schedule();
        assert_eq!(batch.seq_ids.len(), 3);
        assert!(batch.seq_ids.contains(&id1));
        assert!(batch.seq_ids.contains(&id2));
        assert!(batch.seq_ids.contains(&id3));
    }

    #[test]
    fn test_max_sequences_respected() {
        let config = SchedulerConfig {
            max_sequences: 2,
            ..SchedulerConfig::default()
        };
        let mut scheduler = Scheduler::new(config);

        scheduler.add_request(vec![1], 10);
        scheduler.add_request(vec![2], 10);
        scheduler.add_request(vec![3], 10); // should stay in waiting

        let batch = scheduler.schedule();
        assert_eq!(batch.seq_ids.len(), 2);
        assert_eq!(scheduler.waiting_count(), 1);
    }

    #[test]
    fn test_remove_sequence() {
        let mut scheduler = Scheduler::new(SchedulerConfig::default());
        let id = scheduler.add_request(vec![1, 2, 3], 10);
        assert_eq!(scheduler.total_count(), 1);

        scheduler.remove_sequence(id);
        assert_eq!(scheduler.total_count(), 0);
        assert_eq!(scheduler.waiting_count(), 0);
    }

    #[test]
    fn test_at_limit_stops_scheduling() {
        let mut scheduler = Scheduler::new(SchedulerConfig::default());
        let id = scheduler.add_request(vec![1], 3); // max 3 total tokens

        scheduler.schedule(); // prefill (1 prompt token)
        scheduler.append_token(id, 2);
        scheduler.schedule(); // decode token 2
        scheduler.append_token(id, 3);

        // Now at 3 tokens total (1 prompt + 2 output) — at limit
        let batch = scheduler.schedule();
        // Should not schedule this sequence anymore
        let has_id = batch.seq_ids.contains(&id);
        assert!(!has_id, "at-limit sequence should not be scheduled");
    }

    #[test]
    fn test_drain_finished() {
        let mut scheduler = Scheduler::new(SchedulerConfig::default());
        let id1 = scheduler.add_request(vec![1], 10);
        let id2 = scheduler.add_request(vec![2], 10);
        scheduler.schedule(); // promote both

        scheduler.finish_sequence(id1);

        let finished = scheduler.drain_finished();
        assert_eq!(finished.len(), 1);
        assert_eq!(finished[0].id, id1);
        assert_eq!(scheduler.total_count(), 1);
        assert!(scheduler.get_sequence(id2).is_some());
    }

    #[test]
    fn test_long_prefill_chunked() {
        let config = SchedulerConfig {
            max_prefill_tokens: 4,
            ..SchedulerConfig::default()
        };
        let mut scheduler = Scheduler::new(config);
        scheduler.add_request(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 20);

        // First batch: prefill first 4 tokens
        let batch = scheduler.schedule();
        assert_eq!(batch.tokens[0], vec![1, 2, 3, 4]);
        assert!(batch.is_prefill[0]);

        // Second batch: next 4 tokens
        let batch = scheduler.schedule();
        assert_eq!(batch.tokens[0].len(), 4);

        // Third batch: last 2 tokens, then transitions to decode
        let batch = scheduler.schedule();
        assert_eq!(batch.tokens[0].len(), 2);
    }

    #[test]
    fn test_has_work() {
        let mut scheduler = Scheduler::new(SchedulerConfig::default());
        assert!(!scheduler.has_work());

        scheduler.add_request(vec![1], 5);
        assert!(scheduler.has_work());
    }
}