1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
use crate::{ExecutionContext, KernelCache};
use anyhow::Context;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use tokio::sync::Mutex as AsyncMutex;
use tokio::sync::{Notify, Semaphore, mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio::task::JoinSet;
/// A small cloneable cancellation token.
#[derive(Clone, Debug)]
pub struct CancellationToken {
inner: Arc<CancellationInner>,
}
#[derive(Debug)]
struct CancellationInner {
cancelled: AtomicBool,
notify: Notify,
}
impl CancellationToken {
pub fn new() -> Self {
Self {
inner: Arc::new(CancellationInner {
cancelled: AtomicBool::new(false),
notify: Notify::new(),
}),
}
}
/// Trigger cancellation. Idempotent.
pub fn cancel(&self) {
if !self.inner.cancelled.swap(true, Ordering::SeqCst) {
self.inner.notify.notify_waiters();
}
}
/// Returns true if already cancelled.
pub fn is_cancelled(&self) -> bool {
self.inner.cancelled.load(Ordering::SeqCst)
}
/// Async wait until cancelled. Returns immediately if already cancelled.
pub async fn cancelled(&self) {
if self.is_cancelled() {
return;
}
self.inner.notify.notified().await;
}
}
#[derive(Debug)]
pub struct CompileTask {
pub fingerprint: String,
pub source: Vec<u8>,
pub cancel: CancellationToken,
/// Optional notify that will be signalled when the task actually starts compiling.
pub started: Option<Arc<Notify>>,
}
impl Clone for CompileTask {
fn clone(&self) -> Self {
Self {
fingerprint: self.fingerprint.clone(),
source: self.source.clone(),
cancel: self.cancel.clone(),
started: self.started.clone(),
}
}
}
impl CompileTask {
/// Convenience helper to create a task with a `Notify` that will be signalled
/// when the worker actually starts compiling the task. Returns the task,
/// the `Arc<Notify>` and a `CancellationToken` (also stored inside the task)
/// so callers can cancel the task after it's started.
pub fn with_started<Fp, S>(fingerprint: Fp, source: S) -> (Self, Arc<Notify>, CancellationToken)
where
Fp: Into<String>,
S: Into<Vec<u8>>,
{
let notify = Arc::new(Notify::new());
let token = CancellationToken::new();
let task = CompileTask {
fingerprint: fingerprint.into(),
source: source.into(),
cancel: token.clone(),
started: Some(notify.clone()),
};
(task, notify, token)
}
}
/// Handle for the running precompile service. Use `submit` to enqueue tasks and
/// `shutdown()` to gracefully stop the background worker and wait for in-flight compiles.
/// A function that performs compilation. It receives the CompileTask and returns a
/// Future resolving to the compiled artifact bytes or an error.
pub type CompileFn = dyn Fn(CompileTask) -> Pin<Box<dyn Future<Output = anyhow::Result<Vec<u8>>> + Send>>
+ Send
+ Sync;
pub struct CancelHandle {
// registry keyed by id for scalable cancellation
inner: Arc<AsyncMutex<std::collections::HashMap<usize, CancellationToken>>>,
}
impl CancelHandle {
pub fn new() -> Self {
Self {
inner: Arc::new(AsyncMutex::new(std::collections::HashMap::new())),
}
}
pub async fn register(&self, token: &CancellationToken) {
let mut g = self.inner.lock().await;
// use token address as a simple id; alternatively a UUID could be generated
let id = Arc::as_ptr(&token.inner) as usize;
g.insert(id, token.clone());
}
/// Cancel all registered tokens.
pub async fn cancel_all(&self) {
let g = self.inner.lock().await;
for (_k, t) in g.iter() {
t.cancel();
}
}
}
impl Clone for CancelHandle {
fn clone(&self) -> Self {
CancelHandle {
inner: self.inner.clone(),
}
}
}
pub struct PrecompileService {
tx: mpsc::Sender<(CompileTask, Option<oneshot::Sender<anyhow::Result<()>>>)>,
worker_handle: JoinHandle<()>,
/// Current approximate queue length (tasks enqueued but not yet taken by worker)
pub queue_len: Arc<AtomicUsize>,
/// Current number of in-flight compile tasks
pub in_flight: Arc<AtomicUsize>,
pub cancel_handle: CancelHandle,
}
impl PrecompileService {
pub async fn submit(&self, task: CompileTask) -> anyhow::Result<()> {
// increment queue length approximation, will be decremented by worker when taken
self.queue_len.fetch_add(1, Ordering::SeqCst);
let (resp_tx, resp_rx) = oneshot::channel();
self.tx
.send((task, Some(resp_tx)))
.await
.context("send task")?;
// resp_rx carries the result of the compile; forward the error if any
match resp_rx.await.context("await response")? {
Ok(_) => Ok(()),
Err(e) => Err(e),
}
}
/// Request graceful shutdown and wait for background worker to finish.
pub async fn shutdown(self) {
// Dropping self.tx will close the channel and cause the worker loop to exit
drop(self.tx);
// Await worker completion
let _ = self.worker_handle.await;
}
/// Return a snapshot of current metrics.
pub fn metrics_snapshot(&self) -> (usize, usize) {
(
self.queue_len.load(Ordering::SeqCst),
self.in_flight.load(Ordering::SeqCst),
)
}
}
/// Spawn a background precompile worker.
///
/// - `cache` is cloned (must be Arc) so multiple workers can share it.
/// - `ctx` is ExecutionContext used for profiler events.
/// - `concurrency_limit` caps concurrent compile tasks.
pub fn spawn_precompile_worker(
cache: Arc<tokio::sync::Mutex<KernelCache>>,
ctx: Arc<ExecutionContext>,
concurrency_limit: usize,
compile_fn: Arc<CompileFn>,
) -> PrecompileService {
let (tx, mut rx) =
mpsc::channel::<(CompileTask, Option<oneshot::Sender<anyhow::Result<()>>>)>(64);
let sem = Arc::new(Semaphore::new(concurrency_limit));
let queue_len = Arc::new(AtomicUsize::new(0));
let in_flight = Arc::new(AtomicUsize::new(0));
let cancel_handle = CancelHandle::new();
// Worker main loop: owns a JoinSet for tracking per-compile tasks
let qlen = queue_len.clone();
let inflight = in_flight.clone();
let ch = cancel_handle.clone();
let worker_handle = tokio::spawn(async move {
let mut joinset: JoinSet<()> = JoinSet::new();
while let Some((task, resp)) = rx.recv().await {
// taken from queue
qlen.fetch_sub(1, Ordering::SeqCst);
let permit = match sem.clone().acquire_owned().await {
Ok(p) => p,
Err(_) => break, // semaphore closed
};
let cache = cache.clone();
let ctx = ctx.clone();
// register token with group cancel handle
let _ = ch.register(&task.cancel).await;
// increment in-flight
inflight.fetch_add(1, Ordering::SeqCst);
// spawn a task per compile (tracked in joinset)
let cf = compile_fn.clone();
let inflight_clone = inflight.clone();
joinset.spawn(async move {
if let Some(p) = &ctx.profiler {
p.record_event("compile_start", 1);
}
if let Some(n) = &task.started {
n.notify_waiters();
}
// Call provided compile function; race it against cancellation
let compile_future = (cf)(task.clone());
let compiled = tokio::select! {
biased;
_ = task.cancel.cancelled() => {
// cancelled
if let Some(p) = &ctx.profiler {
p.record_event("compile_cancelled", 1);
}
Err(anyhow::anyhow!("cancelled"))
}
r = compile_future => r,
};
// Write to cache if compiled
let res = match compiled {
Ok(bytes) => {
let mut guard = cache.lock().await;
let write_res = guard.write_artifact(&task.fingerprint, &bytes);
if write_res.is_ok() {
if let Some(p) = &ctx.profiler {
p.record_event("compile_done", 1);
}
}
write_res
}
Err(e) => Err(e),
};
// respond if requested
if let Some(r) = resp {
let _ = r.send(res.map_err(|e| e));
}
// decrement in-flight
inflight_clone.fetch_sub(1, Ordering::SeqCst);
drop(permit); // release semaphore permit
});
// Note: we don't aggressively poll joinset here; spawned tasks are still tracked
}
// Channel closed: wait for all spawned compile tasks to finish
while let Some(_r) = joinset.join_next().await {}
});
PrecompileService {
tx,
worker_handle,
queue_len,
in_flight,
cancel_handle,
}
}