use crate::weight_registry::WeightRegistry;
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct LoraRequest {
pub id: u64,
pub adapter: Option<String>,
pub payload: LoraPayload,
}
#[derive(Debug, Clone)]
pub struct LoraPayload {
pub prompt_tokens: Vec<u32>,
pub max_new_tokens: usize,
}
#[derive(Debug)]
pub struct LoraBatch {
pub adapter: Option<String>,
pub requests: Vec<LoraRequest>,
}
impl LoraBatch {
pub fn len(&self) -> usize {
self.requests.len()
}
pub fn is_empty(&self) -> bool {
self.requests.is_empty()
}
}
pub struct LoraScheduler {
pending: VecDeque<LoraRequest>,
pub max_batch: usize,
registry: Option<*const WeightRegistry>,
}
unsafe impl Send for LoraScheduler {}
impl LoraScheduler {
pub fn new(max_batch: usize) -> Self {
Self {
pending: VecDeque::new(),
max_batch,
registry: None,
}
}
pub fn bind_registry(&mut self, registry: &WeightRegistry) {
self.registry = Some(registry as *const _);
}
pub fn push(&mut self, req: LoraRequest) -> Result<(), UnknownAdapter> {
if let (Some(reg_ptr), Some(adapter)) = (self.registry, &req.adapter) {
let reg = unsafe { &*reg_ptr };
if reg.lora_adapter_handles(adapter).is_empty() {
return Err(UnknownAdapter {
name: adapter.clone(),
});
}
}
self.pending.push_back(req);
Ok(())
}
pub fn peek_adapter(&self) -> Option<Option<String>> {
self.pending.front().map(|r| r.adapter.clone())
}
pub fn drain_batch(&mut self) -> Option<LoraBatch> {
let head = self.pending.pop_front()?;
let target = head.adapter.clone();
let mut requests = vec![head];
while requests.len() < self.max_batch {
match self.pending.front() {
Some(next) if next.adapter == target => {
requests.push(self.pending.pop_front().unwrap());
}
_ => break,
}
}
Some(LoraBatch {
adapter: target,
requests,
})
}
pub fn pending(&self) -> usize {
self.pending.len()
}
pub fn is_empty(&self) -> bool {
self.pending.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct UnknownAdapter {
pub name: String,
}
impl std::fmt::Display for UnknownAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "adapter `{}` is not registered", self.name)
}
}
impl std::error::Error for UnknownAdapter {}
pub fn naive_swap_count(reqs: &[LoraRequest]) -> usize {
let mut swaps: usize = 0;
let mut last: Option<&Option<String>> = None;
for r in reqs {
if last.map(|l| l != &r.adapter).unwrap_or(true) {
swaps += 1;
}
last = Some(&r.adapter);
}
swaps.saturating_sub(1) }
#[cfg(test)]
mod tests {
use super::*;
use crate::weight_registry::{WeightKind, WeightRegistry};
use rlx_ir::{DType, Shape};
use std::sync::Arc;
fn req(id: u64, adapter: Option<&str>) -> LoraRequest {
LoraRequest {
id,
adapter: adapter.map(|s| s.to_string()),
payload: LoraPayload {
prompt_tokens: vec![],
max_new_tokens: 4,
},
}
}
#[test]
fn coalesces_same_adapter_runs() {
let mut s = LoraScheduler::new(8);
for r in [
req(1, Some("code")),
req(2, Some("code")),
req(3, Some("math")),
req(4, Some("math")),
req(5, Some("code")),
req(6, None),
req(7, None),
] {
s.push(r).unwrap();
}
let b1 = s.drain_batch().unwrap();
assert_eq!(b1.adapter.as_deref(), Some("code"));
assert_eq!(b1.len(), 2);
let b2 = s.drain_batch().unwrap();
assert_eq!(b2.adapter.as_deref(), Some("math"));
assert_eq!(b2.len(), 2);
let b3 = s.drain_batch().unwrap();
assert_eq!(b3.adapter.as_deref(), Some("code"));
assert_eq!(b3.len(), 1);
let b4 = s.drain_batch().unwrap();
assert!(b4.adapter.is_none());
assert_eq!(b4.len(), 2);
assert!(s.drain_batch().is_none());
}
#[test]
fn respects_max_batch_cap() {
let mut s = LoraScheduler::new(3);
for i in 0..10 {
s.push(req(i, Some("code"))).unwrap();
}
let b = s.drain_batch().unwrap();
assert_eq!(b.len(), 3, "max_batch=3 should split a long run");
assert_eq!(s.pending(), 7);
}
#[test]
fn registry_validation_rejects_unknown_adapter() {
let mut reg = WeightRegistry::new();
reg.register(
"ffn",
Shape::new(&[8, 8], DType::F32),
Arc::from(vec![0u8; 256]),
WeightKind::Base,
);
reg.register(
"ffn.lora.a",
Shape::new(&[8, 4], DType::F32),
Arc::from(vec![0u8; 128]),
WeightKind::LoraAdapter {
adapter: "code".into(),
},
);
let mut s = LoraScheduler::new(4);
s.bind_registry(®);
assert!(s.push(req(1, Some("code"))).is_ok());
assert!(s.push(req(2, None)).is_ok());
let err = s.push(req(3, Some("nonexistent"))).unwrap_err();
assert_eq!(err.name, "nonexistent");
}
#[test]
fn swap_count_metric() {
let reqs = [
req(1, Some("a")),
req(2, Some("a")),
req(3, Some("b")),
req(4, Some("a")),
];
assert_eq!(naive_swap_count(&reqs), 2);
}
}