1use std::any::Any;
12use std::num::NonZeroUsize;
13use std::sync::Arc;
14
15use lru::LruCache;
16use parking_lot::Mutex;
17
18use crate::conv::{ConvKind, ConvLayout, ConvShape};
19use crate::dtype::{CutlassDtype, GemmSupported, SmArch};
20use crate::gemm::{GemmEpilogue, GemmLayout, GemmShape};
21#[cfg(feature = "grouped")]
22use crate::grouped_gemm::GroupedLayout;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub struct PlanKey {
32 template_id: u8,
34 payload: [u64; 3],
38}
39
40impl PlanKey {
41 pub const SIZE_BYTES: usize = core::mem::size_of::<PlanKey>();
44
45 pub fn template_id(&self) -> u8 {
46 self.template_id
47 }
48
49 #[allow(clippy::too_many_arguments)]
51 pub fn gemm<T: GemmSupported>(
52 shape: GemmShape,
53 layout_a: GemmLayout,
54 layout_b: GemmLayout,
55 layout_c: GemmLayout,
56 epilogue: GemmEpilogue,
57 accum: CutlassDtype,
58 out: CutlassDtype,
59 arch: SmArch,
60 persistent: bool,
61 ) -> Self {
62 let mut h = Hasher::new();
63 h.add_u32(shape.m);
64 h.add_u32(shape.n);
65 h.add_u32(shape.k);
66 h.add_u8(layout_a as u8);
67 h.add_u8(layout_b as u8);
68 h.add_u8(layout_c as u8);
69 h.add_str(T::DTYPE.short_name());
70 h.add_str(accum.short_name());
71 h.add_str(out.short_name());
72 h.add_str(arch.short_name());
73 h.add_u8(persistent as u8);
74 h.add_str(epilogue.short_name());
75 match epilogue {
76 GemmEpilogue::Linear { alpha, beta }
77 | GemmEpilogue::LinearReLU { alpha, beta }
78 | GemmEpilogue::LinearGelu { alpha, beta } => {
79 h.add_u32(alpha.to_bits());
80 h.add_u32(beta.to_bits());
81 }
82 }
83 Self {
84 template_id: 1,
85 payload: h.finish(),
86 }
87 }
88
89 #[cfg(feature = "grouped")]
93 #[allow(clippy::too_many_arguments)]
94 pub fn grouped_gemm<T: GemmSupported>(
95 shape_summary: (u32, u32, u32, usize),
96 layout_a: GemmLayout,
97 layout_b: GemmLayout,
98 layout_c: GemmLayout,
99 grouped_layout: GroupedLayout,
100 epilogue: GemmEpilogue,
101 accum: CutlassDtype,
102 out: CutlassDtype,
103 arch: SmArch,
104 persistent: bool,
105 ) -> Self {
106 let mut h = Hasher::new();
107 h.add_u32(shape_summary.0);
108 h.add_u32(shape_summary.1);
109 h.add_u32(shape_summary.2);
110 h.add_u32(shape_summary.3 as u32);
111 h.add_u8(layout_a as u8);
112 h.add_u8(layout_b as u8);
113 h.add_u8(layout_c as u8);
114 h.add_str(grouped_layout.short_name());
115 h.add_str(T::DTYPE.short_name());
116 h.add_str(accum.short_name());
117 h.add_str(out.short_name());
118 h.add_str(arch.short_name());
119 h.add_u8(persistent as u8);
120 h.add_str(epilogue.short_name());
121 Self {
122 template_id: 2,
123 payload: h.finish(),
124 }
125 }
126
127 #[cfg(not(feature = "grouped"))]
132 #[allow(dead_code)]
133 pub(crate) fn grouped_gemm_unsupported() -> Self {
134 Self {
135 template_id: 2,
136 payload: [0, 0, 0],
137 }
138 }
139
140 pub(crate) fn conv<T: GemmSupported>(
141 kind: ConvKind,
142 shape: ConvShape,
143 layout: ConvLayout,
144 accum: CutlassDtype,
145 out: CutlassDtype,
146 arch: SmArch,
147 ) -> Self {
148 let mut h = Hasher::new();
149 h.add_str(kind.short_name());
150 h.add_u32(shape.n);
151 h.add_u32(shape.h);
152 h.add_u32(shape.w);
153 h.add_u32(shape.c);
154 h.add_u32(shape.k);
155 h.add_u32(shape.r);
156 h.add_u32(shape.s);
157 h.add_u32(shape.pad_h);
158 h.add_u32(shape.pad_w);
159 h.add_u32(shape.stride_h);
160 h.add_u32(shape.stride_w);
161 h.add_u32(shape.dil_h);
162 h.add_u32(shape.dil_w);
163 h.add_str(layout.short_name());
164 h.add_str(T::DTYPE.short_name());
165 h.add_str(accum.short_name());
166 h.add_str(out.short_name());
167 h.add_str(arch.short_name());
168 Self {
169 template_id: 3,
170 payload: h.finish(),
171 }
172 }
173}
174
175struct Hasher {
179 a: std::collections::hash_map::DefaultHasher,
180 b: std::collections::hash_map::DefaultHasher,
181 c: std::collections::hash_map::DefaultHasher,
182}
183
184impl Hasher {
185 fn new() -> Self {
186 use std::hash::Hasher as _;
187 let mut a = std::collections::hash_map::DefaultHasher::new();
188 let mut b = std::collections::hash_map::DefaultHasher::new();
189 let mut c = std::collections::hash_map::DefaultHasher::new();
190 a.write_u64(0xA5A5_A5A5_A5A5_A5A5);
191 b.write_u64(0x5A5A_5A5A_5A5A_5A5A);
192 c.write_u64(0xC3C3_C3C3_C3C3_C3C3);
193 Self { a, b, c }
194 }
195
196 fn add_u8(&mut self, v: u8) {
197 use std::hash::Hasher as _;
198 self.a.write_u8(v);
199 self.b.write_u8(v.wrapping_add(0x55));
200 self.c.write_u8(v.wrapping_add(0xAA));
201 }
202
203 fn add_u32(&mut self, v: u32) {
204 use std::hash::Hasher as _;
205 self.a.write_u32(v);
206 self.b.write_u32(v.rotate_left(11));
207 self.c.write_u32(v.rotate_left(23));
208 }
209
210 fn add_str(&mut self, s: &str) {
211 use std::hash::Hasher as _;
212 self.a.write(s.as_bytes());
213 self.b.write(s.as_bytes());
214 self.c.write(s.as_bytes());
215 }
216
217 fn finish(self) -> [u64; 3] {
218 use std::hash::Hasher as _;
219 [self.a.finish(), self.b.finish(), self.c.finish()]
220 }
221}
222
223pub struct CachedPlan {
226 pub key: PlanKey,
227 pub source: Arc<String>,
228 pub kernel_name: Arc<String>,
229 pub kernel_handle: Option<Arc<dyn Any + Send + Sync>>,
231}
232
233impl core::fmt::Debug for CachedPlan {
234 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
235 f.debug_struct("CachedPlan")
236 .field("key", &self.key)
237 .field("kernel_name", &*self.kernel_name)
238 .field("source_len", &self.source.len())
239 .field("has_kernel_handle", &self.kernel_handle.is_some())
240 .finish()
241 }
242}
243
244pub struct PlanCache {
249 inner: Mutex<LruCache<PlanKey, Arc<CachedPlan>>>,
250 capacity: usize,
251}
252
253impl PlanCache {
254 pub fn new(capacity: usize) -> Self {
255 let cap = NonZeroUsize::new(capacity.max(1)).expect("capacity > 0");
256 Self {
257 inner: Mutex::new(LruCache::new(cap)),
258 capacity: cap.get(),
259 }
260 }
261
262 pub fn capacity(&self) -> usize {
263 self.capacity
264 }
265
266 pub fn len(&self) -> usize {
267 self.inner.lock().len()
268 }
269
270 pub fn is_empty(&self) -> bool {
271 self.len() == 0
272 }
273
274 pub fn get(&self, key: &PlanKey) -> Option<Arc<CachedPlan>> {
275 self.inner.lock().get(key).cloned()
276 }
277
278 pub fn insert(&self, plan: CachedPlan) -> Arc<CachedPlan> {
279 let key = plan.key;
280 let arc = Arc::new(plan);
281 self.inner.lock().put(key, arc.clone());
282 arc
283 }
284
285 pub fn clear(&self) {
286 self.inner.lock().clear();
287 }
288}
289
290impl Default for PlanCache {
291 fn default() -> Self {
292 Self::new(64)
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::dtype::F16;
300 use crate::gemm::{GemmLayout, GemmShape};
301
302 fn k(m: u32) -> PlanKey {
303 PlanKey::gemm::<F16>(
304 GemmShape::new(m, 64, 64),
305 GemmLayout::RowMajor,
306 GemmLayout::RowMajor,
307 GemmLayout::RowMajor,
308 GemmEpilogue::default(),
309 CutlassDtype::F32,
310 CutlassDtype::F16,
311 SmArch::Sm80,
312 false,
313 )
314 }
315
316 #[test]
317 fn plan_cache_lru_round_trip() {
318 let cache = PlanCache::new(2);
319 assert_eq!(cache.capacity(), 2);
320 assert!(cache.is_empty());
321
322 let p1 = cache.insert(CachedPlan {
323 key: k(1),
324 source: Arc::new("a".into()),
325 kernel_name: Arc::new("k1".into()),
326 kernel_handle: None,
327 });
328 let p2 = cache.insert(CachedPlan {
329 key: k(2),
330 source: Arc::new("b".into()),
331 kernel_name: Arc::new("k2".into()),
332 kernel_handle: None,
333 });
334 assert_eq!(cache.len(), 2);
335
336 let _ = cache.get(&p1.key).unwrap();
338 let _ = cache.insert(CachedPlan {
339 key: k(3),
340 source: Arc::new("c".into()),
341 kernel_name: Arc::new("k3".into()),
342 kernel_handle: None,
343 });
344 assert_eq!(cache.len(), 2);
345 assert!(cache.get(&p2.key).is_none());
346 assert!(cache.get(&p1.key).is_some());
347
348 const _: () = assert!(PlanKey::SIZE_BYTES <= 64);
352
353 assert_ne!(k(1), k(2));
355
356 cache.clear();
358 assert!(cache.is_empty());
359 }
360
361 #[test]
362 fn plan_keys_distinct_across_template_kinds() {
363 let gemm = k(1);
364 let conv = PlanKey::conv::<F16>(
365 ConvKind::Fprop,
366 ConvShape::nhwc(1, 1, 1, 1, 1, 1, 1),
367 ConvLayout::Nhwc,
368 CutlassDtype::F32,
369 CutlassDtype::F16,
370 SmArch::Sm80,
371 );
372 assert_ne!(gemm, conv);
373 assert_ne!(gemm.template_id(), conv.template_id());
374 }
375}