Skip to main content

loom_rs/
runtime.rs

1//! Loom runtime implementation.
2//!
3//! The runtime combines a tokio async runtime with a rayon thread pool,
4//! both configured with CPU pinning.
5//!
6//! # Performance
7//!
8//! This module is designed for zero unnecessary overhead:
9//! - `spawn_async()`: ~10ns overhead (TaskTracker token only)
10//! - `spawn_compute()`: ~100-500ns (cross-thread signaling, 0 bytes after warmup)
11//! - `install()`: ~0ns (zero overhead, direct rayon access)
12//!
13//! # Thread Architecture
14//!
15//! ```text
16//! ┌─────────────────────────────────────────────────────────────┐
17//! │                     LoomRuntime                              │
18//! │  pools: ComputePoolRegistry (per-type lock-free pools)      │
19//! │  (One pool per result type, shared across all threads)      │
20//! └─────────────────────────────────────────────────────────────┘
21//!          │ on_thread_start           │ start_handler
22//!          ▼                           ▼
23//! ┌─────────────────────┐     ┌─────────────────────┐
24//! │   Tokio Workers     │     │   Rayon Workers     │
25//! │  thread_local! {    │     │  thread_local! {    │
26//! │    RUNTIME: Weak<>  │     │    RUNTIME: Weak<>  │
27//! │  }                  │     │  }                  │
28//! └─────────────────────┘     └─────────────────────┘
29//! ```
30
31use crate::affinity::{pin_to_cpu, CpuAllocator};
32use crate::bridge::{PooledRayonTask, TaskState};
33use crate::config::LoomConfig;
34use crate::context::{clear_current_runtime, set_current_runtime};
35use crate::cpuset::{available_cpus, format_cpuset, parse_and_validate_cpuset};
36use crate::error::{LoomError, Result};
37use crate::mab::{Arm, ComputeHint, Context, FunctionKey, MabKnobs, MabScheduler};
38use crate::metrics::LoomMetrics;
39use crate::pool::ComputePoolRegistry;
40
41use std::future::Future;
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::{Arc, OnceLock, Weak};
44use std::time::Instant;
45use tokio::sync::Notify;
46use tokio_util::task::TaskTracker;
47use tracing::{debug, info, warn};
48
49/// State for tracking in-flight compute tasks.
50///
51/// Combines the task counter with a notification mechanism for efficient
52/// shutdown waiting (avoids spin loops).
53struct ComputeTaskState {
54    /// Number of tasks currently executing on rayon
55    count: AtomicUsize,
56    /// Notified when count reaches 0
57    notify: Notify,
58}
59
60impl ComputeTaskState {
61    fn new() -> Self {
62        Self {
63            count: AtomicUsize::new(0),
64            notify: Notify::new(),
65        }
66    }
67}
68
69/// Guard for tracking async task metrics.
70///
71/// Panic-safe: task_completed is called even if the future panics.
72struct AsyncMetricsGuard {
73    inner: Arc<LoomRuntimeInner>,
74}
75
76impl AsyncMetricsGuard {
77    fn new(inner: Arc<LoomRuntimeInner>) -> Self {
78        inner.prometheus_metrics.task_started();
79        Self { inner }
80    }
81}
82
83impl Drop for AsyncMetricsGuard {
84    fn drop(&mut self) {
85        self.inner.prometheus_metrics.task_completed();
86    }
87}
88
89/// Guard for tracking compute task state and metrics.
90///
91/// Panic-safe: executes even if the task closure panics.
92///
93/// SAFETY: The state lives in LoomRuntimeInner which outlives all rayon tasks
94/// because block_until_idle waits for compute_tasks to reach 0.
95struct ComputeTaskGuard {
96    state: *const ComputeTaskState,
97    metrics: *const LoomMetrics,
98}
99
100unsafe impl Send for ComputeTaskGuard {}
101
102impl ComputeTaskGuard {
103    /// Create a new guard, tracking submission in MAB metrics.
104    ///
105    /// This should be called BEFORE spawning on rayon.
106    fn new(state: &ComputeTaskState, metrics: &LoomMetrics) -> Self {
107        state.count.fetch_add(1, Ordering::Relaxed);
108        metrics.rayon_submitted();
109        Self {
110            state: state as *const ComputeTaskState,
111            metrics: metrics as *const LoomMetrics,
112        }
113    }
114
115    /// Mark that the rayon task has started executing.
116    ///
117    /// This should be called at the START of the rayon closure.
118    fn started(&self) {
119        // SAFETY: metrics outlives rayon tasks
120        unsafe {
121            (*self.metrics).rayon_started();
122        }
123    }
124}
125
126impl Drop for ComputeTaskGuard {
127    fn drop(&mut self) {
128        // SAFETY: state and metrics outlive rayon tasks due to shutdown waiting
129        unsafe {
130            // Track MAB metrics completion (panic-safe)
131            (*self.metrics).rayon_completed();
132
133            let prev = (*self.state).count.fetch_sub(1, Ordering::Release);
134            if prev == 1 {
135                // Count just went from 1 to 0, notify waiters
136                (*self.state).notify.notify_waiters();
137            }
138        }
139    }
140}
141
142/// A bespoke thread pool runtime combining tokio and rayon with CPU pinning.
143///
144/// The runtime provides:
145/// - A tokio async runtime for I/O-bound work
146/// - A rayon thread pool for CPU-bound parallel work
147/// - Automatic CPU pinning for both runtimes
148/// - A task tracker for graceful shutdown
149/// - Zero-allocation compute spawning after warmup
150///
151/// # Performance Guarantees
152///
153/// | Method | Overhead | Allocations | Tracked |
154/// |--------|----------|-------------|---------|
155/// | `spawn_async()` | ~10ns | Token only | Yes |
156/// | `spawn_compute()` | ~100-500ns | 0 bytes (after warmup) | Yes |
157/// | `install()` | ~0ns | None | No |
158/// | `rayon_pool()` | 0ns | None | No |
159/// | `tokio_handle()` | 0ns | None | No |
160///
161/// # Examples
162///
163/// ```ignore
164/// use loom_rs::LoomBuilder;
165///
166/// let runtime = LoomBuilder::new()
167///     .prefix("myapp")
168///     .tokio_threads(2)
169///     .rayon_threads(6)
170///     .build()?;
171///
172/// runtime.block_on(async {
173///     // Spawn tracked async I/O task
174///     let io_handle = runtime.spawn_async(async {
175///         fetch_data().await
176///     });
177///
178///     // Spawn tracked compute task and await result
179///     let result = runtime.spawn_compute(|| {
180///         expensive_computation()
181///     }).await;
182///
183///     // Zero-overhead parallel iterators (within tracked context)
184///     let processed = runtime.install(|| {
185///         data.par_iter().map(|x| process(x)).collect()
186///     });
187/// });
188///
189/// // Graceful shutdown from main thread
190/// runtime.block_until_idle();
191/// ```
192pub struct LoomRuntime {
193    pub(crate) inner: Arc<LoomRuntimeInner>,
194}
195
196/// Inner state shared with thread-locals.
197///
198/// This is Arc-wrapped and shared with tokio/rayon worker threads via thread-local
199/// storage, enabling `current_runtime()` to work from any managed thread.
200pub(crate) struct LoomRuntimeInner {
201    config: LoomConfig,
202    tokio_runtime: tokio::runtime::Runtime,
203    pub(crate) rayon_pool: rayon::ThreadPool,
204    task_tracker: TaskTracker,
205    /// Track in-flight rayon tasks for graceful shutdown
206    compute_state: ComputeTaskState,
207    /// Per-type object pools for zero-allocation spawn_compute
208    pub(crate) pools: ComputePoolRegistry,
209    /// Number of tokio worker threads
210    pub(crate) tokio_threads: usize,
211    /// Number of rayon worker threads
212    pub(crate) rayon_threads: usize,
213    /// CPUs allocated to tokio workers
214    pub(crate) tokio_cpus: Vec<usize>,
215    /// CPUs allocated to rayon workers
216    pub(crate) rayon_cpus: Vec<usize>,
217    /// Lazily initialized shared MAB scheduler
218    mab_scheduler: OnceLock<Arc<MabScheduler>>,
219    /// MAB knobs configuration
220    pub(crate) mab_knobs: MabKnobs,
221    /// Prometheus metrics - single source of truth for all metrics
222    /// (serves both Prometheus exposition and MAB scheduling)
223    pub(crate) prometheus_metrics: LoomMetrics,
224}
225
226impl LoomRuntime {
227    /// Create a LoomRuntime from an existing inner reference.
228    ///
229    /// This does **not** create a new runtime; it only creates another
230    /// handle that points at the same `LoomRuntimeInner`. As a result,
231    /// multiple `LoomRuntime` values may refer to the same underlying
232    /// runtime state.
233    ///
234    /// This is intended for internal use by `current_runtime()` to wrap the
235    /// thread-local inner reference. Callers must **not** treat the returned
236    /// handle as an independently owned runtime for the purpose of shutdown
237    /// or teardown. Invoking shutdown-related methods from multiple wrappers
238    /// that share the same inner state may lead to unexpected behavior.
239    pub(crate) fn from_inner(inner: Arc<LoomRuntimeInner>) -> Self {
240        Self { inner }
241    }
242
243    /// Create a runtime from a configuration.
244    ///
245    /// This is typically called via `LoomBuilder::build()`.
246    pub(crate) fn from_config(config: LoomConfig, pool_size: usize) -> Result<Self> {
247        // Determine available CPUs
248        // Priority: CUDA device cpuset > user cpuset > all available CPUs
249        // Error if both cuda_device and cpuset are specified (mutually exclusive)
250        let cpus = {
251            #[cfg(feature = "cuda")]
252            {
253                // Check for conflicting configuration first
254                if config.cuda_device.is_some() && config.cpuset.is_some() {
255                    return Err(LoomError::CudaCpusetConflict);
256                }
257
258                if let Some(ref selector) = config.cuda_device {
259                    match crate::cuda::cpuset_for_cuda_device(selector)? {
260                        Some(cuda_cpus) => cuda_cpus,
261                        None => {
262                            // Could not determine CUDA locality, fall back to all CPUs
263                            available_cpus()
264                        }
265                    }
266                } else if let Some(ref cpuset_str) = config.cpuset {
267                    parse_and_validate_cpuset(cpuset_str)?
268                } else {
269                    available_cpus()
270                }
271            }
272            #[cfg(not(feature = "cuda"))]
273            {
274                if let Some(ref cpuset_str) = config.cpuset {
275                    parse_and_validate_cpuset(cpuset_str)?
276                } else {
277                    available_cpus()
278                }
279            }
280        };
281
282        if cpus.is_empty() {
283            return Err(LoomError::NoCpusAvailable);
284        }
285
286        let total_cpus = cpus.len();
287        let tokio_threads = config.effective_tokio_threads();
288        let rayon_threads = config.effective_rayon_threads(total_cpus);
289
290        // Validate we have enough CPUs
291        let total_threads = tokio_threads + rayon_threads;
292        if total_threads > total_cpus {
293            return Err(LoomError::InsufficientCpus {
294                requested: total_threads,
295                available: total_cpus,
296            });
297        }
298
299        // Split CPUs between tokio and rayon
300        let (tokio_cpus, rayon_cpus) = cpus.split_at(tokio_threads.min(cpus.len()));
301        let tokio_cpus = tokio_cpus.to_vec();
302        let rayon_cpus = if rayon_cpus.is_empty() {
303            // If we don't have dedicated rayon CPUs, share with tokio
304            tokio_cpus.clone()
305        } else {
306            rayon_cpus.to_vec()
307        };
308
309        info!(
310            prefix = %config.prefix,
311            tokio_threads,
312            rayon_threads,
313            total_cpus,
314            pool_size,
315            "building loom runtime"
316        );
317
318        // Use Arc<str> for prefix to avoid cloning on each thread start
319        let prefix: Arc<str> = config.prefix.as_str().into();
320
321        // Create the inner runtime first (without tokio/rayon)
322        // We'll use a two-phase approach with OnceCell-like pattern
323        let inner = Arc::new_cyclic(|weak: &Weak<LoomRuntimeInner>| {
324            let weak_clone = weak.clone();
325
326            // Build tokio runtime with thread-local injection
327            let tokio_runtime = Self::build_tokio_runtime(
328                &prefix,
329                tokio_threads,
330                tokio_cpus.clone(),
331                weak_clone.clone(),
332            )
333            .expect("failed to build tokio runtime");
334
335            // Build rayon pool with thread-local injection
336            let rayon_pool =
337                Self::build_rayon_pool(&prefix, rayon_threads, rayon_cpus.clone(), weak_clone)
338                    .expect("failed to build rayon pool");
339
340            // Extract MAB knobs, using defaults if not configured
341            let mab_knobs = config.mab_knobs.clone().unwrap_or_default();
342
343            // Create Prometheus metrics with the runtime's prefix
344            let prometheus_metrics = LoomMetrics::with_prefix(&config.prefix);
345
346            // Register with provided registry if available
347            if let Some(ref registry) = config.prometheus_registry {
348                if let Err(e) = prometheus_metrics.register(registry) {
349                    warn!(%e, "failed to register prometheus metrics");
350                }
351            }
352
353            LoomRuntimeInner {
354                config,
355                tokio_runtime,
356                rayon_pool,
357                task_tracker: TaskTracker::new(),
358                compute_state: ComputeTaskState::new(),
359                pools: ComputePoolRegistry::new(pool_size),
360                tokio_threads,
361                rayon_threads,
362                tokio_cpus,
363                rayon_cpus,
364                mab_scheduler: OnceLock::new(),
365                mab_knobs,
366                prometheus_metrics,
367            }
368        });
369
370        Ok(Self { inner })
371    }
372
373    fn build_tokio_runtime(
374        prefix: &Arc<str>,
375        num_threads: usize,
376        cpus: Vec<usize>,
377        runtime_weak: Weak<LoomRuntimeInner>,
378    ) -> Result<tokio::runtime::Runtime> {
379        let allocator = Arc::new(CpuAllocator::new(cpus));
380        let prefix_clone = Arc::clone(prefix);
381
382        // Thread name counter
383        let thread_counter = Arc::new(AtomicUsize::new(0));
384        let name_prefix = Arc::clone(prefix);
385
386        let start_weak = runtime_weak.clone();
387        let start_allocator = allocator.clone();
388        let start_prefix = prefix_clone.clone();
389
390        let runtime = tokio::runtime::Builder::new_multi_thread()
391            .worker_threads(num_threads)
392            .thread_name_fn(move || {
393                let id = thread_counter.fetch_add(1, Ordering::SeqCst);
394                format!("{}-tokio-{:04}", name_prefix, id)
395            })
396            .on_thread_start(move || {
397                // Pin CPU
398                let cpu_id = start_allocator.allocate();
399                if let Err(e) = pin_to_cpu(cpu_id) {
400                    warn!(%e, %start_prefix, cpu_id, "failed to pin tokio thread");
401                } else {
402                    debug!(cpu_id, %start_prefix, "pinned tokio thread to CPU");
403                }
404
405                // Inject runtime reference into thread-local
406                set_current_runtime(start_weak.clone());
407            })
408            .on_thread_stop(|| {
409                clear_current_runtime();
410            })
411            .enable_all()
412            .build()?;
413
414        Ok(runtime)
415    }
416
417    fn build_rayon_pool(
418        prefix: &Arc<str>,
419        num_threads: usize,
420        cpus: Vec<usize>,
421        runtime_weak: Weak<LoomRuntimeInner>,
422    ) -> Result<rayon::ThreadPool> {
423        let allocator = Arc::new(CpuAllocator::new(cpus));
424        let name_prefix = Arc::clone(prefix);
425
426        let start_weak = runtime_weak.clone();
427        let start_allocator = allocator.clone();
428        let start_prefix = Arc::clone(prefix);
429
430        let pool = rayon::ThreadPoolBuilder::new()
431            .num_threads(num_threads)
432            .thread_name(move |i| format!("{}-rayon-{:04}", name_prefix, i))
433            .start_handler(move |thread_index| {
434                // Pin CPU
435                let cpu_id = start_allocator.allocate();
436                debug!(thread_index, cpu_id, %start_prefix, "rayon thread starting");
437                if let Err(e) = pin_to_cpu(cpu_id) {
438                    warn!(%e, %start_prefix, cpu_id, thread_index, "failed to pin rayon thread");
439                }
440
441                // Inject runtime reference into thread-local
442                set_current_runtime(start_weak.clone());
443            })
444            .exit_handler(|_thread_index| {
445                clear_current_runtime();
446            })
447            .build()?;
448
449        Ok(pool)
450    }
451
452    /// Get the resolved configuration.
453    pub fn config(&self) -> &LoomConfig {
454        &self.inner.config
455    }
456
457    /// Get the tokio runtime handle.
458    ///
459    /// This can be used to spawn untracked tasks or enter the runtime context.
460    /// For tracked async tasks, prefer `spawn_async()`.
461    ///
462    /// # Performance
463    ///
464    /// Zero overhead - returns a reference.
465    pub fn tokio_handle(&self) -> &tokio::runtime::Handle {
466        self.inner.tokio_runtime.handle()
467    }
468
469    /// Get the rayon thread pool.
470    ///
471    /// This can be used to execute parallel iterators or spawn untracked work directly.
472    /// For tracked compute tasks, prefer `spawn_compute()`.
473    /// For zero-overhead parallel iterators, prefer `install()`.
474    ///
475    /// # Performance
476    ///
477    /// Zero overhead - returns a reference.
478    pub fn rayon_pool(&self) -> &rayon::ThreadPool {
479        &self.inner.rayon_pool
480    }
481
482    /// Get the task tracker for graceful shutdown.
483    ///
484    /// Use this to track spawned tasks and wait for them to complete.
485    pub fn task_tracker(&self) -> &TaskTracker {
486        &self.inner.task_tracker
487    }
488
489    /// Block on a future using the tokio runtime.
490    ///
491    /// This is the main entry point for running async code from the main thread.
492    /// The current runtime is available via `loom_rs::current_runtime()` within
493    /// the block_on scope.
494    ///
495    /// # Examples
496    ///
497    /// ```ignore
498    /// runtime.block_on(async {
499    ///     // Async code here
500    ///     // loom_rs::current_runtime() works here
501    /// });
502    /// ```
503    pub fn block_on<F: Future>(&self, f: F) -> F::Output {
504        // Set current runtime for the main thread during block_on
505        set_current_runtime(Arc::downgrade(&self.inner));
506        let result = self.inner.tokio_runtime.block_on(f);
507        clear_current_runtime();
508        result
509    }
510
511    /// Spawn a tracked async task on tokio.
512    ///
513    /// The task is tracked for graceful shutdown via `block_until_idle()`.
514    ///
515    /// # Performance
516    ///
517    /// Overhead: ~10ns (TaskTracker token only).
518    ///
519    /// # Examples
520    ///
521    /// ```ignore
522    /// runtime.block_on(async {
523    ///     let handle = runtime.spawn_async(async {
524    ///         // I/O-bound async work
525    ///         fetch_data().await
526    ///     });
527    ///
528    ///     let result = handle.await.unwrap();
529    /// });
530    /// ```
531    #[inline]
532    pub fn spawn_async<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
533    where
534        F: Future + Send + 'static,
535        F::Output: Send + 'static,
536    {
537        // Track task for MAB metrics (panic-safe via guard)
538        let metrics_guard = AsyncMetricsGuard::new(Arc::clone(&self.inner));
539        let token = self.inner.task_tracker.token();
540        self.inner.tokio_runtime.spawn(async move {
541            let _tracker = token;
542            let _metrics = metrics_guard;
543            future.await
544        })
545    }
546
547    /// Spawn CPU-bound work on rayon and await the result.
548    ///
549    /// The task is tracked for graceful shutdown via `block_until_idle()`.
550    /// Automatically uses per-type object pools for zero allocation after warmup.
551    ///
552    /// # Performance
553    ///
554    /// | State | Allocations | Overhead |
555    /// |-------|-------------|----------|
556    /// | Pool hit | 0 bytes | ~100-500ns |
557    /// | Pool miss | ~32 bytes | ~100-500ns |
558    /// | First call per type | Pool + state | ~1µs |
559    ///
560    /// For zero-overhead parallel iterators (within an already-tracked context),
561    /// use `install()` instead.
562    ///
563    /// # Examples
564    ///
565    /// ```ignore
566    /// runtime.block_on(async {
567    ///     let result = runtime.spawn_compute(|| {
568    ///         // CPU-intensive work
569    ///         expensive_computation()
570    ///     }).await;
571    /// });
572    /// ```
573    #[inline]
574    pub async fn spawn_compute<F, R>(&self, f: F) -> R
575    where
576        F: FnOnce() -> R + Send + 'static,
577        R: Send + 'static,
578    {
579        self.inner.spawn_compute(f).await
580    }
581
582    /// Spawn work with adaptive inline/offload decision.
583    ///
584    /// Uses MAB (Multi-Armed Bandit) to learn whether this function type should
585    /// run inline on tokio or offload to rayon. Good for handler patterns where
586    /// work duration varies by input.
587    ///
588    /// Unlike `spawn_compute()` which always offloads, this adaptively chooses
589    /// based on learned behavior and current system pressure.
590    ///
591    /// # Performance
592    ///
593    /// | Scenario | Behavior | Overhead |
594    /// |----------|----------|----------|
595    /// | Fast work | Inlines after learning | ~100ns (decision only) |
596    /// | Slow work | Offloads after learning | ~100-500ns (+ offload) |
597    /// | Cold start | Explores both arms | Variable |
598    ///
599    /// # Examples
600    ///
601    /// ```ignore
602    /// runtime.block_on(async {
603    ///     // MAB will learn whether this is fast or slow
604    ///     let result = runtime.spawn_adaptive(|| {
605    ///         process_item(item)
606    ///     }).await;
607    /// });
608    /// ```
609    pub async fn spawn_adaptive<F, R>(&self, f: F) -> R
610    where
611        F: FnOnce() -> R + Send + 'static,
612        R: Send + 'static,
613    {
614        self.spawn_adaptive_with_hint(ComputeHint::Unknown, f).await
615    }
616
617    /// Spawn with hint for cold-start guidance.
618    ///
619    /// The hint helps the scheduler make better initial decisions before it has
620    /// learned the actual execution time of this function type.
621    ///
622    /// # Hints
623    ///
624    /// - `ComputeHint::Low` - Expected < 50µs (likely inline-safe)
625    /// - `ComputeHint::Medium` - Expected 50-500µs (borderline)
626    /// - `ComputeHint::High` - Expected > 500µs (should test offload early)
627    /// - `ComputeHint::Unknown` - No hint (default exploration)
628    ///
629    /// # Examples
630    ///
631    /// ```ignore
632    /// use loom_rs::ComputeHint;
633    ///
634    /// runtime.block_on(async {
635    ///     // Hint that this is likely slow work
636    ///     let result = runtime.spawn_adaptive_with_hint(
637    ///         ComputeHint::High,
638    ///         || expensive_computation()
639    ///     ).await;
640    /// });
641    /// ```
642    pub async fn spawn_adaptive_with_hint<F, R>(&self, hint: ComputeHint, f: F) -> R
643    where
644        F: FnOnce() -> R + Send + 'static,
645        R: Send + 'static,
646    {
647        let ctx = self.collect_context();
648        let key = FunctionKey::from_type::<F>();
649        let scheduler = self.mab_scheduler();
650
651        let (id, arm) = scheduler.choose_with_hint(key, &ctx, hint);
652        let start = Instant::now();
653
654        let result = match arm {
655            Arm::InlineTokio => f(),
656            Arm::OffloadRayon => self.inner.spawn_compute(f).await,
657        };
658
659        let elapsed_us = start.elapsed().as_secs_f64() * 1_000_000.0;
660        scheduler.finish(id, elapsed_us, Some(elapsed_us));
661        result
662    }
663
664    /// Execute work on rayon with zero overhead (sync, blocking).
665    ///
666    /// This installs the rayon pool for the current scope, allowing direct use
667    /// of rayon's parallel iterators.
668    ///
669    /// **NOT tracked** - use within an already-tracked task (e.g., inside
670    /// `spawn_async` or `spawn_compute`) for proper shutdown tracking.
671    ///
672    /// # Performance
673    ///
674    /// Zero overhead - direct rayon access.
675    ///
676    /// # Examples
677    ///
678    /// ```ignore
679    /// runtime.block_on(async {
680    ///     // This is a tracked context (we're in block_on)
681    ///     let processed = runtime.install(|| {
682    ///         use rayon::prelude::*;
683    ///         data.par_iter().map(|x| process(x)).collect::<Vec<_>>()
684    ///     });
685    /// });
686    /// ```
687    #[inline]
688    pub fn install<F, R>(&self, f: F) -> R
689    where
690        F: FnOnce() -> R + Send,
691        R: Send,
692    {
693        self.inner.rayon_pool.install(f)
694    }
695
696    /// Stop accepting new tasks.
697    ///
698    /// After calling this, `spawn_async()` and `spawn_compute()` will still
699    /// work, but the shutdown process has begun. Use `is_idle()` or
700    /// `wait_for_shutdown()` to check/wait for completion.
701    pub fn shutdown(&self) {
702        self.inner.task_tracker.close();
703    }
704
705    /// Check if all tracked tasks have completed.
706    ///
707    /// Returns `true` if `shutdown()` has been called and all tracked async
708    /// tasks and compute tasks have finished.
709    ///
710    /// # Performance
711    ///
712    /// Zero overhead - single atomic load.
713    #[inline]
714    pub fn is_idle(&self) -> bool {
715        self.inner.task_tracker.is_closed()
716            && self.inner.task_tracker.is_empty()
717            && self.inner.compute_state.count.load(Ordering::Acquire) == 0
718    }
719
720    /// Get the number of compute tasks currently in flight.
721    ///
722    /// Useful for debugging shutdown issues or monitoring workload.
723    ///
724    /// # Example
725    ///
726    /// ```ignore
727    /// if runtime.compute_tasks_in_flight() > 0 {
728    ///     tracing::warn!("Still waiting for {} compute tasks",
729    ///         runtime.compute_tasks_in_flight());
730    /// }
731    /// ```
732    #[inline]
733    pub fn compute_tasks_in_flight(&self) -> usize {
734        self.inner.compute_state.count.load(Ordering::Relaxed)
735    }
736
737    /// Wait for all tracked tasks to complete (async).
738    ///
739    /// Call from within `block_on()`. Requires `shutdown()` to be called first,
740    /// otherwise this will wait forever.
741    ///
742    /// # Examples
743    ///
744    /// ```ignore
745    /// runtime.block_on(async {
746    ///     runtime.spawn_async(work());
747    ///     runtime.shutdown();
748    ///     runtime.wait_for_shutdown().await;
749    /// });
750    /// ```
751    pub async fn wait_for_shutdown(&self) {
752        self.inner.task_tracker.wait().await;
753
754        // Wait for compute tasks efficiently (no spin loop)
755        let mut logged = false;
756        loop {
757            let count = self.inner.compute_state.count.load(Ordering::Acquire);
758            if count == 0 {
759                break;
760            }
761            if !logged {
762                debug!(count, "waiting for compute tasks to complete");
763                logged = true;
764            }
765            self.inner.compute_state.notify.notified().await;
766        }
767    }
768
769    /// Block until all tracked tasks complete (from main thread).
770    ///
771    /// This is the primary shutdown method. It:
772    /// 1. Calls `shutdown()` to close the task tracker
773    /// 2. Waits for all tracked async and compute tasks to finish
774    ///
775    /// # Examples
776    ///
777    /// ```ignore
778    /// runtime.block_on(async {
779    ///     runtime.spawn_async(background_work());
780    ///     runtime.spawn_compute(|| cpu_work());
781    /// });
782    ///
783    /// // Graceful shutdown from main thread
784    /// runtime.block_until_idle();
785    /// ```
786    pub fn block_until_idle(&self) {
787        self.shutdown();
788        self.block_on(self.wait_for_shutdown());
789    }
790
791    /// Get the shared MAB scheduler for handler patterns.
792    ///
793    /// The scheduler is lazily initialized on first call. Use this when you
794    /// need to make manual scheduling decisions in handler code.
795    ///
796    /// # Example
797    ///
798    /// ```ignore
799    /// use loom_rs::mab::{FunctionKey, Arm};
800    ///
801    /// let sched = runtime.mab_scheduler();
802    /// let key = FunctionKey::from_type::<MyHandler>();
803    /// let ctx = runtime.collect_context();
804    ///
805    /// let (id, arm) = sched.choose(key, &ctx);
806    /// let result = match arm {
807    ///     Arm::InlineTokio => my_work(),
808    ///     Arm::OffloadRayon => runtime.block_on(async {
809    ///         runtime.spawn_compute(|| my_work()).await
810    ///     }),
811    /// };
812    /// sched.finish(id, elapsed_us, Some(fn_us));
813    /// ```
814    pub fn mab_scheduler(&self) -> Arc<MabScheduler> {
815        self.inner
816            .mab_scheduler
817            .get_or_init(|| {
818                Arc::new(MabScheduler::with_metrics(
819                    self.inner.mab_knobs.clone(),
820                    self.inner.prometheus_metrics.clone(),
821                ))
822            })
823            .clone()
824    }
825
826    /// Collect current runtime context for MAB scheduling decisions.
827    ///
828    /// Returns a snapshot of current metrics including inflight tasks,
829    /// spawn rate, and queue depth.
830    pub fn collect_context(&self) -> Context {
831        self.inner.prometheus_metrics.collect_context(
832            self.inner.tokio_threads as u32,
833            self.inner.rayon_threads as u32,
834        )
835    }
836
837    /// Get the number of tokio worker threads.
838    pub fn tokio_threads(&self) -> usize {
839        self.inner.tokio_threads
840    }
841
842    /// Get the number of rayon threads.
843    pub fn rayon_threads(&self) -> usize {
844        self.inner.rayon_threads
845    }
846
847    /// Get the Prometheus metrics.
848    ///
849    /// The metrics are always collected (zero overhead atomic operations).
850    /// If a Prometheus registry was provided via `LoomBuilder::prometheus_registry()`,
851    /// the metrics are also registered for exposition.
852    pub fn prometheus_metrics(&self) -> &LoomMetrics {
853        &self.inner.prometheus_metrics
854    }
855
856    /// Get the CPUs allocated to tokio workers.
857    pub fn tokio_cpus(&self) -> &[usize] {
858        &self.inner.tokio_cpus
859    }
860
861    /// Get the CPUs allocated to rayon workers.
862    pub fn rayon_cpus(&self) -> &[usize] {
863        &self.inner.rayon_cpus
864    }
865}
866
867impl LoomRuntimeInner {
868    /// Spawn CPU-bound work on rayon and await the result.
869    ///
870    /// Uses per-type object pools for zero allocation after warmup.
871    #[inline]
872    pub async fn spawn_compute<F, R>(self: &Arc<Self>, f: F) -> R
873    where
874        F: FnOnce() -> R + Send + 'static,
875        R: Send + 'static,
876    {
877        let pool = self.pools.get_or_create::<R>();
878
879        // Try to get state from pool, or allocate new
880        let state = pool.pop().unwrap_or_else(|| Arc::new(TaskState::new()));
881
882        // Create the pooled task
883        let (task, completion, state_for_return) = PooledRayonTask::new(state);
884
885        // Create guard BEFORE spawning - it increments counter and tracks MAB metrics
886        let guard = ComputeTaskGuard::new(&self.compute_state, &self.prometheus_metrics);
887
888        self.rayon_pool.spawn(move || {
889            // Track rayon task start for queue depth calculation
890            guard.started();
891
892            // Execute work inside guard scope so counter decrements BEFORE completing.
893            // This ensures the async future sees count=0 when it wakes up.
894            let result = {
895                let _guard = guard;
896                f()
897            };
898            completion.complete(result);
899        });
900
901        let result = task.await;
902
903        // Return state to pool for reuse
904        state_for_return.reset();
905        pool.push(state_for_return);
906
907        result
908    }
909}
910
911impl std::fmt::Debug for LoomRuntime {
912    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
913        f.debug_struct("LoomRuntime")
914            .field("config", &self.inner.config)
915            .field(
916                "compute_tasks_in_flight",
917                &self.inner.compute_state.count.load(Ordering::Relaxed),
918            )
919            .finish_non_exhaustive()
920    }
921}
922
923impl std::fmt::Display for LoomRuntime {
924    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
925        write!(
926            f,
927            "LoomRuntime[{}]: tokio({}, cpus={}) rayon({}, cpus={})",
928            self.inner.config.prefix,
929            self.inner.tokio_threads,
930            format_cpuset(&self.inner.tokio_cpus),
931            self.inner.rayon_threads,
932            format_cpuset(&self.inner.rayon_cpus),
933        )
934    }
935}
936
937#[cfg(test)]
938mod tests {
939    use super::*;
940    use crate::pool::DEFAULT_POOL_SIZE;
941
942    fn test_config() -> LoomConfig {
943        LoomConfig {
944            prefix: "test".to_string(),
945            cpuset: None,
946            tokio_threads: Some(1),
947            rayon_threads: Some(1),
948            compute_pool_size: DEFAULT_POOL_SIZE,
949            #[cfg(feature = "cuda")]
950            cuda_device: None,
951            mab_knobs: None,
952            calibration: None,
953            prometheus_registry: None,
954        }
955    }
956
957    #[test]
958    fn test_runtime_creation() {
959        let config = test_config();
960        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
961        assert_eq!(runtime.config().prefix, "test");
962    }
963
964    #[test]
965    fn test_block_on() {
966        let config = test_config();
967        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
968
969        let result = runtime.block_on(async { 42 });
970        assert_eq!(result, 42);
971    }
972
973    #[test]
974    fn test_spawn_compute() {
975        let config = test_config();
976        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
977
978        let result =
979            runtime.block_on(async { runtime.spawn_compute(|| (0..100).sum::<i32>()).await });
980        assert_eq!(result, 4950);
981    }
982
983    #[test]
984    fn test_spawn_async() {
985        let config = test_config();
986        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
987
988        let result = runtime.block_on(async {
989            let handle = runtime.spawn_async(async { 42 });
990            handle.await.unwrap()
991        });
992        assert_eq!(result, 42);
993    }
994
995    #[test]
996    fn test_install() {
997        let config = test_config();
998        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
999
1000        let result = runtime.install(|| {
1001            use rayon::prelude::*;
1002            (0..100).into_par_iter().sum::<i32>()
1003        });
1004        assert_eq!(result, 4950);
1005    }
1006
1007    #[test]
1008    fn test_shutdown_and_idle() {
1009        let config = test_config();
1010        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1011
1012        // Initially not idle (tracker not closed)
1013        assert!(!runtime.is_idle());
1014
1015        // After shutdown with no tasks, should be idle
1016        runtime.shutdown();
1017        assert!(runtime.is_idle());
1018    }
1019
1020    #[test]
1021    fn test_block_until_idle() {
1022        let config = test_config();
1023        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1024
1025        runtime.block_on(async {
1026            runtime.spawn_async(async { 42 });
1027            runtime.spawn_compute(|| 100).await;
1028        });
1029
1030        runtime.block_until_idle();
1031        assert!(runtime.is_idle());
1032    }
1033
1034    #[test]
1035    fn test_insufficient_cpus_error() {
1036        let mut config = test_config();
1037        config.cpuset = Some("0".to_string()); // Only 1 CPU
1038        config.tokio_threads = Some(2);
1039        config.rayon_threads = Some(2);
1040
1041        let result = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE);
1042        assert!(matches!(result, Err(LoomError::InsufficientCpus { .. })));
1043    }
1044
1045    #[test]
1046    fn test_current_runtime_in_block_on() {
1047        let config = test_config();
1048        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1049
1050        runtime.block_on(async {
1051            // current_runtime should work inside block_on
1052            let current = crate::context::current_runtime();
1053            assert!(current.is_some());
1054        });
1055
1056        // Outside block_on, should be None
1057        assert!(crate::context::current_runtime().is_none());
1058    }
1059
1060    #[test]
1061    fn test_spawn_compute_pooling() {
1062        let config = test_config();
1063        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1064
1065        // Warmup - first call allocates
1066        runtime.block_on(async {
1067            runtime.spawn_compute(|| 1i32).await;
1068        });
1069
1070        // Subsequent calls should reuse pooled state (we can't easily verify this
1071        // without internal access, but we can verify it works)
1072        runtime.block_on(async {
1073            for i in 0..100 {
1074                let result = runtime.spawn_compute(move || i).await;
1075                assert_eq!(result, i);
1076            }
1077        });
1078    }
1079
1080    #[test]
1081    fn test_spawn_compute_guard_drops_on_scope_exit() {
1082        // This test verifies the guard's Drop implementation works correctly.
1083        // We can't easily test panic behavior in rayon (panics abort by default),
1084        // but we can verify the guard decrements the counter when it goes out of scope.
1085        use crate::metrics::LoomMetrics;
1086        use std::sync::atomic::Ordering;
1087
1088        let state = super::ComputeTaskState::new();
1089        let metrics = LoomMetrics::new();
1090
1091        // Create a guard (increments counter)
1092        {
1093            let _guard = super::ComputeTaskGuard::new(&state, &metrics);
1094            assert_eq!(state.count.load(Ordering::Relaxed), 1);
1095        }
1096        // Guard dropped, counter should be 0
1097        assert_eq!(state.count.load(Ordering::Relaxed), 0);
1098
1099        // Test multiple guards
1100        let state = super::ComputeTaskState::new();
1101
1102        let guard1 = super::ComputeTaskGuard::new(&state, &metrics);
1103        assert_eq!(state.count.load(Ordering::Relaxed), 1);
1104
1105        let guard2 = super::ComputeTaskGuard::new(&state, &metrics);
1106        assert_eq!(state.count.load(Ordering::Relaxed), 2);
1107
1108        drop(guard1);
1109        assert_eq!(state.count.load(Ordering::Relaxed), 1);
1110
1111        drop(guard2);
1112        assert_eq!(state.count.load(Ordering::Relaxed), 0);
1113
1114        // The notification mechanism is verified by the fact that wait_for_shutdown
1115        // doesn't spin-loop forever when compute tasks complete
1116    }
1117
1118    #[test]
1119    fn test_compute_tasks_in_flight() {
1120        let config = test_config();
1121        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1122
1123        // Initially no tasks
1124        assert_eq!(runtime.compute_tasks_in_flight(), 0);
1125
1126        // After spawning and completing, should be back to 0
1127        runtime.block_on(async {
1128            runtime.spawn_compute(|| 42).await;
1129        });
1130        assert_eq!(runtime.compute_tasks_in_flight(), 0);
1131    }
1132
1133    #[test]
1134    fn test_display() {
1135        let config = test_config();
1136        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1137
1138        let display = format!("{}", runtime);
1139        assert!(display.starts_with("LoomRuntime[test]:"));
1140        assert!(display.contains("tokio(1, cpus="));
1141        assert!(display.contains("rayon(1, cpus="));
1142    }
1143
1144    #[test]
1145    fn test_cpuset_only() {
1146        let mut config = test_config();
1147        config.cpuset = Some("0".to_string());
1148        config.tokio_threads = Some(1);
1149        config.rayon_threads = Some(0);
1150
1151        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1152        // Should use the user-provided cpuset
1153        assert_eq!(runtime.inner.tokio_cpus, vec![0]);
1154    }
1155
1156    /// Test that CUDA cpuset conflict error is properly detected.
1157    /// This test requires actual CUDA hardware to verify the conflict.
1158    #[cfg(feature = "cuda-tests")]
1159    #[test]
1160    fn test_cuda_cpuset_conflict_error() {
1161        let mut config = test_config();
1162        config.cuda_device = Some(crate::cuda::CudaDeviceSelector::DeviceId(0));
1163        config.cpuset = Some("0".to_string()); // Conflict: both specified
1164
1165        let result = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE);
1166        assert!(
1167            matches!(result, Err(LoomError::CudaCpusetConflict)),
1168            "expected CudaCpusetConflict error, got {:?}",
1169            result
1170        );
1171    }
1172
1173    /// Test that CUDA device alone (without cpuset) works.
1174    #[cfg(feature = "cuda-tests")]
1175    #[test]
1176    fn test_cuda_device_only() {
1177        let mut config = test_config();
1178        config.cuda_device = Some(crate::cuda::CudaDeviceSelector::DeviceId(0));
1179        config.cpuset = None;
1180
1181        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1182        // Should have found CUDA-local CPUs
1183        assert!(!runtime.inner.tokio_cpus.is_empty());
1184    }
1185
1186    // =============================================================================
1187    // spawn_adaptive Tests
1188    // =============================================================================
1189
1190    #[test]
1191    fn test_spawn_adaptive_runs_work() {
1192        let config = test_config();
1193        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1194
1195        let result = runtime.block_on(async { runtime.spawn_adaptive(|| 42).await });
1196
1197        assert_eq!(result, 42);
1198    }
1199
1200    #[test]
1201    fn test_spawn_adaptive_with_hint() {
1202        let config = test_config();
1203        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1204
1205        let result = runtime.block_on(async {
1206            runtime
1207                .spawn_adaptive_with_hint(crate::ComputeHint::Low, || 100)
1208                .await
1209        });
1210
1211        assert_eq!(result, 100);
1212    }
1213
1214    #[test]
1215    fn test_spawn_adaptive_multiple_calls() {
1216        let config = test_config();
1217        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1218
1219        runtime.block_on(async {
1220            // Run many fast tasks to let MAB learn
1221            for i in 0..50 {
1222                let result = runtime.spawn_adaptive(move || i * 2).await;
1223                assert_eq!(result, i * 2);
1224            }
1225        });
1226    }
1227
1228    #[test]
1229    fn test_spawn_adaptive_records_metrics() {
1230        let config = test_config();
1231        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1232
1233        runtime.block_on(async {
1234            // Run some adaptive tasks
1235            for _ in 0..10 {
1236                runtime.spawn_adaptive(|| std::hint::black_box(42)).await;
1237            }
1238        });
1239
1240        // Check that metrics were recorded
1241        let metrics = runtime.prometheus_metrics();
1242        let total_decisions = metrics.inline_decisions.get() + metrics.offload_decisions.get();
1243        assert!(
1244            total_decisions >= 10,
1245            "Should have recorded at least 10 decisions, got {}",
1246            total_decisions
1247        );
1248    }
1249
1250    #[test]
1251    fn test_prometheus_metrics_use_prefix() {
1252        let mut config = test_config();
1253        config.prefix = "myapp".to_string();
1254        let runtime = LoomRuntime::from_config(config, DEFAULT_POOL_SIZE).unwrap();
1255
1256        // The metrics should use the prefix from config
1257        // We can verify by checking the registry if one was provided
1258        let registry = prometheus::Registry::new();
1259        runtime
1260            .prometheus_metrics()
1261            .register(&registry)
1262            .expect("registration should succeed");
1263
1264        let families = registry.gather();
1265        // Find a metric with our prefix
1266        let myapp_metric = families.iter().find(|f| f.get_name().starts_with("myapp_"));
1267        assert!(
1268            myapp_metric.is_some(),
1269            "Should find metrics with 'myapp_' prefix"
1270        );
1271
1272        // Should not find metrics with default 'loom_' prefix
1273        let loom_metric = families.iter().find(|f| f.get_name().starts_with("loom_"));
1274        assert!(
1275            loom_metric.is_none(),
1276            "Should not find metrics with 'loom_' prefix"
1277        );
1278    }
1279}