1use std::sync::Arc;
18
19use parking_lot::Mutex;
20
21use crate::conv::CutlassConvDispatch;
22use crate::gemm::{CutlassGemmDispatch, RefitMsg};
23use crate::plan_cache::{CachedPlan, PlanCache};
24
25#[cfg(feature = "grouped")]
26use crate::grouped_gemm::CutlassGroupedGemmDispatch;
27
28pub enum CutlassMsg {
37 Gemm(Box<dyn CutlassGemmDispatch>),
38 #[cfg(feature = "grouped")]
39 GroupedGemm(Box<dyn CutlassGroupedGemmDispatch>),
40 Conv(Box<dyn CutlassConvDispatch>),
41 Refit {
42 msg: RefitMsg,
43 reply: Box<dyn FnOnce(Result<(), String>) + Send + 'static>,
44 },
45}
46
47impl std::fmt::Debug for CutlassMsg {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 CutlassMsg::Gemm(d) => f
51 .debug_struct("Gemm")
52 .field("dtype", &d.dtype())
53 .field("arch", &d.arch())
54 .finish(),
55 #[cfg(feature = "grouped")]
56 CutlassMsg::GroupedGemm(d) => f
57 .debug_struct("GroupedGemm")
58 .field("dtype", &d.dtype())
59 .field("arch", &d.arch())
60 .field("group_count", &d.group_count())
61 .finish(),
62 CutlassMsg::Conv(d) => f
63 .debug_struct("Conv")
64 .field("kind", &d.kind_name())
65 .field("dtype", &d.dtype())
66 .field("arch", &d.arch())
67 .finish(),
68 CutlassMsg::Refit { msg, .. } => f
69 .debug_struct("Refit")
70 .field("plan_key", &msg.plan_key)
71 .field("weights_len", &msg.weights.len())
72 .finish(),
73 }
74 }
75}
76
77pub type CompileSink = Arc<dyn Fn(&str, &str) -> Result<(), String> + Send + Sync>;
81
82pub struct CutlassInner {
84 pub plan_cache: Arc<PlanCache>,
85 pub compile_sink: Option<CompileSink>,
91 pub dispatched: Mutex<u64>,
94}
95
96impl CutlassInner {
97 pub fn new(plan_cache_capacity: usize) -> Self {
98 Self {
99 plan_cache: Arc::new(PlanCache::new(plan_cache_capacity)),
100 compile_sink: None,
101 dispatched: Mutex::new(0),
102 }
103 }
104
105 pub fn dispatched(&self) -> u64 {
106 *self.dispatched.lock()
107 }
108}
109
110pub struct CutlassActor {
114 inner: Arc<CutlassInner>,
115}
116
117impl CutlassActor {
118 pub fn new(plan_cache_capacity: usize) -> Self {
119 Self {
120 inner: Arc::new(CutlassInner::new(plan_cache_capacity)),
121 }
122 }
123
124 pub fn prebuilt_active() -> bool {
131 cfg!(cutlass_prebuilt_active)
132 }
133
134 pub fn inner(&self) -> Arc<CutlassInner> {
135 self.inner.clone()
136 }
137
138 pub fn handle(&self, msg: CutlassMsg) {
142 *self.inner.dispatched.lock() += 1;
143 match msg {
144 CutlassMsg::Gemm(d) => {
145 let key = d.plan_key();
146 if self.inner.plan_cache.get(&key).is_none() {
147 let (src, name) = d.render_cu();
148 if let Some(sink) = &self.inner.compile_sink {
149 if let Err(e) = sink(&src, &name) {
150 tracing::warn!(error = %e, "cutlass compile sink rejected source");
151 }
152 }
153 self.inner.plan_cache.insert(CachedPlan {
154 key,
155 source: Arc::new(src),
156 kernel_name: Arc::new(name),
157 kernel_handle: None,
158 });
159 }
160 }
161 #[cfg(feature = "grouped")]
162 CutlassMsg::GroupedGemm(d) => {
163 let key = d.plan_key();
164 if self.inner.plan_cache.get(&key).is_none() {
165 let (src, name) = d.render_cu();
166 if let Some(sink) = &self.inner.compile_sink {
167 let _ = sink(&src, &name);
168 }
169 self.inner.plan_cache.insert(CachedPlan {
170 key,
171 source: Arc::new(src),
172 kernel_name: Arc::new(name),
173 kernel_handle: None,
174 });
175 }
176 }
177 CutlassMsg::Conv(d) => {
178 let key = d.plan_key();
179 if self.inner.plan_cache.get(&key).is_none() {
180 let (src, name) = d.render_cu();
181 if let Some(sink) = &self.inner.compile_sink {
182 let _ = sink(&src, &name);
183 }
184 self.inner.plan_cache.insert(CachedPlan {
185 key,
186 source: Arc::new(src),
187 kernel_name: Arc::new(name),
188 kernel_handle: None,
189 });
190 }
191 }
192 CutlassMsg::Refit { msg, reply } => {
193 let exists = self.inner.plan_cache.get(&msg.plan_key).is_some();
194 if exists {
195 reply(Ok(()));
196 } else {
197 reply(Err(format!(
198 "cutlass refit: no plan for key {:?}",
199 msg.plan_key
200 )));
201 }
202 }
203 }
204 }
205}
206
207#[derive(Debug, Clone, Copy)]
213pub struct CutlassProps {
214 pub plan_cache_capacity: usize,
215}
216
217impl CutlassProps {
218 pub fn new(plan_cache_capacity: usize) -> Self {
219 Self {
220 plan_cache_capacity,
221 }
222 }
223
224 pub fn create(self) -> CutlassActor {
225 CutlassActor::new(self.plan_cache_capacity)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::dtype::{SmArch, F16};
233 use crate::gemm::{GemmRequest, GemmShape};
234 use crate::plan_cache::PlanKey;
235
236 #[test]
237 fn cutlass_msg_constructs() {
238 let actor = CutlassActor::new(8);
239 let req = GemmRequest::<F16>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
240 let key = req.plan_key();
241
242 actor.handle(CutlassMsg::Gemm(Box::new(req.clone())));
244 assert_eq!(actor.inner().dispatched(), 1);
245 assert!(actor.inner().plan_cache.get(&key).is_some());
246
247 use crate::conv::{ConvFwdRequest, ConvShape};
249 let conv = ConvFwdRequest::<F16>::new(ConvShape::nhwc(1, 8, 8, 16, 32, 3, 3), SmArch::Sm80);
250 let conv_key = conv.plan_key();
251 actor.handle(CutlassMsg::Conv(Box::new(conv)));
252 assert_eq!(actor.inner().dispatched(), 2);
253 assert!(actor.inner().plan_cache.get(&conv_key).is_some());
254
255 let (tx, rx) = std::sync::mpsc::channel();
257 actor.handle(CutlassMsg::Refit {
258 msg: RefitMsg {
259 plan_key: key,
260 weights: vec![0u8; 16],
261 },
262 reply: Box::new(move |r| {
263 let _ = tx.send(r);
264 }),
265 });
266 let res = rx.recv().unwrap();
267 assert!(res.is_ok());
268
269 let bogus = PlanKey::gemm::<F16>(
271 GemmShape::new(1, 1, 1),
272 crate::gemm::GemmLayout::RowMajor,
273 crate::gemm::GemmLayout::RowMajor,
274 crate::gemm::GemmLayout::RowMajor,
275 crate::gemm::GemmEpilogue::default(),
276 crate::dtype::CutlassDtype::F32,
277 crate::dtype::CutlassDtype::F16,
278 SmArch::Sm80,
279 false,
280 );
281 let (tx, rx) = std::sync::mpsc::channel();
282 actor.handle(CutlassMsg::Refit {
283 msg: RefitMsg {
284 plan_key: bogus,
285 weights: vec![],
286 },
287 reply: Box::new(move |r| {
288 let _ = tx.send(r);
289 }),
290 });
291 let res = rx.recv().unwrap();
292 assert!(res.is_err());
293
294 let before = actor.inner().plan_cache.len();
297 actor.handle(CutlassMsg::Gemm(Box::new(req)));
298 let after = actor.inner().plan_cache.len();
299 assert_eq!(before, after);
300 }
301
302 #[cfg(feature = "grouped")]
303 #[test]
304 fn grouped_dispatch() {
305 use crate::grouped_gemm::{GroupedGemmRequest, GroupedGemmShape};
306 let actor = CutlassActor::new(4);
307 let req = GroupedGemmRequest::<F16>::new(
308 GroupedGemmShape::new(vec![GemmShape::new(64, 64, 64)]),
309 SmArch::Sm90a,
310 );
311 let key = req.plan_key();
312 actor.handle(CutlassMsg::GroupedGemm(Box::new(req)));
313 assert!(actor.inner().plan_cache.get(&key).is_some());
314 }
315}