Skip to main content

atomr_accel_cutlass/
plan_cache.rs

1//! Plan cache: in-process LRU keyed by template instantiation
2//! identity. Backs the `CutlassActor` so that repeated GEMM messages
3//! with identical `(template_id, shape, dtype, arch)` skip both the
4//! Rust-side template render and the NVRTC compile.
5//!
6//! The on-disk cache used by `NvrtcActor` (Phase 0.6) takes care of
7//! cross-process / cross-restart caching; this LRU is purely an
8//! in-process win to avoid rendering the same `.cu` source string
9//! repeatedly.
10
11use std::any::Any;
12use std::num::NonZeroUsize;
13use std::sync::Arc;
14
15use lru::LruCache;
16use parking_lot::Mutex;
17
18use crate::conv::{ConvKind, ConvLayout, ConvShape};
19use crate::dtype::{CutlassDtype, GemmSupported, SmArch};
20use crate::gemm::{GemmEpilogue, GemmLayout, GemmShape};
21#[cfg(feature = "grouped")]
22use crate::grouped_gemm::GroupedLayout;
23
24/// Plan-cache key. `u128` is chosen so all Phase 6 templates fit
25/// without collisions while staying small enough to copy by value.
26///
27/// The key is opaque on purpose: callers should construct it via the
28/// dedicated `PlanKey::gemm` / `PlanKey::grouped_gemm` / `PlanKey::conv`
29/// constructors so we control the layout.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub struct PlanKey {
32    /// Discriminator: 1 = gemm, 2 = grouped_gemm, 3 = conv.
33    template_id: u8,
34    /// Packed shape/layout/dtype/arch identity. The exact layout is
35    /// internal to this module; callers that need to inspect the key
36    /// use the read-only accessors below.
37    payload: [u64; 3],
38}
39
40impl PlanKey {
41    /// Element size of the cache key in bytes — used by the cache size
42    /// reporter and by callers that want to budget the LRU.
43    pub const SIZE_BYTES: usize = core::mem::size_of::<PlanKey>();
44
45    pub fn template_id(&self) -> u8 {
46        self.template_id
47    }
48
49    /// Build a key for a single GEMM.
50    #[allow(clippy::too_many_arguments)]
51    pub fn gemm<T: GemmSupported>(
52        shape: GemmShape,
53        layout_a: GemmLayout,
54        layout_b: GemmLayout,
55        layout_c: GemmLayout,
56        epilogue: GemmEpilogue,
57        accum: CutlassDtype,
58        out: CutlassDtype,
59        arch: SmArch,
60        persistent: bool,
61    ) -> Self {
62        let mut h = Hasher::new();
63        h.add_u32(shape.m);
64        h.add_u32(shape.n);
65        h.add_u32(shape.k);
66        h.add_u8(layout_a as u8);
67        h.add_u8(layout_b as u8);
68        h.add_u8(layout_c as u8);
69        h.add_str(T::DTYPE.short_name());
70        h.add_str(accum.short_name());
71        h.add_str(out.short_name());
72        h.add_str(arch.short_name());
73        h.add_u8(persistent as u8);
74        h.add_str(epilogue.short_name());
75        match epilogue {
76            GemmEpilogue::Linear { alpha, beta }
77            | GemmEpilogue::LinearReLU { alpha, beta }
78            | GemmEpilogue::LinearGelu { alpha, beta } => {
79                h.add_u32(alpha.to_bits());
80                h.add_u32(beta.to_bits());
81            }
82        }
83        Self {
84            template_id: 1,
85            payload: h.finish(),
86        }
87    }
88
89    /// Build a key for a grouped GEMM. `shape_summary` is the
90    /// `(max_m, max_n, max_k, group_count)` tuple from
91    /// `GroupedGemmShape::summary`.
92    #[cfg(feature = "grouped")]
93    #[allow(clippy::too_many_arguments)]
94    pub fn grouped_gemm<T: GemmSupported>(
95        shape_summary: (u32, u32, u32, usize),
96        layout_a: GemmLayout,
97        layout_b: GemmLayout,
98        layout_c: GemmLayout,
99        grouped_layout: GroupedLayout,
100        epilogue: GemmEpilogue,
101        accum: CutlassDtype,
102        out: CutlassDtype,
103        arch: SmArch,
104        persistent: bool,
105    ) -> Self {
106        let mut h = Hasher::new();
107        h.add_u32(shape_summary.0);
108        h.add_u32(shape_summary.1);
109        h.add_u32(shape_summary.2);
110        h.add_u32(shape_summary.3 as u32);
111        h.add_u8(layout_a as u8);
112        h.add_u8(layout_b as u8);
113        h.add_u8(layout_c as u8);
114        h.add_str(grouped_layout.short_name());
115        h.add_str(T::DTYPE.short_name());
116        h.add_str(accum.short_name());
117        h.add_str(out.short_name());
118        h.add_str(arch.short_name());
119        h.add_u8(persistent as u8);
120        h.add_str(epilogue.short_name());
121        Self {
122            template_id: 2,
123            payload: h.finish(),
124        }
125    }
126
127    /// Stub overload kept available even when the `grouped` feature
128    /// is off, so callers that conditionally produce grouped keys
129    /// still link. The non-`grouped` build returns a deterministic
130    /// placeholder that distinguishes itself from the gemm/conv keys.
131    #[cfg(not(feature = "grouped"))]
132    #[allow(dead_code)]
133    pub(crate) fn grouped_gemm_unsupported() -> Self {
134        Self {
135            template_id: 2,
136            payload: [0, 0, 0],
137        }
138    }
139
140    pub(crate) fn conv<T: GemmSupported>(
141        kind: ConvKind,
142        shape: ConvShape,
143        layout: ConvLayout,
144        accum: CutlassDtype,
145        out: CutlassDtype,
146        arch: SmArch,
147    ) -> Self {
148        let mut h = Hasher::new();
149        h.add_str(kind.short_name());
150        h.add_u32(shape.n);
151        h.add_u32(shape.h);
152        h.add_u32(shape.w);
153        h.add_u32(shape.c);
154        h.add_u32(shape.k);
155        h.add_u32(shape.r);
156        h.add_u32(shape.s);
157        h.add_u32(shape.pad_h);
158        h.add_u32(shape.pad_w);
159        h.add_u32(shape.stride_h);
160        h.add_u32(shape.stride_w);
161        h.add_u32(shape.dil_h);
162        h.add_u32(shape.dil_w);
163        h.add_str(layout.short_name());
164        h.add_str(T::DTYPE.short_name());
165        h.add_str(accum.short_name());
166        h.add_str(out.short_name());
167        h.add_str(arch.short_name());
168        Self {
169            template_id: 3,
170            payload: h.finish(),
171        }
172    }
173}
174
175/// Compact hashing helper. Splits the SipHash output into three u64
176/// lanes so the resulting key is `Eq` on the underlying bytes
177/// without collisions across template kinds.
178struct Hasher {
179    a: std::collections::hash_map::DefaultHasher,
180    b: std::collections::hash_map::DefaultHasher,
181    c: std::collections::hash_map::DefaultHasher,
182}
183
184impl Hasher {
185    fn new() -> Self {
186        use std::hash::Hasher as _;
187        let mut a = std::collections::hash_map::DefaultHasher::new();
188        let mut b = std::collections::hash_map::DefaultHasher::new();
189        let mut c = std::collections::hash_map::DefaultHasher::new();
190        a.write_u64(0xA5A5_A5A5_A5A5_A5A5);
191        b.write_u64(0x5A5A_5A5A_5A5A_5A5A);
192        c.write_u64(0xC3C3_C3C3_C3C3_C3C3);
193        Self { a, b, c }
194    }
195
196    fn add_u8(&mut self, v: u8) {
197        use std::hash::Hasher as _;
198        self.a.write_u8(v);
199        self.b.write_u8(v.wrapping_add(0x55));
200        self.c.write_u8(v.wrapping_add(0xAA));
201    }
202
203    fn add_u32(&mut self, v: u32) {
204        use std::hash::Hasher as _;
205        self.a.write_u32(v);
206        self.b.write_u32(v.rotate_left(11));
207        self.c.write_u32(v.rotate_left(23));
208    }
209
210    fn add_str(&mut self, s: &str) {
211        use std::hash::Hasher as _;
212        self.a.write(s.as_bytes());
213        self.b.write(s.as_bytes());
214        self.c.write(s.as_bytes());
215    }
216
217    fn finish(self) -> [u64; 3] {
218        use std::hash::Hasher as _;
219        [self.a.finish(), self.b.finish(), self.c.finish()]
220    }
221}
222
223/// Cached plan entry. The payload is type-erased so a single cache
224/// can hold gemm / grouped-gemm / conv plans side by side.
225pub struct CachedPlan {
226    pub key: PlanKey,
227    pub source: Arc<String>,
228    pub kernel_name: Arc<String>,
229    /// Optional opaque payload, e.g. the post-NVRTC `KernelHandle`.
230    pub kernel_handle: Option<Arc<dyn Any + Send + Sync>>,
231}
232
233impl core::fmt::Debug for CachedPlan {
234    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
235        f.debug_struct("CachedPlan")
236            .field("key", &self.key)
237            .field("kernel_name", &*self.kernel_name)
238            .field("source_len", &self.source.len())
239            .field("has_kernel_handle", &self.kernel_handle.is_some())
240            .finish()
241    }
242}
243
244/// LRU plan cache. Default capacity is 64 entries — a single device
245/// rarely keeps more than that many distinct CUTLASS template
246/// instantiations live at once, and at ~64 KiB / cached `.cu` source
247/// the cache stays well under a megabyte.
248pub struct PlanCache {
249    inner: Mutex<LruCache<PlanKey, Arc<CachedPlan>>>,
250    capacity: usize,
251}
252
253impl PlanCache {
254    pub fn new(capacity: usize) -> Self {
255        let cap = NonZeroUsize::new(capacity.max(1)).expect("capacity > 0");
256        Self {
257            inner: Mutex::new(LruCache::new(cap)),
258            capacity: cap.get(),
259        }
260    }
261
262    pub fn capacity(&self) -> usize {
263        self.capacity
264    }
265
266    pub fn len(&self) -> usize {
267        self.inner.lock().len()
268    }
269
270    pub fn is_empty(&self) -> bool {
271        self.len() == 0
272    }
273
274    pub fn get(&self, key: &PlanKey) -> Option<Arc<CachedPlan>> {
275        self.inner.lock().get(key).cloned()
276    }
277
278    pub fn insert(&self, plan: CachedPlan) -> Arc<CachedPlan> {
279        let key = plan.key;
280        let arc = Arc::new(plan);
281        self.inner.lock().put(key, arc.clone());
282        arc
283    }
284
285    pub fn clear(&self) {
286        self.inner.lock().clear();
287    }
288}
289
290impl Default for PlanCache {
291    fn default() -> Self {
292        Self::new(64)
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::dtype::F16;
300    use crate::gemm::{GemmLayout, GemmShape};
301
302    fn k(m: u32) -> PlanKey {
303        PlanKey::gemm::<F16>(
304            GemmShape::new(m, 64, 64),
305            GemmLayout::RowMajor,
306            GemmLayout::RowMajor,
307            GemmLayout::RowMajor,
308            GemmEpilogue::default(),
309            CutlassDtype::F32,
310            CutlassDtype::F16,
311            SmArch::Sm80,
312            false,
313        )
314    }
315
316    #[test]
317    fn plan_cache_lru_round_trip() {
318        let cache = PlanCache::new(2);
319        assert_eq!(cache.capacity(), 2);
320        assert!(cache.is_empty());
321
322        let p1 = cache.insert(CachedPlan {
323            key: k(1),
324            source: Arc::new("a".into()),
325            kernel_name: Arc::new("k1".into()),
326            kernel_handle: None,
327        });
328        let p2 = cache.insert(CachedPlan {
329            key: k(2),
330            source: Arc::new("b".into()),
331            kernel_name: Arc::new("k2".into()),
332            kernel_handle: None,
333        });
334        assert_eq!(cache.len(), 2);
335
336        // Hit on p1 promotes it; inserting p3 must evict p2.
337        let _ = cache.get(&p1.key).unwrap();
338        let _ = cache.insert(CachedPlan {
339            key: k(3),
340            source: Arc::new("c".into()),
341            kernel_name: Arc::new("k3".into()),
342            kernel_handle: None,
343        });
344        assert_eq!(cache.len(), 2);
345        assert!(cache.get(&p2.key).is_none());
346        assert!(cache.get(&p1.key).is_some());
347
348        // PlanKey size is 4 bytes (template_id padded) + 24 bytes
349        // payload — guarded so a future struct-layout change can't
350        // accidentally explode the cache memory budget.
351        const _: () = assert!(PlanKey::SIZE_BYTES <= 64);
352
353        // Distinct shapes -> distinct keys.
354        assert_ne!(k(1), k(2));
355
356        // Clear empties the cache.
357        cache.clear();
358        assert!(cache.is_empty());
359    }
360
361    #[test]
362    fn plan_keys_distinct_across_template_kinds() {
363        let gemm = k(1);
364        let conv = PlanKey::conv::<F16>(
365            ConvKind::Fprop,
366            ConvShape::nhwc(1, 1, 1, 1, 1, 1, 1),
367            ConvLayout::Nhwc,
368            CutlassDtype::F32,
369            CutlassDtype::F16,
370            SmArch::Sm80,
371        );
372        assert_ne!(gemm, conv);
373        assert_ne!(gemm.template_id(), conv.template_id());
374    }
375}