Skip to main content

cubecl_core/compute/
launcher.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::marker::PhantomData;
3
4use crate::Runtime;
5use crate::prelude::{ArrayArg, TensorArg, TensorMapArg, TensorMapKind};
6use crate::{InfoBuilder, KernelSettings, ScalarArgType};
7#[cfg(feature = "std")]
8use core::cell::RefCell;
9use cubecl_ir::{AddressType, Scope, StorageType, Type};
10use cubecl_runtime::server::{Binding, CubeCount, TensorMapBinding};
11use cubecl_runtime::{
12    client::ComputeClient,
13    kernel::{CubeKernel, KernelTask},
14    server::KernelArguments,
15};
16
17#[cfg(feature = "std")]
18std::thread_local! {
19    static INFO: RefCell<InfoBuilder> = RefCell::new(InfoBuilder::default());
20    // Only used for resolving types
21    static SCOPE: RefCell<Scope> = RefCell::new(Scope::root(false));
22}
23
24/// Prepare a kernel for [launch](KernelLauncher::launch).
25pub struct KernelLauncher<R: Runtime> {
26    buffers: Vec<Binding>,
27    tensor_maps: Vec<TensorMapBinding>,
28    address_type: AddressType,
29    pub settings: KernelSettings,
30    #[cfg(not(feature = "std"))]
31    info: InfoBuilder,
32    #[cfg(not(feature = "std"))]
33    pub scope: Scope,
34    _runtime: PhantomData<R>,
35}
36
37impl<R: Runtime> KernelLauncher<R> {
38    #[cfg(feature = "std")]
39    pub fn with_scope<T>(&mut self, fun: impl FnMut(&mut Scope) -> T) -> T {
40        SCOPE.with_borrow_mut(fun)
41    }
42
43    #[cfg(not(feature = "std"))]
44    pub fn with_scope<T>(&mut self, mut fun: impl FnMut(&mut Scope) -> T) -> T {
45        fun(&mut self.scope)
46    }
47
48    #[cfg(feature = "std")]
49    fn with_info<T>(&mut self, fun: impl FnMut(&mut InfoBuilder) -> T) -> T {
50        INFO.with_borrow_mut(fun)
51    }
52
53    #[cfg(not(feature = "std"))]
54    fn with_info<T>(&mut self, mut fun: impl FnMut(&mut InfoBuilder) -> T) -> T {
55        fun(&mut self.info)
56    }
57
58    /// Register a scalar to be launched.
59    pub fn register_scalar<C: ScalarArgType>(&mut self, scalar: C) {
60        self.with_info(|info| info.scalars.push(scalar));
61    }
62
63    /// Register a scalar to be launched from raw data.
64    pub fn register_scalar_raw(&mut self, bytes: &[u8], dtype: StorageType) {
65        self.with_info(|info| info.scalars.push_raw(bytes, dtype));
66    }
67
68    /// Launch the kernel.
69    #[track_caller]
70    pub fn launch<K: CubeKernel>(
71        self,
72        cube_count: CubeCount,
73        kernel: K,
74        client: &ComputeClient<R>,
75    ) {
76        let bindings = self.into_bindings();
77        let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
78
79        client.launch(kernel, cube_count, bindings)
80    }
81
82    /// Launch the kernel without check bounds.
83    ///
84    /// # Safety
85    ///
86    /// The kernel must not:
87    /// - Contain any out of bounds reads or writes. Doing so is immediate UB.
88    /// - Contain any loops that never terminate. These may be optimized away entirely or cause
89    ///   other unpredictable behaviour.
90    #[track_caller]
91    pub unsafe fn launch_unchecked<K: CubeKernel>(
92        self,
93        cube_count: CubeCount,
94        kernel: K,
95        client: &ComputeClient<R>,
96    ) {
97        unsafe {
98            let bindings = self.into_bindings();
99            let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
100
101            client.launch_unchecked(kernel, cube_count, bindings)
102        }
103    }
104
105    /// We need to create the bindings in the same order they are defined in the compilation step.
106    ///
107    /// The function [`crate::KernelIntegrator::integrate`] stars by registering the input tensors followed
108    /// by the output tensors. Then the tensor metadata, and the scalars at the end. The scalars
109    /// are registered in the same order they are added. This is why we store the scalar data type
110    /// in the `scalar_order` vector, so that we can register them in the same order.
111    ///
112    /// Also returns an ordered list of constant bindings. The ordering between constants and tensors
113    /// is up to the runtime.
114    fn into_bindings(mut self) -> KernelArguments {
115        let mut bindings = KernelArguments::new();
116        let address_type = self.address_type;
117        let info = self.with_info(|info| info.finish(address_type));
118
119        bindings.buffers = self.buffers;
120        bindings.tensor_maps = self.tensor_maps;
121        bindings.info = info;
122
123        bindings
124    }
125}
126
127// Tensors/arrays
128impl<R: Runtime> KernelLauncher<R> {
129    /// Push a new input tensor to the state.
130    pub fn register_tensor(&mut self, tensor: TensorArg<R>, ty: Type) {
131        if let Some(tensor) = self.process_tensor(tensor, ty) {
132            self.buffers.push(tensor);
133        }
134    }
135
136    fn process_tensor(&mut self, tensor: TensorArg<R>, ty: Type) -> Option<Binding> {
137        let tensor = match tensor {
138            TensorArg::Handle { handle, .. } => handle,
139            TensorArg::Alias { .. } => return None,
140        };
141
142        let elem_size = ty.size();
143        let vectorization = ty.vector_size();
144
145        let buffer_len = tensor.handle.size_in_used() / elem_size as u64;
146        let len = tensor.shape.iter().product::<usize>() / vectorization;
147        let address_type = self.address_type;
148        self.with_info(|info| {
149            info.metadata.register_tensor(
150                tensor.strides.len() as u64,
151                buffer_len,
152                len as u64,
153                tensor.shape.clone(),
154                tensor.strides.clone(),
155                address_type,
156            )
157        });
158        Some(tensor.handle)
159    }
160
161    /// Push a new input array to the state.
162    pub fn register_array(&mut self, array: ArrayArg<R>, ty: Type) {
163        if let Some(tensor) = self.process_array(array, ty) {
164            self.buffers.push(tensor);
165        }
166    }
167
168    fn process_array(&mut self, array: ArrayArg<R>, ty: Type) -> Option<Binding> {
169        let array = match array {
170            ArrayArg::Handle { handle, .. } => handle,
171            ArrayArg::Alias { .. } => return None,
172        };
173
174        let elem_size = ty.size();
175        let vectorization = ty.vector_size();
176
177        let buffer_len = array.handle.size_in_used() / elem_size as u64;
178        let address_type = self.address_type;
179        self.with_info(|info| {
180            info.metadata.register_array(
181                buffer_len,
182                array.length[0] as u64 / vectorization as u64,
183                address_type,
184            )
185        });
186        Some(array.handle)
187    }
188
189    /// Push a new tensor to the state.
190    pub fn register_tensor_map<K: TensorMapKind>(&mut self, map: TensorMapArg<R, K>, ty: Type) {
191        let binding = self
192            .process_tensor(map.tensor, ty)
193            .expect("Can't use alias for TensorMap");
194
195        let map = map.metadata.clone();
196        self.tensor_maps.push(TensorMapBinding { binding, map });
197    }
198}
199
200impl<R: Runtime> KernelLauncher<R> {
201    pub fn new(settings: KernelSettings) -> Self {
202        Self {
203            address_type: settings.address_type,
204            settings,
205            buffers: Vec::new(),
206            tensor_maps: Vec::new(),
207            _runtime: PhantomData,
208            #[cfg(not(feature = "std"))]
209            info: InfoBuilder::default(),
210            #[cfg(not(feature = "std"))]
211            scope: Scope::root(false),
212        }
213    }
214}