#[derive(Debug)]
pub struct TokenBatch {
tokens: Vec<usize>,
capacity: usize,
}
impl TokenBatch {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
tokens: Vec::with_capacity(capacity),
capacity,
}
}
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
#[must_use]
pub fn len(&self) -> usize {
self.tokens.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tokens.is_empty()
}
#[must_use]
pub fn is_full(&self) -> bool {
self.tokens.len() >= self.capacity
}
pub fn push(&mut self, token: usize) -> Option<Vec<usize>> {
self.tokens.push(token);
if self.is_full() {
Some(self.flush())
} else {
None
}
}
pub fn flush(&mut self) -> Vec<usize> {
std::mem::take(&mut self.tokens)
}
}
#[derive(Debug, Clone)]
struct SpeculativeCandidate {
token: usize,
#[allow(dead_code)]
confidence: f32,
}
#[derive(Debug)]
pub struct SpeculativeBuffer {
candidates: Vec<SpeculativeCandidate>,
capacity: usize,
}
impl SpeculativeBuffer {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
candidates: Vec::with_capacity(capacity),
capacity,
}
}
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
#[must_use]
pub fn len(&self) -> usize {
self.candidates.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.candidates.is_empty()
}
pub fn add_candidate(&mut self, token: usize, confidence: f32) {
if self.candidates.len() < self.capacity {
self.candidates
.push(SpeculativeCandidate { token, confidence });
}
}
#[must_use]
pub fn verify(&self, actual_tokens: &[usize]) -> (usize, Option<usize>) {
let mut accepted = 0;
for (i, candidate) in self.candidates.iter().enumerate() {
if i < actual_tokens.len() && candidate.token == actual_tokens[i] {
accepted += 1;
} else {
return (accepted, Some(i));
}
}
(accepted, None)
}
pub fn accept(&mut self, n: usize) {
if n >= self.candidates.len() {
self.candidates.clear();
} else {
self.candidates.drain(0..n);
}
}
pub fn reject(&mut self) {
self.candidates.clear();
}
}
pub type BatchId = u64;
#[derive(Debug)]
pub struct InferenceBatchScheduler {
next_id: BatchId,
pending: std::collections::HashMap<BatchId, Vec<usize>>,
completed: std::collections::VecDeque<(BatchId, Vec<usize>)>,
}
impl InferenceBatchScheduler {
#[must_use]
pub fn new() -> Self {
Self {
next_id: 0,
pending: std::collections::HashMap::new(),
completed: std::collections::VecDeque::new(),
}
}
#[must_use]
pub fn pending_count(&self) -> usize {
self.pending.len()
}
#[must_use]
pub fn completed_count(&self) -> usize {
self.completed.len()
}
pub fn submit(&mut self, tokens: Vec<usize>) -> BatchId {
let id = self.next_id;
self.next_id += 1;
self.pending.insert(id, tokens);
id
}
pub fn complete(&mut self, batch_id: BatchId, results: Vec<usize>) {
self.pending.remove(&batch_id);
self.completed.push_back((batch_id, results));
}
pub fn poll(&mut self) -> Option<(BatchId, Vec<usize>)> {
self.completed.pop_front()
}
pub fn drain(&mut self) -> Vec<(BatchId, Vec<usize>)> {
self.completed.drain(..).collect()
}
}
impl Default for InferenceBatchScheduler {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct AsyncRequestQueue<T> {
items: std::collections::VecDeque<T>,
capacity: usize,
}
impl<T> AsyncRequestQueue<T> {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
items: std::collections::VecDeque::with_capacity(capacity),
capacity,
}
}
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
#[must_use]
pub fn len(&self) -> usize {
self.items.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
#[must_use]
pub fn is_full(&self) -> bool {
self.items.len() >= self.capacity
}
pub fn try_push(&mut self, item: T) -> bool {
if self.is_full() {
false
} else {
self.items.push_back(item);
true
}
}
pub fn try_pop(&mut self) -> Option<T> {
self.items.pop_front()
}
}
pub type InferenceCompletionHandler = Box<dyn Fn(u64, &[usize]) + Send + Sync>;
pub struct InferenceEventNotifier {
handlers: Vec<InferenceCompletionHandler>,
}
impl std::fmt::Debug for InferenceEventNotifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InferenceEventNotifier")
.field("handler_count", &self.handlers.len())
.finish()
}
}
impl InferenceEventNotifier {
#[must_use]
pub fn new() -> Self {
Self {
handlers: Vec::new(),
}
}
#[must_use]
pub fn handler_count(&self) -> usize {
self.handlers.len()
}
pub fn register(&mut self, handler: InferenceCompletionHandler) {
self.handlers.push(handler);
}
pub fn notify(&self, request_id: u64, tokens: &[usize]) {
for handler in &self.handlers {
handler(request_id, tokens);
}
}
pub fn clear(&mut self) {
self.handlers.clear();
}
}
impl Default for InferenceEventNotifier {
fn default() -> Self {
Self::new()
}
}
pub type RequestId = u64;
#[derive(Debug)]
pub struct TimeoutManager {
deadlines: std::collections::HashMap<RequestId, std::time::Instant>,
}
impl TimeoutManager {
#[must_use]
pub fn new() -> Self {
Self {
deadlines: std::collections::HashMap::new(),
}
}
#[must_use]
pub fn active_count(&self) -> usize {
self.deadlines.len()
}
pub fn register(&mut self, request_id: RequestId, deadline: std::time::Instant) {
self.deadlines.insert(request_id, deadline);
}
pub fn remove(&mut self, request_id: RequestId) {
self.deadlines.remove(&request_id);
}
pub fn check_expired(&mut self) -> Vec<RequestId> {
let now = std::time::Instant::now();
let expired: Vec<RequestId> = self
.deadlines
.iter()
.filter(|(_, &deadline)| now >= deadline)
.map(|(&id, _)| id)
.collect();
for id in &expired {
self.deadlines.remove(id);
}
expired
}
}
impl Default for TimeoutManager {
fn default() -> Self {
Self::new()
}
}
pub type Priority = u32;
#[derive(Debug, Clone)]
pub struct PriorityRequest<T> {
priority: Priority,
sequence: u64, data: T,
}
include!("priority_queue.rs");
include!("batch_scheduling_token_rate.rs");