Skip to main content

firq_tower/
lib.rs

1//! Tower integration for Firq scheduling.
2//!
3//! `firq-tower` exposes a configurable layer that:
4//! - extracts a tenant key from requests
5//! - enqueues requests through Firq
6//! - supports cancellation before execution turn
7//! - enforces in-flight limits with a semaphore
8//! - maps scheduling rejections to HTTP-friendly errors
9//!
10//! # Example (header-style tenant extraction)
11//!
12//! ```rust,no_run
13//! use firq_tower::{Firq, TenantKey};
14//! use std::collections::HashMap;
15//! use std::convert::Infallible;
16//! use std::future::Ready;
17//! use std::task::{Context, Poll};
18//! use tower::{Service, ServiceBuilder};
19//!
20//! #[derive(Clone)]
21//! struct Request {
22//!     headers: HashMap<&'static str, &'static str>,
23//!     body: &'static str,
24//! }
25//!
26//! #[derive(Clone)]
27//! struct EchoService;
28//!
29//! impl Service<Request> for EchoService {
30//!     type Response = &'static str;
31//!     type Error = Infallible;
32//!     type Future = Ready<Result<Self::Response, Self::Error>>;
33//!     fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
34//!         Poll::Ready(Ok(()))
35//!     }
36//!     fn call(&mut self, req: Request) -> Self::Future {
37//!         std::future::ready(Ok(req.body))
38//!     }
39//! }
40//!
41//! let layer = Firq::new().build(|req: &Request| {
42//!     req.headers
43//!         .get("x-tenant-id")
44//!         .and_then(|raw| raw.parse::<u64>().ok())
45//!         .map(TenantKey::from)
46//!         .unwrap_or(TenantKey::from(0))
47//! });
48//! let _svc = ServiceBuilder::new().layer(layer).service(EchoService);
49//! ```
50
51mod builder;
52pub use builder::Firq;
53pub use firq_async::{
54    self, AsyncScheduler, BackpressurePolicy, CancelResult, CloseMode, DequeueResult,
55    EnqueueRejectReason, EnqueueResult, EnqueueWithHandleResult, Priority, QueueTimeBucket,
56    Scheduler, SchedulerConfig, SchedulerStats, Task, TaskHandle, TenantCount, TenantKey,
57};
58
59use std::any::Any;
60use std::fmt;
61use std::future::Future;
62use std::marker::PhantomData;
63use std::pin::Pin;
64use std::sync::Arc;
65use std::sync::Mutex;
66use std::sync::atomic::{AtomicBool, Ordering};
67use std::task::{Context, Poll};
68
69use tokio::sync::{Semaphore, oneshot};
70use tower_service::Service;
71
72pub(crate) type ErasedDeadlineExtractor =
73    Arc<dyn Fn(&dyn Any) -> Option<std::time::Instant> + Send + Sync>;
74
75/// Function used to map enqueue rejections into HTTP-facing errors.
76pub type RejectionMapper = Arc<dyn Fn(EnqueueRejectReason) -> FirqHttpRejection + Send + Sync>;
77
78/// Structured HTTP rejection payload produced by the layer.
79#[derive(Clone, Debug)]
80pub struct FirqHttpRejection {
81    /// HTTP status code.
82    pub status: u16,
83    /// Stable machine-readable error code.
84    pub code: &'static str,
85    /// Human-readable error message.
86    pub message: &'static str,
87    /// Underlying scheduling rejection reason.
88    pub reason: EnqueueRejectReason,
89}
90
91/// Default mapping from scheduler rejection reasons to HTTP responses.
92pub fn default_rejection_mapper(reason: EnqueueRejectReason) -> FirqHttpRejection {
93    match reason {
94        EnqueueRejectReason::TenantFull => FirqHttpRejection {
95            status: 429,
96            code: "tenant_full",
97            message: "tenant queue limit reached",
98            reason,
99        },
100        EnqueueRejectReason::GlobalFull => FirqHttpRejection {
101            status: 503,
102            code: "global_full",
103            message: "scheduler global capacity reached",
104            reason,
105        },
106        EnqueueRejectReason::Timeout => FirqHttpRejection {
107            status: 503,
108            code: "timeout",
109            message: "request timed out waiting for scheduler capacity",
110            reason,
111        },
112    }
113}
114
115/// Internal permit payload sent to the background dequeue worker.
116pub struct FirqPermit {
117    tx: oneshot::Sender<tokio::sync::OwnedSemaphorePermit>,
118}
119
120/// Layer/service error type.
121#[derive(Debug)]
122pub enum FirqError<E> {
123    /// Error returned by the wrapped inner service.
124    Service(E),
125    /// Request rejected by scheduler policy.
126    Rejected(FirqHttpRejection),
127    /// Scheduler has been closed.
128    Closed,
129    /// Failed waiting for permit (expired/dropped/cancelled).
130    PermitError,
131}
132
133impl<E: fmt::Display> fmt::Display for FirqError<E> {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        match self {
136            FirqError::Service(e) => write!(f, "Service error: {}", e),
137            FirqError::Rejected(rejection) => write!(
138                f,
139                "Request rejected: status={} code={} reason={:?}",
140                rejection.status, rejection.code, rejection.reason
141            ),
142            FirqError::Closed => write!(f, "Scheduler closed"),
143            FirqError::PermitError => {
144                write!(f, "Permit acquisition failed (task expired or dropped)")
145            }
146        }
147    }
148}
149
150impl<E: std::error::Error + 'static> std::error::Error for FirqError<E> {
151    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
152        match self {
153            FirqError::Service(e) => Some(e),
154            _ => None,
155        }
156    }
157}
158
159pub trait KeyExtractor<Request>: Clone {
160    /// Extracts a tenant key from an incoming request.
161    fn extract(&self, req: &Request) -> TenantKey;
162}
163
164impl<F, Request> KeyExtractor<Request> for F
165where
166    F: Fn(&Request) -> TenantKey + Clone,
167{
168    fn extract(&self, req: &Request) -> TenantKey {
169        (self)(req)
170    }
171}
172
173/// Tower `Layer` implementation configured with a Firq scheduler.
174pub struct FirqLayer<Request, K> {
175    scheduler: AsyncScheduler<FirqPermit>,
176    extractor: K,
177    in_flight: Arc<Semaphore>,
178    max_in_flight: usize,
179    _worker: Arc<BackgroundWorker>,
180    deadline_extractor: Option<ErasedDeadlineExtractor>,
181    rejection_mapper: RejectionMapper,
182    _marker: PhantomData<Request>,
183}
184
185impl<Request, K: Clone> Clone for FirqLayer<Request, K> {
186    fn clone(&self) -> Self {
187        Self {
188            scheduler: self.scheduler.clone(),
189            extractor: self.extractor.clone(),
190            in_flight: Arc::clone(&self.in_flight),
191            max_in_flight: self.max_in_flight,
192            _worker: Arc::clone(&self._worker),
193            deadline_extractor: self.deadline_extractor.clone(),
194            rejection_mapper: Arc::clone(&self.rejection_mapper),
195            _marker: PhantomData,
196        }
197    }
198}
199
200impl<Request, K> FirqLayer<Request, K> {
201    /// Creates a new `FirqLayer`.
202    pub fn new(
203        scheduler: AsyncScheduler<FirqPermit>,
204        extractor: K,
205        in_flight_limit: usize,
206        deadline_extractor: Option<ErasedDeadlineExtractor>,
207        rejection_mapper: RejectionMapper,
208    ) -> Self {
209        let worker_scheduler = scheduler.inner().clone();
210        let in_flight = Arc::new(Semaphore::new(in_flight_limit.max(1)));
211        let worker_in_flight = Arc::clone(&in_flight);
212        let shutdown = Arc::new(AtomicBool::new(false));
213        let worker_shutdown = Arc::clone(&shutdown);
214        let worker_handle = std::thread::spawn(move || {
215            let runtime = tokio::runtime::Builder::new_current_thread()
216                .enable_time()
217                .build()
218                .expect("failed to build tower worker runtime");
219            loop {
220                if worker_shutdown.load(Ordering::Acquire) {
221                    break;
222                }
223                match worker_scheduler.dequeue_blocking() {
224                    DequeueResult::Task { task, .. } => {
225                        let permit =
226                            match runtime.block_on(worker_in_flight.clone().acquire_owned()) {
227                                Ok(permit) => permit,
228                                Err(_) => break,
229                            };
230                        let _ = task.payload.tx.send(permit);
231                    }
232                    DequeueResult::Closed => break,
233                    DequeueResult::Empty => {}
234                }
235            }
236        });
237
238        let worker = Arc::new(BackgroundWorker::new(
239            scheduler.clone(),
240            Arc::clone(&in_flight),
241            shutdown,
242            worker_handle,
243        ));
244
245        Self {
246            scheduler,
247            extractor,
248            in_flight,
249            max_in_flight: in_flight_limit.max(1),
250            _worker: worker,
251            deadline_extractor,
252            rejection_mapper,
253            _marker: PhantomData,
254        }
255    }
256
257    pub fn scheduler(&self) -> &AsyncScheduler<FirqPermit> {
258        &self.scheduler
259    }
260
261    /// Returns configured in-flight execution limit.
262    pub fn in_flight_limit(&self) -> usize {
263        self.max_in_flight
264    }
265
266    /// Returns currently active in-flight executions.
267    pub fn in_flight_active(&self) -> usize {
268        self.max_in_flight
269            .saturating_sub(self.in_flight.available_permits())
270    }
271
272    /// Returns current in-flight saturation ratio in `[0.0, 1.0+]`.
273    pub fn in_flight_saturation_ratio(&self) -> f64 {
274        if self.max_in_flight == 0 {
275            0.0
276        } else {
277            self.in_flight_active() as f64 / self.max_in_flight as f64
278        }
279    }
280}
281
282impl<Request> FirqLayer<Request, ()> {
283    /// Returns the `Firq` builder.
284    pub fn builder() -> Firq {
285        Firq::new()
286    }
287}
288
289impl<S, Request, K> tower::Layer<S> for FirqLayer<Request, K>
290where
291    K: KeyExtractor<Request> + Clone,
292{
293    type Service = FirqService<S, K, Request>;
294
295    fn layer(&self, inner: S) -> Self::Service {
296        FirqService {
297            inner,
298            scheduler: self.scheduler.clone(),
299            extractor: self.extractor.clone(),
300            in_flight: Arc::clone(&self.in_flight),
301            _worker: Arc::clone(&self._worker),
302            deadline_extractor: self.deadline_extractor.clone(),
303            rejection_mapper: Arc::clone(&self.rejection_mapper),
304            _marker: PhantomData,
305        }
306    }
307}
308
309/// Tower `Service` produced by [`FirqLayer`].
310pub struct FirqService<S, K, Request> {
311    inner: S,
312    scheduler: AsyncScheduler<FirqPermit>,
313    extractor: K,
314    in_flight: Arc<Semaphore>,
315    _worker: Arc<BackgroundWorker>,
316    deadline_extractor: Option<ErasedDeadlineExtractor>,
317    rejection_mapper: RejectionMapper,
318    _marker: PhantomData<Request>,
319}
320
321impl<S, K, Request> Clone for FirqService<S, K, Request>
322where
323    S: Clone,
324    K: Clone,
325{
326    fn clone(&self) -> Self {
327        Self {
328            inner: self.inner.clone(),
329            scheduler: self.scheduler.clone(),
330            extractor: self.extractor.clone(),
331            in_flight: Arc::clone(&self.in_flight),
332            _worker: Arc::clone(&self._worker),
333            deadline_extractor: self.deadline_extractor.clone(),
334            rejection_mapper: Arc::clone(&self.rejection_mapper),
335            _marker: PhantomData,
336        }
337    }
338}
339
340impl<S, K, Request> Service<Request> for FirqService<S, K, Request>
341where
342    S: Service<Request> + Clone + Send + 'static,
343    S::Future: Send + 'static,
344    K: KeyExtractor<Request> + Send + 'static,
345    Request: Send + 'static,
346{
347    type Response = S::Response;
348    type Error = FirqError<S::Error>;
349    type Future = Pin<Box<dyn Future<Output = Result<S::Response, Self::Error>> + Send>>;
350
351    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
352        self.inner.poll_ready(cx).map_err(FirqError::Service)
353    }
354
355    fn call(&mut self, req: Request) -> Self::Future {
356        let tenant = self.extractor.extract(&req);
357        let deadline = self
358            .deadline_extractor
359            .as_ref()
360            .and_then(|extractor| extractor(&req as &dyn Any));
361
362        let (tx, rx) = oneshot::channel();
363        let task = Task {
364            payload: FirqPermit { tx },
365            enqueue_ts: std::time::Instant::now(),
366            deadline,
367            priority: Default::default(),
368            cost: 1,
369        };
370
371        let scheduler = self.scheduler.clone();
372        let mapper = Arc::clone(&self.rejection_mapper);
373
374        match scheduler.enqueue_with_handle(tenant, task) {
375            EnqueueWithHandleResult::Enqueued(handle) => {
376                let mut inner = self.inner.clone();
377                Box::pin(async move {
378                    let mut guard = PendingCancelGuard::new(scheduler.clone(), handle);
379                    let permit = rx.await.map_err(|_| FirqError::PermitError)?;
380                    guard.disarm();
381
382                    let response = inner.call(req).await.map_err(FirqError::Service)?;
383                    drop(permit);
384                    Ok(response)
385                })
386            }
387            EnqueueWithHandleResult::Rejected(reason) => {
388                let rejection = mapper(reason);
389                Box::pin(async move { Err(FirqError::Rejected(rejection)) })
390            }
391            EnqueueWithHandleResult::Closed => Box::pin(async { Err(FirqError::Closed) }),
392        }
393    }
394}
395
396struct PendingCancelGuard {
397    scheduler: AsyncScheduler<FirqPermit>,
398    handle: Option<TaskHandle>,
399}
400
401impl PendingCancelGuard {
402    fn new(scheduler: AsyncScheduler<FirqPermit>, handle: TaskHandle) -> Self {
403        Self {
404            scheduler,
405            handle: Some(handle),
406        }
407    }
408
409    fn disarm(&mut self) {
410        self.handle = None;
411    }
412}
413
414impl Drop for PendingCancelGuard {
415    fn drop(&mut self) {
416        if let Some(handle) = self.handle.take() {
417            let _ = self.scheduler.cancel(handle);
418        }
419    }
420}
421
422struct BackgroundWorker {
423    scheduler: AsyncScheduler<FirqPermit>,
424    in_flight: Arc<Semaphore>,
425    shutdown: Arc<AtomicBool>,
426    handle: Mutex<Option<std::thread::JoinHandle<()>>>,
427}
428
429impl BackgroundWorker {
430    fn new(
431        scheduler: AsyncScheduler<FirqPermit>,
432        in_flight: Arc<Semaphore>,
433        shutdown: Arc<AtomicBool>,
434        handle: std::thread::JoinHandle<()>,
435    ) -> Self {
436        Self {
437            scheduler,
438            in_flight,
439            shutdown,
440            handle: Mutex::new(Some(handle)),
441        }
442    }
443}
444
445impl Drop for BackgroundWorker {
446    fn drop(&mut self) {
447        self.shutdown.store(true, Ordering::Release);
448        self.in_flight.close();
449        self.scheduler.close_immediate();
450        let mut guard = self.handle.lock().expect("worker mutex poisoned");
451        if let Some(handle) = guard.take() {
452            let _ = handle.join();
453        }
454    }
455}