1use std::future::Future;
36use std::pin::Pin;
37use std::sync::atomic::{AtomicBool, Ordering};
38use std::sync::{Arc, Mutex};
39use std::task::{Context, Poll};
40use std::thread;
41use std::time::Duration;
42
43pub use firq_core::{
44 BackpressurePolicy, CancelResult, CloseMode, DequeueResult, EnqueueRejectReason, EnqueueResult,
45 EnqueueWithHandleResult, Priority, QueueTimeBucket, Scheduler, SchedulerConfig, SchedulerStats,
46 Task, TaskHandle, TenantCount, TenantKey,
47};
48use futures_core::Stream;
49use tokio::sync::{Semaphore, mpsc};
50
51const WORKER_DEQUEUE_TIMEOUT: Duration = Duration::from_millis(25);
52
53pub struct AsyncScheduler<T> {
55 inner: Arc<Scheduler<T>>,
56}
57
58impl<T> Clone for AsyncScheduler<T> {
59 fn clone(&self) -> Self {
60 Self {
61 inner: Arc::clone(&self.inner),
62 }
63 }
64}
65
66impl<T> AsyncScheduler<T> {
67 pub fn new(inner: Arc<Scheduler<T>>) -> Self {
69 Self { inner }
70 }
71
72 pub fn inner(&self) -> &Arc<Scheduler<T>> {
74 &self.inner
75 }
76
77 pub fn enqueue(&self, tenant: TenantKey, task: Task<T>) -> EnqueueResult {
79 self.inner.enqueue(tenant, task)
80 }
81
82 pub fn enqueue_with_handle(&self, tenant: TenantKey, task: Task<T>) -> EnqueueWithHandleResult {
84 self.inner.enqueue_with_handle(tenant, task)
85 }
86
87 pub fn try_dequeue(&self) -> DequeueResult<T> {
89 self.inner.try_dequeue()
90 }
91
92 pub fn cancel(&self, handle: TaskHandle) -> CancelResult {
94 self.inner.cancel(handle)
95 }
96
97 pub fn stats(&self) -> SchedulerStats {
99 self.inner.stats()
100 }
101
102 pub fn close(&self) {
104 self.inner.close_immediate();
105 }
106
107 pub fn close_immediate(&self) {
109 self.inner.close_immediate();
110 }
111
112 pub fn close_drain(&self) {
114 self.inner.close_drain();
115 }
116
117 pub fn close_with_mode(&self, mode: CloseMode) {
119 self.inner.close_with_mode(mode);
120 }
121
122 pub fn receiver(&self) -> AsyncReceiver<T> {
124 AsyncReceiver::new(self.clone())
125 }
126
127 pub fn stream(&self) -> AsyncStream<T> {
129 AsyncStream::new(self.clone())
130 }
131}
132
133impl<T: Send + 'static> AsyncScheduler<T> {
134 pub async fn dequeue_async(&self) -> DequeueResult<T> {
136 let scheduler = Arc::clone(&self.inner);
137 match tokio::task::spawn_blocking(move || scheduler.dequeue_blocking()).await {
138 Ok(result) => result,
139 Err(_) => DequeueResult::Closed,
140 }
141 }
142
143 pub fn receiver_with_worker(&self, buffer: usize) -> AsyncWorkerReceiver<T> {
145 AsyncWorkerReceiver::new(self.clone(), buffer)
146 }
147
148 pub fn stream_with_worker(&self, buffer: usize) -> AsyncWorkerReceiver<T> {
150 self.receiver_with_worker(buffer)
151 }
152}
153
154pub struct DequeueItem<T> {
156 pub tenant: TenantKey,
158 pub task: Task<T>,
160}
161
162#[derive(Clone)]
164pub struct AsyncReceiver<T> {
165 scheduler: AsyncScheduler<T>,
166}
167
168impl<T> AsyncReceiver<T> {
169 pub fn new(scheduler: AsyncScheduler<T>) -> Self {
171 Self { scheduler }
172 }
173}
174
175impl<T: Send + 'static> AsyncReceiver<T> {
176 pub async fn recv(&self) -> Option<DequeueItem<T>> {
178 loop {
179 match self.scheduler.dequeue_async().await {
180 DequeueResult::Task { tenant, task } => {
181 return Some(DequeueItem { tenant, task });
182 }
183 DequeueResult::Closed => return None,
184 DequeueResult::Empty => {
185 tokio::task::yield_now().await;
186 }
187 }
188 }
189 }
190
191 pub fn new_worker(scheduler: AsyncScheduler<T>, buffer: usize) -> AsyncWorkerReceiver<T> {
193 AsyncWorkerReceiver::new(scheduler, buffer)
194 }
195}
196
197struct WorkerThreadHandle {
198 shutdown: Arc<AtomicBool>,
199 handle: Mutex<Option<thread::JoinHandle<()>>>,
200}
201
202impl WorkerThreadHandle {
203 fn new(shutdown: Arc<AtomicBool>, handle: thread::JoinHandle<()>) -> Self {
204 Self {
205 shutdown,
206 handle: Mutex::new(Some(handle)),
207 }
208 }
209}
210
211impl Drop for WorkerThreadHandle {
212 fn drop(&mut self) {
213 self.shutdown.store(true, Ordering::Release);
214 let mut guard = self.handle.lock().expect("worker handle mutex poisoned");
215 if let Some(handle) = guard.take() {
216 let _ = handle.join();
217 }
218 }
219}
220
221pub struct AsyncWorkerReceiver<T> {
223 rx: mpsc::Receiver<DequeueItem<T>>,
224 _worker: WorkerThreadHandle,
225}
226
227impl<T: Send + 'static> AsyncWorkerReceiver<T> {
228 pub fn new(scheduler: AsyncScheduler<T>, buffer: usize) -> Self {
230 let buffer = buffer.max(1);
231 let (tx, rx) = mpsc::channel(buffer);
232 let shutdown = Arc::new(AtomicBool::new(false));
233 let worker_shutdown = Arc::clone(&shutdown);
234 let core = Arc::clone(scheduler.inner());
235
236 let handle = thread::spawn(move || {
237 while !worker_shutdown.load(Ordering::Acquire) {
238 match core.dequeue_blocking_timeout(WORKER_DEQUEUE_TIMEOUT) {
239 DequeueResult::Task { tenant, task } => {
240 if tx.blocking_send(DequeueItem { tenant, task }).is_err() {
241 break;
242 }
243 }
244 DequeueResult::Closed => break,
245 DequeueResult::Empty => {}
246 }
247 }
248 });
249
250 Self {
251 rx,
252 _worker: WorkerThreadHandle::new(shutdown, handle),
253 }
254 }
255
256 pub async fn recv(&mut self) -> Option<DequeueItem<T>> {
258 self.rx.recv().await
259 }
260}
261
262impl<T> Drop for AsyncWorkerReceiver<T> {
263 fn drop(&mut self) {
264 self.rx.close();
265 }
266}
267
268impl<T> Stream for AsyncWorkerReceiver<T> {
269 type Item = DequeueItem<T>;
270
271 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
272 self.get_mut().rx.poll_recv(cx)
273 }
274}
275
276pub struct AsyncStream<T> {
278 scheduler: AsyncScheduler<T>,
279 pending: Option<Pin<Box<dyn Future<Output = DequeueResult<T>> + Send>>>,
280}
281
282impl<T> AsyncStream<T> {
283 pub fn new(scheduler: AsyncScheduler<T>) -> Self {
285 Self {
286 scheduler,
287 pending: None,
288 }
289 }
290}
291
292impl<T: Send + 'static> Stream for AsyncStream<T> {
293 type Item = DequeueItem<T>;
294
295 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
296 let this = self.get_mut();
297 if this.pending.is_none() {
298 let scheduler = this.scheduler.clone();
299 this.pending = Some(Box::pin(async move { scheduler.dequeue_async().await }));
300 }
301
302 let pending = match this.pending.as_mut() {
303 Some(pending) => pending,
304 None => return Poll::Pending,
305 };
306
307 match pending.as_mut().poll(cx) {
308 Poll::Pending => Poll::Pending,
309 Poll::Ready(result) => {
310 this.pending = None;
311 match result {
312 DequeueResult::Task { tenant, task } => {
313 Poll::Ready(Some(DequeueItem { tenant, task }))
314 }
315 DequeueResult::Closed => Poll::Ready(None),
316 DequeueResult::Empty => {
317 cx.waker().wake_by_ref();
318 Poll::Pending
319 }
320 }
321 }
322 }
323 }
324}
325
326pub struct Dispatcher<T> {
327 scheduler: AsyncScheduler<T>,
328 semaphore: Arc<Semaphore>,
329 max_in_flight: usize,
330}
331
332impl<T> Dispatcher<T> {
333 pub fn new(scheduler: AsyncScheduler<T>, max_in_flight: usize) -> Self {
335 let max_in_flight = max_in_flight.max(1);
336 Self {
337 scheduler,
338 semaphore: Arc::new(Semaphore::new(max_in_flight)),
339 max_in_flight,
340 }
341 }
342}
343
344impl<T: Send + 'static> Dispatcher<T> {
345 pub async fn run<F, Fut>(&self, handler: F)
347 where
348 F: Fn(DequeueItem<T>) -> Fut + Send + Sync + 'static,
349 Fut: Future<Output = ()> + Send + 'static,
350 {
351 let handler = Arc::new(handler);
352 loop {
353 match self.scheduler.dequeue_async().await {
354 DequeueResult::Task { tenant, task } => {
355 let permit = match Arc::clone(&self.semaphore).acquire_owned().await {
356 Ok(permit) => permit,
357 Err(_) => break,
358 };
359 let handler = Arc::clone(&handler);
360 tokio::spawn(async move {
361 handler(DequeueItem { tenant, task }).await;
362 drop(permit);
363 });
364 }
365 DequeueResult::Closed => break,
366 DequeueResult::Empty => {
367 tokio::task::yield_now().await;
368 }
369 }
370 }
371
372 let _ = self.semaphore.acquire_many(self.max_in_flight as u32).await;
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use futures_util::StreamExt;
380 use std::collections::HashMap;
381 use std::sync::Arc;
382 use std::sync::atomic::{AtomicU64, Ordering};
383 use std::time::{Duration, Instant};
384
385 fn config() -> SchedulerConfig {
386 SchedulerConfig {
387 shards: 2,
388 max_global: 128,
389 max_per_tenant: 128,
390 quantum: 1,
391 quantum_by_tenant: HashMap::new(),
392 quantum_provider: None,
393 backpressure: BackpressurePolicy::Reject,
394 backpressure_by_tenant: HashMap::new(),
395 top_tenants_capacity: 0,
396 }
397 }
398
399 fn task(payload: u64) -> Task<u64> {
400 Task {
401 payload,
402 enqueue_ts: Instant::now(),
403 deadline: None,
404 priority: Priority::Normal,
405 cost: 1,
406 }
407 }
408
409 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
410 async fn async_scheduler_enqueue_cancel_roundtrip() {
411 let scheduler = AsyncScheduler::new(Arc::new(Scheduler::new(config())));
412 let tenant = TenantKey::from(1);
413
414 let handle = match scheduler.enqueue_with_handle(tenant, task(1)) {
415 EnqueueWithHandleResult::Enqueued(handle) => handle,
416 other => panic!("expected handle, got {:?}", other),
417 };
418 assert!(matches!(scheduler.cancel(handle), CancelResult::Cancelled));
419 }
420
421 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
422 async fn async_receiver_receives_items() {
423 let scheduler = AsyncScheduler::new(Arc::new(Scheduler::new(config())));
424 let tenant = TenantKey::from(42);
425 let _ = scheduler.enqueue(tenant, task(7));
426
427 let item = scheduler.receiver().recv().await.expect("item");
428 assert_eq!(item.tenant, tenant);
429 assert_eq!(item.task.payload, 7);
430 }
431
432 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
433 async fn async_stream_yields_items() {
434 let scheduler = AsyncScheduler::new(Arc::new(Scheduler::new(config())));
435 let tenant = TenantKey::from(3);
436 let _ = scheduler.enqueue(tenant, task(11));
437
438 let mut stream = scheduler.stream();
439 let item = stream.next().await.expect("stream item");
440 assert_eq!(item.tenant, tenant);
441 assert_eq!(item.task.payload, 11);
442 }
443
444 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
445 async fn async_worker_receiver_receives_items() {
446 let scheduler = AsyncScheduler::new(Arc::new(Scheduler::new(config())));
447 let tenant = TenantKey::from(6);
448 let _ = scheduler.enqueue(tenant, task(12));
449
450 let mut receiver = scheduler.receiver_with_worker(16);
451 let item = tokio::time::timeout(Duration::from_secs(1), receiver.recv())
452 .await
453 .expect("worker recv timed out")
454 .expect("worker recv should yield item");
455 assert_eq!(item.tenant, tenant);
456 assert_eq!(item.task.payload, 12);
457
458 scheduler.close();
459 }
460
461 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
462 async fn async_worker_receiver_observes_basic_fair_order() {
463 let mut cfg = config();
464 cfg.shards = 1;
465 cfg.quantum = 1;
466 let scheduler = AsyncScheduler::new(Arc::new(Scheduler::new(cfg)));
467
468 let tenant_a = TenantKey::from(1);
469 let tenant_b = TenantKey::from(2);
470 let _ = scheduler.enqueue(tenant_a, task(1));
471 let _ = scheduler.enqueue(tenant_a, task(2));
472 let _ = scheduler.enqueue(tenant_b, task(3));
473 let _ = scheduler.enqueue(tenant_b, task(4));
474
475 let mut receiver = scheduler.receiver_with_worker(16);
476 let mut observed = Vec::new();
477 for _ in 0..4 {
478 let item = tokio::time::timeout(Duration::from_secs(1), receiver.recv())
479 .await
480 .expect("worker recv timed out")
481 .expect("expected dequeued item");
482 observed.push((item.tenant.as_u64(), item.task.payload));
483 }
484
485 assert_eq!(observed, vec![(1, 1), (2, 3), (1, 2), (2, 4)]);
486 scheduler.close();
487 }
488
489 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
490 async fn async_worker_receiver_drop_stops_worker() {
491 let scheduler = AsyncScheduler::new(Arc::new(Scheduler::<u64>::new(config())));
492
493 let start = Instant::now();
494 {
495 let _receiver = scheduler.receiver_with_worker(8);
496 }
497 assert!(
498 start.elapsed() < Duration::from_secs(1),
499 "worker drop should join promptly"
500 );
501 }
502
503 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
504 async fn dispatcher_recovers_permits_after_panic() {
505 let scheduler = AsyncScheduler::new(Arc::new(Scheduler::new(config())));
506 let tenant = TenantKey::from(5);
507 let _ = scheduler.enqueue(tenant, task(1));
508 let _ = scheduler.enqueue(tenant, task(2));
509
510 let dispatcher = Dispatcher::new(scheduler.clone(), 1);
511 let served = Arc::new(AtomicU64::new(0));
512 let served_clone = Arc::clone(&served);
513
514 let runner = tokio::spawn(async move {
515 dispatcher
516 .run(move |item| {
517 let served = Arc::clone(&served_clone);
518 async move {
519 if item.task.payload == 1 {
520 panic!("simulated panic");
521 }
522 served.fetch_add(1, Ordering::Relaxed);
523 }
524 })
525 .await;
526 });
527
528 tokio::time::sleep(std::time::Duration::from_millis(120)).await;
529 scheduler.close();
530 let _ = runner.await;
531
532 assert_eq!(
533 served.load(Ordering::Relaxed),
534 1,
535 "second task should execute despite panic in first task"
536 );
537 }
538
539 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
540 #[ignore = "measurement helper for dequeue_async spawn_blocking overhead"]
541 async fn measure_dequeue_async_spawn_blocking_cost() {
542 let mut cfg = config();
543 cfg.max_global = 1_024;
544 cfg.max_per_tenant = 1_024;
545 let scheduler = AsyncScheduler::new(Arc::new(Scheduler::new(cfg)));
546 let tenant = TenantKey::from(9);
547 let samples = 512u64;
548
549 for i in 0..samples {
550 let result = scheduler.enqueue(tenant, task(i));
551 assert!(matches!(result, EnqueueResult::Enqueued));
552 }
553
554 let start = Instant::now();
555 for _ in 0..samples {
556 let result = scheduler.dequeue_async().await;
557 assert!(matches!(result, DequeueResult::Task { .. }));
558 }
559 let elapsed = start.elapsed();
560 let avg = elapsed / samples as u32;
561 println!(
562 "dequeue_async_spawn_blocking: samples={} total_ms={:.3} avg_us={:.3}",
563 samples,
564 elapsed.as_secs_f64() * 1_000.0,
565 duration_to_us(avg)
566 );
567 }
568
569 fn duration_to_us(duration: Duration) -> f64 {
570 duration.as_secs_f64() * 1_000_000.0
571 }
572}