Skip to main content

kaio_runtime/
module.rs

1//! PTX module loading and kernel function handles.
2
3use std::sync::Arc;
4
5use cudarc::driver::{CudaFunction, CudaModule};
6
7use crate::error::Result;
8
9/// A loaded PTX module on the GPU device.
10///
11/// Created via [`KaioDevice::load_ptx`](crate::device::KaioDevice::load_ptx).
12/// Use [`function`](Self::function) to get a handle to a specific kernel
13/// entry point, then launch it via cudarc's `launch_builder`.
14pub struct KaioModule {
15    inner: Arc<CudaModule>,
16}
17
18impl KaioModule {
19    /// Wrap a raw cudarc module.
20    pub(crate) fn from_raw(inner: Arc<CudaModule>) -> Self {
21        Self { inner }
22    }
23
24    /// Get a kernel function handle by name.
25    ///
26    /// The name must match the `.entry` name in the PTX source
27    /// (e.g. `"vector_add"`).
28    pub fn function(&self, name: &str) -> Result<KaioFunction> {
29        let func = self.inner.load_function(name)?;
30        Ok(KaioFunction { inner: func })
31    }
32}
33
34/// A handle to a kernel function within a loaded PTX module.
35///
36/// Use [`inner`](Self::inner) to access the underlying `CudaFunction`
37/// for passing to cudarc's `launch_builder`. In Phase 1, kernel launch
38/// goes through cudarc directly — Phase 2's macro will generate typed
39/// safe wrappers.
40pub struct KaioFunction {
41    inner: CudaFunction,
42}
43
44impl KaioFunction {
45    /// Access the underlying [`CudaFunction`] for cudarc's launch builder.
46    ///
47    /// # Example
48    ///
49    /// ```ignore
50    /// let cfg = LaunchConfig::for_num_elems(n);
51    /// unsafe {
52    ///     device.stream()
53    ///         .launch_builder(func.inner())
54    ///         .arg(buf_a.inner())
55    ///         .arg(&n)
56    ///         .launch(cfg)?;
57    /// }
58    /// ```
59    pub fn inner(&self) -> &CudaFunction {
60        &self.inner
61    }
62}