kaio_runtime/module.rs
1//! PTX module loading and kernel function handles.
2
3use std::sync::Arc;
4
5use cudarc::driver::{CudaFunction, CudaModule};
6
7use crate::error::Result;
8
9/// A loaded PTX module on the GPU device.
10///
11/// Created via [`KaioDevice::load_ptx`](crate::device::KaioDevice::load_ptx).
12/// Use [`function`](Self::function) to get a handle to a specific kernel
13/// entry point, then launch it via cudarc's `launch_builder`.
14pub struct KaioModule {
15 inner: Arc<CudaModule>,
16}
17
18impl KaioModule {
19 /// Wrap a raw cudarc module.
20 pub(crate) fn from_raw(inner: Arc<CudaModule>) -> Self {
21 Self { inner }
22 }
23
24 /// Get a kernel function handle by name.
25 ///
26 /// The name must match the `.entry` name in the PTX source
27 /// (e.g. `"vector_add"`).
28 pub fn function(&self, name: &str) -> Result<KaioFunction> {
29 let func = self.inner.load_function(name)?;
30 Ok(KaioFunction { inner: func })
31 }
32}
33
34/// A handle to a kernel function within a loaded PTX module.
35///
36/// Use [`inner`](Self::inner) to access the underlying `CudaFunction`
37/// for passing to cudarc's `launch_builder`. In Phase 1, kernel launch
38/// goes through cudarc directly — Phase 2's macro will generate typed
39/// safe wrappers.
40pub struct KaioFunction {
41 inner: CudaFunction,
42}
43
44impl KaioFunction {
45 /// Access the underlying [`CudaFunction`] for cudarc's launch builder.
46 ///
47 /// # Example
48 ///
49 /// ```ignore
50 /// let cfg = LaunchConfig::for_num_elems(n);
51 /// unsafe {
52 /// device.stream()
53 /// .launch_builder(func.inner())
54 /// .arg(buf_a.inner())
55 /// .arg(&n)
56 /// .launch(cfg)?;
57 /// }
58 /// ```
59 pub fn inner(&self) -> &CudaFunction {
60 &self.inner
61 }
62}