rlx_runtime/lora_scheduler.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! LoRA-aware request scheduling (plan #33).
17//!
18//! Borrowed from MAX's `serve/scheduler/lora_scheduler_utils.py`.
19//! When multiple LoRA adapters are loaded, each request specifies
20//! which adapter it wants. Naïvely interleaving requests forces an
21//! adapter swap per request — wasted bandwidth + latency. The
22//! scheduler here groups consecutive same-adapter requests into
23//! one "batch", so a swap happens once per batch boundary instead
24//! of once per request.
25//!
26//! Pure data-layer scheduling — no executor, no compiled graphs.
27//! Plug into a future serving loop by:
28//! 1. Push incoming requests into [`LoraScheduler::push`].
29//! 2. Drain runnable batches with [`LoraScheduler::drain_batch`].
30//! 3. For each batch, swap to that adapter once and run all
31//! requests in the batch back-to-back.
32//!
33//! Pairs with #24 (named weights registry) which owns the adapter
34//! bytes, and #9 (LoRA kernel) which is the actual compute.
35
36use crate::weight_registry::WeightRegistry;
37use std::collections::VecDeque;
38
39/// One serving request with its target LoRA adapter (or `None`
40/// for "use the base model"). The opaque `id: u64` is whatever
41/// the caller wants — typically a request UUID hash.
42#[derive(Debug, Clone)]
43pub struct LoraRequest {
44 pub id: u64,
45 pub adapter: Option<String>,
46 /// Caller-provided payload. The scheduler doesn't interpret
47 /// it; downstream code reads it after `drain_batch`.
48 pub payload: LoraPayload,
49}
50
51/// Generic request payload. The scheduler is generic over
52/// payload type via `Box<dyn Any>`; we keep this concrete here
53/// for simplicity and the common case (text generation).
54#[derive(Debug, Clone)]
55pub struct LoraPayload {
56 pub prompt_tokens: Vec<u32>,
57 pub max_new_tokens: usize,
58}
59
60/// One batch handed to the executor. All requests in a batch
61/// target the same adapter (or all None).
62#[derive(Debug)]
63pub struct LoraBatch {
64 pub adapter: Option<String>,
65 pub requests: Vec<LoraRequest>,
66}
67
68impl LoraBatch {
69 pub fn len(&self) -> usize {
70 self.requests.len()
71 }
72 pub fn is_empty(&self) -> bool {
73 self.requests.is_empty()
74 }
75}
76
77/// FIFO-ish scheduler with same-adapter coalescing.
78pub struct LoraScheduler {
79 /// Pending requests; insertion order preserved.
80 pending: VecDeque<LoraRequest>,
81 /// Maximum requests per drained batch.
82 pub max_batch: usize,
83 /// Optional reference to a registry — `validate` checks that
84 /// adapter names are registered before push. Storing as a raw
85 /// pointer avoids lifetime entanglement; callers ensure the
86 /// registry outlives the scheduler.
87 registry: Option<*const WeightRegistry>,
88}
89
90// `*const WeightRegistry` is not Send/Sync by default; we mark
91// the scheduler Send because the pointer is only used for read
92// queries on a registry that's itself Send + Sync once we wrap
93// it in Arc<RwLock>. The pointer should never be mutated.
94unsafe impl Send for LoraScheduler {}
95
96impl LoraScheduler {
97 pub fn new(max_batch: usize) -> Self {
98 Self {
99 pending: VecDeque::new(),
100 max_batch,
101 registry: None,
102 }
103 }
104
105 /// Bind a registry for adapter-name validation. Caller must
106 /// ensure `registry` outlives the scheduler.
107 pub fn bind_registry(&mut self, registry: &WeightRegistry) {
108 self.registry = Some(registry as *const _);
109 }
110
111 /// Push a request. Returns `Err` only if a registry is bound
112 /// and the adapter name isn't registered.
113 pub fn push(&mut self, req: LoraRequest) -> Result<(), UnknownAdapter> {
114 if let (Some(reg_ptr), Some(adapter)) = (self.registry, &req.adapter) {
115 // Safety: caller guaranteed the registry outlives us.
116 let reg = unsafe { &*reg_ptr };
117 if reg.lora_adapter_handles(adapter).is_empty() {
118 return Err(UnknownAdapter {
119 name: adapter.clone(),
120 });
121 }
122 }
123 self.pending.push_back(req);
124 Ok(())
125 }
126
127 /// Look at the next batch's adapter without draining.
128 pub fn peek_adapter(&self) -> Option<Option<String>> {
129 self.pending.front().map(|r| r.adapter.clone())
130 }
131
132 /// Drain the next runnable batch — up to `max_batch` requests
133 /// that all share the same `adapter`. Returns `None` if empty.
134 pub fn drain_batch(&mut self) -> Option<LoraBatch> {
135 let head = self.pending.pop_front()?;
136 let target = head.adapter.clone();
137 let mut requests = vec![head];
138 while requests.len() < self.max_batch {
139 match self.pending.front() {
140 Some(next) if next.adapter == target => {
141 requests.push(self.pending.pop_front().unwrap());
142 }
143 _ => break,
144 }
145 }
146 Some(LoraBatch {
147 adapter: target,
148 requests,
149 })
150 }
151
152 pub fn pending(&self) -> usize {
153 self.pending.len()
154 }
155 pub fn is_empty(&self) -> bool {
156 self.pending.is_empty()
157 }
158}
159
160#[derive(Debug, Clone)]
161pub struct UnknownAdapter {
162 pub name: String,
163}
164
165impl std::fmt::Display for UnknownAdapter {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 write!(f, "adapter `{}` is not registered", self.name)
168 }
169}
170impl std::error::Error for UnknownAdapter {}
171
172/// Count adapter swaps that would occur if a sequence of
173/// requests were processed in declared order without coalescing.
174/// Useful for observability / unit tests.
175pub fn naive_swap_count(reqs: &[LoraRequest]) -> usize {
176 let mut swaps: usize = 0;
177 let mut last: Option<&Option<String>> = None;
178 for r in reqs {
179 if last.map(|l| l != &r.adapter).unwrap_or(true) {
180 swaps += 1;
181 }
182 last = Some(&r.adapter);
183 }
184 swaps.saturating_sub(1) // first "swap" is the initial setup, not a swap
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::weight_registry::{WeightKind, WeightRegistry};
191 use rlx_ir::{DType, Shape};
192 use std::sync::Arc;
193
194 fn req(id: u64, adapter: Option<&str>) -> LoraRequest {
195 LoraRequest {
196 id,
197 adapter: adapter.map(|s| s.to_string()),
198 payload: LoraPayload {
199 prompt_tokens: vec![],
200 max_new_tokens: 4,
201 },
202 }
203 }
204
205 #[test]
206 fn coalesces_same_adapter_runs() {
207 let mut s = LoraScheduler::new(8);
208 // Mixed-adapter input order:
209 // code, code, math, math, code, base, base
210 for r in [
211 req(1, Some("code")),
212 req(2, Some("code")),
213 req(3, Some("math")),
214 req(4, Some("math")),
215 req(5, Some("code")),
216 req(6, None),
217 req(7, None),
218 ] {
219 s.push(r).unwrap();
220 }
221
222 let b1 = s.drain_batch().unwrap();
223 assert_eq!(b1.adapter.as_deref(), Some("code"));
224 assert_eq!(b1.len(), 2); // 1, 2 — stops at the math at front
225
226 let b2 = s.drain_batch().unwrap();
227 assert_eq!(b2.adapter.as_deref(), Some("math"));
228 assert_eq!(b2.len(), 2);
229
230 let b3 = s.drain_batch().unwrap();
231 assert_eq!(b3.adapter.as_deref(), Some("code"));
232 assert_eq!(b3.len(), 1);
233
234 let b4 = s.drain_batch().unwrap();
235 assert!(b4.adapter.is_none());
236 assert_eq!(b4.len(), 2);
237
238 assert!(s.drain_batch().is_none());
239 }
240
241 #[test]
242 fn respects_max_batch_cap() {
243 let mut s = LoraScheduler::new(3);
244 for i in 0..10 {
245 s.push(req(i, Some("code"))).unwrap();
246 }
247 let b = s.drain_batch().unwrap();
248 assert_eq!(b.len(), 3, "max_batch=3 should split a long run");
249 assert_eq!(s.pending(), 7);
250 }
251
252 #[test]
253 fn registry_validation_rejects_unknown_adapter() {
254 let mut reg = WeightRegistry::new();
255 reg.register(
256 "ffn",
257 Shape::new(&[8, 8], DType::F32),
258 Arc::from(vec![0u8; 256]),
259 WeightKind::Base,
260 );
261 reg.register(
262 "ffn.lora.a",
263 Shape::new(&[8, 4], DType::F32),
264 Arc::from(vec![0u8; 128]),
265 WeightKind::LoraAdapter {
266 adapter: "code".into(),
267 },
268 );
269
270 let mut s = LoraScheduler::new(4);
271 s.bind_registry(®);
272
273 // Known adapter passes.
274 assert!(s.push(req(1, Some("code"))).is_ok());
275 // None (base model) always passes.
276 assert!(s.push(req(2, None)).is_ok());
277 // Unknown adapter rejected.
278 let err = s.push(req(3, Some("nonexistent"))).unwrap_err();
279 assert_eq!(err.name, "nonexistent");
280 }
281
282 #[test]
283 fn swap_count_metric() {
284 let reqs = [
285 req(1, Some("a")),
286 req(2, Some("a")),
287 req(3, Some("b")),
288 req(4, Some("a")),
289 ];
290 // Sequence transitions: a, a→a (no swap), a→b (swap), b→a
291 // (swap) = 2 swaps after initial.
292 assert_eq!(naive_swap_count(&reqs), 2);
293 }
294}