Skip to main content

atomr_accel_cutlass/
actor.rs

1//! `CutlassActor` — host-side dispatcher for CUTLASS template
2//! instantiations.
3//!
4//! The actor is intentionally GPU-agnostic at the type level:
5//! launching a compiled kernel goes through
6//! `atomr_accel_cuda::kernel::NvrtcActor` once that crate's `nvrtc`
7//! feature is on. When no NVRTC actor is wired in (e.g. the host-only
8//! test runs that this crate ships), the actor records the rendered
9//! `.cu` source plus the lowered kernel name in the
10//! [`crate::plan_cache::PlanCache`] and replies with the cache hit.
11//!
12//! Wiring into `KernelChildren::register_extra` is left to the
13//! caller: this crate doesn't reach into the device actor, but it
14//! exposes [`CutlassProps`] so a downstream `ContextActor` can
15//! `register_extra("cutlass", atomr_accel_cutlass::props(64))`.
16
17use std::sync::Arc;
18
19use parking_lot::Mutex;
20
21use crate::conv::CutlassConvDispatch;
22use crate::gemm::{CutlassGemmDispatch, RefitMsg};
23use crate::plan_cache::{CachedPlan, PlanCache};
24
25#[cfg(feature = "grouped")]
26use crate::grouped_gemm::CutlassGroupedGemmDispatch;
27
28/// Top-level actor mailbox.
29///
30/// `Refit` carries new weight bytes for an already-compiled plan;
31/// the actor copies them into the kernel's bound workspace without
32/// recompiling. `reply` is an opaque sender that downstream code
33/// can pick — we don't depend on `tokio::sync::oneshot` here so that
34/// host-only callers can use a `std::sync::mpsc` reply channel
35/// instead.
36pub enum CutlassMsg {
37    Gemm(Box<dyn CutlassGemmDispatch>),
38    #[cfg(feature = "grouped")]
39    GroupedGemm(Box<dyn CutlassGroupedGemmDispatch>),
40    Conv(Box<dyn CutlassConvDispatch>),
41    Refit {
42        msg: RefitMsg,
43        reply: Box<dyn FnOnce(Result<(), String>) + Send + 'static>,
44    },
45}
46
47impl std::fmt::Debug for CutlassMsg {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        match self {
50            CutlassMsg::Gemm(d) => f
51                .debug_struct("Gemm")
52                .field("dtype", &d.dtype())
53                .field("arch", &d.arch())
54                .finish(),
55            #[cfg(feature = "grouped")]
56            CutlassMsg::GroupedGemm(d) => f
57                .debug_struct("GroupedGemm")
58                .field("dtype", &d.dtype())
59                .field("arch", &d.arch())
60                .field("group_count", &d.group_count())
61                .finish(),
62            CutlassMsg::Conv(d) => f
63                .debug_struct("Conv")
64                .field("kind", &d.kind_name())
65                .field("dtype", &d.dtype())
66                .field("arch", &d.arch())
67                .finish(),
68            CutlassMsg::Refit { msg, .. } => f
69                .debug_struct("Refit")
70                .field("plan_key", &msg.plan_key)
71                .field("weights_len", &msg.weights.len())
72                .finish(),
73        }
74    }
75}
76
77/// Closure invoked for every rendered `.cu` source — typically wired
78/// to `atomr-accel-cuda::kernel::NvrtcActor::Compile` when the
79/// downstream binary opts into `nvrtc`.
80pub type CompileSink = Arc<dyn Fn(&str, &str) -> Result<(), String> + Send + Sync>;
81
82/// Inner state of a [`CutlassActor`].
83pub struct CutlassInner {
84    pub plan_cache: Arc<PlanCache>,
85    /// Optional NVRTC-shaped sink: when present, the actor forwards
86    /// rendered `.cu` source to it for compilation. Left as a generic
87    /// `Box<dyn Fn ...>` so the cutlass crate doesn't pull
88    /// `atomr-accel-cuda::nvrtc` into its compile graph when the
89    /// feature is off.
90    pub compile_sink: Option<CompileSink>,
91    /// Counter of dispatched messages — exposed for the
92    /// `actor::tests::cutlass_msg_constructs` test and for telemetry.
93    pub dispatched: Mutex<u64>,
94}
95
96impl CutlassInner {
97    pub fn new(plan_cache_capacity: usize) -> Self {
98        Self {
99            plan_cache: Arc::new(PlanCache::new(plan_cache_capacity)),
100            compile_sink: None,
101            dispatched: Mutex::new(0),
102        }
103    }
104
105    pub fn dispatched(&self) -> u64 {
106        *self.dispatched.lock()
107    }
108}
109
110/// Host-side actor. Holds an [`Arc<CutlassInner>`] so messages can be
111/// processed from a worker thread without locking the actor itself
112/// after construction.
113pub struct CutlassActor {
114    inner: Arc<CutlassInner>,
115}
116
117impl CutlassActor {
118    pub fn new(plan_cache_capacity: usize) -> Self {
119        Self {
120            inner: Arc::new(CutlassInner::new(plan_cache_capacity)),
121        }
122    }
123
124    /// `true` when the crate was built with `cutlass-prebuilt` and
125    /// `nvcc` was found at build time, so `libatomr_cutlass_prebuilt.a`
126    /// is statically linked into the binary. Phase 6.1 ships one
127    /// canonical GEMM placeholder cell (proves the wiring); Phase 6.2
128    /// expands the cell matrix and routes hits through the prebuilt
129    /// symbol table before falling back to NVRTC.
130    pub fn prebuilt_active() -> bool {
131        cfg!(cutlass_prebuilt_active)
132    }
133
134    pub fn inner(&self) -> Arc<CutlassInner> {
135        self.inner.clone()
136    }
137
138    /// Synchronously process a message. The real production path
139    /// runs through `atomr_core::actor::Actor::handle`; this method
140    /// is the host-only fast path that the unit tests exercise.
141    pub fn handle(&self, msg: CutlassMsg) {
142        *self.inner.dispatched.lock() += 1;
143        match msg {
144            CutlassMsg::Gemm(d) => {
145                let key = d.plan_key();
146                if self.inner.plan_cache.get(&key).is_none() {
147                    let (src, name) = d.render_cu();
148                    if let Some(sink) = &self.inner.compile_sink {
149                        if let Err(e) = sink(&src, &name) {
150                            tracing::warn!(error = %e, "cutlass compile sink rejected source");
151                        }
152                    }
153                    self.inner.plan_cache.insert(CachedPlan {
154                        key,
155                        source: Arc::new(src),
156                        kernel_name: Arc::new(name),
157                        kernel_handle: None,
158                    });
159                }
160            }
161            #[cfg(feature = "grouped")]
162            CutlassMsg::GroupedGemm(d) => {
163                let key = d.plan_key();
164                if self.inner.plan_cache.get(&key).is_none() {
165                    let (src, name) = d.render_cu();
166                    if let Some(sink) = &self.inner.compile_sink {
167                        let _ = sink(&src, &name);
168                    }
169                    self.inner.plan_cache.insert(CachedPlan {
170                        key,
171                        source: Arc::new(src),
172                        kernel_name: Arc::new(name),
173                        kernel_handle: None,
174                    });
175                }
176            }
177            CutlassMsg::Conv(d) => {
178                let key = d.plan_key();
179                if self.inner.plan_cache.get(&key).is_none() {
180                    let (src, name) = d.render_cu();
181                    if let Some(sink) = &self.inner.compile_sink {
182                        let _ = sink(&src, &name);
183                    }
184                    self.inner.plan_cache.insert(CachedPlan {
185                        key,
186                        source: Arc::new(src),
187                        kernel_name: Arc::new(name),
188                        kernel_handle: None,
189                    });
190                }
191            }
192            CutlassMsg::Refit { msg, reply } => {
193                let exists = self.inner.plan_cache.get(&msg.plan_key).is_some();
194                if exists {
195                    reply(Ok(()));
196                } else {
197                    reply(Err(format!(
198                        "cutlass refit: no plan for key {:?}",
199                        msg.plan_key
200                    )));
201                }
202            }
203        }
204    }
205}
206
207/// Props-equivalent constructor handle. Mirrors the
208/// `atomr_core::actor::Props` shape used elsewhere in the workspace
209/// without depending on `atomr-core` directly — once the upstream
210/// `KernelChildren::register_extra` API stabilizes, this struct is
211/// what the device actor calls into.
212#[derive(Debug, Clone, Copy)]
213pub struct CutlassProps {
214    pub plan_cache_capacity: usize,
215}
216
217impl CutlassProps {
218    pub fn new(plan_cache_capacity: usize) -> Self {
219        Self {
220            plan_cache_capacity,
221        }
222    }
223
224    pub fn create(self) -> CutlassActor {
225        CutlassActor::new(self.plan_cache_capacity)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::dtype::{SmArch, F16};
233    use crate::gemm::{GemmRequest, GemmShape};
234    use crate::plan_cache::PlanKey;
235
236    #[test]
237    fn cutlass_msg_constructs() {
238        let actor = CutlassActor::new(8);
239        let req = GemmRequest::<F16>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
240        let key = req.plan_key();
241
242        // Gemm
243        actor.handle(CutlassMsg::Gemm(Box::new(req.clone())));
244        assert_eq!(actor.inner().dispatched(), 1);
245        assert!(actor.inner().plan_cache.get(&key).is_some());
246
247        // Conv
248        use crate::conv::{ConvFwdRequest, ConvShape};
249        let conv = ConvFwdRequest::<F16>::new(ConvShape::nhwc(1, 8, 8, 16, 32, 3, 3), SmArch::Sm80);
250        let conv_key = conv.plan_key();
251        actor.handle(CutlassMsg::Conv(Box::new(conv)));
252        assert_eq!(actor.inner().dispatched(), 2);
253        assert!(actor.inner().plan_cache.get(&conv_key).is_some());
254
255        // Refit (existing plan)
256        let (tx, rx) = std::sync::mpsc::channel();
257        actor.handle(CutlassMsg::Refit {
258            msg: RefitMsg {
259                plan_key: key,
260                weights: vec![0u8; 16],
261            },
262            reply: Box::new(move |r| {
263                let _ = tx.send(r);
264            }),
265        });
266        let res = rx.recv().unwrap();
267        assert!(res.is_ok());
268
269        // Refit (missing plan)
270        let bogus = PlanKey::gemm::<F16>(
271            GemmShape::new(1, 1, 1),
272            crate::gemm::GemmLayout::RowMajor,
273            crate::gemm::GemmLayout::RowMajor,
274            crate::gemm::GemmLayout::RowMajor,
275            crate::gemm::GemmEpilogue::default(),
276            crate::dtype::CutlassDtype::F32,
277            crate::dtype::CutlassDtype::F16,
278            SmArch::Sm80,
279            false,
280        );
281        let (tx, rx) = std::sync::mpsc::channel();
282        actor.handle(CutlassMsg::Refit {
283            msg: RefitMsg {
284                plan_key: bogus,
285                weights: vec![],
286            },
287            reply: Box::new(move |r| {
288                let _ = tx.send(r);
289            }),
290        });
291        let res = rx.recv().unwrap();
292        assert!(res.is_err());
293
294        // Idempotent dispatch: re-sending the same Gemm doesn't grow
295        // the plan cache past one entry for that key.
296        let before = actor.inner().plan_cache.len();
297        actor.handle(CutlassMsg::Gemm(Box::new(req)));
298        let after = actor.inner().plan_cache.len();
299        assert_eq!(before, after);
300    }
301
302    #[cfg(feature = "grouped")]
303    #[test]
304    fn grouped_dispatch() {
305        use crate::grouped_gemm::{GroupedGemmRequest, GroupedGemmShape};
306        let actor = CutlassActor::new(4);
307        let req = GroupedGemmRequest::<F16>::new(
308            GroupedGemmShape::new(vec![GemmShape::new(64, 64, 64)]),
309            SmArch::Sm90a,
310        );
311        let key = req.plan_key();
312        actor.handle(CutlassMsg::GroupedGemm(Box::new(req)));
313        assert!(actor.inner().plan_cache.get(&key).is_some());
314    }
315}