Skip to main content

rlx_runtime/
lora_scheduler.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! LoRA-aware request scheduling (plan #33).
17//!
18//! Borrowed from MAX's `serve/scheduler/lora_scheduler_utils.py`.
19//! When multiple LoRA adapters are loaded, each request specifies
20//! which adapter it wants. Naïvely interleaving requests forces an
21//! adapter swap per request — wasted bandwidth + latency. The
22//! scheduler here groups consecutive same-adapter requests into
23//! one "batch", so a swap happens once per batch boundary instead
24//! of once per request.
25//!
26//! Pure data-layer scheduling — no executor, no compiled graphs.
27//! Plug into a future serving loop by:
28//!   1. Push incoming requests into [`LoraScheduler::push`].
29//!   2. Drain runnable batches with [`LoraScheduler::drain_batch`].
30//!   3. For each batch, swap to that adapter once and run all
31//!      requests in the batch back-to-back.
32//!
33//! Pairs with #24 (named weights registry) which owns the adapter
34//! bytes, and #9 (LoRA kernel) which is the actual compute.
35
36use crate::weight_registry::WeightRegistry;
37use std::collections::VecDeque;
38
39/// One serving request with its target LoRA adapter (or `None`
40/// for "use the base model"). The opaque `id: u64` is whatever
41/// the caller wants — typically a request UUID hash.
42#[derive(Debug, Clone)]
43pub struct LoraRequest {
44    pub id: u64,
45    pub adapter: Option<String>,
46    /// Caller-provided payload. The scheduler doesn't interpret
47    /// it; downstream code reads it after `drain_batch`.
48    pub payload: LoraPayload,
49}
50
51/// Generic request payload. The scheduler is generic over
52/// payload type via `Box<dyn Any>`; we keep this concrete here
53/// for simplicity and the common case (text generation).
54#[derive(Debug, Clone)]
55pub struct LoraPayload {
56    pub prompt_tokens: Vec<u32>,
57    pub max_new_tokens: usize,
58}
59
60/// One batch handed to the executor. All requests in a batch
61/// target the same adapter (or all None).
62#[derive(Debug)]
63pub struct LoraBatch {
64    pub adapter: Option<String>,
65    pub requests: Vec<LoraRequest>,
66}
67
68impl LoraBatch {
69    pub fn len(&self) -> usize {
70        self.requests.len()
71    }
72    pub fn is_empty(&self) -> bool {
73        self.requests.is_empty()
74    }
75}
76
77/// FIFO-ish scheduler with same-adapter coalescing.
78pub struct LoraScheduler {
79    /// Pending requests; insertion order preserved.
80    pending: VecDeque<LoraRequest>,
81    /// Maximum requests per drained batch.
82    pub max_batch: usize,
83    /// Optional reference to a registry — `validate` checks that
84    /// adapter names are registered before push. Storing as a raw
85    /// pointer avoids lifetime entanglement; callers ensure the
86    /// registry outlives the scheduler.
87    registry: Option<*const WeightRegistry>,
88}
89
90// `*const WeightRegistry` is not Send/Sync by default; we mark
91// the scheduler Send because the pointer is only used for read
92// queries on a registry that's itself Send + Sync once we wrap
93// it in Arc<RwLock>. The pointer should never be mutated.
94unsafe impl Send for LoraScheduler {}
95
96impl LoraScheduler {
97    pub fn new(max_batch: usize) -> Self {
98        Self {
99            pending: VecDeque::new(),
100            max_batch,
101            registry: None,
102        }
103    }
104
105    /// Bind a registry for adapter-name validation. Caller must
106    /// ensure `registry` outlives the scheduler.
107    pub fn bind_registry(&mut self, registry: &WeightRegistry) {
108        self.registry = Some(registry as *const _);
109    }
110
111    /// Push a request. Returns `Err` only if a registry is bound
112    /// and the adapter name isn't registered.
113    pub fn push(&mut self, req: LoraRequest) -> Result<(), UnknownAdapter> {
114        if let (Some(reg_ptr), Some(adapter)) = (self.registry, &req.adapter) {
115            // Safety: caller guaranteed the registry outlives us.
116            let reg = unsafe { &*reg_ptr };
117            if reg.lora_adapter_handles(adapter).is_empty() {
118                return Err(UnknownAdapter {
119                    name: adapter.clone(),
120                });
121            }
122        }
123        self.pending.push_back(req);
124        Ok(())
125    }
126
127    /// Look at the next batch's adapter without draining.
128    pub fn peek_adapter(&self) -> Option<Option<String>> {
129        self.pending.front().map(|r| r.adapter.clone())
130    }
131
132    /// Drain the next runnable batch — up to `max_batch` requests
133    /// that all share the same `adapter`. Returns `None` if empty.
134    pub fn drain_batch(&mut self) -> Option<LoraBatch> {
135        let head = self.pending.pop_front()?;
136        let target = head.adapter.clone();
137        let mut requests = vec![head];
138        while requests.len() < self.max_batch {
139            match self.pending.front() {
140                Some(next) if next.adapter == target => {
141                    requests.push(self.pending.pop_front().unwrap());
142                }
143                _ => break,
144            }
145        }
146        Some(LoraBatch {
147            adapter: target,
148            requests,
149        })
150    }
151
152    pub fn pending(&self) -> usize {
153        self.pending.len()
154    }
155    pub fn is_empty(&self) -> bool {
156        self.pending.is_empty()
157    }
158}
159
160#[derive(Debug, Clone)]
161pub struct UnknownAdapter {
162    pub name: String,
163}
164
165impl std::fmt::Display for UnknownAdapter {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        write!(f, "adapter `{}` is not registered", self.name)
168    }
169}
170impl std::error::Error for UnknownAdapter {}
171
172/// Count adapter swaps that would occur if a sequence of
173/// requests were processed in declared order without coalescing.
174/// Useful for observability / unit tests.
175pub fn naive_swap_count(reqs: &[LoraRequest]) -> usize {
176    let mut swaps: usize = 0;
177    let mut last: Option<&Option<String>> = None;
178    for r in reqs {
179        if last.map(|l| l != &r.adapter).unwrap_or(true) {
180            swaps += 1;
181        }
182        last = Some(&r.adapter);
183    }
184    swaps.saturating_sub(1) // first "swap" is the initial setup, not a swap
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::weight_registry::{WeightKind, WeightRegistry};
191    use rlx_ir::{DType, Shape};
192    use std::sync::Arc;
193
194    fn req(id: u64, adapter: Option<&str>) -> LoraRequest {
195        LoraRequest {
196            id,
197            adapter: adapter.map(|s| s.to_string()),
198            payload: LoraPayload {
199                prompt_tokens: vec![],
200                max_new_tokens: 4,
201            },
202        }
203    }
204
205    #[test]
206    fn coalesces_same_adapter_runs() {
207        let mut s = LoraScheduler::new(8);
208        // Mixed-adapter input order:
209        //   code, code, math, math, code, base, base
210        for r in [
211            req(1, Some("code")),
212            req(2, Some("code")),
213            req(3, Some("math")),
214            req(4, Some("math")),
215            req(5, Some("code")),
216            req(6, None),
217            req(7, None),
218        ] {
219            s.push(r).unwrap();
220        }
221
222        let b1 = s.drain_batch().unwrap();
223        assert_eq!(b1.adapter.as_deref(), Some("code"));
224        assert_eq!(b1.len(), 2); // 1, 2 — stops at the math at front
225
226        let b2 = s.drain_batch().unwrap();
227        assert_eq!(b2.adapter.as_deref(), Some("math"));
228        assert_eq!(b2.len(), 2);
229
230        let b3 = s.drain_batch().unwrap();
231        assert_eq!(b3.adapter.as_deref(), Some("code"));
232        assert_eq!(b3.len(), 1);
233
234        let b4 = s.drain_batch().unwrap();
235        assert!(b4.adapter.is_none());
236        assert_eq!(b4.len(), 2);
237
238        assert!(s.drain_batch().is_none());
239    }
240
241    #[test]
242    fn respects_max_batch_cap() {
243        let mut s = LoraScheduler::new(3);
244        for i in 0..10 {
245            s.push(req(i, Some("code"))).unwrap();
246        }
247        let b = s.drain_batch().unwrap();
248        assert_eq!(b.len(), 3, "max_batch=3 should split a long run");
249        assert_eq!(s.pending(), 7);
250    }
251
252    #[test]
253    fn registry_validation_rejects_unknown_adapter() {
254        let mut reg = WeightRegistry::new();
255        reg.register(
256            "ffn",
257            Shape::new(&[8, 8], DType::F32),
258            Arc::from(vec![0u8; 256]),
259            WeightKind::Base,
260        );
261        reg.register(
262            "ffn.lora.a",
263            Shape::new(&[8, 4], DType::F32),
264            Arc::from(vec![0u8; 128]),
265            WeightKind::LoraAdapter {
266                adapter: "code".into(),
267            },
268        );
269
270        let mut s = LoraScheduler::new(4);
271        s.bind_registry(&reg);
272
273        // Known adapter passes.
274        assert!(s.push(req(1, Some("code"))).is_ok());
275        // None (base model) always passes.
276        assert!(s.push(req(2, None)).is_ok());
277        // Unknown adapter rejected.
278        let err = s.push(req(3, Some("nonexistent"))).unwrap_err();
279        assert_eq!(err.name, "nonexistent");
280    }
281
282    #[test]
283    fn swap_count_metric() {
284        let reqs = [
285            req(1, Some("a")),
286            req(2, Some("a")),
287            req(3, Some("b")),
288            req(4, Some("a")),
289        ];
290        // Sequence transitions: a, a→a (no swap), a→b (swap), b→a
291        // (swap) = 2 swaps after initial.
292        assert_eq!(naive_swap_count(&reqs), 2);
293    }
294}