Skip to main content

oxicuda_driver/
occupancy_ext.rs

1//! Extended occupancy helpers for CPU-side occupancy estimation.
2//!
3//! Unlike the GPU-side queries in [`crate::occupancy`] that call
4//! `cuOccupancy*` driver functions, this module provides **pure computation**
5//! for analysing occupancy trade-offs without requiring a live GPU.
6//!
7//! # Features
8//!
9//! - [`OccupancyCalculator`] — CPU-side occupancy estimation
10//! - [`OccupancyGrid`] — sweep block sizes to find the optimum
11//! - [`DynamicSmemOccupancy`] — occupancy with shared-memory callbacks
12//! - [`ClusterOccupancy`] — Hopper+ thread block cluster support
13//!
14//! # Example
15//!
16//! ```rust
17//! use oxicuda_driver::occupancy_ext::*;
18//!
19//! let info = DeviceOccupancyInfo {
20//!     sm_count: 84,
21//!     max_threads_per_sm: 1536,
22//!     max_blocks_per_sm: 16,
23//!     max_registers_per_sm: 65536,
24//!     max_shared_memory_per_sm: 102400,
25//!     warp_size: 32,
26//! };
27//! let calc = OccupancyCalculator::new(info);
28//! let est = calc.estimate_occupancy(256, 32, 0);
29//! assert!(est.occupancy_ratio > 0.0);
30//! ```
31
32use crate::device::Device;
33#[cfg(not(target_os = "macos"))]
34use crate::error::CudaError;
35use crate::error::CudaResult;
36
37// ---------------------------------------------------------------------------
38// DeviceOccupancyInfo
39// ---------------------------------------------------------------------------
40
41/// Hardware parameters needed for CPU-side occupancy estimation.
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub struct DeviceOccupancyInfo {
44    /// Number of streaming multiprocessors on the device.
45    pub sm_count: u32,
46    /// Maximum resident threads per SM.
47    pub max_threads_per_sm: u32,
48    /// Maximum concurrent blocks per SM.
49    pub max_blocks_per_sm: u32,
50    /// Total 32-bit registers available per SM.
51    pub max_registers_per_sm: u32,
52    /// Shared memory capacity per SM in bytes.
53    pub max_shared_memory_per_sm: u32,
54    /// Threads per warp (typically 32).
55    pub warp_size: u32,
56}
57
58impl DeviceOccupancyInfo {
59    /// Maximum number of warps that can be resident on one SM.
60    fn max_warps_per_sm(&self) -> u32 {
61        if self.warp_size == 0 {
62            return 0;
63        }
64        self.max_threads_per_sm / self.warp_size
65    }
66
67    /// Return synthetic [`DeviceOccupancyInfo`] for a given SM compute
68    /// capability, enabling CPU-side occupancy analysis without a live GPU.
69    ///
70    /// Covers all major NVIDIA GPU architectures from Turing through Blackwell.
71    /// Unknown architectures fall back to Ampere SM 8.6 defaults.
72    ///
73    /// # SM capability table
74    ///
75    /// | Architecture      | sm_major | sm_minor | SMs | Threads/SM | Smem/SM |
76    /// |-------------------|----------|----------|-----|------------|---------|
77    /// | Turing            | 7        | 5        | 68  | 1024       | 65536   |
78    /// | Ampere A100       | 8        | 0        | 108 | 2048       | 167936  |
79    /// | Ampere GA10x      | 8        | 6        | 84  | 1536       | 102400  |
80    /// | Ada Lovelace      | 8        | 9        | 76  | 1536       | 101376  |
81    /// | Hopper H100       | 9        | 0        | 132 | 2048       | 232448  |
82    /// | Blackwell B100    | 10       | 0        | 132 | 2048       | 262144  |
83    /// | Blackwell B200    | 12       | 0        | 148 | 2048       | 262144  |
84    #[must_use]
85    pub fn for_compute_capability(sm_major: u32, sm_minor: u32) -> Self {
86        match (sm_major, sm_minor) {
87            // Turing (sm_75)
88            (7, 5) => Self {
89                sm_count: 68,
90                max_threads_per_sm: 1024,
91                max_blocks_per_sm: 16,
92                max_registers_per_sm: 65536,
93                max_shared_memory_per_sm: 65536,
94                warp_size: 32,
95            },
96            // Ampere A100 (sm_80)
97            (8, 0) => Self {
98                sm_count: 108,
99                max_threads_per_sm: 2048,
100                max_blocks_per_sm: 32,
101                max_registers_per_sm: 65536,
102                max_shared_memory_per_sm: 167936,
103                warp_size: 32,
104            },
105            // Ampere GA10x (sm_86, e.g. RTX 3090)
106            (8, 6) => Self {
107                sm_count: 84,
108                max_threads_per_sm: 1536,
109                max_blocks_per_sm: 16,
110                max_registers_per_sm: 65536,
111                max_shared_memory_per_sm: 102400,
112                warp_size: 32,
113            },
114            // Ada Lovelace (sm_89, e.g. RTX 4090)
115            (8, 9) => Self {
116                sm_count: 76,
117                max_threads_per_sm: 1536,
118                max_blocks_per_sm: 24,
119                max_registers_per_sm: 65536,
120                max_shared_memory_per_sm: 101376,
121                warp_size: 32,
122            },
123            // Hopper H100 (sm_90)
124            (9, 0) => Self {
125                sm_count: 132,
126                max_threads_per_sm: 2048,
127                max_blocks_per_sm: 32,
128                max_registers_per_sm: 65536,
129                max_shared_memory_per_sm: 232448,
130                warp_size: 32,
131            },
132            // Blackwell B100 (sm_100) — 132 SMs, 256KB shared/SM
133            (10, 0) => Self {
134                sm_count: 132,
135                max_threads_per_sm: 2048,
136                max_blocks_per_sm: 32,
137                max_registers_per_sm: 65536,
138                max_shared_memory_per_sm: 262144,
139                warp_size: 32,
140            },
141            // Blackwell B200 (sm_120) — 148 SMs, 256KB shared/SM
142            (12, 0) => Self {
143                sm_count: 148,
144                max_threads_per_sm: 2048,
145                max_blocks_per_sm: 32,
146                max_registers_per_sm: 65536,
147                max_shared_memory_per_sm: 262144,
148                warp_size: 32,
149            },
150            // Unknown / future — fall back to Ampere GA10x defaults.
151            _ => Self {
152                sm_count: 84,
153                max_threads_per_sm: 1536,
154                max_blocks_per_sm: 16,
155                max_registers_per_sm: 65536,
156                max_shared_memory_per_sm: 102400,
157                warp_size: 32,
158            },
159        }
160    }
161}
162
163// ---------------------------------------------------------------------------
164// LimitingFactor
165// ---------------------------------------------------------------------------
166
167/// The resource that limits occupancy the most.
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
169pub enum LimitingFactor {
170    /// Block size limits the number of warps.
171    Threads,
172    /// Register pressure limits concurrent warps.
173    Registers,
174    /// Shared memory exhaustion limits concurrent blocks.
175    SharedMemory,
176    /// Hardware block-per-SM cap is the bottleneck.
177    Blocks,
178    /// All resources have headroom (or estimation was trivial).
179    None,
180}
181
182// ---------------------------------------------------------------------------
183// OccupancyEstimate
184// ---------------------------------------------------------------------------
185
186/// Result of a CPU-side occupancy estimation for one configuration.
187#[derive(Debug, Clone, Copy)]
188pub struct OccupancyEstimate {
189    /// Active warps per SM for this configuration.
190    pub active_warps_per_sm: u32,
191    /// Maximum possible warps per SM (hardware limit).
192    pub max_warps_per_sm: u32,
193    /// Fraction of max warps that are active (0.0 .. 1.0).
194    pub occupancy_ratio: f64,
195    /// Which resource is the tightest bottleneck.
196    pub limiting_factor: LimitingFactor,
197}
198
199// ---------------------------------------------------------------------------
200// OccupancyCalculator
201// ---------------------------------------------------------------------------
202
203/// CPU-side occupancy estimator — no GPU calls required.
204///
205/// Given device hardware parameters, this struct computes how many warps
206/// can be concurrently resident for a given kernel configuration.
207#[derive(Debug, Clone)]
208pub struct OccupancyCalculator {
209    info: DeviceOccupancyInfo,
210}
211
212impl OccupancyCalculator {
213    /// Create a new calculator from device occupancy information.
214    pub fn new(device_info: DeviceOccupancyInfo) -> Self {
215        Self { info: device_info }
216    }
217
218    /// Return a reference to the underlying device info.
219    pub fn device_info(&self) -> &DeviceOccupancyInfo {
220        &self.info
221    }
222
223    /// Estimate occupancy for the given kernel configuration.
224    ///
225    /// # Parameters
226    ///
227    /// * `block_size` — threads per block.
228    /// * `registers_per_thread` — registers consumed by each thread.
229    /// * `shared_memory` — shared memory per block in bytes.
230    pub fn estimate_occupancy(
231        &self,
232        block_size: u32,
233        registers_per_thread: u32,
234        shared_memory: u32,
235    ) -> OccupancyEstimate {
236        let max_warps = self.info.max_warps_per_sm();
237
238        // Degenerate cases
239        if block_size == 0 || self.info.warp_size == 0 || max_warps == 0 {
240            return OccupancyEstimate {
241                active_warps_per_sm: 0,
242                max_warps_per_sm: max_warps,
243                occupancy_ratio: 0.0,
244                limiting_factor: LimitingFactor::None,
245            };
246        }
247
248        let warps_per_block = block_size.div_ceil(self.info.warp_size);
249
250        // --- Limit 1: max blocks per SM (hardware cap) -----------------------
251        let blocks_by_block_limit = self.info.max_blocks_per_sm;
252
253        // --- Limit 2: threads (warps) ----------------------------------------
254        let blocks_by_threads = max_warps.checked_div(warps_per_block).unwrap_or(0);
255
256        // --- Limit 3: registers -----------------------------------------------
257        let blocks_by_registers = if registers_per_thread == 0 || warps_per_block == 0 {
258            u32::MAX // registers not a bottleneck
259        } else {
260            let regs_per_block = registers_per_thread * warps_per_block * self.info.warp_size;
261            self.info
262                .max_registers_per_sm
263                .checked_div(regs_per_block)
264                .unwrap_or(u32::MAX)
265        };
266
267        // --- Limit 4: shared memory -------------------------------------------
268        let blocks_by_smem = if shared_memory == 0 {
269            u32::MAX // smem not a bottleneck
270        } else if self.info.max_shared_memory_per_sm == 0 {
271            0
272        } else {
273            self.info.max_shared_memory_per_sm / shared_memory
274        };
275
276        // Take the minimum across all limits
277        let active_blocks = blocks_by_block_limit
278            .min(blocks_by_threads)
279            .min(blocks_by_registers)
280            .min(blocks_by_smem);
281
282        let active_warps = active_blocks * warps_per_block;
283        let clamped_warps = active_warps.min(max_warps);
284        let ratio = if max_warps > 0 {
285            clamped_warps as f64 / max_warps as f64
286        } else {
287            0.0
288        };
289
290        // Determine limiting factor
291        let effective = active_blocks;
292        let limiting_factor = if effective == 0 {
293            if blocks_by_smem == 0 {
294                LimitingFactor::SharedMemory
295            } else if blocks_by_registers == 0 {
296                LimitingFactor::Registers
297            } else if blocks_by_threads == 0 {
298                LimitingFactor::Threads
299            } else {
300                LimitingFactor::Blocks
301            }
302        } else if effective == blocks_by_smem
303            && blocks_by_smem
304                <= blocks_by_registers
305                    .min(blocks_by_threads)
306                    .min(blocks_by_block_limit)
307        {
308            LimitingFactor::SharedMemory
309        } else if effective == blocks_by_registers
310            && blocks_by_registers <= blocks_by_threads.min(blocks_by_block_limit)
311        {
312            LimitingFactor::Registers
313        } else if effective == blocks_by_threads && blocks_by_threads <= blocks_by_block_limit {
314            LimitingFactor::Threads
315        } else if effective == blocks_by_block_limit {
316            LimitingFactor::Blocks
317        } else {
318            LimitingFactor::None
319        };
320
321        OccupancyEstimate {
322            active_warps_per_sm: clamped_warps,
323            max_warps_per_sm: max_warps,
324            occupancy_ratio: ratio,
325            limiting_factor,
326        }
327    }
328}
329
330// ---------------------------------------------------------------------------
331// OccupancyPoint / OccupancyGrid
332// ---------------------------------------------------------------------------
333
334/// A single data point from a block-size sweep.
335#[derive(Debug, Clone, Copy)]
336pub struct OccupancyPoint {
337    /// Block size (threads per block) for this point.
338    pub block_size: u32,
339    /// Occupancy ratio (0.0 .. 1.0).
340    pub occupancy: f64,
341    /// Active warps per SM.
342    pub active_warps: u32,
343    /// Limiting resource at this block size.
344    pub limiting_factor: LimitingFactor,
345}
346
347/// Sweep block sizes to find the configuration that maximises occupancy.
348pub struct OccupancyGrid;
349
350impl OccupancyGrid {
351    /// Sweep block sizes from `warp_size` to `max_threads_per_sm` in
352    /// increments of `warp_size` and return occupancy at each step.
353    pub fn sweep(
354        calculator: &OccupancyCalculator,
355        registers_per_thread: u32,
356        shared_memory: u32,
357    ) -> Vec<OccupancyPoint> {
358        let ws = calculator.info.warp_size;
359        if ws == 0 {
360            return Vec::new();
361        }
362        let max_threads = calculator.info.max_threads_per_sm;
363        let mut points = Vec::new();
364        let mut bs = ws;
365        while bs <= max_threads {
366            let est = calculator.estimate_occupancy(bs, registers_per_thread, shared_memory);
367            points.push(OccupancyPoint {
368                block_size: bs,
369                occupancy: est.occupancy_ratio,
370                active_warps: est.active_warps_per_sm,
371                limiting_factor: est.limiting_factor,
372            });
373            bs += ws;
374        }
375        points
376    }
377
378    /// Pick the block size with the highest occupancy.
379    ///
380    /// Ties are broken by choosing the **smallest** block size.
381    /// Returns `0` if the slice is empty.
382    pub fn best_block_size(points: &[OccupancyPoint]) -> u32 {
383        let mut best: Option<&OccupancyPoint> = Option::None;
384        for pt in points {
385            best = Some(match best {
386                Option::None => pt,
387                Some(prev) => {
388                    if pt.occupancy > prev.occupancy
389                        || (pt.occupancy == prev.occupancy && pt.block_size < prev.block_size)
390                    {
391                        pt
392                    } else {
393                        prev
394                    }
395                }
396            });
397        }
398        best.map_or(0, |p| p.block_size)
399    }
400}
401
402// ---------------------------------------------------------------------------
403// DynamicSmemOccupancy
404// ---------------------------------------------------------------------------
405
406/// Occupancy estimation where shared memory varies with block size.
407pub struct DynamicSmemOccupancy;
408
409impl DynamicSmemOccupancy {
410    /// Sweep block sizes using a callback `smem_fn(block_size) -> smem_bytes`.
411    pub fn with_smem_function<F>(
412        calculator: &OccupancyCalculator,
413        smem_fn: F,
414        registers_per_thread: u32,
415    ) -> Vec<OccupancyPoint>
416    where
417        F: Fn(u32) -> u32,
418    {
419        let ws = calculator.info.warp_size;
420        if ws == 0 {
421            return Vec::new();
422        }
423        let max_threads = calculator.info.max_threads_per_sm;
424        let mut points = Vec::new();
425        let mut bs = ws;
426        while bs <= max_threads {
427            let smem = smem_fn(bs);
428            let est = calculator.estimate_occupancy(bs, registers_per_thread, smem);
429            points.push(OccupancyPoint {
430                block_size: bs,
431                occupancy: est.occupancy_ratio,
432                active_warps: est.active_warps_per_sm,
433                limiting_factor: est.limiting_factor,
434            });
435            bs += ws;
436        }
437        points
438    }
439
440    /// A shared-memory function that scales linearly with block size.
441    ///
442    /// Returns `block_size * bytes_per_thread`.
443    pub fn linear_smem(bytes_per_thread: u32) -> impl Fn(u32) -> u32 {
444        move |block_size: u32| block_size * bytes_per_thread
445    }
446
447    /// A shared-memory function for tile-based kernels.
448    ///
449    /// Returns `tile_size * tile_size * element_size` (constant per block).
450    pub fn tile_smem(tile_size: u32, element_size: u32) -> impl Fn(u32) -> u32 {
451        move |_block_size: u32| tile_size * tile_size * element_size
452    }
453}
454
455// ---------------------------------------------------------------------------
456// ClusterOccupancy (Hopper+)
457// ---------------------------------------------------------------------------
458
459/// Thread block cluster configuration for Hopper+ GPUs.
460#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
461pub struct ClusterConfig {
462    /// Cluster extent in X dimension (number of blocks).
463    pub cluster_x: u32,
464    /// Cluster extent in Y dimension.
465    pub cluster_y: u32,
466    /// Cluster extent in Z dimension.
467    pub cluster_z: u32,
468}
469
470impl ClusterConfig {
471    /// Total blocks in one cluster.
472    pub fn total_blocks(&self) -> u32 {
473        self.cluster_x * self.cluster_y * self.cluster_z
474    }
475}
476
477/// Result of a cluster occupancy estimation.
478#[derive(Debug, Clone, Copy)]
479pub struct ClusterOccupancyEstimate {
480    /// Number of blocks per cluster.
481    pub blocks_per_cluster: u32,
482    /// Maximum clusters that fit per SM (fractional blocks accounted for).
483    pub clusters_per_sm: u32,
484    /// Effective occupancy ratio (0.0 .. 1.0).
485    pub effective_occupancy: f64,
486    /// Total shared memory consumed by one cluster (bytes).
487    pub cluster_smem_total: u32,
488}
489
490/// Hopper+ thread block cluster occupancy estimation.
491pub struct ClusterOccupancy;
492
493impl ClusterOccupancy {
494    /// Estimate occupancy when blocks are grouped into clusters.
495    ///
496    /// A cluster of `cluster_size` blocks must all reside on the same GPC
497    /// (GPU Processing Cluster). This effectively reduces the number of
498    /// independent blocks schedulable per SM.
499    ///
500    /// # Parameters
501    ///
502    /// * `calculator` — occupancy calculator with device info.
503    /// * `block_size` — threads per block.
504    /// * `cluster_size` — number of blocks per cluster.
505    /// * `registers_per_thread` — registers per thread.
506    /// * `shared_memory` — shared memory per block (bytes).
507    pub fn estimate_cluster_occupancy(
508        calculator: &OccupancyCalculator,
509        block_size: u32,
510        cluster_size: u32,
511        registers_per_thread: u32,
512        shared_memory: u32,
513    ) -> ClusterOccupancyEstimate {
514        if cluster_size == 0 {
515            return ClusterOccupancyEstimate {
516                blocks_per_cluster: 0,
517                clusters_per_sm: 0,
518                effective_occupancy: 0.0,
519                cluster_smem_total: 0,
520            };
521        }
522
523        // First, get the per-block occupancy estimate
524        let est = calculator.estimate_occupancy(block_size, registers_per_thread, shared_memory);
525
526        let max_warps = est.max_warps_per_sm;
527        let warps_per_block = if calculator.info.warp_size == 0 {
528            0
529        } else {
530            block_size.div_ceil(calculator.info.warp_size)
531        };
532
533        // How many blocks could fit per SM (from the standard estimate)?
534        let blocks_per_sm = est
535            .active_warps_per_sm
536            .checked_div(warps_per_block)
537            .unwrap_or(0);
538
539        // Clusters must schedule in whole units
540        let clusters_per_sm = blocks_per_sm / cluster_size;
541        let active_blocks = clusters_per_sm * cluster_size;
542        let active_warps = active_blocks * warps_per_block;
543
544        let effective_occupancy = if max_warps > 0 {
545            (active_warps.min(max_warps)) as f64 / max_warps as f64
546        } else {
547            0.0
548        };
549
550        ClusterOccupancyEstimate {
551            blocks_per_cluster: cluster_size,
552            clusters_per_sm,
553            effective_occupancy,
554            cluster_smem_total: cluster_size * shared_memory,
555        }
556    }
557}
558
559// ---------------------------------------------------------------------------
560// Device convenience extension
561// ---------------------------------------------------------------------------
562
563impl Device {
564    /// Gather all occupancy-relevant hardware attributes into a
565    /// [`DeviceOccupancyInfo`] struct.
566    ///
567    /// On macOS (where no NVIDIA driver is available) this returns
568    /// synthetic values for a typical SM 8.6 (Ampere) GPU so that
569    /// CPU-side occupancy analysis can still run.
570    ///
571    /// # Errors
572    ///
573    /// Returns a [`CudaError`](crate::error::CudaError) if an attribute query fails on a real GPU.
574    pub fn occupancy_info(&self) -> CudaResult<DeviceOccupancyInfo> {
575        // On macOS the driver is never present — return synthetic defaults.
576        #[cfg(target_os = "macos")]
577        {
578            let _ = self; // suppress unused warning
579            Ok(DeviceOccupancyInfo {
580                sm_count: 84,
581                max_threads_per_sm: 1536,
582                max_blocks_per_sm: 16,
583                max_registers_per_sm: 65536,
584                max_shared_memory_per_sm: 102400,
585                warp_size: 32,
586            })
587        }
588
589        #[cfg(not(target_os = "macos"))]
590        {
591            let sm_count = self
592                .multiprocessor_count()
593                .map(|v| v as u32)
594                .map_err(|_| CudaError::NotInitialized)?;
595            let max_threads_per_sm = self
596                .max_threads_per_multiprocessor()
597                .map(|v| v as u32)
598                .map_err(|_| CudaError::NotInitialized)?;
599            let max_blocks_per_sm = self
600                .max_blocks_per_multiprocessor()
601                .map(|v| v as u32)
602                .map_err(|_| CudaError::NotInitialized)?;
603            let max_registers_per_sm = self
604                .max_registers_per_multiprocessor()
605                .map(|v| v as u32)
606                .map_err(|_| CudaError::NotInitialized)?;
607            let max_shared_memory_per_sm = self
608                .max_shared_memory_per_multiprocessor()
609                .map(|v| v as u32)
610                .map_err(|_| CudaError::NotInitialized)?;
611            let warp_size = self
612                .warp_size()
613                .map(|v| v as u32)
614                .map_err(|_| CudaError::NotInitialized)?;
615
616            Ok(DeviceOccupancyInfo {
617                sm_count,
618                max_threads_per_sm,
619                max_blocks_per_sm,
620                max_registers_per_sm,
621                max_shared_memory_per_sm,
622                warp_size,
623            })
624        }
625    }
626}
627
628// ===========================================================================
629// Tests
630// ===========================================================================
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635
636    /// A typical SM 8.6 (e.g. RTX 3090) device info for testing.
637    fn ampere_info() -> DeviceOccupancyInfo {
638        DeviceOccupancyInfo {
639            sm_count: 82,
640            max_threads_per_sm: 1536,
641            max_blocks_per_sm: 16,
642            max_registers_per_sm: 65536,
643            max_shared_memory_per_sm: 102400,
644            warp_size: 32,
645        }
646    }
647
648    // --- Basic occupancy estimation ------------------------------------------
649
650    #[test]
651    fn test_basic_occupancy_estimation() {
652        let calc = OccupancyCalculator::new(ampere_info());
653        let est = calc.estimate_occupancy(256, 32, 0);
654        // 256 threads = 8 warps/block, max 48 warps, 48/8 = 6 blocks,
655        // but limited by max_blocks_per_sm = 16 (not limiting here)
656        // registers: 32 * 8 * 32 = 8192 per block, 65536/8192 = 8 blocks
657        // min(16, 6, 8) = 6 blocks => 6*8 = 48 warps => 100%
658        assert_eq!(est.max_warps_per_sm, 48);
659        assert!(est.occupancy_ratio > 0.0);
660        assert!(est.active_warps_per_sm > 0);
661    }
662
663    #[test]
664    fn test_full_occupancy() {
665        let calc = OccupancyCalculator::new(ampere_info());
666        // 32 threads, 0 registers pressure, 0 smem => many blocks fit
667        let est = calc.estimate_occupancy(32, 16, 0);
668        // 1 warp/block, max 48 warps => 48 blocks needed,
669        // but max_blocks_per_sm = 16 => 16 warps => 16/48 = 33%
670        assert_eq!(est.active_warps_per_sm, 16);
671    }
672
673    // --- Limiting factor detection -------------------------------------------
674
675    #[test]
676    fn test_limiting_factor_threads() {
677        let calc = OccupancyCalculator::new(ampere_info());
678        // Large block (1024 threads = 32 warps), low registers, no smem
679        // 48 / 32 = 1 block => 32 warps => threads is the limit
680        let est = calc.estimate_occupancy(1024, 16, 0);
681        assert_eq!(est.limiting_factor, LimitingFactor::Threads);
682    }
683
684    #[test]
685    fn test_limiting_factor_registers() {
686        let calc = OccupancyCalculator::new(ampere_info());
687        // 256 threads = 8 warps, 128 registers each
688        // regs per block = 128 * 8 * 32 = 32768, 65536/32768 = 2 blocks
689        // threads: 48/8 = 6 blocks, blocks: 16 => min(16, 6, 2) = 2
690        let est = calc.estimate_occupancy(256, 128, 0);
691        assert_eq!(est.limiting_factor, LimitingFactor::Registers);
692    }
693
694    #[test]
695    fn test_limiting_factor_shared_memory() {
696        let calc = OccupancyCalculator::new(ampere_info());
697        // 128 threads = 4 warps, low regs, 51200 bytes smem
698        // smem: 102400 / 51200 = 2 blocks
699        // threads: 48 / 4 = 12, blocks: 16 => min(16, 12, oo, 2) = 2
700        let est = calc.estimate_occupancy(128, 16, 51200);
701        assert_eq!(est.limiting_factor, LimitingFactor::SharedMemory);
702    }
703
704    #[test]
705    fn test_limiting_factor_blocks() {
706        let info = DeviceOccupancyInfo {
707            max_blocks_per_sm: 4,
708            ..ampere_info()
709        };
710        let calc = OccupancyCalculator::new(info);
711        // 64 threads = 2 warps, low regs, no smem
712        // threads: 48/2 = 24, blocks: 4 => min(4, 24) = 4
713        let est = calc.estimate_occupancy(64, 16, 0);
714        assert_eq!(est.limiting_factor, LimitingFactor::Blocks);
715    }
716
717    #[test]
718    fn test_limiting_factor_none_zero_block() {
719        let calc = OccupancyCalculator::new(ampere_info());
720        let est = calc.estimate_occupancy(0, 32, 0);
721        assert_eq!(est.limiting_factor, LimitingFactor::None);
722        assert_eq!(est.active_warps_per_sm, 0);
723        assert_eq!(est.occupancy_ratio, 0.0);
724    }
725
726    // --- Block size sweep ----------------------------------------------------
727
728    #[test]
729    fn test_sweep_returns_points() {
730        let calc = OccupancyCalculator::new(ampere_info());
731        let points = OccupancyGrid::sweep(&calc, 32, 0);
732        // warp_size=32, max_threads=1536 => 1536/32 = 48 points
733        assert_eq!(points.len(), 48);
734        assert_eq!(points[0].block_size, 32);
735        assert_eq!(points[47].block_size, 1536);
736    }
737
738    #[test]
739    fn test_best_block_size() {
740        let calc = OccupancyCalculator::new(ampere_info());
741        let points = OccupancyGrid::sweep(&calc, 32, 0);
742        let best = OccupancyGrid::best_block_size(&points);
743        // Should pick a block size that gives 100% occupancy
744        assert!(best > 0);
745        assert_eq!(best % 32, 0);
746    }
747
748    #[test]
749    fn test_best_block_size_empty() {
750        assert_eq!(OccupancyGrid::best_block_size(&[]), 0);
751    }
752
753    // --- Dynamic shared memory -----------------------------------------------
754
755    #[test]
756    fn test_dynamic_smem_linear() {
757        let calc = OccupancyCalculator::new(ampere_info());
758        let smem_fn = DynamicSmemOccupancy::linear_smem(8); // 8 bytes per thread
759        let points = DynamicSmemOccupancy::with_smem_function(&calc, smem_fn, 32);
760        assert!(!points.is_empty());
761        // At block_size = 32 => 256 bytes smem; at 1024 => 8192 bytes
762        // Verify smem increases with block size by checking occupancy trend
763        // (larger blocks with more smem should eventually reduce occupancy)
764        let first_occ = points[0].occupancy;
765        let last_occ = points[points.len() - 1].occupancy;
766        // Just verify we got valid data
767        assert!((0.0..=1.0).contains(&first_occ));
768        assert!((0.0..=1.0).contains(&last_occ));
769    }
770
771    #[test]
772    fn test_dynamic_smem_tile() {
773        let calc = OccupancyCalculator::new(ampere_info());
774        let smem_fn = DynamicSmemOccupancy::tile_smem(16, 4); // 16x16 * 4B = 1024 bytes
775        let points = DynamicSmemOccupancy::with_smem_function(&calc, smem_fn, 32);
776        // Tile smem is constant (1024) regardless of block size
777        assert!(!points.is_empty());
778    }
779
780    // --- Cluster occupancy ---------------------------------------------------
781
782    #[test]
783    fn test_cluster_occupancy_basic() {
784        let calc = OccupancyCalculator::new(ampere_info());
785        let result = ClusterOccupancy::estimate_cluster_occupancy(&calc, 128, 2, 32, 4096);
786        assert_eq!(result.blocks_per_cluster, 2);
787        assert!(result.effective_occupancy >= 0.0 && result.effective_occupancy <= 1.0);
788        assert_eq!(result.cluster_smem_total, 2 * 4096);
789    }
790
791    #[test]
792    fn test_cluster_occupancy_zero_cluster() {
793        let calc = OccupancyCalculator::new(ampere_info());
794        let result = ClusterOccupancy::estimate_cluster_occupancy(&calc, 128, 0, 32, 0);
795        assert_eq!(result.clusters_per_sm, 0);
796        assert_eq!(result.effective_occupancy, 0.0);
797    }
798
799    // --- DeviceOccupancyInfo from Device (macOS synthetic) --------------------
800
801    #[test]
802    fn test_cluster_config_total_blocks() {
803        let cfg = ClusterConfig {
804            cluster_x: 2,
805            cluster_y: 3,
806            cluster_z: 4,
807        };
808        assert_eq!(cfg.total_blocks(), 24);
809    }
810
811    // --- Edge cases ----------------------------------------------------------
812
813    #[test]
814    fn test_block_size_exceeds_max() {
815        let calc = OccupancyCalculator::new(ampere_info());
816        // Block size larger than max_threads_per_sm (not strictly invalid for
817        // the estimator but should still produce a reasonable result).
818        let est = calc.estimate_occupancy(2048, 32, 0);
819        // 2048 / 32 = 64 warps per block, but only 48 max => 0 blocks fit
820        assert_eq!(est.active_warps_per_sm, 0);
821        assert_eq!(est.occupancy_ratio, 0.0);
822    }
823
824    // --- SM100 / SM120 (Blackwell) occupancy coverage ----------------------------
825
826    fn sm100_info() -> DeviceOccupancyInfo {
827        DeviceOccupancyInfo::for_compute_capability(10, 0)
828    }
829
830    fn sm120_info() -> DeviceOccupancyInfo {
831        DeviceOccupancyInfo::for_compute_capability(12, 0)
832    }
833
834    #[test]
835    fn test_sm100_device_info_attributes() {
836        let info = sm100_info();
837        assert_eq!(info.sm_count, 132, "Blackwell B100 has 132 SMs");
838        assert_eq!(info.max_threads_per_sm, 2048);
839        assert_eq!(info.max_blocks_per_sm, 32);
840        assert_eq!(info.max_shared_memory_per_sm, 262144, "256 KiB shared/SM");
841        assert_eq!(info.warp_size, 32);
842    }
843
844    #[test]
845    fn test_sm120_device_info_attributes() {
846        let info = sm120_info();
847        assert_eq!(info.sm_count, 148, "Blackwell B200 has 148 SMs");
848        assert_eq!(info.max_threads_per_sm, 2048);
849        assert_eq!(info.max_blocks_per_sm, 32);
850        assert_eq!(info.max_shared_memory_per_sm, 262144, "256 KiB shared/SM");
851        assert_eq!(info.warp_size, 32);
852    }
853
854    #[test]
855    fn test_sm100_occupancy_estimation() {
856        let calc = OccupancyCalculator::new(sm100_info());
857        // 256 threads = 8 warps/block; max_warps = 2048/32 = 64
858        // 64 / 8 = 8 blocks; limited by max_blocks_per_sm = 32 → 8 blocks fit
859        // active_warps = 8 * 8 = 64; ratio = 64/64 = 1.0 (full occupancy)
860        let est = calc.estimate_occupancy(256, 0, 0);
861        assert!(
862            est.occupancy_ratio > 0.0,
863            "Blackwell B100 must report positive occupancy"
864        );
865        assert!(
866            est.active_warps_per_sm <= 64,
867            "Active warps must not exceed hardware limit"
868        );
869    }
870
871    #[test]
872    fn test_sm120_full_occupancy() {
873        let calc = OccupancyCalculator::new(sm120_info());
874        // 64-thread block = 2 warps; 64 max warps → 32 blocks, limited by
875        // max_blocks_per_sm = 32 ⇒ 32 blocks × 2 warps = 64 warps ⇒ ratio=1.0
876        let est = calc.estimate_occupancy(64, 0, 0);
877        assert_eq!(est.occupancy_ratio, 1.0, "Should reach full occupancy");
878        assert_eq!(est.active_warps_per_sm, 64);
879    }
880
881    #[test]
882    fn test_sm100_large_shared_memory_limit() {
883        let calc = OccupancyCalculator::new(sm100_info());
884        // 128 KiB per block: 262144 / 131072 = 2 blocks fit
885        let smem_per_block = 131_072u32;
886        let est = calc.estimate_occupancy(1024, 0, smem_per_block);
887        // active_blocks ≤ 2; active_warps = 2 × (1024/32) = 2 × 32 = 64
888        assert!(
889            matches!(est.limiting_factor, LimitingFactor::SharedMemory),
890            "Large smem must be the bottleneck"
891        );
892    }
893
894    #[test]
895    fn test_for_compute_capability_unknown_falls_back() {
896        // An architecture not in the table must return a sane fallback.
897        let info = DeviceOccupancyInfo::for_compute_capability(99, 99);
898        let calc = OccupancyCalculator::new(info);
899        let est = calc.estimate_occupancy(256, 0, 0);
900        assert!(est.occupancy_ratio > 0.0);
901    }
902
903    #[test]
904    fn test_sm100_vs_sm90_shared_memory_capacity() {
905        let hopper = DeviceOccupancyInfo::for_compute_capability(9, 0);
906        let blackwell = sm100_info();
907        // Blackwell has strictly more shared memory per SM than Hopper.
908        assert!(
909            blackwell.max_shared_memory_per_sm > hopper.max_shared_memory_per_sm,
910            "Blackwell B100 must have larger smem than Hopper H100"
911        );
912    }
913
914    #[test]
915    fn test_sm120_vs_sm100_sm_count() {
916        let b100 = sm100_info();
917        let b200 = sm120_info();
918        // B200 has more SMs than B100.
919        assert!(
920            b200.sm_count > b100.sm_count,
921            "Blackwell B200 must have more SMs than B100"
922        );
923    }
924
925    #[test]
926    fn test_for_compute_capability_all_known_arches() {
927        // Ensure all known architectures parse without panic and return
928        // warp_size == 32.
929        let arches = [(7, 5), (8, 0), (8, 6), (8, 9), (9, 0), (10, 0), (12, 0)];
930        for (major, minor) in arches {
931            let info = DeviceOccupancyInfo::for_compute_capability(major, minor);
932            assert_eq!(info.warp_size, 32, "sm_{major}{minor} warp_size must be 32");
933            assert!(info.sm_count > 0);
934            assert!(info.max_threads_per_sm > 0);
935        }
936    }
937}